diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 09d63af11..4e6c1d912 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -190,30 +190,39 @@ def allocate_block_size( return idx def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInfo: - # Check if this size is already a registered block size - if isinstance(size, torch.SymInt): - from .host_function import HostFunction + # Quick return for existing reduction with same size + for info in self.block_sizes: + if info.reduction and info.size == size: + return info - expr = size._sympy_() - origin_info = HostFunction.current().expr_to_origin.get(expr) - if origin_info and isinstance(origin_info.origin, BlockSizeOrigin): - block_idx = origin_info.origin.block_id - # Return the existing block size if it's a reduction dimension - if self.block_sizes[block_idx].reduction: - return self.block_sizes[block_idx] - - # Check for existing reduction dimensions with the same size - for rdim in self.block_sizes: - if rdim.reduction and rdim.size == size: - return rdim + # For SymInt, check if we can reuse existing block for symbolic equality + if isinstance(size, torch.SymInt): + sym = size._sympy_() + block_id = self.get_block_id(size) + + # Return existing reduction by block_id + if block_id is not None and self.block_sizes[block_id].reduction: + return self.block_sizes[block_id] + + # Clone non-reduction block as reduction for symbolic equality + for idx, info in enumerate(self.block_sizes): + if not info.reduction and (idx == block_id or sym == info.symbol()): + reduction_loop_index = sum(int(b.reduction) for b in self.block_sizes) + rdim_idx = self.allocate_block_size( + size, + reduction=True, + source=ReductionLoopBlockSizeSource(reduction_loop_index), + hint=next_power_of_2(self.size_hint(size)), + ) + self.block_sizes[rdim_idx].var = info.var + return self.block_sizes[rdim_idx] - # Allocate a new reduction dimension + # Allocate new reduction dimension + reduction_loop_index = sum(int(info.reduction) for info in self.block_sizes) rdim_idx = self.allocate_block_size( size, reduction=True, - source=ReductionLoopBlockSizeSource( - sum([int(bs.reduction) for bs in self.block_sizes]) - ), + source=ReductionLoopBlockSizeSource(reduction_loop_index), hint=next_power_of_2(self.size_hint(size)), ) return self.block_sizes[rdim_idx] @@ -271,6 +280,108 @@ def cached_create_unbacked_symint( self._symint_cache[key] = result return result + + def register_tile_index_tensor_block_id(self, tensor: torch.Tensor, block_id: int) -> None: + """Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance.""" + tensor._tile_index_block_id = block_id # type: ignore[attr-defined] + + def get_tile_index_tensor_block_id(self, tensor: torch.Tensor) -> int | None: + """Return the originating ``tile.index`` block id if present.""" + return getattr(tensor, "_tile_index_block_id", None) + + def get_indexer_output_dims( + self, + indexer_tensor: torch.Tensor, + base_dim_size: int | torch.SymInt | None, + ) -> list[int | torch.SymInt]: + """Map a tensor indexer's shape to the output dimensions for advanced indexing.""" + dims = list(indexer_tensor.size()) + non_broadcast_dims = [d for d in dims if self.size_hint(d) != 1] + + # Multi-dimensional indexer - return full shape + if len(non_broadcast_dims) > 1: + return dims + + # Try to find block_id from various sources + block_id = ( + self.get_tile_index_tensor_block_id(indexer_tensor) + or (self.get_block_id(base_dim_size) if base_dim_size else None) + or (self.get_block_id(non_broadcast_dims[0]) if non_broadcast_dims else None) + ) + + return [self.block_sizes[block_id].var] if block_id else (non_broadcast_dims or [1]) + + def tensor_indexer_broadcast_shape( + self, tensors: typing.Sequence[torch.Tensor] + ) -> list[int | torch.SymInt] | None: + """Compute a shared broadcast shape for tensor indexers when needed.""" + tensor_list = [t for t in tensors if isinstance(t, torch.Tensor)] + if not tensor_list or all(self.get_tile_index_tensor_block_id(t) for t in tensor_list): + return None + + shapes = [list(t.size()) for t in tensor_list] + + # Inline compute_broadcast_shape_for_tensor_indexers logic + if not shapes: + return [] + + # Special case: multiple 1D tensors form a Cartesian product + if all(len(s) == 1 for s in shapes) and len(shapes) > 1: + return [s[0] for s in shapes] + + # General broadcasting case + max_ndim = max(len(s) for s in shapes) + padded = [([1] * (max_ndim - len(s)) + s) for s in shapes] + return [ + next((d for d in dims if self.size_hint(d) != 1), 1) + for dims in zip(*padded, strict=True) + ] + + def tensor_indexer_broadcast( + self, tensor_list: "Sequence[torch.Tensor] | None" + ) -> TensorIndexerBroadcast: + """Initialize broadcast tracking for tensor indexers.""" + if not tensor_list: + return TensorIndexerBroadcast(shape=None) + return TensorIndexerBroadcast( + shape=self.tensor_indexer_broadcast_shape(tensor_list) + ) + + def tensor_indexer_dims( + self, + indexer_tensor: torch.Tensor, + base_dim_size: int | torch.SymInt, + broadcast: TensorIndexerBroadcast, + ) -> list[int | torch.SymInt]: + """Return dims contributed by a tensor indexer, honoring shared broadcast.""" + if broadcast.shape is None: + return self.get_indexer_output_dims(indexer_tensor, base_dim_size) + if broadcast.used: + return [] + broadcast.used = True + return list(broadcast.shape) + + def new_index_result( + self, tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt] + ) -> torch.Tensor: + """Create a new tensor for indexing/view ops while preserving tile index provenance.""" + # Inline resolve_tile_index_shape logic + block_id = self.get_tile_index_tensor_block_id(tensor) + if not block_id: + resolved_shape = list(output_shape) + else: + resolved_shape = list(output_shape) + non_broadcast = [i for i, s in enumerate(resolved_shape) if self.size_hint(s) != 1] + if len(non_broadcast) == 1: + resolved_shape[non_broadcast[0]] = self.block_sizes[block_id].var + elif len(non_broadcast) > 1: + block_id = None + + result = tensor.new_empty(resolved_shape) + if block_id is not None: + self.register_tile_index_tensor_block_id(result, block_id) + return result + def to_fake(self, obj: object, origin: Origin) -> object: if obj is None: return None @@ -354,6 +465,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor: self.fake_mode, tensor, shape_env=self.shape_env, source=source ) self.input_sources[result] = source + if hasattr(tensor, "_tile_index_block_id"): + self.register_tile_index_tensor_block_id( + result, typing.cast(int, getattr(tensor, "_tile_index_block_id")) + ) if isinstance(source, LocalSource): for i, s in enumerate(result.size()): if isinstance(s, torch.SymInt) and isinstance( @@ -493,6 +608,14 @@ class AutoSize: """A marker used to delay setting the size of a block until it is known.""" +@dataclasses.dataclass +class TensorIndexerBroadcast: + """Tracks shared broadcast state for tensor indexers.""" + + shape: list[int | torch.SymInt] | None + used: bool = False + + @dataclasses.dataclass class BlockSizeInfo: """ diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index 1586a2cd6..fc96d6bc9 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -5,6 +5,7 @@ import dataclasses from typing import TYPE_CHECKING from typing import NamedTuple +from typing import cast import sympy import torch @@ -16,6 +17,7 @@ from .._compat import get_tensor_descriptor_fn_name from .ast_extension import expr_from_string from .compile_environment import CompileEnvironment +from .compile_environment import TensorIndexerBroadcast from .device_function import DeviceFunction from .host_function import HostFunction from .tile_strategy import DeviceLoopState @@ -557,6 +559,60 @@ def codegen_store( ) +@dataclasses.dataclass +class _SubscriptIndexingContext: + """Context object to hold state during index processing.""" + state: CodegenState + fake_value: torch.Tensor + index: list[object] + env: CompileEnvironment + dtype: str + + # Output tracking + output_idx: int = 0 + index_values: list[str] = dataclasses.field(default_factory=list) + mask_values: dict[str, None] = dataclasses.field(default_factory=dict) + + # Computed values (initialized in __post_init__) + output_size: list[int | torch.SymInt] = dataclasses.field(init=False) + tile_strategy: object = dataclasses.field(init=False) + all_tensors: list[torch.Tensor] = dataclasses.field(init=False) + tensor_shapes: list[list[int | torch.SymInt]] = dataclasses.field(init=False) + broadcast_shape: list[int | torch.SymInt] | None = dataclasses.field(init=False) + + # Tensor indexing state + first_tensor_idx: int = 0 + tensor_count: int = 0 + k_index: int = 0 + + def __post_init__(self) -> None: + self.all_tensors = [ + cast(torch.Tensor, k) for k in self.index + if isinstance(k, torch.Tensor) + ] + broadcast = self.env.tensor_indexer_broadcast(self.all_tensors) + self.output_size = SubscriptIndexing.compute_shape( + self.fake_value, self.index, self.state, broadcast=broadcast + ) + self.tile_strategy = self.state.tile_strategy + self.tensor_shapes = [list(t.size()) for t in self.all_tensors] + self.broadcast_shape = broadcast.shape + + def is_size_one(self, size: int | torch.SymInt) -> bool: + """Check if a size is known to be one.""" + return self.env.known_equal(size, 1) + + @property + def using_shared_broadcast(self) -> bool: + """Whether tensor indexers share a single broadcasted shape.""" + return self.broadcast_shape is not None + + @property + def skip_broadcast_masks(self) -> bool: + """Skip repeated mask insertion for additional broadcasted tensor indexers.""" + return self.broadcast_shape is not None and self.tensor_count > 0 + + class SubscriptIndexing(NamedTuple): index_expr: ast.AST mask_expr: ast.AST @@ -566,15 +622,25 @@ def has_mask(self) -> bool: isinstance(self.mask_expr, ast.Constant) and self.mask_expr.value is None ) + @staticmethod def compute_shape( - tensor: torch.Tensor, index: list[object], state: CodegenState | None = None + tensor: torch.Tensor, + index: list[object], + state: CodegenState | None = None, + *, + broadcast: TensorIndexerBroadcast | None = None, ) -> list[int | torch.SymInt]: assert isinstance(tensor, torch.Tensor) assert isinstance(index, (list, tuple)), index input_size = collections.deque(tensor.size()) - output_size = [] + output_size: list[int | torch.SymInt] = [] env = CompileEnvironment.current() + + # Get broadcast shape for tensor indexers (same semantics as type propagation) + tensors = [cast(torch.Tensor, k) for k in index if isinstance(k, torch.Tensor)] + broadcast_ctx = broadcast or env.tensor_indexer_broadcast(tensors) + k_index = 0 for k in index: if k is None: @@ -608,24 +674,22 @@ def compute_shape( k_index += 1 elif isinstance(k, slice): size = input_size.popleft() - # Handle slices with steps slice_size = compute_slice_size(k, size) - if slice_size != 1: rdim = env.allocate_reduction_dimension(slice_size) output_size.append(rdim.var) else: output_size.append(1) k_index += 1 - elif isinstance(k, torch.Tensor) and ( - k.ndim == 1 or (len(index) == 1 and tensor.ndim == 1) - ): - input_size.popleft() - output_size.extend(k.size()) + elif isinstance(k, torch.Tensor): + base_dim = input_size.popleft() + output_size.extend(env.tensor_indexer_dims(k, base_dim, broadcast_ctx)) k_index += 1 else: raise exc.InvalidIndexingType(k) - assert len(input_size) == 0, "invalid subscript" + # Advanced indexing might not consume all dimensions + # Add any remaining dimensions from the input + output_size.extend(input_size) return output_size @staticmethod @@ -660,156 +724,455 @@ def create( index: list[object], extra_mask: ast.AST | None = None, ) -> SubscriptIndexing: - tile_strategy = state.tile_strategy - output_idx = 0 - index_values = [] - mask_values = {} - output_size = SubscriptIndexing.compute_shape(fake_value, index, state) + """Create a SubscriptIndexing instance for the given tensor and index.""" env = CompileEnvironment.current() dtype = env.triton_index_type() + if dtype == "tl.int32" and SubscriptIndexing._needs_int64(fake_value): raise exc.IndexOffsetOutOfRangeForInt32(env.index_dtype) - def _is_size_one(size: int | torch.SymInt) -> bool: - return env.known_equal(size, 1) + # Initialize context for index processing + context = _SubscriptIndexingContext( + state=state, + fake_value=fake_value, + index=index, + env=env, + dtype=dtype + ) - k_index = 0 - for n, k in enumerate(index): - if k is None: - output_idx += 1 - elif isinstance(k, int): - index_values.append(repr(k)) - elif ( - tile_info := _get_tile_with_offset_info(k, state, k_index) - ) is not None: - # Tensor marked as tile.index + offset - block_id, offset = tile_info - index_var = state.codegen.index_var(block_id) - offset_expr = state.device_function.literal_expr(offset) - expand = tile_strategy.expand_str(output_size, output_idx) - i = len(index_values) - index_values.append(f"(({index_var}) + {offset_expr}){expand}") - # Use the same mask as the underlying tile - if (mask := state.codegen.mask_var(block_id)) and not _is_size_one( - fake_value.size(i) - ): - mask_values.setdefault(f"({mask}){expand}") - output_idx += 1 - k_index += 1 - elif isinstance(k, torch.SymInt): - symbol = k._sympy_() - origin = None - if isinstance(symbol, sympy.Symbol): - origin = HostFunction.current().expr_to_origin.get(symbol) - if origin and isinstance(origin.origin, BlockSizeOrigin): - index_var = state.codegen.index_var(origin.origin.block_id) - expand = tile_strategy.expand_str(output_size, output_idx) - i = len(index_values) - index_values.append(f"({index_var}){expand}") - if ( - mask := state.codegen.mask_var(origin.origin.block_id) - ) and not _is_size_one(fake_value.size(i)): - mask_values.setdefault(f"({mask}){expand}") - output_idx += 1 - k_index += 1 - else: - # When the index is a scalar (no BlockSizeOrigin), the corresponding dim is eliminated. - val = state.device_function.literal_expr(k) - index_values.append(f"({val})") - elif isinstance(k, slice): - expand = tile_strategy.expand_str(output_size, output_idx) - size = fake_value.size(len(index_values)) + # Process each index element + for position, index_elem in enumerate(index): + SubscriptIndexing._process_index_element(context, position, index_elem) - # Handle slices with steps - if k.step is not None and k.step != 1: - # For strided slices, we need to generate: start + index * step - start = k.start if k.start is not None else 0 - step = k.step - slice_size = compute_slice_size(k, size) - - if slice_size != 1: - rdim = env.allocate_reduction_dimension(slice_size) - block_idx = rdim.block_id - index_var = state.codegen.index_var(block_idx) - # Generate strided index: start + index * step - index_values.append( - f"({start} + ({index_var}) * {step}){expand}" - ) - if mask := state.codegen.mask_var(block_idx): - mask_values.setdefault(f"({mask}){expand}") - else: - index_values.append(f"{start}{expand}") - else: - # Full slice or slice without step - if not _is_size_one(size): - rdim = env.allocate_reduction_dimension(size) - block_idx = rdim.block_id - index_var = state.codegen.index_var(block_idx) - index_values.append(f"({index_var}){expand}") - if mask := state.codegen.mask_var(block_idx): - mask_values.setdefault(f"({mask}){expand}") - else: - index_values.append(f"tl.zeros([1], {dtype}){expand}") - output_idx += 1 - k_index += 1 - elif isinstance(k, torch.Tensor) and k.ndim == 1: - expand = tile_strategy.expand_str(output_size, output_idx) - ast_index = state.ast_args[1] - assert isinstance(ast_index, (list, tuple)) - assert len(ast_index) == len(index) - index_var = state.codegen.lift(ast_index[n], prefix="index").id - index_values.append(f"({index_var}){expand}") - if (block_idx := env.get_block_id(output_size[output_idx])) is not None: - if mask := state.codegen.mask_var(block_idx): - mask_values.setdefault(f"({mask}){expand}") - # Check if this index comes from a padded hl.arange and generate mask - if ( - original_length := _get_padded_iota_original_length(state, n) - ) is not None: - mask_values.setdefault(f"({index_var} < {original_length}){expand}") - output_idx += 1 - k_index += 1 - elif ( - isinstance(k, torch.Tensor) and len(index) == 1 and fake_value.ndim == 1 - ): - # TODO(jansel): combine this case with the above - ast_index = state.ast_args[1] - assert isinstance(ast_index, (list, tuple)) - assert len(ast_index) == 1 - index_var = state.codegen.lift(ast_index[0], prefix="index").id - index_values.append(index_var) - output_idx += k.ndim - for n, s in enumerate(output_size): - if (block_idx := env.get_block_id(s)) is not None and ( - mask := state.codegen.mask_var(block_idx) - ): - mask_values.setdefault( - f"({mask}){tile_strategy.expand_str(output_size, n)}" - ) - k_index += 1 + # Validate and build final expressions + assert len(context.output_size) == context.output_idx + assert len(context.index_values) == fake_value.ndim + + return SubscriptIndexing._build_final_expressions(context, extra_mask) + + @staticmethod + def _process_index_element( + ctx: _SubscriptIndexingContext, + position: int, + index_elem: object + ) -> None: + """Process a single index element and update context.""" + if index_elem is None: + ctx.output_idx += 1 + elif isinstance(index_elem, int): + SubscriptIndexing._process_int_index(ctx, index_elem) + elif (tile_info := _get_tile_with_offset_info( + index_elem, ctx.state, ctx.k_index + )) is not None: + SubscriptIndexing._process_tile_with_offset(ctx, tile_info) + elif isinstance(index_elem, torch.SymInt): + SubscriptIndexing._process_symint_index(ctx, index_elem) + elif isinstance(index_elem, slice): + SubscriptIndexing._process_slice_index(ctx, index_elem) + elif isinstance(index_elem, torch.Tensor): + SubscriptIndexing._process_tensor_index(ctx, position, index_elem) + else: + raise exc.InvalidIndexingType(type(index_elem)) + + @staticmethod + def _process_int_index(ctx: _SubscriptIndexingContext, value: int) -> None: + """Process an integer index.""" + ctx.index_values.append(repr(value)) + + @staticmethod + def _add_block_index_and_mask( + ctx: _SubscriptIndexingContext, + block_id: int, + extra_offset: str = "", + position_override: int | None = None + ) -> str: + """Add block index variable and associated mask if needed. + + Returns: + The expand string used. + """ + index_var = ctx.state.codegen.index_var(block_id) + pos = position_override if position_override is not None else ctx.output_idx + expand = ctx.tile_strategy.expand_str(ctx.output_size, pos) + + # Build index expression + index_expr = f"({index_var})" + if extra_offset: + index_expr = f"({index_expr} + {extra_offset})" + ctx.index_values.append(f"{index_expr}{expand}") + + # Add mask if needed + SubscriptIndexing._add_mask_if_needed(ctx, block_id, expand, len(ctx.index_values) - 1) + return expand + + @staticmethod + def _add_mask_if_needed( + ctx: _SubscriptIndexingContext, + block_id: int | None, + expand: str, + dim_index: int | None = None + ) -> None: + """Add mask for a block if it exists and dimension is not size one.""" + if block_id is None: + return + + mask = ctx.state.codegen.mask_var(block_id) + if not mask: + return + + # Check if dimension is size one (can be skipped) + if dim_index is not None and ctx.is_size_one(ctx.fake_value.size(dim_index)): + return + + ctx.mask_values.setdefault(f"({mask}){expand}") + + @staticmethod + def _add_masks_for_positions( + ctx: _SubscriptIndexingContext, + positions: list[int] + ) -> None: + """Add masks for a list of output positions when required.""" + if ctx.skip_broadcast_masks: + return + + limit = len(ctx.output_size) + for pos in positions: + if pos >= limit: + continue + block_idx = ctx.env.get_block_id(ctx.output_size[pos]) + if block_idx is None: + continue + expand = ctx.tile_strategy.expand_str(ctx.output_size, pos) + SubscriptIndexing._add_mask_if_needed(ctx, block_idx, expand) + + @staticmethod + def _get_symint_block_origin(symint: torch.SymInt) -> BlockSizeOrigin | None: + """Extract BlockSizeOrigin from a SymInt if it has one.""" + symbol = symint._sympy_() + if not isinstance(symbol, sympy.Symbol): + return None + origin = HostFunction.current().expr_to_origin.get(symbol) + if origin and isinstance(origin.origin, BlockSizeOrigin): + return origin.origin + return None + + @staticmethod + def _process_tile_with_offset( + ctx: _SubscriptIndexingContext, + tile_info: tuple[int, int | torch.SymInt] + ) -> None: + """Process a tensor marked as tile.index + offset.""" + block_id, offset = tile_info + offset_expr = ctx.state.device_function.literal_expr(offset) + SubscriptIndexing._add_block_index_and_mask(ctx, block_id, offset_expr) + ctx.output_idx += 1 + ctx.k_index += 1 + + @staticmethod + def _process_symint_index(ctx: _SubscriptIndexingContext, symint: torch.SymInt) -> None: + """Process a SymInt index.""" + origin = SubscriptIndexing._get_symint_block_origin(symint) + + if origin: + # Handle block size origin + SubscriptIndexing._add_block_index_and_mask(ctx, origin.block_id) + ctx.output_idx += 1 + ctx.k_index += 1 + else: + # Scalar index - dimension is eliminated + val = ctx.state.device_function.literal_expr(symint) + ctx.index_values.append(f"({val})") + + @staticmethod + def _process_slice_index(ctx: _SubscriptIndexingContext, slice_obj: slice) -> None: + """Process a slice index.""" + size = ctx.fake_value.size(len(ctx.index_values)) + slice_size = compute_slice_size(slice_obj, size) + + if slice_obj.step is not None and slice_obj.step != 1: + # Handle strided slices + start = slice_obj.start if slice_obj.start is not None else 0 + if slice_size != 1: + rdim = ctx.env.allocate_reduction_dimension(slice_size) + expand = ctx.tile_strategy.expand_str(ctx.output_size, ctx.output_idx) + index_var = ctx.state.codegen.index_var(rdim.block_id) + # Generate strided index: start + index * step + ctx.index_values.append(f"({start} + ({index_var}) * {slice_obj.step}){expand}") + SubscriptIndexing._add_mask_if_needed(ctx, rdim.block_id, expand, len(ctx.index_values) - 1) + else: + expand = ctx.tile_strategy.expand_str(ctx.output_size, ctx.output_idx) + ctx.index_values.append(f"{start}{expand}") + else: + # Handle regular slices + if not ctx.is_size_one(size): + rdim = ctx.env.allocate_reduction_dimension(size) + SubscriptIndexing._add_block_index_and_mask(ctx, rdim.block_id) + else: + expand = ctx.tile_strategy.expand_str(ctx.output_size, ctx.output_idx) + ctx.index_values.append(f"tl.zeros([1], {ctx.dtype}){expand}") + + ctx.output_idx += 1 + ctx.k_index += 1 + + @staticmethod + def _process_tensor_index( + ctx: _SubscriptIndexingContext, + position: int, + tensor: torch.Tensor + ) -> None: + """Process a tensor index with broadcasting support.""" + # Get the index variable from AST + ast_index = ctx.state.ast_args[1] + assert isinstance(ast_index, (list, tuple)) + index_var = ctx.state.codegen.lift(ast_index[position], prefix="index").id + + # Try special case for 2D Cartesian product first + if SubscriptIndexing._try_2d_cartesian_product(ctx, index_var, position): + return + + # Handle general tensor indexing + if ctx.broadcast_shape: + SubscriptIndexing._process_broadcast_tensor(ctx, tensor, index_var) + else: + SubscriptIndexing._process_simple_tensor(ctx, tensor, index_var) + + ctx.tensor_count += 1 + ctx.k_index += 1 + + @staticmethod + def _try_2d_cartesian_product( + ctx: _SubscriptIndexingContext, + index_var: str, + position: int + ) -> bool: + """Try to handle 2D Cartesian product special case.""" + # Only apply this special case when we have exactly 2 tensor indices + # with broadcast shape of length 2, regardless of other index types + if not (ctx.broadcast_shape and + len(ctx.all_tensors) == 2 and + len(ctx.broadcast_shape) == 2): + return False + + # Check if all tensors are effectively 1D + all_1d = all( + len(s) == 1 or sum(1 for d in s if ctx.env.size_hint(d) != 1) <= 1 + for s in ctx.tensor_shapes + ) + if not all_1d: + return False + + original_length = _get_padded_iota_original_length(ctx.state, position) + + if ctx.tensor_count == 0: + expr = f"({index_var})" + mask_suffix = "" + # Only add an extra broadcast dim if the tensor itself is 1D + shape = ctx.tensor_shapes[ctx.tensor_count] + if len(shape) <= 1: + expr = f"{expr}[:, None]" + mask_suffix = "[:, None]" + ctx.index_values.append(expr) + ctx.first_tensor_idx = ctx.output_idx + ctx.output_idx += 2 + if original_length is not None: + ctx.mask_values.setdefault( + f"(({index_var} < {original_length}){mask_suffix})" + ) + else: + expr = f"({index_var})" + mask_suffix = "" + shape = ctx.tensor_shapes[ctx.tensor_count] + if len(shape) <= 1: + expr = f"{expr}[None, :]" + mask_suffix = "[None, :]" + ctx.index_values.append(expr) + if original_length is not None: + ctx.mask_values.setdefault( + f"(({index_var} < {original_length}){mask_suffix})" + ) + + ctx.tensor_count += 1 + ctx.k_index += 1 + return True + + @staticmethod + def _process_broadcast_tensor( + ctx: _SubscriptIndexingContext, + tensor: torch.Tensor, + index_var: str + ) -> None: + """Process tensor indexing with broadcasting.""" + if ctx.tensor_count == 0: + ctx.first_tensor_idx = ctx.output_idx + ctx.output_idx += len(ctx.broadcast_shape) + + expand_pos, width = SubscriptIndexing._get_tensor_shape_info(ctx, tensor) + + if width <= 1: + SubscriptIndexing._process_tensor_with_tile_origin( + ctx, tensor, index_var, expand_pos + ) + else: + SubscriptIndexing._process_multi_dim_broadcast( + ctx, tensor, index_var, expand_pos, width + ) + + @staticmethod + def _process_tensor_with_tile_origin( + ctx: _SubscriptIndexingContext, + tensor: torch.Tensor, + index_var: str, + expand_pos: int | None = None + ) -> None: + """Process tensor indexing with potential tile origin block ID.""" + pos = expand_pos if expand_pos is not None else ctx.output_idx + need_expand = ( + (ctx.using_shared_broadcast and tensor.ndim == 1) + or (not ctx.using_shared_broadcast and tensor.ndim < len(ctx.output_size)) + ) + expand = ctx.tile_strategy.expand_str(ctx.output_size, pos) if need_expand else "" + tile_origin_block_id = ctx.env.get_tile_index_tensor_block_id(tensor) + index_source = ( + ctx.state.codegen.index_var(tile_origin_block_id) + if tile_origin_block_id is not None + else index_var + ) + mask_block_id = tile_origin_block_id or ( + ctx.env.get_block_id(ctx.output_size[pos]) + if pos < len(ctx.output_size) + else None + ) + + ctx.index_values.append(f"({index_source}){expand}") + + if not ctx.skip_broadcast_masks: + SubscriptIndexing._add_mask_if_needed( + ctx, mask_block_id, expand, len(ctx.index_values) - 1 + ) + + @staticmethod + def _process_simple_tensor( + ctx: _SubscriptIndexingContext, + tensor: torch.Tensor, + index_var: str + ) -> None: + """Process simple tensor indexing without broadcasting.""" + SubscriptIndexing._process_tensor_with_tile_origin(ctx, tensor, index_var) + ctx.output_idx += tensor.ndim + + @staticmethod + def _get_tensor_shape_info( + ctx: _SubscriptIndexingContext, + tensor: torch.Tensor + ) -> tuple[int, int]: + """Get shape information for tensor broadcasting.""" + shape = ( + ctx.tensor_shapes[ctx.tensor_count] + if ctx.tensor_count < len(ctx.tensor_shapes) + else [1] + ) + shape_size = len(shape) + + # Check if tensor has more than 1 non-singleton dimension + non_bcast_dims = sum(1 for d in shape if ctx.env.size_hint(d) != 1) + is_single_dim = non_bcast_dims <= 1 + + # Calculate positioning + offset = max(0, len(ctx.broadcast_shape) - shape_size) + if is_single_dim and shape_size > 0: + non_one_positions = [ + i for i, d in enumerate(shape) + if ctx.env.size_hint(d) != 1 + ] + rel_pos = non_one_positions[0] if non_one_positions else (shape_size - 1) + expand_pos = ctx.first_tensor_idx + offset + rel_pos + else: + expand_pos = ctx.first_tensor_idx + offset + + expand_pos = max(0, min(expand_pos, len(ctx.output_size) - 1)) if ctx.output_size else 0 + width = 1 if is_single_dim else min(shape_size, max(0, len(ctx.output_size) - expand_pos)) + + return expand_pos, width + + @staticmethod + def _process_multi_dim_broadcast( + ctx: _SubscriptIndexingContext, + tensor: torch.Tensor, + index_var: str, + expand_pos: int, + width: int + ) -> None: + """Process multi-dimensional broadcast tensor.""" + positions = [expand_pos + d for d in range(width)] + + needs_bracket = ( + (ctx.broadcast_shape is None) + and tensor.ndim < len(ctx.output_size) + and not (positions == list(range(width)) and tensor.ndim == width) + ) + bracket = ( + SubscriptIndexing._build_multi_dim_bracket(ctx, positions) + if needs_bracket + else "" + ) + ctx.index_values.append(f"({index_var}){bracket}") + + SubscriptIndexing._add_masks_for_positions(ctx, positions) + + @staticmethod + def _build_multi_dim_bracket( + ctx: _SubscriptIndexingContext, + positions: list[int] + ) -> str: + """Build bracket string for multi-dimensional expansion.""" + tokens: list[str] | None = None + + for pos in positions: + expand = ctx.tile_strategy.expand_str(ctx.output_size, pos) + if expand == "": + current = [":"] else: - raise exc.InvalidIndexingType(type(k)) - assert len(output_size) == output_idx - assert len(index_values) == fake_value.ndim + assert expand.startswith("[") and expand.endswith("]"), expand + current = expand[1:-1].split(", ") if len(expand) > 2 else [] + + if tokens is None: + tokens = current + elif current: + tokens = [ + ":" if (a == ":" or b == ":") else "None" + for a, b in zip(tokens, current, strict=True) + ] + + return f"[{', '.join(tokens)}]" if tokens and not all(t == ":" for t in tokens) else "" + + @staticmethod + def _build_final_expressions( + ctx: _SubscriptIndexingContext, + extra_mask: ast.AST | None + ) -> SubscriptIndexing: + """Build the final index and mask expressions.""" + # Build index expression index_expr = [] - for i, idx in enumerate(index_values): - if not _is_size_one(fake_value.size(i)): - stride = state.device_function.tensor_stride(fake_value, i).name + for i, idx in enumerate(ctx.index_values): + if not ctx.is_size_one(ctx.fake_value.size(i)): + stride = ctx.state.device_function.tensor_stride(ctx.fake_value, i).name index_expr.append(f"{idx} * {stride}") + if not index_expr: - shape_str = tile_strategy.shape_str(output_size) - index_expr.append(f"tl.zeros({shape_str}, {dtype})") + shape_str = ctx.tile_strategy.shape_str(ctx.output_size) + index_expr.append(f"tl.zeros({shape_str}, {ctx.dtype})") + # Build mask expression kwargs = {} if extra_mask is not None: - mask_values.setdefault("{_extra_mask}") + ctx.mask_values.setdefault("{_extra_mask}") kwargs["_extra_mask"] = extra_mask + return SubscriptIndexing( expr_from_string("+".join(index_expr)), - expr_from_string("&".join(mask_values) or "None", **kwargs), + expr_from_string("&".join(ctx.mask_values) or "None", **kwargs), ) - @dataclasses.dataclass class BlockedSubscriptIndexing: """Indexing used for block_ptr and tensor_descriptor""" @@ -926,45 +1289,18 @@ def is_supported( ): # Tensor marked as tile.index + offset - treat like TileWithOffset block_index, _ = tile_info - try: - state.codegen.offset_var(block_index) - except NotImplementedError: + if not BlockedSubscriptIndexing._check_block_index_support( + state, block_index, input_size + ): return False - loop_state = state.codegen.active_device_loops[block_index][-1] - if isinstance(loop_state, DeviceLoopState): - if not loop_state.block_id_to_info[block_index].is_end_matching( - input_size - ): - assert state.fx_node is not None - if "masked_value" in state.fx_node.meta: - return False k_index += 1 elif isinstance(k, torch.SymInt): - symbol = k._sympy_() - origin = None - if isinstance(symbol, sympy.Symbol): - origin = HostFunction.current().expr_to_origin.get(symbol) - if origin and isinstance(origin.origin, BlockSizeOrigin): - block_index = origin.origin.block_id - try: - state.codegen.offset_var(block_index) - except NotImplementedError: + origin = SubscriptIndexing._get_symint_block_origin(k) + if origin: + if not BlockedSubscriptIndexing._check_block_index_support( + state, origin.block_id, input_size + ): return False - loop_state = state.codegen.active_device_loops[block_index][-1] - if isinstance(loop_state, DeviceLoopState): - """ - Check for a corner case where the loop size does not match the tensor size. - In this case, the block masking will be incorrect. So we check if the - masking is needed and bail if it is. - """ - if not loop_state.block_id_to_info[block_index].is_end_matching( - input_size - ): - assert state.fx_node is not None - if "masked_value" in state.fx_node.meta: - # TODO(jansel): in this case we should be able to lower to block_ptr+tl.where - # see test/test_loops.py::TestLoops::test_data_dependent_bounds2 - return False k_index += 1 elif isinstance(k, torch.Tensor): # indirect loads don't work with block_ptr @@ -972,6 +1308,25 @@ def is_supported( output_shape = SubscriptIndexing.compute_shape(fake_tensor, index, state) return len(output_shape) != 0 + @staticmethod + def _check_block_index_support( + state: CodegenState, block_index: int, input_size: int | torch.SymInt + ) -> bool: + """Check if a block index is supported for block_ptr.""" + try: + state.codegen.offset_var(block_index) + except NotImplementedError: + return False + loop_state = state.codegen.active_device_loops[block_index][-1] + if isinstance(loop_state, DeviceLoopState): + if not loop_state.block_id_to_info[block_index].is_end_matching(input_size): + assert state.fx_node is not None + if "masked_value" in state.fx_node.meta: + # TODO(jansel): in this case we should be able to lower to block_ptr+tl.where + # see test/test_loops.py::TestLoops::test_data_dependent_bounds2 + return False + return True + def validate(self) -> None: n = self.ndim assert len(self.offsets) == n, ( @@ -1012,12 +1367,11 @@ def create( res.block_shape.append(1) k_index += 1 elif isinstance(k, torch.SymInt): - symbol = k._sympy_() - origin = HostFunction.current().expr_to_origin.get(symbol) - if origin and isinstance(origin.origin, BlockSizeOrigin): + origin = SubscriptIndexing._get_symint_block_origin(k) + if origin: if fake_value.size(len(res.offsets)) != 1: res.offsets.append( - state.codegen.offset_var(origin.origin.block_id) + state.codegen.offset_var(origin.block_id) ) res.block_shape.append(k) else: diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index d223cbd51..7cb8750e8 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -507,32 +507,25 @@ def is_one(x: int | torch.SymInt) -> bool: # Check each dimension independently for dim in range(max_rank): - # First, see if multiple distinct block-ids appear in this dim - block_ids: set[int] = set() - for s in shapes: - size_i = s[dim] - if is_one(size_i): - continue - block_id = env.get_block_id(size_i) - if block_id is not None: - block_ids.add(block_id) + non_one_sizes = [s[dim] for s in shapes if not is_one(s[dim])] + if len(non_one_sizes) <= 1: + continue + + # Check block_ids first - different tile loops cannot broadcast + block_ids = { + block_id + for sz in non_one_sizes + if (block_id := env.get_block_id(sz)) is not None + } if len(block_ids) >= 2: raise exc.ShapeMismatch( str(shapes[0]), ", ".join(map(str, shapes[1:])), ) - # Otherwise, fall back to strict symbolic inequality among non-1 sizes - exprs: set[object] = set() - for s in shapes: - size_i = s[dim] - if is_one(size_i): - continue - if isinstance(size_i, torch.SymInt): - exprs.add(size_i._sympy_()) - else: - exprs.add(size_i) - if len(exprs) >= 2: + # Check symbolic equality + base = non_one_sizes[0] + if not all(env.known_equal(base, sz) for sz in non_one_sizes[1:]): raise exc.ShapeMismatch( str(shapes[0]), ", ".join(map(str, shapes[1:])), diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 09a74bdff..538fa05c6 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -41,6 +41,7 @@ from .compile_environment import warning from .device_function import contains_only_block_size_symbols from .host_function import HostFunction +from .indexing_strategy import SubscriptIndexing from .host_function import SymbolOrigin from .output_header import library_imports from .source_location import current_location @@ -145,9 +146,7 @@ def merge(self, other: LocalScope | dict[str, TypeInfo]) -> LocalScope: other = other.variables for k, v in other.items(): if k in self.variables: - existing = self.variables[k] - merged = existing.merge(v, var_name=k) - self.variables[k] = merged + self.variables[k] = self.variables[k].merge(v, var_name=k) else: self.variables[k] = v return self @@ -458,8 +457,10 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]: else: keys = [key] inputs_consumed = 0 - output_sizes = [] + output_sizes: list[int | torch.SymInt] = [] env = CompileEnvironment.current() + tensor_indexers = [cast(TensorType, k).fake_value for k in keys if isinstance(k, TensorType)] + broadcast = env.tensor_indexer_broadcast(tensor_indexers) for k in keys: if isinstance(k, LiteralType): if isinstance(k.value, (int, torch.SymInt)): @@ -505,19 +506,30 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]: raise exc.DataDependentOutputShapeNotSupported( op_desc="Boolean mask indexing (tensor[boolean_mask])" ) - elif isinstance(k, TensorType) and k.fake_value.ndim == 1: + elif isinstance(k, TensorType): + base_dim_size = self.fake_value.size(inputs_consumed) inputs_consumed += 1 - output_sizes.append(k.fake_value.size(0)) + output_sizes.extend( + env.tensor_indexer_dims(k.fake_value, base_dim_size, broadcast) + ) elif k.contains_type(TileIndexType): raise exc.OverpackedTile(k) else: raise exc.InvalidIndexingType(k) - if inputs_consumed != self.fake_value.ndim: + # Check for rank mismatch - consuming too many dimensions + if inputs_consumed > self.fake_value.ndim: raise exc.RankMismatch( self.fake_value.ndim, inputs_consumed, - f"tensor shape: {tuple(self.fake_value.shape)}", + f"tensor shape: {tuple(self.fake_value.shape)}, consumed {inputs_consumed} dimensions", ) + + # Add any remaining dimensions from the original tensor + # This handles cases like tensor[idx] where tensor is multi-dimensional + # and idx is a tensor that only indexes the first dimension + for dim in range(inputs_consumed, self.fake_value.ndim): + output_sizes.append(self.fake_value.size(dim)) + return output_sizes def propagate_setitem( @@ -553,9 +565,11 @@ def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo: raise exc.TypeInferenceError( f"Subscript not supported on {self!s} with key={key!s}" ) from None - return TensorType( - origin, self.fake_value.new_empty(self._device_indexing_size(key)) - ) + new_sizes = self._device_indexing_size(key) + env = CompileEnvironment.current() + new_fake = env.new_index_result(self.fake_value, new_sizes) + + return TensorType(origin, new_fake) def merge(self, other: TypeInfo, var_name: str | None = None) -> TypeInfo: if isinstance(other, TensorType): @@ -2145,6 +2159,41 @@ def visit_NamedExpr(self, node: ast.NamedExpr) -> TypeInfo: def visit_Subscript(self, node: ast.Subscript) -> TypeInfo: value_type = self.visit(node.value) slice_type = self.visit(node.slice) + + # Check for rank mismatch in device loops (tile loops) + # Require all dimensions to be explicitly indexed unless using tensor indexing + if self.device_loop_depth > 0 and isinstance(value_type, TensorType): + if isinstance(slice_type, SequenceType): + keys = slice_type.unpack() + else: + keys = [slice_type] + + # Check for overpacked tiles first and raise error immediately + for k in keys: + if k.contains_type(TileIndexType) and not isinstance(k, TileIndexType): + raise exc.OverpackedTile(k) + + # Count how many dimensions will be consumed + inputs_consumed = 0 + has_tensor_index = False + for k in keys: + if isinstance(k, (LiteralType, SymIntType, TileIndexType)): + if not (isinstance(k, LiteralType) and k.value is None): + inputs_consumed += 1 + elif isinstance(k, SliceType): + inputs_consumed += 1 + elif isinstance(k, TensorType): + has_tensor_index = True + inputs_consumed += 1 + + # In device loops, require all dimensions to be indexed (unless using tensor indexing) + if not has_tensor_index and inputs_consumed < value_type.fake_value.ndim: + raise exc.RankMismatch( + value_type.fake_value.ndim, + inputs_consumed, + f"tensor shape: {tuple(value_type.fake_value.shape)}, indexed {inputs_consumed} dimensions", + ) + return value_type.propagate_getitem(slice_type, self.origin()) def visit_Slice(self, node: ast.Slice) -> TypeInfo: diff --git a/helion/language/memory_ops.py b/helion/language/memory_ops.py index eac0acde6..2bf51520a 100644 --- a/helion/language/memory_ops.py +++ b/helion/language/memory_ops.py @@ -7,6 +7,8 @@ from torch.fx import has_side_effect from .. import exc +from .._compiler.ast_extension import expr_from_string +from .._compiler.compile_environment import CompileEnvironment from .._compiler.indexing_strategy import SubscriptIndexing from . import _decorators from .stack_tensor import StackTensor @@ -239,7 +241,10 @@ def _( ) -> torch.Tensor: if isinstance(tensor, torch.Tensor): target_shape = SubscriptIndexing.compute_shape(tensor, index) - return tensor.new_empty(target_shape) + from .._compiler.compile_environment import CompileEnvironment + + env = CompileEnvironment.current() + return env.new_index_result(tensor, target_shape) if isinstance(tensor, tuple): tensor_like, dev_ptrs = tensor assert isinstance(tensor_like, torch.Tensor) @@ -275,6 +280,54 @@ def _(state: CodegenState) -> ast.AST: eviction_policy = ast.Constant(value=eviction_policy) if isinstance(tensor, torch.Tensor): + # Fast-path for tile_index(...) being broadcast-only indexed + from ..language import tile_index + tensor_node = state.fx_node.args[0] + if ( + isinstance(tensor_node, torch.fx.Node) + and tensor_node.op == "call_function" + and tensor_node.target == tile_index + ): + # tile.index tensors are not real memory accesses; materialize the + # block index variable with the requested broadcast/reshape. + env = CompileEnvironment.current() + block_id = env.get_tile_index_tensor_block_id(tensor) or env.get_block_id( + tensor.size(0) + ) + assert block_id is not None + base_var = state.codegen.index_var(block_id) + + def _format_idx(idx: object) -> str: + if idx is None: + return "None" + if isinstance(idx, slice): + if idx == slice(None, None, None): + return ":" + start = "" if idx.start is None else idx.start + stop = "" if idx.stop is None else idx.stop + step = idx.step + if step is None: + return f"{start}:{stop}" + return f"{start}:{stop}:{step}" + raise NotImplementedError + + try: + parts = [_format_idx(idx) for idx in subscript] + except NotImplementedError: + parts = [] + if parts: + bracket = ", ".join(parts) + return expr_from_string(f"{base_var}[{bracket}]") + # If we couldn't safely format the indices, fall back to the identity + if ( + not any(idx is None for idx in subscript) + and all( + isinstance(idx, slice) and idx == slice(None, None, None) + for idx in subscript + ) + ): + return state.ast_args[0] + # Use the shared memory op index for indexing strategy indexing_idx = device_fn.device_memory_op_index device_fn.device_memory_op_index += 1 diff --git a/helion/language/tile_ops.py b/helion/language/tile_ops.py index 7a97e79ff..b452f2484 100644 --- a/helion/language/tile_ops.py +++ b/helion/language/tile_ops.py @@ -49,8 +49,11 @@ def arange(length: int, device: torch.device) -> torch.Tensor: def _(tile: torch.SymInt) -> torch.Tensor: assert isinstance(tile, torch.SymInt) env = CompileEnvironment.current() - assert env.get_block_id(tile) is not None - return torch.empty([tile], dtype=env.index_dtype, device=env.device) + block_id = env.get_block_id(tile) + assert block_id is not None + t = torch.empty([tile], dtype=env.index_dtype, device=env.device) + env.register_tile_index_tensor_block_id(t, block_id) + return t @_decorators.codegen(tile_index, "triton") diff --git a/helion/language/view_ops.py b/helion/language/view_ops.py index da4678f6d..5f37e6556 100644 --- a/helion/language/view_ops.py +++ b/helion/language/view_ops.py @@ -84,7 +84,10 @@ def _(tensor: torch.Tensor, index: list[object]) -> torch.Tensor: else: raise exc.InvalidIndexingType(repr(val)) assert len(input_size) == 0 - return tensor.new_empty(output_size) + from .._compiler.compile_environment import CompileEnvironment + + env = CompileEnvironment.current() + return env.new_index_result(tensor, output_size) @_decorators.codegen(subscript, "triton") diff --git a/test/test_examples.expected b/test/test_examples.expected index dbb880280..6f908c19b 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -1588,94 +1588,6 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): # src[fp8_gemm.py:N]: return out return out ---- assertExpectedJournal(TestExamples.test_fused_linear_jsd) -from __future__ import annotations - -import torch -import triton -import triton.language as tl -from torch._inductor.runtime.triton_helpers import math as tl_math -from helion.runtime import default_launcher as _default_launcher - -@triton.jit -def _helion_fused_linear_jsd_kernel(student_logits, teacher_logits, loss, temperature, beta, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): - # src[fused_linear_jsd.py:N]: for batch in hl.tile(student_logits.shape[0]): - pid_0 = tl.program_id(0) - offset_0 = pid_0 * _BLOCK_SIZE_0 - indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) - indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) - # src[fused_linear_jsd.py:N]: student_prob = torch.log_softmax(student_logits[batch, :] / temperature, dim=-1) - load = tl.load(student_logits + (indices_0[:, None] * 256 + indices_1[None, :] * 1), None) - v_0 = load / temperature - amax = tl.cast(tl.reshape(tl.max(v_0, 1), [_BLOCK_SIZE_0, 1]), tl.float32) - v_1 = v_0 - amax - v_2 = libdevice.exp(v_1) - sum_1 = tl.cast(tl.reshape(tl.sum(v_2, 1), [_BLOCK_SIZE_0, 1]), tl.float32) - v_3 = tl_math.log(sum_1) - v_4 = v_1 - v_3 - # src[fused_linear_jsd.py:N]: teacher_prob = torch.log_softmax(teacher_logits[batch, :] / temperature, dim=-1) - load_1 = tl.load(teacher_logits + (indices_0[:, None] * 256 + indices_1[None, :] * 1), None) - v_5 = load_1 / temperature - amax_1 = tl.cast(tl.reshape(tl.max(v_5, 1), [_BLOCK_SIZE_0, 1]), tl.float32) - v_6 = v_5 - amax_1 - v_7 = libdevice.exp(v_6) - sum_2 = tl.cast(tl.reshape(tl.sum(v_7, 1), [_BLOCK_SIZE_0, 1]), tl.float32) - v_8 = tl_math.log(sum_2) - v_9 = v_6 - v_8 - # src[fused_linear_jsd.py:N]: student_prob = student_prob.to(torch.float).view(-1, student_prob.size(-1)) - student_prob_1 = tl.reshape(v_4, [_BLOCK_SIZE_0, _RDIM_SIZE_1]) - # src[fused_linear_jsd.py:N]: teacher_prob = teacher_prob.to(torch.float).view(-1, teacher_prob.size(-1)) - teacher_prob_1 = tl.reshape(v_9, [_BLOCK_SIZE_0, _RDIM_SIZE_1]) - # src[fused_linear_jsd.py:N]: m = torch.exp(student_prob) + beta * ( - v_10 = libdevice.exp(student_prob_1) - # src[fused_linear_jsd.py:N]: torch.exp(teacher_prob) - torch.exp(student_prob) - v_11 = libdevice.exp(teacher_prob_1) - v_12 = libdevice.exp(student_prob_1) - v_13 = v_11 - v_12 - # src[fused_linear_jsd.py:N]: m = torch.exp(student_prob) + beta * ( - # src[fused_linear_jsd.py:N]: torch.exp(teacher_prob) - torch.exp(student_prob) - # src[fused_linear_jsd.py:N]: ) - v_14 = v_13 * beta - v_15 = v_10 + v_14 - # src[fused_linear_jsd.py:N]: torch.log(m), teacher_prob, reduction="none", log_target=True - v_16 = tl_math.log(v_15) - # src[fused_linear_jsd.py:N]: teacher_div = torch.nn.functional.kl_div( - # src[fused_linear_jsd.py:N]: torch.log(m), teacher_prob, reduction="none", log_target=True - # src[fused_linear_jsd.py:N]: ).sum(dim=-1) - v_17 = teacher_prob_1 - v_16 - v_18 = libdevice.exp(teacher_prob_1) - v_19 = v_18 * v_17 - teacher_div = tl.cast(tl.sum(v_19, 1), tl.float32) - # src[fused_linear_jsd.py:N]: torch.log(m), student_prob, reduction="none", log_target=True - v_20 = tl_math.log(v_15) - # src[fused_linear_jsd.py:N]: student_div = torch.nn.functional.kl_div( - # src[fused_linear_jsd.py:N]: torch.log(m), student_prob, reduction="none", log_target=True - # src[fused_linear_jsd.py:N]: ).sum(dim=-1) - v_21 = student_prob_1 - v_20 - v_22 = libdevice.exp(student_prob_1) - v_23 = v_22 * v_21 - student_div = tl.cast(tl.sum(v_23, 1), tl.float32) - # src[fused_linear_jsd.py:N]: batch_loss = student_div + beta * (teacher_div - student_div) - v_24 = teacher_div - student_div - v_25 = v_24 * beta - v_26 = student_div + v_25 - # src[fused_linear_jsd.py:N]: loss[batch] = batch_loss - tl.store(loss + indices_0 * 1, v_26, None) - -def fused_linear_jsd_kernel(beta: float, ignore_index: int, temperature: float, student_logits: torch.Tensor, teacher_logits: torch.Tensor, *, _launcher=_default_launcher): - # src[fused_linear_jsd.py:N]: loss = student_logits.new_empty(student_logits.shape[0], dtype=torch.float) - loss = student_logits.new_empty(student_logits.shape[0], dtype=torch.float) - # src[fused_linear_jsd.py:N]: for batch in hl.tile(student_logits.shape[0]): - _BLOCK_SIZE_0 = 32 - _RDIM_SIZE_1 = 256 - # src[fused_linear_jsd.py:N]: for batch in hl.tile(student_logits.shape[0]): - # src[fused_linear_jsd.py:N]: student_prob = torch.log_softmax(student_logits[batch, :] / temperature, dim=-1) - # src[fused_linear_jsd.py:N]: teacher_prob = torch.log_softmax(teacher_logits[batch, :] / temperature, dim=-1) - # src[fused_linear_jsd.py:N-N]: ... - _launcher(_helion_fused_linear_jsd_kernel, (triton.cdiv(64, _BLOCK_SIZE_0),), student_logits, teacher_logits, loss, temperature, beta, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) - # src[fused_linear_jsd.py:N]: return (loss / student_logits.shape[0]).sum() - return (loss / student_logits.shape[0]).sum() - --- assertExpectedJournal(TestExamples.test_geglu) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index 31a796a5f..364a70ee2 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1402,43 +1402,6 @@ def test_jagged_sum(self): ) ) - def test_fused_linear_jsd(self): - beta = 0.5 - ignore_index = 1 - temperature = 1.0 - m, n, k = 64, 128, 256 - - student_input = torch.randn([m, n], device=DEVICE, dtype=torch.float32) - teacher_input = torch.randn([m, n], device=DEVICE, dtype=torch.float32) - student_weight = torch.randn([k, n], device=DEVICE, dtype=torch.float32) - teacher_weight = torch.randn([k, n], device=DEVICE, dtype=torch.float32) - student_logits = student_input @ student_weight.T - teacher_logits = teacher_input @ teacher_weight.T - - args = ( - beta, - ignore_index, - temperature, - student_logits, - teacher_logits, - ) - - # Import and use the reference implementation - mod = import_path(EXAMPLES_DIR / "fused_linear_jsd.py") - expected = mod.fused_linear_jsd_pytorch( - *args[:-2], student_input, teacher_input, student_weight, teacher_weight - ) - - self.assertExpectedJournal( - check_example( - "fused_linear_jsd", - args, - expected, - fn_name="fused_linear_jsd_kernel", - block_sizes=[32], - ) - ) - def test_jagged_layer_norm(self): num_rows, max_cols = 128, 64 M = 8 # number of features diff --git a/test/test_indirect_indexing.expected b/test/test_indirect_indexing.expected new file mode 100644 index 000000000..8350ce3c7 --- /dev/null +++ b/test/test_indirect_indexing.expected @@ -0,0 +1,345 @@ +This file is automatically generated by assertExpectedJournal calls in test_indirect_indexing.py. +Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. + +--- assertExpectedJournal(TestIndirectIndexing.test_indirect_indexing_2d_direct_gather) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_test(col, B, val, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[test_indirect_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]): + num_blocks_0 = tl.cdiv(32, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + # src[test_indirect_indexing.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[test_indirect_indexing.py:N]: for tile_k in hl.tile(K): + # src[test_indirect_indexing.py:N]: cols_2d = col[tile_m, tile_k] + # src[test_indirect_indexing.py:N]: B_slice = B[cols_2d[:, :, None], tile_n.index[None, None, :]] + # src[test_indirect_indexing.py:N-N]: ... + for offset_3 in tl.range(0, 16, _BLOCK_SIZE_2): + indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + acc_copy = acc + acc_copy_0 = acc_copy + # src[test_indirect_indexing.py:N]: cols_2d = col[tile_m, tile_k] + cols_2d = tl.load(col + (indices_0[:, None] * 16 + indices_3[None, :] * 1), None) + # src[test_indirect_indexing.py:N]: B_slice = B[cols_2d[:, :, None], tile_n.index[None, None, :]] + subscript = cols_2d[:, :, None] + load_1 = indices_1[None, None, :] + B_slice = tl.load(B + (subscript * 24 + load_1 * 1), None) + # src[test_indirect_indexing.py:N]: vals_2d = val[tile_m, tile_k] + vals_2d = tl.load(val + (indices_0[:, None] * 16 + indices_3[None, :] * 1), None) + # src[test_indirect_indexing.py:N]: contrib = vals_2d[:, :, None] * B_slice + subscript_1 = vals_2d[:, :, None] + v_0 = subscript_1 * B_slice + # src[test_indirect_indexing.py:N]: contrib = contrib.sum(dim=1) + contrib_1 = tl.cast(tl.sum(v_0, 1), tl.float32) + # src[test_indirect_indexing.py:N]: acc = acc + contrib + acc = acc_copy_0 + contrib_1 + # src[test_indirect_indexing.py:N]: C[tile_m, tile_n] = acc.to(out_dtype) + tl.store(C + (indices_0[:, None] * 24 + indices_1[None, :] * 1), acc, None) + +def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher): + # src[test_indirect_indexing.py:N]: M, K = col.shape + M, K = col.shape + # src[test_indirect_indexing.py:N]: _, N = B.shape + _, N = B.shape + # src[test_indirect_indexing.py:N]: out_dtype = torch.promote_types(val.dtype, B.dtype) + out_dtype = torch.promote_types(val.dtype, B.dtype) + # src[test_indirect_indexing.py:N]: C = torch.empty((M, N), dtype=out_dtype, device=B.device) + C = torch.empty((M, N), dtype=out_dtype, device=B.device) + # src[test_indirect_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]): + _BLOCK_SIZE_0 = 8 + _BLOCK_SIZE_1 = 8 + # src[test_indirect_indexing.py:N]: for tile_k in hl.tile(K): + # src[test_indirect_indexing.py:N]: cols_2d = col[tile_m, tile_k] + # src[test_indirect_indexing.py:N]: B_slice = B[cols_2d[:, :, None], tile_n.index[None, None, :]] + # src[test_indirect_indexing.py:N-N]: ... + _BLOCK_SIZE_2 = 4 + # src[test_indirect_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]): + # src[test_indirect_indexing.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[test_indirect_indexing.py:N-N]: ... + _RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_1) + _launcher(_helion_test, (triton.cdiv(32, _BLOCK_SIZE_0) * triton.cdiv(24, _BLOCK_SIZE_1),), col, B, val, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + # src[test_indirect_indexing.py:N]: return C + return C + +--- assertExpectedJournal(TestIndirectIndexing.test_indirect_indexing_2d_flat_load) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_test(col, B_flat, val, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[test_indirect_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]): + num_blocks_0 = tl.cdiv(32, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + # src[test_indirect_indexing.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[test_indirect_indexing.py:N]: for tile_k in hl.tile(K): + # src[test_indirect_indexing.py:N]: cols_2d = col[tile_m, tile_k] + # src[test_indirect_indexing.py:N]: B_indices = (cols_2d * N)[:, :, None] + tile_n.index[None, None, :] + # src[test_indirect_indexing.py:N-N]: ... + for offset_3 in tl.range(0, 16, _BLOCK_SIZE_2): + indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + acc_copy = acc + acc_copy_0 = acc_copy + # src[test_indirect_indexing.py:N]: cols_2d = col[tile_m, tile_k] + cols_2d = tl.load(col + (indices_0[:, None] * 16 + indices_3[None, :] * 1), None) + # src[test_indirect_indexing.py:N]: B_indices = (cols_2d * N)[:, :, None] + tile_n.index[None, None, :] + v_0 = tl.full([], 24, tl.int64) + v_1 = tl.cast(cols_2d * v_0, tl.int64) + subscript = v_1[:, :, None] + load_1 = indices_1[None, None, :] + v_2 = tl.cast(load_1, tl.int64) + v_3 = subscript + v_2 + # src[test_indirect_indexing.py:N]: B_slice = hl.load(B_flat, [B_indices]) + B_slice = tl.load(B_flat + v_3 * 1, None) + # src[test_indirect_indexing.py:N]: vals_2d = val[tile_m, tile_k] + vals_2d = tl.load(val + (indices_0[:, None] * 16 + indices_3[None, :] * 1), None) + # src[test_indirect_indexing.py:N]: contrib = vals_2d[:, :, None] * B_slice + subscript_1 = vals_2d[:, :, None] + v_4 = subscript_1 * B_slice + # src[test_indirect_indexing.py:N]: contrib = contrib.sum(dim=1) + contrib_1 = tl.cast(tl.sum(v_4, 1), tl.float32) + # src[test_indirect_indexing.py:N]: acc = acc + contrib + acc = acc_copy_0 + contrib_1 + # src[test_indirect_indexing.py:N]: C[tile_m, tile_n] = acc.to(out_dtype) + tl.store(C + (indices_0[:, None] * 24 + indices_1[None, :] * 1), acc, None) + +def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher): + # src[test_indirect_indexing.py:N]: M, K = col.shape + M, K = col.shape + # src[test_indirect_indexing.py:N]: _, N = B.shape + _, N = B.shape + # src[test_indirect_indexing.py:N]: out_dtype = torch.promote_types(val.dtype, B.dtype) + out_dtype = torch.promote_types(val.dtype, B.dtype) + # src[test_indirect_indexing.py:N]: C = torch.empty((M, N), dtype=out_dtype, device=B.device) + C = torch.empty((M, N), dtype=out_dtype, device=B.device) + # src[test_indirect_indexing.py:N]: B_flat = B.reshape(-1) # [K*N] + B_flat = B.reshape(-1) + # src[test_indirect_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]): + _BLOCK_SIZE_0 = 8 + _BLOCK_SIZE_1 = 8 + # src[test_indirect_indexing.py:N]: for tile_k in hl.tile(K): + # src[test_indirect_indexing.py:N]: cols_2d = col[tile_m, tile_k] + # src[test_indirect_indexing.py:N]: B_indices = (cols_2d * N)[:, :, None] + tile_n.index[None, None, :] + # src[test_indirect_indexing.py:N-N]: ... + _BLOCK_SIZE_2 = 4 + # src[test_indirect_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]): + # src[test_indirect_indexing.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[test_indirect_indexing.py:N-N]: ... + _RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_1) + _launcher(_helion_test, (triton.cdiv(32, _BLOCK_SIZE_0) * triton.cdiv(24, _BLOCK_SIZE_1),), col, B_flat, val, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + # src[test_indirect_indexing.py:N]: return C + return C + +--- assertExpectedJournal(TestIndirectIndexing.test_indirect_indexing_3d_direct_gather) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_test(col, B, val, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr): + # src[test_indirect_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]): + num_blocks_0 = tl.cdiv(16, _BLOCK_SIZE_0) + num_blocks_1 = tl.cdiv(12, _BLOCK_SIZE_1) + num_blocks_2 = tl.cdiv(10, _BLOCK_SIZE_2) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1 + pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) % num_blocks_2 + pid_3 = tl.program_id(0) // (num_blocks_0 * num_blocks_1 * num_blocks_2) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + offset_2 = pid_2 * _BLOCK_SIZE_2 + indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) + mask_2 = indices_2 < 10 + offset_3 = pid_3 * _BLOCK_SIZE_3 + indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32) + mask_3 = indices_3 < 14 + # src[test_indirect_indexing.py:N]: acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32) + # src[test_indirect_indexing.py:N]: for tile_k in hl.tile(K): + # src[test_indirect_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k] + # src[test_indirect_indexing.py:N]: B_slice = B[ + # src[test_indirect_indexing.py:N-N]: ... + for offset_5 in tl.range(0, 8, _BLOCK_SIZE_4): + indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32) + acc_copy = acc + acc_copy_0 = acc_copy + # src[test_indirect_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k] + cols_3d = tl.load(col + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 8 + indices_5[None, None, :] * 1), None) + # src[test_indirect_indexing.py:N]: cols_3d[:, :, :, None, None], + subscript = cols_3d[:, :, :, None, None] + # src[test_indirect_indexing.py:N]: tile_p.index[None, None, :, None], + load_1 = indices_2[None, None, :, None] + # src[test_indirect_indexing.py:N]: tile_q.index[None, None, None, :], + load_2 = indices_3[None, None, None, :] + # src[test_indirect_indexing.py:N]: B_slice = B[ + # src[test_indirect_indexing.py:N]: cols_3d[:, :, :, None, None], + # src[test_indirect_indexing.py:N]: tile_p.index[None, None, :, None], + # src[test_indirect_indexing.py:N-N]: ... + B_slice = tl.load(B + (subscript * 140 + load_1 * 14 + load_2 * 1), mask_2[None, None, None, :, None] & mask_3[None, None, None, None, :], other=0) + # src[test_indirect_indexing.py:N]: vals_3d = val[tile_m, tile_n, tile_k] + vals_3d = tl.load(val + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 8 + indices_5[None, None, :] * 1), None) + # src[test_indirect_indexing.py:N]: contrib = vals_3d[:, :, :, None, None] * B_slice + subscript_1 = vals_3d[:, :, :, None, None] + v_0 = subscript_1 * B_slice + # src[test_indirect_indexing.py:N]: contrib = contrib.sum(dim=2) + contrib_1 = tl.cast(tl.sum(v_0, 2), tl.float32) + # src[test_indirect_indexing.py:N]: acc = acc + contrib + acc = acc_copy_0 + contrib_1 + # src[test_indirect_indexing.py:N]: C[tile_m, tile_n, tile_p, tile_q] = acc.to(out_dtype) + tl.store(C + (indices_0[:, None, None, None] * 1680 + indices_1[None, :, None, None] * 140 + indices_2[None, None, :, None] * 14 + indices_3[None, None, None, :] * 1), acc, mask_2[None, None, :, None] & mask_3[None, None, None, :]) + +def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher): + # src[test_indirect_indexing.py:N]: M, N, K = col.shape + M, N, K = col.shape + # src[test_indirect_indexing.py:N]: _, P, Q = B.shape + _, P, Q = B.shape + # src[test_indirect_indexing.py:N]: out_dtype = torch.promote_types(val.dtype, B.dtype) + out_dtype = torch.promote_types(val.dtype, B.dtype) + # src[test_indirect_indexing.py:N]: C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device) + C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device) + # src[test_indirect_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]): + _BLOCK_SIZE_0 = 4 + _BLOCK_SIZE_1 = 4 + _BLOCK_SIZE_2 = 4 + _BLOCK_SIZE_3 = 4 + # src[test_indirect_indexing.py:N]: for tile_k in hl.tile(K): + # src[test_indirect_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k] + # src[test_indirect_indexing.py:N]: B_slice = B[ + # src[test_indirect_indexing.py:N-N]: ... + _BLOCK_SIZE_4 = 4 + # src[test_indirect_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]): + # src[test_indirect_indexing.py:N]: acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32) + # src[test_indirect_indexing.py:N-N]: ... + _RDIM_SIZE_5 = triton.next_power_of_2(_BLOCK_SIZE_2) + _launcher(_helion_test, (triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(12, _BLOCK_SIZE_1) * triton.cdiv(10, _BLOCK_SIZE_2) * triton.cdiv(14, _BLOCK_SIZE_3),), col, B, val, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=1) + # src[test_indirect_indexing.py:N]: return C + return C + +--- assertExpectedJournal(TestIndirectIndexing.test_indirect_indexing_3d_flat_load) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_test(col, B_flat, val, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr): + # src[test_indirect_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]): + num_blocks_0 = tl.cdiv(16, _BLOCK_SIZE_0) + num_blocks_1 = tl.cdiv(12, _BLOCK_SIZE_1) + num_blocks_2 = tl.cdiv(10, _BLOCK_SIZE_2) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1 + pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) % num_blocks_2 + pid_3 = tl.program_id(0) // (num_blocks_0 * num_blocks_1 * num_blocks_2) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + offset_2 = pid_2 * _BLOCK_SIZE_2 + indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) + mask_2 = indices_2 < 10 + offset_3 = pid_3 * _BLOCK_SIZE_3 + indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32) + mask_3 = indices_3 < 14 + # src[test_indirect_indexing.py:N]: acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32) + # src[test_indirect_indexing.py:N]: for tile_k in hl.tile(K): + # src[test_indirect_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k] + # src[test_indirect_indexing.py:N]: B_indices = ( + # src[test_indirect_indexing.py:N-N]: ... + for offset_5 in tl.range(0, 8, _BLOCK_SIZE_4): + indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32) + acc_copy = acc + acc_copy_0 = acc_copy + # src[test_indirect_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k] + cols_3d = tl.load(col + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 8 + indices_5[None, None, :] * 1), None) + # src[test_indirect_indexing.py:N]: cols_3d[:, :, :, None, None] * (P * Q) + subscript = cols_3d[:, :, :, None, None] + v_0 = tl.full([], 140, tl.int64) + v_1 = tl.cast(subscript * v_0, tl.int64) + # src[test_indirect_indexing.py:N]: + tile_p.index[None, None, :, None] * Q + load_1 = indices_2[None, None, :, None] + v_2 = tl.full([], 14, tl.int32) + v_3 = tl.cast(load_1 * v_2, tl.int32) + # src[test_indirect_indexing.py:N]: cols_3d[:, :, :, None, None] * (P * Q) + # src[test_indirect_indexing.py:N]: + tile_p.index[None, None, :, None] * Q + v_4 = v_3[None, :, :, :, :] + v_5 = tl.cast(v_4, tl.int64) + v_6 = v_1 + v_5 + # src[test_indirect_indexing.py:N]: + tile_q.index[None, None, None, :] + load_2 = indices_3[None, None, None, :] + # src[test_indirect_indexing.py:N]: cols_3d[:, :, :, None, None] * (P * Q) + # src[test_indirect_indexing.py:N]: + tile_p.index[None, None, :, None] * Q + # src[test_indirect_indexing.py:N]: + tile_q.index[None, None, None, :] + v_7 = load_2[None, :, :, :, :] + v_8 = tl.cast(v_7, tl.int64) + v_9 = v_6 + v_8 + # src[test_indirect_indexing.py:N]: B_slice = hl.load(B_flat, [B_indices]) + B_slice = tl.load(B_flat + v_9 * 1, mask_2[None, None, None, :, None] & mask_3[None, None, None, None, :], other=0) + # src[test_indirect_indexing.py:N]: vals_3d = val[tile_m, tile_n, tile_k] + vals_3d = tl.load(val + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 8 + indices_5[None, None, :] * 1), None) + # src[test_indirect_indexing.py:N]: contrib = vals_3d[:, :, :, None, None] * B_slice + subscript_1 = vals_3d[:, :, :, None, None] + v_10 = subscript_1 * B_slice + # src[test_indirect_indexing.py:N]: contrib = contrib.sum(dim=2) + contrib_1 = tl.cast(tl.sum(v_10, 2), tl.float32) + # src[test_indirect_indexing.py:N]: acc = acc + contrib + acc = acc_copy_0 + contrib_1 + # src[test_indirect_indexing.py:N]: C[tile_m, tile_n, tile_p, tile_q] = acc.to(out_dtype) + tl.store(C + (indices_0[:, None, None, None] * 1680 + indices_1[None, :, None, None] * 140 + indices_2[None, None, :, None] * 14 + indices_3[None, None, None, :] * 1), acc, mask_2[None, None, :, None] & mask_3[None, None, None, :]) + +def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher): + # src[test_indirect_indexing.py:N]: M, N, K = col.shape + M, N, K = col.shape + # src[test_indirect_indexing.py:N]: _, P, Q = B.shape + _, P, Q = B.shape + # src[test_indirect_indexing.py:N]: out_dtype = torch.promote_types(val.dtype, B.dtype) + out_dtype = torch.promote_types(val.dtype, B.dtype) + # src[test_indirect_indexing.py:N]: C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device) + C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device) + # src[test_indirect_indexing.py:N]: B_flat = B.reshape(-1) # [K*P*Q] + B_flat = B.reshape(-1) + # src[test_indirect_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]): + _BLOCK_SIZE_0 = 4 + _BLOCK_SIZE_1 = 4 + _BLOCK_SIZE_2 = 4 + _BLOCK_SIZE_3 = 4 + # src[test_indirect_indexing.py:N]: for tile_k in hl.tile(K): + # src[test_indirect_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k] + # src[test_indirect_indexing.py:N]: B_indices = ( + # src[test_indirect_indexing.py:N-N]: ... + _BLOCK_SIZE_4 = 4 + # src[test_indirect_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]): + # src[test_indirect_indexing.py:N]: acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32) + # src[test_indirect_indexing.py:N-N]: ... + _RDIM_SIZE_5 = triton.next_power_of_2(_BLOCK_SIZE_2) + _launcher(_helion_test, (triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(12, _BLOCK_SIZE_1) * triton.cdiv(10, _BLOCK_SIZE_2) * triton.cdiv(14, _BLOCK_SIZE_3),), col, B_flat, val, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=1) + # src[test_indirect_indexing.py:N]: return C + return C diff --git a/test/test_indirect_indexing.py b/test/test_indirect_indexing.py new file mode 100644 index 000000000..3f3a7a67c --- /dev/null +++ b/test/test_indirect_indexing.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +import torch + +import helion +import helion.language as hl +from helion._testing import DEVICE +from helion._testing import RefEagerTestBase +from helion._testing import TestCase +from helion._testing import code_and_output + + +class TestIndirectIndexing(RefEagerTestBase, TestCase): + def test_indirect_indexing_2d_direct_gather(self): + @helion.kernel() + def test( + col: torch.Tensor, # [M, K] int64 + val: torch.Tensor, # [M, K] fp32 + B: torch.Tensor, # [K, N] fp32 + ) -> torch.Tensor: # [M, N] fp32 + M, K = col.shape + _, N = B.shape + out_dtype = torch.promote_types(val.dtype, B.dtype) + C = torch.empty((M, N), dtype=out_dtype, device=B.device) + + for tile_m, tile_n in hl.tile([M, N]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + + for tile_k in hl.tile(K): + cols_2d = col[tile_m, tile_k] + B_slice = B[cols_2d[:, :, None], tile_n.index[None, None, :]] + vals_2d = val[tile_m, tile_k] + contrib = vals_2d[:, :, None] * B_slice + contrib = contrib.sum(dim=1) + acc = acc + contrib + + C[tile_m, tile_n] = acc.to(out_dtype) + + return C + + M, K, N = 32, 16, 24 + col = torch.randint(0, K, (M, K), device=DEVICE, dtype=torch.int64) + val = torch.rand((M, K), device=DEVICE, dtype=torch.float32) + B = torch.rand((K, N), device=DEVICE, dtype=torch.float32) + + code, result = code_and_output( + test, + (col, val, B), + block_size=[8, 8, 4], + ) + + expected = torch.zeros((M, N), device=DEVICE, dtype=torch.float32) + for i in range(M): + for j in range(N): + for k in range(K): + expected[i, j] += val[i, k] * B[col[i, k], j] + + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + self.assertExpectedJournal(code) + + def test_indirect_indexing_2d_flat_load(self): + @helion.kernel() + def test( + col: torch.Tensor, # [M, K] int64 + val: torch.Tensor, # [M, K] fp32 + B: torch.Tensor, # [K, N] fp32 + ) -> torch.Tensor: # [M, N] fp32 + M, K = col.shape + _, N = B.shape + out_dtype = torch.promote_types(val.dtype, B.dtype) + C = torch.empty((M, N), dtype=out_dtype, device=B.device) + B_flat = B.reshape(-1) # [K*N] + + for tile_m, tile_n in hl.tile([M, N]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + + for tile_k in hl.tile(K): + cols_2d = col[tile_m, tile_k] + B_indices = (cols_2d * N)[:, :, None] + tile_n.index[None, None, :] + B_slice = hl.load(B_flat, [B_indices]) + vals_2d = val[tile_m, tile_k] + contrib = vals_2d[:, :, None] * B_slice + contrib = contrib.sum(dim=1) + acc = acc + contrib + + C[tile_m, tile_n] = acc.to(out_dtype) + + return C + + M, K, N = 32, 16, 24 + col = torch.randint(0, K, (M, K), device=DEVICE, dtype=torch.int64) + val = torch.rand((M, K), device=DEVICE, dtype=torch.float32) + B = torch.rand((K, N), device=DEVICE, dtype=torch.float32) + + code, result = code_and_output( + test, + (col, val, B), + block_size=[8, 8, 4], + ) + + expected = torch.zeros((M, N), device=DEVICE, dtype=torch.float32) + for i in range(M): + for j in range(N): + for k in range(K): + expected[i, j] += val[i, k] * B[col[i, k], j] + + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + self.assertExpectedJournal(code) + + def test_indirect_indexing_3d_flat_load(self): + @helion.kernel() + def test( + col: torch.Tensor, # [M, N, K] int64 + val: torch.Tensor, # [M, N, K] fp32 + B: torch.Tensor, # [K, P, Q] fp32 + ) -> torch.Tensor: # [M, N, P, Q] fp32 + M, N, K = col.shape + _, P, Q = B.shape + out_dtype = torch.promote_types(val.dtype, B.dtype) + C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device) + B_flat = B.reshape(-1) # [K*P*Q] + + for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]): + acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32) + + for tile_k in hl.tile(K): + cols_3d = col[tile_m, tile_n, tile_k] + B_indices = ( + cols_3d[:, :, :, None, None] * (P * Q) + + tile_p.index[None, None, :, None] * Q + + tile_q.index[None, None, None, :] + ) + B_slice = hl.load(B_flat, [B_indices]) + vals_3d = val[tile_m, tile_n, tile_k] + contrib = vals_3d[:, :, :, None, None] * B_slice + contrib = contrib.sum(dim=2) + acc = acc + contrib + + C[tile_m, tile_n, tile_p, tile_q] = acc.to(out_dtype) + return C + + M, N, K, P, Q = 16, 12, 8, 10, 14 + col = torch.randint(0, K, (M, N, K), device=DEVICE, dtype=torch.int64) + val = torch.rand((M, N, K), device=DEVICE, dtype=torch.float32) + B = torch.rand((K, P, Q), device=DEVICE, dtype=torch.float32) + + code, result = code_and_output( + test, + (col, val, B), + block_size=[4, 4, 4, 4, 4], + ) + + expected = torch.zeros((M, N, P, Q), device=DEVICE, dtype=torch.float32) + for i in range(M): + for j in range(N): + for p in range(P): + for q in range(Q): + for k in range(K): + expected[i, j, p, q] += val[i, j, k] * B[ + col[i, j, k], p, q + ] + + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + self.assertExpectedJournal(code) + + def test_indirect_indexing_3d_direct_gather(self): + @helion.kernel() + def test( + col: torch.Tensor, # [M, N, K] int64 - indices for first dimension of B + val: torch.Tensor, # [M, N, K] fp32 - values to multiply + B: torch.Tensor, # [K, P, Q] fp32 - tensor to index into + ) -> torch.Tensor: # [M, N, P, Q] fp32 + M, N, K = col.shape + _, P, Q = B.shape + out_dtype = torch.promote_types(val.dtype, B.dtype) + C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device) + + for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]): + acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32) + + for tile_k in hl.tile(K): + cols_3d = col[tile_m, tile_n, tile_k] + B_slice = B[ + cols_3d[:, :, :, None, None], + tile_p.index[None, None, :, None], + tile_q.index[None, None, None, :], + ] + + vals_3d = val[tile_m, tile_n, tile_k] + contrib = vals_3d[:, :, :, None, None] * B_slice + contrib = contrib.sum(dim=2) + acc = acc + contrib + + C[tile_m, tile_n, tile_p, tile_q] = acc.to(out_dtype) + return C + + M, N, K, P, Q = 16, 12, 8, 10, 14 + col = torch.randint(0, K, (M, N, K), device=DEVICE, dtype=torch.int64) + val = torch.rand((M, N, K), device=DEVICE, dtype=torch.float32) + B = torch.rand((K, P, Q), device=DEVICE, dtype=torch.float32) + + code, result = code_and_output( + test, + (col, val, B), + block_size=[4, 4, 4, 4, 4], # 5D tiling for M, N, P, Q, K + ) + + expected = torch.zeros((M, N, P, Q), device=DEVICE, dtype=torch.float32) + for i in range(M): + for j in range(N): + for p in range(P): + for q in range(Q): + for k in range(K): + expected[i, j, p, q] += val[i, j, k] * B[col[i, j, k], p, q] + + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + self.assertExpectedJournal(code) + + +if __name__ == "__main__": + import unittest + + unittest.main() diff --git a/test/test_linear_jsd.expected b/test/test_linear_jsd.expected new file mode 100644 index 000000000..0259d4de7 --- /dev/null +++ b/test/test_linear_jsd.expected @@ -0,0 +1,90 @@ +This file is automatically generated by assertExpectedJournal calls in test_linear_jsd.py. +Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. + +--- assertExpectedJournal(TestLinearJSD.test_fused_linear_jsd) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime.triton_helpers import math as tl_math +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_fused_linear_jsd_kernel(student_logits, teacher_logits, loss, temperature, beta, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): + # src[fused_linear_jsd.py:N]: for batch in hl.tile(student_logits.shape[0]): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + # src[fused_linear_jsd.py:N]: student_prob = torch.log_softmax(student_logits[batch, :] / temperature, dim=-1) + load = tl.load(student_logits + (indices_0[:, None] * 256 + indices_1[None, :] * 1), None) + v_0 = load / temperature + amax = tl.cast(tl.reshape(tl.max(v_0, 1), [_BLOCK_SIZE_0, 1]), tl.float32) + v_1 = v_0 - amax + v_2 = libdevice.exp(v_1) + sum_1 = tl.cast(tl.reshape(tl.sum(v_2, 1), [_BLOCK_SIZE_0, 1]), tl.float32) + v_3 = tl_math.log(sum_1) + v_4 = v_1 - v_3 + # src[fused_linear_jsd.py:N]: teacher_prob = torch.log_softmax(teacher_logits[batch, :] / temperature, dim=-1) + load_1 = tl.load(teacher_logits + (indices_0[:, None] * 256 + indices_1[None, :] * 1), None) + v_5 = load_1 / temperature + amax_1 = tl.cast(tl.reshape(tl.max(v_5, 1), [_BLOCK_SIZE_0, 1]), tl.float32) + v_6 = v_5 - amax_1 + v_7 = libdevice.exp(v_6) + sum_2 = tl.cast(tl.reshape(tl.sum(v_7, 1), [_BLOCK_SIZE_0, 1]), tl.float32) + v_8 = tl_math.log(sum_2) + v_9 = v_6 - v_8 + # src[fused_linear_jsd.py:N]: student_prob = student_prob.to(torch.float).view(-1, student_prob.size(-1)) + student_prob_1 = tl.reshape(v_4, [_BLOCK_SIZE_0, _RDIM_SIZE_1]) + # src[fused_linear_jsd.py:N]: teacher_prob = teacher_prob.to(torch.float).view(-1, teacher_prob.size(-1)) + teacher_prob_1 = tl.reshape(v_9, [_BLOCK_SIZE_0, _RDIM_SIZE_1]) + # src[fused_linear_jsd.py:N]: m = torch.exp(student_prob) + beta * ( + v_10 = libdevice.exp(student_prob_1) + # src[fused_linear_jsd.py:N]: torch.exp(teacher_prob) - torch.exp(student_prob) + v_11 = libdevice.exp(teacher_prob_1) + v_12 = libdevice.exp(student_prob_1) + v_13 = v_11 - v_12 + # src[fused_linear_jsd.py:N]: m = torch.exp(student_prob) + beta * ( + # src[fused_linear_jsd.py:N]: torch.exp(teacher_prob) - torch.exp(student_prob) + # src[fused_linear_jsd.py:N]: ) + v_14 = v_13 * beta + v_15 = v_10 + v_14 + # src[fused_linear_jsd.py:N]: torch.log(m), teacher_prob, reduction="none", log_target=True + v_16 = tl_math.log(v_15) + # src[fused_linear_jsd.py:N]: teacher_div = torch.nn.functional.kl_div( + # src[fused_linear_jsd.py:N]: torch.log(m), teacher_prob, reduction="none", log_target=True + # src[fused_linear_jsd.py:N]: ).sum(dim=-1) + v_17 = libdevice.exp(teacher_prob_1) + v_18 = teacher_prob_1 - v_16 + v_19 = v_17 * v_18 + teacher_div = tl.cast(tl.sum(v_19, 1), tl.float32) + # src[fused_linear_jsd.py:N]: torch.log(m), student_prob, reduction="none", log_target=True + v_20 = tl_math.log(v_15) + # src[fused_linear_jsd.py:N]: student_div = torch.nn.functional.kl_div( + # src[fused_linear_jsd.py:N]: torch.log(m), student_prob, reduction="none", log_target=True + # src[fused_linear_jsd.py:N]: ).sum(dim=-1) + v_21 = libdevice.exp(student_prob_1) + v_22 = student_prob_1 - v_20 + v_23 = v_21 * v_22 + student_div = tl.cast(tl.sum(v_23, 1), tl.float32) + # src[fused_linear_jsd.py:N]: batch_loss = student_div + beta * (teacher_div - student_div) + v_24 = teacher_div - student_div + v_25 = v_24 * beta + v_26 = student_div + v_25 + # src[fused_linear_jsd.py:N]: loss[batch] = batch_loss + tl.store(loss + indices_0 * 1, v_26, None) + +def fused_linear_jsd_kernel(beta: float, ignore_index: int, temperature: float, student_logits: torch.Tensor, teacher_logits: torch.Tensor, *, _launcher=_default_launcher): + # src[fused_linear_jsd.py:N]: loss = student_logits.new_empty(student_logits.shape[0], dtype=torch.float) + loss = student_logits.new_empty(student_logits.shape[0], dtype=torch.float) + # src[fused_linear_jsd.py:N]: for batch in hl.tile(student_logits.shape[0]): + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_1 = 256 + # src[fused_linear_jsd.py:N]: for batch in hl.tile(student_logits.shape[0]): + # src[fused_linear_jsd.py:N]: student_prob = torch.log_softmax(student_logits[batch, :] / temperature, dim=-1) + # src[fused_linear_jsd.py:N]: teacher_prob = torch.log_softmax(teacher_logits[batch, :] / temperature, dim=-1) + # src[fused_linear_jsd.py:N-N]: ... + _launcher(_helion_fused_linear_jsd_kernel, (triton.cdiv(64, _BLOCK_SIZE_0),), student_logits, teacher_logits, loss, temperature, beta, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) + # src[fused_linear_jsd.py:N]: return (loss / student_logits.shape[0]).sum() + return (loss / student_logits.shape[0]).sum() diff --git a/test/test_linear_jsd.py b/test/test_linear_jsd.py new file mode 100644 index 000000000..0f73d833f --- /dev/null +++ b/test/test_linear_jsd.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import torch + +from helion._testing import DEVICE +from helion._testing import EXAMPLES_DIR +from helion._testing import RefEagerTestBase +from helion._testing import TestCase +from helion._testing import check_example +from helion._testing import import_path +from helion._testing import skipIfCpu + +torch.backends.cuda.matmul.fp32_precision = "tf32" +torch.backends.cudnn.conv.fp32_precision = "tf32" + + +@skipIfCpu("needs to be debugged") +class TestLinearJSD(RefEagerTestBase, TestCase): + def test_fused_linear_jsd(self): + beta = 0.5 + ignore_index = 1 + temperature = 1.0 + m, n, k = 64, 128, 256 + + student_input = torch.randn([m, n], device=DEVICE, dtype=torch.float32) + teacher_input = torch.randn([m, n], device=DEVICE, dtype=torch.float32) + student_weight = torch.randn([k, n], device=DEVICE, dtype=torch.float32) + teacher_weight = torch.randn([k, n], device=DEVICE, dtype=torch.float32) + student_logits = student_input @ student_weight.T + teacher_logits = teacher_input @ teacher_weight.T + + args = ( + beta, + ignore_index, + temperature, + student_logits, + teacher_logits, + ) + + mod = import_path(EXAMPLES_DIR / "fused_linear_jsd.py") + expected = mod.fused_linear_jsd_pytorch( + *args[:-2], student_input, teacher_input, student_weight, teacher_weight + ) + + self.assertExpectedJournal( + check_example( + "fused_linear_jsd", + args, + expected, + fn_name="fused_linear_jsd_kernel", + block_sizes=[32], + ) + ) + + +if __name__ == "__main__": + import unittest + + unittest.main()