commit 68468585ee6da62d61ff128b94a9dba818aef15a
parent a8a444d7ed12576ed30cff136ae33b8f5ce6fdc4
Author: Sheng <webmaster0115@gmail.com>
Date: Thu, 18 Oct 2018 20:25:30 +0800
Added a command line option xheaders
Diffstat:
3 files changed, 40 insertions(+), 23 deletions(-)
diff --git a/tests/test_handler.py b/tests/test_handler.py
@@ -1,7 +1,6 @@
import unittest
import paramiko
-from tornado.httpclient import HTTPRequest
from tornado.httputil import HTTPServerRequest
from tornado.options import options
from tests.utils import read_file, make_tests_data_path
@@ -17,42 +16,55 @@ class TestMixinHandler(unittest.TestCase):
def test_is_forbidden(self):
handler = MixinHandler()
- request = HTTPRequest('http://example.com/')
- handler.request = request
options.fbidhttp = True
- context = Mock(
+ handler.context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=['127.0.0.1'],
_orig_protocol='http'
)
- request.connection = Mock(context=context)
self.assertTrue(handler.is_forbidden())
- context = Mock(
+ handler.context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=[],
_orig_protocol='http'
)
- request.connection = Mock(context=context)
self.assertTrue(handler.is_forbidden())
- context = Mock(
+ handler.context = Mock(
address=('192.168.1.1', 8888),
trusted_downstream=[],
_orig_protocol='http'
)
- request.connection = Mock(context=context)
self.assertIsNone(handler.is_forbidden())
- context = Mock(
+ handler.context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=[],
_orig_protocol='https'
)
- request.connection = Mock(context=context)
self.assertIsNone(handler.is_forbidden())
+ def test_get_client_addr(self):
+ handler = MixinHandler()
+ client_addr = ('8.8.8.8', 8888)
+ context_addr = ('127.0.0.1', 1234)
+ options.xheaders = True
+
+ handler.context = Mock(address=context_addr)
+ handler.get_real_client_addr = lambda: None
+ self.assertEqual(handler.get_client_addr(), context_addr)
+
+ handler.context = Mock(address=context_addr)
+ handler.get_real_client_addr = lambda: client_addr
+ self.assertEqual(handler.get_client_addr(), client_addr)
+
+ options.xheaders = False
+ handler.context = Mock(address=context_addr)
+ handler.get_real_client_addr = lambda: client_addr
+ self.assertEqual(handler.get_client_addr(), context_addr)
+
def test_get_real_client_addr(self):
x_forwarded_for = '1.1.1.1'
x_forwarded_port = 1111
diff --git a/webssh/handler.py b/webssh/handler.py
@@ -46,19 +46,21 @@ class MixinHandler(object):
}
def initialize(self):
+ conn = self.request.connection
+ self.context = conn.context
if self.is_forbidden():
result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
- self.request.connection.stream.write(to_bytes(result))
- self.request.connection.close()
+ conn.stream.write(to_bytes(result))
+ conn.close()
raise ValueError('Accesss denied')
def is_forbidden(self):
"""
Following requests are forbidden:
* requests not come from trusted_downstream (if set).
- * non-https requests from a public network.
+ * plain http requests from a public network.
"""
- context = self.request.connection.context
+ context = self.context
ip = context.address[0]
lst = context.trusted_downstream
@@ -71,7 +73,7 @@ class MixinHandler(object):
if options.fbidhttp and context._orig_protocol == 'http':
ipaddr = to_ip_address(ip)
if not ipaddr.is_private:
- logging.warning('Public non-https request is forbidden.')
+ logging.warning('Public plain http request is forbidden.')
return True
def set_default_headers(self):
@@ -85,8 +87,10 @@ class MixinHandler(object):
return value
def get_client_addr(self):
- return self.get_real_client_addr() or self.request.connection.context.\
- address
+ if options.xheaders:
+ return self.get_real_client_addr() or self.context.address
+ else:
+ return self.context.address
def get_real_client_addr(self):
ip = self.request.remote_ip
diff --git a/webssh/settings.py b/webssh/settings.py
@@ -30,8 +30,10 @@ define('policy', default='warning',
help='Missing host key policy, reject|autoadd|warning')
define('hostfile', default='', help='User defined host keys file')
define('syshostfile', default='', help='System wide host keys file')
-define('tdstream', default='', help='trusted downstream, separated by comma')
-define('fbidhttp', type=bool, default=True, help='forbid public http request')
+define('tdstream', default='', help='Trusted downstream, separated by comma')
+define('fbidhttp', type=bool, default=True,
+ help='Forbid public plain http incoming requests')
+define('xheaders', type=bool, default=True, help='Support xheaders')
define('wpintvl', type=int, default=0, help='Websocket ping interval')
define('version', type=bool, help='Show version information',
callback=print_version)
@@ -39,7 +41,6 @@ define('version', type=bool, help='Show version information',
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
max_body_size = 1 * 1024 * 1024
-xheaders = True
def get_app_settings(options):
@@ -55,7 +56,7 @@ def get_app_settings(options):
def get_server_settings(options):
settings = dict(
- xheaders=xheaders,
+ xheaders=options.xheaders,
max_body_size=max_body_size,
trusted_downstream=get_trusted_downstream(options)
)
@@ -121,4 +122,4 @@ def detect_is_open_to_public(options):
result = on_public_network_interfaces(get_ips_by_name(options.address))
if not result and options.fbidhttp:
options.fbidhttp = False
- logging.info('Forbid public http: {}'.format(options.fbidhttp))
+ logging.info('Forbid public plain http: {}'.format(options.fbidhttp))