Index: squid-tagger.py ================================================================== --- squid-tagger.py +++ squid-tagger.py @@ -1,89 +1,170 @@ -#!/usr/bin/env python3.1 +#!/usr/bin/env python + +from __future__ import division, print_function, unicode_literals + +import gevent.monkey +gevent.monkey.patch_all() + +import fcntl, gevent.core, gevent.pool, gevent.queue, gevent.socket, os, psycopg2, re, sys + +# //inclusion start +# Copyright (C) 2010 Daniele Varrazzo +# and licensed under the MIT license: + +def gevent_wait_callback(conn, timeout=None): + """A wait callback useful to allow gevent to work with Psycopg.""" + while 1: + state = conn.poll() + if state == psycopg2.extensions.POLL_OK: + break + elif state == psycopg2.extensions.POLL_READ: + gevent.socket.wait_read(conn.fileno(), timeout=timeout) + elif state == psycopg2.extensions.POLL_WRITE: + gevent.socket.wait_write(conn.fileno(), timeout=timeout) + else: + raise psycopg2.OperationalError("Bad result from poll: %r" % state) + +if not hasattr(psycopg2.extensions, 'set_wait_callback'): + raise ImportError("support for coroutines not available in this Psycopg version (%s)" % psycopg2.__version__) + psycopg2.extensions.set_wait_callback(gevent_wait_callback) + +# //inclusion end + +# tiny wrapper around a file to make reads from it geventable +# or should i move this somewhere? + +class FReadlineQueue(gevent.queue.Queue): + # storing file descriptor, leftover + __slots__ = frozenset(['_fd', '_tail']) + + def __init__(self, fd): + # initialising class + gevent.queue.Queue.__init__(self) + # storing file descriptor + self._fd = fd + # using empty tail + self._tail = '' + # setting up event + self._install_wait() + + def _install_wait(self): + fileno = self._fd.fileno() + # putting file to nonblocking mode + fcntl.fcntl(fileno, fcntl.F_SETFL, fcntl.fcntl(fileno, fcntl.F_GETFL) | os.O_NONBLOCK) + # installing event handler + gevent.core.read_event(fileno, self._wait_helper) + + def _wait_helper(self, ev, evtype): + # reading one buffer from stream + buf = self._fd.read(4096) + # splitting stream by line ends + rows = buf.split('\n') + # adding tail to the first element if there is some tail + if len(self._tail) > 0: + rows[0] = self._tail + rows[0] + # popping out last (incomplete) element + self._tail = rows.pop(-1) + # dropping all complete elements to the queue + for row in rows: + self.put_nowait(row) + if len(buf) > 0: + # no EOF, reinstalling event handler + gevent.core.read_event(self._fd.fileno(), self._wait_helper) + else: + # EOF found, sending EOF to queue + self.put_nowait(None) -import postgresql.api, re, sys +stdin = FReadlineQueue(sys.stdin) # wrapper around syslog, can be muted -class Logger: +class Logger(object): __slots__ = frozenset(['_syslog']) def __init__(self): config.section('log') if config['silent'] == 'yes': self._syslog = None else: import syslog self._syslog = syslog - self._syslog.openlog('squidTag') + self._syslog.openlog(str('squidTag')) def info(self, message): - if self._syslog: + if self._syslog != None: self._syslog.syslog(self._syslog.LOG_INFO, message) def notice(self, message): - if self._syslog: + if self._syslog != None: self._syslog.syslog(self._syslog.LOG_NOTICE, message) # wrapper around database -class tagDB: - __slots__ = frozenset(('_check_stmt', '_db')) +class tagDB(object): + __slots__ = frozenset(['_cursor', '_db']) def __init__(self): config.section('database') - self._db = postgresql.open( - 'pq://{}:{}@{}/{}'.format( - config['user'], - config['password'], - config['host'], - config['database'], - ) ) - self._check_stmt = None + self._db = psycopg2.connect( + database = config['database'], + host = config['host'], + user = config['user'], + password = config['password'], + ) + self._cursor = self._db.cursor() + + def _field_names(self): + names = [] + for record in self._cursor.description: + names.append(record.name) + return(names) def check(self, site, ip_address): - if self._check_stmt == None: - self._check_stmt = self._db.prepare("select redirect_url, regexp from site_rule where site <@ tripdomain($1) and netmask >>= $2::text::inet order by array_length(site, 1) desc") - return(self._check_stmt(site, ip_address)) + self._cursor.execute("select redirect_url, regexp from site_rule where site <@ tripdomain(%s) and netmask >>= %s order by array_length(site, 1) desc", [site, ip_address]) + return(self._cursor.fetchall()) def dump(self): - return(self._db.prepare("copy (select untrip(site) as site, tag, regexp from urls order by site, tag) to stdout csv header")()) + self._cursor.execute("select untrip(site) as site, tag::text, regexp from urls order by site, tag") + return(self._field_names(), self._cursor.fetchall()) def load(self, data): - with self._db.xact(): - if config.options.flush_db: - self._db.execute('delete from urls;') - insert = self._db.prepare("insert into urls (site, tag, regexp) values (tripdomain($1), $2::text::text[], $3)") - for row in data: - if len(row) == 2: - insert(row[0], row[1], None) - else: - insert(row[0], row[1], row[2]) - self._db.execute("update urls set regexp = NULL where regexp = ''") - self._db.execute('vacuum analyze urls;') + if config.options.flush_db: + self._cursor.execute('delete from urls;') + bundle = [] + for row in data: + if len(row) == 2: + bundle.append([row[0], row[1], None]) + else: + bundle.append([row[0], row[1], row[2]]) + self._cursor.executemany("insert into urls (site, tag, regexp) values (tripdomain(%s), %s, %s)", bundle) + self._cursor.execute("update urls set regexp = NULL where regexp = ''") + self._db.commit() def load_conf(self, csv_data): - with self._db.xact(): - self._db.execute('delete from rules;') - insertconf = self._db.prepare("insert into rules (netmask, redirect_url, from_weekday, to_weekday, from_time, to_time, tag) values ($1::text::cidr, $2, $3, $4, $5::text::time, $6::text::time, $7::text::text[])") - for row in csv_data: - insertconf(row[0], row[1], int(row[2]), int(row[3]), row[4], row[5], row[6]) - self._db.execute('vacuum analyze rules;') + self._cursor.execute('delete from rules;') + bundle = [] + for row in csv_data: + bundle.append([row[0], row[1], int(row[2]), int(row[3]), row[4], row[5], row[6]]) + self._cursor.executemany("insert into rules (netmask, redirect_url, from_weekday, to_weekday, from_time, to_time, tag) values (%s::text::cidr, %s, %s, %s, %s::text::time, %s::text::time, %s::text::text[])", bundle) + self._db.commit() def dump_conf(self): - return(self._db.prepare("copy (select netmask, redirect_url, from_weekday, to_weekday, from_time, to_time, tag from rules) to stdout csv header")()) + self._cursor.execute("select netmask, redirect_url, from_weekday, to_weekday, from_time, to_time, tag::text from rules") + return(self._field_names(), self._cursor.fetchall()) # abstract class with basic checking functionality -class Checker: - __slots__ = frozenset(['_db', '_log', '_request']) +class Checker(object): + __slots__ = frozenset(['_db', '_log', '_queue', '_request']) - def __init__(self): + def __init__(self, queue): self._db = tagDB() self._log = Logger() self._log.info('started\n') self._request = re.compile('^([0-9]+)\ (http|ftp):\/\/([-\w.:]+)\/([^ ]*)\ ([0-9.]+)\/(-|[\w\.]+)\ (-|\w+)\ (-|GET|HEAD|POST).*$') + self._queue = queue def process(self, id, site, ip_address, url_path, line = None): - self._log.info('trying {}\n'.format(site)) + #self._log.info('trying {}\n'.format(site)) result = self._db.check(site, ip_address) reply = None for row in result: if row != None and row[0] != None: if row[1] != None: @@ -100,174 +181,42 @@ if reply != None: self.writeline('{} {}\n'.format(id, reply)) return(True) self.writeline('{}\n'.format(id)) - def check(self, line): - request = self._request.match(line) - if request: - id = request.group(1) - #proto = request.group(2) - site = request.group(3) - url_path = request.group(4) - ip_address = request.group(5) - self.process(id, site, ip_address, url_path, line) - return(True) - else: - self._log.info('bad request\n') - self.writeline(line) - return(False) + def check(self): + while True: + line = self._queue.get() + if line == None: + break + self._log.info('request: ' + line) + request = self._request.match(line) + if request: + id = request.group(1) + #proto = request.group(2) + site = request.group(3) + url_path = request.group(4) + ip_address = request.group(5) + self.process(id, site, ip_address, url_path, line) + else: + self._log.info('bad request\n') + self.writeline(line + '\n') def writeline(self, string): self._log.info('sending: ' + string) sys.stdout.write(string) sys.stdout.flush() def loop(self): - while True: - line = sys.stdin.readline() - if len(line) == 0: - break - self.check(line) - -# threaded checking facility -class CheckerThread(Checker): - __slots__ = frozenset(['_lock', '_lock_exit', '_lock_queue', '_queue']) - - def __init__(self): - import _thread - - # basic initialisation - Checker.__init__(self) - - # Spin lock. Loop acquires it on start then releases it when holding queue - # lock. This way the thread proceeds without stops while queue has data and - # gets stalled when no data present. The lock is released by queue writer - # after storing something into the queue - self._lock = _thread.allocate_lock() - self._lock_exit = _thread.allocate_lock() - self._lock_queue = _thread.allocate_lock() - self._lock.acquire() - self._queue = [] - _thread.start_new_thread(self._start, ()) - - def _start(self): - while True: - self._lock.acquire() - with self._lock_queue: - # yes this should be written this way, and yes, this is why I hate threading - if len(self._queue) > 1: - if self._lock.locked(): - self._lock.release() - req = self._queue.pop(0) - Checker.process(self, req[0], req[1], req[2], req[3]) - with self._lock_queue: - if len(self._queue) == 0: - if self._lock_exit.locked(): - self._lock_exit.release() - - def process(self, id, site, ip_address, url_path, line): - with self._lock_queue: - self._queue.append((id, site, ip_address, url_path)) - self._log.info('request {} queued ({})\n'.format(id, line)) - if not self._lock_exit.locked(): - self._lock_exit.acquire() - if self._lock.locked(): - self._lock.release() - - def loop(self): - while True: - line = sys.stdin.readline() - if len(line) == 0: - break - self.check(line) - self._lock_exit.acquire() - -# kqueue enabled class for BSD's -class CheckerKqueue(Checker): - __slots__ = frozenset(['_kq', '_select', '_queue']) - - def __init__(self): - # basic initialisation - Checker.__init__(self) - - # importing select module - import select - self._select = select - - # kreating kqueue - self._kq = self._select.kqueue() - assert self._kq.fileno() != -1, "Fatal error: can't initialise kqueue." - - # watching sys.stdin for data - self._kq.control([self._select.kevent(sys.stdin, self._select.KQ_FILTER_READ, self._select.KQ_EV_ADD)], 0) - - # creating data queue - self._queue = [] - - def loop(self): - # Wait for data by default - timeout = None - eof = False - buffer = '' - while True: - # checking if there is any data or witing for data to arrive - kevs = self._kq.control(None, 1, timeout) - - for kev in kevs: - if kev.filter == self._select.KQ_FILTER_READ and kev.data > 0: - # reading data in - new_buffer = sys.stdin.read(kev.data) - # if no data was sent - we have reached end of file - if len(new_buffer) == 0: - eof = True - else: - # adding current buffer to old buffer remains - buffer += new_buffer - # splitting to lines - lines = buffer.split('\n') - # last line that was not terminate by newline returns to buffer - buffer = lines[-1] - # an only if there was at least one newline - if len(lines) > 1: - for line in lines[:-1]: - # add data to the queue - if self.check(line + '\n'): - # don't wait for more data, start processing - timeout = 0 - - # detect end of stream and exit if possible - if kev.flags >> 15 == 1: - self._kq.control([self._select.kevent(sys.stdin, self._select.KQ_FILTER_READ, self._select.KQ_EV_DELETE)], 0) - eof = True - #timeout = 0 - - if len(kevs) == 0: - if len(self._queue) > 0: - # get one request and process it - req = self._queue.pop(0) - Checker.process(self, req[0], req[1], req[2], req[3]) - if len(self._queue) == 0: - # wait for data - we have nothing to process - timeout = None - - # if queue is empty and we reached end of stream - we can exit - if len(self._queue) == 0 and eof: - break - - def process(self, id, site, ip_address, url_path, line): - # simply adding data to the queue - self._queue.append((id, site, ip_address, url_path)) - self._log.info('request {} queued ({})\n'.format(id, line)) + pool = gevent.pool.Pool() + pool.spawn(self.check) + pool.join() # this classes processes config file and substitutes default values class Config: __slots__ = frozenset(['_config', '_default', '_section', 'options']) _default = { - 'reactor': { - 'reactor': 'thread', - }, 'log': { 'silent': 'no', }, 'database': { 'host': 'localhost', @@ -274,11 +223,11 @@ 'database': 'squidTag', },} # function to read in config file def __init__(self): - import configparser, optparse, os + import ConfigParser, optparse, os parser = optparse.OptionParser() parser.add_option('-c', '--config', dest = 'config', help = 'config file location', metavar = 'FILE', default = '/usr/local/etc/squid-tagger.conf') @@ -300,11 +249,11 @@ (self.options, args) = parser.parse_args() assert os.access(self.options.config, os.R_OK), "Fatal error: can't read {}".format(self.options.config) - self._config = configparser.ConfigParser() + self._config = ConfigParser.ConfigParser() self._config.readfp(open(self.options.config)) # function to select config file section or create one def section(self, section): if not self._config.has_section(section): @@ -333,17 +282,19 @@ tagdb = tagDB() data_fields = ['site', 'tag', 'regexp'] conf_fields = ['netmask', 'redirect_url', 'from_weekday', 'to_weekday', 'from_time', 'to_time', 'tag'] if config.options.dump or config.options.dump_conf: + csv_writer = csv.writer(sys.stdout) if config.options.dump: dump = tagdb.dump() elif config.options.dump_conf: dump = tagdb.dump_conf() - for line in dump: - sys.stdout.write(line.decode('utf-8')) + csv_writer.writerow(dump[0]) + for line in dump[1]: + csv_writer.writerow(line) elif config.options.load or config.options.load_conf: csv_reader = csv.reader(sys.stdin) first_row = next(csv_reader) @@ -357,14 +308,6 @@ assert first_row == fields, 'File must contain csv data with theese columns: ' + repr(fields) load(csv_reader) else: # main loop - config.section('reactor') - if config['reactor'] == 'thread': - checker = CheckerThread() - elif config['reactor'] == 'plain': - checker = Checker() - elif config['reactor'] == 'kqueue': - checker = CheckerKqueue() - - checker.loop() + Checker(stdin).loop()