Https handling
authorJean-Michel Nirgal Vourgère <jmv@nirgal.com>
Fri, 8 Jan 2010 13:27:00 +0000 (13:27 +0000)
committerJean-Michel Nirgal Vourgère <jmv@nirgal.com>
Fri, 8 Jan 2010 13:27:00 +0000 (13:27 +0000)
sproxy

diff --git a/sproxy b/sproxy
index ab96976baf3c8acce1805bf9a59ccc3d280f5477..3580560602169a080e07b8065b92dea02b0ddea6 100755 (executable)
--- a/sproxy
+++ b/sproxy
@@ -3,6 +3,11 @@
 
 # To generate a certificate:
 # openssl req -nodes -new -x509 -keyout proxy.key -out proxy.crt -days 10000
+# openssl req -nodes -new -x509 -keyout proxy.key -out proxy.crt -days 10000 -subj "/O=Spy Proxy/CN=*" -newkey rsa:2048
+#
+# openssl req -nodes -new -subj "/CN=proxy" -days 10000 -keyout proxy.key -out proxy.csr
+# openssl ca -in proxy.csr -out proxy.crt -keyfile ca.key 
+
 
 import sys
 import logging
@@ -153,6 +158,7 @@ class HttpBase:
     def pop_data_line(self):
         """
         Extract a line separated by CRLF from the data buffer
+        Returns None if there is no CRLF left
         """
         p = self.data.find('\r\n')
         if p == -1:
@@ -167,7 +173,7 @@ class HttpBase:
             try:
                 new_raw_data = sock.recv(1500) # usual IP MTU, for speed
             except SSL.Error, err:
-                logging.debug('Error during recv_from: %s', err)
+                logging.debug('Error during sock.recv: %s', repr(err))
                 return # connection failure
             if not new_raw_data:
                 return # connection was closed
@@ -200,7 +206,7 @@ class HttpBase:
             sock.send(self.all_headers(abs_path=abs_path))
             if self.data != '':
                 sock.send(self.data)
-        except socket.error, err:
+        except (socket.error, SSL.SysCallError), err:
             logging.error('Error during sock.send: %s', err.args[1])
             # do nothing
 
@@ -238,8 +244,10 @@ class HttpBase:
         else:
             self.debug_dump_line1()
 
+
 class HttpRequest(HttpBase):
-    content_length_default = 0 # default is no data for queries
+    # default is no data for requests
+    content_length_default = 0
 
     def __init__(self):
         HttpBase.__init__(self)
@@ -247,12 +255,23 @@ class HttpRequest(HttpBase):
         self.http_version = 'HTTP/1.1'
         self.parsed_url = None
 
+    def recv_from(self, sock):
+        HttpBase.recv_from(self, sock)
+        if options.debug_raw_messages:
+            self.debug_dump('REQUEST RAW')
+    
+    def send_to(self, sock, *args, **kargs):
+        HttpBase.send_to(self, sock, *args, **kargs)
+        if options.debug_raw_messages:
+            self.debug_dump('REQUEST PATCHED')
+
+    
     def set_line1(self, line1):
         self.line1 = line1
         splits = line1.split(' ')
         if len(splits) == 3:
             self.http_method, url, self.http_version = splits
-            self.parsed_url = urlparse.urlparse(url, scheme='http')
+            self.parsed_url = urlparse.urlparse(url)
         else:
             logging.error("Can't parse http request line %s", line1)
 
@@ -287,9 +306,7 @@ class HttpRequest(HttpBase):
         header_hostname, header_port = split_it(header_necloc)
         if not request_hostname and header_hostname:
             # copy "Host" header into request netloc
-            self.parsed_url = urlparse.BaseResult(self.parsed_url.scheme,
-                header_hostname,
-                *self.parsed_url[2:])
+            self.parsed_url = urlparse.ParseResult(self.parsed_url.scheme, header_hostname, *self.parsed_url[2:])
         elif request_hostname:
             if request_netloc != header_necloc:
                 # RFC 2616, section 5.2: Host header must be ignored FIXME
@@ -301,8 +318,12 @@ class HttpRequest(HttpBase):
             elif not header_necloc:
                 self.headers.append(('Host', request_netloc))
 
+    def set_default_scheme(self, scheme):
+        if not self.parsed_url.scheme:
+            self.parsed_url = urlparse.ParseResult(scheme, *self.parsed_url[1:])
+
     def clean_hop_headers(self):
