commit 2653a3e35af499b90c1346323b85248619e337e4
parent c126856daa629d32568da28233e4f19b6007c257
Author: Sheng <webmaster0115@gmail.com>
Date: Sat, 29 Dec 2018 16:16:06 +0800
Added function for limiting connections for every client(ip)
Diffstat:
4 files changed, 68 insertions(+), 8 deletions(-)
diff --git a/tests/test_app.py b/tests/test_app.py
@@ -15,6 +15,7 @@ from webssh.settings import (
get_app_settings, get_server_settings, max_body_size
)
from webssh.utils import to_str
+from webssh.worker import clients
try:
from urllib.parse import urlencode
@@ -447,6 +448,7 @@ class OtherTestBase(AsyncHTTPTestCase):
hostfile = ''
syshostfile = ''
tdstream = ''
+ maxconn = 20
body = {
'hostname': '127.0.0.1',
'port': '',
@@ -464,6 +466,7 @@ class OtherTestBase(AsyncHTTPTestCase):
options.hostfile = self.hostfile
options.syshostfile = self.syshostfile
options.tdstream = self.tdstream
+ options.maxconn = self.maxconn
app = make_app(make_handlers(loop, options), get_app_settings(options))
return app
@@ -670,3 +673,46 @@ class TestAppWithPutRequest(OtherTestBase):
url, method='PUT', body=body, headers=self.headers
)
self.assertIn('Method Not Allowed', ctx.exception.message)
+
+
+class TestAppWithTooManyConnections(OtherTestBase):
+
+ maxconn = 1
+
+ def setUp(self):
+ clients.clear()
+ super(TestAppWithTooManyConnections, self).setUp()
+
+ @tornado.testing.gen_test
+ def test_app_with_too_many_connections(self):
+ url = self.get_url('/')
+ client = self.get_http_client()
+ body = urlencode(dict(self.body, username='foo'))
+ response = yield client.fetch(url, method='POST', body=body,
+ headers=self.headers)
+ data = json.loads(to_str(response.body))
+ worker_id = data['id']
+ self.assertIsNotNone(worker_id)
+ self.assertIsNotNone(data['encoding'])
+ self.assertIsNone(data['status'])
+
+ response = yield client.fetch(url, method='POST', body=body,
+ headers=self.headers)
+ data = json.loads(to_str(response.body))
+ self.assertIsNone(data['id'])
+ self.assertIsNone(data['encoding'])
+ self.assertEqual(data['status'], 'Too many connections.')
+
+ ws_url = url.replace('http', 'ws') + 'ws?id=' + worker_id
+ ws = yield tornado.websocket.websocket_connect(ws_url)
+ msg = yield ws.read_message()
+ self.assertIsNotNone(msg)
+
+ response = yield client.fetch(url, method='POST', body=body,
+ headers=self.headers)
+ data = json.loads(to_str(response.body))
+ self.assertIsNone(data['id'])
+ self.assertIsNone(data['encoding'])
+ self.assertEqual(data['status'], 'Too many connections.')
+
+ ws.close()
diff --git a/webssh/handler.py b/webssh/handler.py
@@ -15,7 +15,7 @@ from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname
)
-from webssh.worker import Worker, recycle_worker, workers
+from webssh.worker import Worker, recycle_worker, clients
try:
from concurrent.futures import Future
@@ -311,8 +311,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
chan = ssh.invoke_shell(term='xterm')
chan.setblocking(0)
- worker = Worker(self.loop, ssh, chan, dst_addr)
- worker.src_addr = self.get_client_addr()
+ worker = Worker(self.loop, ssh, chan, dst_addr, self.src_addr)
worker.encoding = self.get_default_encoding(ssh)
return worker
@@ -337,6 +336,10 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
# for testing purpose only
raise ValueError('Uncaught exception')
+ self.src_addr = self.get_client_addr()
+ if len(clients.get(self.src_addr[0], {})) >= options.maxconn:
+ raise tornado.web.HTTPError(403, 'Too many connections.')
+
future = Future()
t = threading.Thread(target=self.ssh_connect_wrapped, args=(future,))
t.setDaemon(True)
@@ -347,6 +350,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
except (ValueError, paramiko.SSHException) as exc:
self.result.update(status=str(exc))
else:
+ workers = clients.setdefault(worker.src_addr[0], {})
workers[worker.id] = worker
self.loop.call_later(DELAY, recycle_worker, worker)
self.result.update(id=worker.id, encoding=worker.encoding)
@@ -363,14 +367,20 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
def open(self):
self.src_addr = self.get_client_addr()
logging.info('Connected from {}:{}'.format(*self.src_addr))
+
+ workers = clients.get(self.src_addr[0])
+ if not workers:
+ self.close(reason='Websocket authentication failed.')
+ return
+
try:
worker_id = self.get_value('id')
except (tornado.web.MissingArgumentError, InvalidValueError) as exc:
self.close(reason=str(exc))
else:
worker = workers.get(worker_id)
- if worker and worker.src_addr[0] == self.src_addr[0]:
- workers.pop(worker.id)
+ if worker:
+ workers[worker_id] = None
self.set_nodelay(True)
worker.set_handler(self)
self.worker_ref = weakref.ref(worker)
diff --git a/webssh/settings.py b/webssh/settings.py
@@ -35,6 +35,7 @@ define('fbidhttp', type=bool, default=True,
define('xheaders', type=bool, default=True, help='Support xheaders')
define('xsrf', type=bool, default=True, help='CSRF protection')
define('wpintvl', type=int, default=0, help='Websocket ping interval')
+define('maxconn', type=int, default=20, help='Maximum connections per client')
define('version', type=bool, help='Show version information',
callback=print_version)
diff --git a/webssh/worker.py b/webssh/worker.py
@@ -7,23 +7,23 @@ from tornado.util import errno_from_exception
BUF_SIZE = 32 * 1024
-workers = {}
+clients = {}
def recycle_worker(worker):
if worker.handler:
return
logging.warning('Recycling worker {}'.format(worker.id))
- workers.pop(worker.id, None)
worker.close(reason='worker recycled')
class Worker(object):
- def __init__(self, loop, ssh, chan, dst_addr):
+ def __init__(self, loop, ssh, chan, dst_addr, src_addr):
self.loop = loop
self.ssh = ssh
self.chan = chan
self.dst_addr = dst_addr
+ self.src_addr = src_addr
self.fd = chan.fileno()
self.id = str(id(self))
self.data_to_dst = []
@@ -104,3 +104,6 @@ class Worker(object):
self.chan.close()
self.ssh.close()
logging.info('Connection to {}:{} lost'.format(*self.dst_addr))
+
+ clients[self.src_addr[0]].pop(self.id, None)
+ logging.debug(clients)