Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

micropython/streampair: Package to create bi-directional linked stream objects. #907

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions micropython/streampair/manifest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
metadata(
description="Create a bi-directional linked pair of stream objects", version="0.0.1"
)

module("streampair.py")
83 changes: 83 additions & 0 deletions micropython/streampair/streampair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import io

from collections import deque
from micropython import ringbuffer, const

try:
from typing import Union, Tuple
except:
pass

# From micropython/py/stream.h
_MP_STREAM_ERROR = const(-1)
_MP_STREAM_FLUSH = const(1)
_MP_STREAM_SEEK = const(2)
_MP_STREAM_POLL = const(3)
_MP_STREAM_CLOSE = const(4)
_MP_STREAM_POLL_RD = const(0x0001)


def streampair(buffer_size: Union[int, Tuple[int, int]]=256):
"""
Returns two bi-directional linked stream objects where writes to one can be read from the other and vice/versa.
This can be used somewhat similarly to a socket.socketpair in python, like a pipe
of data that can be used to connect stream consumers (eg. asyncio.StreamWriter, mock Uart)
"""
try:
size_a, size_b = buffer_size
except TypeError:
size_a = size_b = buffer_size

a = ringbuffer(size_a)
b = ringbuffer(size_b)
return StreamPair(a, b), StreamPair(b, a)


class StreamPair(io.IOBase):

def __init__(self, own: ringbuffer, other: ringbuffer):
self.own = own
self.other = other
super().__init__()

def read(self, nbytes=-1):
return self.own.read(nbytes)

def readline(self):
return self.own.readline()

def readinto(self, buf, limit=-1):
return self.own.readinto(buf, limit)

def write(self, data):
return self.other.write(data)

def seek(self, offset, whence):
return self.own.seek(offset, whence)

def flush(self):
self.own.flush()
self.other.flush()

def close(self):
self.own.close()
self.other.close()

def any(self):
return self.own.any()

def ioctl(self, op, arg):
if op == _MP_STREAM_POLL:
if self.any():
return _MP_STREAM_POLL_RD
return 0

elif op ==_MP_STREAM_FLUSH:
return self.flush()
elif op ==_MP_STREAM_SEEK:
return self.seek(arg[0], arg[1])
elif op ==_MP_STREAM_CLOSE:
return self.close()

else:
return _MP_STREAM_ERROR
54 changes: 54 additions & 0 deletions micropython/streampair/test_streampair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import asyncio
import unittest
from streampair import streampair

def async_test(f):
"""
Decorator to run an async test function
"""
def wrapper(*args, **kwargs):
loop = asyncio.new_event_loop()
# loop.set_exception_handler(_exception_handler)
t = loop.create_task(f(*args, **kwargs))
loop.run_until_complete(t)

return wrapper

class StreamPairTestCase(unittest.TestCase):

def test_streampair(self):
a, b = streampair()
assert a.write(b"foo") == 3
assert b.write(b"bar") == 3

assert (r := a.read()) == b"bar", r
assert (r := b.read()) == b"foo", r

@async_test
async def test_async_streampair(self):
a, b = streampair()
ar = asyncio.StreamReader(a)
bw = asyncio.StreamWriter(b)

br = asyncio.StreamReader(b)
aw = asyncio.StreamWriter(a)

aw.write(b"foo\n")
await aw.drain()
assert not a.any()
assert b.any()
assert (r := await br.readline()) == b"foo\n", r
assert not b.any()
assert not a.any()

bw.write(b"bar\n")
await bw.drain()
assert not b.any()
assert a.any()
assert (r := await ar.readline()) == b"bar\n", r
assert not b.any()
assert not a.any()


if __name__ == "__main__":
unittest.main()
Loading