Now using urlparse external library
authorJean-Michel Nirgal Vourgère <jmv@nirgal.com>
Wed, 6 Jan 2010 17:05:52 +0000 (17:05 +0000)
committerJean-Michel Nirgal Vourgère <jmv@nirgal.com>
Wed, 6 Jan 2010 17:05:52 +0000 (17:05 +0000)
sproxy

diff --git a/sproxy b/sproxy
index cd3f852a58821b90bf69f9f7f59b42df6d4e1d4e..ab96976baf3c8acce1805bf9a59ccc3d280f5477 100755 (executable)
--- a/sproxy
+++ b/sproxy
@@ -12,6 +12,7 @@ import threading
 from gzip import GzipFile
 from StringIO import StringIO
 from OpenSSL import SSL
+import urlparse
 import base64
 
 
@@ -30,7 +31,7 @@ def format_ip_port(ip, port, *spam):
         return '%s:%s' % (ip, port)
 
 class HttpBase:
-    "Base class for both query & responses"
+    "Base class for both requests & responses"
 
     # Derived class can override that value:
     content_length_default = 0
@@ -44,15 +45,30 @@ class HttpBase:
     def set_line1(self, line1):
         self.line1 = line1
     
-    def get_line1(self):
+    def get_line1(self, *args, **kargs):
         return self.line1
     
     def add_header_line(self, line):
+        # RFC 2616 section 2.2:
+        #  Field names are case insensitive
+        # RFC 2616 section 4.2:
+        #  Multiple message-header fields with the same field-name MAY be
+        #  present in a message if and only if the entire field-value for that
+        #  header field is defined as a comma-separated list [i.e., #(values)].
+        #  It MUST be possible to combine the multiple header fields into one
+        #  "field-name: field-value" pair, without changing the semantics of the
+        #  message, by appending each subsequent field-value to the first, each
+        #  separated by a comma. 
+        # A proxy MUST NOT change the order of fields.
         sp = line.split(':', 1)
         #logging.debug(repr(sp))
         if len(sp)==2:
-            self.headers.append((sp[0].strip(), sp[1].strip()))
+            self.headers.append((sp[0].strip(), sp[1].strip(' \t')))
         else:
+            # FIXME headers can be on multiple lines
+            # See RFC 2616 section 2.2:
+            # HTTP/1.1 header field values can be folded onto multiple lines if
+            # the continuation line begins with a space or horizontal tab.
             self.headers.append((line,))
 
     def get_header_value(self, key):
@@ -61,9 +77,9 @@ class HttpBase:
                 return header[1]
         return None
     
-    def all_headers(self):
+    def all_headers(self, *args, **kargs):
         result = ''
-        line1 = self.get_line1()
+        line1 = self.get_line1(*args, **kargs)
         if line1 is not None:
             result += line1+'\r\n'
         for header in self.headers:
@@ -161,21 +177,37 @@ class HttpBase:
                 if line is None:
                     break # no more token, continue recv
                 if line == '':
-                    self.headers_complete = True
-                elif self.line1 == None:
+                    if self.line1 is not None:
+                        self.headers_complete = True
+                    # else
+                    # See RFC 2616 section 4.1:
+                    # If the server is reading the protocol stream at the beginning of a
+                    #  message and receives a CRLF first, it should ignore the CRLF
+                elif self.line1 is None:
                     self.set_line1(line)
                 else:
                     self.add_header_line(line)
             if self.is_data_complete():
                 return
     
-    def send_to(self, sock):
-        sock.send(self.all_headers())
-        if self.data != '':
-            sock.send(self.data)
+    def send_to(self, sock, abs_path=False):
+        """
+        Sends that http information to an opened socked
+        If abs_path is true, it will remove the scheme/hostname part
+        Otherwise, it will produce a full absolute url
+        """
+        try:
+            sock.send(self.all_headers(abs_path=abs_path))
+            if self.data != '':
+                sock.send(self.data)
+        except socket.error, err:
+            logging.error('Error during sock.send: %s', err.args[1])
+            # do nothing
 
     def debug_dump_line1(self):
-        logging.debug(self.get_line1())
+        line1 = self.get_line1()
+        if line1 is not None:
+            logging.debug(self.get_line1())
 
     def debug_dump_headers(self):
         for header in self.headers:
@@ -206,27 +238,69 @@ class HttpBase:
         else:
             self.debug_dump_line1()
 
-class HttpQuery(HttpBase):
+class HttpRequest(HttpBase):
     content_length_default = 0 # default is no data for queries
 
     def __init__(self):
         HttpBase.__init__(self)
         self.http_method = ''
-        self.url = '/'
         self.http_version = 'HTTP/1.1'
-        self.scheme = 'http'
+        self.parsed_url = None
 
     def set_line1(self, line1):
         self.line1 = line1
         splits = line1.split(' ')
         if len(splits) == 3:
