diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 3a42d3d96..c32cea1d0 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -1893,21 +1893,29 @@ def _tree_repr_argument_construction(self, i, op, header): def _node_label_args(self): return self.operands + def _remove_common_columns(self, other): + if set(self.keys) & set(other.keys): + keys = set(self.keys) + operands = [[k, v] for k, v in zip(other.keys, other.vals) if k not in keys] + return [other.frame] + list(flatten(operands)) + self.operands[1:] + else: + return other.operands + self.operands[1:] + def _simplify_down(self): if isinstance(self.frame, Assign): if self._check_for_previously_created_column(self.frame): # don't squash if we are using a column that was previously created return - return Assign(*self.frame.operands, *self.operands[1:]) + return Assign(*self._remove_common_columns(self.frame)) elif isinstance(self.frame, Projection) and isinstance( self.frame.frame, Assign ): if self._check_for_previously_created_column(self.frame.frame): return new_columns = self.frame.operands[1].copy() - new_columns.extend(self.keys) + new_columns.extend([k for k in self.keys if k not in new_columns]) return Projection( - Assign(*self.frame.frame.operands, *self.operands[1:]), new_columns + Assign(*self._remove_common_columns(self.frame.frame)), new_columns ) def _check_for_previously_created_column(self, child): diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index f2bee98f7..50bbc041d 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -2689,6 +2689,33 @@ def apply_func(x): assert result.expr._depth() == 13.0 # this grew exponentially previously +def test_assign_overwriting_column(df, pdf): + pdf = pd.DataFrame( + {"partner": ["A", np.nan, "C", "A", np.nan, "C", "A", np.nan, "C"]}, + dtype="string[pyarrow]", + ) + + df = from_pandas(pdf, npartitions=2) + + df["partner1"] = "" + df["partner1"] = "google" + + df["partner2"] = "" + df["partner2"] = np.nan + df["partner2"] = df["partner2"].mask(cond=(df["partner"] == "A"), other="Blackhawk") + + pdf["partner1"] = "" + pdf["partner1"] = "google" + + pdf["partner2"] = "" + pdf["partner2"] = np.nan + pdf["partner2"] = pdf["partner2"].mask( + cond=(pdf["partner"] == "A"), other="Blackhawk" + ) + df.compute() + assert_eq(df, pdf, check_dtype=False) + + def test_dropna_merge(df, pdf): dropped_na = df.dropna(subset=["x"]) result = dropped_na.merge(dropped_na, on="x")