settings.py (6583B)
1 import logging 2 import os.path 3 import ssl 4 import sys 5 6 from tornado.options import define 7 from webssh.policy import ( 8 load_host_keys, get_policy_class, check_policy_setting 9 ) 10 from webssh.utils import ( 11 to_ip_address, parse_origin_from_url, is_valid_encoding 12 ) 13 from webssh._version import __version__ 14 15 16 def print_version(flag): 17 if flag: 18 print(__version__) 19 sys.exit(0) 20 21 22 define('address', default='', help='Listen address') 23 define('port', type=int, default=8888, help='Listen port') 24 define('ssladdress', default='', help='SSL listen address') 25 define('sslport', type=int, default=4433, help='SSL listen port') 26 define('certfile', default='', help='SSL certificate file') 27 define('keyfile', default='', help='SSL private key file') 28 define('debug', type=bool, default=False, help='Debug mode') 29 define('policy', default='warning', 30 help='Missing host key policy, reject|autoadd|warning') 31 define('hostfile', default='', help='User defined host keys file') 32 define('syshostfile', default='', help='System wide host keys file') 33 define('tdstream', default='', help='Trusted downstream, separated by comma') 34 define('redirect', type=bool, default=True, help='Redirecting http to https') 35 define('fbidhttp', type=bool, default=True, 36 help='Forbid public plain http incoming requests') 37 define('xheaders', type=bool, default=True, help='Support xheaders') 38 define('xsrf', type=bool, default=True, help='CSRF protection') 39 define('origin', default='same', help='''Origin policy, 40 'same': same origin policy, matches host name and port number; 41 'primary': primary domain policy, matches primary domain only; 42 '<domains>': custom domains policy, matches any domain in the <domains> list 43 separated by comma; 44 '*': wildcard policy, matches any domain, allowed in debug mode only.''') 45 define('wpintvl', type=float, default=0, help='Websocket ping interval') 46 define('timeout', type=float, default=3, help='SSH connection timeout') 47 define('delay', type=float, default=3, help='The delay to call recycle_worker') 48 define('maxconn', type=int, default=20, 49 help='Maximum live connections (ssh sessions) per client') 50 define('font', default='', help='custom font filename') 51 define('encoding', default='', 52 help='''The default character encoding of ssh servers. 53 Example: --encoding='utf-8' to solve the problem with some switches&routers''') 54 define('version', type=bool, help='Show version information', 55 callback=print_version) 56 57 58 base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 59 font_dirs = ['webssh', 'static', 'css', 'fonts'] 60 max_body_size = 1 * 1024 * 1024 61 62 63 class Font(object): 64 65 def __init__(self, filename, dirs): 66 self.family = self.get_family(filename) 67 self.url = self.get_url(filename, dirs) 68 69 def get_family(self, filename): 70 return filename.split('.')[0] 71 72 def get_url(self, filename, dirs): 73 return os.path.join(*(dirs + [filename])) 74 75 76 def get_app_settings(options): 77 settings = dict( 78 template_path=os.path.join(base_dir, 'webssh', 'templates'), 79 static_path=os.path.join(base_dir, 'webssh', 'static'), 80 websocket_ping_interval=options.wpintvl, 81 debug=options.debug, 82 xsrf_cookies=options.xsrf, 83 font=Font( 84 get_font_filename(options.font, 85 os.path.join(base_dir, *font_dirs)), 86 font_dirs[1:] 87 ), 88 origin_policy=get_origin_setting(options) 89 ) 90 return settings 91 92 93 def get_server_settings(options): 94 settings = dict( 95 xheaders=options.xheaders, 96 max_body_size=max_body_size, 97 trusted_downstream=get_trusted_downstream(options.tdstream) 98 ) 99 return settings 100 101 102 def get_host_keys_settings(options): 103 if not options.hostfile: 104 host_keys_filename = os.path.join(base_dir, 'known_hosts') 105 else: 106 host_keys_filename = options.hostfile 107 host_keys = load_host_keys(host_keys_filename) 108 109 if not options.syshostfile: 110 filename = os.path.expanduser('~/.ssh/known_hosts') 111 else: 112 filename = options.syshostfile 113 system_host_keys = load_host_keys(filename) 114 115 settings = dict( 116 host_keys=host_keys, 117 system_host_keys=system_host_keys, 118 host_keys_filename=host_keys_filename 119 ) 120 return settings 121 122 123 def get_policy_setting(options, host_keys_settings): 124 policy_class = get_policy_class(options.policy) 125 logging.info(policy_class.__name__) 126 check_policy_setting(policy_class, host_keys_settings) 127 return policy_class() 128 129 130 def get_ssl_context(options): 131 if not options.certfile and not options.keyfile: 132 return None 133 elif not options.certfile: 134 raise ValueError('certfile is not provided') 135 elif not options.keyfile: 136 raise ValueError('keyfile is not provided') 137 elif not os.path.isfile(options.certfile): 138 raise ValueError('File {!r} does not exist'.format(options.certfile)) 139 elif not os.path.isfile(options.keyfile): 140 raise ValueError('File {!r} does not exist'.format(options.keyfile)) 141 else: 142 ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 143 ssl_ctx.load_cert_chain(options.certfile, options.keyfile) 144 return ssl_ctx 145 146 147 def get_trusted_downstream(tdstream): 148 result = set() 149 for ip in tdstream.split(','): 150 ip = ip.strip() 151 if ip: 152 to_ip_address(ip) 153 result.add(ip) 154 return result 155 156 157 def get_origin_setting(options): 158 if options.origin == '*': 159 if not options.debug: 160 raise ValueError( 161 'Wildcard origin policy is only allowed in debug mode.' 162 ) 163 else: 164 return '*' 165 166 origin = options.origin.lower() 167 if origin in ['same', 'primary']: 168 return origin 169 170 origins = set() 171 for url in origin.split(','): 172 orig = parse_origin_from_url(url) 173 if orig: 174 origins.add(orig) 175 176 if not origins: 177 raise ValueError('Empty origin list') 178 179 return origins 180 181 182 def get_font_filename(font, font_dir): 183 filenames = {f for f in os.listdir(font_dir) if not f.startswith('.') 184 and os.path.isfile(os.path.join(font_dir, f))} 185 if font: 186 if font not in filenames: 187 raise ValueError( 188 'Font file {!r} not found'.format(os.path.join(font_dir, font)) 189 ) 190 elif filenames: 191 font = filenames.pop() 192 193 return font 194 195 196 def check_encoding_setting(encoding): 197 if encoding and not is_valid_encoding(encoding): 198 raise ValueError('Unknown character encoding {!r}.'.format(encoding))