From f897d3b2647372af1c1b90b70a81e624dd1ffced Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 29 Aug 2023 09:24:50 -0500 Subject: [PATCH] [RFC] Cache Expr objects Much of our overhead comes from doing computational work in other libraries (pandas, arrow, ...) that could be cached. We do cache a lot of stuff today, but we store these caches on the object itself. When we then recreate objects (for example in optimization) then we lose those caches. One solution here is to cache the objects themselves, so that `Op(...) is Op(...)`. This technique is a bit magical, but is used in other projects like SymPy where it has had good performance impacts (although they use it because they make many more very small objects). Maybe this isn't relevant for us. Ideally we wouldn't recreate objects often in optimization (this is why we return the original object if arguments match). But maybe it's hard to be careful. If so, this might provide a bit of a sledgehammer approach. THis isn't done yet, in particular there are open questions about non-hashable inputs like pandas dataframes. Hopefully it is a useful proof of concept. --- dask_expr/_expr.py | 36 +++++++++++++++++++++++++++++- dask_expr/tests/test_collection.py | 8 +++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 7a2f2dcfe..620f13682 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -4,6 +4,7 @@ import numbers import operator import os +import weakref from collections import defaultdict from collections.abc import Generator, Mapping @@ -34,6 +35,20 @@ no_default = "__no_default__" +_object_cache = weakref.WeakValueDictionary() + + +def normalize_arg(arg): + if isinstance(arg, list): + return tuple(arg) + if isinstance(arg, dict): + return tuple(sorted(arg.items())) + if isinstance(arg, pd.core.base.PandasObject): + return (type(arg), id(arg)) # not quite safe + if isinstance(arg, np.ndarray): + return (type(arg), id(arg)) # not quite safe + return arg + class Expr: """Primary class for all Expressions @@ -46,7 +61,26 @@ class Expr: _defaults = {} _is_length_preserving = False - def __init__(self, *args, **kwargs): + def __new__(cls, *args, **kwargs): + key = ( + cls, + tuple(map(normalize_arg, args)), + tuple(sorted(toolz.valmap(normalize_arg, kwargs).items())), + ) + + try: + return _object_cache[key] + except KeyError: + obj = object.__new__(cls) + cls._init(obj, *args, **kwargs) + _object_cache[key] = obj + return obj + except Exception: # can not hash + obj = object.__new__(cls) + cls._init(obj, *args, **kwargs) + return obj + + def _init(self, *args, **kwargs): operands = list(args) for parameter in type(self)._parameters[len(operands) :]: try: diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 13a76e849..e5f7b1d43 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -1212,3 +1212,11 @@ def test_shape(df, pdf): def test_size(df, pdf): assert_eq(df.size, pdf.size) + + +def test_object_caching(df): + a = df + 1 + b = df + 1 + assert a._expr is b._expr + assert a._meta is b._meta + del a, b