diff --git a/dask_expr/collection.py b/dask_expr/collection.py index b4be52f84..da0c48960 100644 --- a/dask_expr/collection.py +++ b/dask_expr/collection.py @@ -22,6 +22,7 @@ from dask_expr.merge import Merge from dask_expr.reductions import ( DropDuplicates, + Len, MemoryUsageFrame, MemoryUsageIndex, NLargest, @@ -85,6 +86,9 @@ def _meta(self): def size(self): return new_collection(self.expr.size) + def __len__(self): + return new_collection(Len(self.expr)).compute() + @property def nbytes(self): raise NotImplementedError("nbytes is not implemented on DataFrame") diff --git a/dask_expr/expr.py b/dask_expr/expr.py index 015d88546..dad8e7343 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -20,6 +20,7 @@ is_dataframe_like, is_index_like, is_series_like, + make_meta, ) from dask.dataframe.dispatch import meta_nonempty from dask.utils import M, apply, funcname, import_required, is_arraylike @@ -636,6 +637,23 @@ def visualize(self, filename="dask-expr.svg", format=None, **kwargs): return g +class Literal(Expr): + """Represent a literal (known) value as an `Expr`""" + + _parameters = ["value"] + + def _divisions(self): + return (None, None) + + @property + def _meta(self): + return make_meta(self.value) + + def _task(self, index: int): + assert index == 0 + return self.value + + class Blockwise(Expr): """Super-class for block-wise operations @@ -1031,6 +1049,33 @@ def _task(self, index: int): ) +class Lengths(Expr): + """Returns a tuple of partition lengths""" + + _parameters = ["frame"] + + @property + def _meta(self): + return tuple() + + def _divisions(self): + return (None, None) + + def _simplify_down(self): + if isinstance(self.frame, Elemwise): + child = max(self.frame.dependencies(), key=lambda expr: expr.npartitions) + return Lengths(child) + + def _layer(self): + name = "part-" + self._name + dsk = { + (name, i): (len, (self.frame._name, i)) + for i in range(self.frame.npartitions) + } + dsk[(self._name, 0)] = (tuple, list(dsk.keys())) + return dsk + + class ResetIndex(Elemwise): """Reset the index of a Series or DataFrame""" diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 513193cd4..ebee423e9 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import functools import math from dask.dataframe.io.io import sorted_division_locations -from dask_expr.expr import Blockwise, Expr, PartitionsFiltered +from dask_expr.expr import Blockwise, Expr, Lengths, Literal, PartitionsFiltered +from dask_expr.reductions import Len class IO(Expr): @@ -44,6 +47,7 @@ class FromPandas(PartitionsFiltered, BlockwiseIO): _parameters = ["frame", "npartitions", "sort", "_partitions"] _defaults = {"npartitions": 1, "sort": True, "_partitions": None} + _pd_length_stats = None @property def _meta(self): @@ -68,6 +72,27 @@ def _divisions_and_locations(self): divisions = (None,) * len(locations) return divisions, locations + def _get_lengths(self) -> tuple | None: + if self._pd_length_stats is None: + locations = self._locations() + self._pd_length_stats = tuple( + offset - locations[i] + for i, offset in enumerate(locations[1:]) + if not self._filtered or i in self._partitions + ) + return self._pd_length_stats + + def _simplify_up(self, parent): + if isinstance(parent, Lengths): + _lengths = self._get_lengths() + if _lengths: + return Literal(_lengths) + + if isinstance(parent, Len): + _lengths = self._get_lengths() + if _lengths: + return Literal(sum(_lengths)) + def _divisions(self): return self._divisions_and_locations[0] diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 388432878..dd7bea178 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -2,32 +2,44 @@ import itertools import operator +from collections import defaultdict from functools import cached_property +import dask +import pyarrow.parquet as pq from dask.dataframe.io.parquet.core import ( ParquetFunctionWrapper, + aggregate_row_groups, get_engine, - process_statistics, set_index_columns, + sorted_columns, ) from dask.dataframe.io.parquet.utils import _split_user_options +from dask.dataframe.io.utils import _is_local_fs +from dask.delayed import delayed from dask.utils import natural_sort_key -from dask_expr.expr import EQ, GE, GT, LE, LT, NE, And, Expr, Filter, Or, Projection +from dask_expr.expr import ( + EQ, + GE, + GT, + LE, + LT, + NE, + And, + Expr, + Filter, + Lengths, + Literal, + Or, + Projection, +) from dask_expr.io import BlockwiseIO, PartitionsFiltered +from dask_expr.reductions import Len NONE_LABEL = "__null_dask_index__" -def _list_columns(columns): - # Simple utility to convert columns to list - if isinstance(columns, (str, int)): - columns = [columns] - elif isinstance(columns, tuple): - columns = list(columns) - return columns - - class ReadParquet(PartitionsFiltered, BlockwiseIO): """Read a parquet dataset""" @@ -68,6 +80,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "_partitions": None, "_series": False, } + _pq_length_stats = None @property def engine(self): @@ -104,6 +117,16 @@ def _simplify_up(self, parent): kwargs["filters"] = filters.combine(kwargs["filters"]).to_list_tuple() return ReadParquet(**kwargs) + if isinstance(parent, Lengths): + _lengths = self._get_lengths() + if _lengths: + return Literal(_lengths) + + if isinstance(parent, Len): + _lengths = self._get_lengths() + if _lengths: + return Literal(sum(_lengths)) + @cached_property def _dataset_info(self): # Process and split user options @@ -169,6 +192,7 @@ def _dataset_info(self): # Infer meta, accounting for index and columns arguments. meta = self.engine._create_dd_meta(dataset_info) + index = dataset_info["index"] index = [index] if isinstance(index, str) else index meta, index, columns = set_index_columns( meta, index, self.operand("columns"), auto_index_allowed @@ -196,21 +220,14 @@ def _plan(self): dataset_info ) - # Parse dataset statistics from metadata (if available) - parts, divisions, _ = process_statistics( - parts, - stats, - dataset_info["filters"], - dataset_info["index"], - ( - dataset_info["blocksize"] - if dataset_info["split_row_groups"] is True - else None - ), - dataset_info["split_row_groups"], - dataset_info["fs"], - dataset_info["aggregation_depth"], - ) + # Make sure parts and stats are aligned + parts, stats = _align_statistics(parts, stats) + + # Use statistics to aggregate partitions + parts, stats = _aggregate_row_groups(parts, stats, dataset_info) + + # Use statistics to calculate divisions + divisions = _calculate_divisions(stats, dataset_info, len(parts)) meta = dataset_info["meta"] if len(divisions) < 2: @@ -234,6 +251,7 @@ def _plan(self): return { "func": io_func, "parts": parts, + "statistics": stats, "divisions": divisions, } @@ -246,9 +264,104 @@ def _filtered_task(self, index: int): return (operator.getitem, tsk, self.columns[0]) return tsk + def _get_lengths(self) -> tuple | None: + """Return known partition lengths using parquet statistics""" + if not self.filters: + self._update_length_statistics() + return tuple( + length + for i, length in enumerate(self._pq_length_stats) + if not self._filtered or i in self._partitions + ) + + def _update_length_statistics(self): + """Ensure that partition-length statistics are up to date""" + + if not self._pq_length_stats: + if self._plan["statistics"]: + # Already have statistics from original API call + self._pq_length_stats = tuple( + stat["num-rows"] + for i, stat in enumerate(self._plan["statistics"]) + if not self._filtered or i in self._partitions + ) + else: + # Need to go back and collect statistics + self._pq_length_stats = tuple( + stat["num-rows"] for stat in _collect_pq_statistics(self) + ) + + +# +# Helper functions +# + + +def _list_columns(columns): + # Simple utility to convert columns to list + if isinstance(columns, (str, int)): + columns = [columns] + elif isinstance(columns, tuple): + columns = list(columns) + return columns + + +def _align_statistics(parts, statistics): + # Make sure parts and statistics are aligned + # (if statistics is not empty) + if statistics and len(parts) != len(statistics): + statistics = [] + if statistics: + result = list( + zip( + *[ + (part, stats) + for part, stats in zip(parts, statistics) + if stats["num-rows"] > 0 + ] + ) + ) + parts, statistics = result or [[], []] + return parts, statistics + + +def _aggregate_row_groups(parts, statistics, dataset_info): + # Aggregate parts/statistics if we are splitting by row-group + blocksize = ( + dataset_info["blocksize"] if dataset_info["split_row_groups"] is True else None + ) + split_row_groups = dataset_info["split_row_groups"] + fs = dataset_info["fs"] + aggregation_depth = dataset_info["aggregation_depth"] + + if statistics: + if blocksize or (split_row_groups and int(split_row_groups) > 1): + parts, statistics = aggregate_row_groups( + parts, statistics, blocksize, split_row_groups, fs, aggregation_depth + ) + return parts, statistics + + +def _calculate_divisions(statistics, dataset_info, npartitions): + # Use statistics to define divisions + divisions = None + if statistics: + calculate_divisions = dataset_info["kwargs"].get("calculate_divisions", None) + index = dataset_info["index"] + process_columns = index if index and len(index) == 1 else None + if (calculate_divisions is not False) and process_columns: + for sorted_column_info in sorted_columns( + statistics, columns=process_columns + ): + if sorted_column_info["name"] in index: + divisions = sorted_column_info["divisions"] + break + + return divisions or (None,) * (npartitions + 1) + # -# Filters +# Filtering logic # @@ -365,3 +478,123 @@ def extract_pq_filters(cls, pq_expr: ReadParquet, predicate_expr: Expr) -> _DNF: _filters = cls._Or([left, right]) return _DNF(_filters) + + +# +# Parquet-statistics handling +# + + +def _collect_pq_statistics( + expr: ReadParquet, columns: list | None = None +) -> list[dict] | None: + """Collect Parquet statistic for dataset paths""" + + # Be strict about columns argument + if columns: + if not isinstance(columns, list): + raise ValueError(f"Expected columns to be a list, got {type(columns)}.") + allowed = {expr._meta.index.name} | set(expr.columns) + if not set(columns).issubset(allowed): + raise ValueError(f"columns={columns} must be a subset of {allowed}") + + # Collect statistics using layer information + fs = expr._plan["func"].fs + parts = [ + part + for i, part in enumerate(expr._plan["parts"]) + if not expr._filtered or i in expr._partitions + ] + + # Execute with delayed for large and remote datasets + parallel = int(False if _is_local_fs(fs) else 16) + if parallel: + # Group parts corresponding to the same file. + # A single task should always parse statistics + # for all these parts at once (since they will + # all be in the same footer) + groups = defaultdict(list) + for part in parts: + for p in [part] if isinstance(part, dict) else part: + path = p.get("piece")[0] + groups[path].append(p) + group_keys = list(groups.keys()) + + # Compute and return flattened result + func = delayed(_read_partition_stats_group) + result = dask.compute( + [ + func( + list( + itertools.chain( + *[groups[k] for k in group_keys[i : i + parallel]] + ) + ), + fs, + columns=columns, + ) + for i in range(0, len(group_keys), parallel) + ] + )[0] + return list(itertools.chain(*result)) + else: + # Serial computation on client + return _read_partition_stats_group(parts, fs, columns=columns) + + +def _read_partition_stats_group(parts, fs, columns=None): + """Parse the statistics for a group of files""" + + def _read_partition_stats(part, fs, columns=None): + # Helper function to read Parquet-metadata + # statistics for a single partition + + if not isinstance(part, list): + part = [part] + + column_stats = {} + num_rows = 0 + columns = columns or [] + for p in part: + piece = p["piece"] + path = piece[0] + row_groups = None if piece[1] == [None] else piece[1] + with fs.open(path, default_cache="none") as f: + md = pq.ParquetFile(f).metadata + if row_groups is None: + row_groups = list(range(md.num_row_groups)) + for rg in row_groups: + row_group = md.row_group(rg) + num_rows += row_group.num_rows + for i in range(row_group.num_columns): + col = row_group.column(i) + name = col.path_in_schema + if name in columns: + if col.statistics and col.statistics.has_min_max: + if name in column_stats: + column_stats[name]["min"] = min( + column_stats[name]["min"], col.statistics.min + ) + column_stats[name]["max"] = max( + column_stats[name]["max"], col.statistics.max + ) + else: + column_stats[name] = { + "min": col.statistics.min, + "max": col.statistics.max, + } + + # Convert dict-of-dict to list-of-dict to be consistent + # with current `dd.read_parquet` convention (for now) + column_stats_list = [ + { + "name": name, + "min": column_stats[name]["min"], + "max": column_stats[name]["max"], + } + for name in column_stats.keys() + ] + return {"num-rows": num_rows, "columns": column_stats_list} + + # Helper function used by _extract_statistics + return [_read_partition_stats(part, fs, columns=columns) for part in parts] diff --git a/dask_expr/io/tests/test_io.py b/dask_expr/io/tests/test_io.py index 3d9fa30d0..b2a7036fa 100644 --- a/dask_expr/io/tests/test_io.py +++ b/dask_expr/io/tests/test_io.py @@ -6,8 +6,9 @@ from dask.dataframe.utils import assert_eq from dask_expr import from_dask_dataframe, from_pandas, optimize, read_csv, read_parquet -from dask_expr.expr import Expr +from dask_expr.expr import Expr, Lengths, Literal from dask_expr.io import ReadParquet +from dask_expr.reductions import Len def _make_file(dir, format="parquet", df=None): @@ -161,7 +162,7 @@ def test_io_culling(tmpdir, fmt): if fmt == "parquet": dd.from_pandas(pdf, 2).to_parquet(tmpdir) df = read_parquet(tmpdir) - elif fmt == "parquet": + elif fmt == "csv": dd.from_pandas(pdf, 2).to_csv(tmpdir) df = read_csv(tmpdir + "/*") else: @@ -207,6 +208,19 @@ def test_parquet_complex_filters(tmpdir): assert_eq(got.optimize(), expect) +def test_parquet_len(tmpdir): + df = read_parquet(_make_file(tmpdir)) + pdf = df.compute() + + assert len(df[df.a > 5]) == len(pdf[pdf.a > 5]) + + s = (df["b"] + 1).astype("Int32") + assert len(s) == len(pdf) + + assert isinstance(Len(s.expr).optimize(), Literal) + assert isinstance(Lengths(s.expr).optimize(), Literal) + + @pytest.mark.parametrize("optimize", [True, False]) def test_from_dask_dataframe(optimize): ddf = dd.from_dict({"a": range(100)}, npartitions=10) diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index e96dd8279..6c32da8c3 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -12,6 +12,7 @@ from dask_expr import expr, from_pandas, optimize from dask_expr.datasets import timeseries +from dask_expr.reductions import Len @pytest.fixture @@ -598,6 +599,19 @@ def test_repartition_divisions(df, opt): assert part.max() < df2.divisions[p + 1] +def test_len(df, pdf): + df2 = df[["x"]] + 1 + assert len(df2) == len(pdf) + + assert len(df[df.x > 5]) == len(pdf[pdf.x > 5]) + + first = df2.partitions[0].compute() + assert len(df2.partitions[0]) == len(first) + + assert isinstance(Len(df2.expr).optimize(), expr.Literal) + assert isinstance(expr.Lengths(df2.expr).optimize(), expr.Literal) + + def test_drop_duplicates(df, pdf): assert_eq(df.drop_duplicates(), pdf.drop_duplicates()) assert_eq( diff --git a/dask_expr/tests/test_datasets.py b/dask_expr/tests/test_datasets.py index ff2ac9d77..1a8532ccb 100644 --- a/dask_expr/tests/test_datasets.py +++ b/dask_expr/tests/test_datasets.py @@ -1,6 +1,8 @@ from dask.dataframe.utils import assert_eq +from dask_expr import new_collection from dask_expr.datasets import timeseries +from dask_expr.expr import Lengths def test_timeseries(): @@ -48,3 +50,8 @@ def test_persist(): assert_eq(a, b) assert len(a.dask) > len(b.dask) assert len(b.dask) == b.npartitions + + +def test_lengths(): + df = timeseries(freq="1H", start="2000-01-01", end="2000-01-03", seed=123) + assert len(df) == sum(new_collection(Lengths(df.expr).optimize()).compute())