commit e85ae1692e492a36c9abf36da7c2a36d2dc6afc0
parent cb86682551623a1d926764d54503f6a9ebe719a6
Author: Sheng <webmaster0115@gmail.com>
Date: Mon, 20 Aug 2018 18:35:50 +0800
Added to_bytes function to utils
Diffstat:
5 files changed, 51 insertions(+), 34 deletions(-)
diff --git a/tests/test_app.py b/tests/test_app.py
@@ -13,6 +13,7 @@ from tests.sshserver import run_ssh_server, banner
from tests.utils import encode_multipart_formdata, read_file
from webssh.main import make_app, make_handlers
from webssh.settings import get_app_settings, max_body_size, base_dir
+from webssh.utils import to_str
handler.DELAY = 0.1
@@ -22,7 +23,7 @@ class TestApp(AsyncHTTPTestCase):
running = [True]
sshserver_port = 2200
- body = u'hostname=127.0.0.1&port={}&username=robey&password=foo'.format(sshserver_port) # noqa
+ body = 'hostname=127.0.0.1&port={}&username=robey&password=foo'.format(sshserver_port) # noqa
body_dict = {
'hostname': '127.0.0.1',
'port': str(sshserver_port),
@@ -61,37 +62,37 @@ class TestApp(AsyncHTTPTestCase):
def test_app_with_invalid_form(self):
response = self.fetch('/')
self.assertEqual(response.code, 200)
- body = u'hostname=&port=&username=&password'
+ body = 'hostname=&port=&username=&password'
response = self.fetch('/', method="POST", body=body)
self.assertIn(b'"status": "Empty hostname"', response.body)
- body = u'hostname=127.0.0.1&port=&username=&password'
+ body = 'hostname=127.0.0.1&port=&username=&password'
response = self.fetch('/', method="POST", body=body)
self.assertIn(b'"status": "Empty port"', response.body)
- body = u'hostname=127.0.0.1&port=port&username=&password'
+ body = 'hostname=127.0.0.1&port=port&username=&password'
response = self.fetch('/', method="POST", body=body)
self.assertIn(b'"status": "Invalid port', response.body)
- body = u'hostname=127.0.0.1&port=70000&username=&password'
+ body = 'hostname=127.0.0.1&port=70000&username=&password'
response = self.fetch('/', method="POST", body=body)
self.assertIn(b'"status": "Invalid port', response.body)
- body = u'hostname=127.0.0.1&port=7000&username=&password'
+ body = 'hostname=127.0.0.1&port=7000&username=&password'
response = self.fetch('/', method="POST", body=body)
self.assertIn(b'"status": "Empty username"', response.body)
def test_app_with_wrong_credentials(self):
response = self.fetch('/')
self.assertEqual(response.code, 200)
- response = self.fetch('/', method="POST", body=self.body + u's')
+ response = self.fetch('/', method="POST", body=self.body + 's')
self.assertIn(b'Authentication failed.', response.body)
def test_app_with_correct_credentials(self):
response = self.fetch('/')
self.assertEqual(response.code, 200)
response = self.fetch('/', method="POST", body=self.body)
- data = json.loads(response.body.decode('utf-8'))
+ data = json.loads(to_str(response.body))
self.assertIsNone(data['status'])
self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding'])
@@ -104,7 +105,7 @@ class TestApp(AsyncHTTPTestCase):
self.assertEqual(response.code, 200)
response = yield client.fetch(url, method="POST", body=self.body)
- data = json.loads(response.body.decode('utf-8'))
+ data = json.loads(to_str(response.body))
self.assertIsNone(data['status'])
self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding'])
@@ -133,7 +134,7 @@ class TestApp(AsyncHTTPTestCase):
}
response = yield client.fetch(url, method="POST", headers=headers,
body=body)
- data = json.loads(response.body.decode('utf-8'))
+ data = json.loads(to_str(response.body))
self.assertIsNone(data['status'])
self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding'])
@@ -142,7 +143,7 @@ class TestApp(AsyncHTTPTestCase):
ws_url = url + 'ws?id=' + data['id']
ws = yield tornado.websocket.websocket_connect(ws_url)
msg = yield ws.read_message()
- self.assertEqual(msg.decode(data['encoding']), banner)
+ self.assertEqual(to_str(msg, data['encoding']), banner)
ws.close()
@tornado.testing.gen_test
@@ -153,7 +154,7 @@ class TestApp(AsyncHTTPTestCase):
self.assertEqual(response.code, 200)
privatekey = read_file(os.path.join(base_dir, 'tests', 'user_rsa_key'))
- privatekey = privatekey[:100] + u'bad' + privatekey[100:]
+ privatekey = privatekey[:100] + 'bad' + privatekey[100:]
files = [('privatekey', 'user_rsa_key', privatekey)]
content_type, body = encode_multipart_formdata(self.body_dict.items(),
files)
@@ -162,7 +163,7 @@ class TestApp(AsyncHTTPTestCase):
}
response = yield client.fetch(url, method="POST", headers=headers,
body=body)
- data = json.loads(response.body.decode('utf-8'))
+ data = json.loads(to_str(response.body))
self.assertIsNotNone(data['status'])
self.assertIsNone(data['id'])
self.assertIsNone(data['encoding'])
@@ -174,7 +175,7 @@ class TestApp(AsyncHTTPTestCase):
response = yield client.fetch(url)
self.assertEqual(response.code, 200)
- privatekey = u'h' * (2 * max_body_size)
+ privatekey = 'h' * (2 * max_body_size)
files = [('privatekey', 'user_rsa_key', privatekey)]
content_type, body = encode_multipart_formdata(self.body_dict.items(),
files)
@@ -193,7 +194,7 @@ class TestApp(AsyncHTTPTestCase):
self.assertEqual(response.code, 200)
response = yield client.fetch(url, method="POST", body=self.body)
- data = json.loads(response.body.decode('utf-8'))
+ data = json.loads(to_str(response.body))
self.assertIsNone(data['status'])
self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding'])
@@ -202,7 +203,7 @@ class TestApp(AsyncHTTPTestCase):
ws_url = url + 'ws?id=' + data['id']
ws = yield tornado.websocket.websocket_connect(ws_url)
msg = yield ws.read_message()
- self.assertEqual(msg.decode(data['encoding']), banner)
+ self.assertEqual(to_str(msg, data['encoding']), banner)
ws.close()
@tornado.testing.gen_test
@@ -214,7 +215,7 @@ class TestApp(AsyncHTTPTestCase):
body = self.body.replace('robey', 'bar')
response = yield client.fetch(url, method="POST", body=body)
- data = json.loads(response.body.decode('utf-8'))
+ data = json.loads(to_str(response.body))
self.assertIsNone(data['status'])
self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding'])
@@ -223,7 +224,7 @@ class TestApp(AsyncHTTPTestCase):
ws_url = url + 'ws?id=' + data['id']
ws = yield tornado.websocket.websocket_connect(ws_url)
msg = yield ws.read_message()
- self.assertEqual(msg.decode(data['encoding']), banner)
+ self.assertEqual(to_str(msg, data['encoding']), banner)
# messages below will be ignored silently
yield ws.write_message('hello')
diff --git a/tests/test_handler.py b/tests/test_handler.py
@@ -56,7 +56,7 @@ class TestIndexHandler(unittest.TestCase):
key = read_file(os.path.join(base_dir, 'tests', fname))
pkey = IndexHandler.get_specific_pkey(cls, key, None)
self.assertIsInstance(pkey, cls)
- pkey = IndexHandler.get_specific_pkey(cls, key, b'iginored')
+ pkey = IndexHandler.get_specific_pkey(cls, key, 'iginored')
self.assertIsInstance(pkey, cls)
pkey = IndexHandler.get_specific_pkey(cls, 'x'+key, None)
self.assertIsNone(pkey)
@@ -64,7 +64,7 @@ class TestIndexHandler(unittest.TestCase):
def test_get_specific_pkey_with_encrypted_key(self):
fname = 'test_rsa_password.key'
cls = paramiko.RSAKey
- password = b'television'
+ password = 'television'
key = read_file(os.path.join(base_dir, 'tests', fname))
pkey = IndexHandler.get_specific_pkey(cls, key, password)
@@ -81,7 +81,7 @@ class TestIndexHandler(unittest.TestCase):
key = read_file(os.path.join(base_dir, 'tests', fname))
pkey = IndexHandler.get_pkey_obj(key, None)
self.assertIsInstance(pkey, cls)
- pkey = IndexHandler.get_pkey_obj(key, u'iginored')
+ pkey = IndexHandler.get_pkey_obj(key, 'iginored')
self.assertIsInstance(pkey, cls)
with self.assertRaises(ValueError):
pkey = IndexHandler.get_pkey_obj('x'+key, None)
@@ -94,6 +94,6 @@ class TestIndexHandler(unittest.TestCase):
pkey = IndexHandler.get_pkey_obj(key, password)
self.assertIsInstance(pkey, cls)
with self.assertRaises(ValueError):
- pkey = IndexHandler.get_pkey_obj(key, u'wrongpass')
+ pkey = IndexHandler.get_pkey_obj(key, 'wrongpass')
with self.assertRaises(ValueError):
pkey = IndexHandler.get_pkey_obj('x'+key, password)
diff --git a/tests/test_utils.py b/tests/test_utils.py
@@ -1,7 +1,7 @@
import unittest
from webssh.utils import (is_valid_ipv4_address, is_valid_ipv6_address,
- is_valid_port, to_str)
+ is_valid_port, to_str, to_bytes)
class TestUitls(unittest.TestCase):
@@ -12,6 +12,12 @@ class TestUitls(unittest.TestCase):
self.assertEqual(to_str(b), u)
self.assertEqual(to_str(u), u)
+ def test_to_bytes(self):
+ b = b'hello'
+ u = u'hello'
+ self.assertEqual(to_bytes(b), b)
+ self.assertEqual(to_bytes(u), b)
+
def test_is_valid_ipv4_address(self):
self.assertFalse(is_valid_ipv4_address('127.0.0'))
self.assertFalse(is_valid_ipv4_address(b'127.0.0'))
diff --git a/webssh/handler.py b/webssh/handler.py
@@ -10,10 +10,9 @@ import paramiko
import tornado.web
from tornado.ioloop import IOLoop
-from tornado.util import basestring_type
from webssh.worker import Worker, recycle_worker, workers
from webssh.utils import (is_valid_ipv4_address, is_valid_ipv6_address,
- is_valid_port)
+ is_valid_port, to_bytes, to_str, UnicodeType)
try:
from concurrent.futures import Future
@@ -70,7 +69,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
data = self.request.files.get('privatekey')[0]['body']
except TypeError:
return
- return data.decode('utf-8')
+ return to_str(data)
@classmethod
def get_specific_pkey(cls, pkeycls, privatekey, password):
@@ -87,7 +86,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
@classmethod
def get_pkey_obj(cls, privatekey, password):
- password = password.encode('utf-8') if password else None
+ password = to_bytes(password)
pkey = cls.get_specific_pkey(paramiko.RSAKey, privatekey, password)\
or cls.get_specific_pkey(paramiko.DSSKey, privatekey, password)\
@@ -138,8 +137,8 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
except paramiko.SSHException:
result = None
else:
- data = stdout.read().decode('utf-8')
- result = parse_encoding(data)
+ data = stdout.read()
+ result = parse_encoding(to_str(data))
return result if result else 'utf-8'
@@ -247,7 +246,7 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
pass
data = msg.get('data')
- if data and isinstance(data, basestring_type):
+ if data and isinstance(data, UnicodeType):
worker.data_to_dst.append(data)
worker.on_write()
diff --git a/webssh/utils.py b/webssh/utils.py
@@ -1,10 +1,21 @@
import ipaddress
+try:
+ from types import UnicodeType
+except ImportError:
+ UnicodeType = str
-def to_str(s):
- if isinstance(s, bytes):
- return s.decode('utf-8')
- return s
+
+def to_str(bstr, encoding='utf-8'):
+ if isinstance(bstr, bytes):
+ return bstr.decode(encoding)
+ return bstr
+
+
+def to_bytes(ustr, encoding='utf-8'):
+ if isinstance(ustr, UnicodeType):
+ return ustr.encode(encoding)
+ return ustr
def is_valid_ipv4_address(ipstr):