diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index bd7d1f55aaee5..08762c215b6ed 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -519,6 +519,7 @@ def __hash__(self): "pyspark.tests.test_util", "pyspark.tests.test_worker", "pyspark.tests.test_stage_sched", + "pyspark.tests.test_zero_copy_byte_stream", # unittests for upstream projects "pyspark.tests.upstream.pyarrow.test_pyarrow_array_cast", "pyspark.tests.upstream.pyarrow.test_pyarrow_array_type_inference", diff --git a/python/pyspark/messages/README.md b/python/pyspark/messages/README.md new file mode 100644 index 0000000000000..95e5962e326b6 --- /dev/null +++ b/python/pyspark/messages/README.md @@ -0,0 +1,23 @@ +## PySpark <-> Spark message interface + +This module implements message-based communication between PySpark and the Spark. +It introduces abstraction layers, which handle the receiving and sending of messages. +Through these abstractions, the underlying data transport channel (Unix domain socket, gRPC, etc.) + can be *decoupled from the core PySpark logic*, which processes the data. + +Overall, introducing these abstractions allows the same PySpark code to work with +different underlying data transport channels transparently. + +This module defines the following message types: + +### Spark -> PySpark + +1. Initialization - UDF Payload, parameters, ... +2. Data - Data to invoke the UDF on +3. Finish - UDF has been invoked on all available data + +### PySpark -> Spark + +1. Response data +2. Exceptions +3. Finish - All processing is done diff --git a/python/pyspark/messages/__init__.py b/python/pyspark/messages/__init__.py new file mode 100644 index 0000000000000..ccb7b9323257f --- /dev/null +++ b/python/pyspark/messages/__init__.py @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.messages.zero_copy_byte_stream import ZeroCopyByteStream + +__all__ = [ + "ZeroCopyByteStream", +] diff --git a/python/pyspark/messages/zero_copy_byte_stream.py b/python/pyspark/messages/zero_copy_byte_stream.py new file mode 100644 index 0000000000000..611b4f928be08 --- /dev/null +++ b/python/pyspark/messages/zero_copy_byte_stream.py @@ -0,0 +1,176 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import threading +from typing import Optional +from collections import deque + + +class ZeroCopyByteStream: + """ + Accepts chunks of bytes as zero-copy memory views. Implements + a file-like interface on top of the received chunks. + + read() calls that access bytes from a single chunk are served + as zero-copy reads. If a read() call crosses chunk boundaries, + memory copies are required. The later case is unexpected and only + implemented for correctness. + + This implementation is thread-safe. + """ + + def __init__(self, initial_view: Optional[memoryview] = None): + if not isinstance(initial_view, memoryview) and initial_view is not None: + raise TypeError( + "Only memoryview and None are allowed as the initial " + + f"ZeroCopyByteStream view. Recveied type {type(initial_view)} instead." + ) + + self._chunks = deque[memoryview]() + self._current_chunk = initial_view + self._current_position = 0 + self._eof = False + self._condition = threading.Condition() + + def add_next_chunk(self, chunk: memoryview) -> None: + """ + Adds the next chunk as a read source. + + Chunks can only be added if the stream has not + been finished before. + + The chunk to be added cannot be None. + """ + if not isinstance(chunk, memoryview): + raise TypeError( + "Only memoryviews can be added to the ZeroCopyByteStreams. " + + f"Received {type(chunk)} instead." + ) + with self._condition: + if self._eof: + raise ValueError("Cannot add chunk after ZeroCopyByteStream has been finished") + self._chunks.append(chunk) + self._condition.notify() + + def finish(self) -> None: + """ + Marks the stream as ended. + + Idempotent: can be called multiple times. + """ + with self._condition: + self._eof = True + self._condition.notify() + + @property + def finished(self) -> bool: + """ + Returns whether the stream has been marked as finished and was fully + consumed. If finished() == True, any attempts to read the stream + will raise an `EOFError`. + """ + with self._condition: + # It is finished if: we read all content of the current chunk, + # there are no remaining chunks, and the input has been marked as done + return self._current_chunk is None and len(self._chunks) == 0 and self._eof + + def _try_read_bytes(self, size: int) -> Optional[memoryview]: + """ + Reads up to ``size`` bytes from the current or next available chunk. + Returns a zero-copy memoryview slice (may be shorter than ``size`` + if the current chunk doesn't have enough data), or None on EOF. + + Blocks until at least some data is available or EOF is reached. + + Internal, assumes to be run inside locked self._condition! + """ + # Ensure we have a current chunk + while self._current_chunk is None: + try: + self._current_chunk = self._chunks.popleft() + self._current_position = 0 + except IndexError: + # No chunks available - check for EOF + if self._eof: + return None + # Block until data arrives or EOF is signaled + self._condition.wait() + + remaining = len(self._current_chunk) - self._current_position + to_read = min(remaining, size) + + # Read slice from current chunk (zero-copy) + result = self._current_chunk[self._current_position : self._current_position + to_read] + self._current_position += to_read + + # If entire chunk consumed, clear it for next chunk + assert self._current_position <= len(self._current_chunk), ( + f"Current position {self._current_position} was unexpectedly " + + f"larger than max position {len(self._current_chunk)}" + ) + if self._current_position == len(self._current_chunk): + self._current_chunk = None + self._current_position = 0 + + return result + + def read(self, size: int) -> memoryview: + """ + Reads size bytes. If the read fails because the + stream was exhausted before size bytes could be read + an `EOFError` will be raised. + + It is required that size >= 0. + """ + if size < 0: + raise ValueError( + "ZeroCopyByteStream.read() cannot be called" + + f" with negative size. Received {size}" + ) + + if size == 0: + return memoryview(b"") + + with self._condition: + first = self._try_read_bytes(size) + + # Zero-copy: either EOF or the read fits in a single chunk + if first is None: + raise EOFError( + f"ZeroCopyByteStream.read() tried to read {size} byte(s), " + + "however, the stream was exhausted after reading 0 byte." + ) + elif len(first) == size: + return first + + # Slow path: read crosses chunk boundaries (requires copy) + buf = bytearray(size) + buf[0 : len(first)] = first + bytes_copied = len(first) + + while bytes_copied < size: + chunk = self._try_read_bytes(size - bytes_copied) + if chunk is None: + raise EOFError( + f"ZeroCopyByteStream.read() tried to read {size} byte(s), " + + "however, the stream was exhausted after reading " + + f"{bytes_copied} byte." + ) + buf[bytes_copied : bytes_copied + len(chunk)] = chunk + bytes_copied += len(chunk) + + return memoryview(buf) diff --git a/python/pyspark/tests/test_zero_copy_byte_stream.py b/python/pyspark/tests/test_zero_copy_byte_stream.py new file mode 100644 index 0000000000000..8c6314dac7d2c --- /dev/null +++ b/python/pyspark/tests/test_zero_copy_byte_stream.py @@ -0,0 +1,272 @@ +# -*- encoding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import threading +import unittest + +from pyspark.messages.zero_copy_byte_stream import ZeroCopyByteStream + + +class ZeroCopyByteStreamTests(unittest.TestCase): + """Tests for ZeroCopyByteStream.""" + + # ---- Basic single-chunk reads (zero-copy fast path) ---- + + def test_read_exact_chunk(self): + stream = ZeroCopyByteStream(initial_view=memoryview(b"hello")) + result = stream.read(5) + self.assertEqual(bytes(result), b"hello") + + def test_read_partial_chunk(self): + stream = ZeroCopyByteStream(initial_view=memoryview(b"hello world")) + stream.finish() + r1 = stream.read(5) + r2 = stream.read(6) + self.assertEqual(bytes(r1), b"hello") + self.assertEqual(bytes(r2), b" world") + # Check EOF read + with self.assertRaises(EOFError): + stream.read(1) + + def test_read_multiple_chunks_sequentially(self): + stream = ZeroCopyByteStream(initial_view=memoryview(b"aaa")) + stream.add_next_chunk(memoryview(b"bbb")) + stream.add_next_chunk(memoryview(b"ccc")) + + self.assertEqual(bytes(stream.read(3)), b"aaa") + self.assertEqual(bytes(stream.read(3)), b"bbb") + self.assertEqual(bytes(stream.read(3)), b"ccc") + + # ---- Cross-boundary reads (slow path with copy) ---- + + def test_read_across_two_chunks(self): + stream = ZeroCopyByteStream(initial_view=memoryview(b"aaa")) + stream.add_next_chunk(memoryview(b"bbb")) + + result = stream.read(6) + self.assertEqual(bytes(result), b"aaabbb") + + def test_read_across_three_chunks(self): + stream = ZeroCopyByteStream(initial_view=memoryview(b"ab")) + stream.add_next_chunk(memoryview(b"cd")) + stream.add_next_chunk(memoryview(b"ef")) + + result = stream.read(6) + self.assertEqual(bytes(result), b"abcdef") + + def test_read_partial_then_cross_boundary(self): + stream = ZeroCopyByteStream(initial_view=memoryview(b"aabb")) + stream.add_next_chunk(memoryview(b"ccdd")) + + # Read first 2 bytes from chunk 1 (zero-copy) + r1 = stream.read(2) + self.assertEqual(bytes(r1), b"aa") + + # Read 4 bytes crossing chunk boundary + r2 = stream.read(4) + self.assertEqual(bytes(r2), b"bbcc") + + # Read remaining 2 bytes from chunk 2 (zero-copy) + r3 = stream.read(2) + self.assertEqual(bytes(r3), b"dd") + + def test_cross_boundary_read_consumes_full_middle_chunk(self): + stream = ZeroCopyByteStream(initial_view=memoryview(b"aa")) + stream.add_next_chunk(memoryview(b"bb")) + stream.add_next_chunk(memoryview(b"cc")) + + # Read 1 byte to offset into first chunk + stream.read(1) + + # Read 5 bytes: 1 from chunk1 + 2 from chunk2 + 2 from chunk3 + result = stream.read(5) + self.assertEqual(bytes(result), b"abbcc") + + # ---- EOF handling ---- + + def test_eof_throws_eof_error(self): + stream = ZeroCopyByteStream() + stream.finish() + + self.assertTrue(stream.finished) + with self.assertRaises(EOFError): + stream.read(1) + + def test_eof_after_consuming_all_data(self): + stream = ZeroCopyByteStream(initial_view=memoryview(b"data")) + stream.finish() + + self.assertFalse(stream.finished) + + result = stream.read(4) + self.assertEqual(bytes(result), b"data") + + self.assertTrue(stream.finished) + with self.assertRaises(EOFError): + stream.read(1) + + def test_eof_with_out_of_bounds_read(self): + stream = ZeroCopyByteStream(initial_view=memoryview(b"data")) + stream.finish() + + result = stream.read(3) + self.assertEqual(bytes(result), b"dat") + + self.assertFalse(stream.finished) + with self.assertRaises(EOFError): + stream.read(2) + + def test_eof_during_cross_boundary_read(self): + """EOF mid-cross-boundary read returns None.""" + stream = ZeroCopyByteStream(initial_view=memoryview(b"ab")) + stream.add_next_chunk(memoryview(b"cd")) + stream.finish() + + # Request more bytes than available; after consuming "abcd", + # _try_read_bytes hits EOF and throws + with self.assertRaises(EOFError): + stream.read(5) + + def test_finished_property(self): + stream = ZeroCopyByteStream(initial_view=memoryview(b"x")) + self.assertFalse(stream.finished) + + stream.read(1) + self.assertFalse(stream.finished) # not yet marked EOF + + stream.finish() + self.assertTrue(stream.finished) + + # ---- Threading / blocking behavior ---- + + def test_read_blocks_until_chunk_available(self): + stream = ZeroCopyByteStream() + result = None + + def reader(): + nonlocal result + result = stream.read(3) + + t = threading.Thread(target=reader) + t.start() + + # Give reader time to block + t.join(timeout=2) + self.assertTrue(t.is_alive(), "Reader should be blocked waiting for data") + + stream.add_next_chunk(memoryview(b"abc")) + t.join(timeout=2) + self.assertFalse(t.is_alive(), "Reader should have unblocked") + self.assertEqual(bytes(result), b"abc") + + def test_read_blocks_until_eof(self): + stream = ZeroCopyByteStream() + read_raised_eof = False + read_called = threading.Event() + + def reader(): + nonlocal read_raised_eof + read_called.set() + try: + stream.read(1) + except EOFError: + read_raised_eof = True + + t = threading.Thread(target=reader) + t.start() + read_called.wait() + + # Give reader time to block + t.join(timeout=2) + self.assertTrue(t.is_alive()) + + stream.finish() + t.join(timeout=2) + self.assertFalse(t.is_alive()) + self.assertTrue(read_raised_eof) + + def test_cross_boundary_read_blocks_for_next_chunk(self): + """Cross-boundary read blocks when second chunk isn't available yet.""" + stream = ZeroCopyByteStream(initial_view=memoryview(b"aa")) + result = None + + def reader(): + nonlocal result + result = stream.read(4) + + t = threading.Thread(target=reader) + t.start() + + # Reader consumed "aa" but needs 2 more bytes + t.join(timeout=2) + self.assertTrue(t.is_alive(), "Reader should block waiting for more data") + + stream.add_next_chunk(memoryview(b"bb")) + t.join(timeout=2) + self.assertFalse(t.is_alive()) + self.assertEqual(bytes(result), b"aabb") + + def test_read_of_zero_bytes_succeeds_without_data(self): + """Reading zero bytes should immediately return, even if no data is present""" + stream = ZeroCopyByteStream(initial_view=None) + + res = stream.read(0) + + self.assertEqual(res, memoryview(b"")) + + def test_negative_read_throws_value_error(self): + stream = ZeroCopyByteStream(initial_view=None) + + with self.assertRaises(ValueError): + stream.read(-1) + + # ---- add_next_chunk assertions ---- + + def test_add_none_chunk_raises(self): + stream = ZeroCopyByteStream() + with self.assertRaises(TypeError): + stream.add_next_chunk(None) + + def test_add_chunk_after_finish_raises(self): + stream = ZeroCopyByteStream() + stream.finish() + with self.assertRaises(ValueError): + stream.add_next_chunk(memoryview(b"data")) + + # ---- Initial view ---- + + def test_no_initial_view(self): + stream = ZeroCopyByteStream() + stream.add_next_chunk(memoryview(b"hello")) + result = stream.read(5) + self.assertEqual(bytes(result), b"hello") + + def test_initial_view_none(self): + stream = ZeroCopyByteStream(initial_view=None) + stream.add_next_chunk(memoryview(b"test")) + result = stream.read(4) + self.assertEqual(bytes(result), b"test") + + def test_invalid_initial_view(self): + with self.assertRaises(TypeError): + ZeroCopyByteStream(initial_view=5) + + +if __name__ == "__main__": + from pyspark.testing import main + + main()