-        #remove any Proxy-* header
+        #remove any Proxy-* header, and hop by hop headers
         i = 0
         while i < len(self.headers):
             key = self.headers[i][0].lower()
@@ -311,11 +332,29 @@ class HttpRequest(HttpBase):
             else:
                 i += 1
     
+    def check_headers_valid(self):
+        if not self.http_method:
+            raise HttpErrorResponse(400, 'Bad Request', 'Http method is required')
+        if not self.parsed_url:
+            raise HttpErrorResponse(400, 'Bad Request', 'Http url is required')
+        if not self.http_version:
+            raise HttpErrorResponse(400, 'Bad Request', 'Http version is required')
 
 class HttpResponse(HttpBase):
     # for responses, default is data until connection closed :
     content_length_default = sys.maxint
     
+    def recv_from(self, sock):
+        HttpBase.recv_from(self, sock)
+        if options.debug_raw_messages:
+            self.debug_dump('RESPONSE RAW')
+    
+    def send_to(self, sock, *args, **kargs):
+        HttpBase.send_to(self, sock, *args, **kargs)
+        if options.debug_raw_messages:
+            self.debug_dump('RESPONSE PATCHED')
+
+
     def decompress_data(self):
         compression_scheme = self.get_header_value('Content-Encoding')
         if compression_scheme == 'gzip':
@@ -350,22 +389,20 @@ class HttpErrorResponse(HttpResponse):
         self.data = errmsg or errtitle
 
 
-def run_request_http(request):
-    host = request.parsed_url.hostname
-    port = request.parsed_url.port
-    if port:
-        port = int(port) # TODO except
-    else:
-        port = 80
-    
+def get_connected_sock(hostname, port):
     try:
-        addrinfo = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM)
+        port = int(port)
+    except ValueError:
+        HttpErrorResponse(500, 'Internal error', "Can't connect to port %s" % port)
+
+    try:
+        addrinfo = socket.getaddrinfo(hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM)
     except socket.gaierror, err:
         if err.args[1] == -2: # Name or service not known
             raise
 
-        logging.debug('Connection to %s failed: %s', format_ip_port(host, port), err.args[1])
-        return HttpErrorResponse(404, 'Not found', 'Can\'resolve %s.' % host)
+        logging.debug('Connection to %s failed: %s', format_ip_port(hostname, port), err.args[1])
+        raise HttpErrorResponse(404, 'Unknown host', 'Can\'resolve %s.' % hostname)
 
     for family, socktype, proto, canonname, sockaddr in addrinfo:
         if options.debug_connections:
@@ -392,11 +429,33 @@ def run_request_http(request):
             break
     
     if sock is None:
-        return HttpErrorResponse(404, 'Not found', 'Can\'t connect to %s.' % host)
+        raise HttpErrorResponse(404, 'Not found', 'Can\'t connect to %s.' % hostname)
+    return sock
+
 
+def run_request_http(request):
+    sock = get_connected_sock(request.parsed_url.hostname, request.parsed_url.port or 80)
     request.send_to(sock, abs_path=True)
     response = HttpResponse()
     response.recv_from(sock)
+    sock.shutdown(socket.SHUT_RDWR)
+    sock.close()
+    response.decompress_data()
+    return response
+
+
+def run_request_https(request):
+    sock = get_connected_sock(request.parsed_url.hostname, request.parsed_url.port or 443)
+    ssl_context = SSL.Context(SSL.SSLv23_METHOD)
+    #ssl_context.use_privatekey_file ('certs/proxy.key')
+    #ssl_context.use_certificate_file('certs/proxy.crt')
+    ssl_sock = SSL.Connection(ssl_context, sock)
+    ssl_sock.set_connect_state()
+    request.send_to(ssl_sock, abs_path=True)
+    response = HttpResponse()
+    response.recv_from(ssl_sock)
+    ssl_sock.shutdown()
+    ssl_sock.close()
     response.decompress_data()
     return response
 
@@ -406,100 +465,97 @@ class ProxyConnectionIn(threading.Thread):
         threading.Thread.__init__(self)
         self.clientsocket = clientsocket
 
