From 18292bf62974a66afccee1c044d127f39ace785d Mon Sep 17 00:00:00 2001 From: Anton Tolchanov Date: Sat, 6 Feb 2021 17:18:26 +0000 Subject: [PATCH] Add retries, close connection on timeout (fix #6) --- neohubapi/neohub.py | 65 ++++++++++++++++++++++---------------------- tests/test_neohub.py | 34 ++++++++++++++++++++++- 2 files changed, 66 insertions(+), 33 deletions(-) diff --git a/neohubapi/neohub.py b/neohubapi/neohub.py index 8096a00..e9dab9b 100644 --- a/neohubapi/neohub.py +++ b/neohubapi/neohub.py @@ -28,11 +28,12 @@ class NeoHubConnectionError(Error): class NeoHub: - def __init__(self, host='Neo-Hub', port=4242, request_timeout=5): + def __init__(self, host='Neo-Hub', port=4242, request_timeout=5, request_attempts=1): self._logger = logging.getLogger('neohub') self._host = host self._port = port self._request_timeout = request_timeout + self._request_attempts = request_attempts async def _send_message(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, message: str): encoded_message = bytearray(json.dumps(message) + "\0\r", "utf-8") @@ -49,39 +50,39 @@ class NeoHub: return data async def _send(self, message, expected_reply=None): - try: - reader, writer = await asyncio.open_connection(self._host, self._port) - except (socket.gaierror, ConnectionRefusedError) as e: - err = f'Could not connect to NeoHub at {self._host}: {e}' - self._logger.error(err) - raise NeoHubConnectionError from e + last_exception = None + for attempt in range(1, self._request_attempts+1): + try: + reader, writer = await asyncio.open_connection(self._host, self._port) + data = await asyncio.wait_for( + self._send_message(reader, writer, message), timeout=self._request_timeout) + json_string = data.decode('utf-8') + self._logger.debug(f"Received message: {json_string}") + reply = json.loads(json_string, object_hook=lambda d: SimpleNamespace(**d)) - try: - data = await asyncio.wait_for( - self._send_message(reader, writer, message), timeout=self._request_timeout) - except asyncio.TimeoutError as e: - self._logger.error(f'Timeout talking to NeoHub: {e}') - return False + if expected_reply is None: + return reply + if reply.__dict__ == expected_reply: + return True + self._logger.error(f"[{attempt}] Unexpected reply: {reply}") + except (socket.gaierror, ConnectionRefusedError) as e: + last_exception = NeoHubConnectionError(e) + self._logger.error(f"[{attempt}] Could not connect to NeoHub at {self._host}: {e}") + except asyncio.TimeoutError as e: + last_exception = e + self._logger.error(f"[{attempt}] Timed out while sending a message to {self._host}") + if writer is not None: + writer.close() + except json.decoder.JSONDecodeError as e: + last_exception = e + self._logger.error(f"[{attempt}] Could not decode JSON: {e}") + # Wait for 1/2 of the timeout value before retrying. + if self._request_attempts > 1 and attempt < self._request_attempts: + await asyncio.sleep(self._request_timeout / 2) - json_string = data.decode('utf-8') - self._logger.debug(f"Received message: {json_string}") - - try: - reply = json.loads(json_string, object_hook=lambda d: SimpleNamespace(**d)) - except json.decoder.JSONDecodeError as e: - if expected_reply is None: - raise(e) - else: - return False - - if expected_reply is None: - return reply - else: - if reply.__dict__ == expected_reply: - return True - else: - self._logger.error(f"Unexpected reply: {reply}") - return False + if expected_reply is None and last_exception is not None: + raise(last_exception) + return False async def firmware(self): """ diff --git a/tests/test_neohub.py b/tests/test_neohub.py index 8f6f597..443febb 100644 --- a/tests/test_neohub.py +++ b/tests/test_neohub.py @@ -1,6 +1,7 @@ import asyncio import json import pytest +import time from types import SimpleNamespace import neohubapi @@ -15,7 +16,8 @@ class FakeProtocol(asyncio.Protocol): def data_received(self, data): input = data.decode() - # self.handler() is set by create_protocol below. + # self.server and self.handler are set by create_protocol below. + self.server.inputs.append(input) output = self.handler(input).encode() + b'\0' self.transport.write(output) self.transport.close() @@ -25,11 +27,13 @@ class FakeServer: def __init__(self, loop, port): self.port = port self.loop = loop + self.inputs = [] async def start(self, handler): def create_protocol(): fake_protocol = FakeProtocol() fake_protocol.handler = handler + fake_protocol.server = self return fake_protocol self.server = await self.loop.create_server(create_protocol, HOST, self.port) @@ -75,7 +79,35 @@ async def test_send_invalid_json(fakeserver): # expected_reply is set, function returns False. assert await hub._send('test', {'message': 'ok'}) is False + assert len(fakeserver.inputs) == 1 # by default there are no retries. # expected_reply is not set, function raises exception. with pytest.raises(json.decoder.JSONDecodeError): await hub._send('test') + + +@pytest.mark.asyncio +async def test_send_timeout(fakeserver): + def handler(input): + time.sleep(0.2) + return '{"message": "ok"}' + await fakeserver.start(handler) + + hub = neohubapi.neohub.NeoHub(host=HOST, port=fakeserver.port, request_timeout=0.1) + + with pytest.raises(asyncio.TimeoutError): + await hub._send('test') + + +@pytest.mark.asyncio +async def test_send_retries(fakeserver): + def handler(input): + return '{"message": "error"}' + await fakeserver.start(handler) + + hub = neohubapi.neohub.NeoHub( + host=HOST, port=fakeserver.port, request_attempts=3, request_timeout=0.1) + + # after 3 attempts the result is still incorrect. + assert await hub._send('test', {'message': 'ok'}) is False + assert len(fakeserver.inputs) == 3