diff --git a/isso/__init__.py b/isso/__init__.py index 6d48af4..7122ff7 100644 --- a/isso/__init__.py +++ b/isso/__init__.py @@ -182,7 +182,8 @@ def make_app(conf=None, threading=True, multiprocessing=False, uwsgi=False): '/css': join(dirname(__file__), 'css/')})) wrapper.append(partial(wsgi.CORSMiddleware, - origin=origin(isso.conf.getiter("general", "host")))) + origin=origin(isso.conf.getiter("general", "host")), + allowed=("Origin", "Content-Type"), exposed=("X-Set-Cookie", ))) wrapper.extend([wsgi.SubURI, ProxyFix]) diff --git a/isso/wsgi.py b/isso/wsgi.py index da7e836..66e1f06 100644 --- a/isso/wsgi.py +++ b/isso/wsgi.py @@ -60,19 +60,25 @@ class SubURI(object): class CORSMiddleware(object): """Add Cross-origin resource sharing headers to every request.""" - def __init__(self, app, origin): + methods = ("HEAD", "GET", "POST", "PUT", "DELETE") + + def __init__(self, app, origin, allowed=[], exposed=[]): self.app = app self.origin = origin + self.allowed = allowed + self.exposed = exposed def __call__(self, environ, start_response): def add_cors_headers(status, headers, exc_info=None): headers = Headers(headers) headers.add("Access-Control-Allow-Origin", self.origin(environ)) - headers.add("Access-Control-Allow-Headers", "Origin, Content-Type") headers.add("Access-Control-Allow-Credentials", "true") - headers.add("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE") - headers.add("Access-Control-Expose-Headers", "X-Set-Cookie") + headers.add("Access-Control-Allow-Methods", ", ".join(self.methods)) + if self.allowed: + headers.add("Access-Control-Allow-Headers", ", ".join(self.allowed)) + if self.exposed: + headers.add("Access-Control-Expose-Headers", ", ".join(self.exposed)) return start_response(status, headers.to_list(), exc_info) if environ.get("REQUEST_METHOD") == "OPTIONS": diff --git a/specs/test_cors.py b/specs/test_cors.py index 93b9036..0479abe 100644 --- a/specs/test_cors.py +++ b/specs/test_cors.py @@ -22,21 +22,23 @@ class CORSTest(unittest.TestCase): def test_simple(self): - app = CORSMiddleware(hello_world, origin=origin([ - "https://example.tld/", - "http://example.tld/", - "http://example.tld", - ])) + app = CORSMiddleware(hello_world, + origin=origin([ + "https://example.tld/", + "http://example.tld/", + "http://example.tld", + ]), + allowed=("Foo", "Bar"), exposed=("Spam", )) client = Client(app, Response) rv = client.get("/", headers={"ORIGIN": "https://example.tld"}) self.assertEqual(rv.headers["Access-Control-Allow-Origin"], "https://example.tld") - self.assertEqual(rv.headers["Access-Control-Allow-Headers"], "Origin, Content-Type") self.assertEqual(rv.headers["Access-Control-Allow-Credentials"], "true") - self.assertEqual(rv.headers["Access-Control-Allow-Methods"], "GET, POST, PUT, DELETE") - self.assertEqual(rv.headers["Access-Control-Expose-Headers"], "X-Set-Cookie") + self.assertEqual(rv.headers["Access-Control-Allow-Methods"], "HEAD, GET, POST, PUT, DELETE") + self.assertEqual(rv.headers["Access-Control-Allow-Headers"], "Foo, Bar") + self.assertEqual(rv.headers["Access-Control-Expose-Headers"], "Spam") a = client.get("/", headers={"ORIGIN": "http://example.tld"}) self.assertEqual(a.headers["Access-Control-Allow-Origin"], "http://example.tld") @@ -50,7 +52,8 @@ class CORSTest(unittest.TestCase): def test_preflight(self): - app = CORSMiddleware(hello_world, origin=origin(["http://example.tld"])) + app = CORSMiddleware(hello_world, origin=origin(["http://example.tld"]), + allowed=("Foo", ), exposed=("Bar", )) client = Client(app, Response) rv = client.open(method="OPTIONS", path="/", headers={"ORIGIN": "http://example.tld"})