Add retries, close connection on timeout (fix #6)
This commit is contained in:
parent
e769da03f7
commit
18292bf629
|
@ -28,11 +28,12 @@ class NeoHubConnectionError(Error):
|
||||||
|
|
||||||
|
|
||||||
class NeoHub:
|
class NeoHub:
|
||||||
def __init__(self, host='Neo-Hub', port=4242, request_timeout=5):
|
def __init__(self, host='Neo-Hub', port=4242, request_timeout=5, request_attempts=1):
|
||||||
self._logger = logging.getLogger('neohub')
|
self._logger = logging.getLogger('neohub')
|
||||||
self._host = host
|
self._host = host
|
||||||
self._port = port
|
self._port = port
|
||||||
self._request_timeout = request_timeout
|
self._request_timeout = request_timeout
|
||||||
|
self._request_attempts = request_attempts
|
||||||
|
|
||||||
async def _send_message(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, message: str):
|
async def _send_message(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, message: str):
|
||||||
encoded_message = bytearray(json.dumps(message) + "\0\r", "utf-8")
|
encoded_message = bytearray(json.dumps(message) + "\0\r", "utf-8")
|
||||||
|
@ -49,38 +50,38 @@ class NeoHub:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def _send(self, message, expected_reply=None):
|
async def _send(self, message, expected_reply=None):
|
||||||
|
last_exception = None
|
||||||
|
for attempt in range(1, self._request_attempts+1):
|
||||||
try:
|
try:
|
||||||
reader, writer = await asyncio.open_connection(self._host, self._port)
|
reader, writer = await asyncio.open_connection(self._host, self._port)
|
||||||
except (socket.gaierror, ConnectionRefusedError) as e:
|
|
||||||
err = f'Could not connect to NeoHub at {self._host}: {e}'
|
|
||||||
self._logger.error(err)
|
|
||||||
raise NeoHubConnectionError from e
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = await asyncio.wait_for(
|
data = await asyncio.wait_for(
|
||||||
self._send_message(reader, writer, message), timeout=self._request_timeout)
|
self._send_message(reader, writer, message), timeout=self._request_timeout)
|
||||||
except asyncio.TimeoutError as e:
|
|
||||||
self._logger.error(f'Timeout talking to NeoHub: {e}')
|
|
||||||
return False
|
|
||||||
|
|
||||||
json_string = data.decode('utf-8')
|
json_string = data.decode('utf-8')
|
||||||
self._logger.debug(f"Received message: {json_string}")
|
self._logger.debug(f"Received message: {json_string}")
|
||||||
|
|
||||||
try:
|
|
||||||
reply = json.loads(json_string, object_hook=lambda d: SimpleNamespace(**d))
|
reply = json.loads(json_string, object_hook=lambda d: SimpleNamespace(**d))
|
||||||
except json.decoder.JSONDecodeError as e:
|
|
||||||
if expected_reply is None:
|
|
||||||
raise(e)
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if expected_reply is None:
|
if expected_reply is None:
|
||||||
return reply
|
return reply
|
||||||
else:
|
|
||||||
if reply.__dict__ == expected_reply:
|
if reply.__dict__ == expected_reply:
|
||||||
return True
|
return True
|
||||||
else:
|
self._logger.error(f"[{attempt}] Unexpected reply: {reply}")
|
||||||
self._logger.error(f"Unexpected reply: {reply}")
|
except (socket.gaierror, ConnectionRefusedError) as e:
|
||||||
|
last_exception = NeoHubConnectionError(e)
|
||||||
|
self._logger.error(f"[{attempt}] Could not connect to NeoHub at {self._host}: {e}")
|
||||||
|
except asyncio.TimeoutError as e:
|
||||||
|
last_exception = e
|
||||||
|
self._logger.error(f"[{attempt}] Timed out while sending a message to {self._host}")
|
||||||
|
if writer is not None:
|
||||||
|
writer.close()
|
||||||
|
except json.decoder.JSONDecodeError as e:
|
||||||
|
last_exception = e
|
||||||
|
self._logger.error(f"[{attempt}] Could not decode JSON: {e}")
|
||||||
|
# Wait for 1/2 of the timeout value before retrying.
|
||||||
|
if self._request_attempts > 1 and attempt < self._request_attempts:
|
||||||
|
await asyncio.sleep(self._request_timeout / 2)
|
||||||
|
|
||||||
|
if expected_reply is None and last_exception is not None:
|
||||||
|
raise(last_exception)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def firmware(self):
|
async def firmware(self):
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import pytest
|
import pytest
|
||||||
|
import time
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import neohubapi
|
import neohubapi
|
||||||
|
@ -15,7 +16,8 @@ class FakeProtocol(asyncio.Protocol):
|
||||||
|
|
||||||
def data_received(self, data):
|
def data_received(self, data):
|
||||||
input = data.decode()
|
input = data.decode()
|
||||||
# self.handler() is set by create_protocol below.
|
# self.server and self.handler are set by create_protocol below.
|
||||||
|
self.server.inputs.append(input)
|
||||||
output = self.handler(input).encode() + b'\0'
|
output = self.handler(input).encode() + b'\0'
|
||||||
self.transport.write(output)
|
self.transport.write(output)
|
||||||
self.transport.close()
|
self.transport.close()
|
||||||
|
@ -25,11 +27,13 @@ class FakeServer:
|
||||||
def __init__(self, loop, port):
|
def __init__(self, loop, port):
|
||||||
self.port = port
|
self.port = port
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
|
self.inputs = []
|
||||||
|
|
||||||
async def start(self, handler):
|
async def start(self, handler):
|
||||||
def create_protocol():
|
def create_protocol():
|
||||||
fake_protocol = FakeProtocol()
|
fake_protocol = FakeProtocol()
|
||||||
fake_protocol.handler = handler
|
fake_protocol.handler = handler
|
||||||
|
fake_protocol.server = self
|
||||||
return fake_protocol
|
return fake_protocol
|
||||||
self.server = await self.loop.create_server(create_protocol, HOST, self.port)
|
self.server = await self.loop.create_server(create_protocol, HOST, self.port)
|
||||||
|
|
||||||
|
@ -75,7 +79,35 @@ async def test_send_invalid_json(fakeserver):
|
||||||
|
|
||||||
# expected_reply is set, function returns False.
|
# expected_reply is set, function returns False.
|
||||||
assert await hub._send('test', {'message': 'ok'}) is False
|
assert await hub._send('test', {'message': 'ok'}) is False
|
||||||
|
assert len(fakeserver.inputs) == 1 # by default there are no retries.
|
||||||
|
|
||||||
# expected_reply is not set, function raises exception.
|
# expected_reply is not set, function raises exception.
|
||||||
with pytest.raises(json.decoder.JSONDecodeError):
|
with pytest.raises(json.decoder.JSONDecodeError):
|
||||||
await hub._send('test')
|
await hub._send('test')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_timeout(fakeserver):
|
||||||
|
def handler(input):
|
||||||
|
time.sleep(0.2)
|
||||||
|
return '{"message": "ok"}'
|
||||||
|
await fakeserver.start(handler)
|
||||||
|
|
||||||
|
hub = neohubapi.neohub.NeoHub(host=HOST, port=fakeserver.port, request_timeout=0.1)
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
await hub._send('test')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_retries(fakeserver):
|
||||||
|
def handler(input):
|
||||||
|
return '{"message": "error"}'
|
||||||
|
await fakeserver.start(handler)
|
||||||
|
|
||||||
|
hub = neohubapi.neohub.NeoHub(
|
||||||
|
host=HOST, port=fakeserver.port, request_attempts=3, request_timeout=0.1)
|
||||||
|
|
||||||
|
# after 3 attempts the result is still incorrect.
|
||||||
|
assert await hub._send('test', {'message': 'ok'}) is False
|
||||||
|
assert len(fakeserver.inputs) == 3
|
||||||
|
|
Loading…
Reference in New Issue