webssh

Web based ssh client https://github.com/huashengdun/webssh webssh.huashengdun.org/
git clone http://git.hanabi.in/repos/webssh.git
Log | Files | Refs | README | LICENSE

test_handler.py (9821B)


      1 import unittest
      2 import paramiko
      3 
      4 from tornado.httputil import HTTPServerRequest
      5 from tornado.options import options
      6 from tests.utils import read_file, make_tests_data_path
      7 from webssh import handler
      8 from webssh.handler import (
      9     MixinHandler, WsockHandler, PrivateKey, InvalidValueError
     10 )
     11 
     12 try:
     13     from unittest.mock import Mock
     14 except ImportError:
     15     from mock import Mock
     16 
     17 
     18 class TestMixinHandler(unittest.TestCase):
     19 
     20     def test_is_forbidden(self):
     21         mhandler = MixinHandler()
     22         handler.redirecting = True
     23         options.fbidhttp = True
     24 
     25         context = Mock(
     26             address=('8.8.8.8', 8888),
     27             trusted_downstream=['127.0.0.1'],
     28             _orig_protocol='http'
     29         )
     30         hostname = '4.4.4.4'
     31         self.assertTrue(mhandler.is_forbidden(context, hostname))
     32 
     33         context = Mock(
     34             address=('8.8.8.8', 8888),
     35             trusted_downstream=[],
     36             _orig_protocol='http'
     37         )
     38         hostname = 'www.google.com'
     39         self.assertEqual(mhandler.is_forbidden(context, hostname), False)
     40 
     41         context = Mock(
     42             address=('8.8.8.8', 8888),
     43             trusted_downstream=[],
     44             _orig_protocol='http'
     45         )
     46         hostname = '4.4.4.4'
     47         self.assertTrue(mhandler.is_forbidden(context, hostname))
     48 
     49         context = Mock(
     50             address=('192.168.1.1', 8888),
     51             trusted_downstream=[],
     52             _orig_protocol='http'
     53         )
     54         hostname = 'www.google.com'
     55         self.assertIsNone(mhandler.is_forbidden(context, hostname))
     56 
     57         options.fbidhttp = False
     58         self.assertIsNone(mhandler.is_forbidden(context, hostname))
     59 
     60         hostname = '4.4.4.4'
     61         self.assertIsNone(mhandler.is_forbidden(context, hostname))
     62 
     63         handler.redirecting = False
     64         self.assertIsNone(mhandler.is_forbidden(context, hostname))
     65 
     66         context._orig_protocol = 'https'
     67         self.assertIsNone(mhandler.is_forbidden(context, hostname))
     68 
     69     def test_get_redirect_url(self):
     70         mhandler = MixinHandler()
     71         hostname = 'www.example.com'
     72         uri = '/'
     73         port = 443
     74 
     75         self.assertEqual(
     76             mhandler.get_redirect_url(hostname, port, uri=uri),
     77             'https://www.example.com/'
     78         )
     79 
     80         port = 4433
     81         self.assertEqual(
     82             mhandler.get_redirect_url(hostname, port, uri),
     83             'https://www.example.com:4433/'
     84         )
     85 
     86     def test_get_client_addr(self):
     87         mhandler = MixinHandler()
     88         client_addr = ('8.8.8.8', 8888)
     89         context_addr = ('127.0.0.1', 1234)
     90         options.xheaders = True
     91 
     92         mhandler.context = Mock(address=context_addr)
     93         mhandler.get_real_client_addr = lambda: None
     94         self.assertEqual(mhandler.get_client_addr(), context_addr)
     95 
     96         mhandler.context = Mock(address=context_addr)
     97         mhandler.get_real_client_addr = lambda: client_addr
     98         self.assertEqual(mhandler.get_client_addr(), client_addr)
     99 
    100         options.xheaders = False
    101         mhandler.context = Mock(address=context_addr)
    102         mhandler.get_real_client_addr = lambda: client_addr
    103         self.assertEqual(mhandler.get_client_addr(), context_addr)
    104 
    105     def test_get_real_client_addr(self):
    106         x_forwarded_for = '1.1.1.1'
    107         x_forwarded_port = 1111
    108         x_real_ip = '2.2.2.2'
    109         x_real_port = 2222
    110         fake_port = 65535
    111 
    112         mhandler = MixinHandler()
    113         mhandler.request = HTTPServerRequest(uri='/')
    114         mhandler.request.remote_ip = x_forwarded_for
    115 
    116         self.assertIsNone(mhandler.get_real_client_addr())
    117 
    118         mhandler.request.headers.add('X-Forwarded-For', x_forwarded_for)
    119         self.assertEqual(mhandler.get_real_client_addr(),
    120                          (x_forwarded_for, fake_port))
    121 
    122         mhandler.request.headers.add('X-Forwarded-Port', fake_port + 1)
    123         self.assertEqual(mhandler.get_real_client_addr(),
    124                          (x_forwarded_for, fake_port))
    125 
    126         mhandler.request.headers['X-Forwarded-Port'] = x_forwarded_port
    127         self.assertEqual(mhandler.get_real_client_addr(),
    128                          (x_forwarded_for, x_forwarded_port))
    129 
    130         mhandler.request.remote_ip = x_real_ip
    131 
    132         mhandler.request.headers.add('X-Real-Ip', x_real_ip)
    133         self.assertEqual(mhandler.get_real_client_addr(),
    134                          (x_real_ip, fake_port))
    135 
    136         mhandler.request.headers.add('X-Real-Port', fake_port + 1)
    137         self.assertEqual(mhandler.get_real_client_addr(),
    138                          (x_real_ip, fake_port))
    139 
    140         mhandler.request.headers['X-Real-Port'] = x_real_port
    141         self.assertEqual(mhandler.get_real_client_addr(),
    142                          (x_real_ip, x_real_port))
    143 
    144 
    145 class TestPrivateKey(unittest.TestCase):
    146 
    147     def get_pk_obj(self, fname, password=None):
    148         key = read_file(make_tests_data_path(fname))
    149         return PrivateKey(key, password=password, filename=fname)
    150 
    151     def _test_with_encrypted_key(self, fname, password, klass):
    152         pk = self.get_pk_obj(fname, password='')
    153         with self.assertRaises(InvalidValueError) as ctx:
    154             pk.get_pkey_obj()
    155         self.assertIn('Need a passphrase', str(ctx.exception))
    156 
    157         pk = self.get_pk_obj(fname, password='wrongpass')
    158         with self.assertRaises(InvalidValueError) as ctx:
    159             pk.get_pkey_obj()
    160         self.assertIn('wrong passphrase', str(ctx.exception))
    161 
    162         pk = self.get_pk_obj(fname, password=password)
    163         self.assertIsInstance(pk.get_pkey_obj(), klass)
    164 
    165     def test_class_with_invalid_key_length(self):
    166         key = u'a' * (PrivateKey.max_length + 1)
    167 
    168         with self.assertRaises(InvalidValueError) as ctx:
    169             PrivateKey(key)
    170         self.assertIn('Invalid key length', str(ctx.exception))
    171 
    172     def test_get_pkey_obj_with_invalid_key(self):
    173         key = u'a b c'
    174         fname = 'abc'
    175 
    176         pk = PrivateKey(key, filename=fname)
    177         with self.assertRaises(InvalidValueError) as ctx:
    178             pk.get_pkey_obj()
    179         self.assertIn('Invalid key {}'.format(fname), str(ctx.exception))
    180 
    181     def test_get_pkey_obj_with_plain_rsa_key(self):
    182         pk = self.get_pk_obj('test_rsa.key')
    183         self.assertIsInstance(pk.get_pkey_obj(), paramiko.RSAKey)
    184 
    185     def test_get_pkey_obj_with_plain_ed25519_key(self):
    186         pk = self.get_pk_obj('test_ed25519.key')
    187         self.assertIsInstance(pk.get_pkey_obj(), paramiko.Ed25519Key)
    188 
    189     def test_get_pkey_obj_with_encrypted_rsa_key(self):
    190         fname = 'test_rsa_password.key'
    191         password = 'television'
    192         self._test_with_encrypted_key(fname, password, paramiko.RSAKey)
    193 
    194     def test_get_pkey_obj_with_encrypted_ed25519_key(self):
    195         fname = 'test_ed25519_password.key'
    196         password = 'abc123'
    197         self._test_with_encrypted_key(fname, password, paramiko.Ed25519Key)
    198 
    199     def test_get_pkey_obj_with_encrypted_new_rsa_key(self):
    200         fname = 'test_new_rsa_password.key'
    201         password = '123456'
    202         self._test_with_encrypted_key(fname, password, paramiko.RSAKey)
    203 
    204     def test_get_pkey_obj_with_plain_new_dsa_key(self):
    205         pk = self.get_pk_obj('test_new_dsa.key')
    206         self.assertIsInstance(pk.get_pkey_obj(), paramiko.DSSKey)
    207 
    208     def test_parse_name(self):
    209         key = u'-----BEGIN PRIVATE KEY-----'
    210         pk = PrivateKey(key)
    211         name, _ = pk.parse_name(pk.iostr, pk.tag_to_name)
    212         self.assertIsNone(name)
    213 
    214         key = u'-----BEGIN xxx PRIVATE KEY-----'
    215         pk = PrivateKey(key)
    216         name, _ = pk.parse_name(pk.iostr, pk.tag_to_name)
    217         self.assertIsNone(name)
    218 
    219         key = u'-----BEGIN  RSA PRIVATE KEY-----'
    220         pk = PrivateKey(key)
    221         name, _ = pk.parse_name(pk.iostr, pk.tag_to_name)
    222         self.assertIsNone(name)
    223 
    224         key = u'-----BEGIN RSA  PRIVATE KEY-----'
    225         pk = PrivateKey(key)
    226         name, _ = pk.parse_name(pk.iostr, pk.tag_to_name)
    227         self.assertIsNone(name)
    228 
    229         key = u'-----BEGIN RSA PRIVATE  KEY-----'
    230         pk = PrivateKey(key)
    231         name, _ = pk.parse_name(pk.iostr, pk.tag_to_name)
    232         self.assertIsNone(name)
    233 
    234         for tag, to_name in PrivateKey.tag_to_name.items():
    235             key = u'-----BEGIN {} PRIVATE KEY----- \r\n'.format(tag)
    236             pk = PrivateKey(key)
    237             name, length = pk.parse_name(pk.iostr, pk.tag_to_name)
    238             self.assertEqual(name, to_name)
    239             self.assertEqual(length, len(key))
    240 
    241 
    242 class TestWsockHandler(unittest.TestCase):
    243 
    244     def test_check_origin(self):
    245         request = HTTPServerRequest(uri='/')
    246         obj = Mock(spec=WsockHandler, request=request)
    247 
    248         obj.origin_policy = 'same'
    249         request.headers['Host'] = 'www.example.com:4433'
    250         origin = 'https://www.example.com:4433'
    251         self.assertTrue(WsockHandler.check_origin(obj, origin))
    252 
    253         origin = 'https://www.example.com'
    254         self.assertFalse(WsockHandler.check_origin(obj, origin))
    255 
    256         obj.origin_policy = 'primary'
    257         self.assertTrue(WsockHandler.check_origin(obj, origin))
    258 
    259         origin = 'https://blog.example.com'
    260         self.assertTrue(WsockHandler.check_origin(obj, origin))
    261 
    262         origin = 'https://blog.example.org'
    263         self.assertFalse(WsockHandler.check_origin(obj, origin))
    264 
    265         origin = 'https://blog.example.org'
    266         obj.origin_policy = {'https://blog.example.org'}
    267         self.assertTrue(WsockHandler.check_origin(obj, origin))
    268 
    269         origin = 'http://blog.example.org'
    270         obj.origin_policy = {'http://blog.example.org'}
    271         self.assertTrue(WsockHandler.check_origin(obj, origin))
    272 
    273         origin = 'http://blog.example.org'
    274         obj.origin_policy = {'https://blog.example.org'}
    275         self.assertFalse(WsockHandler.check_origin(obj, origin))
    276 
    277         obj.origin_policy = '*'
    278         origin = 'https://blog.example.org'
    279         self.assertTrue(WsockHandler.check_origin(obj, origin))