test_settings.py (6946B)
1 import io 2 import random 3 import ssl 4 import sys 5 import os.path 6 import unittest 7 import paramiko 8 import tornado.options as options 9 10 from tests.utils import make_tests_data_path 11 from webssh.policy import load_host_keys 12 from webssh.settings import ( 13 get_host_keys_settings, get_policy_setting, base_dir, get_font_filename, 14 get_ssl_context, get_trusted_downstream, get_origin_setting, print_version, 15 check_encoding_setting 16 ) 17 from webssh.utils import UnicodeType 18 from webssh._version import __version__ 19 20 21 class TestSettings(unittest.TestCase): 22 23 def test_print_version(self): 24 sys_stdout = sys.stdout 25 sys.stdout = io.StringIO() if UnicodeType == str else io.BytesIO() 26 27 self.assertEqual(print_version(False), None) 28 self.assertEqual(sys.stdout.getvalue(), '') 29 30 with self.assertRaises(SystemExit): 31 self.assertEqual(print_version(True), None) 32 self.assertEqual(sys.stdout.getvalue(), __version__ + '\n') 33 34 sys.stdout = sys_stdout 35 36 def test_get_host_keys_settings(self): 37 options.hostfile = '' 38 options.syshostfile = '' 39 dic = get_host_keys_settings(options) 40 41 filename = os.path.join(base_dir, 'known_hosts') 42 self.assertEqual(dic['host_keys'], load_host_keys(filename)) 43 self.assertEqual(dic['host_keys_filename'], filename) 44 self.assertEqual( 45 dic['system_host_keys'], 46 load_host_keys(os.path.expanduser('~/.ssh/known_hosts')) 47 ) 48 49 options.hostfile = make_tests_data_path('known_hosts_example') 50 options.syshostfile = make_tests_data_path('known_hosts_example2') 51 dic2 = get_host_keys_settings(options) 52 self.assertEqual(dic2['host_keys'], load_host_keys(options.hostfile)) 53 self.assertEqual(dic2['host_keys_filename'], options.hostfile) 54 self.assertEqual(dic2['system_host_keys'], 55 load_host_keys(options.syshostfile)) 56 57 def test_get_policy_setting(self): 58 options.policy = 'warning' 59 options.hostfile = '' 60 options.syshostfile = '' 61 settings = get_host_keys_settings(options) 62 instance = get_policy_setting(options, settings) 63 self.assertIsInstance(instance, paramiko.client.WarningPolicy) 64 65 options.policy = 'autoadd' 66 options.hostfile = '' 67 options.syshostfile = '' 68 settings = get_host_keys_settings(options) 69 instance = get_policy_setting(options, settings) 70 self.assertIsInstance(instance, paramiko.client.AutoAddPolicy) 71 os.unlink(settings['host_keys_filename']) 72 73 options.policy = 'reject' 74 options.hostfile = '' 75 options.syshostfile = '' 76 settings = get_host_keys_settings(options) 77 try: 78 instance = get_policy_setting(options, settings) 79 except ValueError: 80 self.assertFalse( 81 settings['host_keys'] and settings['system_host_keys'] 82 ) 83 else: 84 self.assertIsInstance(instance, paramiko.client.RejectPolicy) 85 86 def test_get_ssl_context(self): 87 options.certfile = '' 88 options.keyfile = '' 89 ssl_ctx = get_ssl_context(options) 90 self.assertIsNone(ssl_ctx) 91 92 options.certfile = 'provided' 93 options.keyfile = '' 94 with self.assertRaises(ValueError) as ctx: 95 ssl_ctx = get_ssl_context(options) 96 self.assertEqual('keyfile is not provided', str(ctx.exception)) 97 98 options.certfile = '' 99 options.keyfile = 'provided' 100 with self.assertRaises(ValueError) as ctx: 101 ssl_ctx = get_ssl_context(options) 102 self.assertEqual('certfile is not provided', str(ctx.exception)) 103 104 options.certfile = 'FileDoesNotExist' 105 options.keyfile = make_tests_data_path('cert.key') 106 with self.assertRaises(ValueError) as ctx: 107 ssl_ctx = get_ssl_context(options) 108 self.assertIn('does not exist', str(ctx.exception)) 109 110 options.certfile = make_tests_data_path('cert.key') 111 options.keyfile = 'FileDoesNotExist' 112 with self.assertRaises(ValueError) as ctx: 113 ssl_ctx = get_ssl_context(options) 114 self.assertIn('does not exist', str(ctx.exception)) 115 116 options.certfile = make_tests_data_path('cert.key') 117 options.keyfile = make_tests_data_path('cert.key') 118 with self.assertRaises(ssl.SSLError) as ctx: 119 ssl_ctx = get_ssl_context(options) 120 121 options.certfile = make_tests_data_path('cert.crt') 122 options.keyfile = make_tests_data_path('cert.key') 123 ssl_ctx = get_ssl_context(options) 124 self.assertIsNotNone(ssl_ctx) 125 126 def test_get_trusted_downstream(self): 127 tdstream = '' 128 result = set() 129 self.assertEqual(get_trusted_downstream(tdstream), result) 130 131 tdstream = '1.1.1.1, 2.2.2.2' 132 result = set(['1.1.1.1', '2.2.2.2']) 133 self.assertEqual(get_trusted_downstream(tdstream), result) 134 135 tdstream = '1.1.1.1, 2.2.2.2, 2.2.2.2' 136 result = set(['1.1.1.1', '2.2.2.2']) 137 self.assertEqual(get_trusted_downstream(tdstream), result) 138 139 tdstream = '1.1.1.1, 2.2.2.' 140 with self.assertRaises(ValueError): 141 get_trusted_downstream(tdstream) 142 143 def test_get_origin_setting(self): 144 options.debug = False 145 options.origin = '*' 146 with self.assertRaises(ValueError): 147 get_origin_setting(options) 148 149 options.debug = True 150 self.assertEqual(get_origin_setting(options), '*') 151 152 options.origin = random.choice(['Same', 'Primary']) 153 self.assertEqual(get_origin_setting(options), options.origin.lower()) 154 155 options.origin = '' 156 with self.assertRaises(ValueError): 157 get_origin_setting(options) 158 159 options.origin = ',' 160 with self.assertRaises(ValueError): 161 get_origin_setting(options) 162 163 options.origin = 'www.example.com, https://www.example.org' 164 result = {'http://www.example.com', 'https://www.example.org'} 165 self.assertEqual(get_origin_setting(options), result) 166 167 options.origin = 'www.example.com:80, www.example.org:443' 168 result = {'http://www.example.com', 'https://www.example.org'} 169 self.assertEqual(get_origin_setting(options), result) 170 171 def test_get_font_setting(self): 172 font_dir = os.path.join(base_dir, 'tests', 'data', 'fonts') 173 font = '' 174 self.assertEqual(get_font_filename(font, font_dir), 'fake-font') 175 176 font = 'fake-font' 177 self.assertEqual(get_font_filename(font, font_dir), 'fake-font') 178 179 font = 'wrong-name' 180 with self.assertRaises(ValueError): 181 get_font_filename(font, font_dir) 182 183 def test_check_encoding_setting(self): 184 self.assertIsNone(check_encoding_setting('')) 185 self.assertIsNone(check_encoding_setting('utf-8')) 186 with self.assertRaises(ValueError): 187 check_encoding_setting('unknown-encoding')