webssh

Web based ssh client https://github.com/huashengdun/webssh webssh.huashengdun.org/
git clone http://git.hanabi.in/repos/webssh.git
Log | Files | Refs | README | LICENSE

handler.py (19200B)


      1 import io
      2 import json
      3 import logging
      4 import socket
      5 import struct
      6 import traceback
      7 import weakref
      8 import paramiko
      9 import tornado.web
     10 
     11 from concurrent.futures import ThreadPoolExecutor
     12 from tornado.ioloop import IOLoop
     13 from tornado.options import options
     14 from tornado.process import cpu_count
     15 from webssh.utils import (
     16     is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
     17     to_int, to_ip_address, UnicodeType, is_ip_hostname, is_same_primary_domain,
     18     is_valid_encoding
     19 )
     20 from webssh.worker import Worker, recycle_worker, clients
     21 
     22 try:
     23     from json.decoder import JSONDecodeError
     24 except ImportError:
     25     JSONDecodeError = ValueError
     26 
     27 try:
     28     from urllib.parse import urlparse
     29 except ImportError:
     30     from urlparse import urlparse
     31 
     32 
     33 DEFAULT_PORT = 22
     34 
     35 swallow_http_errors = True
     36 redirecting = None
     37 
     38 
     39 class InvalidValueError(Exception):
     40     pass
     41 
     42 
     43 class SSHClient(paramiko.SSHClient):
     44 
     45     def handler(self, title, instructions, prompt_list):
     46         answers = []
     47         for prompt_, _ in prompt_list:
     48             prompt = prompt_.strip().lower()
     49             if prompt.startswith('password'):
     50                 answers.append(self.password)
     51             elif prompt.startswith('verification'):
     52                 answers.append(self.totp)
     53             else:
     54                 raise ValueError('Unknown prompt: {}'.format(prompt_))
     55         return answers
     56 
     57     def auth_interactive(self, username, handler):
     58         if not self.totp:
     59             raise ValueError('Need a verification code for 2fa.')
     60         self._transport.auth_interactive(username, handler)
     61 
     62     def _auth(self, username, password, pkey, *args):
     63         self.password = password
     64         saved_exception = None
     65         two_factor = False
     66         allowed_types = set()
     67         two_factor_types = {'keyboard-interactive', 'password'}
     68 
     69         if pkey is not None:
     70             logging.info('Trying publickey authentication')
     71             try:
     72                 allowed_types = set(
     73                     self._transport.auth_publickey(username, pkey)
     74                 )
     75                 two_factor = allowed_types & two_factor_types
     76                 if not two_factor:
     77                     return
     78             except paramiko.SSHException as e:
     79                 saved_exception = e
     80 
     81         if two_factor:
     82             logging.info('Trying publickey 2fa')
     83             return self.auth_interactive(username, self.handler)
     84 
     85         if password is not None:
     86             logging.info('Trying password authentication')
     87             try:
     88                 self._transport.auth_password(username, password)
     89                 return
     90             except paramiko.SSHException as e:
     91                 saved_exception = e
     92                 allowed_types = set(getattr(e, 'allowed_types', []))
     93                 two_factor = allowed_types & two_factor_types
     94 
     95         if two_factor:
     96             logging.info('Trying password 2fa')
     97             return self.auth_interactive(username, self.handler)
     98 
     99         assert saved_exception is not None
    100         raise saved_exception
    101 
    102 
    103 class PrivateKey(object):
    104 
    105     max_length = 16384  # rough number
    106 
    107     tag_to_name = {
    108         'RSA': 'RSA',
    109         'DSA': 'DSS',
    110         'EC': 'ECDSA',
    111         'OPENSSH': 'Ed25519'
    112     }
    113 
    114     def __init__(self, privatekey, password=None, filename=''):
    115         self.privatekey = privatekey
    116         self.filename = filename
    117         self.password = password
    118         self.check_length()
    119         self.iostr = io.StringIO(privatekey)
    120         self.last_exception = None
    121 
    122     def check_length(self):
    123         if len(self.privatekey) > self.max_length:
    124             raise InvalidValueError('Invalid key length.')
    125 
    126     def parse_name(self, iostr, tag_to_name):
    127         name = None
    128         for line_ in iostr:
    129             line = line_.strip()
    130             if line and line.startswith('-----BEGIN ') and \
    131                     line.endswith(' PRIVATE KEY-----'):
    132                 lst = line.split(' ')
    133                 if len(lst) == 4:
    134                     tag = lst[1]
    135                     if tag:
    136                         name = tag_to_name.get(tag)
    137                         if name:
    138                             break
    139         return name, len(line_)
    140 
    141     def get_specific_pkey(self, name, offset, password):
    142         self.iostr.seek(offset)
    143         logging.debug('Reset offset to {}.'.format(offset))
    144 
    145         logging.debug('Try parsing it as {} type key'.format(name))
    146         pkeycls = getattr(paramiko, name+'Key')
    147         pkey = None
    148 
    149         try:
    150             pkey = pkeycls.from_private_key(self.iostr, password=password)
    151         except paramiko.PasswordRequiredException:
    152             raise InvalidValueError('Need a passphrase to decrypt the key.')
    153         except (paramiko.SSHException, ValueError) as exc:
    154             self.last_exception = exc
    155             logging.debug(str(exc))
    156 
    157         return pkey
    158 
    159     def get_pkey_obj(self):
    160         logging.info('Parsing private key {!r}'.format(self.filename))
    161         name, length = self.parse_name(self.iostr, self.tag_to_name)
    162         if not name:
    163             raise InvalidValueError('Invalid key {}.'.format(self.filename))
    164 
    165         offset = self.iostr.tell() - length
    166         password = to_bytes(self.password) if self.password else None
    167         pkey = self.get_specific_pkey(name, offset, password)
    168 
    169         if pkey is None and name == 'Ed25519':
    170             for name in ['RSA', 'ECDSA', 'DSS']:
    171                 pkey = self.get_specific_pkey(name, offset, password)
    172                 if pkey:
    173                     break
    174 
    175         if pkey:
    176             return pkey
    177 
    178         logging.error(str(self.last_exception))
    179         msg = 'Invalid key'
    180         if self.password:
    181             msg += ' or wrong passphrase "{}" for decrypting it.'.format(
    182                     self.password)
    183         raise InvalidValueError(msg)
    184 
    185 
    186 class MixinHandler(object):
    187 
    188     custom_headers = {
    189         'Server': 'TornadoServer'
    190     }
    191 
    192     html = ('<html><head><title>{code} {reason}</title></head><body>{code} '
    193             '{reason}</body></html>')
    194 
    195     def initialize(self, loop=None):
    196         self.check_request()
    197         self.loop = loop
    198         self.origin_policy = self.settings.get('origin_policy')
    199 
    200     def check_request(self):
    201         context = self.request.connection.context
    202         result = self.is_forbidden(context, self.request.host_name)
    203         self._transforms = []
    204         if result:
    205             self.set_status(403)
    206             self.finish(
    207                 self.html.format(code=self._status_code, reason=self._reason)
    208             )
    209         elif result is False:
    210             to_url = self.get_redirect_url(
    211                 self.request.host_name, options.sslport, self.request.uri
    212             )
    213             self.redirect(to_url, permanent=True)
    214         else:
    215             self.context = context
    216 
    217     def check_origin(self, origin):
    218         if self.origin_policy == '*':
    219             return True
    220 
    221         parsed_origin = urlparse(origin)
    222         netloc = parsed_origin.netloc.lower()
    223         logging.debug('netloc: {}'.format(netloc))
    224 
    225         host = self.request.headers.get('Host')
    226         logging.debug('host: {}'.format(host))
    227 
    228         if netloc == host:
    229             return True
    230 
    231         if self.origin_policy == 'same':
    232             return False
    233         elif self.origin_policy == 'primary':
    234             return is_same_primary_domain(netloc.rsplit(':', 1)[0],
    235                                           host.rsplit(':', 1)[0])
    236         else:
    237             return origin in self.origin_policy
    238 
    239     def is_forbidden(self, context, hostname):
    240         ip = context.address[0]
    241         lst = context.trusted_downstream
    242         ip_address = None
    243 
    244         if lst and ip not in lst:
    245             logging.warning(
    246                 'IP {!r} not found in trusted downstream {!r}'.format(ip, lst)
    247             )
    248             return True
    249 
    250         if context._orig_protocol == 'http':
    251             if redirecting and not is_ip_hostname(hostname):
    252                 ip_address = to_ip_address(ip)
    253                 if not ip_address.is_private:
    254                     # redirecting
    255                     return False
    256 
    257             if options.fbidhttp:
    258                 if ip_address is None:
    259                     ip_address = to_ip_address(ip)
    260                 if not ip_address.is_private:
    261                     logging.warning('Public plain http request is forbidden.')
    262                     return True
    263 
    264     def get_redirect_url(self, hostname, port, uri):
    265         port = '' if port == 443 else ':%s' % port
    266         return 'https://{}{}{}'.format(hostname, port, uri)
    267 
    268     def set_default_headers(self):
    269         for header in self.custom_headers.items():
    270             self.set_header(*header)
    271 
    272     def get_value(self, name):
    273         value = self.get_argument(name)
    274         if not value:
    275             raise InvalidValueError('Missing value {}'.format(name))
    276         return value
    277 
    278     def get_context_addr(self):
    279         return self.context.address[:2]
    280 
    281     def get_client_addr(self):
    282         if options.xheaders:
    283             return self.get_real_client_addr() or self.get_context_addr()
    284         else:
    285             return self.get_context_addr()
    286 
    287     def get_real_client_addr(self):
    288         ip = self.request.remote_ip
    289 
    290         if ip == self.request.headers.get('X-Real-Ip'):
    291             port = self.request.headers.get('X-Real-Port')
    292         elif ip in self.request.headers.get('X-Forwarded-For', ''):
    293             port = self.request.headers.get('X-Forwarded-Port')
    294         else:
    295             # not running behind an nginx server
    296             return
    297 
    298         port = to_int(port)
    299         if port is None or not is_valid_port(port):
    300             # fake port
    301             port = 65535
    302 
    303         return (ip, port)
    304 
    305 
    306 class NotFoundHandler(MixinHandler, tornado.web.ErrorHandler):
    307 
    308     def initialize(self):
    309         super(NotFoundHandler, self).initialize()
    310 
    311     def prepare(self):
    312         raise tornado.web.HTTPError(404)
    313 
    314 
    315 class IndexHandler(MixinHandler, tornado.web.RequestHandler):
    316 
    317     executor = ThreadPoolExecutor(max_workers=cpu_count()*5)
    318 
    319     def initialize(self, loop, policy, host_keys_settings):
    320         super(IndexHandler, self).initialize(loop)
    321         self.policy = policy
    322         self.host_keys_settings = host_keys_settings
    323         self.ssh_client = self.get_ssh_client()
    324         self.debug = self.settings.get('debug', False)
    325         self.font = self.settings.get('font', '')
    326         self.result = dict(id=None, status=None, encoding=None)
    327 
    328     def write_error(self, status_code, **kwargs):
    329         if swallow_http_errors and self.request.method == 'POST':
    330             exc_info = kwargs.get('exc_info')
    331             if exc_info:
    332                 reason = getattr(exc_info[1], 'log_message', None)
    333                 if reason:
    334                     self._reason = reason
    335             self.result.update(status=self._reason)
    336             self.set_status(200)
    337             self.finish(self.result)
    338         else:
    339             super(IndexHandler, self).write_error(status_code, **kwargs)
    340 
    341     def get_ssh_client(self):
    342         ssh = SSHClient()
    343         ssh._system_host_keys = self.host_keys_settings['system_host_keys']
    344         ssh._host_keys = self.host_keys_settings['host_keys']
    345         ssh._host_keys_filename = self.host_keys_settings['host_keys_filename']
    346         ssh.set_missing_host_key_policy(self.policy)
    347         return ssh
    348 
    349     def get_privatekey(self):
    350         name = 'privatekey'
    351         lst = self.request.files.get(name)
    352         if lst:
    353             # multipart form
    354             filename = lst[0]['filename']
    355             data = lst[0]['body']
    356             value = self.decode_argument(data, name=name).strip()
    357         else:
    358             # urlencoded form
    359             value = self.get_argument(name, u'')
    360             filename = ''
    361 
    362         return value, filename
    363 
    364     def get_hostname(self):
    365         value = self.get_value('hostname')
    366         if not (is_valid_hostname(value) or is_valid_ip_address(value)):
    367             raise InvalidValueError('Invalid hostname: {}'.format(value))
    368         return value
    369 
    370     def get_port(self):
    371         value = self.get_argument('port', u'')
    372         if not value:
    373             return DEFAULT_PORT
    374 
    375         port = to_int(value)
    376         if port is None or not is_valid_port(port):
    377             raise InvalidValueError('Invalid port: {}'.format(value))
    378         return port
    379 
    380     def lookup_hostname(self, hostname, port):
    381         key = hostname if port == 22 else '[{}]:{}'.format(hostname, port)
    382 
    383         if self.ssh_client._system_host_keys.lookup(key) is None:
    384             if self.ssh_client._host_keys.lookup(key) is None:
    385                 raise tornado.web.HTTPError(
    386                         403, 'Connection to {}:{} is not allowed.'.format(
    387                             hostname, port)
    388                     )
    389 
    390     def get_args(self):
    391         hostname = self.get_hostname()
    392         port = self.get_port()
    393         username = self.get_value('username')
    394         password = self.get_argument('password', u'')
    395         privatekey, filename = self.get_privatekey()
    396         passphrase = self.get_argument('passphrase', u'')
    397         totp = self.get_argument('totp', u'')
    398 
    399         if isinstance(self.policy, paramiko.RejectPolicy):
    400             self.lookup_hostname(hostname, port)
    401 
    402         if privatekey:
    403             pkey = PrivateKey(privatekey, passphrase, filename).get_pkey_obj()
    404         else:
    405             pkey = None
    406 
    407         self.ssh_client.totp = totp
    408         args = (hostname, port, username, password, pkey)
    409         logging.debug(args)
    410 
    411         return args
    412 
    413     def parse_encoding(self, data):
    414         try:
    415             encoding = to_str(data.strip(), 'ascii')
    416         except UnicodeDecodeError:
    417             return
    418 
    419         if is_valid_encoding(encoding):
    420             return encoding
    421 
    422     def get_default_encoding(self, ssh):
    423         commands = [
    424             '$SHELL -ilc "locale charmap"',
    425             '$SHELL -ic "locale charmap"'
    426         ]
    427 
    428         for command in commands:
    429             try:
    430                 _, stdout, _ = ssh.exec_command(command, get_pty=True)
    431             except paramiko.SSHException as exc:
    432                 logging.info(str(exc))
    433             else:
    434                 data = stdout.read()
    435                 logging.debug('{!r} => {!r}'.format(command, data))
    436                 result = self.parse_encoding(data)
    437                 if result:
    438                     return result
    439 
    440         logging.warning('Could not detect the default encoding.')
    441         return 'utf-8'
    442 
    443     def ssh_connect(self, args):
    444         ssh = self.ssh_client
    445         dst_addr = args[:2]
    446         logging.info('Connecting to {}:{}'.format(*dst_addr))
    447 
    448         try:
    449             ssh.connect(*args, timeout=options.timeout)
    450         except socket.error:
    451             raise ValueError('Unable to connect to {}:{}'.format(*dst_addr))
    452         except paramiko.BadAuthenticationType:
    453             raise ValueError('Bad authentication type.')
    454         except paramiko.AuthenticationException:
    455             raise ValueError('Authentication failed.')
    456         except paramiko.BadHostKeyException:
    457             raise ValueError('Bad host key.')
    458 
    459         term = self.get_argument('term', u'') or u'xterm'
    460         chan = ssh.invoke_shell(term=term)
    461         chan.setblocking(0)
    462         worker = Worker(self.loop, ssh, chan, dst_addr)
    463         worker.encoding = options.encoding if options.encoding else \
    464             self.get_default_encoding(ssh)
    465         return worker
    466 
    467     def check_origin(self):
    468         event_origin = self.get_argument('_origin', u'')
    469         header_origin = self.request.headers.get('Origin')
    470         origin = event_origin or header_origin
    471 
    472         if origin:
    473             if not super(IndexHandler, self).check_origin(origin):
    474                 raise tornado.web.HTTPError(
    475                     403, 'Cross origin operation is not allowed.'
    476                 )
    477 
    478             if not event_origin and self.origin_policy != 'same':
    479                 self.set_header('Access-Control-Allow-Origin', origin)
    480 
    481     def head(self):
    482         pass
    483 
    484     def get(self):
    485         self.render('index.html', debug=self.debug, font=self.font)
    486 
    487     @tornado.gen.coroutine
    488     def post(self):
    489         if self.debug and self.get_argument('error', u''):
    490             # for testing purpose only
    491             raise ValueError('Uncaught exception')
    492 
    493         ip, port = self.get_client_addr()
    494         workers = clients.get(ip, {})
    495         if workers and len(workers) >= options.maxconn:
    496             raise tornado.web.HTTPError(403, 'Too many live connections.')
    497 
    498         self.check_origin()
    499 
    500         try:
    501             args = self.get_args()
    502         except InvalidValueError as exc:
    503             raise tornado.web.HTTPError(400, str(exc))
    504 
    505         future = self.executor.submit(self.ssh_connect, args)
    506 
    507         try:
    508             worker = yield future
    509         except (ValueError, paramiko.SSHException) as exc:
    510             logging.error(traceback.format_exc())
    511             self.result.update(status=str(exc))
    512         else:
    513             if not workers:
    514                 clients[ip] = workers
    515             worker.src_addr = (ip, port)
    516             workers[worker.id] = worker
    517             self.loop.call_later(options.delay, recycle_worker, worker)
    518             self.result.update(id=worker.id, encoding=worker.encoding)
    519 
    520         self.write(self.result)
    521 
    522 
    523 class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
    524 
    525     def initialize(self, loop):
    526         super(WsockHandler, self).initialize(loop)
    527         self.worker_ref = None
    528 
    529     def open(self):
    530         self.src_addr = self.get_client_addr()
    531         logging.info('Connected from {}:{}'.format(*self.src_addr))
    532 
    533         workers = clients.get(self.src_addr[0])
    534         if not workers:
    535             self.close(reason='Websocket authentication failed.')
    536             return
    537 
    538         try:
    539             worker_id = self.get_value('id')
    540         except (tornado.web.MissingArgumentError, InvalidValueError) as exc:
    541             self.close(reason=str(exc))
    542         else:
    543             worker = workers.get(worker_id)
    544             if worker:
    545                 workers[worker_id] = None
    546                 self.set_nodelay(True)
    547                 worker.set_handler(self)
    548                 self.worker_ref = weakref.ref(worker)
    549                 self.loop.add_handler(worker.fd, worker, IOLoop.READ)
    550             else:
    551                 self.close(reason='Websocket authentication failed.')
    552 
    553     def on_message(self, message):
    554         logging.debug('{!r} from {}:{}'.format(message, *self.src_addr))
    555         worker = self.worker_ref()
    556         try:
    557             msg = json.loads(message)
    558         except JSONDecodeError:
    559             return
    560 
    561         if not isinstance(msg, dict):
    562             return
    563 
    564         resize = msg.get('resize')
    565         if resize and len(resize) == 2:
    566             try:
    567                 worker.chan.resize_pty(*resize)
    568             except (TypeError, struct.error, paramiko.SSHException):
    569                 pass
    570 
    571         data = msg.get('data')
    572         if data and isinstance(data, UnicodeType):
    573             worker.data_to_dst.append(data)
    574             worker.on_write()
    575 
    576     def on_close(self):
    577         logging.info('Disconnected from {}:{}'.format(*self.src_addr))
    578         if not self.close_reason:
    579             self.close_reason = 'client disconnected'
    580 
    581         worker = self.worker_ref() if self.worker_ref else None
    582         if worker:
    583             worker.close(reason=self.close_reason)