From 375066c4068c09ccf3c3c88147ab3fe5fc09d30e Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 20 Nov 2025 11:28:05 -0600 Subject: [PATCH 01/18] add compute in its current form --- loopy/target/c/compyte | 2 +- loopy/transform/compute.py | 204 +++++++++++++++++++++++++++++++++++++ 2 files changed, 205 insertions(+), 1 deletion(-) create mode 100644 loopy/transform/compute.py diff --git a/loopy/target/c/compyte b/loopy/target/c/compyte index 2b168ca39..955160ac2 160000 --- a/loopy/target/c/compyte +++ b/loopy/target/c/compyte @@ -1 +1 @@ -Subproject commit 2b168ca396aec2259da408f441f5e38ac9f95cb6 +Subproject commit 955160ac2f504dabcd8641471a56146fa1afe35d diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py new file mode 100644 index 000000000..4b325a7af --- /dev/null +++ b/loopy/transform/compute.py @@ -0,0 +1,204 @@ +# DomainChanger +# iname nesting order <=> tree +# loop transformations +# - traverse syntax tree +# - affine map inames +# +# index views for warp tiling + +from pymbolic.mapper.substitutor import make_subst_func +from loopy.kernel import LoopKernel +import islpy as isl + +import loopy as lp +from loopy.kernel.data import AddressSpace +from loopy.kernel.function_interface import CallableKernel, ScalarCallable +from loopy.kernel.instruction import MultiAssignmentBase +from loopy.kernel.tools import DomainChanger +from loopy.match import parse_stack_match +from loopy.symbolic import RuleAwareSubstitutionMapper, SubstitutionRuleMappingContext, aff_from_expr, aff_to_expr, pw_aff_to_expr +from loopy.transform.precompute import RuleInvocationGatherer, RuleInvocationReplacer, contains_a_subst_rule_invocation +from loopy.translation_unit import TranslationUnit + +import pymbolic.primitives as prim +from pymbolic import var + +from pytools.tag import Tag + + +def compute( + t_unit: TranslationUnit, +substitution: str, + *args, + **kwargs + ) -> TranslationUnit: + """ + Entrypoint for performing a compute transformation on all kernels in a + translation unit. See :func:`_compute_inner` for more details. + """ + + assert isinstance(t_unit, TranslationUnit) + new_callables = {} + + for id, callable in t_unit.callables_table.items(): + if isinstance(callable, CallableKernel): + kernel = _compute_inner( + callable.subkernel, + substitution, + *args, **kwargs + ) + + callable = callable.copy(subkernel=kernel) + elif isinstance(callable, ScalarCallable): + pass + else: + raise NotImplementedError() + + new_callables[id] = callable + + return t_unit + +def _compute_inner( + kernel: LoopKernel, + substitution: str, + transform_map: isl.Map, + compute_map: isl.Map, + storage_inames: list[str], + default_tag: Tag | str | None = None, + temporary_address_space: AddressSpace | None = None + ) -> LoopKernel: + """ + Inserts an instruction to compute an expression given by :arg:`substitution` + and replaces all invocations of :arg:`substitution` with the result of the + compute instruction. + + :arg substitution: The substitution rule for which the compute + transform should be applied. + + :arg transform_map: An :class:`isl.Map` representing the affine + transformation from the original iname domain to the transformed iname + domain. + + :arg compute_map: An :class:`isl.Map` representing a relation between + substitution rule indices and tuples `(a, l)`, where `a` is a vector of + storage indices and `l` is a vector of "timestamps". This map describes + """ + + if not temporary_address_space: + temporary_address_space = AddressSpace.GLOBAL + + # {{{ normalize names + + iname_to_storage_map = { + iname : (iname + "_store" if iname in kernel.all_inames() else iname) + for iname in storage_inames + } + + new_storage_axes = list(iname_to_storage_map.values()) + + for dim in range(compute_map.dim(isl.dim_type.out)): + for iname, storage_ax in iname_to_storage_map.items(): + if compute_map.get_dim_name(isl.dim_type.out, dim) == iname: + compute_map = compute_map.set_dim_name( + isl.dim_type.out, dim, storage_ax + ) + + # }}} + + # {{{ update kernel domain to contain storage inames + + storage_domain = compute_map.range().project_out_except( + new_storage_axes, [isl.dim_type.set] + ) + + # FIXME: likely need to do some more digging to find proper domain to update + new_domain = kernel.domains[0] + for ax in new_storage_axes: + new_domain = new_domain.add_dims(isl.dim_type.set, 1) + + new_domain = new_domain.set_dim_name( + isl.dim_type.set, + new_domain.dim(isl.dim_type.set) - 1, + ax + ) + + new_domain, storage_domain = isl.align_two(new_domain, storage_domain) + new_domain = new_domain & storage_domain + kernel = kernel.copy(domains=[new_domain]) + + # }}} + + # {{{ express substitution inputs as pw affs of (storage, time) names + + compute_pw_aff = compute_map.reverse().as_pw_multi_aff() + subst_inp_names = [ + compute_map.get_dim_name(isl.dim_type.in_, i) + for i in range(compute_map.dim(isl.dim_type.in_)) + ] + storage_ax_to_global_expr = dict.fromkeys(subst_inp_names) + for dim in range(compute_pw_aff.dim(isl.dim_type.out)): + subst_inp = compute_map.get_dim_name(isl.dim_type.in_, dim) + storage_ax_to_global_expr[subst_inp] = \ + pw_aff_to_expr(compute_pw_aff.get_at(dim)) + + # }}} + + # {{{ generate instruction from compute map + + rule_mapping_ctx = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + + expr_subst_map = RuleAwareSubstitutionMapper( + rule_mapping_ctx, + make_subst_func(storage_ax_to_global_expr), + within=parse_stack_match(None) + ) + + subst_expr = kernel.substitutions[substitution].expression + compute_expression = expr_subst_map(subst_expr, kernel, None) + + temporary_name = substitution + "_temp" + assignee = var(temporary_name)[tuple( + var(iname) for iname in new_storage_axes + )] + + compute_insn_id = substitution + "_compute" + compute_insn = lp.Assignment( + id=compute_insn_id, + assignee=assignee, + expression=compute_expression, + ) + + compute_dep_id = compute_insn_id + new_insns = [compute_insn] + + # add global sync if we are storing in global memory + if temporary_address_space == lp.AddressSpace.GLOBAL: + gbarrier_id = kernel.make_unique_instruction_id( + based_on=substitution + "_barrier" + ) + + from loopy.kernel.instruction import BarrierInstruction + barrier_insn = BarrierInstruction( + id=gbarrier_id, + depends_on=frozenset([compute_insn_id]), + synchronization_kind="global", + mem_kind="global" + ) + + compute_dep_id = gbarrier_id + + # }}} + + # {{{ replace substitution rule with newly created instruction + + # FIXME: get these properly (see `precompute`) + subst_name = substitution + subst_tag = None + within = None # do we need this? + + + + # }}} + + return kernel From 745f841ca174d2cf97f1bcf0335330436e991b8f Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 20 Nov 2025 11:41:05 -0600 Subject: [PATCH 02/18] align compyte with inducer/main --- loopy/target/c/compyte | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/loopy/target/c/compyte b/loopy/target/c/compyte index 955160ac2..2b168ca39 160000 --- a/loopy/target/c/compyte +++ b/loopy/target/c/compyte @@ -1 +1 @@ -Subproject commit 955160ac2f504dabcd8641471a56146fa1afe35d +Subproject commit 2b168ca396aec2259da408f441f5e38ac9f95cb6 From 88966446dc586b0a0d2ef44b9f410bc330774f21 Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 20 Nov 2025 11:49:42 -0600 Subject: [PATCH 03/18] clean up comments and typos --- loopy/transform/compute.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 4b325a7af..e1e54fa2d 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -1,27 +1,19 @@ -# DomainChanger -# iname nesting order <=> tree -# loop transformations -# - traverse syntax tree -# - affine map inames -# -# index views for warp tiling - -from pymbolic.mapper.substitutor import make_subst_func -from loopy.kernel import LoopKernel import islpy as isl import loopy as lp +from loopy.kernel import LoopKernel from loopy.kernel.data import AddressSpace from loopy.kernel.function_interface import CallableKernel, ScalarCallable -from loopy.kernel.instruction import MultiAssignmentBase -from loopy.kernel.tools import DomainChanger from loopy.match import parse_stack_match -from loopy.symbolic import RuleAwareSubstitutionMapper, SubstitutionRuleMappingContext, aff_from_expr, aff_to_expr, pw_aff_to_expr -from loopy.transform.precompute import RuleInvocationGatherer, RuleInvocationReplacer, contains_a_subst_rule_invocation +from loopy.symbolic import ( + RuleAwareSubstitutionMapper, + SubstitutionRuleMappingContext, + pw_aff_to_expr +) from loopy.translation_unit import TranslationUnit -import pymbolic.primitives as prim from pymbolic import var +from pymbolic.mapper.substitutor import make_subst_func from pytools.tag import Tag @@ -81,7 +73,7 @@ def _compute_inner( :arg compute_map: An :class:`isl.Map` representing a relation between substitution rule indices and tuples `(a, l)`, where `a` is a vector of - storage indices and `l` is a vector of "timestamps". This map describes + storage indices and `l` is a vector of "timestamps". """ if not temporary_address_space: From 6230e11ff78e97bd074a19910c9094f38cbf0a69 Mon Sep 17 00:00:00 2001 From: Addison Date: Wed, 10 Dec 2025 00:46:09 -0600 Subject: [PATCH 04/18] switch to alpha namedisl usage --- loopy/transform/compute.py | 119 +++++++------------------------------ 1 file changed, 22 insertions(+), 97 deletions(-) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index e1e54fa2d..cb01401db 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -1,16 +1,18 @@ import islpy as isl +import namedisl as nisl import loopy as lp from loopy.kernel import LoopKernel from loopy.kernel.data import AddressSpace -from loopy.kernel.function_interface import CallableKernel, ScalarCallable +from loopy.kernel.instruction import MultiAssignmentBase from loopy.match import parse_stack_match from loopy.symbolic import ( RuleAwareSubstitutionMapper, SubstitutionRuleMappingContext, pw_aff_to_expr ) -from loopy.translation_unit import TranslationUnit +from loopy.transform.precompute import contains_a_subst_rule_invocation +from loopy.translation_unit import for_each_kernel from pymbolic import var from pymbolic.mapper.substitutor import make_subst_func @@ -18,43 +20,11 @@ from pytools.tag import Tag +@for_each_kernel def compute( - t_unit: TranslationUnit, -substitution: str, - *args, - **kwargs - ) -> TranslationUnit: - """ - Entrypoint for performing a compute transformation on all kernels in a - translation unit. See :func:`_compute_inner` for more details. - """ - - assert isinstance(t_unit, TranslationUnit) - new_callables = {} - - for id, callable in t_unit.callables_table.items(): - if isinstance(callable, CallableKernel): - kernel = _compute_inner( - callable.subkernel, - substitution, - *args, **kwargs - ) - - callable = callable.copy(subkernel=kernel) - elif isinstance(callable, ScalarCallable): - pass - else: - raise NotImplementedError() - - new_callables[id] = callable - - return t_unit - -def _compute_inner( kernel: LoopKernel, substitution: str, - transform_map: isl.Map, - compute_map: isl.Map, + compute_map: isl.Map | nisl.Map, storage_inames: list[str], default_tag: Tag | str | None = None, temporary_address_space: AddressSpace | None = None @@ -67,14 +37,12 @@ def _compute_inner( :arg substitution: The substitution rule for which the compute transform should be applied. - :arg transform_map: An :class:`isl.Map` representing the affine - transformation from the original iname domain to the transformed iname - domain. - :arg compute_map: An :class:`isl.Map` representing a relation between substitution rule indices and tuples `(a, l)`, where `a` is a vector of storage indices and `l` is a vector of "timestamps". """ + if isinstance(compute_map, isl.Map): + compute_map = nisl.make_map(compute_map) if not temporary_address_space: temporary_address_space = AddressSpace.GLOBAL @@ -86,52 +54,29 @@ def _compute_inner( for iname in storage_inames } - new_storage_axes = list(iname_to_storage_map.values()) - - for dim in range(compute_map.dim(isl.dim_type.out)): - for iname, storage_ax in iname_to_storage_map.items(): - if compute_map.get_dim_name(isl.dim_type.out, dim) == iname: - compute_map = compute_map.set_dim_name( - isl.dim_type.out, dim, storage_ax - ) + compute_map = compute_map.rename_dims(iname_to_storage_map) # }}} # {{{ update kernel domain to contain storage inames - storage_domain = compute_map.range().project_out_except( - new_storage_axes, [isl.dim_type.set] - ) + new_storage_axes = list(iname_to_storage_map.values()) - # FIXME: likely need to do some more digging to find proper domain to update + # FIXME: use DomainChanger to add domain to kernel + storage_domain = compute_map.range().project_out_except(new_storage_axes) new_domain = kernel.domains[0] - for ax in new_storage_axes: - new_domain = new_domain.add_dims(isl.dim_type.set, 1) - - new_domain = new_domain.set_dim_name( - isl.dim_type.set, - new_domain.dim(isl.dim_type.set) - 1, - ax - ) - - new_domain, storage_domain = isl.align_two(new_domain, storage_domain) - new_domain = new_domain & storage_domain - kernel = kernel.copy(domains=[new_domain]) # }}} # {{{ express substitution inputs as pw affs of (storage, time) names compute_pw_aff = compute_map.reverse().as_pw_multi_aff() - subst_inp_names = [ - compute_map.get_dim_name(isl.dim_type.in_, i) - for i in range(compute_map.dim(isl.dim_type.in_)) - ] - storage_ax_to_global_expr = dict.fromkeys(subst_inp_names) - for dim in range(compute_pw_aff.dim(isl.dim_type.out)): - subst_inp = compute_map.get_dim_name(isl.dim_type.in_, dim) - storage_ax_to_global_expr[subst_inp] = \ - pw_aff_to_expr(compute_pw_aff.get_at(dim)) + + # FIXME: remove PwAff._obj usage when ready + storage_ax_to_global_expr = { + dim_name : pw_aff_to_expr(compute_pw_aff.get_at(dim_name)._obj) + for dim_name in compute_map.dim_type_names(isl.dim_type.in_) + } # }}} @@ -161,34 +106,14 @@ def _compute_inner( expression=compute_expression, ) - compute_dep_id = compute_insn_id - new_insns = [compute_insn] - - # add global sync if we are storing in global memory - if temporary_address_space == lp.AddressSpace.GLOBAL: - gbarrier_id = kernel.make_unique_instruction_id( - based_on=substitution + "_barrier" - ) - - from loopy.kernel.instruction import BarrierInstruction - barrier_insn = BarrierInstruction( - id=gbarrier_id, - depends_on=frozenset([compute_insn_id]), - synchronization_kind="global", - mem_kind="global" - ) - - compute_dep_id = gbarrier_id - # }}} # {{{ replace substitution rule with newly created instruction - # FIXME: get these properly (see `precompute`) - subst_name = substitution - subst_tag = None - within = None # do we need this? - + for insn in kernel.instructions: + if contains_a_subst_rule_invocation(kernel, insn) \ + and isinstance(insn, MultiAssignmentBase): + print(insn) # }}} From 80839bad239b9833b7804ab2f750a7bd03fb38a8 Mon Sep 17 00:00:00 2001 From: Addison Date: Wed, 10 Dec 2025 12:40:49 -0600 Subject: [PATCH 05/18] start using namedisl in places other than compute --- loopy/symbolic.py | 37 ++++++++++++++++++++++++++----------- loopy/transform/compute.py | 2 +- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index ba6d71a80..442eb8572 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -48,6 +48,7 @@ from constantdict import constantdict from typing_extensions import Self, override +import namedisl as nisl import islpy as isl import pymbolic.primitives as p import pytools.lex @@ -2044,23 +2045,37 @@ def map_subscript(self, expr: p.Subscript) -> Set[p.Subscript]: # {{{ (pw)aff to expr conversion -def aff_to_expr(aff: isl.Aff) -> ArithmeticExpression: +def aff_to_expr(aff: isl.Aff | nisl.Aff) -> ArithmeticExpression: from pymbolic import var denom = aff.get_denominator_val().to_python() - result = (aff.get_constant_val()*denom).to_python() - for dt in [isl.dim_type.in_, isl.dim_type.param]: - for i in range(aff.dim(dt)): - coeff = (aff.get_coefficient_val(dt, i)*denom).to_python() + if isinstance(aff, isl.Aff): + for dt in [isl.dim_type.in_, isl.dim_type.param]: + for i in range(aff.dim(dt)): + coeff = (aff.get_coefficient_val(dt, i)*denom).to_python() + if coeff: + dim_name = not_none(aff.get_dim_name(dt, i)) + result += coeff*var(dim_name) + + for i in range(aff.dim(isl.dim_type.div)): + coeff = (aff.get_coefficient_val(isl.dim_type.div, i)*denom).to_python() + if coeff: + result += coeff*aff_to_expr(aff.get_div(i)) + + else: + in_names = set(aff.dim_type_names(isl.dim_type.in_)) + param_names = set(aff.dim_type_names(isl.dim_type.param)) + + for name in in_names | param_names: + coeff = (aff.get_coefficient_val(name) * denom).to_python() if coeff: - dim_name = not_none(aff.get_dim_name(dt, i)) - result += coeff*var(dim_name) + result = coeff * var(name) - for i in range(aff.dim(isl.dim_type.div)): - coeff = (aff.get_coefficient_val(isl.dim_type.div, i)*denom).to_python() - if coeff: - result += coeff*aff_to_expr(aff.get_div(i)) + for name in aff.dim_type_names(isl.dim_type.div): + coeff = (aff.get_coefficient_val(name) * denom).to_python() + if coeff: + result += coeff * aff_to_expr(aff.get_div(name)) assert not isinstance(result, complex) return flatten(result // denom) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index cb01401db..b3e06f2e9 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -74,7 +74,7 @@ def compute( # FIXME: remove PwAff._obj usage when ready storage_ax_to_global_expr = { - dim_name : pw_aff_to_expr(compute_pw_aff.get_at(dim_name)._obj) + dim_name : pw_aff_to_expr(compute_pw_aff.get_at(dim_name)) for dim_name in compute_map.dim_type_names(isl.dim_type.in_) } From c1ba35bb2d20f093f86df785378b63b1092ba0dd Mon Sep 17 00:00:00 2001 From: Addison Date: Wed, 10 Dec 2025 12:43:44 -0600 Subject: [PATCH 06/18] add namedisl objects to a type signature --- loopy/symbolic.py | 3 ++- loopy/transform/compute.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 442eb8572..f47e32f9d 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -2048,6 +2048,7 @@ def map_subscript(self, expr: p.Subscript) -> Set[p.Subscript]: def aff_to_expr(aff: isl.Aff | nisl.Aff) -> ArithmeticExpression: from pymbolic import var + # FIXME: remove this once namedisl is the standard in loopy denom = aff.get_denominator_val().to_python() result = (aff.get_constant_val()*denom).to_python() if isinstance(aff, isl.Aff): @@ -2082,7 +2083,7 @@ def aff_to_expr(aff: isl.Aff | nisl.Aff) -> ArithmeticExpression: def pw_aff_to_expr( - pw_aff: int | isl.PwAff | isl.Aff, + pw_aff: int | isl.PwAff | isl.Aff | nisl.PwAff | nisl.Aff, int_ok: bool = False ) -> ArithmeticExpression: if isinstance(pw_aff, int): diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index b3e06f2e9..59ddf8a2e 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -72,7 +72,6 @@ def compute( compute_pw_aff = compute_map.reverse().as_pw_multi_aff() - # FIXME: remove PwAff._obj usage when ready storage_ax_to_global_expr = { dim_name : pw_aff_to_expr(compute_pw_aff.get_at(dim_name)) for dim_name in compute_map.dim_type_names(isl.dim_type.in_) From 7265536a3a1fc4cd24883946876b4f16a12bcac2 Mon Sep 17 00:00:00 2001 From: Addison Date: Mon, 16 Mar 2026 09:22:53 -0500 Subject: [PATCH 07/18] compute transform up to and including instruction creation + insertion; missing invocation replacement --- loopy/transform/compute.py | 176 ++++++++++++++++++++++++------------- 1 file changed, 116 insertions(+), 60 deletions(-) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 59ddf8a2e..5c3c02130 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -1,32 +1,55 @@ -import islpy as isl +import loopy as lp +from loopy.kernel.tools import DomainChanger import namedisl as nisl -import loopy as lp from loopy.kernel import LoopKernel from loopy.kernel.data import AddressSpace -from loopy.kernel.instruction import MultiAssignmentBase from loopy.match import parse_stack_match from loopy.symbolic import ( + RuleAwareIdentityMapper, RuleAwareSubstitutionMapper, SubstitutionRuleMappingContext, - pw_aff_to_expr + pw_aff_to_expr, + pwaff_from_expr +) +from loopy.transform.precompute import ( + RuleInvocationGatherer, + contains_a_subst_rule_invocation ) -from loopy.transform.precompute import contains_a_subst_rule_invocation from loopy.translation_unit import for_each_kernel - from pymbolic import var from pymbolic.mapper.substitutor import make_subst_func -from pytools.tag import Tag +import islpy as isl +import pymbolic.primitives as p +from pymbolic.mapper.dependency import DependencyMapper + +from pymbolic.mapper import IdentityMapper + + +def gather_vars(expr) -> set[str]: + deps = DependencyMapper()(expr) + return { + dep.name + for dep in deps + if isinstance(dep, p.Variable) + } +def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT): + names = sorted(set().union(*(gather_vars(expr) for expr in exprs))) + set_names = [name for name in names] + + return isl.Space.create_from_names( + ctx, + set=set_names + ) @for_each_kernel def compute( kernel: LoopKernel, substitution: str, - compute_map: isl.Map | nisl.Map, - storage_inames: list[str], - default_tag: Tag | str | None = None, + compute_map: nisl.Map, + storage_indices: frozenset[str], temporary_address_space: AddressSpace | None = None ) -> LoopKernel: """ @@ -40,52 +63,86 @@ def compute( :arg compute_map: An :class:`isl.Map` representing a relation between substitution rule indices and tuples `(a, l)`, where `a` is a vector of storage indices and `l` is a vector of "timestamps". - """ - if isinstance(compute_map, isl.Map): - compute_map = nisl.make_map(compute_map) - - if not temporary_address_space: - temporary_address_space = AddressSpace.GLOBAL - - # {{{ normalize names - - iname_to_storage_map = { - iname : (iname + "_store" if iname in kernel.all_inames() else iname) - for iname in storage_inames - } - - compute_map = compute_map.rename_dims(iname_to_storage_map) - - # }}} - # {{{ update kernel domain to contain storage inames - - new_storage_axes = list(iname_to_storage_map.values()) - - # FIXME: use DomainChanger to add domain to kernel - storage_domain = compute_map.range().project_out_except(new_storage_axes) - new_domain = kernel.domains[0] - - # }}} + :arg storage_indices: A :class:`frozenset` of names of storage indices. Used + to create inames for the loops that cover the required footprint. + """ + compute_map = compute_map._reconstruct_isl_object() - # {{{ express substitution inputs as pw affs of (storage, time) names + # construct union of usage footprints to determine bounds on compute inames + ctx = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + inv_gatherer = RuleInvocationGatherer( + ctx, kernel, substitution, None, parse_stack_match(None) + ) + for insn in kernel.instructions: + if (isinstance(insn, lp.MultiAssignmentBase) and + contains_a_subst_rule_invocation(kernel, insn)): + for assignee in insn.assignees: + _ = inv_gatherer(assignee, kernel, insn) + _ = inv_gatherer(insn.expression, kernel, insn) + + access_descriptors = inv_gatherer.access_descriptors + + acc_desc_exprs = [ + arg + for ad in access_descriptors + if ad.args is not None + for arg in ad.args + ] + + space = space_from_exprs(acc_desc_exprs) + + footprint = isl.Set.empty(isl.Space.create_from_names( + ctx=space.get_ctx(), + set=list(storage_indices) + )) + for ad in access_descriptors: + if not ad.args: + continue + + nout = len(ad.args) + + range_space = isl.Space.alloc(space.get_ctx(), 0, nout, 0).domain() + map_space = space.map_from_domain_and_range(range_space) + pw_multi_aff = isl.MultiPwAff.zero(map_space) + + for i, arg in enumerate(ad.args): + if arg is not None: + pw_multi_aff = pw_multi_aff.set_pw_aff( + i, + pwaff_from_expr(space, arg) + ) + + usage_map = pw_multi_aff.as_map() + iname_to_timespace = usage_map.apply_range(compute_map).coalesce() + iname_to_storage = iname_to_timespace.project_out_except( + storage_indices, [isl.dim_type.out] + ) + + footprint = footprint | iname_to_storage.range() + + # add compute inames to domain / kernel + domain_changer = DomainChanger(kernel, kernel.all_inames()) + domain = domain_changer.domain + + footprint, domain = isl.align_two(footprint, domain) + domain = domain & footprint + + new_domains = domain_changer.get_domains_with(domain) + kernel = kernel.copy(domains=new_domains) + + # create compute instruction in kernel compute_pw_aff = compute_map.reverse().as_pw_multi_aff() - storage_ax_to_global_expr = { - dim_name : pw_aff_to_expr(compute_pw_aff.get_at(dim_name)) - for dim_name in compute_map.dim_type_names(isl.dim_type.in_) + compute_pw_aff.get_dim_name(isl.dim_type.out, dim) : + pw_aff_to_expr(compute_pw_aff.get_at(dim)) + for dim in range(compute_pw_aff.dim(isl.dim_type.out)) } - # }}} - - # {{{ generate instruction from compute map - - rule_mapping_ctx = SubstitutionRuleMappingContext( - kernel.substitutions, kernel.get_var_name_generator()) - expr_subst_map = RuleAwareSubstitutionMapper( - rule_mapping_ctx, + ctx, make_subst_func(storage_ax_to_global_expr), within=parse_stack_match(None) ) @@ -95,26 +152,25 @@ def compute( temporary_name = substitution + "_temp" assignee = var(temporary_name)[tuple( - var(iname) for iname in new_storage_axes + var(iname) for iname in storage_indices )] + within_inames = frozenset( + compute_map.get_dim_name(isl.dim_type.out, dim) + for dim in range(compute_map.dim(isl.dim_type.out)) + ) + compute_insn_id = substitution + "_compute" compute_insn = lp.Assignment( id=compute_insn_id, assignee=assignee, expression=compute_expression, + within_inames=within_inames ) - # }}} - - # {{{ replace substitution rule with newly created instruction - - for insn in kernel.instructions: - if contains_a_subst_rule_invocation(kernel, insn) \ - and isinstance(insn, MultiAssignmentBase): - print(insn) - - - # }}} + new_insns = list(kernel.instructions) + new_insns.append(compute_insn) + kernel = kernel.copy(instructions=new_insns) + print(kernel) return kernel From 56af4fe810ca2c7a9813cb33a51f7e45f08498dc Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 17 Mar 2026 17:43:06 -0500 Subject: [PATCH 08/18] invocation replacement; dependencies still need handling --- loopy/transform/compute.py | 291 +++++++++++++++++++++++++++++++------ 1 file changed, 248 insertions(+), 43 deletions(-) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 5c3c02130..02df1eaa7 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -1,30 +1,35 @@ +from collections.abc import Sequence, Set +from dataclasses import dataclass +from typing import override import loopy as lp from loopy.kernel.tools import DomainChanger import namedisl as nisl from loopy.kernel import LoopKernel -from loopy.kernel.data import AddressSpace -from loopy.match import parse_stack_match +from loopy.kernel.data import AddressSpace, SubstitutionRule +from loopy.match import StackMatch, parse_stack_match from loopy.symbolic import ( + ExpansionState, RuleAwareIdentityMapper, RuleAwareSubstitutionMapper, + SubstitutionRuleExpander, SubstitutionRuleMappingContext, + get_dependencies, pw_aff_to_expr, pwaff_from_expr ) from loopy.transform.precompute import ( - RuleInvocationGatherer, contains_a_subst_rule_invocation ) from loopy.translation_unit import for_each_kernel -from pymbolic import var +from pymbolic import ArithmeticExpression, var from pymbolic.mapper.substitutor import make_subst_func import islpy as isl import pymbolic.primitives as p from pymbolic.mapper.dependency import DependencyMapper - -from pymbolic.mapper import IdentityMapper +from pymbolic.typing import Expression +from pytools.tag import Tag def gather_vars(expr) -> set[str]: @@ -35,6 +40,7 @@ def gather_vars(expr) -> set[str]: if isinstance(dep, p.Variable) } + def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT): names = sorted(set().union(*(gather_vars(expr) for expr in exprs))) set_names = [name for name in names] @@ -44,12 +50,181 @@ def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT): set=set_names ) + +@dataclass(frozen=True) +class UsageDescriptor: + usage: Sequence[Expression] + global_map: isl.Map + local_map: isl.Map + + @override + def __str__(self): + return ( + f"USAGE = {self.usage}\n" + + f"GLOBAL MAP = {self.global_map}\n" + + f"LOCAL MAP = {self.local_map}" + ) + + +class UsageSiteExpressionGatherer(RuleAwareIdentityMapper[[]]): + """ + Gathers all expressions used as inputs to a particular substitution rule, + identified by name. + """ + def __init__( + self, + rule_mapping_ctx: SubstitutionRuleMappingContext, + subst_expander: SubstitutionRuleExpander, + kernel: LoopKernel, + subst_name: str, + subst_tag: Set[Tag] | Tag | None = None + ) -> None: + + super().__init__(rule_mapping_ctx) + + self.subst_expander: SubstitutionRuleExpander = subst_expander + self.kernel: LoopKernel = kernel + self.subst_name: str = subst_name + self.subst_tag: Set[Tag] | None = ( + {subst_tag} if isinstance(subst_tag, Tag) else subst_tag + ) + + self.usage_expressions: list[Sequence[Expression]] = [] + + + @override + def map_subst_rule( + self, + name: str, + tags: Set[Tag] | None, + arguments: Sequence[Expression], + expn_state: ExpansionState, + ) -> Expression: + + if name != self.subst_name: + return super().map_subst_rule( + name, tags, arguments, expn_state + ) + + if self.subst_tag is not None and self.subst_tag != tags: + return super().map_subst_rule( + name, tags, arguments, expn_state + ) + + rule = self.rule_mapping_context.old_subst_rules[name] + arg_ctx = self.make_new_arg_context( + name, rule.arguments, arguments, expn_state.arg_context + ) + + self.usage_expressions.append([ + arg_ctx[arg_name] for arg_name in rule.arguments + ]) + + return 0 + + +class RuleInvocationReplacer(RuleAwareIdentityMapper[[]]): + def __init__( + self, + ctx: SubstitutionRuleMappingContext, + subst_name: str, + subst_tag: Sequence[Tag] | None, + usage_descriptors: Sequence[UsageDescriptor], + storage_indices: Sequence[str], + temporary_name: str, + compute_insn_id: str, + compute_map: isl.Map + ) -> None: + + super().__init__(ctx) + + self.subst_name: str = subst_name + self.subst_tag: Sequence[Tag] | None = subst_tag + + self.usage_descriptors: Sequence[UsageDescriptor] = usage_descriptors + self.storage_indices: Sequence[str] = storage_indices + + self.temporary_name: str = temporary_name + self.compute_insn_id: str = compute_insn_id + + + @override + def map_subst_rule( + self, + name: str, + tags: Set[Tag] | None, + arguments: Sequence[Expression], + expn_state: ExpansionState + ) -> Expression: + + if not name == self.subst_name: + return super().map_subst_rule(name, tags, arguments, expn_state) + + rule = self.rule_mapping_context.old_subst_rules[name] + arg_ctx = self.make_new_arg_context( + name, rule.arguments, arguments, expn_state.arg_context + ) + args = [arg_ctx[arg_name] for arg_name in rule.arguments] + + # FIXME: footprint check? likely required if user supplies bounds on + # storage indices because we are not guaranteed to capture the footprint + # of all usage sites + + if not len(arguments) == len(rule.arguments): + raise ValueError("Number of arguments passed to rule {name} ", + "does not match the signature of {name}.") + + index_exprs: Sequence[Expression] = [] + for usage_descr in self.usage_descriptors: + if args == usage_descr.usage: + local_pwmaff = usage_descr.local_map.as_pw_multi_aff() + + for dim in range(local_pwmaff.dim(isl.dim_type.out)): + index_exprs.append(pw_aff_to_expr(local_pwmaff.get_at(dim))) + + break + + new_expression = var(self.temporary_name)[tuple(index_exprs)] + + return new_expression + + + @override + def map_kernel( + self, + kernel: LoopKernel, + within: StackMatch = lambda knl, insn, stack: True, + map_args: bool = True, + map_tvs: bool = True + ) -> LoopKernel: + + new_insns = [] + for insn in kernel.instructions: + if (isinstance(insn, lp.MultiAssignmentBase) and not + contains_a_subst_rule_invocation(kernel, insn)): + new_insns.append(insn) + continue + + insn = insn.with_transformed_expressions( + lambda expr: self(expr, kernel, insn) + ) + + new_insns.append(insn) + + return kernel.copy(instructions=new_insns) + + @for_each_kernel def compute( kernel: LoopKernel, substitution: str, compute_map: nisl.Map, - storage_indices: frozenset[str], + storage_indices: Sequence[str], + + # NOTE: how can we deduce this? + temporal_inames: Sequence[str], + + temporary_name: str | None = None, temporary_address_space: AddressSpace | None = None ) -> LoopKernel: """ @@ -64,65 +239,76 @@ def compute( substitution rule indices and tuples `(a, l)`, where `a` is a vector of storage indices and `l` is a vector of "timestamps". - :arg storage_indices: A :class:`frozenset` of names of storage indices. Used - to create inames for the loops that cover the required footprint. + :arg storage_indices: An ordered sequence of names of storage indices. Used + to create inames for the loops that cover the required set of compute points. """ compute_map = compute_map._reconstruct_isl_object() # construct union of usage footprints to determine bounds on compute inames ctx = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) - inv_gatherer = RuleInvocationGatherer( - ctx, kernel, substitution, None, parse_stack_match(None) + expander = SubstitutionRuleExpander(kernel.substitutions) + expr_gatherer = UsageSiteExpressionGatherer( + ctx, expander, kernel, substitution, None ) - for insn in kernel.instructions: - if (isinstance(insn, lp.MultiAssignmentBase) and - contains_a_subst_rule_invocation(kernel, insn)): - for assignee in insn.assignees: - _ = inv_gatherer(assignee, kernel, insn) - _ = inv_gatherer(insn.expression, kernel, insn) + _ = expr_gatherer.map_kernel(kernel) + usage_exprs = expr_gatherer.usage_expressions - access_descriptors = inv_gatherer.access_descriptors - - acc_desc_exprs = [ - arg - for ad in access_descriptors - if ad.args is not None - for arg in ad.args + all_exprs = [ + expr + for usage in usage_exprs + for expr in usage ] - space = space_from_exprs(acc_desc_exprs) + space = space_from_exprs(all_exprs) - footprint = isl.Set.empty(isl.Space.create_from_names( - ctx=space.get_ctx(), - set=list(storage_indices) - )) - for ad in access_descriptors: - if not ad.args: - continue + footprint = isl.Set.empty( + isl.Space.create_from_names( + ctx=space.get_ctx(), + set=list(storage_indices) + ) + ) - nout = len(ad.args) + usage_descrs: Sequence[UsageDescriptor] = [] + for usage in usage_exprs: - range_space = isl.Space.alloc(space.get_ctx(), 0, nout, 0).domain() + range_space = isl.Space.create_from_names( + ctx=space.get_ctx(), + set=list(storage_indices) + ) map_space = space.map_from_domain_and_range(range_space) + pw_multi_aff = isl.MultiPwAff.zero(map_space) - for i, arg in enumerate(ad.args): - if arg is not None: - pw_multi_aff = pw_multi_aff.set_pw_aff( - i, - pwaff_from_expr(space, arg) - ) + for i, arg in enumerate(usage): + pw_multi_aff = pw_multi_aff.set_pw_aff( + i, + pwaff_from_expr(space, arg) + ) usage_map = pw_multi_aff.as_map() - iname_to_timespace = usage_map.apply_range(compute_map).coalesce() + + iname_to_timespace = usage_map.apply_range(compute_map) iname_to_storage = iname_to_timespace.project_out_except( storage_indices, [isl.dim_type.out] ) + local_map = iname_to_storage.project_out_except( + kernel.all_inames() - frozenset(temporal_inames), + [isl.dim_type.in_] + ) + footprint = footprint | iname_to_storage.range() + usage_descrs.append( + UsageDescriptor( + usage, + iname_to_storage, + local_map + ) + ) + # add compute inames to domain / kernel domain_changer = DomainChanger(kernel, kernel.all_inames()) domain = domain_changer.domain @@ -138,7 +324,7 @@ def compute( storage_ax_to_global_expr = { compute_pw_aff.get_dim_name(isl.dim_type.out, dim) : pw_aff_to_expr(compute_pw_aff.get_at(dim)) - for dim in range(compute_pw_aff.dim(isl.dim_type.out)) + for dim in range(compute_pw_aff.dim(isl.dim_type.out)) } expr_subst_map = RuleAwareSubstitutionMapper( @@ -150,7 +336,9 @@ def compute( subst_expr = kernel.substitutions[substitution].expression compute_expression = expr_subst_map(subst_expr, kernel, None) - temporary_name = substitution + "_temp" + if not temporary_name: + temporary_name = substitution + "_temp" + assignee = var(temporary_name)[tuple( var(iname) for iname in storage_indices )] @@ -172,5 +360,22 @@ def compute( new_insns.append(compute_insn) kernel = kernel.copy(instructions=new_insns) + ctx = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator() + ) + + replacer = RuleInvocationReplacer( + ctx, + substitution, + None, + usage_descrs, + storage_indices, + temporary_name, + compute_insn_id, + compute_map + ) + + kernel = replacer.map_kernel(kernel) + print(kernel) return kernel From 9d121834435cd1a7ff7bf95805bcef13be6d356e Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 17 Mar 2026 22:59:58 -0500 Subject: [PATCH 09/18] rough sketch of compute transform; inames not schedulable because of duplicates --- loopy/transform/compute.py | 100 +++++++++++++++++++++++-------------- 1 file changed, 63 insertions(+), 37 deletions(-) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 02df1eaa7..342a3911c 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -1,8 +1,9 @@ -from collections.abc import Sequence, Set +from collections.abc import Mapping, Sequence, Set from dataclasses import dataclass from typing import override import loopy as lp from loopy.kernel.tools import DomainChanger +from loopy.types import to_loopy_type import namedisl as nisl from loopy.kernel import LoopKernel @@ -50,22 +51,6 @@ def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT): set=set_names ) - -@dataclass(frozen=True) -class UsageDescriptor: - usage: Sequence[Expression] - global_map: isl.Map - local_map: isl.Map - - @override - def __str__(self): - return ( - f"USAGE = {self.usage}\n" + - f"GLOBAL MAP = {self.global_map}\n" + - f"LOCAL MAP = {self.local_map}" - ) - - class UsageSiteExpressionGatherer(RuleAwareIdentityMapper[[]]): """ Gathers all expressions used as inputs to a particular substitution rule, @@ -129,7 +114,7 @@ def __init__( ctx: SubstitutionRuleMappingContext, subst_name: str, subst_tag: Sequence[Tag] | None, - usage_descriptors: Sequence[UsageDescriptor], + usage_descriptors: Mapping[tuple[Expression, ...], isl.Map], storage_indices: Sequence[str], temporary_name: str, compute_insn_id: str, @@ -141,12 +126,19 @@ def __init__( self.subst_name: str = subst_name self.subst_tag: Sequence[Tag] | None = subst_tag - self.usage_descriptors: Sequence[UsageDescriptor] = usage_descriptors + self.usage_descriptors: Mapping[tuple[Expression, ...], isl.Map] = \ + usage_descriptors self.storage_indices: Sequence[str] = storage_indices self.temporary_name: str = temporary_name self.compute_insn_id: str = compute_insn_id + # FIXME: may not always be the case (i.e. global barrier between + # compute insn and uses) + self.compute_dep_id: str = compute_insn_id + + self.replaced_something: bool = False + @override def map_subst_rule( @@ -175,17 +167,17 @@ def map_subst_rule( "does not match the signature of {name}.") index_exprs: Sequence[Expression] = [] - for usage_descr in self.usage_descriptors: - if args == usage_descr.usage: - local_pwmaff = usage_descr.local_map.as_pw_multi_aff() - for dim in range(local_pwmaff.dim(isl.dim_type.out)): - index_exprs.append(pw_aff_to_expr(local_pwmaff.get_at(dim))) + # FIXME: make self.usage_descriptors a constantdict + local_pwmaff = self.usage_descriptors[tuple(args)].as_pw_multi_aff() - break + for dim in range(local_pwmaff.dim(isl.dim_type.out)): + index_exprs.append(pw_aff_to_expr(local_pwmaff.get_at(dim))) new_expression = var(self.temporary_name)[tuple(index_exprs)] + self.replaced_something = True + return new_expression @@ -198,8 +190,10 @@ def map_kernel( map_tvs: bool = True ) -> LoopKernel: - new_insns = [] + new_insns: Sequence[lp.InstructionBase] = [] for insn in kernel.instructions: + self.replaced_something = False + if (isinstance(insn, lp.MultiAssignmentBase) and not contains_a_subst_rule_invocation(kernel, insn)): new_insns.append(insn) @@ -209,6 +203,15 @@ def map_kernel( lambda expr: self(expr, kernel, insn) ) + if self.replaced_something: + insn = insn.copy( + depends_on=( + insn.depends_on | frozenset([self.compute_insn_id]) + ) + ) + + # FIXME: determine compute insn dependencies + new_insns.append(insn) return kernel.copy(instructions=new_insns) @@ -270,7 +273,7 @@ def compute( ) ) - usage_descrs: Sequence[UsageDescriptor] = [] + usage_descrs: Mapping[tuple[Expression, ...], isl.Map] = {} for usage in usage_exprs: range_space = isl.Space.create_from_names( @@ -301,20 +304,14 @@ def compute( footprint = footprint | iname_to_storage.range() - usage_descrs.append( - UsageDescriptor( - usage, - iname_to_storage, - local_map - ) - ) + usage_descrs[tuple(usage)] = local_map # add compute inames to domain / kernel domain_changer = DomainChanger(kernel, kernel.all_inames()) domain = domain_changer.domain - footprint, domain = isl.align_two(footprint, domain) - domain = domain & footprint + footprint_tmp, domain = isl.align_two(footprint, domain) + domain = domain & footprint_tmp new_domains = domain_changer.get_domains_with(domain) kernel = kernel.copy(domains=new_domains) @@ -377,5 +374,34 @@ def compute( kernel = replacer.map_kernel(kernel) - print(kernel) + # FIXME: accept dtype as an argument + import numpy as np + loopy_type = to_loopy_type(np.float64, allow_none=True) + + # WARNING: this can result in symbolic shapes, is that allowed? + temp_shape = tuple( + pw_aff_to_expr(footprint.dim_max(dim)) + 1 + for dim in range(footprint.dim(isl.dim_type.out)) + ) + + new_temp_vars = dict(kernel.temporary_variables) + + # FIXME: temp_var might already exist, handle the case where it does + temp_var = lp.TemporaryVariable( + name=temporary_name, + dtype=loopy_type, + base_indices=(0,)*len(temp_shape), + shape=temp_shape, + address_space=temporary_address_space, + dim_names=tuple(storage_indices) + ) + + new_temp_vars[temporary_name] = temp_var + + kernel = kernel.copy( + temporary_variables=new_temp_vars + ) + + # FIXME: handle iname tagging + return kernel From 8667a01f196bf6f580b5014e74d522a64534e5e1 Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 17 Mar 2026 23:14:44 -0500 Subject: [PATCH 10/18] compute working for tiled matmul; write race condition warning --- loopy/transform/compute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 342a3911c..afdad1bf4 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -311,7 +311,7 @@ def compute( domain = domain_changer.domain footprint_tmp, domain = isl.align_two(footprint, domain) - domain = domain & footprint_tmp + domain = (domain & footprint_tmp).get_basic_sets()[0] new_domains = domain_changer.get_domains_with(domain) kernel = kernel.copy(domains=new_domains) From 959e68b78ddc044b68d0da8bad5e7885a480e3d9 Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 17 Mar 2026 23:15:57 -0500 Subject: [PATCH 11/18] add compute matmul example --- examples/compute-tiled-matmul.py | 120 +++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 examples/compute-tiled-matmul.py diff --git a/examples/compute-tiled-matmul.py b/examples/compute-tiled-matmul.py new file mode 100644 index 000000000..979f93421 --- /dev/null +++ b/examples/compute-tiled-matmul.py @@ -0,0 +1,120 @@ +import namedisl as nisl + +import loopy as lp +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 +from loopy.transform.compute import compute + +import numpy as np +import numpy.linalg as la +import pyopencl as cl + + +def main( + use_precompute: bool = False, + use_compute: bool = False, + run_kernel: bool = False + ) -> None: + + knl = lp.make_kernel( + "{ [i, j, k] : 0 <= i, j, k < 128 }", + """ + a_(is, ks) := a[is, ks] + b_(ks, js) := b[ks, js] + out[i, j] = sum([k], a_(i, k) * b_(k, j)) + """, + [ + lp.GlobalArg("a", shape=(128, 128), dtype=np.float64), + lp.GlobalArg("b", shape=(128, 128), dtype=np.float64), + lp.GlobalArg("out", shape=(128, 128), dtype=np.float64, + is_output=True) + ] + ) + + bm = bn = 32 + bk = 16 + + knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") + knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") + knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") + + compute_map_a = nisl.make_map(f"""{{ + [is, ks] -> [io, ii_s, ko, ki_s] : + 0 <= ii_s < {bm} and 0 <= ki_s < {bk} and + is = io * {bm} + ii_s and + ks = ko * {bk} + ki_s + }}""") + + compute_map_b = nisl.make_map(f"""{{ + [ks, js] -> [ko, ki_s, jo, ji_s] : + 0 <= ji_s < {bn} and 0 <= ki_s < {bk} and + js = jo * {bn} + ji_s and + ks = ko * {bk} + ki_s + }}""") + + if use_compute: + knl = compute( + knl, + "a_", + compute_map=compute_map_a, + storage_indices=["ii_s", "ki_s"], + temporal_inames=["io", "ko", "jo"], + temporary_address_space=lp.AddressSpace.LOCAL + ) + + knl = compute( + knl, + "b_", + compute_map=compute_map_b, + storage_indices=["ki_s", "ji_s"], + temporal_inames=["io", "ko", "jo"], + temporary_address_space=lp.AddressSpace.LOCAL + ) + + knl = lp.tag_inames( + knl, { + "io" : "g.0", + "jo" : "g.1", + "ii" : "l.0", + "ji" : "l.1", + } + ) + + knl = lp.add_inames_for_unused_hw_axes(knl) + + if use_precompute: + knl = lp.precompute( + knl, + "a_", + sweep_inames=["ii", "ki"], + ) + + if run_kernel: + a = np.random.randn(128, 128) + b = np.random.randn(128, 128) + + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + ex = knl.executor(ctx) + _, out = ex(queue, a=a, b=b) + + print(la.norm((a @ b) - out) / la.norm(out)) + + knl = lp.generate_code_v2(knl).device_code() + + print(knl) + + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + _ = parser.add_argument("--precompute", action="store_true") + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + + args = parser.parse_args() + + main(args.precompute, args.compute, args.run_kernel) From a566b6ee3e51acdaddb4b96516250c33dbbae2f3 Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 17 Mar 2026 23:20:05 -0500 Subject: [PATCH 12/18] clean up compute matmul example --- examples/compute-tiled-matmul.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/examples/compute-tiled-matmul.py b/examples/compute-tiled-matmul.py index 979f93421..17620b41d 100644 --- a/examples/compute-tiled-matmul.py +++ b/examples/compute-tiled-matmul.py @@ -12,7 +12,9 @@ def main( use_precompute: bool = False, use_compute: bool = False, - run_kernel: bool = False + run_kernel: bool = False, + print_kernel: bool = False, + print_device_code: bool = False ) -> None: knl = lp.make_kernel( @@ -98,11 +100,12 @@ def main( ex = knl.executor(ctx) _, out = ex(queue, a=a, b=b) - print(la.norm((a @ b) - out) / la.norm(out)) + print(f"Relative error = {la.norm((a @ b) - out) / la.norm(out)}") - knl = lp.generate_code_v2(knl).device_code() - - print(knl) + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + elif print_kernel: + print(knl) @@ -114,7 +117,15 @@ def main( _ = parser.add_argument("--precompute", action="store_true") _ = parser.add_argument("--compute", action="store_true") _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") args = parser.parse_args() - main(args.precompute, args.compute, args.run_kernel) + main( + args.precompute, + args.compute, + args.run_kernel, + args.print_kernel, + args.print_device_code + ) From cf920b597ec8fac3ea1d2f69856e506a9a9d863e Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 19 Mar 2026 08:23:43 -0500 Subject: [PATCH 13/18] improve matmul example with more parameters, better post-compute transformations --- examples/compute-tiled-matmul.py | 71 +++++--- .../compute-examples/compute-tiled-matmul.py | 156 ++++++++++++++++++ .../finite-difference-2-5D.py | 0 loopy/transform/compute.py | 6 +- 4 files changed, 203 insertions(+), 30 deletions(-) create mode 100644 examples/python/compute-examples/compute-tiled-matmul.py create mode 100644 examples/python/compute-examples/finite-difference-2-5D.py diff --git a/examples/compute-tiled-matmul.py b/examples/compute-tiled-matmul.py index 17620b41d..07dad6e1a 100644 --- a/examples/compute-tiled-matmul.py +++ b/examples/compute-tiled-matmul.py @@ -10,31 +10,34 @@ def main( - use_precompute: bool = False, - use_compute: bool = False, - run_kernel: bool = False, - print_kernel: bool = False, - print_device_code: bool = False - ) -> None: + M: int = 128, + N: int = 128, + K: int = 128, + bm: int = 32, + bn: int = 32, + bk: int = 16, + use_precompute: bool = False, + use_compute: bool = False, + run_kernel: bool = False, + print_kernel: bool = False, + print_device_code: bool = False +) -> None: knl = lp.make_kernel( - "{ [i, j, k] : 0 <= i, j, k < 128 }", + "{ [i, j, k] : 0 <= i < M and 0 <= j < N and 0 <= k < K }", """ a_(is, ks) := a[is, ks] b_(ks, js) := b[ks, js] - out[i, j] = sum([k], a_(i, k) * b_(k, j)) + c[i, j] = sum([k], a_(i, k) * b_(k, j)) """, [ - lp.GlobalArg("a", shape=(128, 128), dtype=np.float64), - lp.GlobalArg("b", shape=(128, 128), dtype=np.float64), - lp.GlobalArg("out", shape=(128, 128), dtype=np.float64, + lp.GlobalArg("a", shape=(M, K), dtype=np.float64), + lp.GlobalArg("b", shape=(K, N), dtype=np.float64), + lp.GlobalArg("c", shape=(M, N), dtype=np.float64, is_output=True) ] ) - bm = bn = 32 - bk = 16 - knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") @@ -72,16 +75,7 @@ def main( temporary_address_space=lp.AddressSpace.LOCAL ) - knl = lp.tag_inames( - knl, { - "io" : "g.0", - "jo" : "g.1", - "ii" : "l.0", - "ji" : "l.1", - } - ) - - knl = lp.add_inames_for_unused_hw_axes(knl) + # knl = lp.add_inames_for_unused_hw_axes(knl) if use_precompute: knl = lp.precompute( @@ -90,9 +84,20 @@ def main( sweep_inames=["ii", "ki"], ) + # knl = lp.tag_inames( + # knl, { + # "io" : "g.0", + # "jo" : "g.1", + # "ii" : "l.0", + # "ji" : "l.1", + # "ii_s": "l.0", + # "ji_s": "l.1" + # } + # ) + if run_kernel: - a = np.random.randn(128, 128) - b = np.random.randn(128, 128) + a = np.random.randn(M, K) + b = np.random.randn(K, N) ctx = cl.create_some_context() queue = cl.CommandQueue(ctx) @@ -120,9 +125,23 @@ def main( _ = parser.add_argument("--print-kernel", action="store_true") _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--m", action="store", type=int, default=128) + _ = parser.add_argument("--n", action="store", type=int, default=128) + _ = parser.add_argument("--k", action="store", type=int, default=128) + + _ = parser.add_argument("--bm", action="store", type=int, default=32) + _ = parser.add_argument("--bn", action="store", type=int, default=32) + _ = parser.add_argument("--bk", action="store", type=int, default=16) + args = parser.parse_args() main( + args.m, + args.n, + args.k, + args.bm, + args.bn, + args.bk, args.precompute, args.compute, args.run_kernel, diff --git a/examples/python/compute-examples/compute-tiled-matmul.py b/examples/python/compute-examples/compute-tiled-matmul.py new file mode 100644 index 000000000..67536f701 --- /dev/null +++ b/examples/python/compute-examples/compute-tiled-matmul.py @@ -0,0 +1,156 @@ +import namedisl as nisl + +import loopy as lp +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 +from loopy.transform.compute import compute + +import numpy as np +import numpy.linalg as la +import pyopencl as cl + + +def main( + M: int = 128, + N: int = 128, + K: int = 128, + bm: int = 32, + bn: int = 32, + bk: int = 16, + use_precompute: bool = False, + use_compute: bool = False, + run_kernel: bool = False, + print_kernel: bool = False, + print_device_code: bool = False + ) -> None: + + knl = lp.make_kernel( + "{ [i, j, k] : 0 <= i < M and 0 <= j < N and 0 <= k < K }", + """ + a_(is, ks) := a[is, ks] + b_(ks, js) := b[ks, js] + c[i, j] = sum([k], a_(i, k) * b_(k, j)) + """, + [ + lp.GlobalArg("a", shape=(M, K), dtype=np.float64), + lp.GlobalArg("b", shape=(K, N), dtype=np.float64), + lp.GlobalArg("c", shape=(M, N), dtype=np.float64, + is_output=True) + ] + ) + + # FIXME: without this, there are complaints about in-bounds access guarantees + knl = lp.fix_parameters(knl, M=M, N=N, K=K) + + knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") + knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") + knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") + + compute_map_a = nisl.make_map(f"""{{ + [is, ks] -> [io, ii_s, ko, ki_s] : + 0 <= ii_s < {bm} and 0 <= ki_s < {bk} and + is = io * {bm} + ii_s and + ks = ko * {bk} + ki_s + }}""") + + compute_map_b = nisl.make_map(f"""{{ + [ks, js] -> [ko, ki_s, jo, ji_s] : + 0 <= ji_s < {bn} and 0 <= ki_s < {bk} and + js = jo * {bn} + ji_s and + ks = ko * {bk} + ki_s + }}""") + + if use_compute: + knl = compute( + knl, + "a_", + compute_map=compute_map_a, + storage_indices=["ii_s", "ki_s"], + temporal_inames=["io", "ko", "jo"], + temporary_address_space=lp.AddressSpace.LOCAL + ) + + knl = compute( + knl, + "b_", + compute_map=compute_map_b, + storage_indices=["ki_s", "ji_s"], + temporal_inames=["io", "ko", "jo"], + temporary_address_space=lp.AddressSpace.LOCAL + ) + + if use_precompute: + knl = lp.precompute( + knl, + "a_", + sweep_inames=["ii", "ki"], + ) + + knl = lp.tag_inames( + knl, { + "io" : "g.0", # outer block loop over block rows + "jo" : "g.1", # outer block loop over block cols + + "ii" : "l.0", # inner block loop over rows + "ji" : "l.1", # inner block loop over cols + + "ii_s" : "l.0", # inner storage loop over a rows + "ji_s" : "l.0", # inner storage loop over b cols + "ki_s" : "l.1" # inner storage loop over a cols / b rows + } + ) + + knl = lp.add_inames_for_unused_hw_axes(knl) + + if run_kernel: + a = np.random.randn(M, K) + b = np.random.randn(K, N) + + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + ex = knl.executor(ctx) + _, out = ex(queue, a=a, b=b) + + print(f"Relative error = {la.norm((a @ b) - out) / la.norm(out)}") + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if print_kernel: + print(knl) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + _ = parser.add_argument("--precompute", action="store_true") + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") + + _ = parser.add_argument("--m", action="store", type=int, default=128) + _ = parser.add_argument("--n", action="store", type=int, default=128) + _ = parser.add_argument("--k", action="store", type=int, default=128) + + _ = parser.add_argument("--bm", action="store", type=int, default=32) + _ = parser.add_argument("--bn", action="store", type=int, default=32) + _ = parser.add_argument("--bk", action="store", type=int, default=16) + + args = parser.parse_args() + + main( + args.m, + args.n, + args.k, + args.bm, + args.bn, + args.bk, + args.precompute, + args.compute, + args.run_kernel, + args.print_kernel, + args.print_device_code + ) diff --git a/examples/python/compute-examples/finite-difference-2-5D.py b/examples/python/compute-examples/finite-difference-2-5D.py new file mode 100644 index 000000000..e69de29bb diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index afdad1bf4..6312f596a 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -1,5 +1,4 @@ from collections.abc import Mapping, Sequence, Set -from dataclasses import dataclass from typing import override import loopy as lp from loopy.kernel.tools import DomainChanger @@ -7,7 +6,7 @@ import namedisl as nisl from loopy.kernel import LoopKernel -from loopy.kernel.data import AddressSpace, SubstitutionRule +from loopy.kernel.data import AddressSpace from loopy.match import StackMatch, parse_stack_match from loopy.symbolic import ( ExpansionState, @@ -15,7 +14,6 @@ RuleAwareSubstitutionMapper, SubstitutionRuleExpander, SubstitutionRuleMappingContext, - get_dependencies, pw_aff_to_expr, pwaff_from_expr ) @@ -23,7 +21,7 @@ contains_a_subst_rule_invocation ) from loopy.translation_unit import for_each_kernel -from pymbolic import ArithmeticExpression, var +from pymbolic import var from pymbolic.mapper.substitutor import make_subst_func import islpy as isl From 4c7d1ac8dd764ec7a12f51eb7f485dcb04e9a513 Mon Sep 17 00:00:00 2001 From: Addison Date: Thu, 19 Mar 2026 19:07:53 -0500 Subject: [PATCH 14/18] add 2.5D FD example base; minor stylistic changes --- examples/compute-tiled-matmul.py | 150 ------------------ .../compute-examples/compute-tiled-matmul.py | 52 +++--- .../finite-difference-2-5D.py | 71 +++++++++ loopy/transform/compute.py | 11 +- 4 files changed, 109 insertions(+), 175 deletions(-) delete mode 100644 examples/compute-tiled-matmul.py diff --git a/examples/compute-tiled-matmul.py b/examples/compute-tiled-matmul.py deleted file mode 100644 index 07dad6e1a..000000000 --- a/examples/compute-tiled-matmul.py +++ /dev/null @@ -1,150 +0,0 @@ -import namedisl as nisl - -import loopy as lp -from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 -from loopy.transform.compute import compute - -import numpy as np -import numpy.linalg as la -import pyopencl as cl - - -def main( - M: int = 128, - N: int = 128, - K: int = 128, - bm: int = 32, - bn: int = 32, - bk: int = 16, - use_precompute: bool = False, - use_compute: bool = False, - run_kernel: bool = False, - print_kernel: bool = False, - print_device_code: bool = False -) -> None: - - knl = lp.make_kernel( - "{ [i, j, k] : 0 <= i < M and 0 <= j < N and 0 <= k < K }", - """ - a_(is, ks) := a[is, ks] - b_(ks, js) := b[ks, js] - c[i, j] = sum([k], a_(i, k) * b_(k, j)) - """, - [ - lp.GlobalArg("a", shape=(M, K), dtype=np.float64), - lp.GlobalArg("b", shape=(K, N), dtype=np.float64), - lp.GlobalArg("c", shape=(M, N), dtype=np.float64, - is_output=True) - ] - ) - - knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") - knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") - knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") - - compute_map_a = nisl.make_map(f"""{{ - [is, ks] -> [io, ii_s, ko, ki_s] : - 0 <= ii_s < {bm} and 0 <= ki_s < {bk} and - is = io * {bm} + ii_s and - ks = ko * {bk} + ki_s - }}""") - - compute_map_b = nisl.make_map(f"""{{ - [ks, js] -> [ko, ki_s, jo, ji_s] : - 0 <= ji_s < {bn} and 0 <= ki_s < {bk} and - js = jo * {bn} + ji_s and - ks = ko * {bk} + ki_s - }}""") - - if use_compute: - knl = compute( - knl, - "a_", - compute_map=compute_map_a, - storage_indices=["ii_s", "ki_s"], - temporal_inames=["io", "ko", "jo"], - temporary_address_space=lp.AddressSpace.LOCAL - ) - - knl = compute( - knl, - "b_", - compute_map=compute_map_b, - storage_indices=["ki_s", "ji_s"], - temporal_inames=["io", "ko", "jo"], - temporary_address_space=lp.AddressSpace.LOCAL - ) - - # knl = lp.add_inames_for_unused_hw_axes(knl) - - if use_precompute: - knl = lp.precompute( - knl, - "a_", - sweep_inames=["ii", "ki"], - ) - - # knl = lp.tag_inames( - # knl, { - # "io" : "g.0", - # "jo" : "g.1", - # "ii" : "l.0", - # "ji" : "l.1", - # "ii_s": "l.0", - # "ji_s": "l.1" - # } - # ) - - if run_kernel: - a = np.random.randn(M, K) - b = np.random.randn(K, N) - - ctx = cl.create_some_context() - queue = cl.CommandQueue(ctx) - - ex = knl.executor(ctx) - _, out = ex(queue, a=a, b=b) - - print(f"Relative error = {la.norm((a @ b) - out) / la.norm(out)}") - - if print_device_code: - print(lp.generate_code_v2(knl).device_code()) - elif print_kernel: - print(knl) - - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - - _ = parser.add_argument("--precompute", action="store_true") - _ = parser.add_argument("--compute", action="store_true") - _ = parser.add_argument("--run-kernel", action="store_true") - _ = parser.add_argument("--print-kernel", action="store_true") - _ = parser.add_argument("--print-device-code", action="store_true") - - _ = parser.add_argument("--m", action="store", type=int, default=128) - _ = parser.add_argument("--n", action="store", type=int, default=128) - _ = parser.add_argument("--k", action="store", type=int, default=128) - - _ = parser.add_argument("--bm", action="store", type=int, default=32) - _ = parser.add_argument("--bn", action="store", type=int, default=32) - _ = parser.add_argument("--bk", action="store", type=int, default=16) - - args = parser.parse_args() - - main( - args.m, - args.n, - args.k, - args.bm, - args.bn, - args.bk, - args.precompute, - args.compute, - args.run_kernel, - args.print_kernel, - args.print_device_code - ) diff --git a/examples/python/compute-examples/compute-tiled-matmul.py b/examples/python/compute-examples/compute-tiled-matmul.py index 67536f701..5b5f47c2f 100644 --- a/examples/python/compute-examples/compute-tiled-matmul.py +++ b/examples/python/compute-examples/compute-tiled-matmul.py @@ -16,6 +16,7 @@ def main( bm: int = 32, bn: int = 32, bk: int = 16, + run_sequentially: bool = False, use_precompute: bool = False, use_compute: bool = False, run_kernel: bool = False, @@ -85,19 +86,20 @@ def main( sweep_inames=["ii", "ki"], ) - knl = lp.tag_inames( - knl, { - "io" : "g.0", # outer block loop over block rows - "jo" : "g.1", # outer block loop over block cols + if not run_sequentially: + knl = lp.tag_inames( + knl, { + "io" : "g.0", # outer block loop over block rows + "jo" : "g.1", # outer block loop over block cols - "ii" : "l.0", # inner block loop over rows - "ji" : "l.1", # inner block loop over cols + "ii" : "l.0", # inner block loop over rows + "ji" : "l.1", # inner block loop over cols - "ii_s" : "l.0", # inner storage loop over a rows - "ji_s" : "l.0", # inner storage loop over b cols - "ki_s" : "l.1" # inner storage loop over a cols / b rows - } - ) + "ii_s" : "l.0", # inner storage loop over a rows + "ji_s" : "l.0", # inner storage loop over b cols + "ki_s" : "l.1" # inner storage loop over a cols / b rows + } + ) knl = lp.add_inames_for_unused_hw_axes(knl) @@ -111,7 +113,11 @@ def main( ex = knl.executor(ctx) _, out = ex(queue, a=a, b=b) + print(20*"=", "Tiled matmul report", 20*"=") + print(f"Problem size: M = {M:-4}, N = {N:-4}, K = {K:-4}") + print(f"Tile size : BM = {bm:-4}, BN = {bn:-4}, BK = {bk:-4}") print(f"Relative error = {la.norm((a @ b) - out) / la.norm(out)}") + print((40 + len(" Tiled matmul report "))*"=") if print_device_code: print(lp.generate_code_v2(knl).device_code()) @@ -130,6 +136,7 @@ def main( _ = parser.add_argument("--run-kernel", action="store_true") _ = parser.add_argument("--print-kernel", action="store_true") _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--run-sequentially", action="store_true") _ = parser.add_argument("--m", action="store", type=int, default=128) _ = parser.add_argument("--n", action="store", type=int, default=128) @@ -142,15 +149,16 @@ def main( args = parser.parse_args() main( - args.m, - args.n, - args.k, - args.bm, - args.bn, - args.bk, - args.precompute, - args.compute, - args.run_kernel, - args.print_kernel, - args.print_device_code + M=args.m, + N=args.n, + K=args.k, + bm=args.bm, + bn=args.bn, + bk=args.bk, + use_precompute=args.precompute, + use_compute=args.compute, + run_kernel=args.run_kernel, + print_kernel=args.print_kernel, + print_device_code=args.print_device_code, + run_sequentially=args.run_sequentially ) diff --git a/examples/python/compute-examples/finite-difference-2-5D.py b/examples/python/compute-examples/finite-difference-2-5D.py index e69de29bb..b1cb5c7c0 100644 --- a/examples/python/compute-examples/finite-difference-2-5D.py +++ b/examples/python/compute-examples/finite-difference-2-5D.py @@ -0,0 +1,71 @@ +import loopy as lp +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 +import numpy as np +import numpy.linalg as la +import pyopencl as cl + + +# FIXME: more complicated function, or better yet define a set of functions +# with sympy and get the exact laplacian symbolically +def f(x, y, z): + return x**2 + y**2 + z**2 + + +def laplacian_f(x, y, z): + return 6 * np.ones_like(x) + + +def main(use_compute: bool = False) -> None: + knl = lp.make_kernel( + "{ [i, j, k, l] : r <= i, j, k < npts - r and -r <= l < r + 1 }", + """ + u_(i, j, k) := u[i, j, k] + + lap_u[i,j,k] = sum([l], c[l+2] * (u[i-l,j,k] + u[i,j-l,k] + u[i,j,k-l])) + """ + ) + + if use_compute: + raise NotImplementedError("WIP") + + npts = 50 + pts = np.linspace(-1, 1, num=npts, endpoint=True) + h = pts[1] - pts[0] + + x, y, z = np.meshgrid(*(pts,)*3) + + x = x.reshape(*(npts,)*3) + y = y.reshape(*(npts,)*3) + z = z.reshape(*(npts,)*3) + + f_ = f(x, y, z) + lap_fd = np.zeros_like(f_) + c = np.array([-1/12, 4/3, -5/2, 4/3, -1/12]) / h**2 + + m = 5 + r = m // 2 + + knl = lp.fix_parameters(knl, npts=npts, r=r) + + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + ex = knl.executor(queue) + _, lap_fd = ex(queue, u=f(x, y, z), c=c) + + lap_true = laplacian_f(x, y, z) + sl = (slice(r, npts - r),)*3 + + print(la.norm(lap_true[sl] - lap_fd[0][sl]) / la.norm(lap_true[sl])) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + _ = parser.add_argument("--compute", action="store_true") + + args = parser.parse_args() + + main(use_compute=args.compute) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 6312f596a..8ff7a6637 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -1,5 +1,6 @@ from collections.abc import Mapping, Sequence, Set from typing import override +from typing_extensions import TypeAlias import loopy as lp from loopy.kernel.tools import DomainChanger from loopy.types import to_loopy_type @@ -31,6 +32,9 @@ from pytools.tag import Tag +AccessTuple: TypeAlias = AccessTuple + + def gather_vars(expr) -> set[str]: deps = DependencyMapper()(expr) return { @@ -112,7 +116,7 @@ def __init__( ctx: SubstitutionRuleMappingContext, subst_name: str, subst_tag: Sequence[Tag] | None, - usage_descriptors: Mapping[tuple[Expression, ...], isl.Map], + usage_descriptors: Mapping[AccessTuple, isl.Map], storage_indices: Sequence[str], temporary_name: str, compute_insn_id: str, @@ -124,7 +128,7 @@ def __init__( self.subst_name: str = subst_name self.subst_tag: Sequence[Tag] | None = subst_tag - self.usage_descriptors: Mapping[tuple[Expression, ...], isl.Map] = \ + self.usage_descriptors: Mapping[AccessTuple, isl.Map] = \ usage_descriptors self.storage_indices: Sequence[str] = storage_indices @@ -271,7 +275,7 @@ def compute( ) ) - usage_descrs: Mapping[tuple[Expression, ...], isl.Map] = {} + usage_descrs: Mapping[AccessTuple, isl.Map] = {} for usage in usage_exprs: range_space = isl.Space.create_from_names( @@ -291,6 +295,7 @@ def compute( usage_map = pw_multi_aff.as_map() iname_to_timespace = usage_map.apply_range(compute_map) + iname_to_storage = iname_to_timespace.project_out_except( storage_indices, [isl.dim_type.out] ) From 781df54974b549082491c703d757d317c118a2eb Mon Sep 17 00:00:00 2001 From: Addison Date: Mon, 23 Mar 2026 09:28:24 -0500 Subject: [PATCH 15/18] improvements to 2.5D example; bug fixes in compute transform --- .../finite-difference-2-5D.py | 99 +++++++++++++++---- loopy/transform/compute.py | 21 ++-- 2 files changed, 91 insertions(+), 29 deletions(-) diff --git a/examples/python/compute-examples/finite-difference-2-5D.py b/examples/python/compute-examples/finite-difference-2-5D.py index b1cb5c7c0..f1525da12 100644 --- a/examples/python/compute-examples/finite-difference-2-5D.py +++ b/examples/python/compute-examples/finite-difference-2-5D.py @@ -1,7 +1,12 @@ import loopy as lp from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 +from loopy.transform.compute import compute + +import namedisl as nisl + import numpy as np import numpy.linalg as la + import pyopencl as cl @@ -15,38 +20,88 @@ def laplacian_f(x, y, z): return 6 * np.ones_like(x) -def main(use_compute: bool = False) -> None: - knl = lp.make_kernel( - "{ [i, j, k, l] : r <= i, j, k < npts - r and -r <= l < r + 1 }", - """ - u_(i, j, k) := u[i, j, k] - - lap_u[i,j,k] = sum([l], c[l+2] * (u[i-l,j,k] + u[i,j-l,k] + u[i,j,k-l])) - """ - ) - - if use_compute: - raise NotImplementedError("WIP") - - npts = 50 +def main( + use_compute: bool = False, + print_device_code: bool = False, + print_kernel: bool = False + ) -> None: + npts = 64 pts = np.linspace(-1, 1, num=npts, endpoint=True) h = pts[1] - pts[0] x, y, z = np.meshgrid(*(pts,)*3) - x = x.reshape(*(npts,)*3) - y = y.reshape(*(npts,)*3) - z = z.reshape(*(npts,)*3) + dtype = np.float32 + x = x.reshape(*(npts,)*3).astype(np.float32) + y = y.reshape(*(npts,)*3).astype(np.float32) + z = z.reshape(*(npts,)*3).astype(np.float32) f_ = f(x, y, z) lap_fd = np.zeros_like(f_) - c = np.array([-1/12, 4/3, -5/2, 4/3, -1/12]) / h**2 + c = (np.array([-1/12, 4/3, -5/2, 4/3, -1/12]) / h**2).astype(dtype) m = 5 r = m // 2 + bm = bn = m + + # FIXME: the usage on the k dimension is incorrect since we are only testing + # tiling (i, j) planes + knl = lp.make_kernel( + "{ [i, j, k, l] : r <= i, j, k < npts - r and -r <= l < r + 1 }", + """ + u_(is, js, ks) := u[is, js, ks] + + lap_u[i,j,k] = sum( + [l], + c[l+r] * (u_(i-l,j,k) + u_(i,j-l,k) + u[i,j,k-l]) + ) + """, + [ + lp.GlobalArg("u", dtype=dtype, shape=(npts,npts,npts)), + lp.GlobalArg("lap_u", dtype=dtype, shape=(npts,npts,npts), + is_output=True), + lp.GlobalArg("c", dtype=dtype, shape=(m)) + ] + ) + knl = lp.fix_parameters(knl, npts=npts, r=r) + knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") + knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") + + # FIXME: need to split k dimension + + if use_compute: + compute_map = nisl.make_map( + f""" + {{ + [is, js, ks] -> [io, ii_s, jo, ji_s, k_s] : + 0 <= ii_s < {bm} and 0 <= ji_s < {bn} and 0 <= k_s < {npts} and + is = io * {bm} + ii_s and + js = jo * {bn} + ji_s and + ks = k_s + }} + """ + ) + + knl = compute( + knl, + "u_", + compute_map=compute_map, + storage_indices=["ii_s", "ji_s", "k_s"], + temporal_inames=["io", "jo"], + temporary_name="u_compute", + temporary_address_space=lp.AddressSpace.LOCAL, + temporary_dtype=np.float32 + ) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if print_kernel: + print(knl) + ctx = cl.create_some_context() queue = cl.CommandQueue(ctx) @@ -65,7 +120,13 @@ def main(use_compute: bool = False) -> None: parser = argparse.ArgumentParser() _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") args = parser.parse_args() - main(use_compute=args.compute) + main( + use_compute=args.compute, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel + ) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 8ff7a6637..78c02aed7 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -32,7 +32,7 @@ from pytools.tag import Tag -AccessTuple: TypeAlias = AccessTuple +AccessTuple: TypeAlias = tuple[Expression, ...] def gather_vars(expr) -> set[str]: @@ -230,7 +230,10 @@ def compute( temporal_inames: Sequence[str], temporary_name: str | None = None, - temporary_address_space: AddressSpace | None = None + temporary_address_space: AddressSpace | None = None, + + # FIXME: typing + temporary_dtype = None ) -> LoopKernel: """ Inserts an instruction to compute an expression given by :arg:`substitution` @@ -277,7 +280,6 @@ def compute( usage_descrs: Mapping[AccessTuple, isl.Map] = {} for usage in usage_exprs: - range_space = isl.Space.create_from_names( ctx=space.get_ctx(), set=list(storage_indices) @@ -286,27 +288,26 @@ def compute( pw_multi_aff = isl.MultiPwAff.zero(map_space) - for i, arg in enumerate(usage): + # FIXME: this will not work if usages are not ordered properly + for i in range(len(storage_indices)): pw_multi_aff = pw_multi_aff.set_pw_aff( i, - pwaff_from_expr(space, arg) + pwaff_from_expr(space, usage[i]) ) usage_map = pw_multi_aff.as_map() iname_to_timespace = usage_map.apply_range(compute_map) - iname_to_storage = iname_to_timespace.project_out_except( storage_indices, [isl.dim_type.out] ) + footprint = footprint | iname_to_storage.range() + local_map = iname_to_storage.project_out_except( kernel.all_inames() - frozenset(temporal_inames), [isl.dim_type.in_] ) - - footprint = footprint | iname_to_storage.range() - usage_descrs[tuple(usage)] = local_map # add compute inames to domain / kernel @@ -379,7 +380,7 @@ def compute( # FIXME: accept dtype as an argument import numpy as np - loopy_type = to_loopy_type(np.float64, allow_none=True) + loopy_type = to_loopy_type(temporary_dtype, allow_none=True) # WARNING: this can result in symbolic shapes, is that allowed? temp_shape = tuple( From cccc6a100654b6ecf1085068addf375836c01245 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 23 Mar 2026 11:02:12 -0500 Subject: [PATCH 16/18] Feedback from meeting --- examples/python/compute-examples/compute-tiled-matmul.py | 5 ++--- loopy/transform/compute.py | 3 +++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/python/compute-examples/compute-tiled-matmul.py b/examples/python/compute-examples/compute-tiled-matmul.py index 5b5f47c2f..d204a1922 100644 --- a/examples/python/compute-examples/compute-tiled-matmul.py +++ b/examples/python/compute-examples/compute-tiled-matmul.py @@ -40,22 +40,21 @@ def main( ) # FIXME: without this, there are complaints about in-bounds access guarantees - knl = lp.fix_parameters(knl, M=M, N=N, K=K) + # knl = lp.fix_parameters(knl, M=M, N=N, K=K) knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") + # FIXME: Given the input is already tiled, we shouldn't have to supply compute bounds here. compute_map_a = nisl.make_map(f"""{{ [is, ks] -> [io, ii_s, ko, ki_s] : - 0 <= ii_s < {bm} and 0 <= ki_s < {bk} and is = io * {bm} + ii_s and ks = ko * {bk} + ki_s }}""") compute_map_b = nisl.make_map(f"""{{ [ks, js] -> [ko, ki_s, jo, ji_s] : - 0 <= ji_s < {bn} and 0 <= ki_s < {bk} and js = jo * {bn} + ji_s and ks = ko * {bk} + ki_s }}""") diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index 78c02aed7..eb966003f 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -286,6 +286,7 @@ def compute( ) map_space = space.map_from_domain_and_range(range_space) + # FIXME package sequence of pymbolic exprs -> multipwaff up as a function in loopy.symbolic pw_multi_aff = isl.MultiPwAff.zero(map_space) # FIXME: this will not work if usages are not ordered properly @@ -295,8 +296,10 @@ def compute( pwaff_from_expr(space, usage[i]) ) + # FIXME intersect the (kernel) domain with the domain (of the map) here. usage_map = pw_multi_aff.as_map() + # FIXME defer as much of this project-y work to be done once, later iname_to_timespace = usage_map.apply_range(compute_map) iname_to_storage = iname_to_timespace.project_out_except( storage_indices, [isl.dim_type.out] From 6a7f7192a164e2c6b2049f3406d2192a8b70f067 Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 24 Mar 2026 00:39:42 -0500 Subject: [PATCH 17/18] update footprint finding to minimize projections, find bounds --- .../compute-examples/compute-tiled-matmul.py | 9 +- loopy/transform/compute.py | 155 ++++++++++++++---- 2 files changed, 132 insertions(+), 32 deletions(-) diff --git a/examples/python/compute-examples/compute-tiled-matmul.py b/examples/python/compute-examples/compute-tiled-matmul.py index d204a1922..25c33d46e 100644 --- a/examples/python/compute-examples/compute-tiled-matmul.py +++ b/examples/python/compute-examples/compute-tiled-matmul.py @@ -39,8 +39,9 @@ def main( ] ) - # FIXME: without this, there are complaints about in-bounds access guarantees - # knl = lp.fix_parameters(knl, M=M, N=N, K=K) + # FIXME: without this, there are complaints about in-bounds access + # guarantees for the instruction that stores into c + knl = lp.fix_parameters(knl, M=M, N=N, K=K) knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") @@ -48,13 +49,13 @@ def main( # FIXME: Given the input is already tiled, we shouldn't have to supply compute bounds here. compute_map_a = nisl.make_map(f"""{{ - [is, ks] -> [io, ii_s, ko, ki_s] : + [is, ks] -> [ii_s, io, ki_s, ko] : is = io * {bm} + ii_s and ks = ko * {bk} + ki_s }}""") compute_map_b = nisl.make_map(f"""{{ - [ks, js] -> [ko, ki_s, jo, ji_s] : + [ks, js] -> [ki_s, ko, ji_s, jo] : js = jo * {bn} + ji_s and ks = ko * {bk} + ki_s }}""") diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index eb966003f..d1823f31a 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -53,6 +53,40 @@ def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT): set=set_names ) + +def align_map_domain_to_set(m: isl.Map, s: isl.Set) -> isl.Map: + """ + Permute the domain dimensions of `m` to match the ordering of `s`, + by routing through the parameter space to preserve constraints. + + Example: + m = { [a, b, c, d] -> [e, f, g, h] } + s = { [d, c, b, a] } + result = { [d, c, b, a] -> [e, f, g, h] } + """ + dom_space = m.get_space().domain() + set_space = s.get_space() + + n = dom_space.dim(isl.dim_type.set) + assert set_space.dim(isl.dim_type.set) == n, "dimension count mismatch" + + dom_names = [dom_space.get_dim_name(isl.dim_type.set, i) for i in range(n)] + set_names = [set_space.get_dim_name(isl.dim_type.set, i) for i in range(n)] + assert set(dom_names) == set(set_names), "dimension names must be the same set" + + # Step 1: move all domain dims into the parameter space + n_params = m.dim(isl.dim_type.param) + m = m.move_dims(isl.dim_type.param, n_params, isl.dim_type.in_, 0, n) + + # Step 2: move each param back to in_ in the order dictated by set_names. + # find_dim_by_name accounts for shifting indices as dims are moved out. + for i, name in enumerate(set_names): + param_idx = m.find_dim_by_name(isl.dim_type.param, name) + m = m.move_dims(isl.dim_type.in_, i, isl.dim_type.param, param_idx, 1) + + return m + + class UsageSiteExpressionGatherer(RuleAwareIdentityMapper[[]]): """ Gathers all expressions used as inputs to a particular substitution rule, @@ -120,7 +154,7 @@ def __init__( storage_indices: Sequence[str], temporary_name: str, compute_insn_id: str, - compute_map: isl.Map + global_usage_map: isl.Map ) -> None: super().__init__(ctx) @@ -278,51 +312,113 @@ def compute( ) ) - usage_descrs: Mapping[AccessTuple, isl.Map] = {} + # add compute inames to domain / kernel + domain_changer = DomainChanger(kernel, kernel.all_inames()) + domain = domain_changer.domain + + range_space = isl.Space.create_from_names( + ctx=space.get_ctx(), + set=list(storage_indices) + ) + map_space = space.map_from_domain_and_range(range_space) + global_usage_map = isl.Map.empty(map_space) + for usage in usage_exprs: - range_space = isl.Space.create_from_names( - ctx=space.get_ctx(), - set=list(storage_indices) - ) - map_space = space.map_from_domain_and_range(range_space) # FIXME package sequence of pymbolic exprs -> multipwaff up as a function in loopy.symbolic - pw_multi_aff = isl.MultiPwAff.zero(map_space) + local_usage_mpwaff = isl.MultiPwAff.zero(map_space) # FIXME: this will not work if usages are not ordered properly for i in range(len(storage_indices)): - pw_multi_aff = pw_multi_aff.set_pw_aff( + local_usage_mpwaff = local_usage_mpwaff.set_pw_aff( i, pwaff_from_expr(space, usage[i]) ) # FIXME intersect the (kernel) domain with the domain (of the map) here. - usage_map = pw_multi_aff.as_map() + local_usage_map = local_usage_mpwaff.as_map() - # FIXME defer as much of this project-y work to be done once, later - iname_to_timespace = usage_map.apply_range(compute_map) - iname_to_storage = iname_to_timespace.project_out_except( - storage_indices, [isl.dim_type.out] + # FIXME: fix with namedisl + # remove unnecessary names from domain and intersect with usage map + usage_names = frozenset( + local_usage_map.get_dim_name(isl.dim_type.in_, dim) + for dim in range(local_usage_map.dim(isl.dim_type.in_)) ) - footprint = footprint | iname_to_storage.range() + domain_names = frozenset( + domain.get_dim_name(isl.dim_type.set, dim) + for dim in range(domain.dim(isl.dim_type.set)) + ) - local_map = iname_to_storage.project_out_except( - kernel.all_inames() - frozenset(temporal_inames), - [isl.dim_type.in_] + domain_tmp = domain.project_out_except( + usage_names & domain_names, [isl.dim_type.set] ) - usage_descrs[tuple(usage)] = local_map - # add compute inames to domain / kernel - domain_changer = DomainChanger(kernel, kernel.all_inames()) - domain = domain_changer.domain + local_usage_map = align_map_domain_to_set(local_usage_map, domain_tmp) + + local_usage_map = local_usage_map.intersect_domain(domain_tmp) + + # U : G -> S + # C : S -> I + # C o U : G -> I + global_usage_map = global_usage_map | local_usage_map + + + # {{{ FIXME: this shouldn't need to be done here; will be handled by namedisl + + global_usage_map = global_usage_map.apply_range(compute_map) + common_dims = { + dim1 : dim2 + for dim1 in range(global_usage_map.dim(isl.dim_type.in_)) + for dim2 in range(global_usage_map.dim(isl.dim_type.out)) + if ( + global_usage_map.get_dim_name(isl.dim_type.in_, dim1) + == + global_usage_map.get_dim_name(isl.dim_type.out, dim2) + ) + } + + for pos1, pos2 in common_dims.items(): + global_usage_map = global_usage_map.equate( + isl.dim_type.in_, pos1, + isl.dim_type.out, pos2 + ) + + # }}} + print(domain) + footprint = global_usage_map.range() footprint_tmp, domain = isl.align_two(footprint, domain) domain = (domain & footprint_tmp).get_basic_sets()[0] new_domains = domain_changer.get_domains_with(domain) kernel = kernel.copy(domains=new_domains) + # {{{ find all index expressions + + usage_substs: Mapping[AccessTuple, isl.Map] = {} + for usage in usage_exprs: + # find the relevant names + relevant_names = gather_vars(usage) + + # project out irrelevant names + relevant_names = frozenset(relevant_names) - frozenset(temporal_inames) + + local_iname_to_storage = global_usage_map.project_out_except( + relevant_names, + [isl.dim_type.in_] + ) + + local_iname_to_storage = local_iname_to_storage.project_out_except( + storage_indices, + [isl.dim_type.out] + ) + + # map usage -> resulting map + usage_substs[tuple(usage)] = local_iname_to_storage + + # }}} + # create compute instruction in kernel compute_pw_aff = compute_map.reverse().as_pw_multi_aff() storage_ax_to_global_expr = { @@ -372,23 +468,26 @@ def compute( ctx, substitution, None, - usage_descrs, + usage_substs, storage_indices, temporary_name, compute_insn_id, - compute_map + global_usage_map ) kernel = replacer.map_kernel(kernel) # FIXME: accept dtype as an argument - import numpy as np loopy_type = to_loopy_type(temporary_dtype, allow_none=True) - # WARNING: this can result in symbolic shapes, is that allowed? + # FIXME: need a better way to determine the shape + shape_domain = footprint.project_out_except(storage_indices, + [isl.dim_type.set]) + shape_domain = shape_domain.project_out_except("", [isl.dim_type.param]) + temp_shape = tuple( - pw_aff_to_expr(footprint.dim_max(dim)) + 1 - for dim in range(footprint.dim(isl.dim_type.out)) + pw_aff_to_expr(shape_domain.dim_max(dim)) + 1 + for dim in range(shape_domain.dim(isl.dim_type.out)) ) new_temp_vars = dict(kernel.temporary_variables) From aa8b612997680836b29f442847a1ebe4914d4601 Mon Sep 17 00:00:00 2001 From: Addison Date: Tue, 24 Mar 2026 00:48:58 -0500 Subject: [PATCH 18/18] add/remove FIXMEs --- loopy/transform/compute.py | 52 ++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py index d1823f31a..e67d583c2 100644 --- a/loopy/transform/compute.py +++ b/loopy/transform/compute.py @@ -35,6 +35,7 @@ AccessTuple: TypeAlias = tuple[Expression, ...] +# FIXME: move to loopy/symbolic.py def gather_vars(expr) -> set[str]: deps = DependencyMapper()(expr) return { @@ -44,6 +45,7 @@ def gather_vars(expr) -> set[str]: } +# FIXME: move to loopy/symbolic.py def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT): names = sorted(set().union(*(gather_vars(expr) for expr in exprs))) set_names = [name for name in names] @@ -54,6 +56,7 @@ def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT): ) +# FIXME: remove this and rely on namedisl def align_map_domain_to_set(m: isl.Map, s: isl.Set) -> isl.Map: """ Permute the domain dimensions of `m` to match the ordering of `s`, @@ -74,12 +77,9 @@ def align_map_domain_to_set(m: isl.Map, s: isl.Set) -> isl.Map: set_names = [set_space.get_dim_name(isl.dim_type.set, i) for i in range(n)] assert set(dom_names) == set(set_names), "dimension names must be the same set" - # Step 1: move all domain dims into the parameter space n_params = m.dim(isl.dim_type.param) m = m.move_dims(isl.dim_type.param, n_params, isl.dim_type.in_, 0, n) - # Step 2: move each param back to in_ in the order dictated by set_names. - # find_dim_by_name accounts for shifting indices as dims are moved out. for i, name in enumerate(set_names): param_idx = m.find_dim_by_name(isl.dim_type.param, name) m = m.move_dims(isl.dim_type.in_, i, isl.dim_type.param, param_idx, 1) @@ -194,9 +194,9 @@ def map_subst_rule( ) args = [arg_ctx[arg_name] for arg_name in rule.arguments] - # FIXME: footprint check? likely required if user supplies bounds on - # storage indices because we are not guaranteed to capture the footprint - # of all usage sites + # FIXME: usage within footprint check? likely required if user supplies + # bounds on storage indices because we are not guaranteed to capture the + # footprint of all usage sites if not len(arguments) == len(rule.arguments): raise ValueError("Number of arguments passed to rule {name} ", @@ -204,7 +204,6 @@ def map_subst_rule( index_exprs: Sequence[Expression] = [] - # FIXME: make self.usage_descriptors a constantdict local_pwmaff = self.usage_descriptors[tuple(args)].as_pw_multi_aff() for dim in range(local_pwmaff.dim(isl.dim_type.out)): @@ -284,9 +283,11 @@ def compute( :arg storage_indices: An ordered sequence of names of storage indices. Used to create inames for the loops that cover the required set of compute points. """ + # FIXME: use namedisl directly compute_map = compute_map._reconstruct_isl_object() - # construct union of usage footprints to determine bounds on compute inames + # {{{ construct necessary pieces; footprint, global usage map + ctx = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) expander = SubstitutionRuleExpander(kernel.substitutions) @@ -328,14 +329,12 @@ def compute( # FIXME package sequence of pymbolic exprs -> multipwaff up as a function in loopy.symbolic local_usage_mpwaff = isl.MultiPwAff.zero(map_space) - # FIXME: this will not work if usages are not ordered properly for i in range(len(storage_indices)): local_usage_mpwaff = local_usage_mpwaff.set_pw_aff( i, pwaff_from_expr(space, usage[i]) ) - # FIXME intersect the (kernel) domain with the domain (of the map) here. local_usage_map = local_usage_mpwaff.as_map() # FIXME: fix with namedisl @@ -355,15 +354,9 @@ def compute( ) local_usage_map = align_map_domain_to_set(local_usage_map, domain_tmp) - local_usage_map = local_usage_map.intersect_domain(domain_tmp) - - # U : G -> S - # C : S -> I - # C o U : G -> I global_usage_map = global_usage_map | local_usage_map - # {{{ FIXME: this shouldn't need to be done here; will be handled by namedisl global_usage_map = global_usage_map.apply_range(compute_map) @@ -386,7 +379,10 @@ def compute( # }}} - print(domain) + # }}} + + # {{{ compute bounds and update kernel domain + footprint = global_usage_map.range() footprint_tmp, domain = isl.align_two(footprint, domain) domain = (domain & footprint_tmp).get_basic_sets()[0] @@ -394,7 +390,9 @@ def compute( new_domains = domain_changer.get_domains_with(domain) kernel = kernel.copy(domains=new_domains) - # {{{ find all index expressions + # }}} + + # {{{ compute index expressions usage_substs: Mapping[AccessTuple, isl.Map] = {} for usage in usage_exprs: @@ -419,7 +417,8 @@ def compute( # }}} - # create compute instruction in kernel + # {{{ create compute instruction in kernel + compute_pw_aff = compute_map.reverse().as_pw_multi_aff() storage_ax_to_global_expr = { compute_pw_aff.get_dim_name(isl.dim_type.out, dim) : @@ -460,6 +459,10 @@ def compute( new_insns.append(compute_insn) kernel = kernel.copy(instructions=new_insns) + # }}} + + # {{{ replace invocations with new compute instruction + ctx = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator() ) @@ -477,10 +480,13 @@ def compute( kernel = replacer.map_kernel(kernel) - # FIXME: accept dtype as an argument + # }}} + + # {{{ create temporary variable for result of compute + loopy_type = to_loopy_type(temporary_dtype, allow_none=True) - # FIXME: need a better way to determine the shape + # FIXME: fix with namedisl? shape_domain = footprint.project_out_except(storage_indices, [isl.dim_type.set]) shape_domain = shape_domain.project_out_except("", [isl.dim_type.param]) @@ -508,6 +514,8 @@ def compute( temporary_variables=new_temp_vars ) - # FIXME: handle iname tagging + # }}} + + # FIXME: anything else? return kernel