From a525a795f835e590ede265a49c720d33b2d8c833 Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Fri, 24 Apr 2026 09:40:58 -0700 Subject: [PATCH 1/2] [SPARK-56322][CONNECT][PYTHON] Fix TypeError when self-joining observed DataFrames --- python/pyspark/sql/connect/plan.py | 12 +- .../sql/tests/connect/test_connect_plan.py | 95 ++++++++++++++- python/pyspark/sql/tests/test_observation.py | 114 ++++++++++++++++++ 3 files changed, 214 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 2a71a94bde21e..3dac4fc47ee70 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -1120,7 +1120,7 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation: @property def observations(self) -> Dict[str, "Observation"]: - return dict(**super().observations, **self.right.observations) + return {**super().observations, **self.right.observations} def print(self, indent: int = 0) -> str: i = " " * indent @@ -1213,7 +1213,7 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation: @property def observations(self) -> Dict[str, "Observation"]: - return dict(**super().observations, **self.right.observations) + return {**super().observations, **self.right.observations} def print(self, indent: int = 0) -> str: assert self.left is not None @@ -1288,7 +1288,7 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation: @property def observations(self) -> Dict[str, "Observation"]: - return dict(**super().observations, **self.right.observations) + return {**super().observations, **self.right.observations} def print(self, indent: int = 0) -> str: i = " " * indent @@ -1354,10 +1354,10 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation: @property def observations(self) -> Dict[str, "Observation"]: - return dict( + return { **super().observations, **(self.other.observations if self.other is not None else {}), - ) + } def print(self, indent: int = 0) -> str: assert self._child is not None @@ -1664,7 +1664,7 @@ def observations(self) -> Dict[str, "Observation"]: observations = {str(self._observation._name): self._observation} else: observations = {} - return dict(**super().observations, **observations) + return {**super().observations, **observations} class NAFill(LogicalPlan): diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py index 10a21b979bc4f..a0ea44c8e531d 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan.py @@ -27,11 +27,21 @@ ) from pyspark.errors import PySparkValueError +from unittest.mock import MagicMock + if should_test_connect: import pyspark.sql.connect.proto as proto from pyspark.sql.connect.column import Column from pyspark.sql.connect.dataframe import DataFrame - from pyspark.sql.connect.plan import WriteOperation, Read + from pyspark.sql.connect.plan import ( + WriteOperation, + Read, + Join, + SetOperation, + CollectMetrics, + LogicalPlan, + ) + from pyspark.sql.connect.observation import Observation from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.expressions import LiteralExpression from pyspark.sql.connect.functions import col, lit, max, min, sum @@ -1131,6 +1141,89 @@ def test_literal_to_any_conversion(self): LiteralExpression._to_value(proto_lit, DoubleType) +if should_test_connect: + + class _StubPlan(LogicalPlan): + """Minimal LogicalPlan that returns a fixed observations dict.""" + + def __init__(self, observations=None): + super().__init__(None) + self._obs = observations or {} + + @property + def observations(self): + return self._obs + + def plan(self, session): + raise NotImplementedError + + def print(self, indent=0): + return "" + + +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class TestObservationMerging(unittest.TestCase): + """Verify that observations are deduplicated when plan branches share the same key.""" + + def test_join_with_duplicate_observation_names(self): + obs = MagicMock() + obs._name = "shared" + shared = {"shared": obs} + + left = _StubPlan(observations=shared) + right = _StubPlan(observations=shared) + + join = Join.__new__(Join) + join._child = left + join.right = right + + result = join.observations + self.assertEqual(result, {"shared": obs}) + + def test_join_with_distinct_observations(self): + obs_a = MagicMock() + obs_a._name = "a" + obs_b = MagicMock() + obs_b._name = "b" + + left = _StubPlan(observations={"a": obs_a}) + right = _StubPlan(observations={"b": obs_b}) + + join = Join.__new__(Join) + join._child = left + join.right = right + + result = join.observations + self.assertEqual(result, {"a": obs_a, "b": obs_b}) + + def test_set_operation_with_duplicate_observation_names(self): + obs = MagicMock() + obs._name = "shared" + shared = {"shared": obs} + + left = _StubPlan(observations=shared) + right = _StubPlan(observations=shared) + + set_op = SetOperation.__new__(SetOperation) + set_op._child = left + set_op.other = right + + result = set_op.observations + self.assertEqual(result, {"shared": obs}) + + def test_collect_metrics_with_duplicate_observation_name(self): + obs = Observation("my_metric") + parent = _StubPlan(observations={"my_metric": obs}) + + cm = CollectMetrics.__new__(CollectMetrics) + cm._child = parent + cm._observation = obs + cm._exprs = [] + + result = cm.observations + self.assertEqual(result, {"my_metric": obs}) + + if __name__ == "__main__": from pyspark.testing import main diff --git a/python/pyspark/sql/tests/test_observation.py b/python/pyspark/sql/tests/test_observation.py index cdb33dd2fd497..440bc207a756e 100644 --- a/python/pyspark/sql/tests/test_observation.py +++ b/python/pyspark/sql/tests/test_observation.py @@ -18,6 +18,7 @@ from pyspark.sql import Row, Observation, functions as F from pyspark.sql.types import StructType, LongType from pyspark.errors import ( + AnalysisException, PySparkAssertionError, PySparkException, PySparkTypeError, @@ -86,6 +87,16 @@ def test_observe(self): messageParameters={}, ) + new_observation = Observation("metric") + with self.assertRaises(AnalysisException) as pe: + observed.observe(new_observation, 2 * F.count(F.lit(1)).alias("cnt")).collect() + + self.check_error( + exception=pe.exception, + errorClass="DUPLICATED_METRICS_NAME", + messageParameters={"metricName": "metric"}, + ) + # observation requires name (if given) to be non empty string with self.assertRaisesRegex(PySparkTypeError, "`name` should be str, got int"): Observation(123) @@ -263,6 +274,109 @@ def test_observation_errors_propagated_to_client(self): self.assertIn("test error", str(cm.exception)) + def test_observe_self_join(self): + # SPARK-56322: self-joining an observed DataFrame + obs = Observation("my_observation") + df = ( + self.spark.range(100) + .selectExpr("id", "CASE WHEN id < 10 THEN 'A' ELSE 'B' END AS group_key") + .observe(obs, F.count(F.lit(1)).alias("row_count")) + ) + + df1 = df.where("id < 20") + df2 = df.where("id % 2 == 0") + + joined = df1.alias("a").join(df2.alias("b"), on=["id"], how="inner") + result = joined.collect() + + # The join should produce rows where id < 20 AND id is even + expected_ids = sorted([i for i in range(20) if i % 2 == 0]) + actual_ids = sorted([row.id for row in result]) + self.assertEqual(actual_ids, expected_ids) + + # The observation should have been collected + self.assertEqual(obs.get, {"row_count": 100}) + + # Check the error conditions + with self.assertRaises(PySparkAssertionError) as pe: + joined.observe(obs, F.count(F.lit(1)).alias("row_count")).collect() + + self.check_error( + exception=pe.exception, + errorClass="REUSE_OBSERVATION", + messageParameters={}, + ) + + obs2 = Observation("my_observation") + with self.assertRaises(AnalysisException) as pe: + joined.observe(obs2, 2 * F.count(F.lit(1)).alias("row_count")).collect() + + self.check_error( + exception=pe.exception, + errorClass="DUPLICATED_METRICS_NAME", + messageParameters={"metricName": "my_observation"}, + ) + + def test_observe_lateral_join(self): + # SPARK-56322: lateral self-joining an observed DataFrame + obs = Observation("lateral_join_observation") + df = self.spark.range(50).observe(obs, F.count(F.lit(1)).alias("row_count")) + + joined = ( + df.alias("left") + .lateralJoin( + df.alias("right"), on=F.expr("right.id between left.id - 1 and left.id + 1") + ) + .selectExpr("left.id as left_id", "right.id as right_id") + ) + result = joined.collect() + + # Joins on row 0 should produce rows 0 and 1 + bounded_matches = sorted([r.right_id for r in result if r.left_id == 0]) + self.assertEqual(bounded_matches, [0, 1]) + + # Joins on row 25 should produce rows 24, 25, and 26 + unbounded_matches = sorted([r.right_id for r in result if r.left_id == 25]) + self.assertEqual(unbounded_matches, [24, 25, 26]) + + # The observation should have been collected + self.assertEqual(obs.get, {"row_count": 50}) + + # Check the error conditions + with self.assertRaises(PySparkAssertionError) as reused: + joined.observe(obs, F.count(F.lit(1)).alias("row_count")).collect() + + self.check_error( + exception=reused.exception, + errorClass="REUSE_OBSERVATION", + messageParameters={}, + ) + + obs2 = Observation("lateral_join_observation") + with self.assertRaises(AnalysisException) as pe: + joined.observe(obs2, F.count(2 * F.lit(1)).alias("row_count")).collect() + + self.check_error( + exception=pe.exception, + errorClass="DUPLICATED_METRICS_NAME", + messageParameters={"metricName": "lateral_join_observation"}, + ) + + def test_observe_self_join_union(self): + # SPARK-56322: union of observed DataFrames with same observation + obs = Observation("union_obs") + df = self.spark.range(50).observe(obs, F.count(F.lit(1)).alias("cnt")) + + df1 = df.where("id < 25") + df2 = df.where("id >= 25") + + unioned = df1.union(df2) + result = unioned.collect() + + actual_ids = sorted([row.id for row in result]) + self.assertEqual(actual_ids, list(range(50))) + self.assertEqual(obs.get, {"cnt": 50}) + class DataFrameObservationTests( DataFrameObservationTestsMixin, From d68b0f314611e2e26743b7429b189ac1a8233b5e Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Fri, 24 Apr 2026 09:41:04 -0700 Subject: [PATCH 2/2] [SPARK-46160][PYTHON] Add axis parameter to DataFrame.shift in pandas API on Spark Co-Authored-By: Claude --- python/pyspark/pandas/frame.py | 71 ++++++++++++-- .../pandas/tests/frame/test_time_series.py | 93 +++++++++++++++++++ 2 files changed, 158 insertions(+), 6 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index fb86e37999eb4..a5ba0c3d6fa65 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -4790,12 +4790,17 @@ def insert( psdf = psdf[columns] self._update_internal_frame(psdf._internal) - # TODO(SPARK-46156): add frep and axis parameter - def shift(self, periods: int = 1, fill_value: Optional[Any] = None) -> "DataFrame": + # TODO(SPARK-46160): add freq parameter + def shift( + self, + periods: int = 1, + fill_value: Optional[Any] = None, + axis: Axis = 0, + ) -> "DataFrame": """ Shift DataFrame by desired number of periods. - .. note:: the current implementation of shift uses Spark's Window without + .. note:: When axis=0, the current implementation of shift uses Spark's Window without specifying partition specification. This leads to moving all data into a single partition in a single machine and could cause serious performance degradation. Avoid this method with very large datasets. @@ -4807,6 +4812,13 @@ def shift(self, periods: int = 1, fill_value: Optional[Any] = None) -> "DataFram fill_value : object, optional The scalar value to use for newly introduced missing values. The default depends on the dtype of self. For numeric data, np.nan is used. + axis : {0 or 'index', 1 or 'columns'}, default 0 + Axis along which to shift: + + * 0 or 'index': shift each column independently (down/up rows) + * 1 or 'columns': shift each row independently (across columns) + + .. versionchanged:: 4.2.0 Returns ------- @@ -4835,10 +4847,57 @@ def shift(self, periods: int = 1, fill_value: Optional[Any] = None) -> "DataFram 3 10 13 17 4 20 23 27 + Shift across columns with axis=1: + + >>> df = ps.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6], 'C': [7, 8, 9]}, + ... columns=['A', 'B', 'C']) + >>> df.shift(periods=1, axis=1).sort_index() + A B C + 0 NaN 1.0 4.0 + 1 NaN 2.0 5.0 + 2 NaN 3.0 6.0 """ - return self._apply_series_op( - lambda psser: psser._shift(periods, fill_value), should_resolve=True - ) + if not isinstance(periods, int): + raise TypeError("periods should be an int; however, got [%s]" % type(periods).__name__) + + axis = validate_axis(axis) + + if axis == 0: + return self._apply_series_op( + lambda psser: psser._shift(periods, fill_value), should_resolve=True + ) + else: + # Infer result schema from a small sample, following the same + # pattern as apply() (shortcut_limit). + limit = get_option("compute.shortcut_limit") + pdf = self.head(limit + 1)._to_internal_pandas() + pdf_shifted = pdf.shift(periods=periods, fill_value=fill_value, axis=1) + if len(pdf) <= limit: + return DataFrame(InternalFrame.from_pandas(pdf_shifted)) + + # Use the shifted sample to infer return types so that the UDF + # path produces consistent dtypes with the fast path. + psdf_shifted = DataFrame(InternalFrame.from_pandas(pdf_shifted)) + data_fields = [ + field.normalize_spark_type() for field in psdf_shifted._internal.data_fields + ] + return_schema = StructType([field.struct_field for field in data_fields]) + + column_label_strings = [ + name_like_string(label) for label in self._internal.column_labels + ] + + @pandas_udf(returnType=return_schema) # type: ignore[call-overload] + def shift_axis_1(*cols: pd.Series) -> pd.DataFrame: + pdf_row = pd.concat(cols, axis=1, keys=column_label_strings) + return pdf_row.shift(periods=periods, fill_value=fill_value, axis=1) + + shifted_struct_col = shift_axis_1(*self._internal.data_spark_columns) + new_data_columns = [ + shifted_struct_col[col_name].alias(col_name) for col_name in column_label_strings + ] + internal = self._internal.with_new_columns(new_data_columns, data_fields=data_fields) + return DataFrame(internal) # TODO(SPARK-46161): axis should support 1 or 'columns' either at this moment def diff(self, periods: int = 1, axis: Axis = 0) -> "DataFrame": diff --git a/python/pyspark/pandas/tests/frame/test_time_series.py b/python/pyspark/pandas/tests/frame/test_time_series.py index 9fe4b39073b93..73e6638b7bdcd 100644 --- a/python/pyspark/pandas/tests/frame/test_time_series.py +++ b/python/pyspark/pandas/tests/frame/test_time_series.py @@ -60,6 +60,99 @@ def test_shift(self): self.assert_eq(pdf.shift().shift(-1), psdf.shift().shift(-1)) self.assert_eq(pdf.shift(0), psdf.shift(0)) + def test_shift_axis(self): + # SPARK-46160: shift with axis parameter + pdf = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + psdf = ps.from_pandas(pdf) + + # Test axis=0 (explicit, should match default behavior) + self.assert_eq(pdf.shift(axis=0).sort_index(), psdf.shift(axis=0).sort_index()) + + # Test axis=1 (shift across columns) + self.assert_eq(pdf.shift(axis=1).sort_index(), psdf.shift(axis=1).sort_index()) + + # Test axis='index' and axis='columns' + self.assert_eq(pdf.shift(axis="index").sort_index(), psdf.shift(axis="index").sort_index()) + self.assert_eq( + pdf.shift(axis="columns").sort_index(), psdf.shift(axis="columns").sort_index() + ) + + # Test various periods with axis=1 + for periods in [1, -1, 2, -2, 0]: + self.assert_eq( + pdf.shift(periods=periods, axis=1).sort_index(), + psdf.shift(periods=periods, axis=1).sort_index(), + ) + + # Test fill_value with axis=1 + self.assert_eq( + pdf.shift(periods=1, fill_value=0, axis=1).sort_index(), + psdf.shift(periods=1, fill_value=0, axis=1).sort_index(), + ) + + # Test with single column DataFrame + pdf_single = pd.DataFrame({"A": [1, 2, 3]}) + psdf_single = ps.from_pandas(pdf_single) + self.assert_eq( + pdf_single.shift(axis=1).sort_index(), + psdf_single.shift(axis=1).sort_index(), + ) + + # Test with NaN values + pdf_nan = pd.DataFrame({"A": [1, np.nan, 3], "B": [4, 3, np.nan]}) + psdf_nan = ps.from_pandas(pdf_nan) + self.assert_eq( + pdf_nan.shift(axis=1).sort_index(), + psdf_nan.shift(axis=1).sort_index(), + ) + + # Test with multi-index columns + columns = pd.MultiIndex.from_tuples([("x", "A"), ("x", "B"), ("y", "C")]) + pdf.columns = columns + psdf.columns = columns + self.assert_eq(pdf.shift(axis=1).sort_index(), psdf.shift(axis=1).sort_index()) + + # Test with large dataset to ensure UDF path is used (>1000 rows) + rng = np.random.RandomState(42) + pdf_large = pd.DataFrame({"A": rng.rand(1500), "B": rng.rand(1500), "C": rng.rand(1500)}) + psdf_large = ps.from_pandas(pdf_large) + self.assert_eq( + pdf_large.shift(axis=1).sort_index(), + psdf_large.shift(axis=1).sort_index(), + ) + + # Test fill_value on UDF path (large dataset) + self.assert_eq( + pdf_large.shift(periods=1, fill_value=0, axis=1).sort_index(), + psdf_large.shift(periods=1, fill_value=0, axis=1).sort_index(), + ) + + # Test periods larger than number of columns (should produce all NaN) + self.assert_eq( + pdf.shift(periods=5, axis=1).sort_index(), + psdf.shift(periods=5, axis=1).sort_index(), + ) + + # Test with mixed numeric types (int + float) + pdf_mixed = pd.DataFrame({"A": [1, 2, 3], "B": [4.0, 5.0, 6.0], "C": [7, 8, 9]}) + psdf_mixed = ps.from_pandas(pdf_mixed) + self.assert_eq( + pdf_mixed.shift(axis=1).sort_index(), + psdf_mixed.shift(axis=1).sort_index(), + ) + + # Test with empty DataFrame + pdf_empty = pd.DataFrame({"A": pd.Series([], dtype="float64")}) + psdf_empty = ps.from_pandas(pdf_empty) + self.assert_eq( + pdf_empty.shift(axis=1).sort_index(), + psdf_empty.shift(axis=1).sort_index(), + ) + + # Test invalid axis value + with self.assertRaisesRegex(ValueError, "No axis named"): + psdf.shift(axis=2) + def test_first_valid_index(self): pdf = pd.DataFrame( {"a": [None, 2, 3, 2], "b": [None, 2.0, 3.0, 1.0], "c": [None, 200, 400, 200]},