webssh

Web based ssh client https://github.com/huashengdun/webssh webssh.huashengdun.org/
git clone http://git.hanabi.in/repos/webssh.git
Log | Files | Refs | README | LICENSE

test_settings.py (6946B)


      1 import io
      2 import random
      3 import ssl
      4 import sys
      5 import os.path
      6 import unittest
      7 import paramiko
      8 import tornado.options as options
      9 
     10 from tests.utils import make_tests_data_path
     11 from webssh.policy import load_host_keys
     12 from webssh.settings import (
     13     get_host_keys_settings, get_policy_setting, base_dir, get_font_filename,
     14     get_ssl_context, get_trusted_downstream, get_origin_setting, print_version,
     15     check_encoding_setting
     16 )
     17 from webssh.utils import UnicodeType
     18 from webssh._version import __version__
     19 
     20 
     21 class TestSettings(unittest.TestCase):
     22 
     23     def test_print_version(self):
     24         sys_stdout = sys.stdout
     25         sys.stdout = io.StringIO() if UnicodeType == str else io.BytesIO()
     26 
     27         self.assertEqual(print_version(False), None)
     28         self.assertEqual(sys.stdout.getvalue(), '')
     29 
     30         with self.assertRaises(SystemExit):
     31             self.assertEqual(print_version(True), None)
     32         self.assertEqual(sys.stdout.getvalue(), __version__ + '\n')
     33 
     34         sys.stdout = sys_stdout
     35 
     36     def test_get_host_keys_settings(self):
     37         options.hostfile = ''
     38         options.syshostfile = ''
     39         dic = get_host_keys_settings(options)
     40 
     41         filename = os.path.join(base_dir, 'known_hosts')
     42         self.assertEqual(dic['host_keys'], load_host_keys(filename))
     43         self.assertEqual(dic['host_keys_filename'], filename)
     44         self.assertEqual(
     45             dic['system_host_keys'],
     46             load_host_keys(os.path.expanduser('~/.ssh/known_hosts'))
     47         )
     48 
     49         options.hostfile = make_tests_data_path('known_hosts_example')
     50         options.syshostfile = make_tests_data_path('known_hosts_example2')
     51         dic2 = get_host_keys_settings(options)
     52         self.assertEqual(dic2['host_keys'], load_host_keys(options.hostfile))
     53         self.assertEqual(dic2['host_keys_filename'], options.hostfile)
     54         self.assertEqual(dic2['system_host_keys'],
     55                          load_host_keys(options.syshostfile))
     56 
     57     def test_get_policy_setting(self):
     58         options.policy = 'warning'
     59         options.hostfile = ''
     60         options.syshostfile = ''
     61         settings = get_host_keys_settings(options)
     62         instance = get_policy_setting(options, settings)
     63         self.assertIsInstance(instance, paramiko.client.WarningPolicy)
     64 
     65         options.policy = 'autoadd'
     66         options.hostfile = ''
     67         options.syshostfile = ''
     68         settings = get_host_keys_settings(options)
     69         instance = get_policy_setting(options, settings)
     70         self.assertIsInstance(instance, paramiko.client.AutoAddPolicy)
     71         os.unlink(settings['host_keys_filename'])
     72 
     73         options.policy = 'reject'
     74         options.hostfile = ''
     75         options.syshostfile = ''
     76         settings = get_host_keys_settings(options)
     77         try:
     78             instance = get_policy_setting(options, settings)
     79         except ValueError:
     80             self.assertFalse(
     81                 settings['host_keys'] and settings['system_host_keys']
     82             )
     83         else:
     84             self.assertIsInstance(instance, paramiko.client.RejectPolicy)
     85 
     86     def test_get_ssl_context(self):
     87         options.certfile = ''
     88         options.keyfile = ''
     89         ssl_ctx = get_ssl_context(options)
     90         self.assertIsNone(ssl_ctx)
     91 
     92         options.certfile = 'provided'
     93         options.keyfile = ''
     94         with self.assertRaises(ValueError) as ctx:
     95             ssl_ctx = get_ssl_context(options)
     96         self.assertEqual('keyfile is not provided', str(ctx.exception))
     97 
     98         options.certfile = ''
     99         options.keyfile = 'provided'
    100         with self.assertRaises(ValueError) as ctx:
    101             ssl_ctx = get_ssl_context(options)
    102         self.assertEqual('certfile is not provided', str(ctx.exception))
    103 
    104         options.certfile = 'FileDoesNotExist'
    105         options.keyfile = make_tests_data_path('cert.key')
    106         with self.assertRaises(ValueError) as ctx:
    107             ssl_ctx = get_ssl_context(options)
    108         self.assertIn('does not exist', str(ctx.exception))
    109 
    110         options.certfile = make_tests_data_path('cert.key')
    111         options.keyfile = 'FileDoesNotExist'
    112         with self.assertRaises(ValueError) as ctx:
    113             ssl_ctx = get_ssl_context(options)
    114         self.assertIn('does not exist', str(ctx.exception))
    115 
    116         options.certfile = make_tests_data_path('cert.key')
    117         options.keyfile = make_tests_data_path('cert.key')
    118         with self.assertRaises(ssl.SSLError) as ctx:
    119             ssl_ctx = get_ssl_context(options)
    120 
    121         options.certfile = make_tests_data_path('cert.crt')
    122         options.keyfile = make_tests_data_path('cert.key')
    123         ssl_ctx = get_ssl_context(options)
    124         self.assertIsNotNone(ssl_ctx)
    125 
    126     def test_get_trusted_downstream(self):
    127         tdstream = ''
    128         result = set()
    129         self.assertEqual(get_trusted_downstream(tdstream), result)
    130 
    131         tdstream = '1.1.1.1, 2.2.2.2'
    132         result = set(['1.1.1.1', '2.2.2.2'])
    133         self.assertEqual(get_trusted_downstream(tdstream), result)
    134 
    135         tdstream = '1.1.1.1, 2.2.2.2, 2.2.2.2'
    136         result = set(['1.1.1.1', '2.2.2.2'])
    137         self.assertEqual(get_trusted_downstream(tdstream), result)
    138 
    139         tdstream = '1.1.1.1, 2.2.2.'
    140         with self.assertRaises(ValueError):
    141             get_trusted_downstream(tdstream)
    142 
    143     def test_get_origin_setting(self):
    144         options.debug = False
    145         options.origin = '*'
    146         with self.assertRaises(ValueError):
    147             get_origin_setting(options)
    148 
    149         options.debug = True
    150         self.assertEqual(get_origin_setting(options), '*')
    151 
    152         options.origin = random.choice(['Same', 'Primary'])
    153         self.assertEqual(get_origin_setting(options), options.origin.lower())
    154 
    155         options.origin = ''
    156         with self.assertRaises(ValueError):
    157             get_origin_setting(options)
    158 
    159         options.origin = ','
    160         with self.assertRaises(ValueError):
    161             get_origin_setting(options)
    162 
    163         options.origin = 'www.example.com,  https://www.example.org'
    164         result = {'http://www.example.com', 'https://www.example.org'}
    165         self.assertEqual(get_origin_setting(options), result)
    166 
    167         options.origin = 'www.example.com:80,  www.example.org:443'
    168         result = {'http://www.example.com', 'https://www.example.org'}
    169         self.assertEqual(get_origin_setting(options), result)
    170 
    171     def test_get_font_setting(self):
    172         font_dir = os.path.join(base_dir, 'tests', 'data', 'fonts')
    173         font = ''
    174         self.assertEqual(get_font_filename(font, font_dir), 'fake-font')
    175 
    176         font = 'fake-font'
    177         self.assertEqual(get_font_filename(font, font_dir), 'fake-font')
    178 
    179         font = 'wrong-name'
    180         with self.assertRaises(ValueError):
    181             get_font_filename(font, font_dir)
    182 
    183     def test_check_encoding_setting(self):
    184         self.assertIsNone(check_encoding_setting(''))
    185         self.assertIsNone(check_encoding_setting('utf-8'))
    186         with self.assertRaises(ValueError):
    187             check_encoding_setting('unknown-encoding')