Samesite - proxy that can cache partial transfers

samesite.py at [b5c328f916]
anonymous

samesite.py at [b5c328f916]

File samesite.py artifact d358a0b198 part of check-in b5c328f916


#!/usr/bin/env python

from __future__ import unicode_literals, print_function

import bsddb.dbshelve, copy, datetime, os, BaseHTTPServer, sys, spacemap, re, urllib2

class Config:
	__slots__ = frozenset(['_config', '_default', '_section', 'options', 'root'])
	_default = {
		'general': {
			'port': '8008',
		},
		'_other': {
			'verbose': 'no',
			'noetag': 'no',
			'noparts': 'no',
			'strip': '',
			'sub': '',
	},}

	# function to read in config file
	def __init__(self):
		import ConfigParser, optparse

		parser = optparse.OptionParser()
		parser.add_option('-c', '--config', dest = 'config', help = 'config file location', metavar = 'FILE', default = 'samesite.conf')
		(self.options, args) = parser.parse_args()

		assert os.access(self.options.config, os.R_OK), "Fatal error: can't read {}".format(self.options.config)

		configDir = re.compile('^(.*)/[^/]+$').match(self.options.config)
		if configDir:
			self.root = configDir.group(1)
		else:
			self.root = os.getcwd()

		self._config = ConfigParser.ConfigParser()
		self._config.readfp(open(self.options.config))

		for section in self._config.sections():
			if section != 'general':
				if self._config.has_option(section, 'dir'):
					if re.compile('^/$').match(self._config.get(section, 'dir')):
						self._config.set(section, 'dir', self.root + os.sep + section)
					thisDir = re.compile('^(.*)/$').match(self._config.get(section, 'dir'))
					if thisDir:
						self._config.set(section, 'dir', thisDir.group(1))
					if not re.compile('^/(.*)$').match(self._config.get(section, 'dir')):
						self._config.set(section, 'dir', self.root + os.sep + self._config.get(section, 'dir'))
				else:
					self._config.set(section, 'dir', self.root + os.sep + section)

				if not self._config.has_option(section, 'root'):
					self._config.set(section, 'root', section)

	# function to select config file section or create one
	def section(self, section):
		if not self._config.has_section(section):
			self._config.add_section(section)
		self._section = section

	# function to get config parameter, if parameter doesn't exists the default
	# value or None is substituted
	def __getitem__(self, name):
		if not self._config.has_option(self._section, name):
			if self._section in self._default:
				if name in self._default[self._section]:
					self._config.set(self._section, name, self._default[self._section][name])
				else:
					self._config.set(self._section, name, None)
			elif name in self._default['_other']:
				self._config.set(self._section, name, self._default['_other'][name])
			else:
				self._config.set(self._section, name, None)
		return(self._config.get(self._section, name))

config = Config()

#assert options.port or os.access(options.log, os.R_OK), 'Log file unreadable'

const_desc_fields = set(['content-length', 'last-modified', 'pragma'])
const_ignore_fields = set([
	'accept-ranges', 'age',
	'cache-control', 'connection', 'content-type',
	'date',
	'expires',
	'referer',
	'server',
	'via',
	'x-cache', 'x-cache-lookup', 'x-livetool', 'x-powered-by',
])

block_size = 4096

class MyRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
	def __process(self):
		# reload means file needs to be reloaded to serve request
		reload = False
		# recheck means file needs to be checked, this also means that if file hav been modified we can serve older copy
		recheck = False
		# file_stat means file definitely exists
		file_stat = None
		# requested_ranges holds data about any range requested
		requested_ranges = None
		# records holds data from index locally, should be written back upon successfull completion
		record = None

		myPath = re.compile('^(.*?)(\?.*)$').match(self.path)
		if myPath:
			my_path = myPath.group(1)
		else:
			my_path = self.path

		config.section(self.headers['host'])

		if config['sub'] != None and config['strip'] != None and len(config['strip']) > 0:
			string = re.compile(config['strip']).sub(config['sub'], my_path)
			my_path = string

		info = 'Checking file: ' + my_path

		if not os.access(config['dir'], os.X_OK):
			os.mkdir(config['dir'])
		# this is file index - everything is stored in this file
		# _parts - list of stored parts of file
		# _time - last time the file was checked
		# everything else is just the headers
		index = bsddb.dbshelve.open(config['dir'] + os.sep + '.index')

		desc_fields = const_desc_fields.copy()
		ignore_fields = const_ignore_fields.copy()
		if config['noetag'] == 'no':
			desc_fields.add('etag')
		else:
			ignore_fields.add('etag')

		proxy_ignored = set([
			'accept', 'accept-charset', 'accept-encoding', 'accept-language',
			'cache-control', 'connection', 'content-length', 'cookie',
			'host',
			'if-modified-since', 'if-unmodified-since',
			'referer',
			'user-agent',
			'via',
			'x-forwarded-for', 'x-last-hr', 'x-last-http-status-code', 'x-removed', 'x-real-ip', 'x-retry-count',
		])

		print('===============[ {} request ]==='.format(self.command))

		for header in self.headers:
			if header in proxy_ignored:
				pass
			elif header in ('range'):
				isRange = re.compile('bytes=(\d+)-(\d+)').match(self.headers[header])
				if isRange:
					requested_ranges = spacemap.SpaceMap({int(isRange.group(1)): int(isRange.group(2)) + 1})
				else:
					return()
			elif header in ('pragma'):
				if my_path in index:
					index[my_path][header] = self.headers[header]
			else:
				print('Unknown header - ', header, ': ', self.headers[header], sep='')
				return()
			print(header, self.headers[header])

		# creating file name from my_path
		file_name = config['dir'] + os.sep + re.compile('%20').sub(' ', my_path)
		# partial file or unfinished download
		temp_name = config['dir'] + os.sep + '.parts' + re.compile('%20').sub(' ', my_path)

		# creating empty placeholder in index
		# if there's no space map and there's no file in real directory - we have no file
		# if there's an empty space map - file is full
		# space map generally covers every bit of file we don't posess currently
		if not my_path in index:
			info += '\nThis one is new.'
			reload = True
			record = {}
		else:
			# forcibly checking file if no file present
			record = index[my_path]
			if os.access(file_name, os.R_OK):
				info += '\nFull file found.'
				file_stat = os.stat(file_name)
			elif '_parts' in index[my_path] and os.access(temp_name, os.R_OK):
				info += '\nPartial file found.'
				file_stat = os.stat(temp_name)
				recheck = True
			else:
				info += '\nFile not found or inaccessible.'
				record['_parts'] = None
				reload = True

		if not '_parts' in record:
			record['_parts'] = None

		if record['_parts'] == None:
			recheck = True

		# forcibly checking file if file size doesn't match with index data
		if not reload:
			if '_parts' in record and record['_parts'] == spacemap.SpaceMap():
				if 'content-length' in record and file_stat and file_stat.st_size != int(record['content-length']):
					info += '\nFile size is {} and stored file size is {}.'.format(file_stat.st_size, record['content-length'])
					record['_parts'] = None
					reload = True

		# forcibly checking file if index holds Pragma header
		if not reload and 'pragma' in record and record['pragma'] == 'no-cache':
			info +='\nPragma on: recheck imminent.'
			recheck = True

		# skipping file processing if there's no need to recheck it and we have checked it at least 4 hours ago
		if not recheck and not reload and '_time' in record and (record['_time'] - datetime.datetime.now() + datetime.timedelta(hours = 4)).days < 0:
			info += '\nFile is old - rechecking.'
			recheck = True

		print(info)
		if reload or recheck:

			try:
				request = 'http://' + config['root'] + self.path
				my_headers = {}
				for header in ('cache-control', 'cookie', 'referer', 'user-agent'):
					if header in self.headers:
						my_headers[header] = self.headers[header]

				needed = None
				if self.command not in ('HEAD'):
					if '_parts' in record and record['_parts'] != None:
						if config['noparts'] != 'no' or requested_ranges == None or requested_ranges == spacemap.SpaceMap():
							needed = record['_parts']
						else:
							needed = record['_parts'] & requested_ranges
					elif config['noparts'] =='no' and requested_ranges != None and requested_ranges != spacemap.SpaceMap():
						needed = requested_ranges
					ranges = ()
					print('Missing ranges: {}, requested ranges: {}, needed ranges: {}.'.format(record['_parts'], requested_ranges, needed))
					if needed != None and len(needed) > 0:
						needed.rewind()
						while True:
							range = needed.pop()
							if range[0] == None:
								break
							ranges += '{}-{}'.format(range[0], range[1] - 1),
						my_headers['range'] = 'bytes=' + ','.join(ranges)

				request = urllib2.Request(request, headers = my_headers)

				source = urllib2.urlopen(request)
				new_record = {}
				new_record['_parts'] = record['_parts']
				headers = source.info()

				# stripping unneeded headers (XXX make this inplace?)
				for header in headers:
					if header in desc_fields:
						#if header == 'Pragma' and headers[header] != 'no-cache':
						if header == 'content-length':
							if 'content-range' not in headers:
								new_record[header] = int(headers[header])
						else:
							new_record[header] = headers[header]
					elif header == 'content-range':
						range = re.compile('^bytes (\d+)-(\d+)/(\d+)$').match(headers[header])
						if range:
							new_record['content-length'] = int(range.group(3))
						else:	
							assert False, 'Content-Range unrecognized.'
					elif not header in ignore_fields:
						print('Undefined header "', header, '": ', headers[header], sep='')

				# comparing headers with data found in index
				# if any header has changed (except Pragma) file is fully downloaded
				# same if we get more or less headers
				old_keys = set(record.keys())
				old_keys.discard('_time')
				old_keys.discard('pragma')
				more_keys = set(new_record.keys()) - old_keys
				more_keys.discard('pragma')
				less_keys = old_keys - set(new_record.keys())
				if len(more_keys) > 0:
					if len(old_keys) != 0:
						print('More headers appear:', more_keys)
					reload = True
				elif len(less_keys) > 0:
					print('Less headers appear:', less_keys)
				else:
					for key in record.keys():
						if key[0] != '_' and key != 'pragma' and record[key] != new_record[key]:
							print('Header "', key, '" changed from [', record[key], '] to [', new_record[key], ']', sep='')
							print(type(record[key]), type(new_record[key]))
							reload = True

				if reload:
					print('Reloading.')
					if os.access(temp_name, os.R_OK):
						os.unlink(temp_name)
					if os.access(file_name, os.R_OK):
						os.unlink(file_name)
					if 'content-length' in new_record:
						new_record['_parts'] = spacemap.SpaceMap({0: int(new_record['content-length'])})
				if not new_record['_parts']:
					new_record['_parts'] = spacemap.SpaceMap()
				print(new_record)

				# downloading file or segment
				if 'content-length' in new_record:
					if needed == None:
						needed = new_record['_parts']
					else:
						if len(needed) > 1:
							print("Multipart requests currently not supported.")
							assert False, 'Skip this one for now.'
				#else:
					#assert False, 'No content-length or Content-Range header.'

				new_record['_time'] = datetime.datetime.now()
				if self.command not in ('HEAD'):
					# file is created at temporary location and moved in place only when download completes
					if not os.access(temp_name, os.R_OK):
						empty_name = config['dir'] + os.sep + '.tmp'
						with open(empty_name, 'w+b') as some_file:
							pass
						os.renames(empty_name, temp_name)
					temp_file = open(temp_name, 'r+b')
					if requested_ranges == None and needed == None:
						needed = new_record['_parts']
					needed.rewind()
					while True:
						(start, end) = needed.pop()
						if start == None:
							break
						stream_last = start
						old_record = copy.copy(new_record)
						if end - start < block_size:
							req_block_size = end - start
						else:
							req_block_size = block_size
						buffer = source.read(req_block_size)
						length = len(buffer)
						while length > 0 and stream_last < end:
							stream_pos = stream_last + length
							assert stream_pos <= end, 'Received more data then requested: pos:{} start:{} end:{}.'.format(stream_pos, start, end)
							temp_file.seek(stream_last)
							temp_file.write(buffer)
							x = new_record['_parts'] - spacemap.SpaceMap({stream_last: stream_pos})
							new_record['_parts'] = new_record['_parts'] - spacemap.SpaceMap({stream_last: stream_pos})
							index[my_path] = old_record
							index.sync()
							old_record = copy.copy(new_record)
							stream_last = stream_pos
							if end - stream_last < block_size:
								req_block_size = end - stream_last
							buffer = source.read(req_block_size)
							length = len(buffer)
					# moving downloaded data to real file
					temp_file.close()

				index[my_path] = new_record
				index.sync()

			except urllib2.HTTPError as error:
				# in case of error we don't need to do anything actually,
				# if file download stalls or fails the file would not be moved to it's location
				print(error)

		print(index[my_path])

		if not os.access(file_name, os.R_OK) and os.access(temp_name, os.R_OK) and '_parts' in index[my_path] and index[my_path]['_parts'] == spacemap.SpaceMap():
			# just moving
			# drop old dirs XXX
			print('Moving temporary file to new destination.')
			os.renames(temp_name, file_name)

		if not my_path in index:
			self.send_response(502)
			self.end_headers()
			return

		if self.command == 'HEAD':
			self.send_response(200)
			if 'content-length' in index[my_path]:
				self.send_header('content-length', index[my_path]['content-length'])
			self.send_header('accept-ranges', 'bytes')
			self.send_header('content-type', 'application/octet-stream')
			if 'last-modified' in index[my_path]:
				self.send_header('last-modified', index[my_path]['last-modified'])
			self.end_headers()
		else:
			if ('_parts' in index[my_path] and index[my_path]['_parts'] != spacemap.SpaceMap()) or not os.access(file_name, os.R_OK):
				file_name = temp_name

			with open(file_name, 'rb') as real_file:
				file_stat = os.stat(file_name)
				if 'range' in self.headers:
					self.send_response(206)
					ranges = ()
					requested_ranges.rewind()
					while True:
						pair = requested_ranges.pop()
						if pair[0] == None:
							break
						ranges += '{}-{}'.format(pair[0], str(pair[1] - 1)),
					self.send_header('content-range', 'bytes {}/{}'.format(','.join(ranges), index[my_path]['content-length']))
				else:
					self.send_response(200)
					self.send_header('content-length', str(file_stat.st_size))
					requested_ranges = spacemap.SpaceMap({0: file_stat.st_size})
				if 'last-modified' in index[my_path]:
					self.send_header('last-modified', index[my_path]['last-modified'])
				self.send_header('content-type', 'application/octet-stream')
				self.end_headers()
				if self.command in ('GET'):
					if len(requested_ranges) > 0:
						requested_ranges.rewind()
						(start, end) = requested_ranges.pop()
					else:
						start = 0
						# XXX ugly hack
						if 'content-length' in index[my_path]:
							end = index[my_path]['content-length']
						else:
							end = 0
					real_file.seek(start)
					if block_size > end - start:
						req_block_size = end - start
					else:
						req_block_size = block_size
					buffer = real_file.read(req_block_size)
					length = len(buffer)
					while length > 0:
						self.wfile.write(buffer)
						start += len(buffer)
						if req_block_size > end - start:
							req_block_size = end - start
						if req_block_size == 0:
							break
						buffer = real_file.read(req_block_size)
						length = len(buffer)
				
	def do_HEAD(self):
		return self.__process()
	def do_GET(self):
		return self.__process()

config.section('general')
server = BaseHTTPServer.HTTPServer(('127.0.0.1', int(config['port'])), MyRequestHandler)
server.serve_forever()