commit a68eff592f68c8e1a899e1b3a2b1dd1e6667c203
parent af60cd1cd59abf0587b1b6a4d50e0d7c8d992f28
Author: Sheng <webmaster0115@gmail.com>
Date: Tue, 16 Oct 2018 15:14:34 +0800
Added attribute custom_headers to MixinHandler
Diffstat:
2 files changed, 17 insertions(+), 4 deletions(-)
diff --git a/tests/test_app.py b/tests/test_app.py
@@ -586,22 +586,30 @@ class TestAppWithTrustedStream(OtherTestBase):
class TestAppNotFoundHandler(OtherTestBase):
+ custom_headers = handler.MixinHandler.custom_headers
+
def test_with_not_found_get_request(self):
response = self.fetch('/pathnotfound', method='GET')
self.assertEqual(response.code, 404)
- self.assertEqual(response.headers['Server'], 'TornadoServer')
+ self.assertEqual(
+ response.headers['Server'], self.custom_headers['Server']
+ )
self.assertIn(b'404: Not Found', response.body)
def test_with_not_found_post_request(self):
response = self.fetch('/pathnotfound', method='POST',
body=urlencode(self.body), headers=self.headers)
self.assertEqual(response.code, 404)
- self.assertEqual(response.headers['Server'], 'TornadoServer')
+ self.assertEqual(
+ response.headers['Server'], self.custom_headers['Server']
+ )
self.assertIn(b'404: Not Found', response.body)
def test_with_not_found_put_request(self):
response = self.fetch('/pathnotfound', method='PUT',
body=urlencode(self.body), headers=self.headers)
self.assertEqual(response.code, 404)
- self.assertEqual(response.headers['Server'], 'TornadoServer')
+ self.assertEqual(
+ response.headers['Server'], self.custom_headers['Server']
+ )
self.assertIn(b'404: Not Found', response.body)
diff --git a/webssh/handler.py b/webssh/handler.py
@@ -39,6 +39,10 @@ class InvalidValueError(Exception):
class MixinHandler(object):
+ custom_headers = {
+ 'Server': 'TornadoServer'
+ }
+
def prepare(self):
if self.is_forbidden():
raise tornado.web.HTTPError(403)
@@ -66,7 +70,8 @@ class MixinHandler(object):
return True
def set_default_headers(self):
- self.set_header('Server', 'TornadoServer')
+ for header in self.custom_headers.items():
+ self.set_header(*header)
def get_value(self, name):
value = self.get_argument(name)