webssh

Web based ssh client https://github.com/huashengdun/webssh webssh.huashengdun.org/
git clone http://git.hanabi.in/repos/webssh.git
Log | Files | Refs | README | LICENSE

commit 5519a1701648690559cdda689648c9c3925a0b35
parent df26d0e677ff38fce4116359f6e47001c6d4e136
Author: Sheng <webmaster0115@gmail.com>
Date:   Sun, 26 Aug 2018 15:13:02 +0800

Return 400 http error for invalid post requests

Diffstat:
Mtests/test_app.py | 101++++++++++++++++++++++++++++++++++++++++++++-----------------------------------
Mtests/test_handler.py | 12+++++++-----
Mwebssh/handler.py | 70++++++++++++++++++++++++++++++++++++----------------------------------
Mwebssh/settings.py | 4++--
4 files changed, 101 insertions(+), 86 deletions(-)

diff --git a/tests/test_app.py b/tests/test_app.py @@ -37,11 +37,13 @@ class TestApp(AsyncHTTPTestCase): def get_app(self): loop = self.io_loop - options.debug = True + options.debug = False options.policy = random.choice(['warning', 'autoadd']) options.hostFile = '' options.sysHostFile = '' - app = make_app(make_handlers(loop, options), get_app_settings(options)) + settings = get_app_settings(options) + settings.update(xsrf_cookies=False) + app = make_app(make_handlers(loop, options), settings) return app @classmethod @@ -63,7 +65,7 @@ class TestApp(AsyncHTTPTestCase): options.update(max_body_size=max_body_size) return options - def test_app_with_invalid_form(self): + def test_app_with_invalid_form_for_missing_argument(self): response = self.fetch('/') self.assertEqual(response.code, 200) @@ -82,44 +84,67 @@ class TestApp(AsyncHTTPTestCase): response = self.fetch('/', method='POST', body=body) self.assertIn(b'Missing argument username', response.body) - body = 'hostname=127.0.0.1&port=7000&username=admin' - response = self.fetch('/', method='POST', body=body) - self.assertEqual(response.code, 400) - self.assertIn(b'Missing argument password', response.body) - body = 'hostname=&port=&username=&password' response = self.fetch('/', method='POST', body=body) - self.assertIn(b'The hostname field is required', response.body) + self.assertEqual(response.code, 400) + self.assertIn(b'Missing argument hostname', response.body) body = 'hostname=127.0.0.1&port=&username=&password' response = self.fetch('/', method='POST', body=body) - self.assertIn(b'The port field is required', response.body) + self.assertEqual(response.code, 400) + self.assertIn(b'Missing argument port', response.body) body = 'hostname=127.0.0.1&port=7000&username=&password' response = self.fetch('/', method='POST', body=body) - self.assertIn(b'The username field is required', response.body) + self.assertEqual(response.code, 400) + self.assertIn(b'Missing argument username', response.body) + def test_app_with_invalid_form_for_invalid_value(self): body = 'hostname=127.0.0&port=22&username=&password' response = self.fetch('/', method='POST', body=body) self.assertIn(b'Invalid hostname', response.body) body = 'hostname=http://www.googe.com&port=22&username=&password' response = self.fetch('/', method='POST', body=body) + self.assertEqual(response.code, 400) self.assertIn(b'Invalid hostname', response.body) body = 'hostname=127.0.0.1&port=port&username=&password' response = self.fetch('/', method='POST', body=body) + self.assertEqual(response.code, 400) self.assertIn(b'Invalid port', response.body) body = 'hostname=127.0.0.1&port=70000&username=&password' response = self.fetch('/', method='POST', body=body) + self.assertEqual(response.code, 400) self.assertIn(b'Invalid port', response.body) + def test_app_with_wrong_hostname_ip(self): + body = 'hostname=127.0.0.1&port=7000&username=admin' + response = self.fetch('/', method='POST', body=body) + self.assertEqual(response.code, 200) + self.assertIn(b'Unable to connect to', response.body) + + def test_app_with_wrong_hostname_domain(self): + body = 'hostname=xxxxxxxxxxxx&port=7000&username=admin' + response = self.fetch('/', method='POST', body=body) + self.assertEqual(response.code, 200) + self.assertIn(b'Unable to connect to', response.body) + + def test_app_with_wrong_port(self): + body = 'hostname=127.0.0.1&port=7000&username=admin' + response = self.fetch('/', method='POST', body=body) + self.assertEqual(response.code, 200) + self.assertIn(b'Unable to connect to', response.body) + def test_app_with_wrong_credentials(self): response = self.fetch('/') self.assertEqual(response.code, 200) response = self.fetch('/', method='POST', body=self.body + 's') - self.assertIn(b'Authentication failed.', response.body) + data = json.loads(to_str(response.body)) + self.assertIsNone(data['encoding']) + self.assertIsNone(data['id']) + self.assertIn('Authentication failed.', data['status']) def test_app_with_correct_credentials(self): response = self.fetch('/') @@ -192,7 +217,7 @@ class TestApp(AsyncHTTPTestCase): self.assertIn('Missing argument id', ws.close_reason) @tornado.testing.gen_test - def test_app_with_correct_credentials_but_epmpty_id(self): + def test_app_with_correct_credentials_but_empty_id(self): url = self.get_url('/') client = self.get_http_client() response = yield client.fetch(url) @@ -209,7 +234,7 @@ class TestApp(AsyncHTTPTestCase): ws = yield tornado.websocket.websocket_connect(ws_url) msg = yield ws.read_message() self.assertIsNone(msg) - self.assertIn('field is required', ws.close_reason) + self.assertIn('Missing argument id', ws.close_reason) @tornado.testing.gen_test def test_app_with_correct_credentials_but_wrong_id(self): @@ -345,12 +370,10 @@ class TestApp(AsyncHTTPTestCase): headers = { 'Content-Type': content_type, 'content-length': str(len(body)) } - response = yield client.fetch(url, method='POST', headers=headers, - body=body) - data = json.loads(to_str(response.body)) - self.assertIsNone(data['id']) - self.assertIsNone(data['encoding']) - self.assertTrue(data['status'].startswith('Invalid private key')) + with self.assertRaises(HTTPError) as ctx: + yield client.fetch(url, method='POST', headers=headers, body=body) + self.assertEqual(ctx.exception.code, 400) + self.assertIn('Invalid private key', ctx.exception.message) @tornado.testing.gen_test def test_app_auth_with_pubkey_exceeds_key_max_size(self): @@ -366,12 +389,10 @@ class TestApp(AsyncHTTPTestCase): headers = { 'Content-Type': content_type, 'content-length': str(len(body)) } - response = yield client.fetch(url, method='POST', headers=headers, - body=body) - data = json.loads(to_str(response.body)) - self.assertIsNone(data['id']) - self.assertIsNone(data['encoding']) - self.assertTrue(data['status'].startswith('Invalid private key')) + with self.assertRaises(HTTPError) as ctx: + yield client.fetch(url, method='POST', headers=headers, body=body) + self.assertEqual(ctx.exception.code, 400) + self.assertIn('Invalid private key', ctx.exception.message) @tornado.testing.gen_test def test_app_auth_with_pubkey_cannot_be_decoded_by_multipart_form(self): @@ -390,22 +411,10 @@ class TestApp(AsyncHTTPTestCase): headers = { 'Content-Type': content_type, 'content-length': str(len(body)) } - with self.assertRaises(HTTPError) as exc: + with self.assertRaises(HTTPError) as ctx: yield client.fetch(url, method='POST', headers=headers, body=body) - self.assertIn('Bad Request (Invalid unicode', exc.msg) - - @tornado.testing.gen_test - def test_app_auth_with_pubkey_cannot_be_decoded_by_urlencoded_form(self): - url = self.get_url('/') - client = self.get_http_client() - response = yield client.fetch(url) - self.assertEqual(response.code, 200) - - privatekey = b'h' * 1024 + b'\xb4\xed\xce\xf3' - body = self.body.encode() + b'&privatekey=' + privatekey - with self.assertRaises(HTTPError) as exc: - yield client.fetch(url, method='POST', body=body) - self.assertIn('Bad Request (Invalid unicode', exc.msg) + self.assertEqual(ctx.exception.code, 400) + self.assertIn('Invalid unicode', ctx.exception.message) @tornado.testing.gen_test def test_app_post_form_with_large_body_size_by_multipart_form(self): @@ -422,9 +431,10 @@ class TestApp(AsyncHTTPTestCase): 'Content-Type': content_type, 'content-length': str(len(body)) } - with self.assertRaises(HTTPError) as exc: + with self.assertRaises(HTTPError) as ctx: yield client.fetch(url, method='POST', headers=headers, body=body) - self.assertIsNone(exc.msg) + self.assertEqual(ctx.exception.code, 400) + self.assertIn('Bad Request', ctx.exception.message) @tornado.testing.gen_test def test_app_post_form_with_large_body_size_by_urlencoded_form(self): @@ -435,6 +445,7 @@ class TestApp(AsyncHTTPTestCase): privatekey = 'h' * (2 * max_body_size) body = self.body + '&privatekey=' + privatekey - with self.assertRaises(HTTPError) as exc: + with self.assertRaises(HTTPError) as ctx: yield client.fetch(url, method='POST', body=body) - self.assertIsNone(exc.msg) + self.assertEqual(ctx.exception.code, 400) + self.assertIn('Bad Request', ctx.exception.message) diff --git a/tests/test_handler.py b/tests/test_handler.py @@ -3,7 +3,9 @@ import paramiko from tornado.httputil import HTTPServerRequest from tests.utils import read_file, make_tests_data_path -from webssh.handler import MixinHandler, IndexHandler, parse_encoding +from webssh.handler import ( + MixinHandler, IndexHandler, parse_encoding, InvalidException +) class TestHandler(unittest.TestCase): @@ -70,7 +72,7 @@ class TestIndexHandler(unittest.TestCase): pkey = IndexHandler.get_specific_pkey(cls, 'x'+key, None) self.assertIsNone(pkey) - with self.assertRaises(ValueError): + with self.assertRaises(paramiko.PasswordRequiredException): pkey = IndexHandler.get_specific_pkey(cls, key, None) def test_get_pkey_obj_with_plain_key(self): @@ -81,7 +83,7 @@ class TestIndexHandler(unittest.TestCase): self.assertIsInstance(pkey, cls) pkey = IndexHandler.get_pkey_obj(key, 'iginored', fname) self.assertIsInstance(pkey, cls) - with self.assertRaises(ValueError) as exc: + with self.assertRaises(InvalidException) as exc: pkey = IndexHandler.get_pkey_obj('x'+key, None, fname) self.assertIn('Invalid private key', str(exc)) @@ -92,9 +94,9 @@ class TestIndexHandler(unittest.TestCase): key = read_file(make_tests_data_path(fname)) pkey = IndexHandler.get_pkey_obj(key, password, fname) self.assertIsInstance(pkey, cls) - with self.assertRaises(ValueError) as exc: + with self.assertRaises(InvalidException) as exc: pkey = IndexHandler.get_pkey_obj(key, 'wrongpass', fname) self.assertIn('Wrong password', str(exc)) - with self.assertRaises(ValueError) as exc: + with self.assertRaises(InvalidException) as exc: pkey = IndexHandler.get_pkey_obj('x'+key, password, fname) self.assertIn('Invalid private key', str(exc)) diff --git a/webssh/handler.py b/webssh/handler.py @@ -38,20 +38,25 @@ def parse_encoding(data): return s.strip('"').split('.')[-1] +class InvalidException(Exception): + pass + + class MixinHandler(object): - def get_value(self, name): - is_required = name in self.arguments_required + def write_error(self, status_code, **kwargs): + exc_info = kwargs.get('exc_info') + if exc_info and len(exc_info) > 1: + info = str(exc_info[1]) + if info: + self._reason = info.split(':', 1)[-1].strip() + super(MixinHandler, self).write_error(status_code, **kwargs) - try: - value = self.get_argument(name) - except tornado.web.MissingArgumentError: - if is_required: - raise - else: - if not value and is_required: - raise ValueError('The {} field is required.'.format(name)) - return value + def get_value(self, name): + value = self.get_argument(name) + if not value: + raise tornado.web.MissingArgumentError(name) + return value def get_real_client_addr(self): ip = self.request.headers.get('X-Real-Ip') @@ -75,8 +80,6 @@ class MixinHandler(object): class IndexHandler(MixinHandler, tornado.web.RequestHandler): - arguments_required = {'hostname', 'port', 'username', 'password'} - def initialize(self, loop, policy, host_keys_settings): self.loop = loop self.policy = policy @@ -86,14 +89,14 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): def get_privatekey(self): lst = self.request.files.get('privatekey') # multipart form if not lst: - return self.get_value('privatekey') # urlencoded form + return self.get_argument('privatekey', u'') # urlencoded form else: self.filename = lst[0]['filename'] data = lst[0]['body'] if len(data) > KEY_MAX_SIZE: - raise ValueError( - 'Invalid private key: {}'.format(self.filename) - ) + raise InvalidException( + 'Invalid private key: {}'.format(self.filename) + ) return self.decode_argument(data, name=self.filename) @classmethod @@ -103,7 +106,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): pkey = pkeycls.from_private_key(io.StringIO(privatekey), password=password) except paramiko.PasswordRequiredException: - raise ValueError('Need password to decrypt the private key.') + raise except paramiko.SSHException: pass else: @@ -125,7 +128,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): error = ( 'Wrong password {!r} for decrypting the private key.' ) .format(password) - raise ValueError(error) + raise InvalidException(error) return pkey @@ -133,7 +136,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): value = self.get_value('hostname') if not (is_valid_hostname(value) | is_valid_ipv4_address(value) | is_valid_ipv6_address(value)): - raise ValueError('Invalid hostname: {}'.format(value)) + raise InvalidException('Invalid hostname: {}'.format(value)) return value def get_port(self): @@ -146,19 +149,13 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): if is_valid_port(port): return port - raise ValueError('Invalid port: {}'.format(value)) - - def get_password(self): - try: - return self.get_value('password') - except ValueError: - return '' + raise InvalidException('Invalid port: {}'.format(value)) def get_args(self): hostname = self.get_hostname() port = self.get_port() username = self.get_value('username') - password = self.get_password() + password = self.get_argument('password', u'') privatekey = self.get_privatekey() pkey = self.get_pkey_obj(privatekey, password, self.filename) \ if privatekey else None @@ -188,7 +185,11 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): ssh._host_keys_filename = self.host_keys_settings['host_keys_filename'] ssh.set_missing_host_key_policy(self.policy) - args = self.get_args() + try: + args = self.get_args() + except InvalidException as exc: + raise tornado.web.HTTPError(400, str(exc)) + dst_addr = (args[0], args[1]) logging.info('Connecting to {}:{}'.format(*dst_addr)) @@ -197,6 +198,8 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): except socket.error: raise ValueError('Unable to connect to {}:{}'.format(*dst_addr)) except paramiko.BadAuthenticationType: + raise ValueError('Bad authentication type.') + except paramiko.AuthenticationException: raise ValueError('Authentication failed.') except paramiko.BadHostKeyException: raise ValueError('Bad host key.') @@ -233,7 +236,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): try: worker = yield future - except ValueError as exc: + except (ValueError, paramiko.SSHException) as exc: status = str(exc) else: worker_id = worker.id @@ -246,8 +249,6 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): - arguments_required = {'id'} - def initialize(self, loop): self.loop = loop self.worker_ref = None @@ -260,8 +261,9 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): logging.info('Connected from {}:{}'.format(*self.src_addr)) try: worker_id = self.get_value('id') - except (tornado.web.MissingArgumentError, ValueError) as exc: - self.close(reason=str(exc)) + except tornado.web.MissingArgumentError as exc: + self.close(reason=exc.log_message) + raise else: worker = workers.get(worker_id) if worker and worker.src_addr[0] == self.src_addr[0]: diff --git a/webssh/settings.py b/webssh/settings.py @@ -36,8 +36,8 @@ def get_app_settings(options): template_path=os.path.join(base_dir, 'webssh', 'templates'), static_path=os.path.join(base_dir, 'webssh', 'static'), websocket_ping_interval=options.wpIntvl, - xsrf_cookies=(not options.debug), - debug=options.debug + debug=options.debug, + xsrf_cookies=True ) return settings