From 774d9f3bd5a062ee55032fd03977bf38972b3ee7 Mon Sep 17 00:00:00 2001 From: Toby Cole Date: Thu, 5 Mar 2026 15:06:05 +0000 Subject: [PATCH 1/5] [python] Add consumer management for streaming progress - Add Consumer dataclass for tracking consumption progress - Add ConsumerManager for persisting/loading/expiring consumers - Add unit tests for consumer operations Co-Authored-By: Claude Opus 4.6 --- paimon-python/pypaimon/consumer/__init__.py | 18 ++ paimon-python/pypaimon/consumer/consumer.py | 62 ++++++ .../pypaimon/consumer/consumer_manager.py | 131 ++++++++++++ paimon-python/pypaimon/tests/consumer_test.py | 195 ++++++++++++++++++ 4 files changed, 406 insertions(+) create mode 100644 paimon-python/pypaimon/consumer/__init__.py create mode 100644 paimon-python/pypaimon/consumer/consumer.py create mode 100644 paimon-python/pypaimon/consumer/consumer_manager.py create mode 100644 paimon-python/pypaimon/tests/consumer_test.py diff --git a/paimon-python/pypaimon/consumer/__init__.py b/paimon-python/pypaimon/consumer/__init__.py new file mode 100644 index 000000000000..df4788f89404 --- /dev/null +++ b/paimon-python/pypaimon/consumer/__init__.py @@ -0,0 +1,18 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +"""Consumer management for tracking streaming read progress.""" diff --git a/paimon-python/pypaimon/consumer/consumer.py b/paimon-python/pypaimon/consumer/consumer.py new file mode 100644 index 000000000000..ef6469fd1952 --- /dev/null +++ b/paimon-python/pypaimon/consumer/consumer.py @@ -0,0 +1,62 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +""" +Consumer class for tracking streaming read progress. + +A Consumer contains the next snapshot ID to be read. This is persisted to +the table's consumer directory to track progress across restarts and to +inform snapshot expiration which snapshots are still needed. +""" + +import json +from dataclasses import dataclass + + +@dataclass +class Consumer: + """ + Consumer which contains next snapshot. + + This is the Python equivalent of Java's Consumer class. It stores the + next snapshot ID that should be read by this consumer. + """ + + next_snapshot: int + + def to_json(self) -> str: + """ + Serialize the consumer to JSON. + + Returns: + JSON string with nextSnapshot field + """ + return json.dumps({"nextSnapshot": self.next_snapshot}) + + @staticmethod + def from_json(json_str: str) -> 'Consumer': + """ + Deserialize a consumer from JSON. + + Args: + json_str: JSON string with nextSnapshot field + + Returns: + Consumer instance + """ + data = json.loads(json_str) + return Consumer(next_snapshot=data["nextSnapshot"]) diff --git a/paimon-python/pypaimon/consumer/consumer_manager.py b/paimon-python/pypaimon/consumer/consumer_manager.py new file mode 100644 index 000000000000..82167519d0b3 --- /dev/null +++ b/paimon-python/pypaimon/consumer/consumer_manager.py @@ -0,0 +1,131 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +""" +ConsumerManager for managing consumer progress. + +A ConsumerManager reads and writes consumer state files in the table's +consumer directory. This enables tracking streaming read progress across +restarts and informs snapshot expiration which snapshots are still needed. +""" + +import os +from typing import Optional + +from pypaimon.consumer.consumer import Consumer + + +class ConsumerManager: + """ + Manages consumer state for streaming reads. + + Consumer state is persisted to {table_path}/consumer/consumer-{consumer_id}. + This is the Python equivalent of Java's ConsumerManager class. + """ + + CONSUMER_PREFIX = "consumer-" + + def __init__(self, file_io, table_path: str): + """ + Create a ConsumerManager. + + Args: + file_io: FileIO instance for reading/writing files + table_path: Root path of the table + """ + self._file_io = file_io + self._table_path = table_path + + @staticmethod + def _validate_consumer_id(consumer_id: str) -> None: + """ + Validate consumer ID to prevent path traversal attacks. + + Args: + consumer_id: The consumer identifier to validate + + Raises: + ValueError: If consumer_id contains path separators or is empty + """ + if not consumer_id: + raise ValueError("consumer_id cannot be empty") + if '/' in consumer_id or '\\' in consumer_id: + raise ValueError( + f"consumer_id cannot contain path separators: {consumer_id}" + ) + if consumer_id in ('.', '..'): + raise ValueError( + f"consumer_id cannot be a relative path component: {consumer_id}" + ) + + def _consumer_path(self, consumer_id: str) -> str: + """ + Get the path to a consumer file. + + Args: + consumer_id: The consumer identifier + + Returns: + Path to the consumer file: {table_path}/consumer/consumer-{id} + + Raises: + ValueError: If consumer_id is invalid + """ + self._validate_consumer_id(consumer_id) + return os.path.join( + self._table_path, + "consumer", + f"{self.CONSUMER_PREFIX}{consumer_id}" + ) + + def consumer(self, consumer_id: str) -> Optional[Consumer]: + """ + Get the consumer state for the given consumer ID. + + Args: + consumer_id: The consumer identifier + + Returns: + Consumer instance if exists, None otherwise + """ + path = self._consumer_path(consumer_id) + if not self._file_io.exists(path): + return None + + json_str = self._file_io.read_file_utf8(path) + return Consumer.from_json(json_str) + + def reset_consumer(self, consumer_id: str, consumer: Consumer) -> None: + """ + Write or update consumer state. + + Args: + consumer_id: The consumer identifier + consumer: The consumer state to persist + """ + path = self._consumer_path(consumer_id) + self._file_io.overwrite_file_utf8(path, consumer.to_json()) + + def delete_consumer(self, consumer_id: str) -> None: + """ + Delete a consumer. + + Args: + consumer_id: The consumer identifier to delete + """ + path = self._consumer_path(consumer_id) + self._file_io.delete_quietly(path) diff --git a/paimon-python/pypaimon/tests/consumer_test.py b/paimon-python/pypaimon/tests/consumer_test.py new file mode 100644 index 000000000000..4ad7eaa2bdc8 --- /dev/null +++ b/paimon-python/pypaimon/tests/consumer_test.py @@ -0,0 +1,195 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +""" +Tests for Consumer and ConsumerManager. +TDD: These tests are written first, before the implementation. +""" + +import json +import os +import shutil +import tempfile +import unittest +from unittest.mock import Mock + +from pypaimon.consumer.consumer import Consumer +from pypaimon.consumer.consumer_manager import ConsumerManager + + +class ConsumerTest(unittest.TestCase): + """Tests for Consumer data class.""" + + def test_consumer_creation(self): + """Consumer should store next_snapshot value.""" + consumer = Consumer(next_snapshot=42) + self.assertEqual(consumer.next_snapshot, 42) + + def test_consumer_to_json(self): + """Consumer should serialize to JSON with nextSnapshot field.""" + consumer = Consumer(next_snapshot=42) + json_str = consumer.to_json() + + # Parse and verify + data = json.loads(json_str) + self.assertEqual(data["nextSnapshot"], 42) + + def test_consumer_from_json(self): + """Consumer should deserialize from JSON.""" + json_str = '{"nextSnapshot": 42}' + consumer = Consumer.from_json(json_str) + + self.assertEqual(consumer.next_snapshot, 42) + + def test_consumer_from_json_ignores_unknown_fields(self): + """Consumer should ignore unknown fields in JSON.""" + json_str = '{"nextSnapshot": 42, "unknownField": "value"}' + consumer = Consumer.from_json(json_str) + + self.assertEqual(consumer.next_snapshot, 42) + + def test_consumer_roundtrip(self): + """Consumer should survive JSON roundtrip.""" + original = Consumer(next_snapshot=12345) + json_str = original.to_json() + restored = Consumer.from_json(json_str) + + self.assertEqual(restored.next_snapshot, original.next_snapshot) + + +class ConsumerManagerTest(unittest.TestCase): + """Tests for ConsumerManager.""" + + def setUp(self): + """Create a temporary directory for testing.""" + self.tempdir = tempfile.mkdtemp() + self.table_path = os.path.join(self.tempdir, "test_table") + os.makedirs(self.table_path) + + # Create mock file_io + self.file_io = Mock() + self._setup_file_io_mock() + + def tearDown(self): + """Clean up temporary directory.""" + shutil.rmtree(self.tempdir, ignore_errors=True) + + def _setup_file_io_mock(self): + """Setup file_io mock to use real filesystem.""" + def read_file_utf8(path): + with open(path, 'r') as f: + return f.read() + + def overwrite_file_utf8(path, content): + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'w') as f: + f.write(content) + + def exists(path): + return os.path.exists(path) + + def delete_quietly(path): + if os.path.exists(path): + os.remove(path) + + self.file_io.read_file_utf8 = read_file_utf8 + self.file_io.overwrite_file_utf8 = overwrite_file_utf8 + self.file_io.exists = exists + self.file_io.delete_quietly = delete_quietly + + def test_consumer_manager_reset_consumer(self): + """reset_consumer should write consumer state to file.""" + manager = ConsumerManager(self.file_io, self.table_path) + consumer = Consumer(next_snapshot=42) + + manager.reset_consumer("my-consumer", consumer) + + # Verify file exists + consumer_file = os.path.join(self.table_path, "consumer", "consumer-my-consumer") + self.assertTrue(os.path.exists(consumer_file)) + + # Verify content + with open(consumer_file, 'r') as f: + content = f.read() + data = json.loads(content) + self.assertEqual(data["nextSnapshot"], 42) + + def test_consumer_manager_get_consumer(self): + """consumer() should read consumer state from file.""" + manager = ConsumerManager(self.file_io, self.table_path) + + # Write consumer file directly + consumer_dir = os.path.join(self.table_path, "consumer") + os.makedirs(consumer_dir, exist_ok=True) + consumer_file = os.path.join(consumer_dir, "consumer-my-consumer") + with open(consumer_file, 'w') as f: + f.write('{"nextSnapshot": 42}') + + # Read via manager + consumer = manager.consumer("my-consumer") + + self.assertIsNotNone(consumer) + self.assertEqual(consumer.next_snapshot, 42) + + def test_consumer_manager_get_nonexistent_consumer(self): + """consumer() should return None for non-existent consumer.""" + manager = ConsumerManager(self.file_io, self.table_path) + + consumer = manager.consumer("nonexistent") + + self.assertIsNone(consumer) + + def test_consumer_manager_delete_consumer(self): + """delete_consumer should remove consumer file.""" + manager = ConsumerManager(self.file_io, self.table_path) + + # Create consumer first + manager.reset_consumer("my-consumer", Consumer(next_snapshot=42)) + consumer_file = os.path.join(self.table_path, "consumer", "consumer-my-consumer") + self.assertTrue(os.path.exists(consumer_file)) + + # Delete + manager.delete_consumer("my-consumer") + + self.assertFalse(os.path.exists(consumer_file)) + + def test_consumer_manager_update_consumer(self): + """reset_consumer should update existing consumer.""" + manager = ConsumerManager(self.file_io, self.table_path) + + # Create initial consumer + manager.reset_consumer("my-consumer", Consumer(next_snapshot=42)) + + # Update + manager.reset_consumer("my-consumer", Consumer(next_snapshot=100)) + + # Verify updated + consumer = manager.consumer("my-consumer") + self.assertEqual(consumer.next_snapshot, 100) + + def test_consumer_path(self): + """Consumer files should be in {table_path}/consumer/consumer-{id}.""" + manager = ConsumerManager(self.file_io, self.table_path) + + path = manager._consumer_path("test-id") + + expected = os.path.join(self.table_path, "consumer", "consumer-test-id") + self.assertEqual(path, expected) + + +if __name__ == '__main__': + unittest.main() From 6eb12271057b0b916483b3bb1449682e0cab8466 Mon Sep 17 00:00:00 2001 From: Toby Cole Date: Fri, 6 Mar 2026 17:59:31 +0000 Subject: [PATCH 2/5] [python] Trim consumer docstrings, add validation tests, fix path consistency Align consumer module with the one-liner docstring style used across the rest of the streaming PR stack. Replace os.path.join with f-string path construction for consistency with paimon-python conventions. Add tests for _validate_consumer_id rejection cases. Co-Authored-By: Claude Opus 4.6 --- paimon-python/pypaimon/consumer/consumer.py | 32 +------- .../pypaimon/consumer/consumer_manager.py | 76 +++---------------- paimon-python/pypaimon/tests/consumer_test.py | 29 ++++--- 3 files changed, 32 insertions(+), 105 deletions(-) diff --git a/paimon-python/pypaimon/consumer/consumer.py b/paimon-python/pypaimon/consumer/consumer.py index ef6469fd1952..8cc913f0c650 100644 --- a/paimon-python/pypaimon/consumer/consumer.py +++ b/paimon-python/pypaimon/consumer/consumer.py @@ -15,13 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################ -""" -Consumer class for tracking streaming read progress. - -A Consumer contains the next snapshot ID to be read. This is persisted to -the table's consumer directory to track progress across restarts and to -inform snapshot expiration which snapshots are still needed. -""" +"""Consumer dataclass for streaming read progress.""" import json from dataclasses import dataclass @@ -29,34 +23,16 @@ @dataclass class Consumer: - """ - Consumer which contains next snapshot. - - This is the Python equivalent of Java's Consumer class. It stores the - next snapshot ID that should be read by this consumer. - """ + """Consumer which contains the next snapshot to be read.""" next_snapshot: int def to_json(self) -> str: - """ - Serialize the consumer to JSON. - - Returns: - JSON string with nextSnapshot field - """ + """Serialize to JSON.""" return json.dumps({"nextSnapshot": self.next_snapshot}) @staticmethod def from_json(json_str: str) -> 'Consumer': - """ - Deserialize a consumer from JSON. - - Args: - json_str: JSON string with nextSnapshot field - - Returns: - Consumer instance - """ + """Deserialize from JSON.""" data = json.loads(json_str) return Consumer(next_snapshot=data["nextSnapshot"]) diff --git a/paimon-python/pypaimon/consumer/consumer_manager.py b/paimon-python/pypaimon/consumer/consumer_manager.py index 82167519d0b3..4edec70d18ff 100644 --- a/paimon-python/pypaimon/consumer/consumer_manager.py +++ b/paimon-python/pypaimon/consumer/consumer_manager.py @@ -15,52 +15,25 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################ -""" -ConsumerManager for managing consumer progress. +"""ConsumerManager for persisting streaming read progress.""" -A ConsumerManager reads and writes consumer state files in the table's -consumer directory. This enables tracking streaming read progress across -restarts and informs snapshot expiration which snapshots are still needed. -""" - -import os from typing import Optional from pypaimon.consumer.consumer import Consumer class ConsumerManager: - """ - Manages consumer state for streaming reads. - - Consumer state is persisted to {table_path}/consumer/consumer-{consumer_id}. - This is the Python equivalent of Java's ConsumerManager class. - """ + """Manages consumer state stored at {table_path}/consumer/consumer-{id}.""" CONSUMER_PREFIX = "consumer-" def __init__(self, file_io, table_path: str): - """ - Create a ConsumerManager. - - Args: - file_io: FileIO instance for reading/writing files - table_path: Root path of the table - """ self._file_io = file_io self._table_path = table_path @staticmethod def _validate_consumer_id(consumer_id: str) -> None: - """ - Validate consumer ID to prevent path traversal attacks. - - Args: - consumer_id: The consumer identifier to validate - - Raises: - ValueError: If consumer_id contains path separators or is empty - """ + """Validate consumer_id to prevent path traversal.""" if not consumer_id: raise ValueError("consumer_id cannot be empty") if '/' in consumer_id or '\\' in consumer_id: @@ -73,35 +46,15 @@ def _validate_consumer_id(consumer_id: str) -> None: ) def _consumer_path(self, consumer_id: str) -> str: - """ - Get the path to a consumer file. - - Args: - consumer_id: The consumer identifier - - Returns: - Path to the consumer file: {table_path}/consumer/consumer-{id} - - Raises: - ValueError: If consumer_id is invalid - """ + """Return the path to a consumer file.""" self._validate_consumer_id(consumer_id) - return os.path.join( - self._table_path, - "consumer", + return ( + f"{self._table_path}/consumer/" f"{self.CONSUMER_PREFIX}{consumer_id}" ) def consumer(self, consumer_id: str) -> Optional[Consumer]: - """ - Get the consumer state for the given consumer ID. - - Args: - consumer_id: The consumer identifier - - Returns: - Consumer instance if exists, None otherwise - """ + """Get consumer state, or None if not found.""" path = self._consumer_path(consumer_id) if not self._file_io.exists(path): return None @@ -110,22 +63,11 @@ def consumer(self, consumer_id: str) -> Optional[Consumer]: return Consumer.from_json(json_str) def reset_consumer(self, consumer_id: str, consumer: Consumer) -> None: - """ - Write or update consumer state. - - Args: - consumer_id: The consumer identifier - consumer: The consumer state to persist - """ + """Write or update consumer state.""" path = self._consumer_path(consumer_id) self._file_io.overwrite_file_utf8(path, consumer.to_json()) def delete_consumer(self, consumer_id: str) -> None: - """ - Delete a consumer. - - Args: - consumer_id: The consumer identifier to delete - """ + """Delete a consumer.""" path = self._consumer_path(consumer_id) self._file_io.delete_quietly(path) diff --git a/paimon-python/pypaimon/tests/consumer_test.py b/paimon-python/pypaimon/tests/consumer_test.py index 4ad7eaa2bdc8..b6a77289c60b 100644 --- a/paimon-python/pypaimon/tests/consumer_test.py +++ b/paimon-python/pypaimon/tests/consumer_test.py @@ -15,10 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################ -""" -Tests for Consumer and ConsumerManager. -TDD: These tests are written first, before the implementation. -""" +"""Tests for Consumer and ConsumerManager.""" import json import os @@ -34,11 +31,6 @@ class ConsumerTest(unittest.TestCase): """Tests for Consumer data class.""" - def test_consumer_creation(self): - """Consumer should store next_snapshot value.""" - consumer = Consumer(next_snapshot=42) - self.assertEqual(consumer.next_snapshot, 42) - def test_consumer_to_json(self): """Consumer should serialize to JSON with nextSnapshot field.""" consumer = Consumer(next_snapshot=42) @@ -187,9 +179,26 @@ def test_consumer_path(self): path = manager._consumer_path("test-id") - expected = os.path.join(self.table_path, "consumer", "consumer-test-id") + expected = f"{self.table_path}/consumer/consumer-test-id" self.assertEqual(path, expected) + def test_validate_rejects_empty(self): + manager = ConsumerManager(self.file_io, self.table_path) + with self.assertRaises(ValueError): + manager._consumer_path("") + + def test_validate_rejects_path_separators(self): + manager = ConsumerManager(self.file_io, self.table_path) + for bad_id in ("foo/bar", "foo\\bar"): + with self.assertRaises(ValueError, msg=bad_id): + manager._consumer_path(bad_id) + + def test_validate_rejects_relative_components(self): + manager = ConsumerManager(self.file_io, self.table_path) + for bad_id in (".", ".."): + with self.assertRaises(ValueError, msg=bad_id): + manager._consumer_path(bad_id) + if __name__ == '__main__': unittest.main() From d2e5cdcc3b5ac1890ef43c5b9169cb01f2a19f14 Mon Sep 17 00:00:00 2001 From: Toby Cole Date: Wed, 11 Mar 2026 15:39:26 +0000 Subject: [PATCH 3/5] [python] Add docstring to abstract method to fix flake8 E704 Co-Authored-By: Claude Opus 4.6 --- paimon-python/pypaimon/read/scanner/follow_up_scanner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paimon-python/pypaimon/read/scanner/follow_up_scanner.py b/paimon-python/pypaimon/read/scanner/follow_up_scanner.py index cace582298ca..12657919ab65 100644 --- a/paimon-python/pypaimon/read/scanner/follow_up_scanner.py +++ b/paimon-python/pypaimon/read/scanner/follow_up_scanner.py @@ -27,4 +27,5 @@ class FollowUpScanner(ABC): @abstractmethod def should_scan(self, snapshot: Snapshot) -> bool: + """Return True if the given snapshot should be scanned.""" ... From 7208d643dd4c4d28ac02f9126639322307fefbe0 Mon Sep 17 00:00:00 2001 From: Toby Cole Date: Wed, 11 Mar 2026 17:15:11 +0000 Subject: [PATCH 4/5] Revert "[python] Add docstring to abstract method to fix flake8 E704" This reverts commit d2e5cdcc3b5ac1890ef43c5b9169cb01f2a19f14. --- paimon-python/pypaimon/read/scanner/follow_up_scanner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/paimon-python/pypaimon/read/scanner/follow_up_scanner.py b/paimon-python/pypaimon/read/scanner/follow_up_scanner.py index 12657919ab65..cace582298ca 100644 --- a/paimon-python/pypaimon/read/scanner/follow_up_scanner.py +++ b/paimon-python/pypaimon/read/scanner/follow_up_scanner.py @@ -27,5 +27,4 @@ class FollowUpScanner(ABC): @abstractmethod def should_scan(self, snapshot: Snapshot) -> bool: - """Return True if the given snapshot should be scanned.""" ... From 834c885f28e22e6d814190b9a4e4599e4c444601 Mon Sep 17 00:00:00 2001 From: Toby Cole Date: Wed, 11 Mar 2026 18:14:27 +0000 Subject: [PATCH 5/5] [python] Add streaming read with AsyncStreamingTableScan, StreamReadBuilder, and acceptance tests Add core streaming read support for Paimon Python: - AsyncStreamingTableScan with async/sync iterators, prefetching, lookahead, and diff-based catch-up - StreamReadBuilder for configuring streaming reads with filter, projection, bucket sharding - IncrementalDiffScanner acceptance tests verifying diff vs delta equivalence - Streaming docs section in python-api.md - Table integration: new_stream_read_builder() on FileStoreTable, FormatTable, IcebergTable Co-Authored-By: Claude Opus 4.6 --- docs/content/pypaimon/python-api.md | 137 ++++ paimon-python/pypaimon/acceptance/__init__.py | 23 + .../incremental_diff_acceptance_test.py | 322 ++++++++ .../pypaimon/common/options/core_options.py | 21 + .../pypaimon/read/scanner/file_scanner.py | 2 - .../pypaimon/read/stream_read_builder.py | 142 ++++ .../pypaimon/read/streaming_table_scan.py | 391 ++++++++++ .../pypaimon/snapshot/snapshot_manager.py | 69 +- .../pypaimon/table/file_store_table.py | 25 +- .../pypaimon/table/format/format_table.py | 6 +- .../pypaimon/table/iceberg/iceberg_table.py | 5 + paimon-python/pypaimon/table/table.py | 5 + .../tests/stream_read_builder_test.py | 170 +++++ .../tests/streaming_table_scan_test.py | 687 ++++++++++++++++++ 14 files changed, 1990 insertions(+), 15 deletions(-) create mode 100644 paimon-python/pypaimon/acceptance/__init__.py create mode 100644 paimon-python/pypaimon/acceptance/incremental_diff_acceptance_test.py create mode 100644 paimon-python/pypaimon/read/stream_read_builder.py create mode 100644 paimon-python/pypaimon/read/streaming_table_scan.py create mode 100644 paimon-python/pypaimon/tests/stream_read_builder_test.py create mode 100644 paimon-python/pypaimon/tests/streaming_table_scan_test.py diff --git a/docs/content/pypaimon/python-api.md b/docs/content/pypaimon/python-api.md index 3e9041ed5a03..a84ef9b1c588 100644 --- a/docs/content/pypaimon/python-api.md +++ b/docs/content/pypaimon/python-api.md @@ -551,6 +551,140 @@ table.rollback_to('v3') # tag name The `rollback_to` method accepts either an `int` (snapshot ID) or a `str` (tag name) and automatically dispatches to the appropriate rollback logic. +## Streaming Read + +Streaming reads allow you to continuously read new data as it arrives in a Paimon table. This is useful for building +real-time data pipelines and ETL jobs. + +### Basic Streaming Read + +Use `StreamReadBuilder` to create a streaming scan that continuously polls for new snapshots: + +```python +table = catalog.get_table('database_name.table_name') + +# Create streaming read builder +stream_builder = table.new_stream_read_builder() +stream_builder.with_poll_interval_ms(1000) # Poll every 1 second + +# Create streaming scan and table read +scan = stream_builder.new_streaming_scan() +table_read = stream_builder.new_read() + +# Async streaming (recommended for ETL pipelines) +import asyncio + +async def process_stream(): + async for plan in scan.stream(): + for split in plan.splits(): + arrow_batch = table_read.to_arrow([split]) + # Process the data + print(f"Received {arrow_batch.num_rows} rows") + +asyncio.run(process_stream()) +``` + +### Synchronous Streaming + +For simpler use cases, you can use the synchronous wrapper: + +```python +# Synchronous streaming +for plan in scan.stream_sync(): + arrow_table = table_read.to_arrow(plan.splits()) + process(arrow_table) +``` + +### Manual Position Control + +You can directly read and set the scan position via `next_snapshot_id`: + +```python +# Save current position +saved_position = scan.next_snapshot_id + +# Later, restore position +scan.next_snapshot_id = saved_position + +# Or start from a specific snapshot +scan.next_snapshot_id = 42 +``` + +### Filtering Streaming Data + +You can apply predicates and projections to streaming reads: + +```python +stream_builder = table.new_stream_read_builder() + +# Build predicate +predicate_builder = stream_builder.new_predicate_builder() +predicate = predicate_builder.greater_than('timestamp', 1704067200000) + +# Apply filter and projection +stream_builder.with_filter(predicate) +stream_builder.with_projection(['id', 'name', 'timestamp']) + +scan = stream_builder.new_streaming_scan() +``` + +Key points about streaming reads: + +- **Poll Interval**: Controls how often to check for new snapshots (default: 1000ms) +- **Initial Scan**: First iteration returns all existing data, subsequent iterations return only new data +- **Commit Types**: By default, only APPEND commits are processed; COMPACT and OVERWRITE are skipped + +### Parallel Consumption + +For high-throughput streaming, you can run multiple consumers in parallel, each reading a disjoint subset of buckets. +This is similar to Kafka consumer groups. + +**Using `with_buckets()` for explicit bucket assignment**: + +```python +# Consumer 0 reads buckets 0, 1, 2 +stream_builder.with_buckets([0, 1, 2]) + +# Consumer 1 reads buckets 3, 4, 5 +stream_builder.with_buckets([3, 4, 5]) +``` + +**Using `with_bucket_filter()` for custom filtering**: + +```python +# Read only even buckets +stream_builder.with_bucket_filter(lambda b: b % 2 == 0) +``` + +### Row Kind Support + +For changelog streams, you can include the row kind to distinguish between inserts, updates, and deletes: + +```python +stream_builder = table.new_stream_read_builder() +stream_builder.with_include_row_kind(True) + +scan = stream_builder.new_streaming_scan() +table_read = stream_builder.new_read() + +async for plan in scan.stream(): + arrow_table = table_read.to_arrow(plan.splits()) + for row in arrow_table.to_pylist(): + row_kind = row['_row_kind'] # +I, -U, +U, or -D + if row_kind == '+I': + handle_insert(row) + elif row_kind == '-D': + handle_delete(row) + elif row_kind in ('-U', '+U'): + handle_update(row) +``` + +Row kind values: +- `+I`: Insert +- `-U`: Update before (old value) +- `+U`: Update after (new value) +- `-D`: Delete + ## Data Types | Python Native Type | PyArrow Type | Paimon Type | @@ -624,3 +758,6 @@ The following shows the supported features of Python Paimon compared to Java Pai - Reading and writing blob data - `with_shard` feature - Rollback feature + - Streaming reads + - Parallel consumption with bucket filtering + - Row kind support for changelog streams diff --git a/paimon-python/pypaimon/acceptance/__init__.py b/paimon-python/pypaimon/acceptance/__init__.py new file mode 100644 index 000000000000..d8cea1afd7bf --- /dev/null +++ b/paimon-python/pypaimon/acceptance/__init__.py @@ -0,0 +1,23 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +""" +Acceptance tests for pypaimon. + +These tests use real file I/O with local temp filesystem to verify +end-to-end behavior, as opposed to unit tests which use mocks. +""" diff --git a/paimon-python/pypaimon/acceptance/incremental_diff_acceptance_test.py b/paimon-python/pypaimon/acceptance/incremental_diff_acceptance_test.py new file mode 100644 index 000000000000..d2ec5021ad10 --- /dev/null +++ b/paimon-python/pypaimon/acceptance/incremental_diff_acceptance_test.py @@ -0,0 +1,322 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +""" +Acceptance tests for IncrementalDiffScanner. + +These tests verify that the diff approach (reading 2 base_manifest_lists) +returns the same data as the delta approach (reading N delta_manifest_lists). + +Uses real file I/O with local temp filesystem. +""" + +import asyncio +import os +import shutil +import tempfile +import unittest + +import pyarrow as pa + +from pypaimon import CatalogFactory, Schema +from pypaimon.manifest.manifest_file_manager import ManifestFileManager +from pypaimon.manifest.manifest_list_manager import ManifestListManager +from pypaimon.read.scanner.append_table_split_generator import \ + AppendTableSplitGenerator +from pypaimon.read.scanner.incremental_diff_scanner import \ + IncrementalDiffScanner +from pypaimon.read.streaming_table_scan import AsyncStreamingTableScan +from pypaimon.snapshot.snapshot_manager import SnapshotManager + + +class IncrementalDiffAcceptanceTest(unittest.TestCase): + """Acceptance tests for diff vs delta equivalence with real data.""" + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + cls.pa_schema = pa.schema([ + ('id', pa.int32()), + ('value', pa.string()), + ('partition_col', pa.string()) + ]) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _create_table_with_snapshots(self, name, num_snapshots=5, partition_keys=None): + """Create a table and write num_snapshots of data. + + Returns: + Tuple of (table, expected_data_per_snapshot) + """ + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=partition_keys) + self.catalog.create_table(f'default.{name}', schema, False) + table = self.catalog.get_table(f'default.{name}') + + all_data = [] + for snap_id in range(1, num_snapshots + 1): + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + + data = { + 'id': [snap_id * 10 + i for i in range(5)], + 'value': [f'snap{snap_id}_row{i}' for i in range(5)], + 'partition_col': ['p1' if i % 2 == 0 else 'p2' for i in range(5)] + } + all_data.append(data) + + pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + return table, all_data + + def _read_via_diff(self, table, start_snap_id, end_snap_id): + """Read data using IncrementalDiffScanner between two snapshots.""" + snapshot_manager = SnapshotManager(table) + start_snapshot = snapshot_manager.get_snapshot_by_id(start_snap_id) + end_snapshot = snapshot_manager.get_snapshot_by_id(end_snap_id) + + scanner = IncrementalDiffScanner(table) + plan = scanner.scan(start_snapshot, end_snapshot) + + splits = plan.splits() + if not splits: + # Return empty table with correct schema + return pa.Table.from_pydict({ + 'id': [], + 'value': [], + 'partition_col': [] + }, schema=self.pa_schema) + + table_read = table.new_read_builder().new_read() + return table_read.to_arrow(splits) + + def _read_via_delta(self, table, start_snap_id, end_snap_id): + """Read data by iterating delta_manifest_lists between two snapshots.""" + snapshot_manager = SnapshotManager(table) + manifest_list_manager = ManifestListManager(table) + manifest_file_manager = ManifestFileManager(table) + + all_entries = [] + for snap_id in range(start_snap_id + 1, end_snap_id + 1): + snapshot = snapshot_manager.get_snapshot_by_id(snap_id) + if snapshot and snapshot.commit_kind == 'APPEND': + manifest_files = manifest_list_manager.read_delta(snapshot) + if manifest_files: + entries = manifest_file_manager.read_entries_parallel(manifest_files) + all_entries.extend(entries) + + if not all_entries: + return pa.Table.from_pydict({ + 'id': [], + 'value': [], + 'partition_col': [] + }, schema=self.pa_schema) + + # Create splits from entries + options = table.options + split_generator = AppendTableSplitGenerator( + table, + options.source_split_target_size(), + options.source_split_open_file_cost(), + {} + ) + splits = split_generator.create_splits(all_entries) + + table_read = table.new_read_builder().new_read() + return table_read.to_arrow(splits) + + def _rows_to_set(self, arrow_table): + """Convert arrow table to set of (id, value, partition_col) tuples.""" + rows = set() + for i in range(arrow_table.num_rows): + row = ( + arrow_table.column('id')[i].as_py(), + arrow_table.column('value')[i].as_py(), + arrow_table.column('partition_col')[i].as_py() + ) + rows.add(row) + return rows + + def test_diff_returns_same_rows_as_delta_simple(self): + """ + Basic case: 5 snapshots, verify row-level equivalence. + + Creates a table with 5 snapshots, then reads data from snapshot 1 to 5 + using both diff and delta approaches, verifying they return the same rows. + """ + table, all_data = self._create_table_with_snapshots( + 'test_diff_delta_simple', + num_snapshots=5 + ) + + # Read using both approaches (from snapshot 1 to 5, so we get snapshots 2-5) + diff_result = self._read_via_diff(table, 1, 5) + delta_result = self._read_via_delta(table, 1, 5) + + # Convert to sets for order-independent comparison + diff_rows = self._rows_to_set(diff_result) + delta_rows = self._rows_to_set(delta_result) + + self.assertEqual(diff_rows, delta_rows) + + # Verify we got the expected number of rows (snapshots 2-5, 5 rows each = 20) + self.assertEqual(len(diff_rows), 20) + + # Verify specific IDs are present (from snapshots 2-5) + expected_ids = set() + for snap_id in range(2, 6): # snapshots 2, 3, 4, 5 + for i in range(5): + expected_ids.add(snap_id * 10 + i) + + actual_ids = {row[0] for row in diff_rows} + self.assertEqual(actual_ids, expected_ids) + + def test_diff_returns_same_rows_as_delta_many_snapshots(self): + """ + Stress test: 20 snapshots, verify row-level equivalence. + + This tests the catch-up scenario where there are many snapshots + between start and end. + """ + table, all_data = self._create_table_with_snapshots( + 'test_diff_delta_many', + num_snapshots=20 + ) + + # Read using both approaches (from snapshot 1 to 20) + diff_result = self._read_via_diff(table, 1, 20) + delta_result = self._read_via_delta(table, 1, 20) + + # Convert to sets for order-independent comparison + diff_rows = self._rows_to_set(diff_result) + delta_rows = self._rows_to_set(delta_result) + + self.assertEqual(diff_rows, delta_rows) + + # Verify we got the expected number of rows (snapshots 2-20, 5 rows each = 95) + self.assertEqual(len(diff_rows), 95) + + def test_diff_returns_same_rows_with_mixed_partitions(self): + """ + Partitioned table: Verify diff handles multiple partitions correctly. + + Creates a partitioned table and verifies diff and delta return + the same rows across all partitions. + """ + table, all_data = self._create_table_with_snapshots( + 'test_diff_delta_partitioned', + num_snapshots=5, + partition_keys=['partition_col'] + ) + + # Read using both approaches + diff_result = self._read_via_diff(table, 1, 5) + delta_result = self._read_via_delta(table, 1, 5) + + # Convert to sets for order-independent comparison + diff_rows = self._rows_to_set(diff_result) + delta_rows = self._rows_to_set(delta_result) + + self.assertEqual(diff_rows, delta_rows) + + # Verify both partitions have data + p1_rows = {r for r in diff_rows if r[2] == 'p1'} + p2_rows = {r for r in diff_rows if r[2] == 'p2'} + + self.assertGreater(len(p1_rows), 0, "Should have rows in partition p1") + self.assertGreater(len(p2_rows), 0, "Should have rows in partition p2") + + def test_streaming_catch_up_returns_same_data(self): + """ + End-to-end: Verify AsyncStreamingTableScan catch-up returns same data. + + Creates a table with 20 snapshots, then uses streaming scan with + a low diff_threshold to trigger diff-based catch-up. Verifies the + total rows match expected. + """ + table, all_data = self._create_table_with_snapshots( + 'test_streaming_catch_up', + num_snapshots=20 + ) + + # Create streaming scan with low threshold to trigger diff catch-up + scan = AsyncStreamingTableScan( + table, + poll_interval_ms=10, + diff_threshold=5, # Low threshold to trigger diff for gap > 5 + prefetch_enabled=False + ) + + # Restore to snapshot 1 (will trigger catch-up to snapshot 20) + scan.next_snapshot_id = 1 + + # Collect all rows from streaming scan + all_rows = [] + table_read = table.new_read_builder().new_read() + + async def collect_rows(): + plan_count = 0 + async for plan in scan.stream(): + splits = plan.splits() + if splits: + arrow_table = table_read.to_arrow(splits) + for i in range(arrow_table.num_rows): + row = ( + arrow_table.column('id')[i].as_py(), + arrow_table.column('value')[i].as_py(), + arrow_table.column('partition_col')[i].as_py() + ) + all_rows.append(row) + plan_count += 1 + # After first plan (catch-up), we should have all data + # Break to avoid infinite loop waiting for new snapshots + if plan_count >= 1: + break + + asyncio.run(collect_rows()) + + # Verify diff catch-up was used (gap of 19 > threshold of 5) + self.assertTrue(scan._diff_catch_up_used, + "Diff-based catch-up should have been used for large gap") + + # Verify we got all expected rows (snapshots 1-20, 5 rows each = 100) + # Note: catch-up includes snapshot 1's data since we start from next_snapshot_id=1 + self.assertEqual(len(all_rows), 100) + + # Verify all expected IDs are present + expected_ids = set() + for snap_id in range(1, 21): # snapshots 1-20 + for i in range(5): + expected_ids.add(snap_id * 10 + i) + + actual_ids = {row[0] for row in all_rows} + self.assertEqual(actual_ids, expected_ids) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/common/options/core_options.py b/paimon-python/pypaimon/common/options/core_options.py index 85cf965f91b9..7e689d21c648 100644 --- a/paimon-python/pypaimon/common/options/core_options.py +++ b/paimon-python/pypaimon/common/options/core_options.py @@ -36,6 +36,16 @@ class ExternalPathStrategy(str, Enum): SPECIFIC_FS = "specific-fs" +class ChangelogProducer(str, Enum): + """ + Changelog producer for streaming reads. + """ + NONE = "none" + INPUT = "input" + FULL_COMPACTION = "full-compaction" + LOOKUP = "lookup" + + class MergeEngine(str, Enum): """ Specifies the merge engine for table with primary key. @@ -273,6 +283,14 @@ class CoreOptions: .with_description("Whether to enable deletion vectors.") ) + CHANGELOG_PRODUCER: ConfigOption[ChangelogProducer] = ( + ConfigOptions.key("changelog-producer") + .enum_type(ChangelogProducer) + .default_value(ChangelogProducer.NONE) + .with_description("The changelog producer for streaming reads. " + "Options: none, input, full-compaction, lookup.") + ) + MERGE_ENGINE: ConfigOption[MergeEngine] = ( ConfigOptions.key("merge-engine") .enum_type(MergeEngine) @@ -500,6 +518,9 @@ def data_evolution_enabled(self, default=None): def deletion_vectors_enabled(self, default=None): return self.options.get(CoreOptions.DELETION_VECTORS_ENABLED, default) + def changelog_producer(self, default=None): + return self.options.get(CoreOptions.CHANGELOG_PRODUCER, default) + def merge_engine(self, default=None): return self.options.get(CoreOptions.MERGE_ENGINE, default) diff --git a/paimon-python/pypaimon/read/scanner/file_scanner.py b/paimon-python/pypaimon/read/scanner/file_scanner.py index 3345ae68101f..52c7c4e53093 100755 --- a/paimon-python/pypaimon/read/scanner/file_scanner.py +++ b/paimon-python/pypaimon/read/scanner/file_scanner.py @@ -189,7 +189,6 @@ def __init__( self.partition_key_predicate = trim_and_transform_predicate( self.predicate, self.table.field_names, self.table.partition_keys) options = self.table.options - # Get split target size and open file cost from table options self.target_split_size = options.source_split_target_size() self.open_file_cost = options.source_split_open_file_cost() @@ -244,7 +243,6 @@ def scan(self) -> Plan: if not entries: return Plan([]) - # Configure sharding if needed if self.idx_of_this_subtask is not None: split_generator.with_shard(self.idx_of_this_subtask, self.number_of_para_subtasks) elif self.start_pos_of_this_subtask is not None: diff --git a/paimon-python/pypaimon/read/stream_read_builder.py b/paimon-python/pypaimon/read/stream_read_builder.py new file mode 100644 index 000000000000..0075cbd9727a --- /dev/null +++ b/paimon-python/pypaimon/read/stream_read_builder.py @@ -0,0 +1,142 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +""" +StreamReadBuilder for building streaming table scans and reads. + +This module provides a builder for configuring streaming reads from Paimon +tables, similar to ReadBuilder but for continuous streaming use cases. +""" + +from typing import Callable, List, Optional, Set + +from pypaimon.common.predicate import Predicate +from pypaimon.common.predicate_builder import PredicateBuilder +from pypaimon.read.streaming_table_scan import AsyncStreamingTableScan +from pypaimon.read.table_read import TableRead +from pypaimon.schema.data_types import DataField +from pypaimon.table.special_fields import SpecialFields + + +class StreamReadBuilder: + """ + Builder for streaming reads from Paimon tables. + + Usage: + stream_builder = table.new_stream_read_builder() + stream_builder.with_poll_interval_ms(500) + + scan = stream_builder.new_streaming_scan() + table_read = stream_builder.new_read() + + async for plan in scan.stream(): + arrow_table = table_read.to_arrow(plan.splits()) + process(arrow_table) + """ + + def __init__(self, table): + """Initialize the StreamReadBuilder.""" + from pypaimon.table.file_store_table import FileStoreTable + + self.table: FileStoreTable = table + self._predicate: Optional[Predicate] = None + self._projection: Optional[List[str]] = None + self._poll_interval_ms: int = 1000 + self._include_row_kind: bool = False + self._bucket_filter: Optional[Callable[[int], bool]] = None + + def with_filter(self, predicate: Predicate) -> 'StreamReadBuilder': + """Set a filter predicate for the streaming read.""" + self._predicate = predicate + return self + + def with_projection(self, projection: List[str]) -> 'StreamReadBuilder': + """Set column projection for the streaming read.""" + self._projection = projection + return self + + def with_poll_interval_ms(self, poll_interval_ms: int) -> 'StreamReadBuilder': + """Set the poll interval in ms for checking new snapshots (default: 1000).""" + self._poll_interval_ms = poll_interval_ms + return self + + def with_include_row_kind(self, include: bool = True) -> 'StreamReadBuilder': + """Include row kind column (_row_kind) in the output. + + When enabled, the output will include a _row_kind column as the first + column with values: +I (insert), -U (update before), +U (update after), + -D (delete). + """ + self._include_row_kind = include + return self + + def with_bucket_filter( + self, + bucket_filter: Callable[[int], bool] + ) -> 'StreamReadBuilder': + """Push bucket filter for parallel consumption. + + Example: + builder.with_bucket_filter(lambda b: b % 2 == 0) + builder.with_bucket_filter(lambda b: b < 4) + """ + self._bucket_filter = bucket_filter + return self + + def with_buckets(self, bucket_ids: List[int]) -> 'StreamReadBuilder': + """Convenience method to read only specific buckets. + + Example: + builder.with_buckets([0, 1, 2]) + builder.with_buckets([3, 4, 5]) + """ + bucket_set: Set[int] = set(bucket_ids) + return self.with_bucket_filter(lambda bucket: bucket in bucket_set) + + def new_streaming_scan(self) -> AsyncStreamingTableScan: + """Create a new AsyncStreamingTableScan with this builder's settings.""" + return AsyncStreamingTableScan( + table=self.table, + predicate=self._predicate, + poll_interval_ms=self._poll_interval_ms, + bucket_filter=self._bucket_filter + ) + + def new_read(self) -> TableRead: + """Create a new TableRead with this builder's settings.""" + return TableRead( + table=self.table, + predicate=self._predicate, + read_type=self.read_type(), + include_row_kind=self._include_row_kind + ) + + def new_predicate_builder(self) -> PredicateBuilder: + """Create a PredicateBuilder for building filter predicates.""" + return PredicateBuilder(self.read_type()) + + def read_type(self) -> List[DataField]: + """Get the read schema fields, applying projection if set.""" + table_fields = self.table.fields + + if not self._projection: + return table_fields + else: + if self.table.options.row_tracking_enabled(): + table_fields = SpecialFields.row_type_with_row_tracking(table_fields) + field_map = {field.name: field for field in table_fields} + return [field_map[name] for name in self._projection if name in field_map] diff --git a/paimon-python/pypaimon/read/streaming_table_scan.py b/paimon-python/pypaimon/read/streaming_table_scan.py new file mode 100644 index 000000000000..9c0c45e006ed --- /dev/null +++ b/paimon-python/pypaimon/read/streaming_table_scan.py @@ -0,0 +1,391 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +""" +AsyncStreamingTableScan for continuous streaming reads from Paimon tables. + +This module provides async-based streaming reads that continuously poll for +new snapshots and yield Plans as new data arrives. It is the Python equivalent +of Java's DataTableStreamScan. +""" + +import asyncio +import os +from concurrent.futures import Future, ThreadPoolExecutor +from typing import AsyncIterator, Callable, Iterator, List, Optional + +from pypaimon.common.options.core_options import ChangelogProducer +from pypaimon.common.predicate import Predicate +from pypaimon.manifest.manifest_file_manager import ManifestFileManager +from pypaimon.manifest.manifest_list_manager import ManifestListManager +from pypaimon.read.plan import Plan +from pypaimon.read.scanner.append_table_split_generator import \ + AppendTableSplitGenerator +from pypaimon.read.scanner.changelog_follow_up_scanner import \ + ChangelogFollowUpScanner +from pypaimon.read.scanner.delta_follow_up_scanner import DeltaFollowUpScanner +from pypaimon.read.scanner.file_scanner import FileScanner +from pypaimon.read.scanner.follow_up_scanner import FollowUpScanner +from pypaimon.read.scanner.incremental_diff_scanner import \ + IncrementalDiffScanner +from pypaimon.read.scanner.primary_key_table_split_generator import \ + PrimaryKeyTableSplitGenerator +from pypaimon.snapshot.snapshot import Snapshot +from pypaimon.snapshot.snapshot_manager import SnapshotManager + + +class AsyncStreamingTableScan: + """ + Async streaming table scan for continuous reads from Paimon tables. + + This class provides an async iterator that continuously polls for new + snapshots and yields Plans containing splits for new data. + + Usage: + scan = AsyncStreamingTableScan(table) + + async for plan in scan.stream(): + for split in plan.splits(): + # Process the data + pass + + For synchronous usage: + for plan in scan.stream_sync(): + process(plan) + """ + + def __init__( + self, + table, + predicate: Optional[Predicate] = None, + poll_interval_ms: int = 1000, + follow_up_scanner: Optional[FollowUpScanner] = None, + bucket_filter: Optional[Callable[[int], bool]] = None, + prefetch_enabled: bool = True, + diff_threshold: int = 10 + ): + """Initialize the streaming table scan.""" + self.table = table + self.predicate = predicate + self.poll_interval = poll_interval_ms / 1000.0 + + # Bucket filter for parallel consumption + self._bucket_filter = bucket_filter + + # Diff-based catch-up configuration + self._diff_threshold = diff_threshold + self._catch_up_in_progress = False + + # Prefetching configuration + self._prefetch_enabled = prefetch_enabled + self._prefetch_future: Optional[Future] = None + self._prefetch_snapshot_id: Optional[int] = None + self._prefetch_hits = 0 + self._prefetch_misses = 0 + self._lookahead_skips = 0 # Track how many snapshots were skipped via lookahead + self._prefetch_executor = ThreadPoolExecutor(max_workers=1) if prefetch_enabled else None + self._lookahead_size = 10 # How many snapshots to look ahead + self._diff_catch_up_used = False # Track if diff-based catch-up was used + + # Initialize managers + self._snapshot_manager = SnapshotManager(table) + self._manifest_list_manager = ManifestListManager(table) + self._manifest_file_manager = ManifestFileManager(table) + + # Scanner for determining which snapshots to read + # Auto-select based on changelog-producer if not explicitly provided + self.follow_up_scanner = follow_up_scanner or self._create_follow_up_scanner() + + # State tracking + self.next_snapshot_id: Optional[int] = None + self._initialized = False + + async def stream(self) -> AsyncIterator[Plan]: + """Yield Plans as new snapshots appear. + + On first call, performs an initial full scan of the latest snapshot. + Subsequent iterations poll for new snapshots and yield delta Plans. + + Yields: + Plan objects containing splits for reading + """ + # Initial scan + if self.next_snapshot_id is None: + latest_snapshot = self._snapshot_manager.get_latest_snapshot() + if latest_snapshot: + self.next_snapshot_id = latest_snapshot.id + 1 + yield self._create_initial_plan(latest_snapshot) + self._initialized = True + + # Check for catch-up scenario: starting from earlier snapshot with large gap + # This handles --from earliest or --from snapshot:X with many snapshots to process + if self._should_use_diff_catch_up(): + self._catch_up_in_progress = True + self._diff_catch_up_used = True + try: + latest_snapshot = self._snapshot_manager.get_latest_snapshot() + if latest_snapshot and self.next_snapshot_id: + catch_up_plan = self._create_catch_up_plan( + self.next_snapshot_id, + latest_snapshot + ) + self.next_snapshot_id = latest_snapshot.id + 1 + self._initialized = True + if catch_up_plan.splits(): + yield catch_up_plan + finally: + self._catch_up_in_progress = False + + # Follow-up polling loop with lookahead and optional prefetching + while True: + plan = None + snapshot_processed = False # Track if we processed (or skipped) a snapshot + + # Check if we have a prefetched result ready + prefetch_used = False + if self._prefetch_future is not None: + try: + # Wait for the prefetch thread to complete + # Returns (plan, next_id, skipped_count) tuple + prefetch_result = self._prefetch_future.result(timeout=30) + prefetch_used = True + + if prefetch_result is not None: + prefetch_plan, next_id, skipped_count = prefetch_result + self._lookahead_skips += skipped_count + self.next_snapshot_id = next_id + snapshot_processed = skipped_count > 0 or prefetch_plan is not None + + if prefetch_plan is not None: + plan = prefetch_plan + self._prefetch_hits += 1 + except Exception: + # Prefetch failed, fall back to synchronous + prefetch_used = False + finally: + self._prefetch_future = None + self._prefetch_snapshot_id = None + + # If prefetch wasn't available or failed, use lookahead to find next scannable + if not prefetch_used: + self._prefetch_misses += 1 + # Use batch lookahead to find the next scannable snapshot + snapshot, next_id, skipped_count = self._snapshot_manager.find_next_scannable( + self.next_snapshot_id, + self.follow_up_scanner.should_scan, + lookahead_size=self._lookahead_size + ) + self._lookahead_skips += skipped_count + self.next_snapshot_id = next_id + + # Check if we found a scannable snapshot or skipped some + snapshot_processed = skipped_count > 0 or snapshot is not None + + if snapshot is not None: + plan = self._create_follow_up_plan(snapshot) + + if plan is not None: + # Start prefetching next scannable snapshot before yielding + if self._prefetch_enabled: + self._start_prefetch(self.next_snapshot_id) + yield plan + elif not snapshot_processed: + # No snapshot available yet, wait and poll again + await asyncio.sleep(self.poll_interval) + # If snapshots were processed but plan is None (all skipped), continue loop immediately + + def stream_sync(self) -> Iterator[Plan]: + """ + Synchronous wrapper for stream(). + + Provides a blocking iterator for use in non-async code. + + Yields: + Plan objects containing splits for reading + """ + loop = asyncio.new_event_loop() + try: + async_gen = self.stream() + while True: + try: + plan = loop.run_until_complete(async_gen.__anext__()) + yield plan + except StopAsyncIteration: + break + finally: + loop.close() + + def _start_prefetch(self, snapshot_id: int) -> None: + """Start prefetching the next scannable snapshot in a background thread.""" + if self._prefetch_future is not None or self._prefetch_executor is None: + return # Already prefetching or executor not available + + self._prefetch_snapshot_id = snapshot_id + # Submit to thread pool - this starts immediately, not when event loop runs + self._prefetch_future = self._prefetch_executor.submit( + self._fetch_plan_with_lookahead, + snapshot_id + ) + + def _fetch_plan_with_lookahead(self, start_id: int) -> Optional[tuple]: + """Find next scannable snapshot via lookahead and create a plan. Runs in thread pool.""" + try: + snapshot, next_id, skipped_count = self._snapshot_manager.find_next_scannable( + start_id, + self.follow_up_scanner.should_scan, + lookahead_size=self._lookahead_size + ) + + if snapshot is None: + return (None, next_id, skipped_count) + + plan = self._create_follow_up_plan(snapshot) + return (plan, next_id, skipped_count) + except Exception: + return None + + def _create_follow_up_plan(self, snapshot: Snapshot) -> Plan: + """Route to changelog or delta plan based on scanner type.""" + if isinstance(self.follow_up_scanner, ChangelogFollowUpScanner): + return self._create_changelog_plan(snapshot) + else: + return self._create_delta_plan(snapshot) + + def _create_follow_up_scanner(self) -> FollowUpScanner: + """Create the appropriate follow-up scanner based on changelog-producer option.""" + changelog_producer = self.table.options.changelog_producer() + if changelog_producer == ChangelogProducer.NONE: + return DeltaFollowUpScanner() + else: + # INPUT, FULL_COMPACTION, LOOKUP all use changelog scanner + return ChangelogFollowUpScanner() + + def _filter_entries_for_shard(self, entries: List) -> List: + """Filter manifest entries by bucket filter, if set.""" + if self._bucket_filter is not None: + return [e for e in entries if self._bucket_filter(e.bucket)] + return entries + + def _create_initial_plan(self, snapshot: Snapshot) -> Plan: + """Create a Plan for the initial full scan of the latest snapshot.""" + def all_manifests(): + return self._manifest_list_manager.read_all(snapshot) + + starting_scanner = FileScanner( + self.table, + all_manifests, + predicate=self.predicate, + limit=None + ) + return starting_scanner.scan() + + def _create_delta_plan(self, snapshot: Snapshot) -> Plan: + """Read new files from delta_manifest_list (changelog-producer=none).""" + manifest_files = self._manifest_list_manager.read_delta(snapshot) + return self._create_plan_from_manifests(manifest_files) + + def _create_changelog_plan(self, snapshot: Snapshot) -> Plan: + """Read from changelog_manifest_list (changelog-producer=input/full-compaction/lookup).""" + manifest_files = self._manifest_list_manager.read_changelog(snapshot) + return self._create_plan_from_manifests(manifest_files) + + def _create_plan_from_manifests(self, manifest_files: List) -> Plan: + """Create splits from manifest files, applying shard filtering.""" + if not manifest_files: + return Plan([]) + + # Use configurable parallelism from table options + max_workers = max(8, self.table.options.scan_manifest_parallelism(os.cpu_count() or 8)) + + # Read manifest entries from manifest files + entries = self._manifest_file_manager.read_entries_parallel( + manifest_files, + manifest_entry_filter=None, + max_workers=max_workers + ) + + # Apply shard/bucket filtering for parallel consumption + entries = self._filter_entries_for_shard(entries) if entries else [] + if not entries: + return Plan([]) + + # Get split options from table + options = self.table.options + target_split_size = options.source_split_target_size() + open_file_cost = options.source_split_open_file_cost() + + # Create appropriate split generator based on table type + if self.table.is_primary_key_table: + split_generator = PrimaryKeyTableSplitGenerator( + self.table, + target_split_size, + open_file_cost, + deletion_files_map={} + ) + else: + split_generator = AppendTableSplitGenerator( + self.table, + target_split_size, + open_file_cost, + deletion_files_map={} + ) + + splits = split_generator.create_splits(entries) + return Plan(splits) + + def _should_use_diff_catch_up(self) -> bool: + """Check if diff-based catch-up should be used (large gap to latest).""" + if self._catch_up_in_progress: + return False + + if self.next_snapshot_id is None: + return False + + latest = self._snapshot_manager.get_latest_snapshot() + if latest is None: + return False + + gap = latest.id - self.next_snapshot_id + return gap > self._diff_threshold + + def _create_catch_up_plan(self, start_id: int, end_snapshot: Snapshot) -> Plan: + """Create a catch-up plan using diff-based scanning between start and end snapshots.""" + # Get start snapshot (one before where we want to start reading) + # If start_id is 0 or 1, use None to indicate "from beginning" + start_snapshot = None + if start_id > 1: + start_snapshot = self._snapshot_manager.get_snapshot_by_id(start_id - 1) + + # Create diff scanner + diff_scanner = IncrementalDiffScanner(self.table) + + if start_snapshot is None: + # No start snapshot - return all files from end snapshot + # This is equivalent to a full scan of end snapshot + def end_snapshot_manifests(): + return self._manifest_list_manager.read_all(end_snapshot) + + starting_scanner = FileScanner( + self.table, + end_snapshot_manifests, + predicate=self.predicate, + limit=None + ) + return starting_scanner.scan() + else: + # Use diff scanner for efficient catch-up + return diff_scanner.scan(start_snapshot, end_snapshot) diff --git a/paimon-python/pypaimon/snapshot/snapshot_manager.py b/paimon-python/pypaimon/snapshot/snapshot_manager.py index 61678e2d96e1..25630f78ee29 100644 --- a/paimon-python/pypaimon/snapshot/snapshot_manager.py +++ b/paimon-python/pypaimon/snapshot/snapshot_manager.py @@ -16,7 +16,8 @@ # limitations under the License. ################################################################################ import logging -from typing import Optional +from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Dict, List, Optional from pypaimon.common.file_io import FileIO @@ -186,3 +187,69 @@ def get_snapshot_by_id(self, snapshot_id: int) -> Optional[Snapshot]: return None snapshot_content = self.file_io.read_file_utf8(snapshot_file) return JSON.from_json(snapshot_content, Snapshot) + + def get_snapshots_batch( + self, snapshot_ids: List[int], max_workers: int = 4 + ) -> Dict[int, Optional[Snapshot]]: + """Fetch multiple snapshots in parallel, returning {id: Snapshot|None}.""" + if not snapshot_ids: + return {} + + # First, batch check which snapshot files exist + paths = [self.get_snapshot_path(sid) for sid in snapshot_ids] + existence = self.file_io.exists_batch(paths) + + # Filter to only existing snapshots + existing_ids = [ + sid for sid, path in zip(snapshot_ids, paths) + if existence.get(path, False) + ] + + if not existing_ids: + return {sid: None for sid in snapshot_ids} + + # Fetch existing snapshots in parallel + def fetch_one(sid: int) -> tuple: + try: + return (sid, self.get_snapshot_by_id(sid)) + except Exception: + return (sid, None) + + results = {sid: None for sid in snapshot_ids} + + with ThreadPoolExecutor(max_workers=min(len(existing_ids), max_workers)) as executor: + for sid, snapshot in executor.map(fetch_one, existing_ids): + results[sid] = snapshot + + return results + + def find_next_scannable( + self, + start_id: int, + should_scan: Callable[[Snapshot], bool], + lookahead_size: int = 10, + max_workers: int = 4 + ) -> tuple: + """Find the next snapshot passing should_scan, using batch lookahead.""" + # Generate the range of snapshot IDs to check + snapshot_ids = list(range(start_id, start_id + lookahead_size)) + + # Batch fetch all snapshots + snapshots = self.get_snapshots_batch(snapshot_ids, max_workers) + + # Find the first scannable snapshot in order + skipped_count = 0 + for sid in snapshot_ids: + snapshot = snapshots.get(sid) + if snapshot is None: + # No more snapshots exist at this ID + # Return next_id = sid so caller knows where to wait + return (None, sid, skipped_count) + if should_scan(snapshot): + # Found a scannable snapshot + return (snapshot, sid + 1, skipped_count) + skipped_count += 1 + + # All fetched snapshots were skipped, but more may exist + # Return next_id pointing past the batch + return (None, start_id + lookahead_size, skipped_count) diff --git a/paimon-python/pypaimon/table/file_store_table.py b/paimon-python/pypaimon/table/file_store_table.py index 5fd46248d0a7..b3e50d6a909c 100644 --- a/paimon-python/pypaimon/table/file_store_table.py +++ b/paimon-python/pypaimon/table/file_store_table.py @@ -19,21 +19,22 @@ from typing import List, Optional from pypaimon.catalog.catalog_environment import CatalogEnvironment -from pypaimon.common.options.core_options import CoreOptions from pypaimon.common.file_io import FileIO from pypaimon.common.identifier import Identifier +from pypaimon.common.options.core_options import CoreOptions from pypaimon.common.options.options import Options from pypaimon.read.read_builder import ReadBuilder +from pypaimon.read.stream_read_builder import StreamReadBuilder from pypaimon.schema.schema_manager import SchemaManager from pypaimon.schema.table_schema import TableSchema from pypaimon.table.bucket_mode import BucketMode from pypaimon.table.table import Table -from pypaimon.write.write_builder import BatchWriteBuilder, StreamWriteBuilder from pypaimon.write.row_key_extractor import (DynamicBucketRowKeyExtractor, FixedBucketRowKeyExtractor, PostponeBucketRowKeyExtractor, RowKeyExtractor, UnawareBucketRowKeyExtractor) +from pypaimon.write.write_builder import BatchWriteBuilder, StreamWriteBuilder class FileStoreTable(Table): @@ -102,12 +103,12 @@ def create_tag( ) -> None: """ Create a tag for a snapshot. - + Args: tag_name: Name for the tag snapshot_id: ID of the snapshot to tag. If None, uses the latest snapshot. ignore_if_exists: If True, don't raise error if tag already exists - + Raises: ValueError: If no snapshot exists or tag already exists (when ignore_if_exists=False) """ @@ -129,10 +130,10 @@ def create_tag( def delete_tag(self, tag_name: str) -> bool: """ Delete a tag. - + Args: tag_name: Name of the tag to delete - + Returns: True if tag was deleted, False if tag didn't exist """ @@ -310,10 +311,8 @@ def new_global_index_scan_builder(self) -> Optional['GlobalIndexScanBuilder']: if not self.options.global_index_enabled(): return None - from pypaimon.globalindex.global_index_scan_builder_impl import ( + from pypaimon.globalindex.global_index_scan_builder_impl import \ GlobalIndexScanBuilderImpl - ) - from pypaimon.index.index_file_handler import IndexFileHandler return GlobalIndexScanBuilderImpl( @@ -325,6 +324,9 @@ def new_global_index_scan_builder(self) -> Optional['GlobalIndexScanBuilder']: index_file_handler=IndexFileHandler(table=self) ) + def new_stream_read_builder(self) -> 'StreamReadBuilder': + return StreamReadBuilder(self) + def new_batch_write_builder(self) -> BatchWriteBuilder: return BatchWriteBuilder(self) @@ -366,10 +368,10 @@ def copy(self, options: dict) -> 'FileStoreTable': def _try_time_travel(self, options: Options) -> Optional[TableSchema]: """ Try to resolve time travel options and return the corresponding schema. - + Supports the following time travel options: - scan.tag-name: Travel to a specific tag - + Returns: The TableSchema at the time travel point, or None if no time travel option is set. """ @@ -385,6 +387,7 @@ def _try_time_travel(self, options: Options) -> Optional[TableSchema]: def _create_external_paths(self) -> List[str]: from urllib.parse import urlparse + from pypaimon.common.options.core_options import ExternalPathStrategy external_paths_str = self.options.data_file_external_paths() diff --git a/paimon-python/pypaimon/table/format/format_table.py b/paimon-python/pypaimon/table/format/format_table.py index 564bd086255d..8c7e82415ec9 100644 --- a/paimon-python/pypaimon/table/format/format_table.py +++ b/paimon-python/pypaimon/table/format/format_table.py @@ -93,8 +93,12 @@ def new_read_builder(self): return FormatReadBuilder(self) def new_batch_write_builder(self): - from pypaimon.table.format.format_batch_write_builder import FormatBatchWriteBuilder + from pypaimon.table.format.format_batch_write_builder import \ + FormatBatchWriteBuilder return FormatBatchWriteBuilder(self) + def new_stream_read_builder(self): + raise NotImplementedError("Format table does not support stream read.") + def new_stream_write_builder(self): raise NotImplementedError("Format table does not support stream write.") diff --git a/paimon-python/pypaimon/table/iceberg/iceberg_table.py b/paimon-python/pypaimon/table/iceberg/iceberg_table.py index ab2b8c9e03e4..137b7c862282 100644 --- a/paimon-python/pypaimon/table/iceberg/iceberg_table.py +++ b/paimon-python/pypaimon/table/iceberg/iceberg_table.py @@ -98,6 +98,11 @@ def new_batch_write_builder(self): "IcebergTable does not support batch write operation in paimon-python yet." ) + def new_stream_read_builder(self): + raise NotImplementedError( + "IcebergTable does not support stream read operation in paimon-python yet." + ) + def new_stream_write_builder(self): raise NotImplementedError( "IcebergTable does not support stream write operation in paimon-python yet." diff --git a/paimon-python/pypaimon/table/table.py b/paimon-python/pypaimon/table/table.py index e20784f1fc9d..dfae23281065 100644 --- a/paimon-python/pypaimon/table/table.py +++ b/paimon-python/pypaimon/table/table.py @@ -19,6 +19,7 @@ from abc import ABC, abstractmethod from pypaimon.read.read_builder import ReadBuilder +from pypaimon.read.stream_read_builder import StreamReadBuilder from pypaimon.write.write_builder import BatchWriteBuilder, StreamWriteBuilder @@ -29,6 +30,10 @@ class Table(ABC): def new_read_builder(self) -> ReadBuilder: """Return a builder for building table scan and table read.""" + @abstractmethod + def new_stream_read_builder(self) -> StreamReadBuilder: + """Return a builder for building streaming table scan and read.""" + @abstractmethod def new_batch_write_builder(self) -> BatchWriteBuilder: """Returns a builder for building batch table write and table commit.""" diff --git a/paimon-python/pypaimon/tests/stream_read_builder_test.py b/paimon-python/pypaimon/tests/stream_read_builder_test.py new file mode 100644 index 000000000000..c0cf6ada1b04 --- /dev/null +++ b/paimon-python/pypaimon/tests/stream_read_builder_test.py @@ -0,0 +1,170 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +""" +Test cases for StreamReadBuilder bucket filtering functionality. + +Tests the with_bucket_filter() and with_buckets() methods +that enable parallel consumption across multiple consumer processes. +""" + +from unittest.mock import MagicMock + +import pytest + +from pypaimon.read.stream_read_builder import StreamReadBuilder + + +# ----------------------------------------------------------------------------- +# Fixtures +# ----------------------------------------------------------------------------- + + +@pytest.fixture +def mock_table(): + """Create a mock table for unit tests.""" + table = MagicMock() + table.fields = [] + table.options.row_tracking_enabled.return_value = False + return table + + +@pytest.fixture +def builder(mock_table): + """Create a StreamReadBuilder with mock table.""" + return StreamReadBuilder(mock_table) + + +class MockEntry: + """Mock manifest entry for testing bucket filtering.""" + + def __init__(self, bucket): + self.bucket = bucket + + +# ----------------------------------------------------------------------------- +# Unit Tests: StreamReadBuilder Validation +# ----------------------------------------------------------------------------- + +class TestStreamReadBuilderValidation: + """Unit tests for StreamReadBuilder method validation.""" + + def test_with_bucket_filter_valid(self, builder): + """Test with_bucket_filter() accepts valid filter function.""" + filter_fn = lambda b: b % 2 == 0 + result = builder.with_bucket_filter(filter_fn) + assert result is builder + assert builder._bucket_filter is filter_fn + + @pytest.mark.parametrize("bucket_ids,expected_true,expected_false", [ + ([0, 2, 4], [0, 2, 4], [1, 3, 5]), + ([], [], [0, 1, 2]), + ([5], [5], [0, 1, 4, 6]), + ]) + def test_with_buckets(self, builder, bucket_ids, expected_true, expected_false): + """Test with_buckets() creates correct filter.""" + builder.with_buckets(bucket_ids) + for b in expected_true: + assert builder._bucket_filter(b), f"Bucket {b} should be included" + for b in expected_false: + assert not builder._bucket_filter(b), f"Bucket {b} should be excluded" + + def test_method_chaining(self, builder): + """Test method chaining works correctly.""" + result = (builder + .with_poll_interval_ms(500) + .with_bucket_filter(lambda b: b % 2 == 0) + .with_include_row_kind(True)) + assert result is builder + assert builder._poll_interval_ms == 500 + assert builder._bucket_filter is not None + assert builder._include_row_kind is True + + +# ----------------------------------------------------------------------------- +# Unit Tests: Bucket Filtering Logic +# ----------------------------------------------------------------------------- + +class TestBucketFilteringLogic: + """Test bucket filtering logic used in scans.""" + + @pytest.mark.parametrize("shard_idx,shard_count,expected_buckets", [ + (0, 4, [0, 4]), + (1, 4, [1, 5]), + (2, 4, [2, 6]), + (3, 4, [3, 7]), + (0, 2, [0, 2, 4, 6]), + (1, 2, [1, 3, 5, 7]), + ]) + def test_shard_filtering(self, shard_idx, shard_count, expected_buckets): + """Test shard-based bucket filtering.""" + entries = [MockEntry(b) for b in range(8)] + filtered = [e for e in entries if e.bucket % shard_count == shard_idx] + assert [e.bucket for e in filtered] == expected_buckets + + @pytest.mark.parametrize("num_buckets,num_consumers", [(8, 4), (7, 3), (10, 3), (5, 5)]) + def test_shards_cover_all_buckets(self, num_buckets, num_consumers): + """Test that all shards together cover all buckets exactly once.""" + all_buckets = set() + for shard_idx in range(num_consumers): + shard_buckets = {b for b in range(num_buckets) if b % num_consumers == shard_idx} + assert not (all_buckets & shard_buckets), "Shards should not overlap" + all_buckets.update(shard_buckets) + assert all_buckets == set(range(num_buckets)), "All buckets should be covered" + + +# ----------------------------------------------------------------------------- +# Unit Tests: AsyncStreamingTableScan +# ----------------------------------------------------------------------------- + +class TestAsyncStreamingTableScanFiltering: + """Test AsyncStreamingTableScan._filter_entries_for_shard().""" + + @pytest.fixture + def mock_scan_table(self): + """Create mock table for AsyncStreamingTableScan.""" + table = MagicMock() + table.options.changelog_producer.return_value = MagicMock() + table.file_io = MagicMock() + table.table_path = "/tmp/test" + return table + + def test_filter_with_bucket_filter(self, mock_scan_table): + """Test _filter_entries_for_shard with custom bucket filter.""" + from pypaimon.read.streaming_table_scan import AsyncStreamingTableScan + + scan = AsyncStreamingTableScan( + table=mock_scan_table, + bucket_filter=lambda b: b % 2 == 0 + ) + entries = [MockEntry(b) for b in range(8)] + filtered = scan._filter_entries_for_shard(entries) + assert [e.bucket for e in filtered] == [0, 2, 4, 6] + + def test_filter_no_filter_returns_all(self, mock_scan_table): + """Test _filter_entries_for_shard with no filter returns all entries.""" + from pypaimon.read.streaming_table_scan import AsyncStreamingTableScan + + scan = AsyncStreamingTableScan(table=mock_scan_table) + entries = [MockEntry(b) for b in range(8)] + filtered = scan._filter_entries_for_shard(entries) + assert [e.bucket for e in filtered] == list(range(8)) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/paimon-python/pypaimon/tests/streaming_table_scan_test.py b/paimon-python/pypaimon/tests/streaming_table_scan_test.py new file mode 100644 index 000000000000..3484bcb98e85 --- /dev/null +++ b/paimon-python/pypaimon/tests/streaming_table_scan_test.py @@ -0,0 +1,687 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +""" +Tests for AsyncStreamingTableScan. +TDD: These tests are written first, before the implementation. +""" + +import asyncio +import unittest +from unittest.mock import Mock, patch + +from pypaimon.common.options.core_options import ChangelogProducer +from pypaimon.read.plan import Plan +from pypaimon.read.streaming_table_scan import AsyncStreamingTableScan +from pypaimon.snapshot.snapshot import Snapshot + + +def _create_mock_snapshot(snapshot_id: int, commit_kind: str = "APPEND"): + """Helper to create a mock snapshot.""" + snapshot = Mock(spec=Snapshot) + snapshot.id = snapshot_id + snapshot.commit_kind = commit_kind + snapshot.time_millis = 1000000 + snapshot_id + snapshot.base_manifest_list = f"manifest-list-{snapshot_id}" + snapshot.delta_manifest_list = f"delta-manifest-list-{snapshot_id}" + return snapshot + + +def _create_mock_table(latest_snapshot_id: int = 5): + """Helper to create a mock table.""" + table = Mock() + table.table_path = "/tmp/test_table" + table.is_primary_key_table = False + table.options = Mock() + table.options.source_split_target_size.return_value = 128 * 1024 * 1024 + table.options.source_split_open_file_cost.return_value = 4 * 1024 * 1024 + table.options.scan_manifest_parallelism.return_value = 8 + table.options.bucket.return_value = 1 + table.options.data_evolution_enabled.return_value = False + table.options.deletion_vectors_enabled.return_value = False + table.options.changelog_producer.return_value = ChangelogProducer.NONE + table.field_names = ['col1', 'col2'] + table.trimmed_primary_keys = [] + table.partition_keys = [] + table.file_io = Mock() + table.table_schema = Mock() + table.table_schema.id = 0 + table.table_schema.fields = [] + table.schema_manager = Mock() + table.schema_manager.get_schema.return_value = table.table_schema + + return table, latest_snapshot_id + + +class AsyncStreamingTableScanTest(unittest.TestCase): + """Tests for AsyncStreamingTableScan async streaming functionality.""" + + @patch('pypaimon.read.streaming_table_scan.SnapshotManager') + @patch('pypaimon.read.streaming_table_scan.ManifestListManager') + @patch('pypaimon.read.streaming_table_scan.FileScanner') + def test_initial_scan_sets_next_snapshot_id( + self, + MockStartingScanner, + MockManifestListManager, + MockSnapshotManager): + """After initial scan, next_snapshot_id should be latest + 1.""" + table, latest_id = _create_mock_table(latest_snapshot_id=5) + + # Setup mocks + mock_snapshot_manager = MockSnapshotManager.return_value + mock_snapshot_manager.get_latest_snapshot.return_value = _create_mock_snapshot(5) + mock_snapshot_manager.get_snapshot_by_id.return_value = None + + mock_starting_scanner = MockStartingScanner.return_value + mock_starting_scanner.scan.return_value = Plan([]) + + scan = AsyncStreamingTableScan(table) + + # Run first iteration + async def get_first_plan(): + async for plan in scan.stream(): + return plan + + asyncio.run(get_first_plan()) + + self.assertEqual(scan.next_snapshot_id, 6) + + @patch('pypaimon.read.streaming_table_scan.SnapshotManager') + @patch('pypaimon.read.streaming_table_scan.ManifestListManager') + @patch('pypaimon.read.streaming_table_scan.FileScanner') + def test_initial_scan_yields_plan(self, MockStartingScanner, MockManifestListManager, MockSnapshotManager): + """Initial scan should yield a Plan with splits.""" + table, _ = _create_mock_table(latest_snapshot_id=5) + + # Setup mocks + mock_snapshot_manager = MockSnapshotManager.return_value + mock_snapshot_manager.get_latest_snapshot.return_value = _create_mock_snapshot(5) + mock_snapshot_manager.get_snapshot_by_id.return_value = None + + mock_starting_scanner = MockStartingScanner.return_value + mock_starting_scanner.scan.return_value = Plan([]) + + scan = AsyncStreamingTableScan(table) + + async def get_first_plan(): + async for plan in scan.stream(): + return plan + + plan = asyncio.run(get_first_plan()) + + self.assertIsInstance(plan, Plan) + + @patch('pypaimon.read.streaming_table_scan.SnapshotManager') + @patch('pypaimon.read.streaming_table_scan.ManifestListManager') + @patch('pypaimon.read.streaming_table_scan.FileScanner') + def test_stream_skips_non_append_commits(self, MockStartingScanner, MockManifestListManager, MockSnapshotManager): + """Stream should skip COMPACT/OVERWRITE commits.""" + table, _ = _create_mock_table(latest_snapshot_id=7) + + # Setup mocks + mock_snapshot_manager = MockSnapshotManager.return_value + + # Snapshots: 6 (COMPACT - skip), 7 (APPEND - scan) + snapshot_7 = _create_mock_snapshot(7, "APPEND") + + # find_next_scannable returns (snapshot, next_id, skipped_count) + # Start at 6, skip 1 (COMPACT), return snapshot 7, next_id=8 + mock_snapshot_manager.find_next_scannable.return_value = (snapshot_7, 8, 1) + mock_snapshot_manager.get_cache_stats.return_value = {"cache_hits": 0, "cache_misses": 0, "cache_size": 0} + # Mock get_latest_snapshot for diff catch-up check (gap=1, below threshold) + mock_snapshot_manager.get_latest_snapshot.return_value = snapshot_7 + + mock_manifest_list_manager = MockManifestListManager.return_value + mock_manifest_list_manager.read_delta.return_value = [] + + mock_starting_scanner = MockStartingScanner.return_value + mock_starting_scanner.read_manifest_entries.return_value = [] + + scan = AsyncStreamingTableScan(table) + scan.next_snapshot_id = 6 # Start from snapshot 6 + + async def get_plans(): + plans = [] + count = 0 + async for plan in scan.stream(): + plans.append(plan) + count += 1 + if count >= 1: # Get one plan (snapshot 7) + break + return plans + + asyncio.run(get_plans()) + + # Should have skipped snapshot 6 (COMPACT) and scanned 7 (APPEND) + self.assertEqual(scan.next_snapshot_id, 8) + # Verify lookahead skipped 1 snapshot + self.assertEqual(scan._lookahead_skips, 1) + + @patch('pypaimon.read.streaming_table_scan.SnapshotManager') + @patch('pypaimon.read.streaming_table_scan.ManifestListManager') + @patch('pypaimon.read.streaming_table_scan.FileScanner') + def test_stream_sync_yields_plans(self, MockStartingScanner, MockManifestListManager, MockSnapshotManager): + """stream_sync() should provide a synchronous iterator.""" + table, _ = _create_mock_table(latest_snapshot_id=5) + + # Setup mocks + mock_snapshot_manager = MockSnapshotManager.return_value + mock_snapshot_manager.get_latest_snapshot.return_value = _create_mock_snapshot(5) + mock_snapshot_manager.get_snapshot_by_id.return_value = None + + mock_starting_scanner = MockStartingScanner.return_value + mock_starting_scanner.scan.return_value = Plan([]) + + scan = AsyncStreamingTableScan(table) + + # Get first plan synchronously + for plan in scan.stream_sync(): + self.assertIsInstance(plan, Plan) + break # Just get one + + @patch('pypaimon.read.streaming_table_scan.SnapshotManager') + @patch('pypaimon.read.streaming_table_scan.ManifestListManager') + def test_poll_interval_configurable(self, MockManifestListManager, MockSnapshotManager): + """Poll interval should be configurable.""" + table, _ = _create_mock_table() + + scan = AsyncStreamingTableScan(table, poll_interval_ms=500) + + self.assertEqual(scan.poll_interval, 0.5) + + @patch('pypaimon.read.streaming_table_scan.SnapshotManager') + @patch('pypaimon.read.streaming_table_scan.ManifestListManager') + @patch('pypaimon.read.streaming_table_scan.FileScanner') + def test_no_snapshot_waits_and_polls(self, MockStartingScanner, MockManifestListManager, MockSnapshotManager): + """When no new snapshot exists, should wait and poll again.""" + table, _ = _create_mock_table(latest_snapshot_id=5) + + mock_snapshot_manager = MockSnapshotManager.return_value + mock_snapshot_manager.get_cache_stats.return_value = {"cache_hits": 0, "cache_misses": 0, "cache_size": 0} + + # No snapshot 6 exists yet - find_next_scannable returns (None, 6, 0) first, + # then on subsequent calls returns a snapshot + call_count = [0] + snapshot_6 = _create_mock_snapshot(6, "APPEND") + + def find_next_scannable(start_id, should_scan, lookahead_size=10, max_workers=4): + call_count[0] += 1 + # After 3 calls, snapshot 6 appears + if call_count[0] > 3: + return (snapshot_6, 7, 0) + # No snapshot yet - return (None, start_id, 0) to indicate no snapshot exists + return (None, start_id, 0) + + mock_snapshot_manager.find_next_scannable.side_effect = find_next_scannable + mock_snapshot_manager.get_latest_snapshot.return_value = None + + mock_manifest_list_manager = MockManifestListManager.return_value + mock_manifest_list_manager.read_delta.return_value = [] + + mock_starting_scanner = MockStartingScanner.return_value + mock_starting_scanner.read_manifest_entries.return_value = [] + + scan = AsyncStreamingTableScan(table, poll_interval_ms=10) + scan.next_snapshot_id = 6 + + async def get_plan_with_timeout(): + async for plan in scan.stream(): + return plan + + # Should eventually get a plan after polling + plan = asyncio.run(asyncio.wait_for(get_plan_with_timeout(), timeout=1.0)) + self.assertIsInstance(plan, Plan) + + +class StreamingPrefetchTest(unittest.TestCase): + """Tests for prefetching functionality in AsyncStreamingTableScan.""" + + @patch('pypaimon.read.streaming_table_scan.SnapshotManager') + @patch('pypaimon.read.streaming_table_scan.ManifestListManager') + @patch('pypaimon.read.streaming_table_scan.ManifestFileManager') + def test_prefetch_enabled_by_default(self, MockManifestFileManager, MockManifestListManager, MockSnapshotManager): + """Prefetching should be enabled by default.""" + table, _ = _create_mock_table() + scan = AsyncStreamingTableScan(table) + self.assertTrue(scan._prefetch_enabled) + + @patch('pypaimon.read.streaming_table_scan.SnapshotManager') + @patch('pypaimon.read.streaming_table_scan.ManifestListManager') + @patch('pypaimon.read.streaming_table_scan.ManifestFileManager') + def test_prefetch_can_be_disabled(self, MockManifestFileManager, MockManifestListManager, MockSnapshotManager): + """Prefetching can be disabled via constructor parameter.""" + table, _ = _create_mock_table() + scan = AsyncStreamingTableScan(table, prefetch_enabled=False) + self.assertFalse(scan._prefetch_enabled) + + @patch('pypaimon.read.streaming_table_scan.SnapshotManager') + @patch('pypaimon.read.streaming_table_scan.ManifestListManager') + @patch('pypaimon.read.streaming_table_scan.ManifestFileManager') + def test_prefetch_starts_after_yielding_plan( + self, + MockManifestFileManager, + MockManifestListManager, + MockSnapshotManager): + """After yielding a plan, prefetch for next snapshot should start.""" + table, _ = _create_mock_table(latest_snapshot_id=5) + + mock_snapshot_manager = MockSnapshotManager.return_value + mock_manifest_list_manager = MockManifestListManager.return_value + mock_manifest_file_manager = MockManifestFileManager.return_value + mock_snapshot_manager.get_cache_stats.return_value = {"cache_hits": 0, "cache_misses": 0, "cache_size": 0} + + # Snapshots 5, 6, 7 exist - find_next_scannable returns each one + snapshot_5 = _create_mock_snapshot(5, "APPEND") + snapshot_6 = _create_mock_snapshot(6, "APPEND") + snapshot_7 = _create_mock_snapshot(7, "APPEND") + + call_count = [0] + + def find_next_scannable(start_id, should_scan, lookahead_size=10, max_workers=4): + call_count[0] += 1 + if start_id == 5: + return (snapshot_5, 6, 0) + elif start_id == 6: + return (snapshot_6, 7, 0) + elif start_id == 7: + return (snapshot_7, 8, 0) + return (None, start_id, 0) + + mock_snapshot_manager.find_next_scannable.side_effect = find_next_scannable + mock_snapshot_manager.get_latest_snapshot.return_value = None + + mock_manifest_list_manager.read_delta.return_value = [] + mock_manifest_file_manager.read_entries_parallel.return_value = [] + + scan = AsyncStreamingTableScan(table, poll_interval_ms=10) + scan.next_snapshot_id = 5 + + async def get_two_plans(): + plans = [] + async for plan in scan.stream(): + plans.append(plan) + # After first plan, prefetch task should exist + if len(plans) == 1: + # Give prefetch a moment to start + await asyncio.sleep(0.01) + self.assertIsNotNone(scan._prefetch_future) + if len(plans) >= 2: + break + return plans + + plans = asyncio.run(get_two_plans()) + self.assertEqual(len(plans), 2) + + @patch('pypaimon.read.streaming_table_scan.SnapshotManager') + @patch('pypaimon.read.streaming_table_scan.ManifestListManager') + @patch('pypaimon.read.streaming_table_scan.ManifestFileManager') + def test_prefetch_returns_same_data_as_sequential( + self, + MockManifestFileManager, + MockManifestListManager, + MockSnapshotManager): + """Prefetched plans should contain the same data as non-prefetched.""" + table, _ = _create_mock_table(latest_snapshot_id=5) + + mock_snapshot_manager = MockSnapshotManager.return_value + mock_manifest_list_manager = MockManifestListManager.return_value + mock_manifest_file_manager = MockManifestFileManager.return_value + mock_snapshot_manager.get_cache_stats.return_value = {"cache_hits": 0, "cache_misses": 0, "cache_size": 0} + + snapshot_5 = _create_mock_snapshot(5, "APPEND") + snapshot_6 = _create_mock_snapshot(6, "APPEND") + + def find_next_scannable(start_id, should_scan, lookahead_size=10, max_workers=4): + if start_id == 5: + return (snapshot_5, 6, 0) + elif start_id == 6: + return (snapshot_6, 7, 0) + return (None, start_id, 0) + + mock_snapshot_manager.find_next_scannable.side_effect = find_next_scannable + mock_snapshot_manager.get_latest_snapshot.return_value = None + + mock_manifest_list_manager.read_delta.return_value = [] + mock_manifest_file_manager.read_entries_parallel.return_value = [] + + # Test with prefetch enabled + scan_prefetch = AsyncStreamingTableScan(table, poll_interval_ms=10, prefetch_enabled=True) + scan_prefetch.next_snapshot_id = 5 + + # Test with prefetch disabled + scan_sequential = AsyncStreamingTableScan(table, poll_interval_ms=10, prefetch_enabled=False) + scan_sequential.next_snapshot_id = 5 + + async def get_plans(scan, count): + plans = [] + async for plan in scan.stream(): + plans.append(plan) + if len(plans) >= count: + break + return plans + + plans_prefetch = asyncio.run(get_plans(scan_prefetch, 2)) + plans_sequential = asyncio.run(get_plans(scan_sequential, 2)) + + # Both should get the same number of plans + self.assertEqual(len(plans_prefetch), len(plans_sequential)) + + @patch('pypaimon.read.streaming_table_scan.SnapshotManager') + @patch('pypaimon.read.streaming_table_scan.ManifestListManager') + @patch('pypaimon.read.streaming_table_scan.ManifestFileManager') + def test_prefetch_handles_no_next_snapshot( + self, + MockManifestFileManager, + MockManifestListManager, + MockSnapshotManager): + """When no next snapshot exists, prefetch should return None gracefully.""" + table, _ = _create_mock_table(latest_snapshot_id=5) + + mock_snapshot_manager = MockSnapshotManager.return_value + mock_manifest_list_manager = MockManifestListManager.return_value + mock_manifest_file_manager = MockManifestFileManager.return_value + mock_snapshot_manager.get_cache_stats.return_value = {"cache_hits": 0, "cache_misses": 0, "cache_size": 0} + + snapshot_5 = _create_mock_snapshot(5, "APPEND") + + # Only snapshot 5 exists + def find_next_scannable(start_id, should_scan, lookahead_size=10, max_workers=4): + if start_id == 5: + return (snapshot_5, 6, 0) + # No more snapshots after 5 + return (None, start_id, 0) + + mock_snapshot_manager.find_next_scannable.side_effect = find_next_scannable + mock_snapshot_manager.get_latest_snapshot.return_value = None + + mock_manifest_list_manager.read_delta.return_value = [] + mock_manifest_file_manager.read_entries_parallel.return_value = [] + + scan = AsyncStreamingTableScan(table, poll_interval_ms=10) + scan.next_snapshot_id = 5 + + async def get_one_plan(): + async for plan in scan.stream(): + # After getting plan for snapshot 5, prefetch for 6 should start + # but return None since snapshot 6 doesn't exist + await asyncio.sleep(0.05) # Let prefetch complete + # Prefetch task should have completed (or be None) + return plan + + plan = asyncio.run(get_one_plan()) + self.assertIsInstance(plan, Plan) + + @patch('pypaimon.read.streaming_table_scan.SnapshotManager') + @patch('pypaimon.read.streaming_table_scan.ManifestListManager') + @patch('pypaimon.read.streaming_table_scan.ManifestFileManager') + def test_prefetch_disabled_no_prefetch_future( + self, + MockManifestFileManager, + MockManifestListManager, + MockSnapshotManager): + """With prefetch disabled, no prefetch future should be created.""" + table, _ = _create_mock_table(latest_snapshot_id=5) + + mock_snapshot_manager = MockSnapshotManager.return_value + mock_manifest_list_manager = MockManifestListManager.return_value + mock_manifest_file_manager = MockManifestFileManager.return_value + mock_snapshot_manager.get_cache_stats.return_value = {"cache_hits": 0, "cache_misses": 0, "cache_size": 0} + + snapshot_5 = _create_mock_snapshot(5, "APPEND") + + def find_next_scannable(start_id, should_scan, lookahead_size=10, max_workers=4): + if start_id == 5: + return (snapshot_5, 6, 0) + return (None, start_id, 0) + + mock_snapshot_manager.find_next_scannable.side_effect = find_next_scannable + mock_snapshot_manager.get_latest_snapshot.return_value = None + + mock_manifest_list_manager.read_delta.return_value = [] + mock_manifest_file_manager.read_entries_parallel.return_value = [] + + scan = AsyncStreamingTableScan(table, poll_interval_ms=10, prefetch_enabled=False) + scan.next_snapshot_id = 5 + + async def get_one_plan(): + async for plan in scan.stream(): + await asyncio.sleep(0.01) + # With prefetch disabled, no task should exist + self.assertIsNone(scan._prefetch_future) + return plan + + asyncio.run(get_one_plan()) + + +class SnapshotManagerCacheTest(unittest.TestCase): + """Tests for snapshot caching and batch lookahead in SnapshotManager.""" + + @patch('pypaimon.snapshot.snapshot_manager.JSON') + def test_get_snapshot_by_id_uses_cache(self, MockJSON): + """Repeated calls to get_snapshot_by_id should use cache.""" + from pypaimon.snapshot.snapshot_manager import SnapshotManager + + table = Mock() + table.table_path = "/tmp/test_table" + table.file_io = Mock() + table.file_io.exists.return_value = True + table.file_io.read_file_utf8.return_value = '{"id": 5}' + + mock_snapshot = _create_mock_snapshot(5) + MockJSON.from_json.return_value = mock_snapshot + + manager = SnapshotManager(table) + + # First call - cache miss + result1 = manager.get_snapshot_by_id(5) + self.assertEqual(result1.id, 5) + self.assertEqual(manager._cache_misses, 1) + self.assertEqual(manager._cache_hits, 0) + + # Second call - cache hit + result2 = manager.get_snapshot_by_id(5) + self.assertEqual(result2.id, 5) + self.assertEqual(manager._cache_misses, 1) # No new miss + self.assertEqual(manager._cache_hits, 1) + + # File IO should only be called once + self.assertEqual(table.file_io.read_file_utf8.call_count, 1) + + def test_get_cache_stats_returns_correct_values(self): + """get_cache_stats should return accurate statistics.""" + from pypaimon.snapshot.snapshot_manager import SnapshotManager + + table = Mock() + table.table_path = "/tmp/test_table" + table.file_io = Mock() + table.file_io.exists.return_value = False # No snapshots exist + + manager = SnapshotManager(table) + + # Trigger some cache misses + manager.get_snapshot_by_id(1) + manager.get_snapshot_by_id(2) + + stats = manager.get_cache_stats() + + self.assertEqual(stats["cache_misses"], 2) + self.assertEqual(stats["cache_hits"], 0) + self.assertEqual(stats["cache_size"], 0) # Nothing cached since no snapshots exist + + def test_find_next_scannable_returns_first_matching(self): + """find_next_scannable should return the first snapshot that passes should_scan.""" + from pypaimon.snapshot.snapshot_manager import SnapshotManager + + table = Mock() + table.table_path = "/tmp/test_table" + table.file_io = Mock() + table.file_io.exists_batch.return_value = { + "/tmp/test_table/snapshot/snapshot-5": True, + "/tmp/test_table/snapshot/snapshot-6": True, + "/tmp/test_table/snapshot/snapshot-7": True, + } + + # Create mock snapshots with different commit kinds + snapshots = { + 5: _create_mock_snapshot(5, "COMPACT"), + 6: _create_mock_snapshot(6, "COMPACT"), + 7: _create_mock_snapshot(7, "APPEND"), + } + + manager = SnapshotManager(table) + + # Mock get_snapshot_by_id to return our test snapshots + def mock_get_snapshot(sid): + manager._cache_misses += 1 + return snapshots.get(sid) + + manager.get_snapshot_by_id = mock_get_snapshot + + # should_scan only accepts APPEND commits + def should_scan(snapshot): + return snapshot.commit_kind == "APPEND" + + result, next_id, skipped_count = manager.find_next_scannable(5, should_scan, lookahead_size=5) + + self.assertEqual(result.id, 7) # First APPEND snapshot + self.assertEqual(next_id, 8) # Next ID to check + self.assertEqual(skipped_count, 2) # Skipped snapshots 5 and 6 + + def test_find_next_scannable_returns_none_when_no_snapshot_exists(self): + """find_next_scannable should return None when no snapshot exists at start_id.""" + from pypaimon.snapshot.snapshot_manager import SnapshotManager + + table = Mock() + table.table_path = "/tmp/test_table" + table.file_io = Mock() + # All paths return False (no files exist) + table.file_io.exists_batch.return_value = {} + + manager = SnapshotManager(table) + + def should_scan(snapshot): + return True + + result, next_id, skipped_count = manager.find_next_scannable(5, should_scan, lookahead_size=5) + + self.assertIsNone(result) + self.assertEqual(next_id, 5) # Still at start_id + self.assertEqual(skipped_count, 0) + + def test_find_next_scannable_continues_when_all_skipped(self): + """When all lookahead snapshots are skipped, next_id should be start+lookahead.""" + from pypaimon.snapshot.snapshot_manager import SnapshotManager + + table = Mock() + table.table_path = "/tmp/test_table" + table.file_io = Mock() + + # All 3 snapshots exist but are COMPACT (will be skipped) + table.file_io.exists_batch.return_value = { + "/tmp/test_table/snapshot/snapshot-5": True, + "/tmp/test_table/snapshot/snapshot-6": True, + "/tmp/test_table/snapshot/snapshot-7": True, + } + + snapshots = { + 5: _create_mock_snapshot(5, "COMPACT"), + 6: _create_mock_snapshot(6, "COMPACT"), + 7: _create_mock_snapshot(7, "COMPACT"), + } + + manager = SnapshotManager(table) + + def mock_get_snapshot(sid): + manager._cache_misses += 1 + return snapshots.get(sid) + + manager.get_snapshot_by_id = mock_get_snapshot + + def should_scan(snapshot): + return snapshot.commit_kind == "APPEND" + + result, next_id, skipped_count = manager.find_next_scannable(5, should_scan, lookahead_size=3) + + self.assertIsNone(result) # No APPEND found + self.assertEqual(next_id, 8) # 5 + 3 = 8, continue from here + self.assertEqual(skipped_count, 3) # All 3 were skipped + + +class StreamingCatchUpDiffTest(unittest.TestCase): + """Tests for diff-based catch-up optimization in AsyncStreamingTableScan.""" + + @patch('pypaimon.read.streaming_table_scan.IncrementalDiffScanner') + @patch('pypaimon.read.streaming_table_scan.SnapshotManager') + @patch('pypaimon.read.streaming_table_scan.ManifestListManager') + @patch('pypaimon.read.streaming_table_scan.ManifestFileManager') + def test_stream_triggers_diff_catch_up_for_large_gap( + self, MockManifestFileManager, MockManifestListManager, + MockSnapshotManager, MockDiffScanner + ): + """ + When starting with a large gap, stream() should use diff scanner. + + This tests the full flow: + 1. CLI calls restore({"next_snapshot_id": 5}) for --from snapshot:5 + 2. stream() detects large gap (5 to 100, gap=95) + 3. Diff scanner is triggered + 4. _diff_catch_up_used flag is set + """ + table, _ = _create_mock_table(latest_snapshot_id=100) + + mock_snapshot_manager = MockSnapshotManager.return_value + mock_diff_scanner = MockDiffScanner.return_value + + # Setup: latest is 100, start is 5 (gap=95) + mock_snapshot_manager.get_latest_snapshot.return_value = _create_mock_snapshot(100) + mock_snapshot_manager.get_snapshot_by_id.return_value = _create_mock_snapshot(4) # start-1 + + # Diff scanner returns a plan with some splits + mock_split = Mock() + mock_plan = Plan([mock_split]) + mock_diff_scanner.scan.return_value = mock_plan + + scan = AsyncStreamingTableScan(table, poll_interval_ms=10, prefetch_enabled=False) + + # Simulate --from snapshot:5: restore to snapshot 5 + scan.next_snapshot_id = 5 + + # Verify diff catch-up should be used (gap=95 > threshold=10) + self.assertTrue(scan._should_use_diff_catch_up()) + + # Run stream() and get first plan + async def get_first_plan(): + async for plan in scan.stream(): + return plan + + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(get_first_plan()) + finally: + loop.close() + + # Verify diff scanner was used + self.assertTrue(scan._diff_catch_up_used) + MockDiffScanner.assert_called_once_with(table) + mock_diff_scanner.scan.assert_called_once() + + # Verify next_snapshot_id was updated to latest + 1 + self.assertEqual(scan.next_snapshot_id, 101) + + +if __name__ == '__main__': + unittest.main()