diff --git a/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py new file mode 100644 index 000000000000..abd99e087cd3 --- /dev/null +++ b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py @@ -0,0 +1,198 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +from typing import Awaitable, Callable, Dict, Optional, Set + +from google.cloud import _storage_v2 +from google.cloud.storage.asyncio.async_read_object_stream import ( + _AsyncReadObjectStream, +) + +logger = logging.getLogger(__name__) + +_DEFAULT_QUEUE_MAX_SIZE = 100 +_DEFAULT_PUT_TIMEOUT = 20.0 + + +class _StreamError: + """Wraps an error with the stream generation that produced it.""" + + def __init__(self, exception: Exception, generation: int): + self.exception = exception + self.generation = generation + + +class _StreamEnd: + """Signals the stream closed normally.""" + + pass + + +class _StreamMultiplexer: + """Multiplexes concurrent download tasks over a single bidi-gRPC stream. + + Routes responses from a background recv loop to per-task asyncio.Queues + keyed by read_id. Coordinates stream reopening via generation-gated + locking. + + A slow consumer on one task will slow down the entire shared connection + due to bounded queue backpressure propagating through gRPC flow control. + """ + + def __init__( + self, + stream: _AsyncReadObjectStream, + queue_max_size: int = _DEFAULT_QUEUE_MAX_SIZE, + ): + self._stream = stream + self._stream_generation: int = 0 + self._queues: Dict[int, asyncio.Queue] = {} + self._reopen_lock = asyncio.Lock() + self._recv_task: Optional[asyncio.Task] = None + self._queue_max_size = queue_max_size + + @property + def stream_generation(self) -> int: + return self._stream_generation + + def register(self, read_ids: Set[int]) -> asyncio.Queue: + """Register read_ids for a task and return its response queue.""" + queue = asyncio.Queue(maxsize=self._queue_max_size) + for read_id in read_ids: + self._queues[read_id] = queue + return queue + + def unregister(self, read_ids: Set[int]) -> None: + """Remove read_ids from routing.""" + for read_id in read_ids: + self._queues.pop(read_id, None) + + def _get_unique_queues(self) -> Set[asyncio.Queue]: + return set(self._queues.values()) + + async def _put_with_timeout(self, queue: asyncio.Queue, item) -> None: + try: + await asyncio.wait_for(queue.put(item), timeout=_DEFAULT_PUT_TIMEOUT) + except asyncio.TimeoutError: + if queue not in self._get_unique_queues(): + logger.debug("Dropped item for unregistered queue.") + else: + logger.warning( + "Queue full for too long. Dropping item to prevent multiplexer hang." + ) + + def _ensure_recv_loop(self) -> None: + if self._recv_task is None or self._recv_task.done(): + self._recv_task = asyncio.create_task(self._recv_loop()) + + def _stop_recv_loop(self) -> None: + if self._recv_task and not self._recv_task.done(): + self._recv_task.cancel() + + def _put_error_nowait(self, queue: asyncio.Queue, error: _StreamError) -> None: + while True: + try: + queue.put_nowait(error) + break + except asyncio.QueueFull: + try: + queue.get_nowait() + except asyncio.QueueEmpty: + pass + + async def _recv_loop(self) -> None: + try: + while True: + response = await self._stream.recv() + if response is None: + sentinel = _StreamEnd() + await asyncio.gather( + *( + self._put_with_timeout(queue, sentinel) + for queue in self._get_unique_queues() + ) + ) + return + + if response.object_data_ranges: + queues_to_notify: Set[asyncio.Queue] = set() + for data_range in response.object_data_ranges: + read_id = data_range.read_range.read_id + queue = self._queues.get(read_id) + if queue: + queues_to_notify.add(queue) + await asyncio.gather( + *( + self._put_with_timeout(queue, response) + for queue in queues_to_notify + ) + ) + else: + await asyncio.gather( + *( + self._put_with_timeout(queue, response) + for queue in self._get_unique_queues() + ) + ) + except asyncio.CancelledError: + raise + except Exception as e: + error = _StreamError(e, self._stream_generation) + for queue in self._get_unique_queues(): + self._put_error_nowait(queue, error) + + async def send(self, request: _storage_v2.BidiReadObjectRequest) -> int: + self._ensure_recv_loop() + await self._stream.send(request) + return self._stream_generation + + async def reopen_stream( + self, + broken_generation: int, + stream_factory: Callable[[], Awaitable[_AsyncReadObjectStream]], + ) -> None: + async with self._reopen_lock: + if self._stream_generation != broken_generation: + return + self._stop_recv_loop() + if self._recv_task: + try: + await self._recv_task + except (asyncio.CancelledError, Exception): + pass + error = _StreamError(Exception("Stream reopening"), self._stream_generation) + for queue in self._get_unique_queues(): + self._put_error_nowait(queue, error) + try: + await self._stream.close() + except Exception: + pass + self._stream = await stream_factory() + self._stream_generation += 1 + self._ensure_recv_loop() + + async def close(self) -> None: + self._stop_recv_loop() + if self._recv_task: + try: + await self._recv_task + except (asyncio.CancelledError, Exception): + pass + error = _StreamError(Exception("Multiplexer closed"), self._stream_generation) + for queue in self._get_unique_queues(): + self._put_error_nowait(queue, error) diff --git a/packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py b/packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py index cea21cb9ae66..9c0fdf05098a 100644 --- a/packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py +++ b/packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py @@ -25,6 +25,11 @@ from google.cloud import _storage_v2 from google.cloud.storage._helpers import generate_random_56_bit_integer +from google.cloud.storage.asyncio._stream_multiplexer import ( + _StreamEnd, + _StreamError, + _StreamMultiplexer, +) from google.cloud.storage.asyncio.async_grpc_client import ( AsyncGrpcClient, ) @@ -224,9 +229,7 @@ def __init__( self.read_obj_str: Optional[_AsyncReadObjectStream] = None self._is_stream_open: bool = False self._routing_token: Optional[str] = None - self._read_id_to_writable_buffer_dict = {} - self._read_id_to_download_ranges_id = {} - self._download_ranges_id_to_pending_read_ids = {} + self._multiplexer: Optional[_StreamMultiplexer] = None self.persisted_size: Optional[int] = None # updated after opening the stream self._open_retries: int = 0 @@ -328,6 +331,45 @@ async def _do_open(): self._is_stream_open = True await retry_policy(_do_open)() + self._multiplexer = _StreamMultiplexer(self.read_obj_str) + + def _create_stream_factory(self, state, metadata): + """Create a factory that opens a new stream with current routing state.""" + + async def factory(): + current_handle = state.get("read_handle") + current_token = state.get("routing_token") + + stream = _AsyncReadObjectStream( + client=self.client.grpc_client, + bucket_name=self.bucket_name, + object_name=self.object_name, + generation_number=self.generation, + read_handle=current_handle, + ) + + current_metadata = list(metadata) if metadata else [] + if current_token: + current_metadata.append( + ( + "x-goog-request-params", + f"routing_token={current_token}", + ) + ) + + await stream.open(metadata=current_metadata if current_metadata else None) + + if stream.generation_number: + self.generation = stream.generation_number + if stream.read_handle: + self.read_handle = stream.read_handle + + self.read_obj_str = stream + self._is_stream_open = True + + return stream + + return factory async def download_ranges( self, @@ -353,32 +395,8 @@ async def download_ranges( * (0, 0, buffer) : downloads 0 to end , i.e. entire object. * (100, 0, buffer) : downloads from 100 to end. - :type lock: asyncio.Lock - :param lock: (Optional) An asyncio lock to synchronize sends and recvs - on the underlying bidi-GRPC stream. This is required when multiple - coroutines are calling this method concurrently. - - i.e. Example usage with multiple coroutines: - - ``` - lock = asyncio.Lock() - task1 = asyncio.create_task(mrd.download_ranges(ranges1, lock)) - task2 = asyncio.create_task(mrd.download_ranges(ranges2, lock)) - await asyncio.gather(task1, task2) - - ``` - - If user want to call this method serially from multiple coroutines, - then providing a lock is not necessary. - - ``` - await mrd.download_ranges(ranges1) - await mrd.download_ranges(ranges2) - - # ... some other code code... - - ``` + :param lock: (Deprecated) This parameter is deprecated and has no effect. :type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry` :param retry_policy: (Optional) The retry policy to use for the operation. @@ -397,9 +415,6 @@ async def download_ranges( if not self._is_stream_open: raise ValueError("Underlying bidi-gRPC stream is not open") - if lock is None: - lock = asyncio.Lock() - if retry_policy is None: retry_policy = AsyncRetry(predicate=_is_read_retryable) @@ -419,99 +434,97 @@ async def download_ranges( "routing_token": None, } - # Track attempts to manage stream reuse - attempt_count = 0 - - def send_ranges_and_get_bytes( - requests: List[_storage_v2.ReadRange], - state: Dict[str, Any], - metadata: Optional[List[Tuple[str, str]]] = None, - ): - async def generator(): - nonlocal attempt_count - attempt_count += 1 - - if attempt_count > 1: - logger.info( - f"Resuming download (attempt {attempt_count}) for {len(requests)} ranges." - ) + read_ids = set(download_states.keys()) + queue = self._multiplexer.register(read_ids) - async with lock: - current_handle = state.get("read_handle") - current_token = state.get("routing_token") + try: + attempt_count = 0 + last_broken_generation = None - # We reopen if it's a redirect (token exists) OR if this is a retry - # (not first attempt). This prevents trying to send data on a dead - # stream from a previous failed attempt. - should_reopen = ( - (attempt_count > 1) - or (current_token is not None) - or (metadata is not None) - ) + def send_and_recv_via_multiplexer( + requests: List[_storage_v2.ReadRange], + state: Dict[str, Any], + ): + async def generator(): + nonlocal attempt_count, last_broken_generation + attempt_count += 1 - if should_reopen: - if current_token: - logger.info( - f"Re-opening stream with routing token: {current_token}" - ) - - self.read_obj_str = _AsyncReadObjectStream( - client=self.client.grpc_client, - bucket_name=self.bucket_name, - object_name=self.object_name, - generation_number=self.generation, - read_handle=current_handle, + if attempt_count > 1: + logger.info( + f"Resuming download (attempt {attempt_count}) for {len(requests)} ranges." ) - # Inject routing_token into metadata if present - current_metadata = list(metadata) if metadata else [] - if current_token: - current_metadata.append( - ( - "x-goog-request-params", - f"routing_token={current_token}", - ) - ) - - await self.read_obj_str.open( - metadata=current_metadata if current_metadata else None + # Reopen stream if needed + should_reopen = ( + attempt_count > 1 and last_broken_generation is not None + ) or (attempt_count == 1 and metadata is not None) + if should_reopen: + broken_gen = ( + last_broken_generation + if attempt_count > 1 + else self._multiplexer.stream_generation + ) + stream_factory = self._create_stream_factory(state, metadata) + await self._multiplexer.reopen_stream( + broken_gen, stream_factory ) - self._is_stream_open = True - pending_read_ids = {r.read_id for r in requests} + my_generation = self._multiplexer.stream_generation # Send Requests + pending_read_ids = {r.read_id for r in requests} for i in range( 0, len(requests), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST ): batch = requests[i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST] - await self.read_obj_str.send( - _storage_v2.BidiReadObjectRequest(read_ranges=batch) - ) + try: + await self._multiplexer.send( + _storage_v2.BidiReadObjectRequest(read_ranges=batch) + ) + except Exception: + last_broken_generation = my_generation + raise + # Receive Responses while pending_read_ids: - response = await self.read_obj_str.recv() - if response is None: + item = await queue.get() + + if isinstance(item, _StreamEnd): + if pending_read_ids: + last_broken_generation = my_generation + raise exceptions.ServiceUnavailable( + "Stream ended with pending read_ids" + ) break - if response.object_data_ranges: - for data_range in response.object_data_ranges: + + if isinstance(item, _StreamError): + if item.generation < my_generation: + continue # stale error, skip + last_broken_generation = item.generation + raise item.exception + + # Track completion + if item.object_data_ranges: + for data_range in item.object_data_ranges: if data_range.range_end: pending_read_ids.discard( data_range.read_range.read_id ) - yield response + yield item - return generator() + return generator() - strategy = _ReadResumptionStrategy() - retry_manager = _BidiStreamRetryManager( - strategy, lambda r, s: send_ranges_and_get_bytes(r, s, metadata=metadata) - ) + strategy = _ReadResumptionStrategy() + retry_manager = _BidiStreamRetryManager( + strategy, send_and_recv_via_multiplexer + ) - await retry_manager.execute(initial_state, retry_policy) + await retry_manager.execute(initial_state, retry_policy) - if initial_state.get("read_handle"): - self.read_handle = initial_state["read_handle"] + if initial_state.get("read_handle"): + self.read_handle = initial_state["read_handle"] + finally: + self._multiplexer.unregister(read_ids) async def close(self): """ @@ -520,8 +533,15 @@ async def close(self): if not self._is_stream_open: raise ValueError("Underlying bidi-gRPC stream is not open") + if self._multiplexer: + await self._multiplexer.close() + self._multiplexer = None + if self.read_obj_str: - await self.read_obj_str.close() + try: + await self.read_obj_str.close() + except (asyncio.CancelledError, exceptions.GoogleAPICallError): + pass self.read_obj_str = None self._is_stream_open = False diff --git a/packages/google-cloud-storage/tests/system/test_zonal.py b/packages/google-cloud-storage/tests/system/test_zonal.py index edd323b037ec..f30d627ae033 100644 --- a/packages/google-cloud-storage/tests/system/test_zonal.py +++ b/packages/google-cloud-storage/tests/system/test_zonal.py @@ -2,13 +2,14 @@ import asyncio import gc import os +import random import uuid from io import BytesIO # python additional imports import google_crc32c import pytest -from google.api_core.exceptions import FailedPrecondition, NotFound +from google.api_core.exceptions import FailedPrecondition, NotFound, OutOfRange from google.cloud.storage.asyncio.async_appendable_object_writer import ( _DEFAULT_FLUSH_INTERVAL_BYTES, @@ -594,3 +595,194 @@ async def _run(): gc.collect() event_loop.run_until_complete(_run()) + + +def test_mrd_concurrent_download( + storage_client, blobs_to_delete, event_loop, grpc_client +): + """ + Test that mrd can handle concurrent `download_ranges` calls correctly. + Tests overlapping ranges, high concurrency (len > 100 multiplexing batch limits), + mixed random chunk sizes (small/medium/large), and full object fetching alongside specific chunks. + """ + object_size = 15 * 1024 * 1024 # 15MB + object_name = f"test_mrd_concurrent-{uuid.uuid4()}" + + async def _run(): + object_data = os.urandom(object_size) + + writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + await writer.open() + await writer.append(object_data) + await writer.close(finalize_on_close=True) + + async with AsyncMultiRangeDownloader( + grpc_client, _ZONAL_BUCKET, object_name + ) as mrd: + tasks = [] + ranges_to_fetch = [] + + # Overlapping ranges & Mixed random chunk sizes + # Small chunks + for _ in range(60): + start = random.randint(0, object_size - 100) + length = random.randint(1, 100) + ranges_to_fetch.append((start, length)) + # Medium chunks + for _ in range(60): + start = random.randint(0, object_size - 100000) + length = random.randint(100, 100000) + ranges_to_fetch.append((start, length)) + # Large chunks + for _ in range(5): + start = random.randint(0, object_size - 2000000) + length = random.randint(1000000, 2000000) + ranges_to_fetch.append((start, length)) + + # Full object fetching concurrently + ranges_to_fetch.append((0, 0)) + + # High concurrency batching (Total > 100 ranges) + assert len(ranges_to_fetch) > 100 + random.shuffle(ranges_to_fetch) + + buffers = [BytesIO() for _ in range(len(ranges_to_fetch))] + + for idx, (start, length) in enumerate(ranges_to_fetch): + tasks.append( + asyncio.create_task( + mrd.download_ranges([(start, length, buffers[idx])]) + ) + ) + + await asyncio.gather(*tasks) + + # Validation + for idx, (start, length) in enumerate(ranges_to_fetch): + if length == 0: + expected_data = object_data[start:] + else: + expected_data = object_data[start : start + length] + assert buffers[idx].getvalue() == expected_data + + del writer + gc.collect() + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + + event_loop.run_until_complete(_run()) + + +def test_mrd_concurrent_download_cancellation( + storage_client, blobs_to_delete, event_loop, grpc_client +): + """ + Test task cancellation / abort mid-stream. + Tests that downloading gracefully manages memory and internal references + when tasks are canceled during active multiplexing, without breaking remaining downloads. + """ + object_size = 5 * 1024 * 1024 # 5MB + object_name = f"test_mrd_cancel-{uuid.uuid4()}" + + async def _run(): + object_data = os.urandom(object_size) + + writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + await writer.open() + await writer.append(object_data) + await writer.close(finalize_on_close=True) + + async with AsyncMultiRangeDownloader( + grpc_client, _ZONAL_BUCKET, object_name + ) as mrd: + tasks = [] + num_chunks = 100 + chunk_size = object_size // num_chunks + buffers = [BytesIO() for _ in range(num_chunks)] + + for i in range(num_chunks): + start = i * chunk_size + tasks.append( + asyncio.create_task( + mrd.download_ranges([(start, chunk_size, buffers[i])]) + ) + ) + + # Let the loop start sending Bidi requests + await asyncio.sleep(0.01) + + # Cancel a subset of evenly distributed tasks + for i in range(0, num_chunks, 2): + tasks[i].cancel() + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i in range(num_chunks): + if i % 2 == 0: + assert isinstance(results[i], asyncio.CancelledError) + else: + start = i * chunk_size + expected_data = object_data[start : start + chunk_size] + assert buffers[i].getvalue() == expected_data + + del writer + gc.collect() + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + + event_loop.run_until_complete(_run()) + + +def test_mrd_concurrent_download_out_of_bounds( + storage_client, blobs_to_delete, event_loop, grpc_client +): + """ + Test out-of-bounds & edge ranges concurrent with valid requests. + Verifies isolation: invalid bounds generate correct exceptions and don't stall the stream + for concurrently valid requests. + """ + object_size = 2 * 1024 * 1024 # 2MB + object_name = f"test_mrd_oob-{uuid.uuid4()}" + + async def _run(): + object_data = os.urandom(object_size) + + writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + await writer.open() + await writer.append(object_data) + await writer.close(finalize_on_close=True) + + async with AsyncMultiRangeDownloader( + grpc_client, _ZONAL_BUCKET, object_name + ) as mrd: + b_valid = BytesIO() + t_valid = asyncio.create_task(mrd.download_ranges([(0, 100, b_valid)])) + + b_oob1 = BytesIO() + t_oob1 = asyncio.create_task( + mrd.download_ranges([(object_size + 1000, 100, b_oob1)]) + ) + + # EOF ask for 100 bytes + b_oob2 = BytesIO() + t_oob2 = asyncio.create_task( + mrd.download_ranges([(object_size, 100, b_oob2)]) + ) + + results = await asyncio.gather( + t_valid, t_oob1, t_oob2, return_exceptions=True + ) + + # Verify valid one processed correctly + assert b_valid.getvalue() == object_data[:100] + + # Verify fully OOB request returned Exception + assert isinstance(results[1], OutOfRange) + + # Verify request exactly at EOF successfully completed with 0 bytes + assert not isinstance(results[2], Exception) + assert b_oob2.getvalue() == b"" + + del writer + gc.collect() + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + + event_loop.run_until_complete(_run()) diff --git a/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py b/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py index 80df5a438173..9b9f63f32e4f 100644 --- a/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py +++ b/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py @@ -60,6 +60,9 @@ async def _make_mock_mrd( mock_stream.generation_number = _TEST_GENERATION_NUMBER mock_stream.persisted_size = _TEST_OBJECT_SIZE mock_stream.read_handle = _TEST_READ_HANDLE + mock_stream.is_stream_open = True + # Default recv blocks forever (tests override with specific side_effect) + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) mrd = await AsyncMultiRangeDownloader.create_mrd( mock_client, bucket_name, object_name, generation, read_handle @@ -102,69 +105,39 @@ async def test_create_mrd(self, mock_cls_async_read_object_stream): "google.cloud.storage.asyncio.async_multi_range_downloader._AsyncReadObjectStream" ) @pytest.mark.asyncio - async def test_download_ranges_via_async_gather( + async def test_download_ranges( self, mock_cls_async_read_object_stream, mock_random_int ): - # Arrange data = b"these_are_18_chars" crc32c = Checksum(data).digest() crc32c_int = int.from_bytes(crc32c, "big") - crc32c_checksum_for_data_slice = int.from_bytes( - Checksum(data[10:16]).digest(), "big" - ) mock_mrd, _ = await self._make_mock_mrd(mock_cls_async_read_object_stream) - - mock_random_int.side_effect = [456, 91011] + mock_random_int.side_effect = [456] mock_mrd.read_obj_str.send = AsyncMock() - mock_mrd.read_obj_str.recv = AsyncMock() - - mock_mrd.read_obj_str.recv.side_effect = [ - _storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - _storage_v2.ObjectRangeData( - checksummed_data=_storage_v2.ChecksummedData( - content=data, crc32c=crc32c_int - ), - range_end=True, - read_range=_storage_v2.ReadRange( - read_offset=0, read_length=18, read_id=456 - ), - ) - ] - ), - _storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - _storage_v2.ObjectRangeData( - checksummed_data=_storage_v2.ChecksummedData( - content=data[10:16], - crc32c=crc32c_checksum_for_data_slice, - ), - range_end=True, - read_range=_storage_v2.ReadRange( - read_offset=10, read_length=6, read_id=91011 - ), - ) - ], - ), - None, - ] - - # Act - buffer = BytesIO() - second_buffer = BytesIO() - lock = asyncio.Lock() - - task1 = asyncio.create_task(mock_mrd.download_ranges([(0, 18, buffer)], lock)) - task2 = asyncio.create_task( - mock_mrd.download_ranges([(10, 6, second_buffer)], lock) + mock_mrd.read_obj_str.recv = AsyncMock( + side_effect=[ + _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=data, crc32c=crc32c_int + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=0, read_length=18, read_id=456 + ), + ) + ], + ), + None, + ] ) - await asyncio.gather(task1, task2) - # Assert + buffer = BytesIO() + await mock_mrd.download_ranges([(0, 18, buffer)]) assert buffer.getvalue() == data - assert second_buffer.getvalue() == data[10:16] @mock.patch( "google.cloud.storage.asyncio.async_multi_range_downloader.generate_random_56_bit_integer" @@ -173,50 +146,78 @@ async def test_download_ranges_via_async_gather( "google.cloud.storage.asyncio.async_multi_range_downloader._AsyncReadObjectStream" ) @pytest.mark.asyncio - async def test_download_ranges( + async def test_download_ranges_via_async_gather( self, mock_cls_async_read_object_stream, mock_random_int ): - # Arrange data = b"these_are_18_chars" crc32c = Checksum(data).digest() crc32c_int = int.from_bytes(crc32c, "big") + crc32c_checksum_for_data_slice = int.from_bytes( + Checksum(data[10:16]).digest(), "big" + ) mock_mrd, _ = await self._make_mock_mrd(mock_cls_async_read_object_stream) + mock_random_int.side_effect = [456, 91011] - mock_random_int.side_effect = [456] + send_count = 0 + both_sent = asyncio.Event() + + async def counting_send(request): + nonlocal send_count + send_count += 1 + if send_count >= 2: + both_sent.set() + + mock_mrd.read_obj_str.send = AsyncMock(side_effect=counting_send) + + recv_call_count = 0 + + async def controlled_recv(): + nonlocal recv_call_count + recv_call_count += 1 + if recv_call_count == 1: + await both_sent.wait() + return _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=data, crc32c=crc32c_int + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=0, read_length=18, read_id=456 + ), + ) + ] + ) + elif recv_call_count == 2: + return _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=data[10:16], + crc32c=crc32c_checksum_for_data_slice, + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=10, read_length=6, read_id=91011 + ), + ) + ], + ) + return None - mock_mrd.read_obj_str.send = AsyncMock() - mock_mrd.read_obj_str.recv = AsyncMock() - mock_mrd.read_obj_str.recv.side_effect = [ - _storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - _storage_v2.ObjectRangeData( - checksummed_data=_storage_v2.ChecksummedData( - content=data, crc32c=crc32c_int - ), - range_end=True, - read_range=_storage_v2.ReadRange( - read_offset=0, read_length=18, read_id=456 - ), - ) - ], - ), - None, - ] + mock_mrd.read_obj_str.recv = AsyncMock(side_effect=controlled_recv) - # Act buffer = BytesIO() - await mock_mrd.download_ranges([(0, 18, buffer)]) + second_buffer = BytesIO() + + task1 = asyncio.create_task(mock_mrd.download_ranges([(0, 18, buffer)])) + task2 = asyncio.create_task(mock_mrd.download_ranges([(10, 6, second_buffer)])) + await asyncio.gather(task1, task2) - # Assert - mock_mrd.read_obj_str.send.assert_called_once_with( - _storage_v2.BidiReadObjectRequest( - read_ranges=[ - _storage_v2.ReadRange(read_offset=0, read_length=18, read_id=456) - ] - ) - ) assert buffer.getvalue() == data + assert second_buffer.getvalue() == data[10:16] @pytest.mark.asyncio async def test_downloading_ranges_with_more_than_1000_should_throw_error(self): @@ -320,6 +321,7 @@ def test_init_raises_if_crc32c_c_extension_is_missing(self, mock_google_crc32c): async def test_download_ranges_raises_on_checksum_mismatch( self, mock_checksum_class ): + from google.cloud.storage.asyncio._stream_multiplexer import _StreamMultiplexer from google.cloud.storage.asyncio.async_multi_range_downloader import ( AsyncMultiRangeDownloader, ) @@ -353,6 +355,7 @@ async def test_download_ranges_raises_on_checksum_mismatch( mrd = AsyncMultiRangeDownloader(mock_client, "bucket", "object") mrd.read_obj_str = mock_stream mrd._is_stream_open = True + mrd._multiplexer = _StreamMultiplexer(mock_stream) with pytest.raises(DataCorruption) as exc_info: with mock.patch( @@ -419,6 +422,8 @@ async def test_create_mrd_with_generation_number( mock_stream.generation_number = _TEST_GENERATION_NUMBER mock_stream.persisted_size = _TEST_OBJECT_SIZE mock_stream.read_handle = _TEST_READ_HANDLE + mock_stream.is_stream_open = True + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) # Act mrd = await AsyncMultiRangeDownloader.create_mrd( @@ -521,59 +526,50 @@ async def test_on_open_error_logs_warning(self, mock_logger): async def test_download_ranges_resumption_logging( self, mock_cls_async_read_object_stream, mock_random_int, mock_logger ): - # Arrange mock_mrd, _ = await self._make_mock_mrd(mock_cls_async_read_object_stream) - mock_mrd.read_obj_str.send = AsyncMock() - mock_mrd.read_obj_str.recv = AsyncMock() - from google.api_core import exceptions as core_exceptions retryable_exc = core_exceptions.ServiceUnavailable("Retry me") - # mock send to raise exception ONCE then succeed - mock_mrd.read_obj_str.send.side_effect = [ - retryable_exc, - None, # Success on second try - ] - - # mock recv for second try - mock_mrd.read_obj_str.recv.side_effect = [ - _storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - _storage_v2.ObjectRangeData( - checksummed_data=_storage_v2.ChecksummedData( - content=b"data", crc32c=123 - ), - range_end=True, - read_range=_storage_v2.ReadRange( - read_offset=0, read_length=4, read_id=123 - ), - ) - ] - ), - None, - ] + mock_mrd.read_obj_str.send = AsyncMock( + side_effect=[ + retryable_exc, + None, + ] + ) + + recv_call_count = 0 + + async def staged_recv(): + nonlocal recv_call_count + recv_call_count += 1 + if recv_call_count == 1: + return _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=b"data", crc32c=123 + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=0, read_length=4, read_id=123 + ), + ) + ] + ) + return None + + mock_mrd.read_obj_str.recv = AsyncMock(side_effect=staged_recv) + mock_mrd.read_obj_str.is_stream_open = True mock_random_int.return_value = 123 - # Act buffer = BytesIO() - # Patch Checksum where it is likely used (reads_resumption_strategy or similar), - # but actually if we use google_crc32c directly, we should patch that or provide valid CRC. - # Since we can't reliably predict where Checksum is imported/used without more digging, - # let's provide a valid CRC for b"data". - # Checksum(b"data").digest() -> needs to match crc32c=123. - # But we can't force b"data" to have crc=123. - # So we MUST patch Checksum. - # It is used in google.cloud.storage.asyncio.retry.reads_resumption_strategy - with mock.patch( "google.cloud.storage.asyncio.retry.reads_resumption_strategy.Checksum" ) as mock_chk: mock_chk.return_value.digest.return_value = (123).to_bytes(4, "big") - await mock_mrd.download_ranges([(0, 4, buffer)]) - # Assert mock_logger.info.assert_any_call("Resuming download (attempt 2) for 1 ranges.") diff --git a/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py b/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py new file mode 100644 index 000000000000..4bf5bfaf4e3b --- /dev/null +++ b/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py @@ -0,0 +1,503 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from google.cloud import _storage_v2 +from google.cloud.storage.asyncio._stream_multiplexer import ( + _DEFAULT_QUEUE_MAX_SIZE, + _StreamEnd, + _StreamError, + _StreamMultiplexer, +) + + +class TestSentinelTypes: + def test_stream_error_stores_exception_and_generation(self): + exc = ValueError("test") + error = _StreamError(exc, generation=3) + assert error.exception is exc + assert error.generation == 3 + + def test_stream_end_is_instantiable(self): + sentinel = _StreamEnd() + assert isinstance(sentinel, _StreamEnd) + + +class TestStreamMultiplexerInit: + def test_init_sets_stream_and_defaults(self): + mock_stream = AsyncMock() + mux = _StreamMultiplexer(mock_stream) + assert mux._stream is mock_stream + assert mux.stream_generation == 0 + assert mux._queues == {} + assert mux._recv_task is None + assert mux._queue_max_size == _DEFAULT_QUEUE_MAX_SIZE + + def test_init_custom_queue_size(self): + mock_stream = AsyncMock() + mux = _StreamMultiplexer(mock_stream, queue_max_size=50) + assert mux._queue_max_size == 50 + + +def _make_response(read_id, data=b"data", range_end=False): + return _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData(content=data), + read_range=_storage_v2.ReadRange( + read_id=read_id, read_offset=0, read_length=len(data) + ), + range_end=range_end, + ) + ] + ) + + +def _make_multi_range_response(read_ids, data=b"data"): + ranges = [] + for rid in read_ids: + ranges.append( + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData(content=data), + read_range=_storage_v2.ReadRange( + read_id=rid, read_offset=0, read_length=len(data) + ), + ) + ) + return _storage_v2.BidiReadObjectResponse(object_data_ranges=ranges) + + +class TestRegisterUnregister: + def _make_multiplexer(self): + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + return _StreamMultiplexer(mock_stream), mock_stream + + @pytest.mark.asyncio + async def test_register_returns_bounded_queue(self): + mux, _ = self._make_multiplexer() + queue = mux.register({1, 2, 3}) + assert isinstance(queue, asyncio.Queue) + assert queue.maxsize == _DEFAULT_QUEUE_MAX_SIZE + mux.unregister({1, 2, 3}) + + @pytest.mark.asyncio + async def test_register_maps_read_ids_to_same_queue(self): + mux, _ = self._make_multiplexer() + queue = mux.register({10, 20}) + assert mux._queues[10] is queue + assert mux._queues[20] is queue + mux.unregister({10, 20}) + + @pytest.mark.asyncio + async def test_register_does_not_start_recv_loop(self): + mux, _ = self._make_multiplexer() + assert mux._recv_task is None + mux.register({1}) + assert mux._recv_task is None + mux.unregister({1}) + + @pytest.mark.asyncio + async def test_two_registers_get_separate_queues(self): + mux, _ = self._make_multiplexer() + q1 = mux.register({1}) + q2 = mux.register({2}) + assert q1 is not q2 + assert mux._queues[1] is q1 + assert mux._queues[2] is q2 + mux.unregister({1, 2}) + + @pytest.mark.asyncio + async def test_unregister_removes_read_ids(self): + mux, _ = self._make_multiplexer() + mux.register({1, 2}) + mux.unregister({1}) + assert 1 not in mux._queues + assert 2 in mux._queues + mux.unregister({2}) + + @pytest.mark.asyncio + async def test_unregister_all_does_not_stop_recv_loop(self): + mux, _ = self._make_multiplexer() + mux.register({1}) + mux._ensure_recv_loop() + recv_task = mux._recv_task + assert recv_task is not None + mux.unregister({1}) + await asyncio.sleep(0) + assert not recv_task.cancelled() + + @pytest.mark.asyncio + async def test_unregister_nonexistent_is_noop(self): + mux, _ = self._make_multiplexer() + mux.register({1}) + mux.unregister({999}) + assert 1 in mux._queues + mux.unregister({1}) + + +class TestRecvLoop: + @pytest.mark.asyncio + async def test_routes_response_by_read_id(self): + mock_stream = AsyncMock() + resp1 = _make_response(read_id=10, data=b"hello") + resp2 = _make_response(read_id=20, data=b"world") + mock_stream.recv = AsyncMock(side_effect=[resp1, resp2, None]) + + mux = _StreamMultiplexer(mock_stream) + q1 = mux.register({10}) + q2 = mux.register({20}) + mux._ensure_recv_loop() + + item1 = await asyncio.wait_for(q1.get(), timeout=1) + item2 = await asyncio.wait_for(q2.get(), timeout=1) + + assert item1 is resp1 + assert item2 is resp2 + end1 = await asyncio.wait_for(q1.get(), timeout=1) + end2 = await asyncio.wait_for(q2.get(), timeout=1) + assert isinstance(end1, _StreamEnd) + assert isinstance(end2, _StreamEnd) + mux.unregister({10, 20}) + + @pytest.mark.asyncio + async def test_deduplicates_when_multiple_read_ids_map_to_same_queue(self): + mock_stream = AsyncMock() + resp = _make_multi_range_response([10, 11]) + mock_stream.recv = AsyncMock(side_effect=[resp, None]) + + mux = _StreamMultiplexer(mock_stream) + queue = mux.register({10, 11}) + mux._ensure_recv_loop() + + item = await asyncio.wait_for(queue.get(), timeout=1) + assert item is resp + end = await asyncio.wait_for(queue.get(), timeout=1) + assert isinstance(end, _StreamEnd) + mux.unregister({10, 11}) + + @pytest.mark.asyncio + async def test_metadata_only_response_broadcast_to_all(self): + mock_stream = AsyncMock() + metadata_resp = _storage_v2.BidiReadObjectResponse( + read_handle=_storage_v2.BidiReadHandle(handle=b"handle") + ) + mock_stream.recv = AsyncMock(side_effect=[metadata_resp, None]) + + mux = _StreamMultiplexer(mock_stream) + q1 = mux.register({10}) + q2 = mux.register({20}) + mux._ensure_recv_loop() + + item1 = await asyncio.wait_for(q1.get(), timeout=1) + item2 = await asyncio.wait_for(q2.get(), timeout=1) + assert item1 is metadata_resp + assert item2 is metadata_resp + mux.unregister({10, 20}) + + @pytest.mark.asyncio + async def test_stream_end_sends_sentinel_to_all_queues(self): + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(return_value=None) + + mux = _StreamMultiplexer(mock_stream) + q1 = mux.register({10}) + q2 = mux.register({20}) + mux._ensure_recv_loop() + + end1 = await asyncio.wait_for(q1.get(), timeout=1) + end2 = await asyncio.wait_for(q2.get(), timeout=1) + assert isinstance(end1, _StreamEnd) + assert isinstance(end2, _StreamEnd) + mux.unregister({10, 20}) + + @pytest.mark.asyncio + async def test_error_broadcasts_stream_error_to_all_queues(self): + mock_stream = AsyncMock() + exc = RuntimeError("stream broke") + mock_stream.recv = AsyncMock(side_effect=exc) + + mux = _StreamMultiplexer(mock_stream) + q1 = mux.register({10}) + q2 = mux.register({20}) + mux._ensure_recv_loop() + + await asyncio.sleep(0.05) + + err1 = q1.get_nowait() + err2 = q2.get_nowait() + assert isinstance(err1, _StreamError) + assert err1.exception is exc + assert err1.generation == 0 + assert isinstance(err2, _StreamError) + assert err2.exception is exc + mux.unregister({10, 20}) + + @pytest.mark.asyncio + async def test_error_uses_put_nowait(self): + mock_stream = AsyncMock() + exc = RuntimeError("broke") + mock_stream.recv = AsyncMock(side_effect=exc) + + mux = _StreamMultiplexer(mock_stream, queue_max_size=1) + queue = mux.register({10}) + queue.put_nowait("filler") + mux._ensure_recv_loop() + + await asyncio.sleep(0.05) + + # Queue is full (maxsize=1), but _put_error_nowait pops existing items + # to ensure the error gets recorded. + assert queue.qsize() == 1 + err = queue.get_nowait() + assert isinstance(err, _StreamError) + assert err.exception is exc + mux.unregister({10}) + + @pytest.mark.asyncio + async def test_unknown_read_id_is_dropped(self): + mock_stream = AsyncMock() + resp = _make_response(read_id=999) + mock_stream.recv = AsyncMock(side_effect=[resp, None]) + + mux = _StreamMultiplexer(mock_stream) + queue = mux.register({10}) + mux._ensure_recv_loop() + + end = await asyncio.wait_for(queue.get(), timeout=1) + assert isinstance(end, _StreamEnd) + mux.unregister({10}) + + +class TestSend: + @pytest.mark.asyncio + async def test_send_forwards_to_stream(self): + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + + request = _storage_v2.BidiReadObjectRequest( + read_ranges=[ + _storage_v2.ReadRange(read_id=1, read_offset=0, read_length=10) + ] + ) + gen = await mux.send(request) + mock_stream.send.assert_called_once_with(request) + assert gen == 0 + + @pytest.mark.asyncio + async def test_send_returns_current_generation(self): + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + mux._stream_generation = 5 + + request = _storage_v2.BidiReadObjectRequest() + gen = await mux.send(request) + assert gen == 5 + + @pytest.mark.asyncio + async def test_send_propagates_exception(self): + mock_stream = AsyncMock() + mock_stream.send = AsyncMock(side_effect=RuntimeError("send failed")) + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + + with pytest.raises(RuntimeError, match="send failed"): + await mux.send(_storage_v2.BidiReadObjectRequest()) + + +class TestReopenStream: + @pytest.mark.asyncio + async def test_reopen_bumps_generation_and_replaces_stream(self): + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(old_stream) + mux.register({1}) + assert mux.stream_generation == 0 + + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + factory = AsyncMock(return_value=new_stream) + + await mux.reopen_stream(0, factory) + + assert mux.stream_generation == 1 + assert mux._stream is new_stream + factory.assert_called_once() + mux.unregister({1}) + + @pytest.mark.asyncio + async def test_reopen_skips_if_generation_mismatch(self): + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + mux._stream_generation = 5 + mux.register({1}) + + factory = AsyncMock() + await mux.reopen_stream(3, factory) + + assert mux.stream_generation == 5 + factory.assert_not_called() + mux.unregister({1}) + + @pytest.mark.asyncio + async def test_reopen_broadcasts_error_before_bump(self): + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(old_stream) + queue = mux.register({1}) + + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + factory = AsyncMock(return_value=new_stream) + + await mux.reopen_stream(0, factory) + + err = queue.get_nowait() + assert isinstance(err, _StreamError) + assert err.generation == 0 + mux.unregister({1}) + + @pytest.mark.asyncio + async def test_reopen_starts_new_recv_loop(self): + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(old_stream) + mux.register({1}) + old_recv_task = mux._recv_task + + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + factory = AsyncMock(return_value=new_stream) + + await mux.reopen_stream(0, factory) + + assert mux._recv_task is not old_recv_task + assert not mux._recv_task.done() + mux.unregister({1}) + + @pytest.mark.asyncio + async def test_reopen_closes_old_stream_best_effort(self): + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + old_stream.close = AsyncMock(side_effect=RuntimeError("close failed")) + mux = _StreamMultiplexer(old_stream) + mux.register({1}) + + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + factory = AsyncMock(return_value=new_stream) + + await mux.reopen_stream(0, factory) + assert mux.stream_generation == 1 + mux.unregister({1}) + + @pytest.mark.asyncio + async def test_concurrent_reopen_only_one_wins(self): + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(old_stream) + mux.register({1}) + + call_count = 0 + + async def counting_factory(): + nonlocal call_count + call_count += 1 + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + return new_stream + + await asyncio.gather( + mux.reopen_stream(0, counting_factory), + mux.reopen_stream(0, counting_factory), + ) + + assert call_count == 1 + assert mux.stream_generation == 1 + mux.unregister({1}) + + @pytest.mark.asyncio + async def test_reopen_factory_failure_leaves_generation_unchanged(self): + """If stream_factory raises, generation is not bumped and recv loop + is not restarted. The caller's retry manager will re-attempt reopen + with the same generation, which will succeed because the generation + check still matches.""" + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(old_stream) + mux.register({1}) + + failing_factory = AsyncMock(side_effect=RuntimeError("open failed")) + + with pytest.raises(RuntimeError, match="open failed"): + await mux.reopen_stream(0, failing_factory) + + # Generation was NOT bumped + assert mux.stream_generation == 0 + # Recv loop was stopped and not restarted + assert mux._recv_task is None or mux._recv_task.done() + + # A subsequent reopen with the same generation succeeds + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + ok_factory = AsyncMock(return_value=new_stream) + + await mux.reopen_stream(0, ok_factory) + + assert mux.stream_generation == 1 + assert mux._stream is new_stream + assert mux._recv_task is not None and not mux._recv_task.done() + mux.unregister({1}) + + +class TestClose: + @pytest.mark.asyncio + async def test_close_cancels_recv_loop(self): + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + mux.register({1}) + mux._ensure_recv_loop() + recv_task = mux._recv_task + + await mux.close() + assert recv_task.cancelled() + + @pytest.mark.asyncio + async def test_close_broadcasts_terminal_error(self): + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + q1 = mux.register({1}) + q2 = mux.register({2}) + + await mux.close() + + err1 = q1.get_nowait() + err2 = q2.get_nowait() + assert isinstance(err1, _StreamError) + assert isinstance(err2, _StreamError) + + @pytest.mark.asyncio + async def test_close_with_no_tasks_is_noop(self): + mock_stream = AsyncMock() + mux = _StreamMultiplexer(mock_stream) + await mux.close() # should not raise