diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index ce39a1c7e..646b2c3ee 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -3,6 +3,7 @@ import datetime import functools import inspect +import os import warnings from collections.abc import Callable, Hashable, Mapping from numbers import Integral, Number @@ -4526,6 +4527,7 @@ def read_csv( storage_options=storage_options, kwargs=kwargs, header=header, + _cwd=_get_cwd(path, kwargs), ) ) @@ -4551,6 +4553,7 @@ def read_table( storage_options=storage_options, kwargs=kwargs, header=header, + _cwd=_get_cwd(path, kwargs), ) ) @@ -4576,10 +4579,25 @@ def read_fwf( storage_options=storage_options, kwargs=kwargs, header=header, + _cwd=_get_cwd(path, kwargs), ) ) +def _get_protocol(urlpath): + if "://" in urlpath: + protocol, _ = urlpath.split("://", 1) + if len(protocol) > 1: + # excludes Windows paths + return protocol + return None + + +def _get_cwd(path, kwargs): + protocol = kwargs.pop("protocol", None) or _get_protocol(path) or "file" + return os.getcwd() if protocol == "file" else None + + def read_parquet( path=None, columns=None, @@ -4630,6 +4648,7 @@ def read_parquet( filesystem=filesystem, engine=_set_parquet_engine(engine), kwargs=kwargs, + _cwd=_get_cwd(path, kwargs), _series=isinstance(columns, str), ) ) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 3aac656c4..31f4e0f58 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -10,25 +10,28 @@ import pandas as pd import toolz from dask.dataframe.core import is_dataframe_like, is_index_like, is_series_like +from dask.delayed import Delayed from dask.utils import funcname, import_required, is_arraylike +from toolz.dicttoolz import merge from dask_expr._util import _BackendData, _tokenize_deterministic def _unpack_collections(o): if isinstance(o, Expr): - return o - - if hasattr(o, "expr"): - return o.expr + return o, o._name + elif hasattr(o, "expr") and not isinstance(o, Delayed): + return o.expr, o.expr._name else: - return o + return o, None class Expr: _parameters = [] _defaults = {} _instances = weakref.WeakValueDictionary() + _dependents = defaultdict(list) + _seen = set() def __new__(cls, *args, **kwargs): operands = list(args) @@ -38,15 +41,75 @@ def __new__(cls, *args, **kwargs): except KeyError: operands.append(cls._defaults[parameter]) assert not kwargs, kwargs + + parsed_operands = [] + children = set() + _subgraphs = [] + _subgraph_instances = [] + _graph_instances = {} + for o in operands: + expr, name = _unpack_collections(o) + parsed_operands.append(expr) + if name is not None: + children.add(name) + _subgraphs.append(expr._graph) + _subgraph_instances.append(expr._graph_instances) + _graph_instances[name] = expr + inst = object.__new__(cls) - inst.operands = [_unpack_collections(o) for o in operands] + inst.operands = parsed_operands _name = inst._name + + # Graph instances is a mapping name -> Expr instance + # Graph itself is a mapping of dependencies mapping names to a set of names + if _name in Expr._instances: - return Expr._instances[_name] + inst = Expr._instances[_name] + inst._graph_instances.update(merge(_graph_instances, *_subgraph_instances)) + inst._graph.update(merge(*_subgraphs)) + inst._graph[_name].update(children) + # Probably a bad idea to have a self ref + inst._graph_instances[_name] = inst + + else: + Expr._instances[_name] = inst + inst._graph_instances = merge(_graph_instances, *_subgraph_instances) + inst._graph = merge(*_subgraphs) + inst._graph[_name] = children + # Probably a bad idea to have a self ref + inst._graph_instances[_name] = inst + + if inst._name in Expr._seen: + # We already registered inst as a dependent of all it's + # dependencies, so we don't need to do it again + return inst + + Expr._seen.add(inst._name) + for dep in inst.dependencies(): + Expr._dependents[dep._name].append(inst) - Expr._instances[_name] = inst return inst + @functools.cached_property + def _dependent_graph(self): + # Reset to clear tracking + Expr._dependents = defaultdict(list) + Expr._seen = set() + rv = Expr._dependents + # This should be O(E) + tmp = defaultdict(set) + for expr, dependencies in self._graph.items(): + for dep in dependencies: + tmp[dep].add(expr) + for name, exprs in tmp.items(): + rv[name] = [self._graph_instances[e] for e in exprs] + return rv + + def __hash__(self): + raise TypeError( + "Expr objects can't be used in sets or dicts or similar, use the _name instead" + ) + def _tune_down(self): return None @@ -150,6 +213,9 @@ def dependencies(self): # Dependencies are `Expr` operands only return [operand for operand in self.operands if isinstance(operand, Expr)] + def dependents(self): + return self._dependent_graph + def _task(self, index: int): """The task for the i'th partition @@ -318,8 +384,6 @@ def simplify_once(self, dependents: defaultdict, simplified: dict): changed = False for operand in expr.operands: if isinstance(operand, Expr): - # Bandaid for now, waiting for Singleton - dependents[operand._name].append(weakref.ref(expr)) new = operand.simplify_once( dependents=dependents, simplified=simplified ) @@ -340,7 +404,7 @@ def simplify_once(self, dependents: defaultdict, simplified: dict): def simplify(self) -> Expr: expr = self while True: - dependents = collect_dependents(expr) + dependents = expr.dependents() new = expr.simplify_once(dependents=dependents, simplified={}) if new._name == expr._name: break @@ -712,19 +776,3 @@ def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]: or issubclass(operation, Expr) ), "`operation` must be`Expr` subclass)" return (expr for expr in self.walk() if isinstance(expr, operation)) - - -def collect_dependents(expr) -> defaultdict: - dependents = defaultdict(list) - stack = [expr] - seen = set() - while stack: - node = stack.pop() - if node._name in seen: - continue - seen.add(node._name) - - for dep in node.dependencies(): - stack.append(dep) - dependents[dep._name].append(weakref.ref(node)) - return dependents diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 6d800eab6..c269e064f 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -1150,7 +1150,7 @@ def _simplify_up(self, parent, dependents): ): predicate = None if self.frame.ndim == 1 and self.ndim == 2: - name = self.frame._meta.name + name = self._meta.columns[0] # Avoid Projection since we are already a Series subs = Projection(self, name) predicate = parent.predicate.substitute(subs, self.frame) @@ -2076,9 +2076,7 @@ def _simplify_up(self, parent, dependents): parent, dependents ): parents = [ - p().columns - for p in dependents[self._name] - if p() is not None and not isinstance(p(), Filter) + p.columns for p in dependents[self._name] if not isinstance(p, Filter) ] predicate = None if not set(flatten(parents, list)).issubset(set(self.frame.columns)): @@ -2107,7 +2105,7 @@ def _simplify_up(self, parent, dependents): if col in (self.name, "index", self.frame._meta.index.name): return if all( - isinstance(d(), Projection) and d().operand("columns") == col + isinstance(d, Projection) and d.operand("columns") == col for d in dependents[self._name] ): return type(self)(self.frame, True, self.name) @@ -2715,10 +2713,6 @@ class _DelayedExpr(Expr): # TODO _parameters = ["obj"] - def __init__(self, obj): - self.obj = obj - self.operands = [obj] - def __str__(self): return f"{type(self).__name__}({str(self.obj)})" @@ -3451,7 +3445,7 @@ def determine_column_projection(expr, parent, dependents, additional_columns=Non column_union = [] else: column_union = parent.columns.copy() - parents = [x() for x in dependents[expr._name] if x() is not None] + parents = dependents[expr._name] seen = set() for p in parents: @@ -3511,7 +3505,7 @@ def plain_column_projection(expr, parent, dependents, additional_columns=None): def is_filter_pushdown_available(expr, parent, dependents, allow_reduction=True): - parents = [x() for x in dependents[expr._name] if x() is not None] + parents = dependents[expr._name] filters = {e._name for e in parents if isinstance(e, Filter)} if len(filters) != 1: # Don't push down if not exactly one Filter @@ -3618,7 +3612,7 @@ def _check_dependents_are_predicates( continue seen.add(e._name) - e_dependents = {x()._name for x in dependents[e._name] if x() is not None} + e_dependents = {x._name for x in dependents[e._name]} if not allow_reduction: if isinstance(e, Reduction): diff --git a/dask_expr/io/csv.py b/dask_expr/io/csv.py index 6be02153d..7bc0618e4 100644 --- a/dask_expr/io/csv.py +++ b/dask_expr/io/csv.py @@ -14,6 +14,7 @@ class ReadCSV(PartitionsFiltered, BlockwiseIO): "_partitions", "storage_options", "kwargs", + "_cwd", # needed for tokenization "_series", ] _defaults = { @@ -24,6 +25,7 @@ class ReadCSV(PartitionsFiltered, BlockwiseIO): "_partitions": None, "storage_options": None, "_series": False, + "_cwd": None, } _absorb_projections = True diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 8cd9e8c3d..f722a4d8f 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -426,6 +426,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "filesystem", "engine", "kwargs", + "_cwd", # needed for tokenization "_partitions", "_series", "_dataset_info_cache", @@ -449,6 +450,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "_partitions": None, "_series": False, "_dataset_info_cache": None, + "_cwd": None, } _pq_length_stats = None _absorb_projections = True diff --git a/dask_expr/io/tests/test_io.py b/dask_expr/io/tests/test_io.py index cca5d63c0..62619c314 100644 --- a/dask_expr/io/tests/test_io.py +++ b/dask_expr/io/tests/test_io.py @@ -1,5 +1,6 @@ import glob import os +from pathlib import Path import dask.array as da import dask.dataframe as dd @@ -30,14 +31,14 @@ pd = _backend_library() -def _make_file(dir, format="parquet", df=None): +def _make_file(dir, format="parquet", df=None, **kwargs): fn = os.path.join(str(dir), f"myfile.{format}") if df is None: df = pd.DataFrame({c: range(10) for c in "abcde"}) if format == "csv": - df.to_csv(fn) + df.to_csv(fn, **kwargs) elif format == "parquet": - df.to_parquet(fn) + df.to_parquet(fn, **kwargs) else: ValueError(f"{format} not a supported format") return fn @@ -413,6 +414,33 @@ def test_combine_similar_no_projection_on_one_branch(tmpdir): assert_eq(df, pdf) +@pytest.mark.parametrize( + "fmt, func, kwargs", + [ + ("parquet", read_parquet, {}), + ("csv", read_csv, {"index": False}), + ], +) +def test_chdir_different_files(tmpdir, fmt, func, kwargs): + cwd = os.getcwd() + + try: + pdf = pd.DataFrame({"x": [0, 1, 2, 3] * 4, "y": range(16)}) + os.chdir(tmpdir) + _make_file(tmpdir, format=fmt, df=pdf, **kwargs) + df = func(f"myfile.{fmt}") + + new_dir = Path(tmpdir).joinpath("new_dir") + new_dir.mkdir() + os.chdir(new_dir) + pdf2 = pd.DataFrame({"x": [0, 100, 200, 300] * 4, "y": range(16)}) + _make_file(new_dir, format=fmt, df=pdf2, **kwargs) + df2 = func(f"myfile.{fmt}") + assert_eq(df.sum() + df2.sum(), pd.Series([2424, 240], index=["x", "y"])) + finally: + os.chdir(cwd) + + @pytest.mark.parametrize("meta", [True, False]) @pytest.mark.parametrize("label", [None, "foo"]) @pytest.mark.parametrize("allow_projection", [True, False]) diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index c7b977284..7484cbdfa 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -1,3 +1,4 @@ +import random from collections import OrderedDict import dask @@ -640,7 +641,7 @@ def test_set_index_sort_values_one_partition(pdf): def test_set_index_triggers_calc_when_accessing_divisions(pdf, df): divisions_lru.data = OrderedDict() - query = df.set_index("x") + query = df.fillna(random.randint(1, 100)).set_index("x") assert len(divisions_lru.data) == 0 divisions = query.divisions # noqa: F841 assert len(divisions_lru.data) == 1