commit 0775c0c3ae198e9bf9093081aebb9778c1b5e9ee
parent 27c587745cfdce5de326003e32a32c4effe886dc
Author: Sheng <webmaster0115@gmail.com>
Date: Mon, 8 Jul 2019 15:37:32 +0800
Refactored handler.py
Diffstat:
M | tests/test_app.py | | | 50 | +++++++++++++++++++++----------------------------- |
M | webssh/handler.py | | | 108 | +++++++++++++++++++++++++++++++++++-------------------------------------------- |
2 files changed, 69 insertions(+), 89 deletions(-)
diff --git a/tests/test_app.py b/tests/test_app.py
@@ -40,12 +40,12 @@ class TestAppBase(AsyncHTTPTestCase):
self.assertEqual(response.code, 400)
self.assertIn(b'Bad Request', response.body)
- def assert_status_in(self, data, status):
+ def assert_status_in(self, status, data):
self.assertIsNone(data['encoding'])
self.assertIsNone(data['id'])
self.assertIn(status, data['status'])
- def assert_status_equal(self, data, status):
+ def assert_status_equal(self, status, data):
self.assertIsNone(data['encoding'])
self.assertIsNone(data['id'])
self.assertEqual(status, data['status'])
@@ -172,7 +172,7 @@ class TestAppBasic(TestAppBase):
def test_app_with_wrong_credentials(self):
response = self.sync_post('/', self.body + 's')
- self.assert_status_in(json.loads(to_str(response.body)), 'Authentication failed.') # noqa
+ self.assert_status_in('Authentication failed.', json.loads(to_str(response.body))) # noqa
def test_app_with_correct_credentials(self):
response = self.sync_post('/', self.body)
@@ -442,10 +442,10 @@ class TestAppBasic(TestAppBase):
self.body_dict.update(username='keyonly', password='foo')
response = yield self.async_post('/', self.body_dict)
self.assertEqual(response.code, 200)
- self.assert_status_in(json.loads(to_str(response.body)), 'Bad authentication type') # noqa
+ self.assert_status_in('Bad authentication type', json.loads(to_str(response.body))) # noqa
@tornado.testing.gen_test
- def test_app_with_user_pass2fa_with_correct_password_and_passcode(self):
+ def test_app_with_user_pass2fa_with_correct_passwords(self):
self.body_dict.update(username='pass2fa', password='password',
totp='passcode')
response = yield self.async_post('/', self.body_dict)
@@ -454,25 +454,7 @@ class TestAppBasic(TestAppBase):
self.assert_status_none(data)
@tornado.testing.gen_test
- def test_app_with_user_pass2fa_with_wrong_password(self):
- self.body_dict.update(username='pass2fa', password='wrongpassword',
- totp='passcode')
- response = yield self.async_post('/', self.body_dict)
- self.assertEqual(response.code, 200)
- data = json.loads(to_str(response.body))
- self.assertIn('Authentication failed', data['status'])
-
- @tornado.testing.gen_test
- def test_app_with_user_pass2fa_with_wrong_passcode(self):
- self.body_dict.update(username='pass2fa', password='password',
- totp='wrongpasscode')
- response = yield self.async_post('/', self.body_dict)
- self.assertEqual(response.code, 200)
- data = json.loads(to_str(response.body))
- self.assertIn('Authentication failed', data['status'])
-
- @tornado.testing.gen_test
- def test_app_with_user_pass2fa_with_wrong_pkey_correct_passwords(self): # noqa
+ def test_app_with_user_pass2fa_with_wrong_pkey_correct_passwords(self):
url = self.get_url('/')
privatekey = read_file(make_tests_data_path('user_rsa_key'))
self.body_dict.update(username='pass2fa', password='password',
@@ -482,7 +464,7 @@ class TestAppBasic(TestAppBase):
self.assert_status_none(data)
@tornado.testing.gen_test
- def test_app_with_user_pkey2fa_with_correct_password_and_passcode(self):
+ def test_app_with_user_pkey2fa_with_correct_passwords(self):
url = self.get_url('/')
privatekey = read_file(make_tests_data_path('user_rsa_key'))
self.body_dict.update(username='pkey2fa', password='password',
@@ -499,7 +481,7 @@ class TestAppBasic(TestAppBase):
privatekey=privatekey, totp='passcode')
response = yield self.async_post(url, self.body_dict)
data = json.loads(to_str(response.body))
- self.assertIn('Authentication failed', data['status'])
+ self.assert_status_in('Authentication failed', data)
@tornado.testing.gen_test
def test_app_with_user_pkey2fa_with_wrong_passcode(self):
@@ -509,7 +491,17 @@ class TestAppBasic(TestAppBase):
privatekey=privatekey, totp='wrongpasscode')
response = yield self.async_post(url, self.body_dict)
data = json.loads(to_str(response.body))
- self.assertIn('Authentication failed', data['status'])
+ self.assert_status_in('Authentication failed', data)
+
+ @tornado.testing.gen_test
+ def test_app_with_user_pkey2fa_with_empty_passcode(self):
+ url = self.get_url('/')
+ privatekey = read_file(make_tests_data_path('user_rsa_key'))
+ self.body_dict.update(username='pkey2fa', password='password',
+ privatekey=privatekey, totp='')
+ response = yield self.async_post(url, self.body_dict)
+ data = json.loads(to_str(response.body))
+ self.assert_status_in('Need a verification code', data)
class OtherTestBase(TestAppBase):
@@ -747,13 +739,13 @@ class TestAppWithCrossOriginOperation(OtherTestBase):
def test_app_with_wrong_event_origin(self):
body = dict(self.body, _origin='localhost')
response = yield self.async_post('/', body)
- self.assert_status_equal(json.loads(to_str(response.body)), 'Cross origin operation is not allowed.') # noqa
+ self.assert_status_equal('Cross origin operation is not allowed.', json.loads(to_str(response.body))) # noqa
@tornado.testing.gen_test
def test_app_with_wrong_header_origin(self):
headers = dict(Origin='localhost')
response = yield self.async_post('/', self.body, headers=headers)
- self.assert_status_equal(json.loads(to_str(response.body)), 'Cross origin operation is not allowed.') # noqa
+ self.assert_status_equal('Cross origin operation is not allowed.', json.loads(to_str(response.body)), ) # noqa
@tornado.testing.gen_test
def test_app_with_correct_event_origin(self):
diff --git a/webssh/handler.py b/webssh/handler.py
@@ -36,80 +36,68 @@ swallow_http_errors = True
redirecting = None
-def make_handler(password, totp):
+class InvalidValueError(Exception):
+ pass
+
- def handler(title, instructions, prompt_list):
+class SSHClient(paramiko.SSHClient):
+
+ def handler(self, title, instructions, prompt_list):
answers = []
for prompt_, _ in prompt_list:
prompt = prompt_.strip().lower()
if prompt.startswith('password'):
- answers.append(password)
+ answers.append(self.password)
elif prompt.startswith('verification'):
- answers.append(totp)
+ answers.append(self.totp)
else:
raise ValueError('Unknown prompt: {}'.format(prompt_))
return answers
- return handler
-
-
-def auth_interactive(transport, username, handler):
- if not handler:
- raise ValueError('Need a verification code for 2fa.')
- transport.auth_interactive(username, handler)
+ def auth_interactive(self, username, handler):
+ if not self.totp:
+ raise ValueError('Need a verification code for 2fa.')
+ self._transport.auth_interactive(username, handler)
+ def _auth(self, username, password, pkey, *args):
+ self.password = password
+ saved_exception = None
+ two_factor = False
+ allowed_types = set()
+ two_factor_types = {'keyboard-interactive', 'password'}
-def auth(self, username, password, pkey, *args):
- handler = None
- saved_exception = None
- two_factor = False
- allowed_types = set()
- two_factor_types = {"keyboard-interactive", "password"}
-
- if self._totp:
- handler = make_handler(password, self._totp)
-
- if pkey is not None:
- logging.info('Trying public key authentication')
- try:
- allowed_types = set(
- self._transport.auth_publickey(username, pkey)
- )
- two_factor = allowed_types & two_factor_types
- if not two_factor:
+ if pkey is not None:
+ logging.info('Trying publickey authentication')
+ try:
+ allowed_types = set(
+ self._transport.auth_publickey(username, pkey)
+ )
+ two_factor = allowed_types & two_factor_types
+ if not two_factor:
+ return
+ except paramiko.SSHException as e:
+ saved_exception = e
+
+ if two_factor:
+ logging.info('Trying publickey 2fa')
+ return self.auth_interactive(username, self.handler)
+
+ if password is not None:
+ logging.info('Trying password authentication')
+ try:
+ self._transport.auth_password(username, password)
return
- except paramiko.SSHException as e:
- saved_exception = e
-
- if two_factor:
- logging.info('Trying publickey 2fa')
- return auth_interactive(self._transport, username, handler)
-
- if password is not None:
- logging.info('Trying password authentication')
- try:
- self._transport.auth_password(username, password)
- return
- except paramiko.SSHException as e:
- saved_exception = e
- allowed_types = set(getattr(e, 'allowed_types', []))
- two_factor = allowed_types & two_factor_types
+ except paramiko.SSHException as e:
+ saved_exception = e
+ allowed_types = set(getattr(e, 'allowed_types', []))
+ two_factor = allowed_types & two_factor_types
- if two_factor:
- logging.info('Trying password 2fa')
- return auth_interactive(self._transport, username, handler)
+ if two_factor:
+ logging.info('Trying password 2fa')
+ return self.auth_interactive(username, self.handler)
- # if we got an auth-failed exception earlier, re-raise it
- if saved_exception is not None:
+ assert saved_exception is not None
raise saved_exception
- raise paramiko.SSHException("No authentication methods available")
-
-
-paramiko.client.SSHClient._auth = auth
-
-
-class InvalidValueError(Exception):
- pass
class PrivateKey(object):
@@ -327,7 +315,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
super(IndexHandler, self).write_error(status_code, **kwargs)
def get_ssh_client(self):
- ssh = paramiko.SSHClient()
+ ssh = SSHClient()
ssh._system_host_keys = self.host_keys_settings['system_host_keys']
ssh._host_keys = self.host_keys_settings['host_keys']
ssh._host_keys_filename = self.host_keys_settings['host_keys_filename']
@@ -392,7 +380,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
else:
pkey = None
- self.ssh_client._totp = totp
+ self.ssh_client.totp = totp
args = (hostname, port, username, password, pkey)
logging.debug(args)