Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,386 changes: 727 additions & 1,659 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion benchmarks/bench_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def track_m2l_op_count(self, param):
dvec, tgt_rscale)
for i, expr in enumerate(result):
sac.assign_unique(f"coeff{i}", expr)
sac.run_global_cse()
sac = sac.run_global_cse()
insns = to_loopy_insns(sac.assignments.items())
counter = pymbolic.mapper.flop_counter.CSEAwareFlopCounter()

Expand Down
4 changes: 4 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
intersphinx_mapping = {
"arraycontext": ("https://documen.tician.de/arraycontext/", None),
"boxtree": ("https://documen.tician.de/boxtree/", None),
"islpy": ("https://documen.tician.de/islpy", None),
"loopy": ("https://documen.tician.de/loopy/", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
Expand Down Expand Up @@ -47,6 +48,7 @@
"obj_array.ObjectArray1D": "obj:pytools.obj_array.ObjectArray1D",
# sympy
"sp.Matrix": "class:sympy.matrices.dense.DenseMatrix",
"sym.Basic": "class:sympy.core.basic.Basic",
"sym.Expr": "class:sympy.core.expr.Expr",
"sym.Symbol": "class:sympy.core.symbol.Symbol",
"sym.Matrix": "class:sympy.matrices.dense.DenseMatrix",
Expand All @@ -58,13 +60,15 @@
# loopy
"Assignment": "class:loopy.kernel.instruction.Assignment",
"CallInstruction": "class:loopy.kernel.instruction.CallInstruction",
"InstructionBase": "class:loopy.kernel.instruction.InstructionBase",
# arraycontext
"Array": "obj:arraycontext.Array",
"ArrayContext": "class:arraycontext.ArrayContext",
# boxtree
"FMMTraversalInfo": "class:boxtree.traversal.FMMTraversalInfo",
# sumpy
"ArithmeticExpr": "obj:sumpy.kernel.ArithmeticExpr",
"OptimizationPair": "obj:sumpy.cse.OptimizationPair",
}


Expand Down
17 changes: 12 additions & 5 deletions sumpy/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@


if TYPE_CHECKING:
from collections.abc import Iterator
from collections.abc import Iterator, Sequence

import islpy
from numpy.typing import DTypeLike

from arraycontext import ArrayContext
from loopy import TranslationUnit
from loopy.codegen import PreambleInfo
from loopy.kernel.instruction import InstructionBase
from pytools.tag import ToTagSetConvertible


Expand All @@ -58,7 +60,8 @@
# {{{ PyOpenCLArrayContext

def make_loopy_program(
domains, statements,
domains: str | Sequence[str | islpy.BasicSet],
statements: Sequence[InstructionBase | str] | str,
kernel_data: list[Any] | None = None, *,
name: str = "sumpy_loopy_kernel",
silenced_warnings: list[str] | str | None = None,
Expand Down Expand Up @@ -125,7 +128,7 @@ def is_cl_cpu(actx: ArrayContext) -> bool:

# {{{ pytest

def _acf():
def _acf() -> ArrayContext:
import pyopencl as cl
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)
Expand All @@ -135,9 +138,13 @@ def _acf():

class PytestPyOpenCLArrayContextFactory(
_PytestPyOpenCLArrayContextFactoryWithClass):
actx_class = PyOpenCLArrayContext
@property
@override
def actx_class(self) -> type[ArrayContext]:
return PyOpenCLArrayContext

def __call__(self):
@override
def __call__(self) -> ArrayContext:
# NOTE: prevent any cache explosions during testing!
from sympy.core.cache import clear_cache
clear_cache()
Expand Down
112 changes: 64 additions & 48 deletions sumpy/assignment_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,47 +24,47 @@
"""

import logging
from typing import TYPE_CHECKING
from collections import defaultdict
from typing import TYPE_CHECKING, overload

from typing_extensions import override
from typing_extensions import Self, override

import sumpy.symbolic as sym


if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Mapping, Sequence


logger = logging.getLogger(__name__)

__doc__ = """

Manipulating batches of assignments
-----------------------------------

.. autoclass:: SymbolicAssignmentCollection

"""


class _SymbolGenerator:
taken_symbols: Mapping[str, sym.Basic]
base_to_count: dict[str, int]

def __init__(self, taken_symbols):
def __init__(self, taken_symbols: Mapping[str, sym.Basic]) -> None:
self.taken_symbols = taken_symbols
from collections import defaultdict
self.base_to_count = defaultdict(lambda: 0)

def _normalize(self, base):
def _normalize(self, base: str) -> str:
# Strip off any _N suffix, to avoid generating conflicting names.
import re
base = re.split(r"_\d+$", base)[0]
return base if base != "" else "expr"

def __call__(self, base="expr"):
def __call__(self, base: str = "expr") -> sym.Symbol:
base = self._normalize(base)
count = self.base_to_count[base]

def make_id_str(base, count):
def make_id_str(base: str, count: int) -> str:
return "{base}{suffix}".format(
base=base,
suffix="" if count == 0 else "_" + str(count - 1))
Expand All @@ -78,13 +78,14 @@ def make_id_str(base, count):

return sym.Symbol(id_str)

def __iter__(self):
def __iter__(self) -> _SymbolGenerator:
return self

def next(self):
def next(self) -> sym.Symbol:
return self()

__next__ = next
def __next__(self) -> sym.Symbol:
return self.next()


# {{{ collection of assignments
Expand All @@ -95,27 +96,29 @@ class SymbolicAssignmentCollection:
a = 5*x
b = a**2-k

In the above, *x* and *k* are external variables, and *a* and *b*
are variables managed by this object.
In the above, *x* and *k* are external variables, and *a* and *b* are
variables managed by this object.

This is a stateful object, but the only state changes allowed are additions
to *assignments*, and corresponding updates of its lookup tables.

This is a stateful object, but the only state changes allowed
are additions to *assignments*, and corresponding updates of
its lookup tables.
Note that user code is *only* allowed to hold on to *names* generated by
this class, but not expressions using names defined in this collection.

Note that user code is *only* allowed to hold on to *names* generated
by this class, but not expressions using names defined in this collection.
.. autoattribute:: assignments
.. automethod:: add_assignment
.. automethod:: assign_unique
.. automethod:: assign_temp
.. automethod:: run_global_cse
"""

assignments: dict[str, sym.Expr]
reversed_assignments: dict[sym.Expr, str]
assignments: dict[str, sym.Basic]
"""A mapping from *var_name* to expressions."""
reversed_assignments: dict[sym.Basic, str]
symbol_generator: _SymbolGenerator
all_dependencies_cache: dict[str, set[sym.Symbol]]

def __init__(self, assignments: dict[str, sym.Expr] | None = None):
"""
:arg assignments: mapping from *var_name* to expression
"""

def __init__(self, assignments: dict[str, sym.Basic] | None = None) -> None:
if assignments is None:
assignments = {}

Expand All @@ -126,13 +129,14 @@ def __init__(self, assignments: dict[str, sym.Expr] | None = None):
self.all_dependencies_cache = {}

@override
def __str__(self):
def __str__(self) -> str:
return "\n".join(
f"{name} <- {expr}"
for name, expr in self.assignments.items())

def get_all_dependencies(self, var_name: str):
def get_all_dependencies(self, var_name: str) -> set[sym.Symbol]:
"""Including recursive dependencies."""

try:
return self.all_dependencies_cache[var_name]
except KeyError:
Expand All @@ -157,9 +161,9 @@ def get_all_dependencies(self, var_name: str):

def add_assignment(self,
name: str,
expr: sym.Expr,
expr: sym.Basic,
root_name: str | None = None,
retain_name: bool = True):
retain_name: bool = True) -> str:
assert isinstance(name, str)
assert name not in self.assignments