+    def check_proxy_auth(self, request_in):
+        # RFC2617
+        if not options.auth:
+            return
+        proxy_auth = request_in.get_header_value('Proxy-Authorization')
+        if proxy_auth is not None and proxy_auth.startswith('Basic '):
+            proxy_auth = proxy_auth[len('Basic '):]
+            # logging.debug('proxy_auth raw: %s', proxy_auth)
+            proxy_auth = base64.b64decode(proxy_auth)
+            #logging.debug('proxy_auth: %s', proxy_auth)
+        if proxy_auth != options.auth:
+            response = HttpErrorResponse('407', 
+                'Proxy Authentication Required', 
+                'Proxy requires an authorization.')
+            response.add_header_line('Proxy-Authenticate: Basic realm="Spy proxy"')
+            raise response
+
     def run(self):
         request_in = HttpRequest()
         request_in.recv_from(self.clientsocket)
-        if options.auth:
-            # RFC2617
-            proxy_auth = request_in.get_header_value('Proxy-Authorization')
-            if proxy_auth is not None and proxy_auth.startswith('Basic '):
-                proxy_auth = proxy_auth[len('Basic '):]
-                # logging.debug('proxy_auth raw: %s', proxy_auth)
-                proxy_auth = base64.b64decode(proxy_auth)
-                #logging.debug('proxy_auth: %s', proxy_auth)
-            if proxy_auth != options.auth:
-                response = HttpErrorResponse('407', 
-                    'Proxy Authentication Required', 
-                    'Proxy requires an authorization.')
-                response.add_header_line('Proxy-Authenticate: Basic realm="Spy proxy"')
-                if options.debug_raw_messages:
-                    request_in.debug_dump('REQUEST RAW')
-                    response.debug_dump('RESPONSE RAW')
-                response.send_to(self.clientsocket)
-                self.clientsocket.shutdown(socket.SHUT_RDWR)
-                self.clientsocket.close()
-                return
+        response = None
 
-        if request_in.http_method in ('OPTIONS', 'GET', 'HEAD', 'POST', 'PUT', 'DELETE', 'TRACE'):
-            if options.debug_raw_messages:
-                request_in.debug_dump('REQUEST RAW')
-
-            if request_in.parsed_url and request_in.parsed_url.scheme == 'http':
-                if not request_in.parsed_url.path:
-                    response = HttpErrorResponse(400, 'Bad Request', "Url misses path.")
-                    response.send_to(self.clientsocket)
-                    self.clientsocket.shutdown(socket.SHUT_RDWR)
-                    self.clientsocket.close()
-                    return
-                request_in.clean_host_request()
-                request_in.clean_hop_headers()
-                request_in.headers.append(('Connection', 'close'))
+        try:
+            self.check_proxy_auth(request_in) # raises 407
+            request_in.check_headers_valid() # raises 400
 
-                request_in.debug_dump('REQUEST')
+            if request_in.http_method in ('OPTIONS', 'GET', 'HEAD', 'POST', 'PUT', 'DELETE', 'TRACE'):
+                if request_in.parsed_url and request_in.parsed_url.scheme == 'http':
+                    
+                    request_in.clean_host_request()
+                    request_in.clean_hop_headers()
+                    request_in.headers.append(('Connection', 'close'))
 
-                response = run_request_http(request_in)
+                    request_in.debug_dump('REQUEST')
 
-                response.debug_dump('RESPONSE')
+                    response = run_request_http(request_in)
 
-            else:
-                response = HttpErrorResponse(501, 'Not implemented', 'Unsupported scheme %s.' % request_in.parsed_url.scheme)
+                    response.debug_dump('RESPONSE')
 
-            logging.info("%s %s %s %s %s", request_in.http_method, request_in.parsed_url.geturl(), request_in.http_version, response.line1[9:12], len(response.data) or '-')
+                else:
+                    raise HttpErrorResponse(501, 'Not implemented', 'Unsupported scheme %s.' % request_in.parsed_url.scheme)
 
-            response.send_to(self.clientsocket)
-            try:
-                self.clientsocket.shutdown(socket.SHUT_RDWR)
-            except socket.error, err:
-                logging.error('Error during socket.shutdown: %s', err.args[1])
-            self.clientsocket.close()
+                logging.info("%s %s %s %s %s", request_in.http_method, request_in.parsed_url.geturl(), request_in.http_version, response.line1[9:12], len(response.data) or '-')
+
+            elif request_in.http_method == 'CONNECT':
+                HttpErrorResponse(200, 'Proceed', '\r\n').send_to(self.clientsocket)
+                #self.clientsocket.send('HTTP/1.1 200 Proceed\r\n\r\n')
 
