Add retries, close connection on timeout (fix #6)

This commit is contained in:
Anton Tolchanov 2021-02-06 17:18:26 +00:00
parent e769da03f7
commit 18292bf629
2 changed files with 66 additions and 33 deletions

View File

@ -28,11 +28,12 @@ class NeoHubConnectionError(Error):
class NeoHub: class NeoHub:
def __init__(self, host='Neo-Hub', port=4242, request_timeout=5): def __init__(self, host='Neo-Hub', port=4242, request_timeout=5, request_attempts=1):
self._logger = logging.getLogger('neohub') self._logger = logging.getLogger('neohub')
self._host = host self._host = host
self._port = port self._port = port
self._request_timeout = request_timeout self._request_timeout = request_timeout
self._request_attempts = request_attempts
async def _send_message(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, message: str): async def _send_message(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, message: str):
encoded_message = bytearray(json.dumps(message) + "\0\r", "utf-8") encoded_message = bytearray(json.dumps(message) + "\0\r", "utf-8")
@ -49,39 +50,39 @@ class NeoHub:
return data return data
async def _send(self, message, expected_reply=None): async def _send(self, message, expected_reply=None):
try: last_exception = None
reader, writer = await asyncio.open_connection(self._host, self._port) for attempt in range(1, self._request_attempts+1):
except (socket.gaierror, ConnectionRefusedError) as e: try:
err = f'Could not connect to NeoHub at {self._host}: {e}' reader, writer = await asyncio.open_connection(self._host, self._port)
self._logger.error(err) data = await asyncio.wait_for(
raise NeoHubConnectionError from e self._send_message(reader, writer, message), timeout=self._request_timeout)
json_string = data.decode('utf-8')
self._logger.debug(f"Received message: {json_string}")
reply = json.loads(json_string, object_hook=lambda d: SimpleNamespace(**d))
try: if expected_reply is None:
data = await asyncio.wait_for( return reply
self._send_message(reader, writer, message), timeout=self._request_timeout) if reply.__dict__ == expected_reply:
except asyncio.TimeoutError as e: return True
self._logger.error(f'Timeout talking to NeoHub: {e}') self._logger.error(f"[{attempt}] Unexpected reply: {reply}")
return False except (socket.gaierror, ConnectionRefusedError) as e:
last_exception = NeoHubConnectionError(e)
self._logger.error(f"[{attempt}] Could not connect to NeoHub at {self._host}: {e}")
except asyncio.TimeoutError as e:
last_exception = e
self._logger.error(f"[{attempt}] Timed out while sending a message to {self._host}")
if writer is not None:
writer.close()
except json.decoder.JSONDecodeError as e:
last_exception = e
self._logger.error(f"[{attempt}] Could not decode JSON: {e}")
# Wait for 1/2 of the timeout value before retrying.
if self._request_attempts > 1 and attempt < self._request_attempts:
await asyncio.sleep(self._request_timeout / 2)
json_string = data.decode('utf-8') if expected_reply is None and last_exception is not None:
self._logger.debug(f"Received message: {json_string}") raise(last_exception)
return False
try:
reply = json.loads(json_string, object_hook=lambda d: SimpleNamespace(**d))
except json.decoder.JSONDecodeError as e:
if expected_reply is None:
raise(e)
else:
return False
if expected_reply is None:
return reply
else:
if reply.__dict__ == expected_reply:
return True
else:
self._logger.error(f"Unexpected reply: {reply}")
return False
async def firmware(self): async def firmware(self):
""" """

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
import json import json
import pytest import pytest
import time
from types import SimpleNamespace from types import SimpleNamespace
import neohubapi import neohubapi
@ -15,7 +16,8 @@ class FakeProtocol(asyncio.Protocol):
def data_received(self, data): def data_received(self, data):
input = data.decode() input = data.decode()
# self.handler() is set by create_protocol below. # self.server and self.handler are set by create_protocol below.
self.server.inputs.append(input)
output = self.handler(input).encode() + b'\0' output = self.handler(input).encode() + b'\0'
self.transport.write(output) self.transport.write(output)
self.transport.close() self.transport.close()
@ -25,11 +27,13 @@ class FakeServer:
def __init__(self, loop, port): def __init__(self, loop, port):
self.port = port self.port = port
self.loop = loop self.loop = loop
self.inputs = []
async def start(self, handler): async def start(self, handler):
def create_protocol(): def create_protocol():
fake_protocol = FakeProtocol() fake_protocol = FakeProtocol()
fake_protocol.handler = handler fake_protocol.handler = handler
fake_protocol.server = self
return fake_protocol return fake_protocol
self.server = await self.loop.create_server(create_protocol, HOST, self.port) self.server = await self.loop.create_server(create_protocol, HOST, self.port)
@ -75,7 +79,35 @@ async def test_send_invalid_json(fakeserver):
# expected_reply is set, function returns False. # expected_reply is set, function returns False.
assert await hub._send('test', {'message': 'ok'}) is False assert await hub._send('test', {'message': 'ok'}) is False
assert len(fakeserver.inputs) == 1 # by default there are no retries.
# expected_reply is not set, function raises exception. # expected_reply is not set, function raises exception.
with pytest.raises(json.decoder.JSONDecodeError): with pytest.raises(json.decoder.JSONDecodeError):
await hub._send('test') await hub._send('test')
@pytest.mark.asyncio
async def test_send_timeout(fakeserver):
def handler(input):
time.sleep(0.2)
return '{"message": "ok"}'
await fakeserver.start(handler)
hub = neohubapi.neohub.NeoHub(host=HOST, port=fakeserver.port, request_timeout=0.1)
with pytest.raises(asyncio.TimeoutError):
await hub._send('test')
@pytest.mark.asyncio
async def test_send_retries(fakeserver):
def handler(input):
return '{"message": "error"}'
await fakeserver.start(handler)
hub = neohubapi.neohub.NeoHub(
host=HOST, port=fakeserver.port, request_attempts=3, request_timeout=0.1)
# after 3 attempts the result is still incorrect.
assert await hub._send('test', {'message': 'ok'}) is False
assert len(fakeserver.inputs) == 3