commit 976b00ca85424d87bd2f3222970aa811da50ef8e
parent 86abf6912d459d27fab0fac018f65a54308cb0de
Author: Sheng <webmaster0115@gmail.com>
Date: Thu, 4 Jul 2019 18:17:36 +0800
Added function clear_worker
Diffstat:
2 files changed, 25 insertions(+), 8 deletions(-)
diff --git a/webssh/handler.py b/webssh/handler.py
@@ -348,7 +348,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
chan = ssh.invoke_shell(term='xterm')
chan.setblocking(0)
- worker = Worker(self.loop, ssh, chan, dst_addr, self.src_addr)
+ worker = Worker(self.loop, ssh, chan, dst_addr)
worker.encoding = self.get_default_encoding(ssh)
return worker
@@ -378,8 +378,9 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
# for testing purpose only
raise ValueError('Uncaught exception')
- self.src_addr = self.get_client_addr()
- if len(clients.get(self.src_addr[0], {})) >= options.maxconn:
+ ip, port = self.get_client_addr()
+ workers = clients.get(ip, {})
+ if workers and len(workers) >= options.maxconn:
raise tornado.web.HTTPError(403, 'Too many live connections.')
self.check_origin()
@@ -397,7 +398,9 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
logging.error(traceback.format_exc())
self.result.update(status=str(exc))
else:
- workers = clients.setdefault(worker.src_addr[0], {})
+ if not workers:
+ clients[ip] = workers
+ worker.src_addr = (ip, port)
workers[worker.id] = worker
self.loop.call_later(DELAY, recycle_worker, worker)
self.result.update(id=worker.id, encoding=worker.encoding)
diff --git a/webssh/worker.py b/webssh/worker.py
@@ -7,7 +7,22 @@ from tornado.util import errno_from_exception
BUF_SIZE = 32 * 1024
-clients = {}
+clients = {} # {ip: {id: worker}}
+
+
+def clear_worker(worker, clients):
+ ip = worker.src_addr[0]
+ workers = clients.get(ip)
+ if workers:
+ try:
+ workers.pop(worker.id)
+ except KeyError:
+ pass
+ else:
+ if not workers:
+ clients.pop(ip)
+ if not clients:
+ clients.clear()
def recycle_worker(worker):
@@ -18,12 +33,11 @@ def recycle_worker(worker):
class Worker(object):
- def __init__(self, loop, ssh, chan, dst_addr, src_addr):
+ def __init__(self, loop, ssh, chan, dst_addr):
self.loop = loop
self.ssh = ssh
self.chan = chan
self.dst_addr = dst_addr
- self.src_addr = src_addr
self.fd = chan.fileno()
self.id = str(id(self))
self.data_to_dst = []
@@ -110,5 +124,5 @@ class Worker(object):
self.ssh.close()
logging.info('Connection to {}:{} lost'.format(*self.dst_addr))
- clients[self.src_addr[0]].pop(self.id, None)
+ clear_worker(self, clients)
logging.debug(clients)