commit 8a8d7412301cdde4ba4e3a74eb7cabb90533d474
parent 9f6d900b23d0567a9fe958d14cbf4587c29738ce
Author: Sheng <webmaster0115@gmail.com>
Date: Wed, 16 Jan 2019 22:58:49 +0800
Refactored method is_forbidden
Diffstat:
4 files changed, 35 insertions(+), 34 deletions(-)
diff --git a/tests/test_handler.py b/tests/test_handler.py
@@ -19,57 +19,53 @@ class TestMixinHandler(unittest.TestCase):
def test_is_forbidden(self):
mhandler = MixinHandler()
- handler.https_server_enabled = True
+ handler.redirecting = True
options.fbidhttp = True
- options.redirect = True
context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=['127.0.0.1'],
_orig_protocol='http'
)
- self.assertTrue(mhandler.is_forbidden(context, ''))
+ hostname = '4.4.4.4'
+ self.assertTrue(mhandler.is_forbidden(context, hostname))
context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=[],
_orig_protocol='http'
)
-
hostname = 'www.google.com'
self.assertEqual(mhandler.is_forbidden(context, hostname), False)
- handler.https_server_enabled = False
- self.assertTrue(mhandler.is_forbidden(context, hostname))
-
- options.redirect = False
- self.assertTrue(mhandler.is_forbidden(context, hostname))
-
- context = Mock(
- address=('192.168.1.1', 8888),
- trusted_downstream=[],
- _orig_protocol='http'
- )
- self.assertIsNone(mhandler.is_forbidden(context, ''))
-
context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=[],
- _orig_protocol='https'
+ _orig_protocol='http'
)
- self.assertIsNone(mhandler.is_forbidden(context, ''))
+ hostname = '4.4.4.4'
+ self.assertTrue(mhandler.is_forbidden(context, hostname))
context = Mock(
- address=('8.8.8.8', 8888),
+ address=('192.168.1.1', 8888),
trusted_downstream=[],
_orig_protocol='http'
)
- hostname = '8.8.8.8'
- self.assertTrue(mhandler.is_forbidden(context, hostname))
+ hostname = 'www.google.com'
+ self.assertIsNone(mhandler.is_forbidden(context, hostname))
options.fbidhttp = False
self.assertIsNone(mhandler.is_forbidden(context, hostname))
+ hostname = '4.4.4.4'
+ self.assertIsNone(mhandler.is_forbidden(context, hostname))
+
+ handler.redirecting = False
+ self.assertIsNone(mhandler.is_forbidden(context, hostname))
+
+ context._orig_protocol = 'https'
+ self.assertIsNone(mhandler.is_forbidden(context, hostname))
+
def test_get_redirect_url(self):
mhandler = MixinHandler()
hostname = 'www.example.com'
diff --git a/tests/test_main.py b/tests/test_main.py
@@ -11,12 +11,12 @@ class TestMain(unittest.TestCase):
app = Application()
app.listen = lambda x, y, **kwargs: 1
- handler.https_server_enabled = False
+ handler.redirecting = None
server_settings = dict()
app_listen(app, 80, '127.0.0.1', server_settings)
- self.assertFalse(handler.https_server_enabled)
+ self.assertFalse(handler.redirecting)
- handler.https_server_enabled = False
+ handler.redirecting = None
server_settings = dict(ssl_options='enabled')
app_listen(app, 80, '127.0.0.1', server_settings)
- self.assertTrue(handler.https_server_enabled)
+ self.assertTrue(handler.redirecting)
diff --git a/webssh/handler.py b/webssh/handler.py
@@ -38,7 +38,7 @@ KEY_MAX_SIZE = 16384
DEFAULT_PORT = 22
swallow_http_errors = True
-https_server_enabled = False
+redirecting = None
class InvalidValueError(Exception):
@@ -78,6 +78,7 @@ class MixinHandler(object):
def is_forbidden(self, context, hostname):
ip = context.address[0]
lst = context.trusted_downstream
+ ip_address = None
if lst and ip not in lst:
logging.warning(
@@ -85,15 +86,19 @@ class MixinHandler(object):
)
return True
- if context._orig_protocol == 'http' and \
- not to_ip_address(ip).is_private:
- if options.redirect and https_server_enabled:
- if not is_ip_hostname(hostname):
+ if context._orig_protocol == 'http':
+ if redirecting and not is_ip_hostname(hostname):
+ ip_address = to_ip_address(ip)
+ if not ip_address.is_private:
# redirecting
return False
+
if options.fbidhttp:
- logging.warning('Public plain http request is forbidden.')
- return True
+ if ip_address is None:
+ ip_address = to_ip_address(ip)
+ if not ip_address.is_private:
+ logging.warning('Public plain http request is forbidden.')
+ return True
def get_redirect_url(self, hostname, port, uri):
port = '' if port == 443 else ':%s' % port
diff --git a/webssh/main.py b/webssh/main.py
@@ -34,7 +34,7 @@ def app_listen(app, port, address, server_settings):
server_type = 'http'
else:
server_type = 'https'
- handler.https_server_enabled = True
+ handler.redirecting = True if options.redirect else False
logging.info(
'Listening on {}:{} ({})'.format(address, port, server_type)
)