commit a576a41ea42556c70e838ca014c16940d717110f
parent 90e7ea032756f8aca62b84daee235c21541cc071
Author: Sheng <webmaster0115@gmail.com>
Date: Mon, 8 Oct 2018 20:25:21 +0800
Lookup hostname before connection under reject policy
Diffstat:
2 files changed, 43 insertions(+), 8 deletions(-)
diff --git a/tests/test_app.py b/tests/test_app.py
@@ -428,6 +428,9 @@ class OtherTestBase(AsyncHTTPTestCase):
sshserver_port = 3300
headers = {'Cookie': '_xsrf=yummy'}
debug = False
+ policy = None
+ hostFile = None
+ sysHostFile = None
body = {
'hostname': '127.0.0.1',
'port': '',
@@ -440,9 +443,9 @@ class OtherTestBase(AsyncHTTPTestCase):
self.body.update(port=str(self.sshserver_port))
loop = self.io_loop
options.debug = self.debug
- options.policy = random.choice(['warning', 'autoadd'])
- options.hostFile = ''
- options.sysHostFile = ''
+ options.policy = self.policy if self.policy else random.choice(['warning', 'autoadd']) # noqa
+ options.hostFile = self.hostFile if self.hostFile else ''
+ options.sysHostFile = self.sysHostFile if self.sysHostFile else ''
app = make_app(make_handlers(loop, options), get_app_settings(options))
return app
@@ -516,3 +519,21 @@ class TestAppMiscell(OtherTestBase):
recv = b''.join(lst).decode(data['encoding'])
self.assertEqual(send, recv)
ws.close()
+
+
+class TestAppWithRejectPolicy(OtherTestBase):
+
+ policy = 'reject'
+ hostFile = make_tests_data_path('known_hosts_example')
+
+ @tornado.testing.gen_test
+ def test_app_with_hostname_not_in_hostkeys(self):
+ url = self.get_url('/')
+ client = self.get_http_client()
+ body = urlencode(dict(self.body, username='foo'))
+ response = yield client.fetch(url, method='POST', body=body,
+ headers=self.headers)
+ data = json.loads(to_str(response.body))
+ self.assertIsNone(data['id'])
+ self.assertIsNone(data['encoding'])
+ self.assertEqual('Connection to 127.0.0.1 is not allowed.', data['status']) # noqa
diff --git a/webssh/handler.py b/webssh/handler.py
@@ -70,6 +70,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
self.loop = loop
self.policy = policy
self.host_keys_settings = host_keys_settings
+ self.ssh_client = self.get_ssh_client()
self.filename = None
self.result = dict(id=None, status=None, encoding=None)
@@ -87,6 +88,14 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
self.set_status(200)
self.finish(self.result)
+ def get_ssh_client(self):
+ ssh = paramiko.SSHClient()
+ ssh._system_host_keys = self.host_keys_settings['system_host_keys']
+ ssh._host_keys = self.host_keys_settings['host_keys']
+ ssh._host_keys_filename = self.host_keys_settings['host_keys_filename']
+ ssh.set_missing_host_key_policy(self.policy)
+ return ssh
+
def get_privatekey(self):
name = 'privatekey'
lst = self.request.files.get(name) # multipart form
@@ -143,6 +152,14 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
raise InvalidValueError('Invalid hostname: {}'.format(value))
return value
+ def lookup_hostname(self, hostname):
+ if isinstance(self.policy, paramiko.RejectPolicy):
+ if self.ssh_client._system_host_keys.lookup(hostname) is None:
+ if self.ssh_client._host_keys.lookup(hostname) is None:
+ raise ValueError(
+ 'Connection to {} is not allowed.'.format(hostname)
+ )
+
def get_port(self):
value = self.get_value('port')
try:
@@ -157,6 +174,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
def get_args(self):
hostname = self.get_hostname()
+ self.lookup_hostname(hostname)
port = self.get_port()
username = self.get_value('username')
password = self.get_argument('password', u'')
@@ -182,11 +200,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
return result if result else 'utf-8'
def ssh_connect(self):
- ssh = paramiko.SSHClient()
- ssh._system_host_keys = self.host_keys_settings['system_host_keys']
- ssh._host_keys = self.host_keys_settings['host_keys']
- ssh._host_keys_filename = self.host_keys_settings['host_keys_filename']
- ssh.set_missing_host_key_policy(self.policy)
+ ssh = self.ssh_client
try:
args = self.get_args()