Index: squid-tagger ================================================================== --- squid-tagger +++ squid-tagger @@ -1,93 +1,141 @@ -#!/usr/bin/env python-shared +#!/usr/bin/env python3.1 + +import configparser, optparse, os, postgresql.api, re, sys, _thread -import cPickle,psycopg2,re,sys,thread +class Logger: + __slots__ = frozenset(['_silent', '_syslog']) -class Logger(object): - __slots__=frozenset(['_silent','_syslog']) - def __init__(this,silent=True): + def __init__(self, silent = True): if silent: - this._silent=True + self._silent = True else: import syslog - this._syslog=syslog - this._syslog.openlog('squidTag') - this._silent=False - def info(this,message): - if not this._silent: - this._syslog.syslog(this._syslog.LOG_INFO,message) - def notice(this,message): - if not this._silent: - this._syslog.syslog(this._syslog.LOG_NOTICE,message) - -class tagDB(object): - __slots__=frozenset(['_prepared','_cursor']) - def __init__(this): - this._prepared=set() - this._cursor=False - def _curs(this): - if not this._cursor: - this._cursor=psycopg2.connect('host=%s dbname=%s user=%s password=%s'%('pkunk','squidTag','squidTag','NachJas%')).cursor() - return this._cursor - def check(this,ip_address,site): - return this._curs().execute("select redirect_url from site_rules where site <@ tripdomain(%s) and netmask >> %s limit 1",(site,ip_address,)) - def statusmessage(this): - return this._curs().statusmessage - def fetchone(this): - return this._curs().fetchone() - -class CheckerThread(object): - __slots__=frozenset(['_db','_lock','_lock_queue','_log','_queue']) - def __init__(this,db,log): - this._db=db - this._log=log - this._lock=thread.allocate_lock() - this._lock_queue=thread.allocate_lock() - this._lock.acquire() - this._queue=[] - thread.start_new_thread(this._start,()) - def _start(this): + self._syslog = syslog + self._syslog.openlog('squidTag') + self._silent = False + + def info(self, message): + if not self._silent: + self._syslog.syslog(self._syslog.LOG_INFO, message) + + def notice(self, message): + if not self._silent: + self._syslog.syslog(self._syslog.LOG_NOTICE, message) + +class tagDB: + __slots__ = frozenset(['_prepared', '_db']) + + def __init__(self): + self._prepared = set() + self._db = False + + def _curs(self): + if not self._db: + config.section('database') + # needs thinking + #connector = postgresql.api.Connector( + #user = config['user'], password = config['password'], + #database = config['database'], + self._db = postgresql.open( + 'pq://{0}:{1}@{2}/{3}'.format( + config['user'], + config['password'], + config['host'], + config['database'], + )) + return(self._db) + + def check(self, ip_address, site): + # doesn't work for inet + #stmt = self._curs().prepare("select redirect_url from site_rules where site <@ tripdomain($1) and netmask >> '$2' limit 1") + #result = stmt(site, ip_address) + stmt = self._curs().prepare("select redirect_url from site_rules where site <@ tripdomain('{0}') and netmask >> '{1}' limit 1".format(site, ip_address)) + result = stmt() + if len(result) > 0: + return result[0] + else: + return None + +class CheckerThread: + __slots__ = frozenset(['_db', '_lock', '_lock_queue', '_log', '_queue']) + + def __init__(self, db, log): + self._db = db + self._log = log + self._lock = _thread.allocate_lock() + self._lock_queue = _thread.allocate_lock() + self._lock.acquire() + self._queue = [] + _thread.start_new_thread(self._start, ()) + + def _start(self): while True: - this._lock.acquire() - this._lock_queue.acquire() - if len(this._queue)>1 and this._lock.locked(): - this._lock.release() - req=this._queue.pop(0) - this._lock_queue.release() - this._log.info('trying %s\n'%req[1]) - this._db.check(req[2],req[1]) - this._log.info("Got '%s' from database.\n"%this._db.statusmessage()) - row=this._db.fetchone() + self._lock.acquire() + self._lock_queue.acquire() + if len(self._queue) > 1 and self._lock.locked(): + self._lock.release() + req = self._queue.pop(0) + self._lock_queue.release() + self._log.info('trying %s\n'%req[1]) + row = self._db.check(req[2], req[1]) if row != None and row[0] != None: - writeline('%s 302:%s\n'%(req[0],row[0])) + writeline('%s 302:%s\n'%(req[0], row[0])) else: writeline('%s -\n'%req[0]) - def check(this,line): - request=re.compile('^([0-9]+)\ (http|ftp):\/\/([-\w.:]+)\/([^ ]*)\ ([0-9.]+)\/(-|[\w\.]+)\ (-|\w+)\ (-|GET|HEAD|POST).*$').match(line) + + def check(self, line): + request = re.compile('^([0-9]+)\ (http|ftp):\/\/([-\w.:]+)\/([^ ]*)\ ([0-9.]+)\/(-|[\w\.]+)\ (-|\w+)\ (-|GET|HEAD|POST).*$').match(line) if request: - site=request.group(3) - ip_address=request.group(5) - id=request.group(1) - this._lock_queue.acquire() - this._queue.append((id,site,ip_address)) - if this._lock.locked(): - this._lock.release() - this._lock_queue.release() - this._log.info('request %s queued (%s)\n'%(id,line)) + site = request.group(3) + ip_address = request.group(5) + id = request.group(1) + self._lock_queue.acquire() + self._queue.append((id, site, ip_address)) + if self._lock.locked(): + self._lock.release() + self._lock_queue.release() + self._log.info('request %s queued (%s)\n'%(id, line)) else: - this._log.info('bad request\n') + self._log.info('bad request\n') writeline(line) def writeline(string): log.info('sending: %s'%string) sys.stdout.write(string) sys.stdout.flush() -log=Logger(False) -db=tagDB() -checker=CheckerThread(db,log) +class Config: + __slots__ = frozenset(['_config', '_section']) + + def __init__(self): + parser = optparse.OptionParser() + parser.add_option('-c', '--config', dest = 'config', + help = 'config file location', metavar = 'FILE', + default = '/usr/local/etc/squid-tagger.conf') + + (options, args) = parser.parse_args() + + if not os.access(options.config, os.R_OK): + print("Can't read {0}: exitting".format(options.config)) + sys.exit(2) + + self._config = configparser.ConfigParser() + self._config.readfp(open(options.config)) + + def section(self, section): + self._section = section + + def __getitem__(self, name): + return self._config.get(self._section, name) + +config = Config() + +log = Logger(False) +db = tagDB() +checker = CheckerThread(db,log) while True: - line=sys.stdin.readline() - if len(line)==0: + line = sys.stdin.readline() + if len(line) == 0: break checker.check(line)