commit 5c8bd84b95910d4fe56dba754e6fb863d53a0b67
parent b51e8239739a4579b7c8883c65e4ea55233ee576
Author: Sheng <webmaster0115@gmail.com>
Date: Thu, 10 Jan 2019 22:09:32 +0800
Added an option for configuring cross-origin websocket level
Diffstat:
5 files changed, 115 insertions(+), 3 deletions(-)
diff --git a/tests/test_handler.py b/tests/test_handler.py
@@ -5,7 +5,7 @@ from tornado.httputil import HTTPServerRequest
from tornado.options import options
from tests.utils import read_file, make_tests_data_path
from webssh.handler import (
- MixinHandler, IndexHandler, InvalidValueError, open_to_public
+ MixinHandler, IndexHandler, WsockHandler, InvalidValueError, open_to_public
)
try:
@@ -202,3 +202,30 @@ class TestIndexHandler(unittest.TestCase):
with self.assertRaises(paramiko.PasswordRequiredException):
pkey = IndexHandler.get_pkey_obj(key, '', fname)
+
+
+class TestWsockHandler(unittest.TestCase):
+
+ def test_check_origin(self):
+ request = HTTPServerRequest(uri='/')
+ obj = Mock(spec=WsockHandler, request=request)
+
+ options.cows = 0
+ request.headers['Host'] = 'www.example.com:4433'
+ origin = 'https://www.example.com:4433'
+ self.assertTrue(WsockHandler.check_origin(obj, origin))
+
+ origin = 'https://www.example.com'
+ self.assertFalse(WsockHandler.check_origin(obj, origin))
+
+ options.cows = 1
+ self.assertTrue(WsockHandler.check_origin(obj, origin))
+
+ origin = 'https://blog.example.com'
+ self.assertTrue(WsockHandler.check_origin(obj, origin))
+
+ origin = 'https://blog.example.org'
+ self.assertFalse(WsockHandler.check_origin(obj, origin))
+
+ options.cows = 2
+ self.assertTrue(WsockHandler.check_origin(obj, origin))
diff --git a/tests/test_utils.py b/tests/test_utils.py
@@ -3,7 +3,7 @@ import unittest
from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
to_int, on_public_network_interface, get_ips_by_name, is_ip_hostname,
- is_name_open_to_public
+ is_name_open_to_public, is_same_primary_domain
)
@@ -79,3 +79,32 @@ class TestUitls(unittest.TestCase):
self.assertTrue(is_ip_hostname('127.0.0.1'))
self.assertFalse(is_ip_hostname('localhost'))
self.assertFalse(is_ip_hostname('www.google.com'))
+
+ def test_is_same_primary_domain(self):
+ domain1 = 'localhost'
+ domain2 = 'localhost'
+ self.assertTrue(is_same_primary_domain(domain1, domain2))
+
+ domain1 = 'localhost'
+ domain2 = 'test'
+ self.assertFalse(is_same_primary_domain(domain1, domain2))
+
+ domain1 = 'example.com'
+ domain2 = 'example.com'
+ self.assertTrue(is_same_primary_domain(domain1, domain2))
+
+ domain1 = 'www.example.com'
+ domain2 = 'example.com'
+ self.assertTrue(is_same_primary_domain(domain1, domain2))
+
+ domain1 = 'wwwexample.com'
+ domain2 = 'example.com'
+ self.assertFalse(is_same_primary_domain(domain1, domain2))
+
+ domain1 = 'www.example.com'
+ domain2 = 'www2.example.com'
+ self.assertTrue(is_same_primary_domain(domain1, domain2))
+
+ domain1 = 'xxx.www.example.com'
+ domain2 = 'xxx.www2.example.com'
+ self.assertTrue(is_same_primary_domain(domain1, domain2))
diff --git a/webssh/handler.py b/webssh/handler.py
@@ -13,7 +13,8 @@ from tornado.ioloop import IOLoop
from tornado.options import options
from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
- to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname
+ to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname,
+ is_same_primary_domain
)
from webssh.worker import Worker, recycle_worker, clients
@@ -27,6 +28,11 @@ try:
except ImportError:
JSONDecodeError = ValueError
+try:
+ from urllib.parse import urlparse
+except ImportError:
+ from urlparse import urlparse
+
DELAY = 3
KEY_MAX_SIZE = 16384
@@ -364,6 +370,24 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
super(WsockHandler, self).initialize(loop)
self.worker_ref = None
+ def check_origin(self, origin):
+ cows = options.cows
+ parsed_origin = urlparse(origin)
+ origin = parsed_origin.netloc
+ origin = origin.lower()
+ logging.debug('origin: {}'.format(origin))
+
+ host = self.request.headers.get('Host')
+ logging.debug('host: {}'.format(host))
+
+ if cows == 0:
+ return origin == host
+ elif cows == 1:
+ return is_same_primary_domain(origin.rsplit(':', 1)[0],
+ host.rsplit(':', 1)[0])
+ else:
+ return True
+
def open(self):
self.src_addr = self.get_client_addr()
logging.info('Connected from {}:{}'.format(*self.src_addr))
diff --git a/webssh/settings.py b/webssh/settings.py
@@ -34,6 +34,10 @@ define('fbidhttp', type=bool, default=True,
help='Forbid public plain http incoming requests')
define('xheaders', type=bool, default=True, help='Support xheaders')
define('xsrf', type=bool, default=True, help='CSRF protection')
+define('cows', type=int, default=0, help='Cross origin websocket, '
+ '0: matches host name and port number'
+ '1: matches primary domain only'
+ '?: matches nothing, allow all cross-origin websockets')
define('wpintvl', type=int, default=0, help='Websocket ping interval')
define('maxconn', type=int, default=20, help='Maximum connections per client')
define('version', type=bool, help='Show version information',
diff --git a/webssh/utils.py b/webssh/utils.py
@@ -96,3 +96,31 @@ def is_name_open_to_public(name):
for ip in get_ips_by_name(name):
if on_public_network_interface(ip):
return True
+
+
+def is_same_primary_domain(domain1, domain2):
+ i = -1
+ dots = 0
+ l1 = len(domain1)
+ l2 = len(domain2)
+ m = 0 - min(l1, l2)
+
+ while i >= m:
+ c1 = domain1[i]
+ c2 = domain2[i]
+
+ if c1 == c2:
+ if c1 == '.':
+ dots += 1
+ if dots == 2:
+ return True
+ else:
+ return False
+
+ i -= 1
+
+ if l1 == l2:
+ return True
+
+ c = domain1[i] if l1 > m else domain2[i]
+ return c == '.'