Skip to content

Commit

Permalink
update unit tests to support video and audio
Browse files Browse the repository at this point in the history
reworked some variable names and types
  • Loading branch information
berysaidi committed Apr 25, 2024
1 parent 9e6a1a0 commit fdfd76e
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 48 deletions.
9 changes: 5 additions & 4 deletions aiortsp/rtsp/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,17 @@ def __init__(
self.connection: Optional[RTSPConnection] = None
self.transport: Optional[RTPTransport] = None
self.session: Optional[RTSPMediaSession] = None
self.payload_types = []
self.payload_types = {}
self.rtp_count = 0
self.rtcp_count = 0


def handle_rtp(self, rtp: RTP):
"""Queue packets for the iterator"""
self.rtp_count += 1
self.logger.info(f'rtp self.payload_types:{self.payload_types} incoming_type:{rtp.pt}')
self.logger.debug(f'rtp self.payload_types:{self.payload_types} incoming_type:{rtp.pt}')

for pt, media_type in self.payload_types:
for media_type, pt in self.payload_types.items():
if pt == rtp.pt:
self.logger.debug(f'adding {media_type} {rtp.pt} to queue')
self.queue.put_nowait((media_type, rtp))
Expand All @@ -69,7 +69,7 @@ def on_ready(self, connection: RTSPConnection, transport: RTPTransport, session:
for media_type in self.media_types:
self.logger.debug(f'setting session media to {media_type}')
pt = session.sdp.media_payload_type(media_type)
self.payload_types.extend([(pt, media_type)])
self.payload_types[media_type] = pt
transport.subscribe(self)
self.connection = connection
self.transport = transport
Expand Down Expand Up @@ -146,6 +146,7 @@ async def iter_packets(self) -> AsyncIterable[Tuple[str, RTP]]:
"""
Yield RTP packets as they come.
User can then do whatever they want, without too much boiler plate.
They can filter based on the media types they selected
"""
while True:
item = await self.queue.get()
Expand Down
6 changes: 3 additions & 3 deletions aiortsp/rtsp/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ async def setup(self):
self.sdp = SDP(resp.content)
self.logger.debug('parsed SDP:\n%s', json.dumps(self.sdp, indent=2))

for index, media_type in enumerate(self.media_types):
for stream_number, media_type in enumerate(self.media_types):
setup_url = self.sdp.setup_url(self.media_url, media_type=media_type)
self.logger.info('setting up using URL: %s', setup_url)

# --- SETUP <url> RTSP/1.0 ---
headers = {}
self.transport.on_transport_request(headers, index)
self.transport.on_transport_request(headers, stream_number)
resp = await self.connection.send_request('SETUP', url=setup_url, headers=headers)
self.transport.on_transport_response(resp.headers, index)
self.transport.on_transport_response(resp.headers, stream_number)
self.logger.info('stream correctly setup: %s', resp)

# Store session ID
Expand Down
21 changes: 10 additions & 11 deletions aiortsp/transport/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
TCP Interleaved transport.
"""
import logging
from typing import Optional, List

from aiortsp.rtcp.parser import RTCP
from aiortsp.rtsp.errors import RTSPError
Expand All @@ -23,23 +24,21 @@ class TCPTransport(RTPTransport):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.rtp_idx = [None] * self.num_streams
self.rtcp_idx = [None] * self.num_streams
self.rtp_idx: List[Optional[int]] = [None] * self.num_streams
self.rtcp_idx: List[Optional[int]] = [None] * self.num_streams

self.receive_buffer = kwargs.get('receive_buffer', DEFAULT_BUFFER_SIZE)
self.send_buffer = kwargs.get('send_buffer', DEFAULT_BUFFER_SIZE)

async def prepare(self):
"""
Register handlers for each stream. Assume each stream has one RTP and one RTCP channel.
Register handlers for each stream. Each stream has one RTP and one RTCP channel.
"""

for idx in range(self.num_streams):
rtp = self.connection.register_binary_handler(self.handle_rtp_bin)
rtcp = self.connection.register_binary_handler(self.handle_rtcp_bin)
self.rtp_idx[idx] = rtp
self.rtcp_idx[idx] = rtcp
self.logger.info('receiving interleaved RTP (%s) and RTCP (%s)', rtp, rtcp)
for stream_number in range(self.num_streams):
self.rtp_idx[stream_number] = self.connection.register_binary_handler(self.handle_rtp_bin)
self.rtcp_idx[stream_number] = self.connection.register_binary_handler(self.handle_rtcp_bin)
self.logger.info(f'receiving interleaved RTP ({self.rtp_idx[stream_number]}) and RTCP ({self.rtcp_idx[stream_number]})')

@property
def running(self) -> bool:
Expand Down Expand Up @@ -68,7 +67,7 @@ def handle_rtp_bin(self, binary: RTSPBinary):
self.handle_rtp_data(binary.data)

def on_transport_request(self, headers: dict, stream_number=0):
if stream_number not in range(len(self.rtp_idx)):
if stream_number not in range(self.num_streams):
raise ValueError(f"Invalid stream number {stream_number}")
rtp_idx = self.rtp_idx[stream_number]
rtcp_idx = self.rtcp_idx[stream_number]
Expand All @@ -91,7 +90,7 @@ def send_rtcp_report(self, rtcp: RTCP, stream_number=0):
"""
Send an RTCP report back to the server for a specific stream.
"""
if stream_number not in range(len(self.rtp_idx)):
if stream_number not in range(self.num_streams):
raise ValueError(f"Invalid stream number {stream_number}")

rtcp_channel = self.rtcp_idx[stream_index]
Expand Down
29 changes: 14 additions & 15 deletions aiortsp/transport/udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import socket
from random import randint
from typing import Tuple
from typing import Tuple, Optional, List

from aiortsp.rtcp.parser import RTCP
from aiortsp.rtsp.errors import RTSPError
Expand Down Expand Up @@ -109,13 +109,12 @@ class UDPTransport(RTPTransport):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.receive_buffer = kwargs.get('receive_buffer', DEFAULT_BUFFER_SIZE)
self.send_buffer = kwargs.get('send_buffer', DEFAULT_BUFFER_SIZE)
self.rtp_sinks = [None] * self.num_streams
self.rtcp_sinks = [None] * self.num_streams
self.server_rtp = [None] * self.num_streams
self.server_rtcp = [None] * self.num_streams
self.rtp_sinks: List[Optional[RTPSink]] = [None] * self.num_streams
self.rtcp_sinks: List[Optional[RTCPSink]] = [None] * self.num_streams
self.server_rtp: List[Optional[int]] = [None] * self.num_streams
self.server_rtcp: List[Optional[int]] = [None] * self.num_streams
self.rtcp_sender = None

@classmethod
Expand Down Expand Up @@ -162,25 +161,25 @@ async def prepare(self):
loop = asyncio.get_event_loop()

# Try to create RTP endpoint
for index in range(len(self.rtp_sinks)):
for stream_number in range(self.num_streams):
rtp_sock, rtcp_sock = await asyncio.wait_for(
loop.run_in_executor(None, self.get_socket_pair), 10)

rtp_transport, self.rtp_sinks[index] = await loop.create_datagram_endpoint(
rtp_transport, self.rtp_sinks[stream_number] = await loop.create_datagram_endpoint(
lambda: RTPSink(self),
sock=rtp_sock
)

# Try to create RTCP endpoint
rtcp_transport, self.rtcp_sinks[index] = await loop.create_datagram_endpoint(
rtcp_transport, self.rtcp_sinks[stream_number] = await loop.create_datagram_endpoint(
lambda: RTCPSink(self),
sock=rtcp_sock
)

self.logger.info(
'UDP Transport ready, will use ports %s-%s',
self.rtp_sinks[index].local_port,
self.rtcp_sinks[index].local_port,
self.rtp_sinks[stream_number].local_port,
self.rtcp_sinks[stream_number].local_port,
)

except Exception:
Expand Down Expand Up @@ -231,10 +230,10 @@ def on_transport_response(self, headers: dict, stream_number=0):

def close(self, error=None):
"""
Perform cleanup, which by default is closing both sinks.
Perform cleanup, which by default is closing all sinks.
"""
super().close(error)
for i in range(len(self.rtp_sinks)):
for i in range(self.num_streams):
self.rtp_sinks[i].close()
self.rtcp_sink[i].close()

Expand All @@ -260,12 +259,12 @@ async def warmup(self):
await super().warmup()

if self.server_rtp is not None:
for i in range(len(self.rtp_sinks)):
for i in range(self.num_streams):
self.logger.info('sending warmup RTP uplink traffic')
self.send_upstream(self.rtp_sinks[i], self.connection.host, self.server_rtp[i])

if self.server_rtcp is not None:
for i in range(len(self.rtp_sinks)):
for i in range(self.num_streams):
self.logger.info('sending warmup RTCP uplink traffic')
self.send_upstream(self.rtcp_sinks[i], self.connection.host, self.server_rtcp[i])

Expand Down
20 changes: 10 additions & 10 deletions examples/camera_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logger = logging.getLogger('RTSPLogger')
logger.propagate = True

def h264_decode(pkt, video_file):
def h264_decode(data, video_file):
'''
5.3. NAL Unit Header Usage
Expand All @@ -30,7 +30,7 @@ def h264_decode(pkt, video_file):
NAL_TYPE_MASK = 0x1F
FU_START_BIT_MASK = 0x80
NAL_START= b'\x00\x00\x00\x01'
nal_type = (pkt.data[0] & NAL_TYPE_MASK)
nal_type = (data[0] & NAL_TYPE_MASK)
'''
https://datatracker.ietf.org/doc/html/rfc6184#section-5.3
Expand All @@ -51,7 +51,7 @@ def h264_decode(pkt, video_file):
if nal_type in range(1, 24):
logger.debug("single NAL unit packet")
video_file.write(NAL_START)
video_file.write(pkt.data)
video_file.write(data)
# FU-A
elif nal_type == 28:
logger.debug("FU-A packet")
Expand Down Expand Up @@ -91,21 +91,21 @@ def h264_decode(pkt, video_file):
fragment NAL units to FU-Bs without organizing the incoming NAL
units to the NAL unit decoding order.
'''
if (pkt.data[1] & FU_START_BIT_MASK):
if (data[1] & FU_START_BIT_MASK):
logger.debug("FU-A packet start!")
video_file.write(NAL_START)
video_file.write(bytes([(pkt.data[0] & FNRI_MASK) | (pkt.data[1] & NAL_TYPE_MASK)]))
video_file.write(pkt.data[2:])
video_file.write(bytes([(data[0] & FNRI_MASK) | (data[1] & NAL_TYPE_MASK)]))
video_file.write(data[2:])
# STAP-A, STAP-B
elif nal_type in [24, 25]:
logger.debug("STRAP A/B packet")
# https://datatracker.ietf.org/doc/html/rfc6184#section-5.7.1
offset = 1
while offset < len(pkt.data):
nal_size = int.from_bytes(pkt.data[offset:offset+2], 'big')
while offset < len(data):
nal_size = int.from_bytes(data[offset:offset+2], 'big')
offset += 2
video_file.write(NAL_START)
video_file.write(pkt.data[offset:offset+nal_size])
video_file.write(data[offset:offset+nal_size])
offset += nal_size
elif nal_type in [26, 27]: # MTAP16, MTAP24
logger.debug("MTAP16 MTAP24 packet (ignored)")
Expand All @@ -123,7 +123,7 @@ async def main():
async for media_type, pkt in reader.iter_packets():
print(f'{media_type} {pkt.pt}')
if media_type == 'video':
h264_decode(pkt, video_file)
h264_decode(pkt.data, video_file)
elif media_type == 'audio':
audio_file.write(pkt.data)
else:
Expand Down
47 changes: 45 additions & 2 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@


@pytest.mark.asyncio
async def test_reader():
async def test_video_reader():
count = 0
server = await asyncio.start_server(handle_client_auth, '127.0.0.1', 5554)
try:
async with RTSPReader('rtspt://127.0.0.1:5554/media.sdp', timeout=2) as reader:
async with RTSPReader('rtspt://127.0.0.1:5554/media.sdp', timeout=2, media_types=['video']) as reader:
async for media_type, pkt in reader.iter_packets():
assert isinstance(pkt, RTP)
assert media_type == 'video'
Expand All @@ -27,6 +27,49 @@ async def test_reader():
finally:
server.close()

@pytest.mark.asyncio
async def test_audio_reader():
count = 0
server = await asyncio.start_server(handle_client_auth, '127.0.0.1', 5554)
try:
async with RTSPReader('rtspt://127.0.0.1:5554/media.sdp', timeout=2, media_types=['audio']) as reader:
async for media_type, pkt in reader.iter_packets():
assert isinstance(pkt, RTP)
assert media_type == 'audio'
count += 1

if count >= 2:
server.close()

assert count == 2
finally:
server.close()

@pytest.mark.asyncio
async def test_video_audio_reader():
count = 0
video_count = 0
audio_count = 0
server = await asyncio.start_server(handle_client_auth, '127.0.0.1', 5554)
try:
async with RTSPReader('rtspt://127.0.0.1:5554/media.sdp', timeout=2, media_types=['video', 'audio']) as reader:
async for media_type, pkt in reader.iter_packets():
assert isinstance(pkt, RTP)
count += 1
if media_type == 'audio':
audio_count += 1
if media_type == 'video':
video_count += 1

if count >= 4:
server.close()

assert count == 4
assert audio_count == 2
assert video_count == 2

finally:
server.close()

@pytest.mark.asyncio
async def test_reader_reconnect():
Expand Down
18 changes: 15 additions & 3 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
async def handle_client_auth(client_reader, client_writer):
parser = RTSPParser()
playing = False
media_num = 0

while True:
data = await client_reader.read(10000)
Expand Down Expand Up @@ -61,20 +62,31 @@ async def handle_client_auth(client_reader, client_writer):
response += 'a=control:trackID=2\r\n'
response += 'a=sendonly\r\n'
elif msg.method == 'SETUP':
response += 'Transport: RTP/AVP/TCP;unicast;interleaved=0-1;ssrc=E6EC9FEF;mode="PLAY"\r\n'
# not the best code there is...
if media_num == 0:
media_num += 1
response += 'Transport: RTP/AVP/TCP;unicast;interleaved=0-1;ssrc=E6EC9FEF;mode="PLAY"\r\n'
else:
response += 'Transport: RTP/AVP/TCP;unicast;interleaved=2-3;ssrc=E6EC9FEF;mode="PLAY"\r\n'
response += 'Session: 2sY7Pd2EPx8JY50-;timeout=60\r\n'

elif msg.method == 'PLAY':
playing = True
response += '\r\n'
print('RESPONSE', response)
client_writer.write(response.encode())

if playing:
# Send 2 RTP packets
# Send 2 Video RTP packets (notice the packet type 96 0x60)
rtp = bytearray.fromhex('2400002080605eaac639ab5e13cd9b86674d0029e29019077f1180b7010101a41e244540'
'2400001080605eabc639ab5e13cd9b8668ee3c80')
client_writer.write(rtp)

# Send 2 Audio RTP packets (notice the packet type 8 0x08)
rtp = bytearray.fromhex('2400002080085eaac639ab5e13cd9b86674d0029e29019077f1180b7010101a41e244540'
'2400001080085eabc639ab5e13cd9b8668ee3c80')
client_writer.write(rtp)

# Send an SR
sr = bytearray.fromhex('80c8000677ae8d65e051bc2bea33b0001fa8034c0000000000000000')
msg = bytearray([ord('$'), 1, 0, len(sr)])
Expand All @@ -95,6 +107,6 @@ async def test_session():
await asyncio.sleep(0.2)
rtcp = sess.stats.build_rtcp()
assert rtcp
assert sess.stats.received == 2
assert sess.stats.received == 4
finally:
server.close()

0 comments on commit fdfd76e

Please sign in to comment.