webssh

Web based ssh client https://github.com/huashengdun/webssh webssh.huashengdun.org/
git clone http://git.hanabi.in/repos/webssh.git
Log | Files | Refs | README | LICENSE

commit 2653a3e35af499b90c1346323b85248619e337e4
parent c126856daa629d32568da28233e4f19b6007c257
Author: Sheng <webmaster0115@gmail.com>
Date:   Sat, 29 Dec 2018 16:16:06 +0800

Added function for limiting connections for every client(ip)

Diffstat:
Mtests/test_app.py | 46++++++++++++++++++++++++++++++++++++++++++++++
Mwebssh/handler.py | 20+++++++++++++++-----
Mwebssh/settings.py | 1+
Mwebssh/worker.py | 9++++++---
4 files changed, 68 insertions(+), 8 deletions(-)

diff --git a/tests/test_app.py b/tests/test_app.py @@ -15,6 +15,7 @@ from webssh.settings import ( get_app_settings, get_server_settings, max_body_size ) from webssh.utils import to_str +from webssh.worker import clients try: from urllib.parse import urlencode @@ -447,6 +448,7 @@ class OtherTestBase(AsyncHTTPTestCase): hostfile = '' syshostfile = '' tdstream = '' + maxconn = 20 body = { 'hostname': '127.0.0.1', 'port': '', @@ -464,6 +466,7 @@ class OtherTestBase(AsyncHTTPTestCase): options.hostfile = self.hostfile options.syshostfile = self.syshostfile options.tdstream = self.tdstream + options.maxconn = self.maxconn app = make_app(make_handlers(loop, options), get_app_settings(options)) return app @@ -670,3 +673,46 @@ class TestAppWithPutRequest(OtherTestBase): url, method='PUT', body=body, headers=self.headers ) self.assertIn('Method Not Allowed', ctx.exception.message) + + +class TestAppWithTooManyConnections(OtherTestBase): + + maxconn = 1 + + def setUp(self): + clients.clear() + super(TestAppWithTooManyConnections, self).setUp() + + @tornado.testing.gen_test + def test_app_with_too_many_connections(self): + url = self.get_url('/') + client = self.get_http_client() + body = urlencode(dict(self.body, username='foo')) + response = yield client.fetch(url, method='POST', body=body, + headers=self.headers) + data = json.loads(to_str(response.body)) + worker_id = data['id'] + self.assertIsNotNone(worker_id) + self.assertIsNotNone(data['encoding']) + self.assertIsNone(data['status']) + + response = yield client.fetch(url, method='POST', body=body, + headers=self.headers) + data = json.loads(to_str(response.body)) + self.assertIsNone(data['id']) + self.assertIsNone(data['encoding']) + self.assertEqual(data['status'], 'Too many connections.') + + ws_url = url.replace('http', 'ws') + 'ws?id=' + worker_id + ws = yield tornado.websocket.websocket_connect(ws_url) + msg = yield ws.read_message() + self.assertIsNotNone(msg) + + response = yield client.fetch(url, method='POST', body=body, + headers=self.headers) + data = json.loads(to_str(response.body)) + self.assertIsNone(data['id']) + self.assertIsNone(data['encoding']) + self.assertEqual(data['status'], 'Too many connections.') + + ws.close() diff --git a/webssh/handler.py b/webssh/handler.py @@ -15,7 +15,7 @@ from webssh.utils import ( is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str, to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname ) -from webssh.worker import Worker, recycle_worker, workers +from webssh.worker import Worker, recycle_worker, clients try: from concurrent.futures import Future @@ -311,8 +311,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): chan = ssh.invoke_shell(term='xterm') chan.setblocking(0) - worker = Worker(self.loop, ssh, chan, dst_addr) - worker.src_addr = self.get_client_addr() + worker = Worker(self.loop, ssh, chan, dst_addr, self.src_addr) worker.encoding = self.get_default_encoding(ssh) return worker @@ -337,6 +336,10 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): # for testing purpose only raise ValueError('Uncaught exception') + self.src_addr = self.get_client_addr() + if len(clients.get(self.src_addr[0], {})) >= options.maxconn: + raise tornado.web.HTTPError(403, 'Too many connections.') + future = Future() t = threading.Thread(target=self.ssh_connect_wrapped, args=(future,)) t.setDaemon(True) @@ -347,6 +350,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): except (ValueError, paramiko.SSHException) as exc: self.result.update(status=str(exc)) else: + workers = clients.setdefault(worker.src_addr[0], {}) workers[worker.id] = worker self.loop.call_later(DELAY, recycle_worker, worker) self.result.update(id=worker.id, encoding=worker.encoding) @@ -363,14 +367,20 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): def open(self): self.src_addr = self.get_client_addr() logging.info('Connected from {}:{}'.format(*self.src_addr)) + + workers = clients.get(self.src_addr[0]) + if not workers: + self.close(reason='Websocket authentication failed.') + return + try: worker_id = self.get_value('id') except (tornado.web.MissingArgumentError, InvalidValueError) as exc: self.close(reason=str(exc)) else: worker = workers.get(worker_id) - if worker and worker.src_addr[0] == self.src_addr[0]: - workers.pop(worker.id) + if worker: + workers[worker_id] = None self.set_nodelay(True) worker.set_handler(self) self.worker_ref = weakref.ref(worker) diff --git a/webssh/settings.py b/webssh/settings.py @@ -35,6 +35,7 @@ define('fbidhttp', type=bool, default=True, define('xheaders', type=bool, default=True, help='Support xheaders') define('xsrf', type=bool, default=True, help='CSRF protection') define('wpintvl', type=int, default=0, help='Websocket ping interval') +define('maxconn', type=int, default=20, help='Maximum connections per client') define('version', type=bool, help='Show version information', callback=print_version) diff --git a/webssh/worker.py b/webssh/worker.py @@ -7,23 +7,23 @@ from tornado.util import errno_from_exception BUF_SIZE = 32 * 1024 -workers = {} +clients = {} def recycle_worker(worker): if worker.handler: return logging.warning('Recycling worker {}'.format(worker.id)) - workers.pop(worker.id, None) worker.close(reason='worker recycled') class Worker(object): - def __init__(self, loop, ssh, chan, dst_addr): + def __init__(self, loop, ssh, chan, dst_addr, src_addr): self.loop = loop self.ssh = ssh self.chan = chan self.dst_addr = dst_addr + self.src_addr = src_addr self.fd = chan.fileno() self.id = str(id(self)) self.data_to_dst = [] @@ -104,3 +104,6 @@ class Worker(object): self.chan.close() self.ssh.close() logging.info('Connection to {}:{} lost'.format(*self.dst_addr)) + + clients[self.src_addr[0]].pop(self.id, None) + logging.debug(clients)