-            self.http_method, self.url, self.http_version = splits
+            self.http_method, url, self.http_version = splits
+            self.parsed_url = urlparse.urlparse(url, scheme='http')
+        else:
+            logging.error("Can't parse http request line %s", line1)
 
-    def get_line1(self):
-        if self.http_method is not None:
-            return self.http_method+' '+self.url+' '+self.http_version
+    def get_line1(self, abs_path=False, *args, **kargs):
+        if self.http_method:
+            if abs_path:
+                url = urlparse.urlunparse(['', ''] + list(self.parsed_url[2:]))
+            else:
+                url = self.parsed_url.geturl()
+            return self.http_method + ' ' + url + ' ' + self.http_version
         return self.line1
 
+    def clean_host_request(self):
+        def split_it(host_port):
+            # 'www.google.com:80' -> 'www.google.com', '80'
+            # 'www.google.com' -> 'www.google.com', ''
+            sp = host_port.split(':', 1)
+            if len(sp) == 2:
+                return sp
+            else:
+                return sp[0], None
+        def join_it(host, port):
+            result = host
+            if port:
+                result += ':' + str(port)
+            return result
+
+        request_hostname = self.parsed_url.hostname
+        request_port = self.parsed_url.port
+        request_netloc = join_it(request_hostname, request_port)
+        header_necloc = self.get_header_value('Host')
+        header_hostname, header_port = split_it(header_necloc)
+        if not request_hostname and header_hostname:
+            # copy "Host" header into request netloc
+            self.parsed_url = urlparse.BaseResult(self.parsed_url.scheme,
+                header_hostname,
+                *self.parsed_url[2:])
+        elif request_hostname:
+            if request_netloc != header_necloc:
+                # RFC 2616, section 5.2: Host header must be ignored FIXME
+                logging.warning('Ignoring necloc value %s in request. Header "Host" value is %s', request_netloc, header_necloc)
+                for i in range(len(self.headers)):
+                    if self.headers[i][0].lower()=='host':
+                        self.headers[i][1] = request_netloc
+                # Patch header here
+            elif not header_necloc:
+                self.headers.append(('Host', request_netloc))
+
     def clean_hop_headers(self):
         #remove any Proxy-* header
         i = 0
@@ -275,11 +349,11 @@ class HttpErrorResponse(HttpResponse):
         self.add_header_line('Date: %s' % ctime())
         self.data = errmsg or errtitle
 
-def run_query(query):
-    host = query.get_header_value('Host')
-    sp = host.split(':', 1)
-    if len(sp)==2:
-        host, port = sp
+
+def run_request_http(request):
+    host = request.parsed_url.hostname
+    port = request.parsed_url.port
+    if port:
         port = int(port) # TODO except
     else:
         port = 80
@@ -320,7 +394,7 @@ def run_query(query):
     if sock is None:
         return HttpErrorResponse(404, 'Not found', 'Can\'t connect to %s.' % host)
 
-    query.send_to(sock)
+    request.send_to(sock, abs_path=True)
     response = HttpResponse()
     response.recv_from(sock)
     response.decompress_data()
@@ -333,11 +407,11 @@ class ProxyConnectionIn(threading.Thread):
         self.clientsocket = clientsocket
 
     def run(self):
-        query_in = HttpQuery()
-        query_in.recv_from(self.clientsocket)
+        request_in = HttpRequest()
+        request_in.recv_from(self.clientsocket)
         if options.auth:
             # RFC2617
-            proxy_auth = query_in.get_header_value('Proxy-Authorization')
+            proxy_auth = request_in.get_header_value('Proxy-Authorization')
             if proxy_auth is not None and proxy_auth.startswith('Basic '):
                 proxy_auth = proxy_auth[len('Basic '):]
                 # logging.debug('proxy_auth raw: %s', proxy_auth)
@@ -349,57 +423,48 @@ class ProxyConnectionIn(threading.Thread):
                     'Proxy requires an authorization.')
                 response.add_header_line('Proxy-Authenticate: Basic realm="Spy proxy"')
                 if options.debug_raw_messages:
-                    query_in.debug_dump('QUERY RAW')
+                    request_in.debug_dump('REQUEST RAW')
                     response.debug_dump('RESPONSE RAW')
                 response.send_to(self.clientsocket)
                 self.clientsocket.shutdown(socket.SHUT_RDWR)
                 self.clientsocket.close()
                 return
-        if query_in.http_method in ('OPTIONS', 'GET', 'HEAD', 'POST', 'PUT', 'DELETE', 'TRACE'):
+
+        if request_in.http_method in ('OPTIONS', 'GET', 'HEAD', 'POST', 'PUT', 'DELETE', 'TRACE'):
             if options.debug_raw_messages:
-                query_in.debug_dump('QUERY RAW')
-            sp = query_in.url.split('://', 1)
-            if len(sp) == 2:
-                scheme, query_address = sp
-            else:
-                logging.debug('Can\'t find scheme in url %s. Assuming http', query_in.url)
-                scheme, query_address = 'http', sp[0]
+                request_in.debug_dump('REQUEST RAW')
 
