commit 1f835f5a70216c2b65c4ad9433e02c782c43deaf
parent fbb3e466b2d12a270c53a17899a4aa6286e9b92a
Author: Sheng <webmaster0115@gmail.com>
Date: Fri, 19 Oct 2018 18:18:55 +0800
Refactored handler.py
Diffstat:
2 files changed, 15 insertions(+), 17 deletions(-)
diff --git a/tests/test_handler.py b/tests/test_handler.py
@@ -18,33 +18,33 @@ class TestMixinHandler(unittest.TestCase):
handler = MixinHandler()
options.fbidhttp = True
- handler.context = Mock(
+ context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=['127.0.0.1'],
_orig_protocol='http'
)
- self.assertTrue(handler.is_forbidden())
+ self.assertTrue(handler.is_forbidden(context))
- handler.context = Mock(
+ context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=[],
_orig_protocol='http'
)
- self.assertTrue(handler.is_forbidden())
+ self.assertTrue(handler.is_forbidden(context))
- handler.context = Mock(
+ context = Mock(
address=('192.168.1.1', 8888),
trusted_downstream=[],
_orig_protocol='http'
)
- self.assertIsNone(handler.is_forbidden())
+ self.assertIsNone(handler.is_forbidden(context))
- handler.context = Mock(
+ context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=[],
_orig_protocol='https'
)
- self.assertIsNone(handler.is_forbidden())
+ self.assertIsNone(handler.is_forbidden(context))
def test_get_client_addr(self):
handler = MixinHandler()
diff --git a/webssh/handler.py b/webssh/handler.py
@@ -45,22 +45,22 @@ class MixinHandler(object):
'Server': 'TornadoServer'
}
- def initialize(self):
+ def initialize(self, loop=None):
conn = self.request.connection
- self.context = conn.context
- if self.is_forbidden():
+ if self.is_forbidden(conn.context):
result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
conn.stream.write(to_bytes(result))
conn.close()
raise ValueError('Accesss denied')
+ self.loop = loop
+ self.context = conn.context
- def is_forbidden(self):
+ def is_forbidden(self, context):
"""
Following requests are forbidden:
* requests not come from trusted_downstream (if set).
* plain http requests from a public network.
"""
- context = self.context
ip = context.address[0]
lst = context.trusted_downstream
@@ -123,14 +123,13 @@ class NotFoundHandler(MixinHandler, tornado.web.ErrorHandler):
class IndexHandler(MixinHandler, tornado.web.RequestHandler):
def initialize(self, loop, policy, host_keys_settings):
- self.loop = loop
+ super(IndexHandler, self).initialize(loop)
self.policy = policy
self.host_keys_settings = host_keys_settings
self.ssh_client = self.get_ssh_client()
self.privatekey_filename = None
self.debug = self.settings.get('debug', False)
self.result = dict(id=None, status=None, encoding=None)
- super(IndexHandler, self).initialize()
def write_error(self, status_code, **kwargs):
if self.request.method != 'POST' or not swallow_http_errors:
@@ -329,9 +328,8 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
def initialize(self, loop):
- self.loop = loop
+ super(WsockHandler, self).initialize(loop)
self.worker_ref = None
- super(WsockHandler, self).initialize()
def open(self):
self.src_addr = self.get_client_addr()