commit f6d2776a20faa8292dbd4e562ad1c82e7aa8ed96
parent 88405eddac6f3dca0be0a628dd941a1c3fc016a7
Author: Sheng <webmaster0115@gmail.com>
Date: Wed, 10 Oct 2018 10:51:40 +0800
Let tornado parse xheaders
Diffstat:
3 files changed, 48 insertions(+), 24 deletions(-)
diff --git a/tests/test_handler.py b/tests/test_handler.py
@@ -9,25 +9,43 @@ from webssh.handler import MixinHandler, IndexHandler, InvalidValueError
class TestMixinHandler(unittest.TestCase):
def test_get_real_client_addr(self):
+ x_forwarded_for = '1.1.1.1'
+ x_forwarded_port = 1111
+ x_real_ip = '2.2.2.2'
+ x_real_port = 2222
+ fake_port = 65535
+
handler = MixinHandler()
handler.request = HTTPServerRequest(uri='/')
+ handler.request.remote_ip = x_forwarded_for
+
self.assertIsNone(handler.get_real_client_addr())
- ip = '127.0.0.1'
- handler.request.headers.add('X-Real-Ip', ip)
- self.assertEqual(handler.get_real_client_addr(), False)
+ handler.request.headers.add('X-Forwarded-For', x_forwarded_for)
+ self.assertEqual(handler.get_real_client_addr(),
+ (x_forwarded_for, fake_port))
+
+ handler.request.headers.add('X-Forwarded-Port', fake_port + 1)
+ self.assertEqual(handler.get_real_client_addr(),
+ (x_forwarded_for, fake_port))
+
+ handler.request.headers['X-Forwarded-Port'] = x_forwarded_port
+ self.assertEqual(handler.get_real_client_addr(),
+ (x_forwarded_for, x_forwarded_port))
- handler.request.headers.add('X-Real-Port', '12345x')
- self.assertEqual(handler.get_real_client_addr(), False)
+ handler.request.remote_ip = x_real_ip
- handler.request.headers.update({'X-Real-Port': '12345'})
- self.assertEqual(handler.get_real_client_addr(), (ip, 12345))
+ handler.request.headers.add('X-Real-Ip', x_real_ip)
+ self.assertEqual(handler.get_real_client_addr(),
+ (x_real_ip, fake_port))
- handler.request.headers.update({'X-Real-ip': None})
- self.assertEqual(handler.get_real_client_addr(), False)
+ handler.request.headers.add('X-Real-Port', fake_port + 1)
+ self.assertEqual(handler.get_real_client_addr(),
+ (x_real_ip, fake_port))
- handler.request.headers.update({'X-Real-Port': '12345x'})
- self.assertEqual(handler.get_real_client_addr(), False)
+ handler.request.headers['X-Real-Port'] = x_real_port
+ self.assertEqual(handler.get_real_client_addr(),
+ (x_real_ip, x_real_port))
class TestIndexHandler(unittest.TestCase):
diff --git a/webssh/handler.py b/webssh/handler.py
@@ -45,19 +45,22 @@ class MixinHandler(object):
return value
def get_real_client_addr(self):
- ip = self.request.headers.get('X-Real-Ip')
- port = self.request.headers.get('X-Real-Port')
+ ip = self.request.remote_ip
- if ip is None and port is None:
- return # suppose this app doesn't run after an nginx server
+ if ip == self.request.headers.get('X-Real-Ip'):
+ port = self.request.headers.get('X-Real-Port')
+ elif ip in self.request.headers.get('X-Forwarded-For', ''):
+ port = self.request.headers.get('X-Forwarded-Port')
+ else:
+ # not running behind an nginx server
+ return
- if is_valid_ipv4_address(ip) or is_valid_ipv6_address(ip):
- port = to_int(port)
- if port and is_valid_port(port):
- return (ip, port)
+ port = to_int(port)
+ if port is None or not is_valid_port(port):
+ # fake port
+ port = 65535
- logging.warning('Bad nginx configuration.')
- return False
+ return (ip, port)
class IndexHandler(MixinHandler, tornado.web.RequestHandler):
@@ -94,13 +97,15 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
def get_privatekey(self):
name = 'privatekey'
- lst = self.request.files.get(name) # multipart form
+ lst = self.request.files.get(name)
if lst:
+ # multipart form
self.privatekey_filename = lst[0]['filename']
data = lst[0]['body']
value = self.decode_argument(data, name=name).strip()
else:
- value = self.get_argument(name, u'') # urlencoded form
+ # urlencoded form
+ value = self.get_argument(name, u'')
if len(value) > KEY_MAX_SIZE:
raise InvalidValueError(
diff --git a/webssh/main.py b/webssh/main.py
@@ -28,7 +28,8 @@ def main():
options.parse_command_line()
loop = tornado.ioloop.IOLoop.current()
app = make_app(make_handlers(loop, options), get_app_settings(options))
- app.listen(options.port, options.address, max_body_size=max_body_size)
+ server_settings = dict(xheaders=True, max_body_size=max_body_size)
+ app.listen(options.port, options.address, **server_settings)
logging.info('Listening on {}:{}'.format(options.address, options.port))
loop.start()