test_policy.py (4340B)
1 import os 2 import unittest 3 import paramiko 4 5 from shutil import copyfile 6 from paramiko.client import RejectPolicy, WarningPolicy 7 from tests.utils import make_tests_data_path 8 from webssh.policy import ( 9 AutoAddPolicy, get_policy_dictionary, load_host_keys, 10 get_policy_class, check_policy_setting 11 ) 12 13 14 class TestPolicy(unittest.TestCase): 15 16 def test_get_policy_dictionary(self): 17 classes = [AutoAddPolicy, RejectPolicy, WarningPolicy] 18 dic = get_policy_dictionary() 19 for cls in classes: 20 val = dic[cls.__name__.lower()] 21 self.assertIs(cls, val) 22 23 def test_load_host_keys(self): 24 path = '/path-not-exists' 25 host_keys = load_host_keys(path) 26 self.assertFalse(host_keys) 27 28 path = '/tmp' 29 host_keys = load_host_keys(path) 30 self.assertFalse(host_keys) 31 32 path = make_tests_data_path('known_hosts_example') 33 host_keys = load_host_keys(path) 34 self.assertEqual(host_keys, paramiko.hostkeys.HostKeys(path)) 35 36 def test_get_policy_class(self): 37 keys = ['autoadd', 'reject', 'warning'] 38 vals = [AutoAddPolicy, RejectPolicy, WarningPolicy] 39 for key, val in zip(keys, vals): 40 cls = get_policy_class(key) 41 self.assertIs(cls, val) 42 43 key = 'non-exists' 44 with self.assertRaises(ValueError): 45 get_policy_class(key) 46 47 def test_check_policy_setting(self): 48 host_keys_filename = make_tests_data_path('host_keys_test.db') 49 host_keys_settings = dict( 50 host_keys=paramiko.hostkeys.HostKeys(), 51 system_host_keys=paramiko.hostkeys.HostKeys(), 52 host_keys_filename=host_keys_filename 53 ) 54 55 with self.assertRaises(ValueError): 56 check_policy_setting(RejectPolicy, host_keys_settings) 57 58 try: 59 os.unlink(host_keys_filename) 60 except OSError: 61 pass 62 check_policy_setting(AutoAddPolicy, host_keys_settings) 63 self.assertEqual(os.path.exists(host_keys_filename), True) 64 65 def test_is_missing_host_key(self): 66 client = paramiko.SSHClient() 67 file1 = make_tests_data_path('known_hosts_example') 68 file2 = make_tests_data_path('known_hosts_example2') 69 client.load_host_keys(file1) 70 client.load_system_host_keys(file2) 71 72 autoadd = AutoAddPolicy() 73 for f in [file1, file2]: 74 entry = paramiko.hostkeys.HostKeys(f)._entries[0] 75 hostname = entry.hostnames[0] 76 key = entry.key 77 self.assertIsNone( 78 autoadd.is_missing_host_key(client, hostname, key) 79 ) 80 81 for f in [file1, file2]: 82 entry = paramiko.hostkeys.HostKeys(f)._entries[0] 83 hostname = entry.hostnames[0] 84 key = entry.key 85 key.get_name = lambda: 'unknown' 86 self.assertTrue( 87 autoadd.is_missing_host_key(client, hostname, key) 88 ) 89 del key.get_name 90 91 for f in [file1, file2]: 92 entry = paramiko.hostkeys.HostKeys(f)._entries[0] 93 hostname = entry.hostnames[0][1:] 94 key = entry.key 95 self.assertTrue( 96 autoadd.is_missing_host_key(client, hostname, key) 97 ) 98 99 file3 = make_tests_data_path('known_hosts_example3') 100 entry = paramiko.hostkeys.HostKeys(file3)._entries[0] 101 hostname = entry.hostnames[0] 102 key = entry.key 103 with self.assertRaises(paramiko.BadHostKeyException): 104 autoadd.is_missing_host_key(client, hostname, key) 105 106 def test_missing_host_key(self): 107 client = paramiko.SSHClient() 108 file1 = make_tests_data_path('known_hosts_example') 109 file2 = make_tests_data_path('known_hosts_example2') 110 filename = make_tests_data_path('known_hosts') 111 copyfile(file1, filename) 112 client.load_host_keys(filename) 113 n1 = len(client._host_keys) 114 115 autoadd = AutoAddPolicy() 116 entry = paramiko.hostkeys.HostKeys(file2)._entries[0] 117 hostname = entry.hostnames[0] 118 key = entry.key 119 autoadd.missing_host_key(client, hostname, key) 120 self.assertEqual(len(client._host_keys), n1 + 1) 121 self.assertEqual(paramiko.hostkeys.HostKeys(filename), 122 client._host_keys) 123 os.unlink(filename)