#!/usr/bin/python2.3

from optparse import OptionParser
import socket
import re
import struct

"""
Copyright (c) 2006
The Regents of the University of Michigan
All Rights Reserved

Permission is granted to use, copy, create derivative works, and
redistribute this software and such derivative works for any purpose,
so long as the name of the University of Michigan is not used in
any advertising or publicity pertaining to the use or distribution
of this software without specific, written prior authorization. If
the above copyright notice or any other identification of the
University of Michigan is included in any copy of any portion of
this software, then the disclaimer below must also be included.

This software is provided as is, without representation or warranty
of any kind either express or implied, including without limitation
the implied warranties of merchantability, fitness for a particular
purpose, or noninfringement.  The Regents of the University of
Michigan shall not be liable for any damages, including special,
indirect, incidental, or consequential damages, with respect to any
claim arising out of or in connection with the use of the software,
even if it has been or is hereafter advised of the possibility of
such damages.

1. What happens if a client matches more than one criterion?
   eg jupiter against *() *.umich.edu() *.citi.umich.edu()

   - mountd takes first one that matches, ignoring others
   
2. What happens if same path appears twice

   - options are taken in order they appear
"""

###########################################################

class Entry(object):
    """Holds a client(options) entry from the exports file"""
    def __init__(self, str):
        """str is of form <client>(<options>)"""
        i = str.find('(')
        self.client = str[:i]
        if self.client == "":
            self.client = "*"
        self.options = str[i+1:-1]
        self.type = self._settype()
        self._setflavors()
        
    def _settype(self):
        c = self.client
        if c == "" or c == "*":
            c = "*"
            self.regex = re.compile(r".*")
            return "wild"
        elif c.startswith("gss/"):
            return "gss"
        elif c[0] == "@":
            return "netgroup"
        elif ismask(c):
            fields = self.client.split('/')
            self.addr = fields[0]
            if len(fields) == 2:
                if isquad(fields[1]):
                    self.mask = socket.inet_aton(fields[1])
                else:
                    bits = int(fields[1])
                    mask = 0xffffffffL & ~(2**(32-bits) - 1)
                    self.mask = struct.pack("!Q", mask)[-4:]
            else:
                self.mask = '\xff\xff\xff\xff'
            return "netmask"
        skip = False
        for x in c:
            if skip:
                skip = False
                continue
            if x in "*?[":
                self.regex = self._get_wild_pattern()
                return "wild"
            if x == '\\':
                skip = True
        self.fqdn = socket.getfqdn(self.client)
        return "name"

    def _setflavors(self):
        self.flavors = []
        if self.type == "gss":
            self.flavors.append(self.client[4:])
        opts = self.options.split(',')
        for o in opts:
            if o.startswith("sec="):
                for f in o[4:].split(':'):
                    if f not in self.flavors:
                        self.flavors.extend(f)
        if not self.flavors:
            self.flavors = ["sys"]
        
    def __repr__(self):
        # return "%s(%s)" % (self.client, self.options)
        return "%s(%s)[%s]" % (self.client, self.options, self.type)

    def matches(self, address, sec):
        """Returns True if self.client matches address"""
        # XXX Really have to carefully match mountd code
        if sec not in self.flavors:
            return False
        if self.type == "name":
            return self.fqdn in address.names
        elif self.type == "netmask":
            for q in address.quads:
                if match_mask(self.addr, q, self.mask):
                    return True
            return False
        elif self.type == "wild":
            for name in address.names:
                if self.regex.match(name):
                    return True
            return False
        elif self.type == "gss":
            return self.client[4:] == sec
        else:
            raise TypeError("Unknown type %s" % self.type)

    def wildmatch(self, text, pat):
        def domatch(t, p):
            ti = pi = 0
            while p[pi:]:
                if t[ti:] == '' and p[pi] != '*':
                    return None
                if p[pi] == '*':
                    try:
                        while p[pi] == '*':
                            pi += 1
                    except IndexError:
                        return True
                    while t[ti:]:
                        matched = domatch(t[ti:], p[pi:])
                        ti += 1
                        if matched != False:
                            return matched
                    return None
                while (True): # Allows 'break' to jump to correct point
                    if p[pi] == '?':
                        break;
                    if p[pi] == '\\':
                        pi += 1
                    if p[pi] == '[':
                        reverse = p[pi+1] == '^'
                        if reverse:
                            pi += 1
                        matched = False
                        if p[pi+1] == ']' or p[pi+1] == '-':
                            pi += 1
                            if p[pi].upper() == t[ti].upper():
                                matched = True
                        last = p[pi]
                        pi += 1
                        while (p[pi:] and p[pi] != ']'):
                            if p[pi] == '-' and p[pi] != ']':
                                pi += 1
                                if last <= t[ti] <= p[pi]:
                                    matched = True
                            elif p[pi].upper() == t[ti].upper():
                                matched = True
                            last = p[pi]
                            pi += 1
                        if matched == reverse:
                            return False
                        break
                    if p[pi].upper() != t[ti].upper():
                        return False
                    break
                ti += 1
                pi += 1
            return t[ti:] == ''
        if pat == '*' or pat == '':
            return True
        return domatch(text, pat) == True
    def _get_wild_pattern(self):
        """Returns regex pattern corresponding to self.client

        Does shell-style matching for *, ?, [], and \ characters.
        """
        # XXX BUG? \ handling is confusing
        # XXX should we emulate mountd code, or use python re library?
        pat = "^"
        skip = False
        in_bracket = False
        for c in self.client:
            out = c
            if skip == True:
                skip = False
            elif c == '\\': # Note this implies active w/in bracket
                skip = True
            elif in_bracket:
                if c == ']':
                    in_bracket = False
            elif c == '[':
                in_bracket = True
            elif c == '*':
                out = ".*"
            elif c == '?':
                out = '.'
            elif c == '.':
                out = r'\.'
            pat += out
        pat += '$'
        return re.compile(pat)
        
