test_handler.py (9821B)
1 import unittest 2 import paramiko 3 4 from tornado.httputil import HTTPServerRequest 5 from tornado.options import options 6 from tests.utils import read_file, make_tests_data_path 7 from webssh import handler 8 from webssh.handler import ( 9 MixinHandler, WsockHandler, PrivateKey, InvalidValueError 10 ) 11 12 try: 13 from unittest.mock import Mock 14 except ImportError: 15 from mock import Mock 16 17 18 class TestMixinHandler(unittest.TestCase): 19 20 def test_is_forbidden(self): 21 mhandler = MixinHandler() 22 handler.redirecting = True 23 options.fbidhttp = True 24 25 context = Mock( 26 address=('8.8.8.8', 8888), 27 trusted_downstream=['127.0.0.1'], 28 _orig_protocol='http' 29 ) 30 hostname = '4.4.4.4' 31 self.assertTrue(mhandler.is_forbidden(context, hostname)) 32 33 context = Mock( 34 address=('8.8.8.8', 8888), 35 trusted_downstream=[], 36 _orig_protocol='http' 37 ) 38 hostname = 'www.google.com' 39 self.assertEqual(mhandler.is_forbidden(context, hostname), False) 40 41 context = Mock( 42 address=('8.8.8.8', 8888), 43 trusted_downstream=[], 44 _orig_protocol='http' 45 ) 46 hostname = '4.4.4.4' 47 self.assertTrue(mhandler.is_forbidden(context, hostname)) 48 49 context = Mock( 50 address=('192.168.1.1', 8888), 51 trusted_downstream=[], 52 _orig_protocol='http' 53 ) 54 hostname = 'www.google.com' 55 self.assertIsNone(mhandler.is_forbidden(context, hostname)) 56 57 options.fbidhttp = False 58 self.assertIsNone(mhandler.is_forbidden(context, hostname)) 59 60 hostname = '4.4.4.4' 61 self.assertIsNone(mhandler.is_forbidden(context, hostname)) 62 63 handler.redirecting = False 64 self.assertIsNone(mhandler.is_forbidden(context, hostname)) 65 66 context._orig_protocol = 'https' 67 self.assertIsNone(mhandler.is_forbidden(context, hostname)) 68 69 def test_get_redirect_url(self): 70 mhandler = MixinHandler() 71 hostname = 'www.example.com' 72 uri = '/' 73 port = 443 74 75 self.assertEqual( 76 mhandler.get_redirect_url(hostname, port, uri=uri), 77 'https://www.example.com/' 78 ) 79 80 port = 4433 81 self.assertEqual( 82 mhandler.get_redirect_url(hostname, port, uri), 83 'https://www.example.com:4433/' 84 ) 85 86 def test_get_client_addr(self): 87 mhandler = MixinHandler() 88 client_addr = ('8.8.8.8', 8888) 89 context_addr = ('127.0.0.1', 1234) 90 options.xheaders = True 91 92 mhandler.context = Mock(address=context_addr) 93 mhandler.get_real_client_addr = lambda: None 94 self.assertEqual(mhandler.get_client_addr(), context_addr) 95 96 mhandler.context = Mock(address=context_addr) 97 mhandler.get_real_client_addr = lambda: client_addr 98 self.assertEqual(mhandler.get_client_addr(), client_addr) 99 100 options.xheaders = False 101 mhandler.context = Mock(address=context_addr) 102 mhandler.get_real_client_addr = lambda: client_addr 103 self.assertEqual(mhandler.get_client_addr(), context_addr) 104 105 def test_get_real_client_addr(self): 106 x_forwarded_for = '1.1.1.1' 107 x_forwarded_port = 1111 108 x_real_ip = '2.2.2.2' 109 x_real_port = 2222 110 fake_port = 65535 111 112 mhandler = MixinHandler() 113 mhandler.request = HTTPServerRequest(uri='/') 114 mhandler.request.remote_ip = x_forwarded_for 115 116 self.assertIsNone(mhandler.get_real_client_addr()) 117 118 mhandler.request.headers.add('X-Forwarded-For', x_forwarded_for) 119 self.assertEqual(mhandler.get_real_client_addr(), 120 (x_forwarded_for, fake_port)) 121 122 mhandler.request.headers.add('X-Forwarded-Port', fake_port + 1) 123 self.assertEqual(mhandler.get_real_client_addr(), 124 (x_forwarded_for, fake_port)) 125 126 mhandler.request.headers['X-Forwarded-Port'] = x_forwarded_port 127 self.assertEqual(mhandler.get_real_client_addr(), 128 (x_forwarded_for, x_forwarded_port)) 129 130 mhandler.request.remote_ip = x_real_ip 131 132 mhandler.request.headers.add('X-Real-Ip', x_real_ip) 133 self.assertEqual(mhandler.get_real_client_addr(), 134 (x_real_ip, fake_port)) 135 136 mhandler.request.headers.add('X-Real-Port', fake_port + 1) 137 self.assertEqual(mhandler.get_real_client_addr(), 138 (x_real_ip, fake_port)) 139 140 mhandler.request.headers['X-Real-Port'] = x_real_port 141 self.assertEqual(mhandler.get_real_client_addr(), 142 (x_real_ip, x_real_port)) 143 144 145 class TestPrivateKey(unittest.TestCase): 146 147 def get_pk_obj(self, fname, password=None): 148 key = read_file(make_tests_data_path(fname)) 149 return PrivateKey(key, password=password, filename=fname) 150 151 def _test_with_encrypted_key(self, fname, password, klass): 152 pk = self.get_pk_obj(fname, password='') 153 with self.assertRaises(InvalidValueError) as ctx: 154 pk.get_pkey_obj() 155 self.assertIn('Need a passphrase', str(ctx.exception)) 156 157 pk = self.get_pk_obj(fname, password='wrongpass') 158 with self.assertRaises(InvalidValueError) as ctx: 159 pk.get_pkey_obj() 160 self.assertIn('wrong passphrase', str(ctx.exception)) 161 162 pk = self.get_pk_obj(fname, password=password) 163 self.assertIsInstance(pk.get_pkey_obj(), klass) 164 165 def test_class_with_invalid_key_length(self): 166 key = u'a' * (PrivateKey.max_length + 1) 167 168 with self.assertRaises(InvalidValueError) as ctx: 169 PrivateKey(key) 170 self.assertIn('Invalid key length', str(ctx.exception)) 171 172 def test_get_pkey_obj_with_invalid_key(self): 173 key = u'a b c' 174 fname = 'abc' 175 176 pk = PrivateKey(key, filename=fname) 177 with self.assertRaises(InvalidValueError) as ctx: 178 pk.get_pkey_obj() 179 self.assertIn('Invalid key {}'.format(fname), str(ctx.exception)) 180 181 def test_get_pkey_obj_with_plain_rsa_key(self): 182 pk = self.get_pk_obj('test_rsa.key') 183 self.assertIsInstance(pk.get_pkey_obj(), paramiko.RSAKey) 184 185 def test_get_pkey_obj_with_plain_ed25519_key(self): 186 pk = self.get_pk_obj('test_ed25519.key') 187 self.assertIsInstance(pk.get_pkey_obj(), paramiko.Ed25519Key) 188 189 def test_get_pkey_obj_with_encrypted_rsa_key(self): 190 fname = 'test_rsa_password.key' 191 password = 'television' 192 self._test_with_encrypted_key(fname, password, paramiko.RSAKey) 193 194 def test_get_pkey_obj_with_encrypted_ed25519_key(self): 195 fname = 'test_ed25519_password.key' 196 password = 'abc123' 197 self._test_with_encrypted_key(fname, password, paramiko.Ed25519Key) 198 199 def test_get_pkey_obj_with_encrypted_new_rsa_key(self): 200 fname = 'test_new_rsa_password.key' 201 password = '123456' 202 self._test_with_encrypted_key(fname, password, paramiko.RSAKey) 203 204 def test_get_pkey_obj_with_plain_new_dsa_key(self): 205 pk = self.get_pk_obj('test_new_dsa.key') 206 self.assertIsInstance(pk.get_pkey_obj(), paramiko.DSSKey) 207 208 def test_parse_name(self): 209 key = u'-----BEGIN PRIVATE KEY-----' 210 pk = PrivateKey(key) 211 name, _ = pk.parse_name(pk.iostr, pk.tag_to_name) 212 self.assertIsNone(name) 213 214 key = u'-----BEGIN xxx PRIVATE KEY-----' 215 pk = PrivateKey(key) 216 name, _ = pk.parse_name(pk.iostr, pk.tag_to_name) 217 self.assertIsNone(name) 218 219 key = u'-----BEGIN RSA PRIVATE KEY-----' 220 pk = PrivateKey(key) 221 name, _ = pk.parse_name(pk.iostr, pk.tag_to_name) 222 self.assertIsNone(name) 223 224 key = u'-----BEGIN RSA PRIVATE KEY-----' 225 pk = PrivateKey(key) 226 name, _ = pk.parse_name(pk.iostr, pk.tag_to_name) 227 self.assertIsNone(name) 228 229 key = u'-----BEGIN RSA PRIVATE KEY-----' 230 pk = PrivateKey(key) 231 name, _ = pk.parse_name(pk.iostr, pk.tag_to_name) 232 self.assertIsNone(name) 233 234 for tag, to_name in PrivateKey.tag_to_name.items(): 235 key = u'-----BEGIN {} PRIVATE KEY----- \r\n'.format(tag) 236 pk = PrivateKey(key) 237 name, length = pk.parse_name(pk.iostr, pk.tag_to_name) 238 self.assertEqual(name, to_name) 239 self.assertEqual(length, len(key)) 240 241 242 class TestWsockHandler(unittest.TestCase): 243 244 def test_check_origin(self): 245 request = HTTPServerRequest(uri='/') 246 obj = Mock(spec=WsockHandler, request=request) 247 248 obj.origin_policy = 'same' 249 request.headers['Host'] = 'www.example.com:4433' 250 origin = 'https://www.example.com:4433' 251 self.assertTrue(WsockHandler.check_origin(obj, origin)) 252 253 origin = 'https://www.example.com' 254 self.assertFalse(WsockHandler.check_origin(obj, origin)) 255 256 obj.origin_policy = 'primary' 257 self.assertTrue(WsockHandler.check_origin(obj, origin)) 258 259 origin = 'https://blog.example.com' 260 self.assertTrue(WsockHandler.check_origin(obj, origin)) 261 262 origin = 'https://blog.example.org' 263 self.assertFalse(WsockHandler.check_origin(obj, origin)) 264 265 origin = 'https://blog.example.org' 266 obj.origin_policy = {'https://blog.example.org'} 267 self.assertTrue(WsockHandler.check_origin(obj, origin)) 268 269 origin = 'http://blog.example.org' 270 obj.origin_policy = {'http://blog.example.org'} 271 self.assertTrue(WsockHandler.check_origin(obj, origin)) 272 273 origin = 'http://blog.example.org' 274 obj.origin_policy = {'https://blog.example.org'} 275 self.assertFalse(WsockHandler.check_origin(obj, origin)) 276 277 obj.origin_policy = '*' 278 origin = 'https://blog.example.org' 279 self.assertTrue(WsockHandler.check_origin(obj, origin))