Spy Proxy, alpha version
authorJean-Michel Nirgal Vourgère <jmv@nirgal.com>
Sun, 3 Jan 2010 17:12:55 +0000 (17:12 +0000)
committerJean-Michel Nirgal Vourgère <jmv@nirgal.com>
Sun, 3 Jan 2010 17:12:55 +0000 (17:12 +0000)
sproxy [new file with mode: 0755]

diff --git a/sproxy b/sproxy
new file mode 100755 (executable)
index 0000000..cd3f852
--- /dev/null
+++ b/sproxy
@@ -0,0 +1,525 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+# To generate a certificate:
+# openssl req -nodes -new -x509 -keyout proxy.key -out proxy.crt -days 10000
+
+import sys
+import logging
+from time import ctime
+import socket
+import threading
+from gzip import GzipFile
+from StringIO import StringIO
+from OpenSSL import SSL
+import base64
+
+
+IPV4_IN_IPV6_PREFIX = '::ffff:'
+
+def format_ip_port(ip, port, *spam):
+    "Build a nice printable string for a given sockaddr"
+
+    if ip.startswith(IPV4_IN_IPV6_PREFIX):
+        ip = ip[len(IPV4_IN_IPV6_PREFIX):]
+    if ip == '':
+        ip = '*'
+    if ':' in ip:
+        return '[%s]:%s' % (ip, port)
+    else:
+        return '%s:%s' % (ip, port)
+
+class HttpBase:
+    "Base class for both query & responses"
+
+    # Derived class can override that value:
+    content_length_default = 0
+
+    def __init__(self):
+        self.line1 = None
+        self.headers = []
+        self.data = ''
+        self.headers_complete = False
+    
+    def set_line1(self, line1):
+        self.line1 = line1
+    
+    def get_line1(self):
+        return self.line1
+    
+    def add_header_line(self, line):
+        sp = line.split(':', 1)
+        #logging.debug(repr(sp))
+        if len(sp)==2:
+            self.headers.append((sp[0].strip(), sp[1].strip()))
+        else:
+            self.headers.append((line,))
+
+    def get_header_value(self, key):
+        for header in self.headers:
+            if header[0] == key:
+                return header[1]
+        return None
+    
+    def all_headers(self):
+        result = ''
+        line1 = self.get_line1()
+        if line1 is not None:
+            result += line1+'\r\n'
+        for header in self.headers:
+            result += ': '.join(header)+'\r\n'
+        result += '\r\n'
+        return result
+
+    def is_data_complete(self):
+        if not self.headers_complete:
+            return False
+        if self.get_header_value('Transfer-Encoding')=='chunked':
+            if not hasattr(self, 'chunk_size'):
+                self.chunk_size = 0
+                self.chunk_data = ''
+            while True:
+                if self.chunk_size == 0:
+                    hex_chunk_size = self.pop_data_line()
+                    if options.debug_length:
+                        logging.debug('hex chunk_size=%s', hex_chunk_size) # TODO extensions
+                    if hex_chunk_size is None:
+                        return False # need more data
+                    self.chunk_size = int(hex_chunk_size, 16) # CRLF
+                    if options.debug_length:
+                        logging.debug('chunk_size=%s', self.chunk_size)
+                    if self.chunk_size == 0:
+                        logging.warning('chunk-transfer trailer? :%s', repr(self.data))
+                        self.data = self.chunk_data # TODO trailers
+                        # remove any Transfert-Encoding: chunked
+                        # update Content-Length
+                        content_length_updated = False
+                        i = 0
+                        while i < len(self.headers):
+                            key = self.headers[i][0].lower()
+                            if key == 'transfer-encoding':
+                                del self.headers[i]
+                            else:
+                                if key == 'content-length':
+                                    self.headers[i] = ('Content-Length', str(len(self.data)))
+                                    content_length_updated = True
+                                i += 1
+                        if not content_length_updated:
+                            self.headers.append(('Content-Length', str(len(self.data))))
+                        #self.headers_complete = False
+                        break # we're done with chunking
+                    else:
+                        self.chunk_size += 2
+                l = len(self.data)
+                if self.chunk_size <= l:
+                    l = self.chunk_size
+                    need_more_data = False
+                else:
+                    need_more_data = True
+                self.chunk_data += self.data[:l]
+                self.data = self.data[l:]
+                self.chunk_size -= l
+                if need_more_data:
+                    return False
+                else:
+                    self.chunk_data = self.chunk_data[:-2] # CRLF
+
+        l = self.get_header_value('Content-Length')
+        if l is not None:
+            l = int(l) # TODO execpt
+        else:
+            l = self.content_length_default
+        if options.debug_length:
+            logging.debug('Expected length=%s', l)
+            logging.debug('Current length=%s', len(self.data))
+        return len(self.data) >= l
+
+    def pop_data_line(self):
+        """
+        Extract a line separated by CRLF from the data buffer
+        """
+        p = self.data.find('\r\n')
+        if p == -1:
+            return None
+        line = self.data[:p]
+        self.data = self.data[p+2:]
+        return line
+
+    def recv_from(self, sock):
+        self.data = '' # unparsed data
+        while True:
+            try:
+                new_raw_data = sock.recv(1500) # usual IP MTU, for speed
+            except SSL.Error, err:
+                logging.debug('Error during recv_from: %s', err)
+                return # connection failure
+            if not new_raw_data:
+                return # connection was closed
+            self.data += new_raw_data
+            while not self.headers_complete:
+                line = self.pop_data_line()
+                if line is None:
+                    break # no more token, continue recv
+                if line == '':
+                    self.headers_complete = True
+                elif self.line1 == None:
+                    self.set_line1(line)
+                else:
+                    self.add_header_line(line)
+            if self.is_data_complete():
+                return
+    
+    def send_to(self, sock):
+        sock.send(self.all_headers())
+        if self.data != '':
+            sock.send(self.data)
+
+    def debug_dump_line1(self):
+        logging.debug(self.get_line1())
+
+    def debug_dump_headers(self):
+        for header in self.headers:
+            if len(header)==2:
+                logging.debug('%s: %s', repr(header[0]), repr(header[1]))
+            else:
+                logging.debug('%s (NO VALUE)', repr(header[0]))
+    def debug_dump_data(self):
+        DUMP_DATA_MAX_LEN = 160
+        if self.data:
+            data_length = len(self.data)
+            truncate = data_length > DUMP_DATA_MAX_LEN
+            if truncate:
+                printed_data = repr(self.data[:DUMP_DATA_MAX_LEN])+'...'
+            else:
+                printed_data = repr(self.data)
+
+            logging.debug('data: (%s bytes) %s', data_length, printed_data)
+
+    def debug_dump(self, title='DUMP'):
+        if options.log_full_transfers:
+            l = len(title)
+            logging.debug(title+' '+('-'*(80-l-1)))
+            self.debug_dump_line1()
+            self.debug_dump_headers()
+            self.debug_dump_data()
+            logging.debug('-'*80)
+        else:
+            self.debug_dump_line1()
+
+class HttpQuery(HttpBase):
+    content_length_default = 0 # default is no data for queries
+
+    def __init__(self):
+        HttpBase.__init__(self)
+        self.http_method = ''
+        self.url = '/'
+        self.http_version = 'HTTP/1.1'
+        self.scheme = 'http'
+
+    def set_line1(self, line1):
+        self.line1 = line1
+        splits = line1.split(' ')
+        if len(splits) == 3:
+            self.http_method, self.url, self.http_version = splits
+
+    def get_line1(self):
+        if self.http_method is not None:
+            return self.http_method+' '+self.url+' '+self.http_version
+        return self.line1
+
+    def clean_hop_headers(self):
+        #remove any Proxy-* header
+        i = 0
+        while i < len(self.headers):
+            key = self.headers[i][0].lower()
+            if key.startswith('proxy-') or key in ('connection', 'keep-alive', 'te', 'trailers', 'transfer-encoding', 'upgrade'):
+                del self.headers[i]
+            else:
+                i += 1
+    
+
+class HttpResponse(HttpBase):
+    # for responses, default is data until connection closed :
+    content_length_default = sys.maxint
+    
+    def decompress_data(self):
+        compression_scheme = self.get_header_value('Content-Encoding')
+        if compression_scheme == 'gzip':
+            gzf = GzipFile(fileobj=StringIO(self.data))
+            try:
+                plain_data = gzf.read()
+            except IOError, err:
+                logging.error('Error while decompressing gzip data: %s', err)
+                self.debug_dump('RESPONSE BEFORE DECOMPRESSION')
+                #raise
+            else:
+                # remove any Content-Encoding header
+                # update Content-Length
+                i = 0
+                while i < len(self.headers):
+                    key = self.headers[i][0].lower()
+                    if key == 'content-encoding':
+                        del self.headers[i]
+                    else:
+                        if key == 'content-length':
+                            self.headers[i] = ('Content-Length', str(len(plain_data)))
+                        i += 1
+                self.data = plain_data
+
+
+class HttpErrorResponse(HttpResponse):
+    def __init__(self, errcode, errtitle, errmsg=None):
+        HttpResponse.__init__(self)
+        self.set_line1('HTTP/1.1 %s %s' % (errcode, errtitle))
+        self.add_header_line('Server: Spy proxy')
+        self.add_header_line('Date: %s' % ctime())
+        self.data = errmsg or errtitle
+
+def run_query(query):
+    host = query.get_header_value('Host')
+    sp = host.split(':', 1)
+    if len(sp)==2:
+        host, port = sp
+        port = int(port) # TODO except
+    else:
+        port = 80
+    
+    try:
+        addrinfo = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM)
+    except socket.gaierror, err:
+        if err.args[1] == -2: # Name or service not known
+            raise
+
+        logging.debug('Connection to %s failed: %s', format_ip_port(host, port), err.args[1])
+        return HttpErrorResponse(404, 'Not found', 'Can\'resolve %s.' % host)
+
+    for family, socktype, proto, canonname, sockaddr in addrinfo:
+        if options.debug_connections:
+            logging.debug('Connecting to %s ...', format_ip_port(*sockaddr))
+        sock = socket.socket(family, socktype, proto)
+        try:
+            sock.connect(sockaddr)
+        except socket.gaierror, err:
+            if err.args[0] != -2:
+                raise
+            if options.debug_connections:
+                logging.debug('Connection to %s failed: %s', format_ip_port(*sockaddr), err.args[1])
+            sock = None
+        except socket.error, err:
+            if err.args[0] not in (111, 113): # Connection refused, No route to host
+                raise
+            if options.debug_connections:
+                logging.debug('Connection to %s failed: %s', format_ip_port(*sockaddr), err.args[1])
+            sock = None
+        else:
+            # Connection successfull
+            if options.debug_connections:
+                logging.debug('Connected to %s', format_ip_port(*sockaddr))
+            break
+    
+    if sock is None:
+        return HttpErrorResponse(404, 'Not found', 'Can\'t connect to %s.' % host)
+
+    query.send_to(sock)
+    response = HttpResponse()
+    response.recv_from(sock)
+    response.decompress_data()
+    return response
+
+
+class ProxyConnectionIn(threading.Thread):
+    def __init__(self, clientsocket):
+        threading.Thread.__init__(self)
+        self.clientsocket = clientsocket
+
+    def run(self):
+        query_in = HttpQuery()
+        query_in.recv_from(self.clientsocket)
+        if options.auth:
+            # RFC2617
+            proxy_auth = query_in.get_header_value('Proxy-Authorization')
+            if proxy_auth is not None and proxy_auth.startswith('Basic '):
+                proxy_auth = proxy_auth[len('Basic '):]
+                # logging.debug('proxy_auth raw: %s', proxy_auth)
+                proxy_auth = base64.b64decode(proxy_auth)
+                #logging.debug('proxy_auth: %s', proxy_auth)
+            if proxy_auth != options.auth:
+                response = HttpErrorResponse('407', 
+                    'Proxy Authentication Required', 
+                    'Proxy requires an authorization.')
+                response.add_header_line('Proxy-Authenticate: Basic realm="Spy proxy"')
+                if options.debug_raw_messages:
+                    query_in.debug_dump('QUERY RAW')
+                    response.debug_dump('RESPONSE RAW')
+                response.send_to(self.clientsocket)
+                self.clientsocket.shutdown(socket.SHUT_RDWR)
+                self.clientsocket.close()
+                return
+        if query_in.http_method in ('OPTIONS', 'GET', 'HEAD', 'POST', 'PUT', 'DELETE', 'TRACE'):
+            if options.debug_raw_messages:
+                query_in.debug_dump('QUERY RAW')
+            sp = query_in.url.split('://', 1)
+            if len(sp) == 2:
+                scheme, query_address = sp
+            else:
+                logging.debug('Can\'t find scheme in url %s. Assuming http', query_in.url)
+                scheme, query_address = 'http', sp[0]
+
+            if scheme == 'http':
+                sp = query_address.split('/', 1)
+                if len(sp)!=2:
+                    self.clientsocket.send('HTTP/1.0 400 Bad Request\r\n\r\nCan\'t parse url.\r\n')
+                    self.clientsocket.shutdown(socket.SHUT_RDWR)
+                    self.clientsocket.close()
+                    return
+                host = sp[0]
+                query_in.url = '/'+sp[1]
+                query_host = query_in.get_header_value('Host')
+                #  logging.debug('host=%s', host)
+                #  logging.debug('query_host=%s', query_host)
+                if query_host is None:
+                    query_in.headers.append('Host', host)
+                elif query_host != host:
+                    logging.warning('Ignoring host value %s in query. Header value is %s', host, query_host)
+
+                query_in.clean_hop_headers()
+                query_in.headers.append(('Connection', 'close'))
+
+                query_in.debug_dump('QUERY')
+
+                response = run_query(query_in)
+
+                response.debug_dump('RESPONSE')
+
+            else:
+                response = HttpErrorResponse(501, 'Not implemented', 'Unsupported scheme %s.' % scheme)
+
+            response.send_to(self.clientsocket)
+            self.clientsocket.shutdown(socket.SHUT_RDWR)
+            self.clientsocket.close()
+
+        elif query_in.http_method == 'CONNECT':
+            query_in.debug_dump('QUERY CONNECT')
+            logging.warning('Method CONNECT in development')
+            self.clientsocket.send('HTTP/1.1 200 Proceed\r\n\r\n\r\n')
+
+            ssl_context = SSL.Context(SSL.SSLv23_METHOD)
+            ssl_context.use_privatekey_file ('proxy.key')
+            ssl_context.use_certificate_file('proxy.crt')
+            ssl_sock = SSL.Connection(ssl_context, self.clientsocket)
+            ssl_sock.set_accept_state()
+
+            query_in_ssl = HttpQuery()
+            query_in_ssl.recv_from(ssl_sock)
+
+            if options.debug_raw_messages:
+                query_in_ssl.debug_dump('QUERY SSL RAW')
+
+            query_in_ssl.clean_hop_headers()
+            query_in_ssl.headers.append(('Connection', 'close'))
+
+            query_in_ssl.debug_dump('QUERY')
+
+            response = run_query(query_in_ssl)
+
+            response.debug_dump('RESPONSE')
+
+            response.send_to(ssl_sock)
+            ssl_sock.shutdown()
+            ssl_sock.close()
+            #self.clientsocket.shutdown(socket.SHUT_RDWR)
+            #self.clientsocket.close()
+        else:
+            query_in.debug_dump('QUERY')
+            logging.error('Method %s not supported', query_in.http_method)
+            self.clientsocket.send('HTTP/1.1 405 Method not allowed\r\n\r\nSorry method %s is not supported by the proxy.\r\n' % query_in.http_method)
+            self.clientsocket.close()
+
+
+def main():
+    if options.debug:
+        loglevel = logging.DEBUG
+    else:
+        loglevel = logging.INFO
+    logging.basicConfig(level=loglevel, format='%(asctime)s %(levelname)s %(message)s')
+
+    if options.listen_host:
+        addrinfo = socket.getaddrinfo(options.listen_host, options.listen_port, socket.AF_UNSPEC, socket.SOCK_STREAM)
+        family, socktype, proto, canonname, sockaddr = addrinfo[0]
+    else:
+        if socket.has_ipv6:
+            family = socket.AF_INET6
+        else:
+            family = socket.AF_INET
+        socktype = socket.SOCK_STREAM
+        try:
+            options.listen_port = int(options.listen_port)
+        except ValueError:
+            options.listen_port = socket.getservbyname(options.listen_port)
+        sockaddr = (options.listen_host, options.listen_port)
+
+    # TODO: multicast require more stuff
+    # see http://code.activestate.com/recipes/442490/
+
+    serversocket = socket.socket(family, socktype) #, proto) # TODO
+    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+    logging.info('Listening on %s', format_ip_port(*sockaddr))
+    serversocket.bind(sockaddr)
+    serversocket.listen(3)
+
+    while True:
+        try:
+            clientsocket, address = serversocket.accept()
+        except KeyboardInterrupt:
+            logging.info('Ctrl+C received. Shutting down.')
+            break
+
+        if options.debug_connections:
+            logging.debug('Connection from %s', format_ip_port(*address))
+        cnx_thread = ProxyConnectionIn(clientsocket)
+        cnx_thread .run()
+
+if __name__ == '__main__':
+    from optparse import OptionParser #, OptionGroup
+    
+    parser = OptionParser(usage='%prog [options]')
+    
+    parser.add_option('-b', '--bind',
+        action='store', type='str', dest='listen_host', default='',
+        metavar='HOST',
+        help="listen address, default='%default'")
+
+    parser.add_option('-p', '--port',
+        action='store', type='str', dest='listen_port', default='8080',
+        metavar='PORT',
+        help="listen port, default=%default")
+
+    parser.add_option('--auth',
+        action='store', type='str', dest='auth', default='',
+        metavar='LOGIN:PASSWORD',
+        help="proxy authentification, default='%default'")
+
+    parser.add_option('--log-full-transfers',
+        action='store_true', dest='log_full_transfers', default=False,
+        help="log full queries and responses")
+    
+    parser.add_option('-d', '--debug',
+        action='store_true', dest='debug', default=False,
+        help="debug mode")
+    
+    parser.add_option('--debug-raw-messages',
+        action='store_true', dest='debug_raw_messages', default=False,
+        help="dump raw messages before they are patched")
+   
+    parser.add_option('--debug-connections',
+        action='store_true', dest='debug_connections', default=False,
+        help="dump connections information")
+    
+    parser.add_option('--debug-length',
+        action='store_true', dest='debug_length', default=False,
+        help="dump lengthes information")
+    
+
+    options, args = parser.parse_args()
+    main(*args)
+