class Address(object):
    """Holds name and ip address info"""
    def __init__(self, client):
        self.given_name = client
        triple = self._gethost(client)
        self.names = [triple[0]] + triple[1]
        self.quads = triple[2]

    def _gethost(self, name):
        """Try to duplicate mountd method of getting host"""
        try:
            addr = socket.gethostbyname_ex(name)
        except:
            # XXX this creates an empty name, as opposed to empty list
            # should we try harder to create empty list?
            return ("", [], [client])
        try:
            host = socket.gethostbyaddr(addr[2][0])
            if len(addr[2]) > 1:
                try:
                    host = socket.gethostbyname_ex(host[0])
                except:
                    pass

        except:
            host = addr
        return host
    
    def show_access(self, exportdata, sec="sys", debug=False):
        """Reports access self would have to various paths in exports file"""
        for path in exportdata:
            access = False
            for entry in exportdata[path]:
                if entry.matches(self, sec):
                    if debug: print "MATCH - %s: %s: %s" % (self, path, entry)
                    access = True
                    break
            if access:
                print "ALLOW: %s %s" % (path, entry)
            else:
                print "DENY : %s" % path

    def __repr__(self):
        return "%s, %s" % (self.names, self.quads)
    
###########################################################

def isquad(str):
    """Determines if str is a dot quad address"""
    a = str.split('.')
    if len(a) != 4:
        return False
    for q in a:
        try:
            i = int(q)
        except ValueError:
            return False
        if not (0 <= i < 256):
            return False
    return True

def ismask(str):
    """Determines if str is a netmask (a simple dot-quad will return True)"""
    # XXXSurely a lot of this mask manipulation must exist in a library already
    fields = str.split('/')
    if len(fields) == 1:
        return isquad(str)
    elif len(fields) > 2:
        return False
    # At this point fields[0] should be address, fields[1] should be mask
    elif not isquad(fields[0]):
        return False
    elif isquad(fields[1]):
        return True
    else:
        try:
            i = int(fields[1])
        except ValueError:
            return False
        return (0 <= i <= 32)

def match_mask(ip1, ip2, mask):
    """Returns True if ip1&mask == ip2&mask"""
    def convert(str):
        """Convert dot-quad string to integer"""
        tmp = socket.inet_aton(str)
        return struct.unpack("!L", tmp)[0]
    # print "match_mask(%s, %s, %s)" % (ip1, ip2, repr(mask))
    ip1 = convert(ip1)
    ip2 = convert(ip2)
    mask = struct.unpack("!L", mask)[0]
    return ip1 & mask == ip2 & mask

###########################################################

def get_options():
    def_file = "/etc/exports"
    def_flavor = "sys"
    p = OptionParser(usage="%prog [options] client ...")
    p.add_option("-f", "--file", default=def_file, metavar="FILE",
                 help="Parse FILE as exports file [%s]" % def_file)
    p.add_option("--sec", default="sys",
                 metavar="FLAVOR",
                 help="Assume client is using security FLAVOR [%s]" % def_flavor)
    opts, args = p.parse_args()
    if args:
        opts.clients = [Address(c) for c in args]
    else:
        p.error("No client given")
    return opts

def readlines(filename):
    """Parse file, concatanating lines and removing whitespace as necessary."""
    def full_lines(fd):
        """Iterate over lines, where we concatanate lines ending with \ """
        line = ""
        for partial_line in fd:
            line += partial_line.rstrip()
            if line.endswith('\\'):
                line = line[:-1]
            else:
                yield line.strip()
                line = ""

    fd = open(filename)
    try:
        return [' '.join(l.split()) for l in full_lines(fd)
                if l != '' and l[0] != '#']
    finally:
        fd.close()

def parse_lines(lines):
    d = {}
    for l in lines:
        fields = l.split()
        path = fields[0]
        exports = [Entry(s) for s in fields[1:]]
        if path in d:
            d[path].extend(exports)
        else:
            d[path] = exports
    return d

###########################################################

def main(opts, debug=True):
    try:
        lines = readlines(opts.file)
    except StandardError, e:
        print "Error trying to read %s:\n%s" % (opts.file, e)
        return
    d = parse_lines(lines)
    if debug:
        print "Reading from file %s:" % opts.file
        for path in d:
            print "%s %s" % (path, ' '.join([str(e) for e in d[path]]))
        print
    for c in opts.clients:
        print c.given_name
        if debug: print c
        c.show_access(d, opts.sec)
        print

if __name__ == "__main__":
    import sys
    try:
        opts = get_options()
    except Exception, e:
        if e.code:
            print e
            print "Failure reading commandline options"
        sys.exit(1)
    main(opts)
