handler.py (19200B)
1 import io 2 import json 3 import logging 4 import socket 5 import struct 6 import traceback 7 import weakref 8 import paramiko 9 import tornado.web 10 11 from concurrent.futures import ThreadPoolExecutor 12 from tornado.ioloop import IOLoop 13 from tornado.options import options 14 from tornado.process import cpu_count 15 from webssh.utils import ( 16 is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str, 17 to_int, to_ip_address, UnicodeType, is_ip_hostname, is_same_primary_domain, 18 is_valid_encoding 19 ) 20 from webssh.worker import Worker, recycle_worker, clients 21 22 try: 23 from json.decoder import JSONDecodeError 24 except ImportError: 25 JSONDecodeError = ValueError 26 27 try: 28 from urllib.parse import urlparse 29 except ImportError: 30 from urlparse import urlparse 31 32 33 DEFAULT_PORT = 22 34 35 swallow_http_errors = True 36 redirecting = None 37 38 39 class InvalidValueError(Exception): 40 pass 41 42 43 class SSHClient(paramiko.SSHClient): 44 45 def handler(self, title, instructions, prompt_list): 46 answers = [] 47 for prompt_, _ in prompt_list: 48 prompt = prompt_.strip().lower() 49 if prompt.startswith('password'): 50 answers.append(self.password) 51 elif prompt.startswith('verification'): 52 answers.append(self.totp) 53 else: 54 raise ValueError('Unknown prompt: {}'.format(prompt_)) 55 return answers 56 57 def auth_interactive(self, username, handler): 58 if not self.totp: 59 raise ValueError('Need a verification code for 2fa.') 60 self._transport.auth_interactive(username, handler) 61 62 def _auth(self, username, password, pkey, *args): 63 self.password = password 64 saved_exception = None 65 two_factor = False 66 allowed_types = set() 67 two_factor_types = {'keyboard-interactive', 'password'} 68 69 if pkey is not None: 70 logging.info('Trying publickey authentication') 71 try: 72 allowed_types = set( 73 self._transport.auth_publickey(username, pkey) 74 ) 75 two_factor = allowed_types & two_factor_types 76 if not two_factor: 77 return 78 except paramiko.SSHException as e: 79 saved_exception = e 80 81 if two_factor: 82 logging.info('Trying publickey 2fa') 83 return self.auth_interactive(username, self.handler) 84 85 if password is not None: 86 logging.info('Trying password authentication') 87 try: 88 self._transport.auth_password(username, password) 89 return 90 except paramiko.SSHException as e: 91 saved_exception = e 92 allowed_types = set(getattr(e, 'allowed_types', [])) 93 two_factor = allowed_types & two_factor_types 94 95 if two_factor: 96 logging.info('Trying password 2fa') 97 return self.auth_interactive(username, self.handler) 98 99 assert saved_exception is not None 100 raise saved_exception 101 102 103 class PrivateKey(object): 104 105 max_length = 16384 # rough number 106 107 tag_to_name = { 108 'RSA': 'RSA', 109 'DSA': 'DSS', 110 'EC': 'ECDSA', 111 'OPENSSH': 'Ed25519' 112 } 113 114 def __init__(self, privatekey, password=None, filename=''): 115 self.privatekey = privatekey 116 self.filename = filename 117 self.password = password 118 self.check_length() 119 self.iostr = io.StringIO(privatekey) 120 self.last_exception = None 121 122 def check_length(self): 123 if len(self.privatekey) > self.max_length: 124 raise InvalidValueError('Invalid key length.') 125 126 def parse_name(self, iostr, tag_to_name): 127 name = None 128 for line_ in iostr: 129 line = line_.strip() 130 if line and line.startswith('-----BEGIN ') and \ 131 line.endswith(' PRIVATE KEY-----'): 132 lst = line.split(' ') 133 if len(lst) == 4: 134 tag = lst[1] 135 if tag: 136 name = tag_to_name.get(tag) 137 if name: 138 break 139 return name, len(line_) 140 141 def get_specific_pkey(self, name, offset, password): 142 self.iostr.seek(offset) 143 logging.debug('Reset offset to {}.'.format(offset)) 144 145 logging.debug('Try parsing it as {} type key'.format(name)) 146 pkeycls = getattr(paramiko, name+'Key') 147 pkey = None 148 149 try: 150 pkey = pkeycls.from_private_key(self.iostr, password=password) 151 except paramiko.PasswordRequiredException: 152 raise InvalidValueError('Need a passphrase to decrypt the key.') 153 except (paramiko.SSHException, ValueError) as exc: 154 self.last_exception = exc 155 logging.debug(str(exc)) 156 157 return pkey 158 159 def get_pkey_obj(self): 160 logging.info('Parsing private key {!r}'.format(self.filename)) 161 name, length = self.parse_name(self.iostr, self.tag_to_name) 162 if not name: 163 raise InvalidValueError('Invalid key {}.'.format(self.filename)) 164 165 offset = self.iostr.tell() - length 166 password = to_bytes(self.password) if self.password else None 167 pkey = self.get_specific_pkey(name, offset, password) 168 169 if pkey is None and name == 'Ed25519': 170 for name in ['RSA', 'ECDSA', 'DSS']: 171 pkey = self.get_specific_pkey(name, offset, password) 172 if pkey: 173 break 174 175 if pkey: 176 return pkey 177 178 logging.error(str(self.last_exception)) 179 msg = 'Invalid key' 180 if self.password: 181 msg += ' or wrong passphrase "{}" for decrypting it.'.format( 182 self.password) 183 raise InvalidValueError(msg) 184 185 186 class MixinHandler(object): 187 188 custom_headers = { 189 'Server': 'TornadoServer' 190 } 191 192 html = ('<html><head><title>{code} {reason}</title></head><body>{code} ' 193 '{reason}</body></html>') 194 195 def initialize(self, loop=None): 196 self.check_request() 197 self.loop = loop 198 self.origin_policy = self.settings.get('origin_policy') 199 200 def check_request(self): 201 context = self.request.connection.context 202 result = self.is_forbidden(context, self.request.host_name) 203 self._transforms = [] 204 if result: 205 self.set_status(403) 206 self.finish( 207 self.html.format(code=self._status_code, reason=self._reason) 208 ) 209 elif result is False: 210 to_url = self.get_redirect_url( 211 self.request.host_name, options.sslport, self.request.uri 212 ) 213 self.redirect(to_url, permanent=True) 214 else: 215 self.context = context 216 217 def check_origin(self, origin): 218 if self.origin_policy == '*': 219 return True 220 221 parsed_origin = urlparse(origin) 222 netloc = parsed_origin.netloc.lower() 223 logging.debug('netloc: {}'.format(netloc)) 224 225 host = self.request.headers.get('Host') 226 logging.debug('host: {}'.format(host)) 227 228 if netloc == host: 229 return True 230 231 if self.origin_policy == 'same': 232 return False 233 elif self.origin_policy == 'primary': 234 return is_same_primary_domain(netloc.rsplit(':', 1)[0], 235 host.rsplit(':', 1)[0]) 236 else: 237 return origin in self.origin_policy 238 239 def is_forbidden(self, context, hostname): 240 ip = context.address[0] 241 lst = context.trusted_downstream 242 ip_address = None 243 244 if lst and ip not in lst: 245 logging.warning( 246 'IP {!r} not found in trusted downstream {!r}'.format(ip, lst) 247 ) 248 return True 249 250 if context._orig_protocol == 'http': 251 if redirecting and not is_ip_hostname(hostname): 252 ip_address = to_ip_address(ip) 253 if not ip_address.is_private: 254 # redirecting 255 return False 256 257 if options.fbidhttp: 258 if ip_address is None: 259 ip_address = to_ip_address(ip) 260 if not ip_address.is_private: 261 logging.warning('Public plain http request is forbidden.') 262 return True 263 264 def get_redirect_url(self, hostname, port, uri): 265 port = '' if port == 443 else ':%s' % port 266 return 'https://{}{}{}'.format(hostname, port, uri) 267 268 def set_default_headers(self): 269 for header in self.custom_headers.items(): 270 self.set_header(*header) 271 272 def get_value(self, name): 273 value = self.get_argument(name) 274 if not value: 275 raise InvalidValueError('Missing value {}'.format(name)) 276 return value 277 278 def get_context_addr(self): 279 return self.context.address[:2] 280 281 def get_client_addr(self): 282 if options.xheaders: 283 return self.get_real_client_addr() or self.get_context_addr() 284 else: 285 return self.get_context_addr() 286 287 def get_real_client_addr(self): 288 ip = self.request.remote_ip 289 290 if ip == self.request.headers.get('X-Real-Ip'): 291 port = self.request.headers.get('X-Real-Port') 292 elif ip in self.request.headers.get('X-Forwarded-For', ''): 293 port = self.request.headers.get('X-Forwarded-Port') 294 else: 295 # not running behind an nginx server 296 return 297 298 port = to_int(port) 299 if port is None or not is_valid_port(port): 300 # fake port 301 port = 65535 302 303 return (ip, port) 304 305 306 class NotFoundHandler(MixinHandler, tornado.web.ErrorHandler): 307 308 def initialize(self): 309 super(NotFoundHandler, self).initialize() 310 311 def prepare(self): 312 raise tornado.web.HTTPError(404) 313 314 315 class IndexHandler(MixinHandler, tornado.web.RequestHandler): 316 317 executor = ThreadPoolExecutor(max_workers=cpu_count()*5) 318 319 def initialize(self, loop, policy, host_keys_settings): 320 super(IndexHandler, self).initialize(loop) 321 self.policy = policy 322 self.host_keys_settings = host_keys_settings 323 self.ssh_client = self.get_ssh_client() 324 self.debug = self.settings.get('debug', False) 325 self.font = self.settings.get('font', '') 326 self.result = dict(id=None, status=None, encoding=None) 327 328 def write_error(self, status_code, **kwargs): 329 if swallow_http_errors and self.request.method == 'POST': 330 exc_info = kwargs.get('exc_info') 331 if exc_info: 332 reason = getattr(exc_info[1], 'log_message', None) 333 if reason: 334 self._reason = reason 335 self.result.update(status=self._reason) 336 self.set_status(200) 337 self.finish(self.result) 338 else: 339 super(IndexHandler, self).write_error(status_code, **kwargs) 340 341 def get_ssh_client(self): 342 ssh = SSHClient() 343 ssh._system_host_keys = self.host_keys_settings['system_host_keys'] 344 ssh._host_keys = self.host_keys_settings['host_keys'] 345 ssh._host_keys_filename = self.host_keys_settings['host_keys_filename'] 346 ssh.set_missing_host_key_policy(self.policy) 347 return ssh 348 349 def get_privatekey(self): 350 name = 'privatekey' 351 lst = self.request.files.get(name) 352 if lst: 353 # multipart form 354 filename = lst[0]['filename'] 355 data = lst[0]['body'] 356 value = self.decode_argument(data, name=name).strip() 357 else: 358 # urlencoded form 359 value = self.get_argument(name, u'') 360 filename = '' 361 362 return value, filename 363 364 def get_hostname(self): 365 value = self.get_value('hostname') 366 if not (is_valid_hostname(value) or is_valid_ip_address(value)): 367 raise InvalidValueError('Invalid hostname: {}'.format(value)) 368 return value 369 370 def get_port(self): 371 value = self.get_argument('port', u'') 372 if not value: 373 return DEFAULT_PORT 374 375 port = to_int(value) 376 if port is None or not is_valid_port(port): 377 raise InvalidValueError('Invalid port: {}'.format(value)) 378 return port 379 380 def lookup_hostname(self, hostname, port): 381 key = hostname if port == 22 else '[{}]:{}'.format(hostname, port) 382 383 if self.ssh_client._system_host_keys.lookup(key) is None: 384 if self.ssh_client._host_keys.lookup(key) is None: 385 raise tornado.web.HTTPError( 386 403, 'Connection to {}:{} is not allowed.'.format( 387 hostname, port) 388 ) 389 390 def get_args(self): 391 hostname = self.get_hostname() 392 port = self.get_port() 393 username = self.get_value('username') 394 password = self.get_argument('password', u'') 395 privatekey, filename = self.get_privatekey() 396 passphrase = self.get_argument('passphrase', u'') 397 totp = self.get_argument('totp', u'') 398 399 if isinstance(self.policy, paramiko.RejectPolicy): 400 self.lookup_hostname(hostname, port) 401 402 if privatekey: 403 pkey = PrivateKey(privatekey, passphrase, filename).get_pkey_obj() 404 else: 405 pkey = None 406 407 self.ssh_client.totp = totp 408 args = (hostname, port, username, password, pkey) 409 logging.debug(args) 410 411 return args 412 413 def parse_encoding(self, data): 414 try: 415 encoding = to_str(data.strip(), 'ascii') 416 except UnicodeDecodeError: 417 return 418 419 if is_valid_encoding(encoding): 420 return encoding 421 422 def get_default_encoding(self, ssh): 423 commands = [ 424 '$SHELL -ilc "locale charmap"', 425 '$SHELL -ic "locale charmap"' 426 ] 427 428 for command in commands: 429 try: 430 _, stdout, _ = ssh.exec_command(command, get_pty=True) 431 except paramiko.SSHException as exc: 432 logging.info(str(exc)) 433 else: 434 data = stdout.read() 435 logging.debug('{!r} => {!r}'.format(command, data)) 436 result = self.parse_encoding(data) 437 if result: 438 return result 439 440 logging.warning('Could not detect the default encoding.') 441 return 'utf-8' 442 443 def ssh_connect(self, args): 444 ssh = self.ssh_client 445 dst_addr = args[:2] 446 logging.info('Connecting to {}:{}'.format(*dst_addr)) 447 448 try: 449 ssh.connect(*args, timeout=options.timeout) 450 except socket.error: 451 raise ValueError('Unable to connect to {}:{}'.format(*dst_addr)) 452 except paramiko.BadAuthenticationType: 453 raise ValueError('Bad authentication type.') 454 except paramiko.AuthenticationException: 455 raise ValueError('Authentication failed.') 456 except paramiko.BadHostKeyException: 457 raise ValueError('Bad host key.') 458 459 term = self.get_argument('term', u'') or u'xterm' 460 chan = ssh.invoke_shell(term=term) 461 chan.setblocking(0) 462 worker = Worker(self.loop, ssh, chan, dst_addr) 463 worker.encoding = options.encoding if options.encoding else \ 464 self.get_default_encoding(ssh) 465 return worker 466 467 def check_origin(self): 468 event_origin = self.get_argument('_origin', u'') 469 header_origin = self.request.headers.get('Origin') 470 origin = event_origin or header_origin 471 472 if origin: 473 if not super(IndexHandler, self).check_origin(origin): 474 raise tornado.web.HTTPError( 475 403, 'Cross origin operation is not allowed.' 476 ) 477 478 if not event_origin and self.origin_policy != 'same': 479 self.set_header('Access-Control-Allow-Origin', origin) 480 481 def head(self): 482 pass 483 484 def get(self): 485 self.render('index.html', debug=self.debug, font=self.font) 486 487 @tornado.gen.coroutine 488 def post(self): 489 if self.debug and self.get_argument('error', u''): 490 # for testing purpose only 491 raise ValueError('Uncaught exception') 492 493 ip, port = self.get_client_addr() 494 workers = clients.get(ip, {}) 495 if workers and len(workers) >= options.maxconn: 496 raise tornado.web.HTTPError(403, 'Too many live connections.') 497 498 self.check_origin() 499 500 try: 501 args = self.get_args() 502 except InvalidValueError as exc: 503 raise tornado.web.HTTPError(400, str(exc)) 504 505 future = self.executor.submit(self.ssh_connect, args) 506 507 try: 508 worker = yield future 509 except (ValueError, paramiko.SSHException) as exc: 510 logging.error(traceback.format_exc()) 511 self.result.update(status=str(exc)) 512 else: 513 if not workers: 514 clients[ip] = workers 515 worker.src_addr = (ip, port) 516 workers[worker.id] = worker 517 self.loop.call_later(options.delay, recycle_worker, worker) 518 self.result.update(id=worker.id, encoding=worker.encoding) 519 520 self.write(self.result) 521 522 523 class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): 524 525 def initialize(self, loop): 526 super(WsockHandler, self).initialize(loop) 527 self.worker_ref = None 528 529 def open(self): 530 self.src_addr = self.get_client_addr() 531 logging.info('Connected from {}:{}'.format(*self.src_addr)) 532 533 workers = clients.get(self.src_addr[0]) 534 if not workers: 535 self.close(reason='Websocket authentication failed.') 536 return 537 538 try: 539 worker_id = self.get_value('id') 540 except (tornado.web.MissingArgumentError, InvalidValueError) as exc: 541 self.close(reason=str(exc)) 542 else: 543 worker = workers.get(worker_id) 544 if worker: 545 workers[worker_id] = None 546 self.set_nodelay(True) 547 worker.set_handler(self) 548 self.worker_ref = weakref.ref(worker) 549 self.loop.add_handler(worker.fd, worker, IOLoop.READ) 550 else: 551 self.close(reason='Websocket authentication failed.') 552 553 def on_message(self, message): 554 logging.debug('{!r} from {}:{}'.format(message, *self.src_addr)) 555 worker = self.worker_ref() 556 try: 557 msg = json.loads(message) 558 except JSONDecodeError: 559 return 560 561 if not isinstance(msg, dict): 562 return 563 564 resize = msg.get('resize') 565 if resize and len(resize) == 2: 566 try: 567 worker.chan.resize_pty(*resize) 568 except (TypeError, struct.error, paramiko.SSHException): 569 pass 570 571 data = msg.get('data') 572 if data and isinstance(data, UnicodeType): 573 worker.data_to_dst.append(data) 574 worker.on_write() 575 576 def on_close(self): 577 logging.info('Disconnected from {}:{}'.format(*self.src_addr)) 578 if not self.close_reason: 579 self.close_reason = 'client disconnected' 580 581 worker = self.worker_ref() if self.worker_ref else None 582 if worker: 583 worker.close(reason=self.close_reason)