use HTTP_REFERER as fallback if HTTP_ORIGIN is not sent

Also refactor those function a bit and move doctests into a separate
module.
pull/80/merge
Martin Zimmermann 10 years ago
parent 09451ff707
commit e393711859

@ -64,7 +64,8 @@ local_manager = LocalManager([local])
from isso import db, migrate, wsgi, ext, views
from isso.core import ThreadedMixin, ProcessMixin, uWSGIMixin, Config
from isso.utils import parse, http, JSONRequest, origin, html
from isso.wsgi import origin, urlsplit
from isso.utils import http, JSONRequest, html
from isso.views import comments
from isso.ext.notifications import Stdout, SMTP
@ -183,7 +184,8 @@ def make_app(conf=None, threading=True, multiprocessing=False, uwsgi=False):
wrapper.append(partial(wsgi.CORSMiddleware,
origin=origin(isso.conf.getiter("general", "host")),
allowed=("Origin", "Content-Type"), exposed=("X-Set-Cookie", "Date")))
allowed=("Origin", "Referer", "Content-Type"),
exposed=("X-Set-Cookie", "Date")))
wrapper.extend([wsgi.SubURI, ProxyFix])
@ -222,7 +224,7 @@ def main():
sys.exit(1)
if conf.get("server", "listen").startswith("http://"):
host, port, _ = parse.host(conf.get("server", "listen"))
host, port, _ = urlsplit(conf.get("server", "listen"))
try:
from gevent.pywsgi import WSGIServer
WSGIServer((host, port), make_app(conf)).serve_forever()

@ -3,11 +3,6 @@
import os
import logging
try:
from urlparse import urlparse
except ImportError:
from urllib.parse import urlparse
from werkzeug.wsgi import DispatcherMiddleware
from werkzeug.wrappers import Response

