diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 3aac656c4..f70cae81b 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -53,6 +53,12 @@ def _tune_down(self): def _tune_up(self, parent): return None + def _pipe_down(self): + return None + + def _pipe_up(self, parent): + return None + def _cull_down(self): return None @@ -342,6 +348,7 @@ def simplify(self) -> Expr: while True: dependents = collect_dependents(expr) new = expr.simplify_once(dependents=dependents, simplified={}) + new = new.rewrite("pipe") if new._name == expr._name: break expr = new diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index bdd821fd0..67c56560f 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -1205,7 +1205,7 @@ def _meta(self): args = [ meta_nonempty(op._meta) if isinstance(op, Expr) else op for op in self._args ] - return self.operation(*args, **self._kwargs) + return make_meta(self.operation(*args, **self._kwargs)) @staticmethod def operation(df, index, sorted_index): @@ -2062,6 +2062,9 @@ class ResetIndex(Elemwise): operation = M.reset_index _filter_passthrough = True + def __new__(cls, *args, **kwargs): + return super().__new__(cls, *args, **kwargs) + @functools.cached_property def _kwargs(self) -> dict: kwargs = {"drop": self.drop} @@ -2099,7 +2102,9 @@ def _simplify_up(self, parent, dependents): return self._filter_simplification(parent, predicate) if isinstance(parent, Projection): - if self.frame.ndim == 1 and not self.drop and not isinstance(parent, list): + if self.frame.ndim == 1 and not self.drop: + if isinstance(parent.operand("columns"), list): + return col = parent.operand("columns") if col in (self.name, "index"): return diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index faed278b3..5654382ec 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -136,7 +136,7 @@ class GroupByApplyConcatApply(ApplyConcatApply, GroupByBase): @functools.cached_property def _meta_chunk(self): meta = meta_nonempty(self.frame._meta) - return self.chunk(meta, *self._by_meta, **self.chunk_kwargs) + return make_meta(self.chunk(meta, *self._by_meta, **self.chunk_kwargs)) @property def _chunk_cls_args(self): @@ -201,6 +201,7 @@ class SingleAggregation(GroupByApplyConcatApply, GroupByBase): "split_out", "sort", "shuffle_method", + "_pipeline_breaker_counter", ] _defaults = { "observed": None, @@ -212,6 +213,7 @@ class SingleAggregation(GroupByApplyConcatApply, GroupByBase): "split_out": None, "sort": None, "shuffle_method": None, + "_pipeline_breaker_counter": None, } groupby_chunk = None @@ -251,7 +253,11 @@ def aggregate_kwargs(self) -> dict: } def _simplify_up(self, parent, dependents): - return groupby_projection(self, parent, dependents) + if isinstance(parent, Projection): + return groupby_projection(self, parent, dependents) + + def _pipe_down(self): + return self._adjust_for_pipelinebreaker() class GroupbyAggregationBase(GroupByApplyConcatApply, GroupByBase): @@ -1479,6 +1485,7 @@ def _single_agg( split_out, self.sort, shuffle_method, + None, *self.by, ) ) @@ -2161,6 +2168,7 @@ def nunique(self, split_every=None, split_out=True, shuffle_method=None): split_out, self.sort, shuffle_method, + None, *self.by, ) ) diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index f8072f9bf..eebd72df4 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -507,6 +507,41 @@ def _lower(self): ignore_index=getattr(self, "ignore_index", True), ) + def _adjust_for_pipelinebreaker(self): + if self._pipeline_breaker_counter is not None: + return + from dask_expr.io.io import IO + + seen = set() + stack = self.dependencies() + io_nodes = [] + counter = 1 + + while stack: + node = stack.pop() + + if node._name in seen: + continue + seen.add(node._name) + + if isinstance(node, IO): + io_nodes.append(node) + continue + elif isinstance(node, ApplyConcatApply): + counter += 1 + continue + stack.extend(node.dependencies()) + if len(io_nodes) == 0: + return + io_nodes_new = [ + io.substitute_parameters({"_pipeline_breaker_counter": counter}) + for io in io_nodes + ] + expr = self + for io_node_old, io_node_new in zip(io_nodes, io_nodes_new): + expr = expr.substitute(io_node_old, io_node_new) + return expr.substitute_parameters({"_pipeline_breaker_counter": counter}) + class Unique(ApplyConcatApply): _parameters = ["frame", "split_every", "split_out", "shuffle_method"] @@ -773,13 +808,23 @@ def _simplify_up(self, parent, dependents): if isinstance(parent, Projection): return plain_column_projection(self, parent, dependents) + def _pipe_down(self): + return self._adjust_for_pipelinebreaker() + class Sum(Reduction): - _parameters = ["frame", "skipna", "numeric_only", "split_every"] + _parameters = [ + "frame", + "skipna", + "numeric_only", + "split_every", + "_pipeline_breaker_counter", + ] _defaults = { "split_every": False, "numeric_only": False, "skipna": True, + "_pipeline_breaker_counter": None, } reduction_chunk = M.sum @@ -1090,8 +1135,21 @@ def reduction_aggregate(cls, vals, order): class Mean(Reduction): - _parameters = ["frame", "skipna", "numeric_only", "split_every", "axis"] - _defaults = {"skipna": True, "numeric_only": False, "split_every": False, "axis": 0} + _parameters = [ + "frame", + "skipna", + "numeric_only", + "split_every", + "axis", + "_pipeline_breaker_counter", + ] + _defaults = { + "skipna": True, + "numeric_only": False, + "split_every": False, + "axis": 0, + "_pipeline_breaker_counter": None, + } @functools.cached_property def _meta(self): @@ -1267,8 +1325,21 @@ def _nlast(df, columns, n, ascending): class NFirst(NLargest): - _parameters = ["frame", "n", "_columns", "ascending", "split_every"] - _defaults = {"n": 5, "_columns": None, "ascending": None, "split_every": None} + _parameters = [ + "frame", + "n", + "_columns", + "ascending", + "split_every", + "_pipeline_breaker_counter", + ] + _defaults = { + "n": 5, + "_columns": None, + "ascending": None, + "split_every": None, + "_pipeline_breaker_counter": None, + } reduction_chunk = _nfirst reduction_aggregate = _nfirst diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 25ef1abc2..6e9544677 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -320,6 +320,7 @@ class FromPandas(PartitionsFiltered, BlockwiseIO): "columns", "_partitions", "_series", + "_pipeline_breaker_counter", ] _defaults = { "npartitions": None, @@ -328,6 +329,7 @@ class FromPandas(PartitionsFiltered, BlockwiseIO): "_partitions": None, "_series": False, "chunksize": None, + "_pipeline_breaker_counter": None, } _pd_length_stats = None _absorb_projections = True diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index a103e60c0..dcbf03f72 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -402,6 +402,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "_partitions", "_series", "_dataset_info_cache", + "_pipeline_breaker_counter", ] _defaults = { "columns": None, @@ -422,6 +423,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "_partitions": None, "_series": False, "_dataset_info_cache": None, + "_pipeline_breaker_counter": None, } _pq_length_stats = None _absorb_projections = True