commit 48acf09f21479e665ad97f71ec319069d178dc64
parent f3d9d297bbb81db4d916737e1809c9f8c97131f6
Author: Sheng <webmaster0115@gmail.com>
Date: Sat, 25 Aug 2018 17:35:31 +0800
Move method get_value to MixinHandler
Diffstat:
M | tests/test_app.py | | | 85 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---- |
M | webssh/handler.py | | | 55 | ++++++++++++++++++++++++++++++++++++------------------- |
2 files changed, 117 insertions(+), 23 deletions(-)
diff --git a/tests/test_app.py b/tests/test_app.py
@@ -66,6 +66,23 @@ class TestApp(AsyncHTTPTestCase):
def test_app_with_invalid_form(self):
response = self.fetch('/')
self.assertEqual(response.code, 200)
+
+ body = 'port=7000&username=admin&password'
+ response = self.fetch('/', method='POST', body=body)
+ self.assertIn(b'Missing argument hostname', response.body)
+
+ body = 'hostname=127.0.0.1&username=admin&password'
+ response = self.fetch('/', method='POST', body=body)
+ self.assertIn(b'Missing argument port', response.body)
+
+ body = 'hostname=127.0.0.1&port=7000&password'
+ response = self.fetch('/', method='POST', body=body)
+ self.assertIn(b'Missing argument username', response.body)
+
+ body = 'hostname=127.0.0.1&port=7000&username=admin'
+ response = self.fetch('/', method='POST', body=body)
+ self.assertIn(b'Missing argument password', response.body)
+
body = 'hostname=&port=&username=&password'
response = self.fetch('/', method='POST', body=body)
self.assertIn(b'The hostname field is required', response.body)
@@ -74,6 +91,10 @@ class TestApp(AsyncHTTPTestCase):
response = self.fetch('/', method='POST', body=body)
self.assertIn(b'The port field is required', response.body)
+ body = 'hostname=127.0.0.1&port=7000&username=&password'
+ response = self.fetch('/', method='POST', body=body)
+ self.assertIn(b'The username field is required', response.body)
+
body = 'hostname=127.0.0&port=22&username=&password'
response = self.fetch('/', method='POST', body=body)
self.assertIn(b'Invalid hostname', response.body)
@@ -90,10 +111,6 @@ class TestApp(AsyncHTTPTestCase):
response = self.fetch('/', method='POST', body=body)
self.assertIn(b'Invalid port', response.body)
- body = 'hostname=127.0.0.1&port=7000&username=&password'
- response = self.fetch('/', method='POST', body=body)
- self.assertIn(b'The username field is required', response.body) # noqa
-
def test_app_with_wrong_credentials(self):
response = self.fetch('/')
self.assertEqual(response.code, 200)
@@ -151,6 +168,66 @@ class TestApp(AsyncHTTPTestCase):
ws.close()
@tornado.testing.gen_test
+ def test_app_with_correct_credentials_but_without_id_argument(self):
+ url = self.get_url('/')
+ client = self.get_http_client()
+ response = yield client.fetch(url)
+ self.assertEqual(response.code, 200)
+
+ response = yield client.fetch(url, method='POST', body=self.body)
+ data = json.loads(to_str(response.body))
+ self.assertIsNone(data['status'])
+ self.assertIsNotNone(data['id'])
+ self.assertIsNotNone(data['encoding'])
+
+ url = url.replace('http', 'ws')
+ ws_url = url + 'ws'
+ ws = yield tornado.websocket.websocket_connect(ws_url)
+ msg = yield ws.read_message()
+ self.assertIsNone(msg)
+ self.assertIn('Missing argument id', ws.close_reason)
+
+ @tornado.testing.gen_test
+ def test_app_with_correct_credentials_but_epmpty_id(self):
+ url = self.get_url('/')
+ client = self.get_http_client()
+ response = yield client.fetch(url)
+ self.assertEqual(response.code, 200)
+
+ response = yield client.fetch(url, method='POST', body=self.body)
+ data = json.loads(to_str(response.body))
+ self.assertIsNone(data['status'])
+ self.assertIsNotNone(data['id'])
+ self.assertIsNotNone(data['encoding'])
+
+ url = url.replace('http', 'ws')
+ ws_url = url + 'ws?id='
+ ws = yield tornado.websocket.websocket_connect(ws_url)
+ msg = yield ws.read_message()
+ self.assertIsNone(msg)
+ self.assertIn('field is required', ws.close_reason)
+
+ @tornado.testing.gen_test
+ def test_app_with_correct_credentials_but_wrong_id(self):
+ url = self.get_url('/')
+ client = self.get_http_client()
+ response = yield client.fetch(url)
+ self.assertEqual(response.code, 200)
+
+ response = yield client.fetch(url, method='POST', body=self.body)
+ data = json.loads(to_str(response.body))
+ self.assertIsNone(data['status'])
+ self.assertIsNotNone(data['id'])
+ self.assertIsNotNone(data['encoding'])
+
+ url = url.replace('http', 'ws')
+ ws_url = url + 'ws?id=1' + data['id']
+ ws = yield tornado.websocket.websocket_connect(ws_url)
+ msg = yield ws.read_message()
+ self.assertIsNone(msg)
+ self.assertIn('Websocket authentication failed', ws.close_reason)
+
+ @tornado.testing.gen_test
def test_app_with_correct_credentials_user_bar(self):
url = self.get_url('/')
client = self.get_http_client()
diff --git a/webssh/handler.py b/webssh/handler.py
@@ -40,6 +40,22 @@ def parse_encoding(data):
class MixinHandler(object):
+ arguments_required = {} # agruments must be deliverd
+ empty_allowed = {} # emtpy value alllowed
+
+ def get_value(self, name):
+ is_required = name in self.arguments_required
+
+ try:
+ value = self.get_argument(name)
+ except tornado.web.MissingArgumentError:
+ if is_required:
+ raise
+ else:
+ if not value and is_required and name not in self.empty_allowed:
+ raise ValueError('The {} field is required.'.format(name))
+ return value
+
def get_real_client_addr(self):
ip = self.request.headers.get('X-Real-Ip')
port = self.request.headers.get('X-Real-Port')
@@ -62,6 +78,9 @@ class MixinHandler(object):
class IndexHandler(MixinHandler, tornado.web.RequestHandler):
+ arguments_required = {'hostname', 'port', 'username', 'password'}
+ empty_allowed = {'password'}
+
def initialize(self, loop, policy, host_keys_settings):
self.loop = loop
self.policy = policy
@@ -71,10 +90,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
def get_privatekey(self):
lst = self.request.files.get('privatekey') # multipart form
if not lst:
- try:
- return self.get_argument('privatekey') # urlencoded form
- except tornado.web.MissingArgumentError:
- pass
+ return self.get_value('privatekey') # urlencoded form
else:
self.filename = lst[0]['filename']
data = lst[0]['body']
@@ -136,17 +152,11 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
raise ValueError('Invalid port: {}'.format(value))
- def get_value(self, name):
- value = self.get_argument(name)
- if not value:
- raise ValueError('The {} field is required.'.format(name))
- return value
-
def get_args(self):
hostname = self.get_hostname()
port = self.get_port()
username = self.get_value('username')
- password = self.get_argument('password')
+ password = self.get_value('password')
privatekey = self.get_privatekey()
pkey = self.get_pkey_obj(privatekey, password, self.filename) \
if privatekey else None
@@ -234,6 +244,8 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
+ arguments_required = {'id'}
+
def initialize(self, loop):
self.loop = loop
self.worker_ref = None
@@ -244,15 +256,20 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
def open(self):
self.src_addr = self.get_client_addr()
logging.info('Connected from {}:{}'.format(*self.src_addr))
- worker = workers.get(self.get_argument('id'))
- if worker and worker.src_addr[0] == self.src_addr[0]:
- workers.pop(worker.id)
- self.set_nodelay(True)
- worker.set_handler(self)
- self.worker_ref = weakref.ref(worker)
- self.loop.add_handler(worker.fd, worker, IOLoop.READ)
+ try:
+ worker_id = self.get_value('id')
+ except (tornado.web.MissingArgumentError, ValueError) as exc:
+ self.close(reason=str(exc))
else:
- self.close(reason='Websocket authentication failed.')
+ worker = workers.get(worker_id)
+ if worker and worker.src_addr[0] == self.src_addr[0]:
+ workers.pop(worker.id)
+ self.set_nodelay(True)
+ worker.set_handler(self)
+ self.worker_ref = weakref.ref(worker)
+ self.loop.add_handler(worker.fd, worker, IOLoop.READ)
+ else:
+ self.close(reason='Websocket authentication failed.')
def on_message(self, message):
logging.debug('{!r} from {}:{}'.format(message, *self.src_addr))