commit 2b8b978ca20169e1fd75198b33ab83ee21878a16
parent 6d62642c7f111ea7d0aea306831edcc16876f135
Author: Sheng <webmaster0115@gmail.com>
Date: Thu, 27 Jun 2019 12:52:19 +0800
Added PrivateKey class
Diffstat:
3 files changed, 93 insertions(+), 102 deletions(-)
diff --git a/tests/test_app.py b/tests/test_app.py
@@ -358,7 +358,7 @@ class TestAppBasic(TestAppBase):
if swallow_http_errors:
response = yield self.async_post(url, body, headers=headers)
- self.assertIn(b'Invalid private key', response.body)
+ self.assertIn(b'Invalid key', response.body)
else:
with self.assertRaises(HTTPError) as ctx:
yield self.async_post(url, body, headers=headers)
@@ -367,7 +367,7 @@ class TestAppBasic(TestAppBase):
@tornado.testing.gen_test
def test_app_auth_with_pubkey_exceeds_key_max_size(self):
url = self.get_url('/')
- privatekey = 'h' * (handler.KEY_MAX_SIZE * 2)
+ privatekey = 'h' * (handler.PrivateKey.max_length + 1)
files = [('privatekey', 'user_rsa_key', privatekey)]
content_type, body = encode_multipart_formdata(self.body_dict.items(),
files)
@@ -376,7 +376,7 @@ class TestAppBasic(TestAppBase):
}
if swallow_http_errors:
response = yield self.async_post(url, body, headers=headers)
- self.assertIn(b'Invalid private key', response.body)
+ self.assertIn(b'Invalid key', response.body)
else:
with self.assertRaises(HTTPError) as ctx:
yield self.async_post(url, body, headers=headers)
diff --git a/tests/test_handler.py b/tests/test_handler.py
@@ -6,7 +6,7 @@ from tornado.options import options
from tests.utils import read_file, make_tests_data_path
from webssh import handler
from webssh.handler import (
- MixinHandler, IndexHandler, WsockHandler, InvalidValueError
+ MixinHandler, WsockHandler, PrivateKey, InvalidValueError
)
try:
@@ -142,73 +142,59 @@ class TestMixinHandler(unittest.TestCase):
(x_real_ip, x_real_port))
-class TestIndexHandler(unittest.TestCase):
+class TestPrivateKey(unittest.TestCase):
- def test_get_specific_pkey_with_plain_key(self):
- fname = 'test_rsa.key'
- cls = paramiko.RSAKey
+ def get_pk_obj(self, fname, password=None):
key = read_file(make_tests_data_path(fname))
+ return PrivateKey(key, password=password, filename=fname)
- pkey = IndexHandler.get_specific_pkey(cls, key, None)
- self.assertIsInstance(pkey, cls)
+ def _test_with_encrypted_key(self, fname, password, klass):
+ pk = self.get_pk_obj(fname, password='')
+ with self.assertRaises(InvalidValueError) as ctx:
+ pk.get_pkey_obj()
+ self.assertIn('Need a password', str(ctx.exception))
- pkey = IndexHandler.get_specific_pkey(cls, key, 'iginored')
- self.assertIsInstance(pkey, cls)
+ pk = self.get_pk_obj(fname, password='wrongpass')
+ with self.assertRaises(InvalidValueError) as ctx:
+ pk.get_pkey_obj()
+ self.assertIn('wrong password', str(ctx.exception))
- pkey = IndexHandler.get_specific_pkey(cls, 'x'+key, None)
- self.assertIsNone(pkey)
+ pk = self.get_pk_obj(fname, password=password)
+ self.assertIsInstance(pk.get_pkey_obj(), klass)
- def test_get_specific_pkey_with_encrypted_key(self):
- fname = 'test_rsa_password.key'
- cls = paramiko.RSAKey
- password = 'television'
+ def test_class_with_invalid_key_length(self):
+ key = u'a' * (PrivateKey.max_length + 1)
- key = read_file(make_tests_data_path(fname))
- pkey = IndexHandler.get_specific_pkey(cls, key, password)
- self.assertIsInstance(pkey, cls)
+ with self.assertRaises(InvalidValueError) as ctx:
+ PrivateKey(key)
+ self.assertIn('Invalid key length', str(ctx.exception))
- pkey = IndexHandler.get_specific_pkey(cls, 'x'+key, None)
- self.assertIsNone(pkey)
+ def test_get_pkey_obj_with_invalid_key(self):
+ key = u'a b c'
+ fname = 'abc'
+ pk = PrivateKey(key, filename=fname)
with self.assertRaises(InvalidValueError) as ctx:
- pkey = IndexHandler.get_specific_pkey(cls, key, None)
- self.assertIn('Need a password', str(ctx.exception))
+ pk.get_pkey_obj()
+ self.assertIn('Invalid key {}'.format(fname), str(ctx.exception))
- def test_get_pkey_obj_with_plain_key(self):
- fname = 'test_ed25519.key'
- cls = paramiko.Ed25519Key
- key = read_file(make_tests_data_path(fname))
+ def test_get_pkey_obj_with_plain_rsa_key(self):
+ pk = self.get_pk_obj('test_rsa.key')
+ self.assertIsInstance(pk.get_pkey_obj(), paramiko.RSAKey)
- pkey = IndexHandler.get_pkey_obj(key, None, fname)
- self.assertIsInstance(pkey, cls)
+ def test_get_pkey_obj_with_plain_ed25519_key(self):
+ pk = self.get_pk_obj('test_ed25519.key')
+ self.assertIsInstance(pk.get_pkey_obj(), paramiko.Ed25519Key)
- pkey = IndexHandler.get_pkey_obj(key, 'iginored', fname)
- self.assertIsInstance(pkey, cls)
-
- with self.assertRaises(InvalidValueError) as ctx:
- pkey = IndexHandler.get_pkey_obj('x'+key, None, fname)
- self.assertIn('Invalid private key', str(ctx.exception))
+ def test_get_pkey_obj_with_encrypted_rsa_key(self):
+ fname = 'test_rsa_password.key'
+ password = 'television'
+ self._test_with_encrypted_key(fname, password, paramiko.RSAKey)
- def test_get_pkey_obj_with_encrypted_key(self):
+ def test_get_pkey_obj_with_encrypted_ed25519_key(self):
fname = 'test_ed25519_password.key'
password = 'abc123'
- cls = paramiko.Ed25519Key
- key = read_file(make_tests_data_path(fname))
-
- pkey = IndexHandler.get_pkey_obj(key, password, fname)
- self.assertIsInstance(pkey, cls)
-
- with self.assertRaises(InvalidValueError) as ctx:
- pkey = IndexHandler.get_pkey_obj(key, 'wrongpass', fname)
- self.assertIn('Wrong password', str(ctx.exception))
-
- with self.assertRaises(InvalidValueError) as ctx:
- pkey = IndexHandler.get_pkey_obj('x'+key, '', fname)
- self.assertIn('Invalid private key', str(ctx.exception))
-
- with self.assertRaises(InvalidValueError) as ctx:
- pkey = IndexHandler.get_specific_pkey(cls, key, None)
- self.assertIn('Need a password', str(ctx.exception))
+ self._test_with_encrypted_key(fname, password, paramiko.Ed25519Key)
class TestWsockHandler(unittest.TestCase):
diff --git a/webssh/handler.py b/webssh/handler.py
@@ -30,7 +30,6 @@ except ImportError:
DELAY = 3
-KEY_MAX_SIZE = 16384
DEFAULT_PORT = 22
swallow_http_errors = True
@@ -41,6 +40,53 @@ class InvalidValueError(Exception):
pass
+class PrivateKey(object):
+
+ max_length = 16384 # rough number
+
+ tag_to_name = {
+ 'RSA': 'RSA',
+ 'DSA': 'DSS',
+ 'EC': 'ECDSA',
+ 'OPENSSH': 'Ed25519'
+ }
+
+ def __init__(self, privatekey, password=None, filename=''):
+ self.privatekey = privatekey.strip()
+ self.filename = filename
+ self.password = password
+ self.check_length()
+
+ def check_length(self):
+ if len(self.privatekey) > self.max_length:
+ raise InvalidValueError('Invalid key length.')
+
+ def get_name(self):
+ lst = self.privatekey.split(' ', 2)
+ if len(lst) > 1:
+ return self.tag_to_name.get(lst[1])
+
+ def get_pkey_obj(self):
+ name = self.get_name()
+ if not name:
+ raise InvalidValueError('Invalid key {}.'.format(self.filename))
+
+ logging.info('Parsing {} key'.format(name))
+ pkeycls = getattr(paramiko, name+'Key')
+ password = to_bytes(self.password) if self.password else None
+ try:
+ return pkeycls.from_private_key(io.StringIO(self.privatekey),
+ password=password)
+ except paramiko.PasswordRequiredException:
+ raise InvalidValueError('Need a password to decrypt the key.')
+ except paramiko.SSHException as exc:
+ logging.error(str(exc))
+ raise InvalidValueError(
+ 'Invalid key or wrong password "{}" for decrypting it.'
+ .format(self.password)
+ )
+
+
class MixinHandler(object):
custom_headers = {
@@ -176,7 +222,6 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
self.policy = policy
self.host_keys_settings = host_keys_settings
self.ssh_client = self.get_ssh_client()
- self.privatekey_filename = None
self.debug = self.settings.get('debug', False)
self.result = dict(id=None, status=None, encoding=None)
@@ -206,53 +251,15 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
lst = self.request.files.get(name)
if lst:
# multipart form
- self.privatekey_filename = lst[0]['filename']
+ filename = lst[0]['filename']
data = lst[0]['body']
value = self.decode_argument(data, name=name).strip()
else:
# urlencoded form
value = self.get_argument(name, u'')
+ filename = ''
- if len(value) > KEY_MAX_SIZE:
- raise InvalidValueError(
- 'Invalid private key: {}'.format(self.privatekey_filename)
- )
- return value
-
- @classmethod
- def get_specific_pkey(cls, pkeycls, privatekey, password):
- logging.info('Trying {}'.format(pkeycls.__name__))
- try:
- pkey = pkeycls.from_private_key(io.StringIO(privatekey),
- password=password)
- except paramiko.PasswordRequiredException:
- raise InvalidValueError(
- 'Need a password to decrypt the private key.'
- )
- except paramiko.SSHException:
- pass
- else:
- return pkey
-
- @classmethod
- def get_pkey_obj(cls, privatekey, password, filename):
- bpass = to_bytes(password) if password else None
-
- pkey = cls.get_specific_pkey(paramiko.RSAKey, privatekey, bpass)\
- or cls.get_specific_pkey(paramiko.DSSKey, privatekey, bpass)\
- or cls.get_specific_pkey(paramiko.ECDSAKey, privatekey, bpass)\
- or cls.get_specific_pkey(paramiko.Ed25519Key, privatekey, bpass)
-
- if not pkey:
- if not password:
- error = 'Invalid private key: {}'.format(filename)
- else:
- error = (
- 'Wrong password {!r} for decrypting the private key.'
- ) .format(password)
- raise InvalidValueError(error)
-
- return pkey
+ return value, filename
def get_hostname(self):
value = self.get_value('hostname')
@@ -287,11 +294,9 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
self.lookup_hostname(hostname, port)
username = self.get_value('username')
password = self.get_argument('password', u'')
- privatekey = self.get_privatekey()
+ privatekey, filename = self.get_privatekey()
if privatekey:
- pkey = self.get_pkey_obj(
- privatekey, password, self.privatekey_filename
- )
+ pkey = PrivateKey(privatekey, password, filename).get_pkey_obj()
password = None
else:
pkey = None