From e3937118596b5d052ab74fb3fa61d48689b06cff Mon Sep 17 00:00:00 2001 From: Martin Zimmermann Date: Fri, 28 Mar 2014 11:11:48 +0100 Subject: [PATCH] use HTTP_REFERER as fallback if HTTP_ORIGIN is not sent Also refactor those function a bit and move doctests into a separate module. --- isso/__init__.py | 8 +++--- isso/dispatch.py | 5 ---- isso/utils/__init__.py | 25 +----------------- isso/utils/http.py | 4 +-- isso/utils/parse.py | 37 ++++----------------------- isso/wsgi.py | 57 +++++++++++++++++++++++++++++++++++++++++- specs/test_cors.py | 3 +-- specs/test_wsgi.py | 48 +++++++++++++++++++++++++++++++++++ 8 files changed, 118 insertions(+), 69 deletions(-) create mode 100644 specs/test_wsgi.py diff --git a/isso/__init__.py b/isso/__init__.py index 168b99e..61a65f6 100644 --- a/isso/__init__.py +++ b/isso/__init__.py @@ -64,7 +64,8 @@ local_manager = LocalManager([local]) from isso import db, migrate, wsgi, ext, views from isso.core import ThreadedMixin, ProcessMixin, uWSGIMixin, Config -from isso.utils import parse, http, JSONRequest, origin, html +from isso.wsgi import origin, urlsplit +from isso.utils import http, JSONRequest, html from isso.views import comments from isso.ext.notifications import Stdout, SMTP @@ -183,7 +184,8 @@ def make_app(conf=None, threading=True, multiprocessing=False, uwsgi=False): wrapper.append(partial(wsgi.CORSMiddleware, origin=origin(isso.conf.getiter("general", "host")), - allowed=("Origin", "Content-Type"), exposed=("X-Set-Cookie", "Date"))) + allowed=("Origin", "Referer", "Content-Type"), + exposed=("X-Set-Cookie", "Date"))) wrapper.extend([wsgi.SubURI, ProxyFix]) @@ -222,7 +224,7 @@ def main(): sys.exit(1) if conf.get("server", "listen").startswith("http://"): - host, port, _ = parse.host(conf.get("server", "listen")) + host, port, _ = urlsplit(conf.get("server", "listen")) try: from gevent.pywsgi import WSGIServer WSGIServer((host, port), make_app(conf)).serve_forever() diff --git a/isso/dispatch.py b/isso/dispatch.py index fd13b98..14441e2 100644 --- a/isso/dispatch.py +++ b/isso/dispatch.py @@ -3,11 +3,6 @@ import os import logging -try: - from urlparse import urlparse -except ImportError: - from urllib.parse import urlparse - from werkzeug.wsgi import DispatcherMiddleware from werkzeug.wrappers import Response diff --git a/isso/utils/__init__.py b/isso/utils/__init__.py index 8dd7e4d..6a6da27 100644 --- a/isso/utils/__init__.py +++ b/isso/utils/__init__.py @@ -5,7 +5,6 @@ from __future__ import division import pkg_resources werkzeug = pkg_resources.get_distribution("werkzeug") -import io import json import hashlib @@ -120,27 +119,5 @@ class JSONResponse(Response): def __init__(self, obj, *args, **kwargs): kwargs["content_type"] = "application/json" - return super(JSONResponse, self).__init__( + super(JSONResponse, self).__init__( json.dumps(obj).encode("utf-8"), *args, **kwargs) - - -def origin(hosts): - """ - Return a function that returns a valid HTTP Origin or localhost - if none found. - """ - - hosts = [x.rstrip("/") for x in hosts] - - def func(environ): - - if not hosts: - return "http://localhost/" - - for host in hosts: - if environ.get("HTTP_ORIGIN", None) == host: - return host - else: - return hosts[0] - - return func diff --git a/isso/utils/http.py b/isso/utils/http.py index cbd179c..7a347c6 100644 --- a/isso/utils/http.py +++ b/isso/utils/http.py @@ -7,7 +7,7 @@ try: except ImportError: import http.client as httplib -from isso.utils import parse +from isso.wsgi import urlsplit class curl(object): @@ -29,7 +29,7 @@ class curl(object): def __enter__(self): - host, port, ssl = parse.host(self.host) + host, port, ssl = urlsplit(self.host) http = httplib.HTTPSConnection if ssl else httplib.HTTPConnection self.con = http(host, port, timeout=self.timeout) diff --git a/isso/utils/parse.py b/isso/utils/parse.py index b6e39b0..cf5ec78 100644 --- a/isso/utils/parse.py +++ b/isso/utils/parse.py @@ -1,20 +1,20 @@ from __future__ import print_function -import re import datetime - from itertools import chain +import re + + try: from urllib import unquote - from urlparse import urlparse except ImportError: - from urllib.parse import urlparse, unquote + from urllib.parse import unquote import html5lib -from isso.compat import map, filter, PY2K, string_types, text_type as str +from isso.compat import map, filter, PY2K if PY2K: # http://bugs.python.org/issue12984 from xml.dom.minidom import NamedNodeMap @@ -51,33 +51,6 @@ def timedelta(value): return datetime.timedelta(**kwargs) -def host(name): - """ - Parse :param name: into `httplib`-compatible host:port. - - >>> host("http://example.tld/") - ('example.tld', 80, False) - >>> host("https://example.tld/") - ('example.tld', 443, True) - >>> host("example.tld") - ('example.tld', 80, False) - >>> host("example.tld:42") - ('example.tld', 42, False) - >>> host("https://example.tld:80/") - ('example.tld', 80, True) - """ - - if not (isinstance(name, string_types)): - name = str(name) - - if not name.startswith(('http://', 'https://')): - name = 'http://' + name - - rv = urlparse(name) - if rv.scheme == 'https' and rv.port is None: - return (rv.netloc, 443, True) - return (rv.netloc.rsplit(':')[0], rv.port or 80, rv.scheme == 'https') - def thread(data, default=u"Untitled.", id=None): """ diff --git a/isso/wsgi.py b/isso/wsgi.py index 66e1f06..f0a9380 100644 --- a/isso/wsgi.py +++ b/isso/wsgi.py @@ -3,12 +3,13 @@ import socket try: - from urllib.parse import quote + from urllib.parse import quote, urlparse from socketserver import ThreadingMixIn from http.server import HTTPServer except ImportError: from urllib import quote + from urlparse import urlparse from SocketServer import ThreadingMixIn from BaseHTTPServer import HTTPServer @@ -16,6 +17,8 @@ except ImportError: from werkzeug.serving import WSGIRequestHandler from werkzeug.datastructures import Headers +from isso.compat import string_types + def host(environ): """ @@ -40,6 +43,58 @@ def host(environ): return url + quote(environ.get('SCRIPT_NAME', '')) +def urlsplit(name): + """ + Parse :param:`name` into (netloc, port, ssl) + """ + + if not (isinstance(name, string_types)): + name = str(name) + + if not name.startswith(('http://', 'https://')): + name = 'http://' + name + + rv = urlparse(name) + if rv.scheme == 'https' and rv.port is None: + return (rv.netloc, 443, True) + return (rv.netloc.rsplit(':')[0], rv.port or 80, rv.scheme == 'https') + + +def urljoin(netloc, port, ssl): + """ + Basically the counter-part of :func:`urlsplit`. + """ + + rv = ("https" if ssl else "http") + "://" + netloc + if ssl and port != 443 or not ssl and port != 80: + rv += ":%i" % port + return rv + + +def origin(hosts): + """ + Return a function that returns a valid HTTP Origin or localhost + if none found. + """ + + hosts = [urlsplit(h) for h in hosts] + + def func(environ): + + loc = environ.get("HTTP_ORIGIN", environ.get("HTTP_REFERER", None)) + + if not hosts or not loc: + return "http://invalid.local" + + for split in hosts: + if urlsplit(loc) == split: + return urljoin(*split) + else: + return urljoin(*hosts[0]) + + return func + + class SubURI(object): def __init__(self, app): diff --git a/specs/test_cors.py b/specs/test_cors.py index 0479abe..c8f7258 100644 --- a/specs/test_cors.py +++ b/specs/test_cors.py @@ -9,8 +9,7 @@ except ImportError: from werkzeug.test import Client from werkzeug.wrappers import Response -from isso.wsgi import CORSMiddleware -from isso.utils import origin +from isso.wsgi import CORSMiddleware, origin def hello_world(environ, start_response): diff --git a/specs/test_wsgi.py b/specs/test_wsgi.py new file mode 100644 index 0000000..516858e --- /dev/null +++ b/specs/test_wsgi.py @@ -0,0 +1,48 @@ +# -*- encoding: utf-8 -*- + +try: + import unittest2 as unittest +except ImportError: + import unittest + + +from isso import wsgi + + +class TestWSGIUtilities(unittest.TestCase): + + def test_urlsplit(self): + + examples = [ + ("http://example.tld/", ('example.tld', 80, False)), + ("https://example.tld/", ('example.tld', 443, True)), + ("example.tld", ('example.tld', 80, False)), + ("example.tld:42", ('example.tld', 42, False)), + ("https://example.tld:80/", ('example.tld', 80, True))] + + for (hostname, result) in examples: + self.assertEqual(wsgi.urlsplit(hostname), result) + + def test_urljoin(self): + + examples = [ + (("example.tld", 80, False), "http://example.tld"), + (("example.tld", 42, True), "https://example.tld:42"), + (("example.tld", 443, True), "https://example.tld")] + + for (split, result) in examples: + self.assertEqual(wsgi.urljoin(*split), result) + + def test_origin(self): + + self.assertEqual(wsgi.origin([])({}), "http://invalid.local") + + origin = wsgi.origin(["http://foo.bar/", "https://foo.bar"]) + self.assertEqual(origin({"HTTP_ORIGIN": "http://foo.bar"}), + "http://foo.bar") + self.assertEqual(origin({"HTTP_ORIGIN": "https://foo.bar"}), + "https://foo.bar") + self.assertEqual(origin({"HTTP_REFERER": "http://foo.bar"}), + "http://foo.bar") + self.assertEqual(origin({"HTTP_ORIGIN": "http://spam.baz"}), + "http://foo.bar")