import aiohttp from aiohttp import web from .exception import RpcError from .exception import RpcErrorCode from .serializer import json from .serializer import msgpack from .constants import JSON_RPC_VERSION class RpcWebsocketHandler(object): def __init__(self, bytes_serializer=msgpack, str_serializer=json, services=None): self.set_bytes_serializer(bytes_serializer) self.set_str_serializer(str_serializer) self._services = {} try: self.register_services(services) except TypeError: pass def set_bytes_serializer(self, serializer): self._bytes_serializer = serializer def set_str_serializer(self, serializer): self._str_serializer = serializer async def __call__(self, request): ws = web.WebSocketResponse() await ws.prepare(request) async for msg in ws: if msg.tp == aiohttp.MsgType.text: if msg.data == 'close': await ws.close() else: ws.send_str(await self._call_service(msg.data, self._str_serializer)) elif msg.tp == aiohttp.MsgType.binary: ws.send_bytes(await self._call_service(msg.data, self._bytes_serializer)) elif msg.tp == aiohttp.MsgType.error: print('ws connection closed ' 'with exception {}'.format(ws.exception())) print('websocket connection closed') return ws def register_services(self, services): for service in services: self.register_service(service) def register_service(self, service): self._services[service.__class__.__name__] = service return self._services def parse_request(self, data, serializer): request = serializer.loads(data) return request['method'], request['params'], request['id'] def create_result(self, id, result, serializer): return serializer.dumps({ 'jsonrpc': JSON_RPC_VERSION, 'result': result, 'id': id, }) def create_error(self, id, message, code, serializer): return serializer.dumps({ 'jsonrpc': JSON_RPC_VERSION, 'error': { 'code': code.value, 'message': message, }, 'id': id, }) async def _call_service(self, data, serializer): method, params, id = self.parse_request(data, serializer) try: try: service_name, method_name = method.split('.') except ValueError: raise RpcError('Method `{}` not found'.format(method), RpcErrorCode.METHOD_NOT_FOUND) try: result = await self._services[service_name](method_name, **params) return self.create_result(id, result, serializer) except KeyError: raise RpcError('Method `{}` not found'.format(method), RpcErrorCode.METHOD_NOT_FOUND) except RpcError as e: return self.create_error(id, e.rpc_error_message, e.rpc_error_code, serializer)