-            if scheme == 'http':
-                sp = query_address.split('/', 1)
-                if len(sp)!=2:
-                    self.clientsocket.send('HTTP/1.0 400 Bad Request\r\n\r\nCan\'t parse url.\r\n')
+            if request_in.parsed_url and request_in.parsed_url.scheme == 'http':
+                if not request_in.parsed_url.path:
+                    response = HttpErrorResponse(400, 'Bad Request', "Url misses path.")
+                    response.send_to(self.clientsocket)
                     self.clientsocket.shutdown(socket.SHUT_RDWR)
                     self.clientsocket.close()
                     return
-                host = sp[0]
-                query_in.url = '/'+sp[1]
-                query_host = query_in.get_header_value('Host')
-                #  logging.debug('host=%s', host)
-                #  logging.debug('query_host=%s', query_host)
-                if query_host is None:
-                    query_in.headers.append('Host', host)
-                elif query_host != host:
-                    logging.warning('Ignoring host value %s in query. Header value is %s', host, query_host)
-
-                query_in.clean_hop_headers()
-                query_in.headers.append(('Connection', 'close'))
+                request_in.clean_host_request()
+                request_in.clean_hop_headers()
+                request_in.headers.append(('Connection', 'close'))
 
-                query_in.debug_dump('QUERY')
+                request_in.debug_dump('REQUEST')
 
-                response = run_query(query_in)
+                response = run_request_http(request_in)
 
                 response.debug_dump('RESPONSE')
 
             else:
-                response = HttpErrorResponse(501, 'Not implemented', 'Unsupported scheme %s.' % scheme)
+                response = HttpErrorResponse(501, 'Not implemented', 'Unsupported scheme %s.' % request_in.parsed_url.scheme)
+
+            logging.info("%s %s %s %s %s", request_in.http_method, request_in.parsed_url.geturl(), request_in.http_version, response.line1[9:12], len(response.data) or '-')
 
             response.send_to(self.clientsocket)
-            self.clientsocket.shutdown(socket.SHUT_RDWR)
+            try:
+                self.clientsocket.shutdown(socket.SHUT_RDWR)
+            except socket.error, err:
+                logging.error('Error during socket.shutdown: %s', err.args[1])
             self.clientsocket.close()
 
-        elif query_in.http_method == 'CONNECT':
-            query_in.debug_dump('QUERY CONNECT')
+        elif request_in.http_method == 'CONNECT':
+            request_in.debug_dump('REQUEST CONNECT')
             logging.warning('Method CONNECT in development')
             self.clientsocket.send('HTTP/1.1 200 Proceed\r\n\r\n\r\n')
 
@@ -409,18 +474,18 @@ class ProxyConnectionIn(threading.Thread):
             ssl_sock = SSL.Connection(ssl_context, self.clientsocket)
             ssl_sock.set_accept_state()
 
-            query_in_ssl = HttpQuery()
-            query_in_ssl.recv_from(ssl_sock)
+            request_in_ssl = HttpRequest()
+            request_in_ssl.recv_from(ssl_sock)
 
             if options.debug_raw_messages:
-                query_in_ssl.debug_dump('QUERY SSL RAW')
+                request_in_ssl.debug_dump('REQUEST SSL RAW')
 
-            query_in_ssl.clean_hop_headers()
-            query_in_ssl.headers.append(('Connection', 'close'))
+            request_in_ssl.clean_hop_headers()
+            request_in_ssl.headers.append(('Connection', 'close'))
 
-            query_in_ssl.debug_dump('QUERY')
+            request_in_ssl.debug_dump('REQUEST')
 
-            response = run_query(query_in_ssl)
+            response = run_request_http(request_in_ssl)
 
             response.debug_dump('RESPONSE')
 
@@ -430,9 +495,10 @@ class ProxyConnectionIn(threading.Thread):
             #self.clientsocket.shutdown(socket.SHUT_RDWR)
             #self.clientsocket.close()
         else:
-            query_in.debug_dump('QUERY')
-            logging.error('Method %s not supported', query_in.http_method)
-            self.clientsocket.send('HTTP/1.1 405 Method not allowed\r\n\r\nSorry method %s is not supported by the proxy.\r\n' % query_in.http_method)
+            request_in.debug_dump('REQUEST')
+            logging.error('Method %s not supported', request_in.http_method)
+            # FIXME RFC 2616, section 14.7: We should return an "Allow" header
+            self.clientsocket.send('HTTP/1.1 405 Method not allowed\r\n\r\nSorry method %s is not supported by the proxy.\r\n' % request_in.http_method)
             self.clientsocket.close()
 
 
@@ -465,7 +531,7 @@ def main():
     serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
     logging.info('Listening on %s', format_ip_port(*sockaddr))
     serversocket.bind(sockaddr)
-    serversocket.listen(3)
+    serversocket.listen(30)
 
     while True:
         try: