webssh

Web based ssh client https://github.com/huashengdun/webssh webssh.huashengdun.org/
git clone http://git.hanabi.in/repos/webssh.git
Log | Files | Refs | README | LICENSE

commit 48acf09f21479e665ad97f71ec319069d178dc64
parent f3d9d297bbb81db4d916737e1809c9f8c97131f6
Author: Sheng <webmaster0115@gmail.com>
Date:   Sat, 25 Aug 2018 17:35:31 +0800

Move method get_value to MixinHandler

Diffstat:
Mtests/test_app.py | 85+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----
Mwebssh/handler.py | 55++++++++++++++++++++++++++++++++++++-------------------
2 files changed, 117 insertions(+), 23 deletions(-)

diff --git a/tests/test_app.py b/tests/test_app.py @@ -66,6 +66,23 @@ class TestApp(AsyncHTTPTestCase): def test_app_with_invalid_form(self): response = self.fetch('/') self.assertEqual(response.code, 200) + + body = 'port=7000&username=admin&password' + response = self.fetch('/', method='POST', body=body) + self.assertIn(b'Missing argument hostname', response.body) + + body = 'hostname=127.0.0.1&username=admin&password' + response = self.fetch('/', method='POST', body=body) + self.assertIn(b'Missing argument port', response.body) + + body = 'hostname=127.0.0.1&port=7000&password' + response = self.fetch('/', method='POST', body=body) + self.assertIn(b'Missing argument username', response.body) + + body = 'hostname=127.0.0.1&port=7000&username=admin' + response = self.fetch('/', method='POST', body=body) + self.assertIn(b'Missing argument password', response.body) + body = 'hostname=&port=&username=&password' response = self.fetch('/', method='POST', body=body) self.assertIn(b'The hostname field is required', response.body) @@ -74,6 +91,10 @@ class TestApp(AsyncHTTPTestCase): response = self.fetch('/', method='POST', body=body) self.assertIn(b'The port field is required', response.body) + body = 'hostname=127.0.0.1&port=7000&username=&password' + response = self.fetch('/', method='POST', body=body) + self.assertIn(b'The username field is required', response.body) + body = 'hostname=127.0.0&port=22&username=&password' response = self.fetch('/', method='POST', body=body) self.assertIn(b'Invalid hostname', response.body) @@ -90,10 +111,6 @@ class TestApp(AsyncHTTPTestCase): response = self.fetch('/', method='POST', body=body) self.assertIn(b'Invalid port', response.body) - body = 'hostname=127.0.0.1&port=7000&username=&password' - response = self.fetch('/', method='POST', body=body) - self.assertIn(b'The username field is required', response.body) # noqa - def test_app_with_wrong_credentials(self): response = self.fetch('/') self.assertEqual(response.code, 200) @@ -151,6 +168,66 @@ class TestApp(AsyncHTTPTestCase): ws.close() @tornado.testing.gen_test + def test_app_with_correct_credentials_but_without_id_argument(self): + url = self.get_url('/') + client = self.get_http_client() + response = yield client.fetch(url) + self.assertEqual(response.code, 200) + + response = yield client.fetch(url, method='POST', body=self.body) + data = json.loads(to_str(response.body)) + self.assertIsNone(data['status']) + self.assertIsNotNone(data['id']) + self.assertIsNotNone(data['encoding']) + + url = url.replace('http', 'ws') + ws_url = url + 'ws' + ws = yield tornado.websocket.websocket_connect(ws_url) + msg = yield ws.read_message() + self.assertIsNone(msg) + self.assertIn('Missing argument id', ws.close_reason) + + @tornado.testing.gen_test + def test_app_with_correct_credentials_but_epmpty_id(self): + url = self.get_url('/') + client = self.get_http_client() + response = yield client.fetch(url) + self.assertEqual(response.code, 200) + + response = yield client.fetch(url, method='POST', body=self.body) + data = json.loads(to_str(response.body)) + self.assertIsNone(data['status']) + self.assertIsNotNone(data['id']) + self.assertIsNotNone(data['encoding']) + + url = url.replace('http', 'ws') + ws_url = url + 'ws?id=' + ws = yield tornado.websocket.websocket_connect(ws_url) + msg = yield ws.read_message() + self.assertIsNone(msg) + self.assertIn('field is required', ws.close_reason) + + @tornado.testing.gen_test + def test_app_with_correct_credentials_but_wrong_id(self): + url = self.get_url('/') + client = self.get_http_client() + response = yield client.fetch(url) + self.assertEqual(response.code, 200) + + response = yield client.fetch(url, method='POST', body=self.body) + data = json.loads(to_str(response.body)) + self.assertIsNone(data['status']) + self.assertIsNotNone(data['id']) + self.assertIsNotNone(data['encoding']) + + url = url.replace('http', 'ws') + ws_url = url + 'ws?id=1' + data['id'] + ws = yield tornado.websocket.websocket_connect(ws_url) + msg = yield ws.read_message() + self.assertIsNone(msg) + self.assertIn('Websocket authentication failed', ws.close_reason) + + @tornado.testing.gen_test def test_app_with_correct_credentials_user_bar(self): url = self.get_url('/') client = self.get_http_client() diff --git a/webssh/handler.py b/webssh/handler.py @@ -40,6 +40,22 @@ def parse_encoding(data): class MixinHandler(object): + arguments_required = {} # agruments must be deliverd + empty_allowed = {} # emtpy value alllowed + + def get_value(self, name): + is_required = name in self.arguments_required + + try: + value = self.get_argument(name) + except tornado.web.MissingArgumentError: + if is_required: + raise + else: + if not value and is_required and name not in self.empty_allowed: + raise ValueError('The {} field is required.'.format(name)) + return value + def get_real_client_addr(self): ip = self.request.headers.get('X-Real-Ip') port = self.request.headers.get('X-Real-Port') @@ -62,6 +78,9 @@ class MixinHandler(object): class IndexHandler(MixinHandler, tornado.web.RequestHandler): + arguments_required = {'hostname', 'port', 'username', 'password'} + empty_allowed = {'password'} + def initialize(self, loop, policy, host_keys_settings): self.loop = loop self.policy = policy @@ -71,10 +90,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): def get_privatekey(self): lst = self.request.files.get('privatekey') # multipart form if not lst: - try: - return self.get_argument('privatekey') # urlencoded form - except tornado.web.MissingArgumentError: - pass + return self.get_value('privatekey') # urlencoded form else: self.filename = lst[0]['filename'] data = lst[0]['body'] @@ -136,17 +152,11 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): raise ValueError('Invalid port: {}'.format(value)) - def get_value(self, name): - value = self.get_argument(name) - if not value: - raise ValueError('The {} field is required.'.format(name)) - return value - def get_args(self): hostname = self.get_hostname() port = self.get_port() username = self.get_value('username') - password = self.get_argument('password') + password = self.get_value('password') privatekey = self.get_privatekey() pkey = self.get_pkey_obj(privatekey, password, self.filename) \ if privatekey else None @@ -234,6 +244,8 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): + arguments_required = {'id'} + def initialize(self, loop): self.loop = loop self.worker_ref = None @@ -244,15 +256,20 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): def open(self): self.src_addr = self.get_client_addr() logging.info('Connected from {}:{}'.format(*self.src_addr)) - worker = workers.get(self.get_argument('id')) - if worker and worker.src_addr[0] == self.src_addr[0]: - workers.pop(worker.id) - self.set_nodelay(True) - worker.set_handler(self) - self.worker_ref = weakref.ref(worker) - self.loop.add_handler(worker.fd, worker, IOLoop.READ) + try: + worker_id = self.get_value('id') + except (tornado.web.MissingArgumentError, ValueError) as exc: + self.close(reason=str(exc)) else: - self.close(reason='Websocket authentication failed.') + worker = workers.get(worker_id) + if worker and worker.src_addr[0] == self.src_addr[0]: + workers.pop(worker.id) + self.set_nodelay(True) + worker.set_handler(self) + self.worker_ref = weakref.ref(worker) + self.loop.add_handler(worker.fd, worker, IOLoop.READ) + else: + self.close(reason='Websocket authentication failed.') def on_message(self, message): logging.debug('{!r} from {}:{}'.format(message, *self.src_addr))