@ -5,7 +5,6 @@ from __future__ import division
import pkg_resources
werkzeug = pkg_resources.get_distribution("werkzeug")
import io
import json
import hashlib
@ -120,27 +119,5 @@ class JSONResponse(Response):
def __init__(self, obj, *args, **kwargs):
kwargs["content_type"] = "application/json"
return super(JSONResponse, self).__init__(
super(JSONResponse, self).__init__(
json.dumps(obj).encode("utf-8"), *args, **kwargs)
def origin(hosts):
"""
Return a function that returns a valid HTTP Origin or localhost
if none found.
"""
hosts = [x.rstrip("/") for x in hosts]
def func(environ):
if not hosts:
return "http://localhost/"
for host in hosts:
if environ.get("HTTP_ORIGIN", None) == host:
return host
else:
return hosts[0]
return func

@ -7,7 +7,7 @@ try:
except ImportError:
import http.client as httplib
from isso.utils import parse
from isso.wsgi import urlsplit
class curl(object):
@ -29,7 +29,7 @@ class curl(object):
def __enter__(self):
host, port, ssl = parse.host(self.host)
host, port, ssl = urlsplit(self.host)
http = httplib.HTTPSConnection if ssl else httplib.HTTPConnection
self.con = http(host, port, timeout=self.timeout)

@ -1,20 +1,20 @@
from __future__ import print_function
import re
import datetime
from itertools import chain
import re
try:
from urllib import unquote
from urlparse import urlparse
except ImportError:
from urllib.parse import urlparse, unquote
from urllib.parse import unquote
import html5lib
from isso.compat import map, filter, PY2K, string_types, text_type as str
from isso.compat import map, filter, PY2K
if PY2K: # http://bugs.python.org/issue12984
from xml.dom.minidom import NamedNodeMap
@ -51,33 +51,6 @@ def timedelta(value):
return datetime.timedelta(**kwargs)
def host(name):
"""
Parse :param name: into `httplib`-compatible host:port.
>>> host("http://example.tld/")
('example.tld', 80, False)
>>> host("https://example.tld/")
('example.tld', 443, True)
>>> host("example.tld")
('example.tld', 80, False)
>>> host("example.tld:42")
('example.tld', 42, False)
>>> host("https://example.tld:80/")
('example.tld', 80, True)
"""
if not (isinstance(name, string_types)):
name = str(name)
if not name.startswith(('http://', 'https://')):
name = 'http://' + name
rv = urlparse(name)
if rv.scheme == 'https' and rv.port is None:
return (rv.netloc, 443, True)
return (rv.netloc.rsplit(':')[0], rv.port or 80, rv.scheme == 'https')
def thread(data, default=u"Untitled.", id=None):
"""

@ -3,12 +3,13 @@
import socket
try:
from urllib.parse import quote
from urllib.parse import quote, urlparse
from socketserver import ThreadingMixIn
from http.server import HTTPServer
except ImportError:
from urllib import quote
from urlparse import urlparse
from SocketServer import ThreadingMixIn
from BaseHTTPServer import HTTPServer
@ -16,6 +17,8 @@ except ImportError:
from werkzeug.serving import WSGIRequestHandler
from werkzeug.datastructures import Headers
from isso.compat import string_types
def host(environ):
"""
@ -40,6 +43,58 @@ def host(environ):
return url + quote(environ.get('SCRIPT_NAME', ''))
def urlsplit(name):
"""
Parse :param:`name` into (netloc, port, ssl)
"""
if not (isinstance(name, string_types)):
name = str(name)
if not name.startswith(('http://', 'https://')):
name = 'http://' + name
rv = urlparse(name)
if rv.scheme == 'https' and rv.port is None:
return (rv.netloc, 443, True)
return (rv.netloc.rsplit(':')[0], rv.port or 80, rv.scheme == 'https')
def urljoin(netloc, port, ssl):
"""
Basically the counter-part of :func:`urlsplit`.
"""
rv = ("https" if ssl else "http") + "://" + netloc
if ssl and port != 443 or not ssl and port != 80:
rv += ":%i" % port
return rv
def origin(hosts):
"""
Return a function that returns a valid HTTP Origin or localhost
if none found.
"""
hosts = [urlsplit(h) for h in hosts]
def func(environ):
loc = environ.get("HTTP_ORIGIN", environ.get("HTTP_REFERER", None))
if not hosts or not loc:
return "http://invalid.local"
for split in hosts:
if urlsplit(loc) == split:
return urljoin(*split)
else:
return urljoin(*hosts[0])
return func
class SubURI(object):
def __init__(self, app):

@ -9,8 +9,7 @@ except ImportError:
from werkzeug.test import Client
from werkzeug.wrappers import Response
from isso.wsgi import CORSMiddleware
from isso.utils import origin
from isso.wsgi import CORSMiddleware, origin
def hello_world(environ, start_response):

@ -0,0 +1,48 @@
# -*- encoding: utf-8 -*-
try:
import unittest2 as unittest
except ImportError:
import unittest
from isso import wsgi
class TestWSGIUtilities(unittest.TestCase):
def test_urlsplit(self):
examples = [
("http://example.tld/", ('example.tld', 80, False)),
("https://example.tld/", ('example.tld', 443, True)),
("example.tld", ('example.tld', 80, False)),
("example.tld:42", ('example.tld', 42, False)),
("https://example.tld:80/", ('example.tld', 80, True))]
for (hostname, result) in examples:
self.assertEqual(wsgi.urlsplit(hostname), result)
def test_urljoin(self):
examples = [
(("example.tld", 80, False), "http://example.tld"),
(("example.tld", 42, True), "https://example.tld:42"),
(("example.tld", 443, True), "https://example.tld")]
for (split, result) in examples:
self.assertEqual(wsgi.urljoin(*split), result)
def test_origin(self):
self.assertEqual(wsgi.origin([])({}), "http://invalid.local")
origin = wsgi.origin(["http://foo.bar/", "https://foo.bar"])
self.assertEqual(origin({"HTTP_ORIGIN": "http://foo.bar"}),
"http://foo.bar")
self.assertEqual(origin({"HTTP_ORIGIN": "https://foo.bar"}),
"https://foo.bar")
self.assertEqual(origin({"HTTP_REFERER": "http://foo.bar"}),
"http://foo.bar")
self.assertEqual(origin({"HTTP_ORIGIN": "http://spam.baz"}),
"http://foo.bar")
Loading…
Cancel
Save