commit c35f801235703f581093ce35912d6f896a33ec45
parent 8a8d7412301cdde4ba4e3a74eb7cabb90533d474
Author: Sheng <webmaster0115@gmail.com>
Date: Sat, 19 Jan 2019 16:46:25 +0800
Support custom origin configuration
Diffstat:
6 files changed, 157 insertions(+), 20 deletions(-)
diff --git a/tests/test_handler.py b/tests/test_handler.py
@@ -215,7 +215,7 @@ class TestWsockHandler(unittest.TestCase):
request = HTTPServerRequest(uri='/')
obj = Mock(spec=WsockHandler, request=request)
- options.cows = 0
+ obj.origin_policy = 'same'
request.headers['Host'] = 'www.example.com:4433'
origin = 'https://www.example.com:4433'
self.assertTrue(WsockHandler.check_origin(obj, origin))
@@ -223,7 +223,7 @@ class TestWsockHandler(unittest.TestCase):
origin = 'https://www.example.com'
self.assertFalse(WsockHandler.check_origin(obj, origin))
- options.cows = 1
+ obj.origin_policy = 'primary'
self.assertTrue(WsockHandler.check_origin(obj, origin))
origin = 'https://blog.example.com'
@@ -232,5 +232,18 @@ class TestWsockHandler(unittest.TestCase):
origin = 'https://blog.example.org'
self.assertFalse(WsockHandler.check_origin(obj, origin))
- options.cows = 2
+ origin = 'https://blog.example.org'
+ obj.origin_policy = {'https://blog.example.org'}
+ self.assertTrue(WsockHandler.check_origin(obj, origin))
+
+ origin = 'http://blog.example.org'
+ obj.origin_policy = {'http://blog.example.org'}
+ self.assertTrue(WsockHandler.check_origin(obj, origin))
+
+ origin = 'http://blog.example.org'
+ obj.origin_policy = {'https://blog.example.org'}
+ self.assertFalse(WsockHandler.check_origin(obj, origin))
+
+ obj.origin_policy = '*'
+ origin = 'https://blog.example.org'
self.assertTrue(WsockHandler.check_origin(obj, origin))
diff --git a/tests/test_settings.py b/tests/test_settings.py
@@ -1,4 +1,5 @@
import io
+import random
import ssl
import sys
import os.path
@@ -10,7 +11,7 @@ from tests.utils import make_tests_data_path
from webssh.policy import load_host_keys
from webssh.settings import (
get_host_keys_settings, get_policy_setting, base_dir, print_version,
- get_ssl_context, get_trusted_downstream
+ get_ssl_context, get_trusted_downstream, get_origin_setting
)
from webssh.utils import UnicodeType
from webssh._version import __version__
@@ -137,3 +138,31 @@ class TestSettings(unittest.TestCase):
tdstream = '1.1.1.1, 2.2.2.'
with self.assertRaises(ValueError):
get_trusted_downstream(tdstream)
+
+ def test_get_origin_setting(self):
+ options.debug = False
+ options.origin = '*'
+ with self.assertRaises(ValueError):
+ get_origin_setting(options)
+
+ options.debug = True
+ self.assertEqual(get_origin_setting(options), '*')
+
+ options.origin = random.choice(['Same', 'Primary'])
+ self.assertEqual(get_origin_setting(options), options.origin.lower())
+
+ options.origin = ''
+ with self.assertRaises(ValueError):
+ get_origin_setting(options)
+
+ options.origin = ','
+ with self.assertRaises(ValueError):
+ get_origin_setting(options)
+
+ options.origin = 'www.example.com, https://www.example.org'
+ result = {'http://www.example.com', 'https://www.example.org'}
+ self.assertEqual(get_origin_setting(options), result)
+
+ options.origin = 'www.example.com:80, www.example.org:443'
+ result = {'http://www.example.com', 'https://www.example.org'}
+ self.assertEqual(get_origin_setting(options), result)
diff --git a/tests/test_utils.py b/tests/test_utils.py
@@ -2,7 +2,7 @@ import unittest
from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
- to_int, is_ip_hostname, is_same_primary_domain
+ to_int, is_ip_hostname, is_same_primary_domain, parse_origin_from_url
)
@@ -90,3 +90,34 @@ class TestUitls(unittest.TestCase):
domain1 = 'xxx.www.example.com'
domain2 = 'xxx.www2.example.com'
self.assertTrue(is_same_primary_domain(domain1, domain2))
+
+ def test_parse_origin_from_url(self):
+ url = ''
+ self.assertIsNone(parse_origin_from_url(url))
+
+ url = 'www.example.com'
+ self.assertEqual(parse_origin_from_url(url), 'http://www.example.com')
+
+ url = 'http://www.example.com'
+ self.assertEqual(parse_origin_from_url(url), 'http://www.example.com')
+
+ url = 'www.example.com:80'
+ self.assertEqual(parse_origin_from_url(url), 'http://www.example.com')
+
+ url = 'http://www.example.com:80'
+ self.assertEqual(parse_origin_from_url(url), 'http://www.example.com')
+
+ url = 'www.example.com:443'
+ self.assertEqual(parse_origin_from_url(url), 'https://www.example.com')
+
+ url = 'https://www.example.com'
+ self.assertEqual(parse_origin_from_url(url), 'https://www.example.com')
+
+ url = 'https://www.example.com:443'
+ self.assertEqual(parse_origin_from_url(url), 'https://www.example.com')
+
+ url = 'https://www.example.com:80'
+ self.assertEqual(parse_origin_from_url(url), url)
+
+ url = 'http://www.example.com:443'
+ self.assertEqual(parse_origin_from_url(url), url)
diff --git a/webssh/handler.py b/webssh/handler.py
@@ -57,6 +57,7 @@ class MixinHandler(object):
def initialize(self, loop=None):
self.check_request()
self.loop = loop
+ self.origin_policy = self.settings.get('origin_policy')
def check_request(self):
context = self.request.connection.context
@@ -364,22 +365,26 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
self.worker_ref = None
def check_origin(self, origin):
- cows = options.cows
+ if self.origin_policy == '*':
+ return True
+
parsed_origin = urlparse(origin)
- origin = parsed_origin.netloc
- origin = origin.lower()
- logging.debug('origin: {}'.format(origin))
+ netloc = parsed_origin.netloc.lower()
+ logging.debug('netloc: {}'.format(netloc))
host = self.request.headers.get('Host')
logging.debug('host: {}'.format(host))
- if cows == 0:
- return origin == host
- elif cows == 1:
- return is_same_primary_domain(origin.rsplit(':', 1)[0],
+ if netloc == host:
+ return True
+
+ if self.origin_policy == 'same':
+ return False
+ elif self.origin_policy == 'primary':
+ return is_same_primary_domain(netloc.rsplit(':', 1)[0],
host.rsplit(':', 1)[0])
else:
- return True
+ return origin in self.origin_policy
def open(self):
self.src_addr = self.get_client_addr()
diff --git a/webssh/settings.py b/webssh/settings.py
@@ -7,7 +7,7 @@ from tornado.options import define
from webssh.policy import (
load_host_keys, get_policy_class, check_policy_setting
)
-from webssh.utils import to_ip_address
+from webssh.utils import to_ip_address, parse_origin_from_url
from webssh._version import __version__
@@ -34,10 +34,12 @@ define('fbidhttp', type=bool, default=True,
help='Forbid public plain http incoming requests')
define('xheaders', type=bool, default=True, help='Support xheaders')
define('xsrf', type=bool, default=True, help='CSRF protection')
-define('cows', type=int, default=0, help='Cross origin websocket, '
- '0: matches host name and port number, '
- '1: matches primary domain only, '
- '?: matches nothing, allow all cross-origin websockets')
+define('origin', default='same', help='''Origin policy,
+'same': same origin policy, matches host name and port number;
+'primary': primary domain policy, matches primary domain only;
+'<domains>': custom domains policy, matches any domain in the <domains> list
+separated by comma;
+'*': wildcard policy, matches any domain, allowed in debug mode only.''')
define('wpintvl', type=int, default=0, help='Websocket ping interval')
define('maxconn', type=int, default=20, help='Maximum connections per client')
define('version', type=bool, help='Show version information',
@@ -54,7 +56,8 @@ def get_app_settings(options):
static_path=os.path.join(base_dir, 'webssh', 'static'),
websocket_ping_interval=options.wpintvl,
debug=options.debug,
- xsrf_cookies=options.xsrf
+ xsrf_cookies=options.xsrf,
+ origin_policy=get_origin_setting(options)
)
return settings
@@ -121,3 +124,28 @@ def get_trusted_downstream(tdstream):
to_ip_address(ip)
result.add(ip)
return result
+
+
+def get_origin_setting(options):
+ if options.origin == '*':
+ if not options.debug:
+ raise ValueError(
+ 'Wildcard origin policy is only allowed in debug mode.'
+ )
+ else:
+ return '*'
+
+ origin = options.origin.lower()
+ if origin in ['same', 'primary']:
+ return origin
+
+ origins = set()
+ for url in origin.split(','):
+ orig = parse_origin_from_url(url)
+ if orig:
+ origins.add(orig)
+
+ if not origins:
+ raise ValueError('Empty origin list')
+
+ return origins
diff --git a/webssh/utils.py b/webssh/utils.py
@@ -6,6 +6,11 @@ try:
except ImportError:
UnicodeType = str
+try:
+ from urllib.parse import urlparse
+except ImportError:
+ from urlparse import urlparse
+
numeric = re.compile(r'[0-9]+$')
allowed = re.compile(r'(?!-)[a-z0-9-]{1,63}(?<!-)$', re.IGNORECASE)
@@ -101,3 +106,29 @@ def is_same_primary_domain(domain1, domain2):
c = domain1[i] if l1 > m else domain2[i]
return c == '.'
+
+
+def parse_origin_from_url(url):
+ url = url.strip()
+ if not url:
+ return
+
+ if not (url.startswith('http://') or url.startswith('https://') or
+ url.startswith('//')):
+ url = '//' + url
+
+ parsed = urlparse(url)
+ port = parsed.port
+ scheme = parsed.scheme
+
+ if scheme == '':
+ scheme = 'https' if port == 443 else 'http'
+
+ if port == 443 and scheme == 'https':
+ netloc = parsed.netloc.replace(':443', '')
+ elif port == 80 and scheme == 'http':
+ netloc = parsed.netloc.replace(':80', '')
+ else:
+ netloc = parsed.netloc
+
+ return '{}://{}'.format(scheme, netloc)