-        elif request_in.http_method == 'CONNECT':
-            request_in.debug_dump('REQUEST CONNECT')
-            logging.warning('Method CONNECT in development')
-            self.clientsocket.send('HTTP/1.1 200 Proceed\r\n\r\n\r\n')
+                ssl_context = SSL.Context(SSL.SSLv23_METHOD)
+                ssl_context.use_privatekey_file ('certs/proxy.key')
+                ssl_context.use_certificate_file('certs/proxy.crt')
+                ssl_sock = SSL.Connection(ssl_context, self.clientsocket)
+                ssl_sock.set_accept_state()
 
-            ssl_context = SSL.Context(SSL.SSLv23_METHOD)
-            ssl_context.use_privatekey_file ('proxy.key')
-            ssl_context.use_certificate_file('proxy.crt')
-            ssl_sock = SSL.Connection(ssl_context, self.clientsocket)
-            ssl_sock.set_accept_state()
+                request_in_ssl = HttpRequest()
+                request_in_ssl.recv_from(ssl_sock)
 
-            request_in_ssl = HttpRequest()
-            request_in_ssl.recv_from(ssl_sock)
+                request_in_ssl.check_headers_valid() # raises 400
 
-            if options.debug_raw_messages:
-                request_in_ssl.debug_dump('REQUEST SSL RAW')
+                request_in_ssl.clean_hop_headers()
+                request_in_ssl.headers.append(('Connection', 'close'))
 
-            request_in_ssl.clean_hop_headers()
-            request_in_ssl.headers.append(('Connection', 'close'))
+                request_in_ssl.clean_host_request()
+                request_in_ssl.set_default_scheme('https')
 
-            request_in_ssl.debug_dump('REQUEST')
+                response = run_request_https(request_in_ssl)
 
-            response = run_request_http(request_in_ssl)
+                response.send_to(ssl_sock)
+                
+                logging.info("%s %s %s %s %s", request_in_ssl.http_method, request_in_ssl.parsed_url.geturl(), request_in_ssl.http_version, response.line1[9:12], len(response.data) or '-')
+                ssl_sock.shutdown()
+                ssl_sock.close()
+                #self.clientsocket.shutdown(socket.SHUT_RDWR)
+                #self.clientsocket.close()
+                return # bypass classic socket shutdown
+            else:
+                request_in.debug_dump('REQUEST')
+                logging.error('Method %s not supported', request_in.http_method)
+                # FIXME RFC 2616, section 14.7: We should return an "Allow" header
+                self.clientsocket.send('HTTP/1.1 405 Method not allowed\r\n\r\nSorry method %s is not supported by the proxy.\r\n' % request_in.http_method)
+                self.clientsocket.close()
 
-            response.debug_dump('RESPONSE')
+        except HttpErrorResponse, error:
+            response = error
 
-            response.send_to(ssl_sock)
-            ssl_sock.shutdown()
-            ssl_sock.close()
-            #self.clientsocket.shutdown(socket.SHUT_RDWR)
-            #self.clientsocket.close()
-        else:
-            request_in.debug_dump('REQUEST')
-            logging.error('Method %s not supported', request_in.http_method)
-            # FIXME RFC 2616, section 14.7: We should return an "Allow" header
-            self.clientsocket.send('HTTP/1.1 405 Method not allowed\r\n\r\nSorry method %s is not supported by the proxy.\r\n' % request_in.http_method)
-            self.clientsocket.close()
+        response.send_to(self.clientsocket)
+        try:
+            self.clientsocket.shutdown(socket.SHUT_RDWR)
+        except socket.error, err:
+            logging.error('Error during socket.shutdown: %s', err.args[1])
+        self.clientsocket.close()
 
 
 def main():
@@ -545,6 +601,7 @@ def main():
         cnx_thread = ProxyConnectionIn(clientsocket)
         cnx_thread .run()
 
+
 if __name__ == '__main__':
     from optparse import OptionParser #, OptionGroup
     
@@ -567,7 +624,7 @@ if __name__ == '__main__':
 
     parser.add_option('--log-full-transfers',
         action='store_true', dest='log_full_transfers', default=False,
-        help="log full queries and responses")
+        help="log full requests and responses")
     
     parser.add_option('-d', '--debug',
         action='store_true', dest='debug', default=False,