diff --git a/examples/python/compute-examples/compute-tiled-matmul.py b/examples/python/compute-examples/compute-tiled-matmul.py new file mode 100644 index 000000000..25c33d46e --- /dev/null +++ b/examples/python/compute-examples/compute-tiled-matmul.py @@ -0,0 +1,164 @@ +import namedisl as nisl + +import loopy as lp +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 +from loopy.transform.compute import compute + +import numpy as np +import numpy.linalg as la +import pyopencl as cl + + +def main( + M: int = 128, + N: int = 128, + K: int = 128, + bm: int = 32, + bn: int = 32, + bk: int = 16, + run_sequentially: bool = False, + use_precompute: bool = False, + use_compute: bool = False, + run_kernel: bool = False, + print_kernel: bool = False, + print_device_code: bool = False + ) -> None: + + knl = lp.make_kernel( + "{ [i, j, k] : 0 <= i < M and 0 <= j < N and 0 <= k < K }", + """ + a_(is, ks) := a[is, ks] + b_(ks, js) := b[ks, js] + c[i, j] = sum([k], a_(i, k) * b_(k, j)) + """, + [ + lp.GlobalArg("a", shape=(M, K), dtype=np.float64), + lp.GlobalArg("b", shape=(K, N), dtype=np.float64), + lp.GlobalArg("c", shape=(M, N), dtype=np.float64, + is_output=True) + ] + ) + + # FIXME: without this, there are complaints about in-bounds access + # guarantees for the instruction that stores into c + knl = lp.fix_parameters(knl, M=M, N=N, K=K) + + knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") + knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") + knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko") + + # FIXME: Given the input is already tiled, we shouldn't have to supply compute bounds here. + compute_map_a = nisl.make_map(f"""{{ + [is, ks] -> [ii_s, io, ki_s, ko] : + is = io * {bm} + ii_s and + ks = ko * {bk} + ki_s + }}""") + + compute_map_b = nisl.make_map(f"""{{ + [ks, js] -> [ki_s, ko, ji_s, jo] : + js = jo * {bn} + ji_s and + ks = ko * {bk} + ki_s + }}""") + + if use_compute: + knl = compute( + knl, + "a_", + compute_map=compute_map_a, + storage_indices=["ii_s", "ki_s"], + temporal_inames=["io", "ko", "jo"], + temporary_address_space=lp.AddressSpace.LOCAL + ) + + knl = compute( + knl, + "b_", + compute_map=compute_map_b, + storage_indices=["ki_s", "ji_s"], + temporal_inames=["io", "ko", "jo"], + temporary_address_space=lp.AddressSpace.LOCAL + ) + + if use_precompute: + knl = lp.precompute( + knl, + "a_", + sweep_inames=["ii", "ki"], + ) + + if not run_sequentially: + knl = lp.tag_inames( + knl, { + "io" : "g.0", # outer block loop over block rows + "jo" : "g.1", # outer block loop over block cols + + "ii" : "l.0", # inner block loop over rows + "ji" : "l.1", # inner block loop over cols + + "ii_s" : "l.0", # inner storage loop over a rows + "ji_s" : "l.0", # inner storage loop over b cols + "ki_s" : "l.1" # inner storage loop over a cols / b rows + } + ) + + knl = lp.add_inames_for_unused_hw_axes(knl) + + if run_kernel: + a = np.random.randn(M, K) + b = np.random.randn(K, N) + + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + ex = knl.executor(ctx) + _, out = ex(queue, a=a, b=b) + + print(20*"=", "Tiled matmul report", 20*"=") + print(f"Problem size: M = {M:-4}, N = {N:-4}, K = {K:-4}") + print(f"Tile size : BM = {bm:-4}, BN = {bn:-4}, BK = {bk:-4}") + print(f"Relative error = {la.norm((a @ b) - out) / la.norm(out)}") + print((40 + len(" Tiled matmul report "))*"=") + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if print_kernel: + print(knl) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + _ = parser.add_argument("--precompute", action="store_true") + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--run-kernel", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--run-sequentially", action="store_true") + + _ = parser.add_argument("--m", action="store", type=int, default=128) + _ = parser.add_argument("--n", action="store", type=int, default=128) + _ = parser.add_argument("--k", action="store", type=int, default=128) + + _ = parser.add_argument("--bm", action="store", type=int, default=32) + _ = parser.add_argument("--bn", action="store", type=int, default=32) + _ = parser.add_argument("--bk", action="store", type=int, default=16) + + args = parser.parse_args() + + main( + M=args.m, + N=args.n, + K=args.k, + bm=args.bm, + bn=args.bn, + bk=args.bk, + use_precompute=args.precompute, + use_compute=args.compute, + run_kernel=args.run_kernel, + print_kernel=args.print_kernel, + print_device_code=args.print_device_code, + run_sequentially=args.run_sequentially + ) diff --git a/examples/python/compute-examples/finite-difference-2-5D.py b/examples/python/compute-examples/finite-difference-2-5D.py new file mode 100644 index 000000000..f1525da12 --- /dev/null +++ b/examples/python/compute-examples/finite-difference-2-5D.py @@ -0,0 +1,132 @@ +import loopy as lp +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 +from loopy.transform.compute import compute + +import namedisl as nisl + +import numpy as np +import numpy.linalg as la + +import pyopencl as cl + + +# FIXME: more complicated function, or better yet define a set of functions +# with sympy and get the exact laplacian symbolically +def f(x, y, z): + return x**2 + y**2 + z**2 + + +def laplacian_f(x, y, z): + return 6 * np.ones_like(x) + + +def main( + use_compute: bool = False, + print_device_code: bool = False, + print_kernel: bool = False + ) -> None: + npts = 64 + pts = np.linspace(-1, 1, num=npts, endpoint=True) + h = pts[1] - pts[0] + + x, y, z = np.meshgrid(*(pts,)*3) + + dtype = np.float32 + x = x.reshape(*(npts,)*3).astype(np.float32) + y = y.reshape(*(npts,)*3).astype(np.float32) + z = z.reshape(*(npts,)*3).astype(np.float32) + + f_ = f(x, y, z) + lap_fd = np.zeros_like(f_) + c = (np.array([-1/12, 4/3, -5/2, 4/3, -1/12]) / h**2).astype(dtype) + + m = 5 + r = m // 2 + + bm = bn = m + + # FIXME: the usage on the k dimension is incorrect since we are only testing + # tiling (i, j) planes + knl = lp.make_kernel( + "{ [i, j, k, l] : r <= i, j, k < npts - r and -r <= l < r + 1 }", + """ + u_(is, js, ks) := u[is, js, ks] + + lap_u[i,j,k] = sum( + [l], + c[l+r] * (u_(i-l,j,k) + u_(i,j-l,k) + u[i,j,k-l]) + ) + """, + [ + lp.GlobalArg("u", dtype=dtype, shape=(npts,npts,npts)), + lp.GlobalArg("lap_u", dtype=dtype, shape=(npts,npts,npts), + is_output=True), + lp.GlobalArg("c", dtype=dtype, shape=(m)) + ] + ) + + knl = lp.fix_parameters(knl, npts=npts, r=r) + + knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io") + knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo") + + # FIXME: need to split k dimension + + if use_compute: + compute_map = nisl.make_map( + f""" + {{ + [is, js, ks] -> [io, ii_s, jo, ji_s, k_s] : + 0 <= ii_s < {bm} and 0 <= ji_s < {bn} and 0 <= k_s < {npts} and + is = io * {bm} + ii_s and + js = jo * {bn} + ji_s and + ks = k_s + }} + """ + ) + + knl = compute( + knl, + "u_", + compute_map=compute_map, + storage_indices=["ii_s", "ji_s", "k_s"], + temporal_inames=["io", "jo"], + temporary_name="u_compute", + temporary_address_space=lp.AddressSpace.LOCAL, + temporary_dtype=np.float32 + ) + + if print_device_code: + print(lp.generate_code_v2(knl).device_code()) + + if print_kernel: + print(knl) + + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + ex = knl.executor(queue) + _, lap_fd = ex(queue, u=f(x, y, z), c=c) + + lap_true = laplacian_f(x, y, z) + sl = (slice(r, npts - r),)*3 + + print(la.norm(lap_true[sl] - lap_fd[0][sl]) / la.norm(lap_true[sl])) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + _ = parser.add_argument("--compute", action="store_true") + _ = parser.add_argument("--print-device-code", action="store_true") + _ = parser.add_argument("--print-kernel", action="store_true") + + args = parser.parse_args() + + main( + use_compute=args.compute, + print_device_code=args.print_device_code, + print_kernel=args.print_kernel + ) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index ba6d71a80..f47e32f9d 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -48,6 +48,7 @@ from constantdict import constantdict from typing_extensions import Self, override +import namedisl as nisl import islpy as isl import pymbolic.primitives as p import pytools.lex @@ -2044,30 +2045,45 @@ def map_subscript(self, expr: p.Subscript) -> Set[p.Subscript]: # {{{ (pw)aff to expr conversion -def aff_to_expr(aff: isl.Aff) -> ArithmeticExpression: +def aff_to_expr(aff: isl.Aff | nisl.Aff) -> ArithmeticExpression: from pymbolic import var + # FIXME: remove this once namedisl is the standard in loopy denom = aff.get_denominator_val().to_python() - result = (aff.get_constant_val()*denom).to_python() - for dt in [isl.dim_type.in_, isl.dim_type.param]: - for i in range(aff.dim(dt)): - coeff = (aff.get_coefficient_val(dt, i)*denom).to_python() + if isinstance(aff, isl.Aff): + for dt in [isl.dim_type.in_, isl.dim_type.param]: + for i in range(aff.dim(dt)): + coeff = (aff.get_coefficient_val(dt, i)*denom).to_python() + if coeff: + dim_name = not_none(aff.get_dim_name(dt, i)) + result += coeff*var(dim_name) + + for i in range(aff.dim(isl.dim_type.div)): + coeff = (aff.get_coefficient_val(isl.dim_type.div, i)*denom).to_python() + if coeff: + result += coeff*aff_to_expr(aff.get_div(i)) + + else: + in_names = set(aff.dim_type_names(isl.dim_type.in_)) + param_names = set(aff.dim_type_names(isl.dim_type.param)) + + for name in in_names | param_names: + coeff = (aff.get_coefficient_val(name) * denom).to_python() if coeff: - dim_name = not_none(aff.get_dim_name(dt, i)) - result += coeff*var(dim_name) + result = coeff * var(name) - for i in range(aff.dim(isl.dim_type.div)): - coeff = (aff.get_coefficient_val(isl.dim_type.div, i)*denom).to_python() - if coeff: - result += coeff*aff_to_expr(aff.get_div(i)) + for name in aff.dim_type_names(isl.dim_type.div): + coeff = (aff.get_coefficient_val(name) * denom).to_python() + if coeff: + result += coeff * aff_to_expr(aff.get_div(name)) assert not isinstance(result, complex) return flatten(result // denom) def pw_aff_to_expr( - pw_aff: int | isl.PwAff | isl.Aff, + pw_aff: int | isl.PwAff | isl.Aff | nisl.PwAff | nisl.Aff, int_ok: bool = False ) -> ArithmeticExpression: if isinstance(pw_aff, int): diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py new file mode 100644 index 000000000..e67d583c2 --- /dev/null +++ b/loopy/transform/compute.py @@ -0,0 +1,521 @@ +from collections.abc import Mapping, Sequence, Set +from typing import override +from typing_extensions import TypeAlias +import loopy as lp +from loopy.kernel.tools import DomainChanger +from loopy.types import to_loopy_type +import namedisl as nisl + +from loopy.kernel import LoopKernel +from loopy.kernel.data import AddressSpace +from loopy.match import StackMatch, parse_stack_match +from loopy.symbolic import ( + ExpansionState, + RuleAwareIdentityMapper, + RuleAwareSubstitutionMapper, + SubstitutionRuleExpander, + SubstitutionRuleMappingContext, + pw_aff_to_expr, + pwaff_from_expr +) +from loopy.transform.precompute import ( + contains_a_subst_rule_invocation +) +from loopy.translation_unit import for_each_kernel +from pymbolic import var +from pymbolic.mapper.substitutor import make_subst_func + +import islpy as isl +import pymbolic.primitives as p +from pymbolic.mapper.dependency import DependencyMapper +from pymbolic.typing import Expression +from pytools.tag import Tag + + +AccessTuple: TypeAlias = tuple[Expression, ...] + + +# FIXME: move to loopy/symbolic.py +def gather_vars(expr) -> set[str]: + deps = DependencyMapper()(expr) + return { + dep.name + for dep in deps + if isinstance(dep, p.Variable) + } + + +# FIXME: move to loopy/symbolic.py +def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT): + names = sorted(set().union(*(gather_vars(expr) for expr in exprs))) + set_names = [name for name in names] + + return isl.Space.create_from_names( + ctx, + set=set_names + ) + + +# FIXME: remove this and rely on namedisl +def align_map_domain_to_set(m: isl.Map, s: isl.Set) -> isl.Map: + """ + Permute the domain dimensions of `m` to match the ordering of `s`, + by routing through the parameter space to preserve constraints. + + Example: + m = { [a, b, c, d] -> [e, f, g, h] } + s = { [d, c, b, a] } + result = { [d, c, b, a] -> [e, f, g, h] } + """ + dom_space = m.get_space().domain() + set_space = s.get_space() + + n = dom_space.dim(isl.dim_type.set) + assert set_space.dim(isl.dim_type.set) == n, "dimension count mismatch" + + dom_names = [dom_space.get_dim_name(isl.dim_type.set, i) for i in range(n)] + set_names = [set_space.get_dim_name(isl.dim_type.set, i) for i in range(n)] + assert set(dom_names) == set(set_names), "dimension names must be the same set" + + n_params = m.dim(isl.dim_type.param) + m = m.move_dims(isl.dim_type.param, n_params, isl.dim_type.in_, 0, n) + + for i, name in enumerate(set_names): + param_idx = m.find_dim_by_name(isl.dim_type.param, name) + m = m.move_dims(isl.dim_type.in_, i, isl.dim_type.param, param_idx, 1) + + return m + + +class UsageSiteExpressionGatherer(RuleAwareIdentityMapper[[]]): + """ + Gathers all expressions used as inputs to a particular substitution rule, + identified by name. + """ + def __init__( + self, + rule_mapping_ctx: SubstitutionRuleMappingContext, + subst_expander: SubstitutionRuleExpander, + kernel: LoopKernel, + subst_name: str, + subst_tag: Set[Tag] | Tag | None = None + ) -> None: + + super().__init__(rule_mapping_ctx) + + self.subst_expander: SubstitutionRuleExpander = subst_expander + self.kernel: LoopKernel = kernel + self.subst_name: str = subst_name + self.subst_tag: Set[Tag] | None = ( + {subst_tag} if isinstance(subst_tag, Tag) else subst_tag + ) + + self.usage_expressions: list[Sequence[Expression]] = [] + + + @override + def map_subst_rule( + self, + name: str, + tags: Set[Tag] | None, + arguments: Sequence[Expression], + expn_state: ExpansionState, + ) -> Expression: + + if name != self.subst_name: + return super().map_subst_rule( + name, tags, arguments, expn_state + ) + + if self.subst_tag is not None and self.subst_tag != tags: + return super().map_subst_rule( + name, tags, arguments, expn_state + ) + + rule = self.rule_mapping_context.old_subst_rules[name] + arg_ctx = self.make_new_arg_context( + name, rule.arguments, arguments, expn_state.arg_context + ) + + self.usage_expressions.append([ + arg_ctx[arg_name] for arg_name in rule.arguments + ]) + + return 0 + + +class RuleInvocationReplacer(RuleAwareIdentityMapper[[]]): + def __init__( + self, + ctx: SubstitutionRuleMappingContext, + subst_name: str, + subst_tag: Sequence[Tag] | None, + usage_descriptors: Mapping[AccessTuple, isl.Map], + storage_indices: Sequence[str], + temporary_name: str, + compute_insn_id: str, + global_usage_map: isl.Map + ) -> None: + + super().__init__(ctx) + + self.subst_name: str = subst_name + self.subst_tag: Sequence[Tag] | None = subst_tag + + self.usage_descriptors: Mapping[AccessTuple, isl.Map] = \ + usage_descriptors + self.storage_indices: Sequence[str] = storage_indices + + self.temporary_name: str = temporary_name + self.compute_insn_id: str = compute_insn_id + + # FIXME: may not always be the case (i.e. global barrier between + # compute insn and uses) + self.compute_dep_id: str = compute_insn_id + + self.replaced_something: bool = False + + + @override + def map_subst_rule( + self, + name: str, + tags: Set[Tag] | None, + arguments: Sequence[Expression], + expn_state: ExpansionState + ) -> Expression: + + if not name == self.subst_name: + return super().map_subst_rule(name, tags, arguments, expn_state) + + rule = self.rule_mapping_context.old_subst_rules[name] + arg_ctx = self.make_new_arg_context( + name, rule.arguments, arguments, expn_state.arg_context + ) + args = [arg_ctx[arg_name] for arg_name in rule.arguments] + + # FIXME: usage within footprint check? likely required if user supplies + # bounds on storage indices because we are not guaranteed to capture the + # footprint of all usage sites + + if not len(arguments) == len(rule.arguments): + raise ValueError("Number of arguments passed to rule {name} ", + "does not match the signature of {name}.") + + index_exprs: Sequence[Expression] = [] + + local_pwmaff = self.usage_descriptors[tuple(args)].as_pw_multi_aff() + + for dim in range(local_pwmaff.dim(isl.dim_type.out)): + index_exprs.append(pw_aff_to_expr(local_pwmaff.get_at(dim))) + + new_expression = var(self.temporary_name)[tuple(index_exprs)] + + self.replaced_something = True + + return new_expression + + + @override + def map_kernel( + self, + kernel: LoopKernel, + within: StackMatch = lambda knl, insn, stack: True, + map_args: bool = True, + map_tvs: bool = True + ) -> LoopKernel: + + new_insns: Sequence[lp.InstructionBase] = [] + for insn in kernel.instructions: + self.replaced_something = False + + if (isinstance(insn, lp.MultiAssignmentBase) and not + contains_a_subst_rule_invocation(kernel, insn)): + new_insns.append(insn) + continue + + insn = insn.with_transformed_expressions( + lambda expr: self(expr, kernel, insn) + ) + + if self.replaced_something: + insn = insn.copy( + depends_on=( + insn.depends_on | frozenset([self.compute_insn_id]) + ) + ) + + # FIXME: determine compute insn dependencies + + new_insns.append(insn) + + return kernel.copy(instructions=new_insns) + + +@for_each_kernel +def compute( + kernel: LoopKernel, + substitution: str, + compute_map: nisl.Map, + storage_indices: Sequence[str], + + # NOTE: how can we deduce this? + temporal_inames: Sequence[str], + + temporary_name: str | None = None, + temporary_address_space: AddressSpace | None = None, + + # FIXME: typing + temporary_dtype = None + ) -> LoopKernel: + """ + Inserts an instruction to compute an expression given by :arg:`substitution` + and replaces all invocations of :arg:`substitution` with the result of the + compute instruction. + + :arg substitution: The substitution rule for which the compute + transform should be applied. + + :arg compute_map: An :class:`isl.Map` representing a relation between + substitution rule indices and tuples `(a, l)`, where `a` is a vector of + storage indices and `l` is a vector of "timestamps". + + :arg storage_indices: An ordered sequence of names of storage indices. Used + to create inames for the loops that cover the required set of compute points. + """ + # FIXME: use namedisl directly + compute_map = compute_map._reconstruct_isl_object() + + # {{{ construct necessary pieces; footprint, global usage map + + ctx = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + expander = SubstitutionRuleExpander(kernel.substitutions) + expr_gatherer = UsageSiteExpressionGatherer( + ctx, expander, kernel, substitution, None + ) + + _ = expr_gatherer.map_kernel(kernel) + usage_exprs = expr_gatherer.usage_expressions + + all_exprs = [ + expr + for usage in usage_exprs + for expr in usage + ] + + space = space_from_exprs(all_exprs) + + footprint = isl.Set.empty( + isl.Space.create_from_names( + ctx=space.get_ctx(), + set=list(storage_indices) + ) + ) + + # add compute inames to domain / kernel + domain_changer = DomainChanger(kernel, kernel.all_inames()) + domain = domain_changer.domain + + range_space = isl.Space.create_from_names( + ctx=space.get_ctx(), + set=list(storage_indices) + ) + map_space = space.map_from_domain_and_range(range_space) + global_usage_map = isl.Map.empty(map_space) + + for usage in usage_exprs: + + # FIXME package sequence of pymbolic exprs -> multipwaff up as a function in loopy.symbolic + local_usage_mpwaff = isl.MultiPwAff.zero(map_space) + + for i in range(len(storage_indices)): + local_usage_mpwaff = local_usage_mpwaff.set_pw_aff( + i, + pwaff_from_expr(space, usage[i]) + ) + + local_usage_map = local_usage_mpwaff.as_map() + + # FIXME: fix with namedisl + # remove unnecessary names from domain and intersect with usage map + usage_names = frozenset( + local_usage_map.get_dim_name(isl.dim_type.in_, dim) + for dim in range(local_usage_map.dim(isl.dim_type.in_)) + ) + + domain_names = frozenset( + domain.get_dim_name(isl.dim_type.set, dim) + for dim in range(domain.dim(isl.dim_type.set)) + ) + + domain_tmp = domain.project_out_except( + usage_names & domain_names, [isl.dim_type.set] + ) + + local_usage_map = align_map_domain_to_set(local_usage_map, domain_tmp) + local_usage_map = local_usage_map.intersect_domain(domain_tmp) + global_usage_map = global_usage_map | local_usage_map + + # {{{ FIXME: this shouldn't need to be done here; will be handled by namedisl + + global_usage_map = global_usage_map.apply_range(compute_map) + common_dims = { + dim1 : dim2 + for dim1 in range(global_usage_map.dim(isl.dim_type.in_)) + for dim2 in range(global_usage_map.dim(isl.dim_type.out)) + if ( + global_usage_map.get_dim_name(isl.dim_type.in_, dim1) + == + global_usage_map.get_dim_name(isl.dim_type.out, dim2) + ) + } + + for pos1, pos2 in common_dims.items(): + global_usage_map = global_usage_map.equate( + isl.dim_type.in_, pos1, + isl.dim_type.out, pos2 + ) + + # }}} + + # }}} + + # {{{ compute bounds and update kernel domain + + footprint = global_usage_map.range() + footprint_tmp, domain = isl.align_two(footprint, domain) + domain = (domain & footprint_tmp).get_basic_sets()[0] + + new_domains = domain_changer.get_domains_with(domain) + kernel = kernel.copy(domains=new_domains) + + # }}} + + # {{{ compute index expressions + + usage_substs: Mapping[AccessTuple, isl.Map] = {} + for usage in usage_exprs: + # find the relevant names + relevant_names = gather_vars(usage) + + # project out irrelevant names + relevant_names = frozenset(relevant_names) - frozenset(temporal_inames) + + local_iname_to_storage = global_usage_map.project_out_except( + relevant_names, + [isl.dim_type.in_] + ) + + local_iname_to_storage = local_iname_to_storage.project_out_except( + storage_indices, + [isl.dim_type.out] + ) + + # map usage -> resulting map + usage_substs[tuple(usage)] = local_iname_to_storage + + # }}} + + # {{{ create compute instruction in kernel + + compute_pw_aff = compute_map.reverse().as_pw_multi_aff() + storage_ax_to_global_expr = { + compute_pw_aff.get_dim_name(isl.dim_type.out, dim) : + pw_aff_to_expr(compute_pw_aff.get_at(dim)) + for dim in range(compute_pw_aff.dim(isl.dim_type.out)) + } + + expr_subst_map = RuleAwareSubstitutionMapper( + ctx, + make_subst_func(storage_ax_to_global_expr), + within=parse_stack_match(None) + ) + + subst_expr = kernel.substitutions[substitution].expression + compute_expression = expr_subst_map(subst_expr, kernel, None) + + if not temporary_name: + temporary_name = substitution + "_temp" + + assignee = var(temporary_name)[tuple( + var(iname) for iname in storage_indices + )] + + within_inames = frozenset( + compute_map.get_dim_name(isl.dim_type.out, dim) + for dim in range(compute_map.dim(isl.dim_type.out)) + ) + + compute_insn_id = substitution + "_compute" + compute_insn = lp.Assignment( + id=compute_insn_id, + assignee=assignee, + expression=compute_expression, + within_inames=within_inames + ) + + new_insns = list(kernel.instructions) + new_insns.append(compute_insn) + kernel = kernel.copy(instructions=new_insns) + + # }}} + + # {{{ replace invocations with new compute instruction + + ctx = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator() + ) + + replacer = RuleInvocationReplacer( + ctx, + substitution, + None, + usage_substs, + storage_indices, + temporary_name, + compute_insn_id, + global_usage_map + ) + + kernel = replacer.map_kernel(kernel) + + # }}} + + # {{{ create temporary variable for result of compute + + loopy_type = to_loopy_type(temporary_dtype, allow_none=True) + + # FIXME: fix with namedisl? + shape_domain = footprint.project_out_except(storage_indices, + [isl.dim_type.set]) + shape_domain = shape_domain.project_out_except("", [isl.dim_type.param]) + + temp_shape = tuple( + pw_aff_to_expr(shape_domain.dim_max(dim)) + 1 + for dim in range(shape_domain.dim(isl.dim_type.out)) + ) + + new_temp_vars = dict(kernel.temporary_variables) + + # FIXME: temp_var might already exist, handle the case where it does + temp_var = lp.TemporaryVariable( + name=temporary_name, + dtype=loopy_type, + base_indices=(0,)*len(temp_shape), + shape=temp_shape, + address_space=temporary_address_space, + dim_names=tuple(storage_indices) + ) + + new_temp_vars[temporary_name] = temp_var + + kernel = kernel.copy( + temporary_variables=new_temp_vars + ) + + # }}} + + # FIXME: anything else? + + return kernel