commit a1c9378048088cde18839a7419bb2a37f58d31f2
parent a6663c408e2f1b333bbbff25d057641c1ed092ab
Author: Sheng <webmaster0115@gmail.com>
Date: Wed, 23 Jan 2019 21:48:03 +0800
Support CORS
Diffstat:
2 files changed, 56 insertions(+), 9 deletions(-)
diff --git a/tests/test_app.py b/tests/test_app.py
@@ -720,12 +720,12 @@ class TestAppWithTooManyConnections(OtherTestBase):
ws.close()
-class TestAppWithCrossOriginConnect(OtherTestBase):
+class TestAppWithCrossOriginOperation(OtherTestBase):
origin = 'http://www.example.com'
@tornado.testing.gen_test
- def test_app_with_cross_orgin_connect(self):
+ def test_app_with_wrong_event_origin(self):
url = self.get_url('/')
client = self.get_http_client()
body = urlencode(dict(self.body, username='foo', _origin='localhost'))
@@ -734,8 +734,29 @@ class TestAppWithCrossOriginConnect(OtherTestBase):
data = json.loads(to_str(response.body))
self.assertIsNone(data['id'])
self.assertIsNone(data['encoding'])
- self.assertIn('Cross origin frame', data['status'])
+ self.assertEqual(
+ 'Cross origin operation is not allowed.', data['status']
+ )
+
+ @tornado.testing.gen_test
+ def test_app_with_wrong_header_origin(self):
+ url = self.get_url('/')
+ client = self.get_http_client()
+ body = urlencode(dict(self.body, username='foo'))
+ headers = dict(self.headers, Origin='localhost')
+ response = yield client.fetch(url, method='POST', body=body,
+ headers=headers)
+ data = json.loads(to_str(response.body))
+ self.assertIsNone(data['id'])
+ self.assertIsNone(data['encoding'])
+ self.assertEqual(
+ 'Cross origin operation is not allowed.', data['status']
+ )
+ @tornado.testing.gen_test
+ def test_app_with_correct_event_origin(self):
+ url = self.get_url('/')
+ client = self.get_http_client()
body = urlencode(dict(self.body, username='foo', _origin=self.origin))
response = yield client.fetch(url, method='POST', body=body,
headers=self.headers)
@@ -743,3 +764,20 @@ class TestAppWithCrossOriginConnect(OtherTestBase):
self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding'])
self.assertIsNone(data['status'])
+ self.assertIsNone(response.headers.get('Access-Control-Allow-Origin'))
+
+ @tornado.testing.gen_test
+ def test_app_with_correct_header_origin(self):
+ url = self.get_url('/')
+ client = self.get_http_client()
+ body = urlencode(dict(self.body, username='foo'))
+ headers = dict(self.headers, Origin=self.origin)
+ response = yield client.fetch(url, method='POST', body=body,
+ headers=headers)
+ data = json.loads(to_str(response.body))
+ self.assertIsNotNone(data['id'])
+ self.assertIsNotNone(data['encoding'])
+ self.assertIsNone(data['status'])
+ self.assertEqual(
+ response.headers.get('Access-Control-Allow-Origin'), self.origin
+ )
diff --git a/webssh/handler.py b/webssh/handler.py
@@ -346,6 +346,20 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
else:
future.set_result(worker)
+ def check_origin(self):
+ event_origin = self.get_argument('_origin', u'')
+ header_origin = self.request.headers.get('Origin')
+ origin = event_origin or header_origin
+
+ if origin:
+ if not super(IndexHandler, self).check_origin(origin):
+ raise tornado.web.HTTPError(
+ 403, 'Cross origin operation is not allowed.'
+ )
+
+ if not event_origin and self.origin_policy != 'same':
+ self.set_header('Access-Control-Allow-Origin', origin)
+
def head(self):
pass
@@ -362,12 +376,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
if len(clients.get(self.src_addr[0], {})) >= options.maxconn:
raise tornado.web.HTTPError(403, 'Too many connections.')
- origin = self.get_argument('_origin', u'')
- if origin:
- if not self.check_origin(origin):
- raise tornado.web.HTTPError(
- 403, 'Cross origin frame operation is not allowed.'
- )
+ self.check_origin()
future = Future()
t = threading.Thread(target=self.ssh_connect_wrapped, args=(future,))