diff --git a/isso/wsgi.py b/isso/wsgi.py index 9f7454a..19690f6 100644 --- a/isso/wsgi.py +++ b/isso/wsgi.py @@ -34,7 +34,7 @@ class CORSMiddleWare(object): else: origin = host.rstrip("/") - headers.append(("Access-Control-Allow-Origin", origin.encode("latin-1"))) + headers.append(("Access-Control-Allow-Origin", origin)) headers.append(("Access-Control-Allow-Headers", "Origin, Content-Type")) headers.append(("Access-Control-Allow-Credentials", "true")) headers.append(("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE")) diff --git a/specs/test_cors.py b/specs/test_cors.py index e69de29..eea4e4c 100644 --- a/specs/test_cors.py +++ b/specs/test_cors.py @@ -0,0 +1,52 @@ + +from werkzeug.test import Client +from werkzeug.wrappers import Response + +from isso.wsgi import CORSMiddleWare + + +def hello_world(environ, start_response): + start_response('200 OK', [('Content-Type', 'text/html')]) + return ["Hello, World."] + + +def test_simple_CORS(): + + app = CORSMiddleWare(hello_world, hosts=[ + "https://example.tld/", + "http://example.tld/", + "http://example.tld", + ]) + + client = Client(app, Response) + + rv = client.get("/", headers={"ORIGIN": "https://example.tld"}) + + assert rv.headers["Access-Control-Allow-Origin"] == "https://example.tld" + assert rv.headers["Access-Control-Allow-Headers"] == "Origin, Content-Type" + assert rv.headers["Access-Control-Allow-Credentials"] == "true" + assert rv.headers["Access-Control-Allow-Methods"] == "GET, POST, PUT, DELETE" + assert rv.headers["Access-Control-Expose-Headers"] == "X-Set-Cookie" + + a = client.get("/", headers={"ORIGIN": "http://example.tld/"}) + assert a.headers["Access-Control-Allow-Origin"] == "http://example.tld" + + b = client.get("/", headers={"ORIGIN": "http://example.tld"}) + assert a.headers["Access-Control-Allow-Origin"] == "http://example.tld" + + c = client.get("/", headers={"ORIGIN": "http://foo.other/"}) + assert a.headers["Access-Control-Allow-Origin"] == "http://example.tld" + + +def test_preflight_CORS(): + + app = CORSMiddleWare(hello_world, hosts=["http://example.tld"]) + client = Client(app, Response) + + rv = client.open(method="OPTIONS", path="/", headers={"ORIGIN": "http://example.tld"}) + assert rv.status_code == 200 + + for hdr in ("Origin", "Headers", "Credentials", "Methods"): + assert "Access-Control-Allow-%s" % hdr in rv.headers + + assert rv.headers["Access-Control-Allow-Origin"] == "http://example.tld"