From 4e8ca96cb0b3a2a2acb6a97dd1e06382266a8b29 Mon Sep 17 00:00:00 2001 From: Andrew Leech Date: Wed, 7 Aug 2024 16:38:14 +1000 Subject: [PATCH] micropython/streampair: Package to create bi-directional linked stream objects. Signed-off-by: Andrew Leech --- micropython/streampair/manifest.py | 5 ++ micropython/streampair/streampair.py | 85 +++++++++++++++++++++++ micropython/streampair/test_streampair.py | 44 ++++++++++++ 3 files changed, 134 insertions(+) create mode 100644 micropython/streampair/manifest.py create mode 100644 micropython/streampair/streampair.py create mode 100644 micropython/streampair/test_streampair.py diff --git a/micropython/streampair/manifest.py b/micropython/streampair/manifest.py new file mode 100644 index 000000000..454291696 --- /dev/null +++ b/micropython/streampair/manifest.py @@ -0,0 +1,5 @@ +metadata( + description="Create a bi-directional linked pair of stream objects", version="0.0.1" +) + +module("streampair.py") diff --git a/micropython/streampair/streampair.py b/micropython/streampair/streampair.py new file mode 100644 index 000000000..664e6b7c5 --- /dev/null +++ b/micropython/streampair/streampair.py @@ -0,0 +1,85 @@ +import io + +from collections import deque +from micropython import const + +try: + from typing import Union, Tuple +except: + pass + +# From micropython/py/stream.h +_MP_STREAM_ERROR = const(-1) +_MP_STREAM_FLUSH = const(1) +_MP_STREAM_SEEK = const(2) +_MP_STREAM_POLL = const(3) +_MP_STREAM_CLOSE = const(4) +_MP_STREAM_POLL_RD = const(0x0001) + + +def streampair(buffer_size: Union[int, Tuple[int, int]]=256): + """ + Returns two bi-directional linked stream objects where writes to one can be read from the other and vice/versa. + This can be used somewhat similarly to a socket.socketpair in python, like a pipe + of data that can be used to connect stream consumers (eg. asyncio.StreamWriter, mock Uart) + """ + try: + size_a, size_b = buffer_size + except TypeError: + size_a = size_b = buffer_size + + a = deque(size_a) + b = deque(size_b) + return StreamPair(a, b), StreamPair(b, a) + + +class StreamPair(io.IOBase): + + def __init__(self, own: deque, other: deque): + self.own = own + self.other = other + super().__init__() + + def read(self, nbytes=-1): + return self.own.read(nbytes) + + def readline(self): + return self.own.readline() + + def readinto(self, buf, limit=-1): + return self.own.readinto(buf, limit) + + def write(self, data): + return self.other.write(data) + + def seek(self, offset, whence): + return self.own.seek(offset, whence) + + def flush(self): + self.own.flush() + self.other.flush() + + def close(self): + self.own.close() + self.other.close() + + def any(self): + pos = self.own.tell() + end = self.own.seek(0, 2) + self.own.seek(pos, 0) + return end - pos + + def ioctl(self, op, arg): + if op == _MP_STREAM_POLL: + if self.any(): + return _MP_STREAM_POLL_RD + + elif op ==_MP_STREAM_FLUSH: + return self.flush() + elif op ==_MP_STREAM_SEEK: + return self.seek(arg[0], arg[1]) + elif op ==_MP_STREAM_CLOSE: + return self.close() + + else: + return _MP_STREAM_ERROR diff --git a/micropython/streampair/test_streampair.py b/micropython/streampair/test_streampair.py new file mode 100644 index 000000000..c29aec585 --- /dev/null +++ b/micropython/streampair/test_streampair.py @@ -0,0 +1,44 @@ +import asyncio +import unittest +from streampair import streampair + +def async_test(f): + """ + Decorator to run an async test function + """ + def wrapper(*args, **kwargs): + loop = asyncio.new_event_loop() + # loop.set_exception_handler(_exception_handler) + t = loop.create_task(f(*args, **kwargs)) + loop.run_until_complete(t) + + return wrapper + +class StreamPairTestCase(unittest.TestCase): + + def test_streampair(self): + a, b = streampair() + assert a.write("foo") == 3 + assert b.write("bar") == 3 + + assert (r := a.read()) == "bar", r + assert (r := b.read()) == "foo", r + + @async_test + async def test_async_streampair(self): + a, b = streampair() + ar = asyncio.StreamReader(a) + bw = asyncio.StreamWriter(b) + + br = asyncio.StreamReader(b) + aw = asyncio.StreamWriter(a) + + assert aw.write("foo") == 3 + assert await br.read() == "foo" + + assert bw.write("bar") == 3 + assert await ar.read() == "bar" + + +if __name__ == "__main__": + unittest.main()