commit 77b6fbfd8573b298de0a16dbbc175bffe151492f
parent db3ee2b784211abd3eafa4b111a85bca6036968d
Author: Sheng <webmaster0115@gmail.com>
Date: Mon, 15 Oct 2018 20:13:11 +0800
Block requests not come from trusted_downstream and public non-https requests
Diffstat:
5 files changed, 82 insertions(+), 6 deletions(-)
diff --git a/.travis.yml b/.travis.yml
@@ -15,6 +15,7 @@ matrix:
install:
- pip install -r requirements.txt
- pip install pytest pytest-cov codecov flake8
+ - if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then pip install mock; fi
script:
- pytest --cov=webssh
diff --git a/tests/test_handler.py b/tests/test_handler.py
@@ -1,13 +1,57 @@
import unittest
import paramiko
+from tornado.httpclient import HTTPRequest
from tornado.httputil import HTTPServerRequest
+from tornado.web import HTTPError
from tests.utils import read_file, make_tests_data_path
from webssh.handler import MixinHandler, IndexHandler, InvalidValueError
+try:
+ from unittest.mock import Mock
+except ImportError:
+ from mock import Mock
+
class TestMixinHandler(unittest.TestCase):
+ def test_is_forbidden(self):
+ handler = MixinHandler()
+ request = HTTPRequest('http://example.com/')
+ handler.request = request
+
+ context = Mock(
+ address=('8.8.8.8', 8888),
+ trusted_downstream=['127.0.0.1'],
+ _orig_protocol='http'
+ )
+ request.connection = Mock(context=context)
+ self.assertTrue(handler.is_forbidden())
+
+ context = Mock(
+ address=('8.8.8.8', 8888),
+ trusted_downstream=[],
+ _orig_protocol='http'
+ )
+ request.connection = Mock(context=context)
+ self.assertTrue(handler.is_forbidden())
+
+ context = Mock(
+ address=('192.168.1.1', 8888),
+ trusted_downstream=[],
+ _orig_protocol='http'
+ )
+ request.connection = Mock(context=context)
+ self.assertIsNone(handler.is_forbidden())
+
+ context = Mock(
+ address=('8.8.8.8', 8888),
+ trusted_downstream=[],
+ _orig_protocol='https'
+ )
+ request.connection = Mock(context=context)
+ self.assertIsNone(handler.is_forbidden())
+
def test_get_real_client_addr(self):
x_forwarded_for = '1.1.1.1'
x_forwarded_port = 1111
diff --git a/webssh/handler.py b/webssh/handler.py
@@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop
from webssh.settings import swallow_http_errors
from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname,
- to_bytes, to_str, to_int, UnicodeType
+ to_bytes, to_str, to_int, to_ip_address, UnicodeType
)
from webssh.worker import Worker, recycle_worker, workers
@@ -39,6 +39,28 @@ class InvalidValueError(Exception):
class MixinHandler(object):
+ def prepare(self):
+ if self.is_forbidden():
+ raise tornado.web.HTTPError(403)
+
+ def is_forbidden(self):
+ """
+ Following requests are forbidden:
+ * requests not come from trusted_downstream (if set).
+ * non-https requests from a public network.
+ """
+ context = self.request.connection.context
+ ip = context.address[0]
+ lst = context.trusted_downstream
+
+ if lst and ip not in lst:
+ return True
+
+ if context._orig_protocol == 'http':
+ ipaddr = to_ip_address(ip)
+ if ipaddr.is_global:
+ return True
+
def set_default_headers(self):
self.set_header('Server', 'TornadoServer')
diff --git a/webssh/main.py b/webssh/main.py
@@ -6,7 +6,7 @@ from tornado.options import options
from webssh.handler import IndexHandler, WsockHandler
from webssh.settings import (
get_app_settings, get_host_keys_settings, get_policy_setting,
- get_ssl_context, max_body_size, xheaders
+ get_ssl_context, get_server_settings
)
@@ -31,12 +31,12 @@ def main():
loop = tornado.ioloop.IOLoop.current()
app = make_app(make_handlers(loop, options), get_app_settings(options))
ssl_ctx = get_ssl_context(options)
- kwargs = dict(xheaders=xheaders, max_body_size=max_body_size)
- app.listen(options.port, options.address, **kwargs)
+ server_settings = get_server_settings(options)
+ app.listen(options.port, options.address, **server_settings)
logging.info('Listening on {}:{}'.format(options.address, options.port))
if ssl_ctx:
- kwargs.update(ssl_options=ssl_ctx)
- app.listen(options.sslPort, options.sslAddress, **kwargs)
+ server_settings.update(ssl_options=ssl_ctx)
+ app.listen(options.sslPort, options.sslAddress, **server_settings)
logging.info('Listening on ssl {}:{}'.format(options.sslAddress,
options.sslPort))
loop.start()
diff --git a/webssh/settings.py b/webssh/settings.py
@@ -51,6 +51,15 @@ def get_app_settings(options):
return settings
+def get_server_settings(options):
+ settings = dict(
+ xheaders=xheaders,
+ max_body_size=max_body_size,
+ trusted_downstream=get_trusted_downstream(options)
+ )
+ return settings
+
+
def get_host_keys_settings(options):
if not options.hostFile:
host_keys_filename = os.path.join(base_dir, 'known_hosts')