commit c90f1e5a5799e1a7a8b95f0412e1072622f19122
parent 331eaa2ba736efa83e1cadc032baafe0be6b690b
Author: Sheng <webmaster0115@gmail.com>
Date: Thu, 26 Apr 2018 21:51:01 +0800
Added test_app.py
Diffstat:
4 files changed, 252 insertions(+), 7 deletions(-)
diff --git a/tests/sshserver.py b/tests/sshserver.py
@@ -0,0 +1,123 @@
+#!/usr/bin/env python
+
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
+
+from binascii import hexlify
+import socket
+# import sys
+import threading
+# import traceback
+
+import paramiko
+from paramiko.py3compat import u, decodebytes
+
+
+# setup logging
+paramiko.util.log_to_file('sshserver.log')
+
+host_key = paramiko.RSAKey(filename='tests/test_rsa.key')
+# host_key = paramiko.DSSKey(filename='test_dss.key')
+
+print('Read key: ' + u(hexlify(host_key.get_fingerprint())))
+
+
+class Server (paramiko.ServerInterface):
+ # 'data' is the output of base64.b64encode(key)
+ # (using the "user_rsa_key" files)
+ data = (b'AAAAB3NzaC1yc2EAAAABIwAAAIEAyO4it3fHlmGZWJaGrfeHOVY7RWO3P9M7hp'
+ b'fAu7jJ2d7eothvfeuoRFtJwhUmZDluRdFyhFY/hFAh76PJKGAusIqIQKlkJxMC'
+ b'KDqIexkgHAfID/6mqvmnSJf0b5W8v5h2pI/stOSwTQ+pxVhwJ9ctYDhRSlF0iT'
+ b'UWT10hcuO4Ks8=')
+ good_pub_key = paramiko.RSAKey(data=decodebytes(data))
+
+ def __init__(self):
+ self.event = threading.Event()
+
+ def check_channel_request(self, kind, chanid):
+ if kind == 'session':
+ return paramiko.OPEN_SUCCEEDED
+ return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
+
+ def check_auth_password(self, username, password):
+ if (username == 'robey') and (password == 'foo'):
+ return paramiko.AUTH_SUCCESSFUL
+ return paramiko.AUTH_FAILED
+
+ def check_auth_publickey(self, username, key):
+ print('Auth attempt with key: ' + u(hexlify(key.get_fingerprint())))
+ if (username == 'robey') and (key == self.good_pub_key):
+ return paramiko.AUTH_SUCCESSFUL
+ return paramiko.AUTH_FAILED
+
+ def get_allowed_auths(self, username):
+ return 'password,publickey'
+
+ def check_channel_shell_request(self, channel):
+ self.event.set()
+ return True
+
+ def check_channel_pty_request(self, channel, term, width, height,
+ pixelwidth, pixelheight, modes):
+ return True
+
+
+def run_ssh_server(app):
+ # now connect
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.bind(('', 2200))
+ sock.listen(100)
+
+ while not app._tear_down:
+ client, addr = sock.accept()
+ print('Got a connection!')
+
+ t = paramiko.Transport(client)
+ t.load_server_moduli()
+ t.add_server_key(host_key)
+ server = Server()
+ try:
+ t.start_server(server=server)
+ except Exception as e:
+ print(e)
+ continue
+
+ # wait for auth
+ chan = t.accept(20)
+ if chan is None:
+ print('*** No channel.')
+ continue
+
+ print('Authenticated!')
+
+ server.event.wait(10)
+ if not server.event.is_set():
+ print('*** Client never asked for a shell.')
+ continue
+
+ chan.send('\r\n\r\nWelcome!\r\n\r\n')
+
+ try:
+ sock.close()
+ except Exception:
+ pass
+
+
+if __name__ == '__main__':
+ run_ssh_server(False)
diff --git a/tests/test_app.py b/tests/test_app.py
@@ -0,0 +1,120 @@
+import json
+import handler
+import random
+import threading
+import tornado.websocket
+import tornado.gen
+
+from tornado.testing import AsyncHTTPTestCase
+from tornado.options import options
+from main import make_app, make_handlers
+from settings import get_app_settings
+from tests.sshserver import run_ssh_server
+
+
+handler.DELAY = 0.1
+
+
+class TestApp(AsyncHTTPTestCase):
+
+ _tear_down = False
+
+ def get_app(self):
+ loop = self.io_loop
+ self._tear_down = False
+ options.debug = True
+ options.policy = random.choice(['warning', 'autoadd'])
+ options.hostFile = ''
+ options.sysHostFile = ''
+ app = make_app(make_handlers(loop, options), get_app_settings(options))
+ return app
+
+ @classmethod
+ def tearDownClass(cls):
+ cls._tear_down = True
+
+ def test_app_with_invalid_form(self):
+ response = self.fetch('/')
+ self.assertEqual(response.code, 200)
+ body = u'hostname=&port=&username=&password'
+ response = self.fetch('/', method="POST", body=body)
+ self.assertIn(b'"status": "Empty hostname"', response.body)
+
+ body = u'hostname=127.0.0.1&port=&username=&password'
+ response = self.fetch('/', method="POST", body=body)
+ self.assertIn(b'"status": "Empty port"', response.body)
+
+ body = u'hostname=127.0.0.1&port=port&username=&password'
+ response = self.fetch('/', method="POST", body=body)
+ self.assertIn(b'"status": "Invalid port', response.body)
+
+ body = u'hostname=127.0.0.1&port=70000&username=&password'
+ response = self.fetch('/', method="POST", body=body)
+ self.assertIn(b'"status": "Invalid port', response.body)
+
+ body = u'hostname=127.0.0.1&port=7000&username=&password'
+ response = self.fetch('/', method="POST", body=body)
+ self.assertIn(b'"status": "Empty username"', response.body)
+
+ body = u'hostname=127.0.0.1&port=22&username=user&password'
+ response = self.fetch('/', method="POST", body=body)
+ self.assertIn(b'Unable to connect to', response.body)
+
+ def test_app_with_wrong_credentials(self):
+ response = self.fetch('/')
+ self.assertEqual(response.code, 200)
+ body = u'hostname=127.0.0.1&port=2200&username=robey&password=foos'
+ response = self.fetch('/', method="POST", body=body)
+ self.assertIn(b'Authentication failed.', response.body)
+
+ def test_app_with_correct_credentials(self):
+ response = self.fetch('/')
+ self.assertEqual(response.code, 200)
+ body = u'hostname=127.0.0.1&port=2200&username=robey&password=foo'
+ response = self.fetch('/', method="POST", body=body)
+ worker_id = json.loads(response.body.decode('utf-8'))['id']
+ self.assertIsNotNone(worker_id)
+
+ @tornado.testing.gen_test
+ def test_app_with_correct_credentials_timeout(self):
+ url = self.get_url('/')
+ client = self.get_http_client()
+ response = yield client.fetch(url)
+ self.assertEqual(response.code, 200)
+
+ body = u'hostname=127.0.0.1&port=2200&username=robey&password=foo'
+ response = yield client.fetch(url, method="POST", body=body)
+ worker_id = json.loads(response.body.decode('utf-8'))['id']
+ self.assertIsNotNone(worker_id)
+
+ url = url.replace('http', 'ws')
+ ws_url = url + 'ws?id=' + worker_id
+ yield tornado.gen.sleep(handler.DELAY + 0.1)
+ ws = yield tornado.websocket.websocket_connect(ws_url)
+ msg = yield ws.read_message()
+ self.assertIsNone(msg)
+ ws.close()
+
+ @tornado.testing.gen_test
+ def test_app_with_correct_credentials_welcome(self):
+ url = self.get_url('/')
+ client = self.get_http_client()
+ response = yield client.fetch(url)
+ self.assertEqual(response.code, 200)
+
+ body = u'hostname=127.0.0.1&port=2200&username=robey&password=foo'
+ response = yield client.fetch(url, method="POST", body=body)
+ worker_id = json.loads(response.body.decode('utf-8'))['id']
+ self.assertIsNotNone(worker_id)
+
+ url = url.replace('http', 'ws')
+ ws_url = url + 'ws?id=' + worker_id
+ ws = yield tornado.websocket.websocket_connect(ws_url)
+ msg = yield ws.read_message()
+ self.assertIn('Welcome!', msg)
+ ws.close()
+
+
+t = threading.Thread(target=run_ssh_server, args=(TestApp,))
+t.setDaemon(True)
+t.start()
diff --git a/webssh/main.py b/webssh/main.py
@@ -8,25 +8,27 @@ from settings import (get_app_settings, get_host_keys_settings,
get_policy_setting)
-def make_app(loop, policy, host_keys_settings, app_settings):
+def make_handlers(loop, options):
+ host_keys_settings = get_host_keys_settings(options)
+ policy = get_policy_setting(options, host_keys_settings)
+
handlers = [
(r'/', IndexHandler, dict(loop=loop, policy=policy,
host_keys_settings=host_keys_settings)),
(r'/ws', WsockHandler, dict(loop=loop))
]
+ return handlers
+
+def make_app(handlers, app_settings):
app = tornado.web.Application(handlers, **app_settings)
return app
def main():
parse_command_line()
- app_settings = get_app_settings(options)
- host_keys_settings = get_host_keys_settings(options)
- policy = get_policy_setting(options, host_keys_settings)
-
loop = tornado.ioloop.IOLoop.current()
- app = make_app(loop, policy, host_keys_settings, app_settings)
+ app = make_app(make_handlers(loop, options), get_app_settings(options))
app.listen(options.port, options.address)
logging.info('Listening on {}:{}'.format(options.address, options.port))
loop.start()
diff --git a/webssh/settings.py b/webssh/settings.py
@@ -23,7 +23,7 @@ def get_app_settings(options):
template_path=os.path.join(base_dir, 'templates'),
static_path=os.path.join(base_dir, 'static'),
cookie_secret=uuid.uuid4().hex,
- xsrf_cookies=True,
+ xsrf_cookies=(not options.debug),
debug=options.debug
)
return settings