use HTTP_REFERER as fallback if HTTP_ORIGIN is not sent
Also refactor those function a bit and move doctests into a separate module.
This commit is contained in:
parent
09451ff707
commit
e393711859
@ -64,7 +64,8 @@ local_manager = LocalManager([local])
|
||||
|
||||
from isso import db, migrate, wsgi, ext, views
|
||||
from isso.core import ThreadedMixin, ProcessMixin, uWSGIMixin, Config
|
||||
from isso.utils import parse, http, JSONRequest, origin, html
|
||||
from isso.wsgi import origin, urlsplit
|
||||
from isso.utils import http, JSONRequest, html
|
||||
from isso.views import comments
|
||||
|
||||
from isso.ext.notifications import Stdout, SMTP
|
||||
@ -183,7 +184,8 @@ def make_app(conf=None, threading=True, multiprocessing=False, uwsgi=False):
|
||||
|
||||
wrapper.append(partial(wsgi.CORSMiddleware,
|
||||
origin=origin(isso.conf.getiter("general", "host")),
|
||||
allowed=("Origin", "Content-Type"), exposed=("X-Set-Cookie", "Date")))
|
||||
allowed=("Origin", "Referer", "Content-Type"),
|
||||
exposed=("X-Set-Cookie", "Date")))
|
||||
|
||||
wrapper.extend([wsgi.SubURI, ProxyFix])
|
||||
|
||||
@ -222,7 +224,7 @@ def main():
|
||||
sys.exit(1)
|
||||
|
||||
if conf.get("server", "listen").startswith("http://"):
|
||||
host, port, _ = parse.host(conf.get("server", "listen"))
|
||||
host, port, _ = urlsplit(conf.get("server", "listen"))
|
||||
try:
|
||||
from gevent.pywsgi import WSGIServer
|
||||
WSGIServer((host, port), make_app(conf)).serve_forever()
|
||||
|
@ -3,11 +3,6 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
try:
|
||||
from urlparse import urlparse
|
||||
except ImportError:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from werkzeug.wsgi import DispatcherMiddleware
|
||||
from werkzeug.wrappers import Response
|
||||
|
||||
|
@ -5,7 +5,6 @@ from __future__ import division
|
||||
import pkg_resources
|
||||
werkzeug = pkg_resources.get_distribution("werkzeug")
|
||||
|
||||
import io
|
||||
import json
|
||||
import hashlib
|
||||
|
||||
@ -120,27 +119,5 @@ class JSONResponse(Response):
|
||||
|
||||
def __init__(self, obj, *args, **kwargs):
|
||||
kwargs["content_type"] = "application/json"
|
||||
return super(JSONResponse, self).__init__(
|
||||
super(JSONResponse, self).__init__(
|
||||
json.dumps(obj).encode("utf-8"), *args, **kwargs)
|
||||
|
||||
|
||||
def origin(hosts):
|
||||
"""
|
||||
Return a function that returns a valid HTTP Origin or localhost
|
||||
if none found.
|
||||
"""
|
||||
|
||||
hosts = [x.rstrip("/") for x in hosts]
|
||||
|
||||
def func(environ):
|
||||
|
||||
if not hosts:
|
||||
return "http://localhost/"
|
||||
|
||||
for host in hosts:
|
||||
if environ.get("HTTP_ORIGIN", None) == host:
|
||||
return host
|
||||
else:
|
||||
return hosts[0]
|
||||
|
||||
return func
|
||||
|
@ -7,7 +7,7 @@ try:
|
||||
except ImportError:
|
||||
import http.client as httplib
|
||||
|
||||
from isso.utils import parse
|
||||
from isso.wsgi import urlsplit
|
||||
|
||||
|
||||
class curl(object):
|
||||
@ -29,7 +29,7 @@ class curl(object):
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
host, port, ssl = parse.host(self.host)
|
||||
host, port, ssl = urlsplit(self.host)
|
||||
http = httplib.HTTPSConnection if ssl else httplib.HTTPConnection
|
||||
|
||||
self.con = http(host, port, timeout=self.timeout)
|
||||
|
@ -1,20 +1,20 @@
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import datetime
|
||||
|
||||
from itertools import chain
|
||||
|
||||
import re
|
||||
|
||||
|
||||
try:
|
||||
from urllib import unquote
|
||||
from urlparse import urlparse
|
||||
except ImportError:
|
||||
from urllib.parse import urlparse, unquote
|
||||
from urllib.parse import unquote
|
||||
|
||||
import html5lib
|
||||
|
||||
from isso.compat import map, filter, PY2K, string_types, text_type as str
|
||||
from isso.compat import map, filter, PY2K
|
||||
|
||||
if PY2K: # http://bugs.python.org/issue12984
|
||||
from xml.dom.minidom import NamedNodeMap
|
||||
@ -51,33 +51,6 @@ def timedelta(value):
|
||||
return datetime.timedelta(**kwargs)
|
||||
|
||||
|
||||
def host(name):
|
||||
"""
|
||||
Parse :param name: into `httplib`-compatible host:port.
|
||||
|
||||
>>> host("http://example.tld/")
|
||||
('example.tld', 80, False)
|
||||
>>> host("https://example.tld/")
|
||||
('example.tld', 443, True)
|
||||
>>> host("example.tld")
|
||||
('example.tld', 80, False)
|
||||
>>> host("example.tld:42")
|
||||
('example.tld', 42, False)
|
||||
>>> host("https://example.tld:80/")
|
||||
('example.tld', 80, True)
|
||||
"""
|
||||
|
||||
if not (isinstance(name, string_types)):
|
||||
name = str(name)
|
||||
|
||||
if not name.startswith(('http://', 'https://')):
|
||||
name = 'http://' + name
|
||||
|
||||
rv = urlparse(name)
|
||||
if rv.scheme == 'https' and rv.port is None:
|
||||
return (rv.netloc, 443, True)
|
||||
return (rv.netloc.rsplit(':')[0], rv.port or 80, rv.scheme == 'https')
|
||||
|
||||
|
||||
def thread(data, default=u"Untitled.", id=None):
|
||||
"""
|
||||
|
57
isso/wsgi.py
57
isso/wsgi.py
@ -3,12 +3,13 @@
|
||||
import socket
|
||||
|
||||
try:
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote, urlparse
|
||||
|
||||
from socketserver import ThreadingMixIn
|
||||
from http.server import HTTPServer
|
||||
except ImportError:
|
||||
from urllib import quote
|
||||
from urlparse import urlparse
|
||||
|
||||
from SocketServer import ThreadingMixIn
|
||||
from BaseHTTPServer import HTTPServer
|
||||
@ -16,6 +17,8 @@ except ImportError:
|
||||
from werkzeug.serving import WSGIRequestHandler
|
||||
from werkzeug.datastructures import Headers
|
||||
|
||||
from isso.compat import string_types
|
||||
|
||||
|
||||
def host(environ):
|
||||
"""
|
||||
@ -40,6 +43,58 @@ def host(environ):
|
||||
return url + quote(environ.get('SCRIPT_NAME', ''))
|
||||
|
||||
|
||||
def urlsplit(name):
|
||||
"""
|
||||
Parse :param:`name` into (netloc, port, ssl)
|
||||
"""
|
||||
|
||||
if not (isinstance(name, string_types)):
|
||||
name = str(name)
|
||||
|
||||
if not name.startswith(('http://', 'https://')):
|
||||
name = 'http://' + name
|
||||
|
||||
rv = urlparse(name)
|
||||
if rv.scheme == 'https' and rv.port is None:
|
||||
return (rv.netloc, 443, True)
|
||||
return (rv.netloc.rsplit(':')[0], rv.port or 80, rv.scheme == 'https')
|
||||
|
||||
|
||||
def urljoin(netloc, port, ssl):
|
||||
"""
|
||||
Basically the counter-part of :func:`urlsplit`.
|
||||
"""
|
||||
|
||||
rv = ("https" if ssl else "http") + "://" + netloc
|
||||
if ssl and port != 443 or not ssl and port != 80:
|
||||
rv += ":%i" % port
|
||||
return rv
|
||||
|
||||
|
||||
def origin(hosts):
|
||||
"""
|
||||
Return a function that returns a valid HTTP Origin or localhost
|
||||
if none found.
|
||||
"""
|
||||
|
||||
hosts = [urlsplit(h) for h in hosts]
|
||||
|
||||
def func(environ):
|
||||
|
||||
loc = environ.get("HTTP_ORIGIN", environ.get("HTTP_REFERER", None))
|
||||
|
||||
if not hosts or not loc:
|
||||
return "http://invalid.local"
|
||||
|
||||
for split in hosts:
|
||||
if urlsplit(loc) == split:
|
||||
return urljoin(*split)
|
||||
else:
|
||||
return urljoin(*hosts[0])
|
||||
|
||||
return func
|
||||
|
||||
|
||||
class SubURI(object):
|
||||
|
||||
def __init__(self, app):
|
||||
|
@ -9,8 +9,7 @@ except ImportError:
|
||||
from werkzeug.test import Client
|
||||
from werkzeug.wrappers import Response
|
||||
|
||||
from isso.wsgi import CORSMiddleware
|
||||
from isso.utils import origin
|
||||
from isso.wsgi import CORSMiddleware, origin
|
||||
|
||||
|
||||
def hello_world(environ, start_response):
|
||||
|
48
specs/test_wsgi.py
Normal file
48
specs/test_wsgi.py
Normal file
@ -0,0 +1,48 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest
|
||||
|
||||
|
||||
from isso import wsgi
|
||||
|
||||
|
||||
class TestWSGIUtilities(unittest.TestCase):
|
||||
|
||||
def test_urlsplit(self):
|
||||
|
||||
examples = [
|
||||
("http://example.tld/", ('example.tld', 80, False)),
|
||||
("https://example.tld/", ('example.tld', 443, True)),
|
||||
("example.tld", ('example.tld', 80, False)),
|
||||
("example.tld:42", ('example.tld', 42, False)),
|
||||
("https://example.tld:80/", ('example.tld', 80, True))]
|
||||
|
||||
for (hostname, result) in examples:
|
||||
self.assertEqual(wsgi.urlsplit(hostname), result)
|
||||
|
||||
def test_urljoin(self):
|
||||
|
||||
examples = [
|
||||
(("example.tld", 80, False), "http://example.tld"),
|
||||
(("example.tld", 42, True), "https://example.tld:42"),
|
||||
(("example.tld", 443, True), "https://example.tld")]
|
||||
|
||||
for (split, result) in examples:
|
||||
self.assertEqual(wsgi.urljoin(*split), result)
|
||||
|
||||
def test_origin(self):
|
||||
|
||||
self.assertEqual(wsgi.origin([])({}), "http://invalid.local")
|
||||
|
||||
origin = wsgi.origin(["http://foo.bar/", "https://foo.bar"])
|
||||
self.assertEqual(origin({"HTTP_ORIGIN": "http://foo.bar"}),
|
||||
"http://foo.bar")
|
||||
self.assertEqual(origin({"HTTP_ORIGIN": "https://foo.bar"}),
|
||||
"https://foo.bar")
|
||||
self.assertEqual(origin({"HTTP_REFERER": "http://foo.bar"}),
|
||||
"http://foo.bar")
|
||||
self.assertEqual(origin({"HTTP_ORIGIN": "http://spam.baz"}),
|
||||
"http://foo.bar")
|
Loading…
Reference in New Issue
Block a user