Expand All @@ -176,23 +180,34 @@ def add_assignment(self,

return name

def assign_unique(self, name_base: str, expr: sym.Expr):
def assign_unique(self, name_base: str, expr: sym.Basic) -> str:
"""Assign *expr* to a new variable whose name is based on *name_base*.
Return the new variable name.
"""
new_name = self.symbol_generator(name_base).name

return self.add_assignment(new_name, expr)

def assign_temp(self, name_base: str, expr: sym.Expr):
def assign_temp(self, name_base: str, expr: sym.Basic) -> str:
"""If *expr* is mapped to a existing variable, then return the existing
variable or assign *expr* to a new variable whose name is based on
*name_base*. Return the variable name *expr* is mapped to in either case.
"""
new_name = self.symbol_generator(name_base).name
return self.add_assignment(new_name, expr, retain_name=False)

def run_global_cse(self, extra_exprs: Sequence[sym.Expr] | None = None):
@overload
def run_global_cse(self, extra_exprs: None = None) -> Self: ...

@overload
def run_global_cse(self,
extra_exprs: Sequence[sym.Expr]
) -> tuple[Self, Sequence[sym.Basic]]: ...

def run_global_cse(self,
extra_exprs: Sequence[sym.Expr] | None = None
) -> tuple[Self, Sequence[sym.Basic]] | Self:
orig_extra_exprs = extra_exprs
if extra_exprs is None:
extra_exprs = []

Expand All @@ -212,31 +227,32 @@ def run_global_cse(self, extra_exprs: Sequence[sym.Expr] | None = None):
# from sumpy.symbolic import checked_cse

from sumpy.cse import cse
new_assignments, new_exprs = cse([*assign_exprs, *extra_exprs],
new_cse_assignments, new_exprs = cse(
[*assign_exprs, *extra_exprs],
symbols=self.symbol_generator)

new_assign_exprs = new_exprs[:len(assign_exprs)]
new_extra_exprs = new_exprs[len(assign_exprs):]

for name, new_expr in zip(assign_names, new_assign_exprs, strict=True):
self.assignments[name] = new_expr
result_assignments: dict[str, sym.Basic] = {}

for name, value in new_assignments:
for name, value in new_cse_assignments:
assert isinstance(name, sym.Symbol)
self.add_assignment(name.name, value)
result_assignments[name.name] = value

for name, new_expr in zip(assign_names, new_assign_exprs, strict=True):
# We want the assignment collection to be ordered correctly
# to make it easier for loopy to schedule.
# Deleting the original assignments and adding them again
# makes them occur after the CSE'd expression preserving
# the order of operations.
del self.assignments[name]
self.assignments[name] = new_expr
result_assignments = {
**result_assignments,
**dict(zip(assign_names, new_assign_exprs, strict=True)),
}

logger.info("common subexpression elimination: done after %.2f s",
time.time() - start_time)
return new_extra_exprs

result = type(self)(result_assignments)
if orig_extra_exprs is None:
return result
else:
return result, new_extra_exprs

# }}}

Expand Down
2 changes: 1 addition & 1 deletion sumpy/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ def map_common_subexpression_uncached(
# {{{ to-loopy conversion

def to_loopy_insns(
assignments: Iterable[tuple[str, sym.Expr]],
assignments: Iterable[tuple[str, sym.Basic]],
vector_names: AbstractSet[str] | None = None,
pymbolic_expr_maps: Sequence[Callable[[Expression], Expression]] = (),
complex_dtype: DTypeLike | None = None,
Expand Down
Loading
Loading