diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index a7249ec9d..6ef03f0cf 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -15,3 +15,4 @@ #include "ops/rope.hpp" #include "ops/silu.hpp" #include "ops/swiglu.hpp" +#include "ops/mul.hpp" diff --git a/python/infinicore/fusion/__init__.py b/python/infinicore/fusion/__init__.py new file mode 100644 index 000000000..0432caff3 --- /dev/null +++ b/python/infinicore/fusion/__init__.py @@ -0,0 +1,31 @@ +""" +InfiniCore Fusion Module - 运行时算子融合调度器 + +提供基于 ninetoothed/ntops 的自动算子融合能力,支持: +- 子图描述和缓存 +- 静态启发式融合决策 +- 运行时开关配置 +""" + +from infinicore.fusion.subgraph import OpNode, SubGraph +from infinicore.fusion.fusion_config import FusionConfig +from infinicore.fusion.heuristics import FusionHeuristics +from infinicore.fusion.fusion_scheduler import FusionScheduler +from infinicore.fusion.graph_converter import ( + convert_graph_to_subgraph, + match_fusion_pattern, + find_fusable_subgraphs, + GraphOpInfo, +) + +__all__ = [ + "OpNode", + "SubGraph", + "FusionConfig", + "FusionHeuristics", + "FusionScheduler", + "convert_graph_to_subgraph", + "match_fusion_pattern", + "find_fusable_subgraphs", + "GraphOpInfo", +] diff --git a/python/infinicore/fusion/fusion_config.py b/python/infinicore/fusion/fusion_config.py new file mode 100644 index 000000000..86cbdd4cf --- /dev/null +++ b/python/infinicore/fusion/fusion_config.py @@ -0,0 +1,47 @@ +""" +融合调度配置模块 +""" + +from dataclasses import dataclass + + +@dataclass +class FusionConfig: + """ + 融合调度配置 + + 控制融合行为的运行时参数,支持动态开关和调优。 + + Attributes: + enable_fusion: 总开关,False 时所有子图回退到标准执行 + enable_cache: 内核缓存开关,关闭后每次都重新编译 + max_graph_size: 最大子图节点数,超过此值不尝试融合 + fallback_on_error: 融合失败时是否回退到标准执行 + debug_mode: 调试模式,开启后打印融合决策信息 + min_tensor_elements: V1启发式规则 - 最小张量元素数阈值 + min_nodes_for_fusion: V1启发式规则 - 最少节点数才尝试融合 + + Example: + >>> config = FusionConfig(enable_fusion=True, debug_mode=True) + >>> scheduler = FusionScheduler(config) + """ + # 核心开关 + enable_fusion: bool = True + enable_cache: bool = True + fallback_on_error: bool = True + debug_mode: bool = False + + # 图大小限制 + max_graph_size: int = 10 + + # V1 静态启发式规则参数 + min_tensor_elements: int = 1024 + min_nodes_for_fusion: int = 2 + + def __repr__(self) -> str: + return ( + f"FusionConfig(" + f"enable_fusion={self.enable_fusion}, " + f"min_elements={self.min_tensor_elements}, " + f"min_nodes={self.min_nodes_for_fusion})" + ) diff --git a/python/infinicore/fusion/fusion_scheduler.py b/python/infinicore/fusion/fusion_scheduler.py new file mode 100644 index 000000000..fa209b2c8 --- /dev/null +++ b/python/infinicore/fusion/fusion_scheduler.py @@ -0,0 +1,295 @@ +""" +融合调度器模块 - 运行时调度核心 + +接收子图描述,根据配置动态决定执行路径: +1. 融合路径:调用 ninetoothed 编译的融合内核 +2. 回退路径:逐个调用标准 InfiniCore 算子 +""" + +from typing import Dict, Tuple, Optional, Any +import functools + +from infinicore.fusion.subgraph import SubGraph, OpNode +from infinicore.fusion.fusion_config import FusionConfig +from infinicore.fusion.heuristics import FusionHeuristics +from infinicore.fusion.kernel_compiler import KernelCompiler, CompiledKernel, FusionError + + +class FusionScheduler: + """ + 运行时融合调度器 + + 核心职责: + 1. 接收子图和输入张量 + 2. 根据启发式规则决定是否融合 + 3. 管理编译后内核的缓存 + 4. 提供回退到标准执行的能力 + + Example: + >>> config = FusionConfig(enable_fusion=True, debug_mode=True) + >>> scheduler = FusionScheduler(config) + >>> + >>> graph = SubGraph( + ... nodes=(OpNode("silu", ("x",), ("y1",)), OpNode("mul", ("y1", "x"), ("y2",))), + ... input_names=("x",), + ... output_names=("y2",), + ... ) + >>> + >>> outputs = scheduler.dispatch(graph, {"x": tensor_x}) + """ + + def __init__(self, config: Optional[FusionConfig] = None): + self.config = config or FusionConfig() + self._kernel_cache: Dict[str, CompiledKernel] = {} + self._heuristics = FusionHeuristics(self.config) + self._compiler = KernelCompiler(self.config) + self._op_registry: Dict[str, callable] = {} + self._init_op_registry() + + def _init_op_registry(self): + """初始化算子注册表(用于回退执行)""" + # Initialize with empty registry first + self._op_registry = {} + + # 1. Try to register functional ops (silu, gelu, etc) + try: + import infinicore.nn.functional as F + self._op_registry.update({ + "silu": F.silu, + "gelu": F.gelu, + "relu": F.relu, + + + }) + if hasattr(F, 'rms_norm'): + self._op_registry["rms_norm"] = F.rms_norm + except (ImportError, AttributeError): + # Fallback to torch.nn.functional for functional ops + try: + import torch + import torch.nn.functional as TorchF + self._op_registry.setdefault("silu", TorchF.silu) + self._op_registry.setdefault("gelu", TorchF.gelu) + self._op_registry.setdefault("relu", TorchF.relu) + + # Create a compatible rms_norm wrapper + # torch.rms_norm(input, normalized_shape, weight=None, eps=1e-5) + # Our graph passes (input, weight), so we adapt the signature + def _torch_rms_norm_wrapper(input_tensor, weight, eps=1e-5): + # Infer normalized_shape from weight shape + normalized_shape = weight.shape + return TorchF.rms_norm(input_tensor, normalized_shape, weight, eps) + + self._op_registry.setdefault("rms_norm", _torch_rms_norm_wrapper) + except ImportError: + pass + + # 2. Try to register core ops (add, mul, etc) + try: + import infinicore + self._op_registry.update({ + "add": infinicore.add, + "mul": infinicore.mul, + }) + except (ImportError, AttributeError): + # Fallback to torch for development/testing if infinicore is missing + try: + import torch + self._op_registry.setdefault("add", torch.add) + self._op_registry.setdefault("mul", torch.mul) + except ImportError: + pass + + if self.config.debug_mode and not self._op_registry: + print("[FusionScheduler] Warning: No operators registered for fallback execution") + + + + def dispatch( + self, + graph: SubGraph, + inputs: Dict[str, Any] + ) -> Dict[str, Any]: + """ + 调度子图执行。 + + Args: + graph: 子图描述(算子序列 + 数据依赖) + inputs: 输入张量字典,键为张量名,值为张量对象 + + Returns: + outputs: 输出张量字典 + """ + # 提取输入形状和类型信息 + input_shapes = self._get_input_shapes(inputs) + input_dtypes = self._get_input_dtypes(inputs) + + # 检查是否应该尝试融合 + if not self._heuristics.should_fuse(graph, input_shapes): + if self.config.debug_mode: + print(f"[FusionScheduler] Skipping fusion for {graph.cache_key(input_dtypes, input_shapes)}") + return self._fallback_execute(graph, inputs) + + # 检查缓存 + cache_key = graph.cache_key(input_dtypes, input_shapes) + + if self.config.enable_cache and cache_key in self._kernel_cache: + if self.config.debug_mode: + print(f"[FusionScheduler] Cache hit: {cache_key}") + compiled_kernel = self._kernel_cache[cache_key] + return self._execute_fused(compiled_kernel, inputs, graph) + + # 尝试编译融合内核 + try: + compiled_kernel = self._compiler.compile(graph, input_dtypes, input_shapes) + + if self.config.enable_cache: + self._kernel_cache[cache_key] = compiled_kernel + + if self.config.debug_mode: + print(f"[FusionScheduler] Compilation success: {cache_key}") + + return self._execute_fused(compiled_kernel, inputs, graph) + + except FusionError as e: + if self.config.debug_mode: + print(f"[FusionScheduler] Fusion failed: {e}") + + if self.config.fallback_on_error: + return self._fallback_execute(graph, inputs) + else: + raise + + def _execute_fused( + self, + compiled_kernel: CompiledKernel, + inputs: Dict[str, Any], + graph: SubGraph + ) -> Dict[str, Any]: + """ + 执行融合内核 + + ninetoothed 融合内核期望接收**每个原始内核的所有张量**作为参数, + 顺序与编译时 Node 构建的顺序一致(不去重)。 + + 例如 SwiGLU (silu + mul): + - silu: (gate, gate_activated) + - mul: (gate_activated, up, output) + - 融合内核期望: (gate, gate_activated, gate_activated, up, output) = 5 个参数 + """ + import torch + + # 获取参考张量用于分配新张量 + ref_tensor = next(iter(inputs.values())) + + # 先收集所有唯一张量名,用于预分配 + unique_names = set() + for node in graph.nodes: + unique_names.update(node.inputs) + unique_names.update(node.outputs) + + # 构建张量字典:inputs 已有,其他需要分配 + tensor_dict = dict(inputs) + for name in unique_names: + if name not in tensor_dict: + # 预分配与参考张量相同 shape/dtype 的新张量 + tensor_dict[name] = torch.empty_like(ref_tensor) + + # 按照编译时的顺序构建参数列表(不去重,同名使用同一张量对象) + all_tensor_args = [] + for node in graph.nodes: + for tensor_name in list(node.inputs) + list(node.outputs): + all_tensor_args.append(tensor_name) + + # 构建实际参数:用张量字典中的对象替换名称 + ordered_args = [tensor_dict[name] for name in all_tensor_args] + + if self.config.debug_mode: + print(f"[FusionScheduler] Executing fused kernel with {len(ordered_args)} args: {all_tensor_args}") + + # 调用融合内核 + compiled_kernel(*ordered_args) + + # 返回输出张量 + return {name: tensor_dict[name] for name in graph.output_names} + + def _fallback_execute( + self, + graph: SubGraph, + inputs: Dict[str, Any] + ) -> Dict[str, Any]: + """ + 回退执行:逐个调用标准算子 + + 按拓扑顺序执行每个节点,中间结果存储在 values 字典中。 + """ + if self.config.debug_mode: + print(f"[FusionScheduler] Fallback execution for graph with {len(graph.nodes)} nodes") + + # 初始化值字典 + values: Dict[str, Any] = dict(inputs) + + # 按拓扑顺序执行 + for node in graph.nodes: + op_func = self._op_registry.get(node.op_type) + + if op_func is None: + raise RuntimeError(f"Operator '{node.op_type}' not registered for fallback") + + # 收集输入 + node_inputs = [values[name] for name in node.inputs] + + # 解析属性 + kwargs = {} + if node.attrs: + kwargs = dict(node.attrs) + + # 执行算子 + result = op_func(*node_inputs, **kwargs) + + # 存储输出 + if len(node.outputs) == 1: + values[node.outputs[0]] = result + else: + for i, out_name in enumerate(node.outputs): + values[out_name] = result[i] + + # 返回最终输出 + return {name: values[name] for name in graph.output_names} + + def _get_input_shapes(self, inputs: Dict[str, Any]) -> Dict[str, Tuple[int, ...]]: + """提取输入张量的形状""" + shapes = {} + for name, tensor in inputs.items(): + if hasattr(tensor, 'shape'): + shapes[name] = tuple(tensor.shape) + else: + shapes[name] = () + return shapes + + def _get_input_dtypes(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """提取输入张量的数据类型""" + dtypes = {} + for name, tensor in inputs.items(): + if hasattr(tensor, 'dtype'): + dtypes[name] = str(tensor.dtype) + else: + dtypes[name] = "unknown" + return dtypes + + def clear_cache(self): + """清空内核缓存""" + self._kernel_cache.clear() + if self.config.debug_mode: + print("[FusionScheduler] Cache cleared") + + def get_cache_stats(self) -> Dict[str, Any]: + """获取缓存统计信息""" + return { + "size": len(self._kernel_cache), + "keys": list(self._kernel_cache.keys()), + } + + def register_op(self, op_type: str, op_func: callable): + """注册自定义算子用于回退执行""" + self._op_registry[op_type] = op_func diff --git a/python/infinicore/fusion/graph_converter.py b/python/infinicore/fusion/graph_converter.py new file mode 100644 index 000000000..f46f6e6c6 --- /dev/null +++ b/python/infinicore/fusion/graph_converter.py @@ -0,0 +1,256 @@ +""" +Graph 转换器模块 - 将 InfiniCore 录制的 Graph 转换为 FusionScheduler 可处理的 SubGraph + +这个模块是 InfiniLM 集成的关键桥梁: +1. InfiniLM 推理时使用 start/stop_graph_recording() 捕获算子调用 +2. 本模块将录制的 Graph 转换为 SubGraph +3. FusionScheduler 分析 SubGraph 决定是否融合执行 +""" + +from typing import Dict, List, Tuple, Optional, Any +from dataclasses import dataclass + +from infinicore.fusion.subgraph import SubGraph, OpNode + + +@dataclass +class GraphOpInfo: + """录制图中算子的信息""" + op_type: str + input_names: Tuple[str, ...] + output_names: Tuple[str, ...] + attrs: Optional[Dict[str, Any]] = None + + +def convert_graph_to_subgraph(graph) -> Optional[SubGraph]: + """ + 将 InfiniCore 录制的 Graph 转换为 SubGraph。 + + 使用 C++ Graph.operators() 接口直接提取算子信息。 + + Args: + graph: infinicore.Graph 对象,从 stop_graph_recording() 获取 + + Returns: + SubGraph: 可供 FusionScheduler 处理的子图描述 + None: 如果图为空或无法转换 + """ + if graph is None: + return None + + # 使用新接口:Graph.operators() 返回 GraphOperator 列表 + if not hasattr(graph, 'operators'): + # 旧版本 InfiniCore,使用 fallback 逻辑 + return _convert_graph_legacy(graph) + + try: + operators = graph.operators() + except Exception: + return None + + if not operators or len(operators) == 0: + return None + + nodes = [] + all_tensor_names = set() + + for i, op in enumerate(operators): + # 获取算子类型(如 "Gemm" -> "gemm") + op_type = op.op_type.lower() if op.op_type else f"op_{i}" + + # 从 tensor_metas 提取输入输出 + inputs = [] + outputs = [] + + for j, meta in enumerate(op.tensor_metas): + name = f"t_{i}_{j}" + all_tensor_names.add(name) + if meta.is_input: + inputs.append(name) + else: + outputs.append(name) + + # 如果没有捕获到任何张量,使用占位符 + if not inputs and not outputs: + inputs = [f"input_{i}"] + outputs = [f"output_{i}"] + + nodes.append(OpNode( + op_type=op_type, + inputs=tuple(inputs), + outputs=tuple(outputs), + )) + + # 推断子图的外部输入/输出 + if nodes: + graph_inputs = nodes[0].inputs + graph_outputs = nodes[-1].outputs + else: + graph_inputs = () + graph_outputs = () + + return SubGraph( + nodes=tuple(nodes), + input_names=graph_inputs, + output_names=graph_outputs, + ) + + +def _convert_graph_legacy(graph) -> Optional[SubGraph]: + """Fallback: 旧版本 Graph 对象的转换逻辑""" + underlying = getattr(graph, '_graph', None) + if underlying is None: + return None + + ops_info = _extract_ops_from_graph(underlying) + if not ops_info: + return None + + return _build_subgraph_from_ops(ops_info) + + +def _extract_ops_from_graph(underlying_graph) -> List[GraphOpInfo]: + """ + 从底层 Graph 对象提取算子信息。 + + Args: + underlying_graph: C++ Graph 对象 (_infinicore.Graph) + + Returns: + 算子信息列表 + + Note: + 当前这是一个占位实现。完整实现需要: + 1. C++ Graph 类添加 nodes() 或 get_operations() 方法 + 2. 通过 pybind11 暴露到 Python + """ + ops_info = [] + + # 检查是否有 get_nodes 或类似方法 + if hasattr(underlying_graph, 'get_nodes'): + for node in underlying_graph.get_nodes(): + op_info = GraphOpInfo( + op_type=node.op_type, + input_names=tuple(node.inputs), + output_names=tuple(node.outputs), + attrs=dict(node.attrs) if hasattr(node, 'attrs') else None + ) + ops_info.append(op_info) + elif hasattr(underlying_graph, 'nodes'): + # 另一种可能的接口 + for node in underlying_graph.nodes: + op_info = GraphOpInfo( + op_type=str(node.op_type), + input_names=tuple(str(i) for i in node.inputs), + output_names=tuple(str(o) for o in node.outputs), + ) + ops_info.append(op_info) + else: + # Graph 没有暴露节点信息 + # TODO: 需要扩展 C++ pybind11 接口 + pass + + return ops_info + + +def _build_subgraph_from_ops(ops_info: List[GraphOpInfo]) -> Optional[SubGraph]: + """ + 从算子信息列表构建 SubGraph。 + + Args: + ops_info: 算子信息列表 + + Returns: + SubGraph 对象 + """ + if not ops_info: + return None + + # 构建 OpNode 列表 + nodes = [] + for op in ops_info: + node = OpNode( + op_type=op.op_type, + inputs=op.input_names, + outputs=op.output_names, + attrs=tuple(op.attrs.items()) if op.attrs else None + ) + nodes.append(node) + + # 推断输入和输出名称 + all_inputs = set() + all_outputs = set() + + for op in ops_info: + all_inputs.update(op.input_names) + all_outputs.update(op.output_names) + + # 图的真实输入 = 不是任何算子输出的输入 + graph_inputs = tuple(sorted(all_inputs - all_outputs)) + # 图的真实输出 = 最后一个算子的输出 + graph_outputs = ops_info[-1].output_names if ops_info else () + + return SubGraph( + nodes=tuple(nodes), + input_names=graph_inputs, + output_names=graph_outputs, + ) + + +# ============================================================ +# 模式匹配辅助函数 +# ============================================================ + +def match_fusion_pattern(graph: SubGraph, pattern: SubGraph) -> bool: + """ + 检查子图是否匹配指定的融合模式。 + + Args: + graph: 待匹配的子图 + pattern: 融合模式模板 (如 SwiGLU 模式) + + Returns: + True 如果匹配,否则 False + """ + if len(graph.nodes) != len(pattern.nodes): + return False + + for g_node, p_node in zip(graph.nodes, pattern.nodes): + if g_node.op_type != p_node.op_type: + return False + + return True + + +def find_fusable_subgraphs( + graph: SubGraph, + patterns: List[SubGraph] +) -> List[Tuple[int, int, SubGraph]]: + """ + 在图中查找所有可融合的子图。 + + Args: + graph: 完整的计算图 + patterns: 融合模式列表 + + Returns: + 列表,每个元素是 (起始索引, 结束索引, 匹配的模式) + """ + results = [] + + for pattern in patterns: + pattern_len = len(pattern.nodes) + + for start_idx in range(len(graph.nodes) - pattern_len + 1): + # 提取子图片段 + sub_nodes = graph.nodes[start_idx:start_idx + pattern_len] + sub_graph = SubGraph( + nodes=sub_nodes, + input_names=graph.input_names, # 简化处理 + output_names=graph.output_names, + ) + + if match_fusion_pattern(sub_graph, pattern): + results.append((start_idx, start_idx + pattern_len, pattern)) + + return results diff --git a/python/infinicore/fusion/heuristics.py b/python/infinicore/fusion/heuristics.py new file mode 100644 index 000000000..732d82680 --- /dev/null +++ b/python/infinicore/fusion/heuristics.py @@ -0,0 +1,296 @@ +import json +import os +from typing import Dict, Tuple, Set, Optional, Any + +from infinicore.fusion.subgraph import SubGraph +from infinicore.fusion.fusion_config import FusionConfig + + +def _detect_hardware_environment() -> str: + """ + 检测当前硬件环境 + + Returns: + "muxi" | "tianshu" | "default" + """ + try: + from infinicore.lib import _infinicore + + all_device_types = tuple(_infinicore.Device.Type.__members__.values())[:-1] + all_device_count = tuple(_infinicore.get_device_count(dt) for dt in all_device_types) + + for device_type, count in zip(all_device_types, all_device_count): + if count > 0: + if device_type == _infinicore.Device.Type.MOORE: + return "muxi" + elif device_type == _infinicore.Device.Type.METAX: + return "tianshu" + return "default" + except Exception: + return "default" + + +def _get_supported_ops() -> Set[str]: + """获取支持融合的算子集合,与 kernel_compiler 保持同步""" + fallback_ops = { + "silu", "gelu", "relu", "sigmoid", + "add", "mul", "sub", "div", + "rms_norm", "layer_norm", + } + + try: + from infinicore.fusion.kernel_compiler import get_supported_fusion_ops + ops = get_supported_fusion_ops() + # 如果 kernel_compiler 返回空集(ntops 不可用),使用 fallback + return ops if ops else fallback_ops + except ImportError: + return fallback_ops + + +# V1 支持融合的算子类型(延迟初始化) +SUPPORTED_OPS: Set[str] = set() + + +class FusionHeuristics: + """ + 静态启发式规则 - 决定是否值得融合 + + V1 实现基于简单规则过滤: + 1. 节点数检查 + 2. 张量大小检查 + 3. 算子类型检查 + 4. profile 决策(unfused 总时间 vs fused 时间) + + profile 缺失/异常:打印错误并返回 False + """ + + def __init__(self, config: FusionConfig, profile_path: Optional[str]): + self.config = config + self._supported_ops: Optional[Set[str]] = None + self._profile_cache: Optional[Dict[str, Any]] = None + self._profile_path_cached: Optional[str] = None + + # 自动检测硬件环境并构建 profile 路径 + env = _detect_hardware_environment() + profile_dir = os.path.join(os.path.dirname(__file__), "profile_result") + + if env == "muxi": + self.profile_path = os.path.join(profile_dir, "profile_result_muxi.json") + elif env == "tianshu": + self.profile_path = os.path.join(profile_dir, "profile_result_tianshu.json") + else: + self.profile_path = os.path.join(profile_dir, "default.json") + + def _get_ops(self) -> Set[str]: + """获取支持的算子集合(带缓存)""" + if self._supported_ops is None: + self._supported_ops = _get_supported_ops() + return self._supported_ops + + # ---------------- profile helpers ---------------- + + def _load_profile(self) -> Dict[str, Any]: + """ + 加载 profile 数据(带缓存) + + Returns: + {"unfused": {...}, "fused": {...}} + """ + # cache hit + if self._profile_cache is not None and self._profile_path_cached == self.profile_path: + return self._profile_cache + + if not os.path.exists(self.profile_path): + raise FileNotFoundError(f"Profile not found: {self.profile_path}") + + with open(self.profile_path, "r", encoding="utf-8") as f: + data = json.load(f) + + if "unfused" not in data or "fused" not in data: + raise ValueError("Profile must contain 'unfused' and 'fused' keys") + + self._profile_cache = data + self._profile_path_cached = self.profile_path + return data + + def _shape_to_key(self, shape: Tuple[int, ...]) -> str: + """tuple -> '[1, 512, 4096]'""" + return "[" + ", ".join(str(int(x)) for x in shape) + "]" + + def _pick_profile_shape_key( + self, + graph: SubGraph, + input_shapes: Dict[str, Tuple[int, ...]] + ) -> Optional[str]: + """ + 选择代表性的 shape key 用于查询 profile + 优先使用 graph.input_names 以避免选到权重 + """ + for name in graph.input_names: + shape = input_shapes.get(name) + if shape: + return self._shape_to_key(shape) + + for _, shape in input_shapes.items(): + if shape: + return self._shape_to_key(shape) + + return None + + def _fused_op_key(self, graph: SubGraph) -> str: + """生成融合算子 key,如 'add+rms_norm'""" + return "+".join(node.op_type for node in graph.nodes) + + def _parse_shape_key(self, shape_key: str) -> Tuple[int, ...]: + """'[1, 512, 4096]' -> (1, 512, 4096)""" + return tuple(int(x.strip()) for x in shape_key.strip("[]").split(",")) + + def _lookup_nearest_shape( + self, + shape_key: str, + shape_map: Dict[str, Any], + ) -> Optional[Any]: + """ + 查找最接近的 shape bucket + 规则:rank、B、H 必须一致,token(S) 维度做 lower-bound + """ + target = self._parse_shape_key(shape_key) + best_key = None + best_s = None + + for k in shape_map.keys(): + try: + cand = self._parse_shape_key(k) + except Exception: + continue + + if len(cand) != len(target): + continue + + # B 和 H 必须完全一致 + if cand[0] != target[0] or cand[-1] != target[-1]: + continue + + s = cand[1] + if s <= target[1]: + if best_s is None or s > best_s: + best_s = s + best_key = k + + # 如果所有 bucket 的 S 都 > target.S,取最小的 + if best_key is None: + for k in shape_map.keys(): + try: + cand = self._parse_shape_key(k) + except Exception: + continue + if len(cand) == len(target) and cand[0] == target[0] and cand[-1] == target[-1]: + best_key = k + break + + return shape_map.get(best_key) if best_key else None + + + def should_fuse( + self, + graph: SubGraph, + input_shapes: Dict[str, Tuple[int, ...]], + margin: float = 0.0, + ) -> bool: + """ + 判断是否应融合(静态规则 + profile 决策) + + 决策条件:unfused_time > fused_time * (1 + margin) + """ + # 总开关 + if not self.config.enable_fusion: + return False + + # 节点数检查 + if len(graph.nodes) < self.config.min_nodes_for_fusion: + if self.config.debug_mode: + print(f"[Fusion] Skip: node count {len(graph.nodes)} < {self.config.min_nodes_for_fusion}") + return False + + # 图大小上限 + if len(graph.nodes) > self.config.max_graph_size: + if self.config.debug_mode: + print(f"[Fusion] Skip: node count {len(graph.nodes)} > {self.config.max_graph_size}") + return False + + # 张量大小检查 + for name, shape in input_shapes.items(): + num_elements = 1 + for dim in shape: + num_elements *= dim + if num_elements < self.config.min_tensor_elements: + if self.config.debug_mode: + print(f"[Fusion] Skip: tensor '{name}' elements {num_elements} < {self.config.min_tensor_elements}") + return False + + # 算子类型检查 + supported = self._get_ops() + for node in graph.nodes: + if node.op_type not in supported: + if self.config.debug_mode: + print(f"[Fusion] Skip: unsupported op '{node.op_type}'") + return False + + # Profile 决策 + shape_key = self._pick_profile_shape_key(graph, input_shapes) + if shape_key is None: + print("[Fusion][Error] Cannot pick representative shape key") + return False + + op_key = self._fused_op_key(graph) + + try: + profile = self._load_profile() + except Exception as e: + print(f"[Fusion][Error] Failed to load profile: {e}") + return False + + if self.config.debug_mode: + print(f"[Fusion] Using profile: {self.profile_path}") + + try: + unfused_map = profile["unfused"][op_key] + fused_map = profile["fused"][op_key] + except Exception as e: + print(f"[Fusion][Error] Invalid profile structure for op '{op_key}': {e}") + return False + + t_unfused = unfused_map.get(shape_key) + if t_unfused is None: + t_unfused = self._lookup_nearest_shape(shape_key, unfused_map) + + t_fused = fused_map.get(shape_key) + if t_fused is None: + t_fused = self._lookup_nearest_shape(shape_key, fused_map) + + if t_unfused is None or t_fused is None: + print(f"[Fusion][Error] Profile missing: op='{op_key}', shape={shape_key}") + return False + + try: + t_unfused_f = float(t_unfused) + t_fused_f = float(t_fused) + margin_f = float(margin) + except Exception as e: + print(f"[Fusion][Error] Invalid profile values: {e}") + return False + + decision = t_unfused_f > t_fused_f * (1.0 + margin_f) + + if self.config.debug_mode: + print( + f"[Fusion] op='{op_key}', shape={shape_key}: " + f"unfused={t_unfused_f:.4f} fused={t_fused_f:.4f} margin={margin_f:.3f} => " + f"{'FUSE' if decision else 'SKIP'}" + ) + + return decision + + def get_supported_ops(self) -> Set[str]: + """返回当前支持融合的算子类型集合""" + return self._get_ops().copy() diff --git a/python/infinicore/fusion/kernel_compiler.py b/python/infinicore/fusion/kernel_compiler.py new file mode 100644 index 000000000..96de0452d --- /dev/null +++ b/python/infinicore/fusion/kernel_compiler.py @@ -0,0 +1,302 @@ +""" +内核编译器模块 - 封装 ninetoothed fusion 编译能力 + +将 InfiniCore 的 SubGraph 表示转换为 ninetoothed 可处理的格式, +调用 fusion 模块进行算子融合编译。 +""" + +from typing import Dict, Tuple, Callable, Optional, Any +import functools + +from infinicore.fusion.subgraph import SubGraph, OpNode +from infinicore.fusion.fusion_config import FusionConfig + + +class FusionError(Exception): + """融合编译失败异常""" + pass + + +class CompiledKernel: + """ + 编译后的融合内核封装 + + Attributes: + kernel: 可调用的融合内核函数 + graph: 原始子图 + cache_key: 缓存键 + """ + + def __init__( + self, + kernel: Callable, + graph: SubGraph, + cache_key: str, + metadata: Optional[Dict[str, Any]] = None + ): + self.kernel = kernel + self.graph = graph + self.cache_key = cache_key + self.metadata = metadata or {} + + def __call__(self, *args, **kwargs): + return self.kernel(*args, **kwargs) + + def __repr__(self) -> str: + return f"CompiledKernel(key={self.cache_key}, nodes={len(self.graph)})" + + +# 算子注册表:映射算子名到 (premake_function, requires_special_handling) +# 对于标准逐元素算子,只需要 ndim;对于特殊算子需要额外参数 +_OP_REGISTRY: Dict[str, Dict[str, Any]] = {} + + +def _init_op_registry(): + """延迟初始化算子注册表,避免循环导入""" + global _OP_REGISTRY + if _OP_REGISTRY: + return + + try: + import ntops.kernels.silu + import ntops.kernels.gelu + import ntops.kernels.relu + import ntops.kernels.sigmoid + import ntops.kernels.mul + import ntops.kernels.add + import ntops.kernels.sub + import ntops.kernels.div + import ntops.kernels.rms_norm + + _OP_REGISTRY = { + # 逐元素激活函数 (input, output) + "silu": { + "premake": ntops.kernels.silu.premake, + "type": "unary", # 单输入单输出 + }, + "gelu": { + "premake": ntops.kernels.gelu.premake, + "type": "unary", + }, + "relu": { + "premake": ntops.kernels.relu.premake, + "type": "unary", + }, + "sigmoid": { + "premake": ntops.kernels.sigmoid.premake, + "type": "unary", + }, + # 逐元素二元算子 (input, other, output) + "mul": { + "premake": ntops.kernels.mul.premake, + "type": "binary", + }, + "add": { + "premake": ntops.kernels.add.premake, + "type": "binary", + }, + "sub": { + "premake": ntops.kernels.sub.premake, + "type": "binary", + }, + "div": { + "premake": ntops.kernels.div.premake, + "type": "binary", + }, + # RMSNorm (input, weight, eps, output, num_normalized_elements) + "rms_norm": { + "premake": ntops.kernels.rms_norm.premake, + "type": "rms_norm", # 特殊处理 + }, + } + except ImportError as e: + # 如果 ntops 不可用,注册表保持为空 + pass + + +def get_supported_fusion_ops() -> set: + """返回支持融合的算子集合""" + _init_op_registry() + return set(_OP_REGISTRY.keys()) + + +class KernelCompiler: + """ + 内核编译器 - 封装 ninetoothed fusion 编译能力 + + 职责: + 1. 将 SubGraph 转换为 ninetoothed 可处理的格式 + 2. 调用 ninetoothed.fusion 进行算子融合 + 3. 返回编译后的可调用内核 + """ + + def __init__(self, config: FusionConfig): + self.config = config + self._ntops_available = False + self._ninetoothed_available = False + self._init_backends() + + def _init_backends(self): + """初始化后端依赖""" + try: + import ntops + import ninetoothed + self._ntops = ntops + self._ninetoothed = ninetoothed + self._ntops_available = True + self._ninetoothed_available = True + _init_op_registry() + except ImportError as e: + if self.config.debug_mode: + print(f"[KernelCompiler] Backend not available: {e}") + + @property + def is_available(self) -> bool: + """检查编译器后端是否可用""" + return self._ntops_available and self._ninetoothed_available and bool(_OP_REGISTRY) + + def compile( + self, + graph: SubGraph, + input_dtypes: Dict[str, str], + input_shapes: Dict[str, Tuple[int, ...]] + ) -> CompiledKernel: + """ + 将子图编译为融合内核。 + + Args: + graph: 子图描述 + input_dtypes: 输入张量的数据类型 + input_shapes: 输入张量的形状 + + Returns: + CompiledKernel: 编译后的融合内核 + + Raises: + FusionError: 编译失败时抛出 + """ + if not self.is_available: + raise FusionError("Backend not available: ntops or ninetoothed not installed") + + cache_key = graph.cache_key(input_dtypes, input_shapes) + + if self.config.debug_mode: + print(f"[KernelCompiler] Compiling graph: {graph}") + print(f"[KernelCompiler] Cache key: {cache_key}") + + try: + # 推断张量维度 + ndim = self._infer_tensor_ndim(input_shapes) + + # Step 1: 为每个算子创建 _Handle 对象 + handles = self._create_handles_for_graph(graph, ndim) + + # Step 2: 构建 ninetoothed Node 列表 + nodes = self._build_fusion_nodes(handles, graph) + + # Step 3: 调用融合 + from ninetoothed.fusion import _fuse_nodes + fused_nodes = _fuse_nodes(nodes) + + if len(fused_nodes) != 1: + raise FusionError(f"Fusion produced {len(fused_nodes)} nodes, expected 1") + + fused_kernel = fused_nodes[0].kernel + + return CompiledKernel( + kernel=fused_kernel, + graph=graph, + cache_key=cache_key, + metadata={ + "input_dtypes": input_dtypes, + "input_shapes": input_shapes, + "num_original_nodes": len(graph.nodes), + } + ) + + except Exception as e: + raise FusionError(f"Compilation failed: {e}") from e + + def _infer_tensor_ndim(self, input_shapes: Dict[str, Tuple[int, ...]]) -> int: + """从输入形状推断张量维度""" + if not input_shapes: + return 2 # 默认 2D + + # 取第一个非空形状的维度 + for name, shape in input_shapes.items(): + if shape: + return len(shape) + + return 2 + + def _create_handle_for_op(self, op_type: str, ndim: int) -> Any: + """ + 为单个算子创建 ninetoothed 内核句柄 (_Handle) + + Args: + op_type: 算子类型名称 + ndim: 张量维度 + + Returns: + _Handle 对象,可用于融合 + """ + if op_type not in _OP_REGISTRY: + raise FusionError(f"Operator '{op_type}' not in fusion registry") + + op_info = _OP_REGISTRY[op_type] + premake_fn = op_info["premake"] + op_kind = op_info["type"] + + # 根据算子类型调用不同的 premake 签名 + if op_kind in ("unary", "binary"): + # 标准逐元素算子: premake(ndim) + arrangement, application, tensors = premake_fn(ndim) + elif op_kind == "rms_norm": + # RMSNorm: premake(ndim, num_normalized_dims) + # 对于 LLM,通常归一化最后一个维度 + arrangement, application, tensors = premake_fn(ndim, num_normalized_dims=1) + else: + raise FusionError(f"Unknown operator kind: {op_kind}") + + # 调用 ninetoothed.make 创建 _Handle + handle = self._ninetoothed.make( + arrangement, + application, + tensors, + num_warps=4, + num_stages=2, + ) + + if self.config.debug_mode: + print(f"[KernelCompiler] Created handle for {op_type}: {handle}") + + return handle + + def _create_handles_for_graph(self, graph: SubGraph, ndim: int) -> Dict[str, Any]: + """为图中所有算子创建 _Handle 对象""" + handles = {} + for op_node in graph.nodes: + if op_node.op_type not in handles: + handles[op_node.op_type] = self._create_handle_for_op(op_node.op_type, ndim) + return handles + + def _build_fusion_nodes(self, handles: Dict[str, Any], graph: SubGraph) -> list: + """ + 构建 ninetoothed fusion 可处理的 Node 列表 + + 将每个 OpNode 包装为 ninetoothed.fusion.Node, + Node 持有 _Handle 对象和运行时参数信息。 + """ + from ninetoothed.fusion import Node + + nodes = [] + for op_node in graph.nodes: + handle = handles[op_node.op_type] + + # Node 需要 args 来建立数据依赖关系 + # 传入 inputs + outputs 作为参数,让 ninetoothed 能识别共享的数据对象(这里是字符串名) + node_args = list(op_node.inputs) + list(op_node.outputs) + node = Node(handle, args=tuple(node_args), kwargs={}) + nodes.append(node) + + return nodes diff --git a/python/infinicore/fusion/patterns/__init__.py b/python/infinicore/fusion/patterns/__init__.py new file mode 100644 index 000000000..212c941a8 --- /dev/null +++ b/python/infinicore/fusion/patterns/__init__.py @@ -0,0 +1,15 @@ +""" +Patterns 模块初始化 +""" + +from infinicore.fusion.patterns.llm_patterns import ( + create_swiglu_pattern, + create_add_rms_norm_pattern, + LLM_FUSION_PATTERNS, +) + +__all__ = [ + "create_swiglu_pattern", + "create_add_rms_norm_pattern", + "LLM_FUSION_PATTERNS", +] diff --git a/python/infinicore/fusion/patterns/llm_patterns.py b/python/infinicore/fusion/patterns/llm_patterns.py new file mode 100644 index 000000000..7674c3689 --- /dev/null +++ b/python/infinicore/fusion/patterns/llm_patterns.py @@ -0,0 +1,85 @@ +""" +LLM 推理常用融合模式 + +定义大语言模型推理中常见的可融合算子组合。 +""" + +from typing import List + +from infinicore.fusion.subgraph import SubGraph, OpNode + + +def create_swiglu_pattern() -> SubGraph: + """ + 创建 SwiGLU 激活融合模式 + + SwiGLU = SiLU(gate) * up + + 常见于 LLaMA、Mistral 等模型的 FFN 层。 + """ + return SubGraph( + nodes=( + OpNode( + op_type="silu", + inputs=("gate",), + outputs=("gate_activated",), + ), + OpNode( + op_type="mul", + inputs=("gate_activated", "up"), + outputs=("output",), + ), + ), + input_names=("gate", "up"), + output_names=("output",), + ) + + +def create_add_rms_norm_pattern() -> SubGraph: + """ + 创建 残差连接 + RMSNorm 融合模式 + + output = rms_norm(x + residual, weight) + + 常见于 Transformer 层的后处理。 + """ + return SubGraph( + nodes=( + OpNode( + op_type="add", + inputs=("x", "residual"), + outputs=("sum",), + ), + OpNode( + op_type="rms_norm", + inputs=("sum", "weight"), + outputs=("output",), + ), + ), + input_names=("x", "residual", "weight"), + output_names=("output",), + ) + + +def create_gelu_pattern() -> SubGraph: + """ + 创建 GELU 激活模式(单算子,用于测试) + """ + return SubGraph( + nodes=( + OpNode( + op_type="gelu", + inputs=("x",), + outputs=("output",), + ), + ), + input_names=("x",), + output_names=("output",), + ) + + +# 预定义的 LLM 融合模式列表 +LLM_FUSION_PATTERNS: List[SubGraph] = [ + create_swiglu_pattern(), + create_add_rms_norm_pattern(), +] diff --git a/python/infinicore/fusion/subgraph.py b/python/infinicore/fusion/subgraph.py new file mode 100644 index 000000000..aa974e8a6 --- /dev/null +++ b/python/infinicore/fusion/subgraph.py @@ -0,0 +1,103 @@ +""" +子图表示模块 - 轻量级、可哈希的子图数据结构 + +设计原则: +1. 解耦:不依赖 ninetoothed 或 torch.fx +2. 缓存友好:frozen dataclass 支持 __hash__ 和 __eq__ +3. 序列化:易于打印日志和调试 +""" + +from dataclasses import dataclass +from typing import Tuple, Optional, Any, Dict +import hashlib + + +@dataclass(frozen=True) +class OpNode: + """ + 算子节点(不可变,用于缓存 Key) + + Attributes: + op_type: 算子类型标识符(如 "rms_norm", "silu", "mul") + inputs: 输入张量名元组 + outputs: 输出张量名元组 + attrs: 算子属性(元组化以支持哈希) + + Example: + >>> node = OpNode( + ... op_type="silu", + ... inputs=("x",), + ... outputs=("y",), + ... ) + """ + op_type: str + inputs: Tuple[str, ...] + outputs: Tuple[str, ...] + attrs: Optional[Tuple[Tuple[str, Any], ...]] = None + + def __hash__(self) -> int: + return hash((self.op_type, self.inputs, self.outputs, self.attrs)) + + def __repr__(self) -> str: + return f"OpNode({self.op_type}, inputs={self.inputs}, outputs={self.outputs})" + + +@dataclass(frozen=True) +class SubGraph: + """ + 子图表示(不可变,用于缓存 Key) + + 表示一个可融合的算子序列,节点按拓扑排序排列。 + + Attributes: + nodes: 拓扑排序的算子节点元组 + input_names: 子图外部输入名 + output_names: 子图外部输出名 + + Example: + >>> graph = SubGraph( + ... nodes=( + ... OpNode("silu", ("x",), ("y1",)), + ... OpNode("mul", ("y1", "x"), ("y2",)), + ... ), + ... input_names=("x",), + ... output_names=("y2",), + ... ) + """ + nodes: Tuple[OpNode, ...] + input_names: Tuple[str, ...] + output_names: Tuple[str, ...] + + def __hash__(self) -> int: + return hash((self.nodes, self.input_names, self.output_names)) + + def cache_key( + self, + input_dtypes: Dict[str, str], + input_shapes: Dict[str, Tuple[int, ...]] + ) -> str: + """ + 生成缓存 Key(包含图结构 + dtype + shape)。 + + 不同的 dtype 或 shape 组合会生成不同的内核,因此需要包含在缓存键中。 + + Args: + input_dtypes: 输入张量的数据类型字典 + input_shapes: 输入张量的形状字典 + + Returns: + 16 字符的十六进制哈希字符串 + """ + key_data = ( + hash(self), + tuple(sorted(input_dtypes.items())), + tuple((k, v) for k, v in sorted(input_shapes.items())) + ) + return hashlib.sha256(str(key_data).encode()).hexdigest()[:16] + + def __len__(self) -> int: + """返回子图中的节点数""" + return len(self.nodes) + + def __repr__(self) -> str: + return f"SubGraph(nodes={len(self.nodes)}, inputs={self.input_names}, outputs={self.output_names})" diff --git a/src/infinicore/utils.hpp b/src/infinicore/utils.hpp index cf8e69789..4daf5f125 100644 --- a/src/infinicore/utils.hpp +++ b/src/infinicore/utils.hpp @@ -8,10 +8,31 @@ inline struct SpdlogInitializer { SpdlogInitializer() { - if (!std::getenv("INFINICORE_LOG_LEVEL")) { + const char* log_level_env = std::getenv("INFINICORE_LOG_LEVEL"); + if (!log_level_env) { spdlog::set_level(spdlog::level::info); } else { - spdlog::cfg::load_env_levels("INFINICORE_LOG_LEVEL"); + std::string level_str(log_level_env); + spdlog::level::level_enum level; + if (level_str == "trace") { + level = spdlog::level::trace; + } else if (level_str == "debug") { + level = spdlog::level::debug; + } else if (level_str == "info") { + level = spdlog::level::info; + } else if (level_str == "warn" || level_str == "warning") { + level = spdlog::level::warn; + } else if (level_str == "error") { + level = spdlog::level::err; + } else if (level_str == "critical") { + level = spdlog::level::critical; + } else if (level_str == "off") { + level = spdlog::level::off; + } else { + // Default to info if unknown level + level = spdlog::level::info; + } + spdlog::set_level(level); } // Set pattern for logging // Using SPDLOG_* macros enables source location support (%s and %#) @@ -23,17 +44,14 @@ inline struct SpdlogInitializer { #define STRINGIZE_(x) #x #define STRINGIZE(x) STRINGIZE_(x) -#define INFINICORE_CHECK_ERROR(call) \ - do { \ - SPDLOG_DEBUG("Entering `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`."); \ - infiniStatus_t ret = (call); \ - SPDLOG_DEBUG("Exiting `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`."); \ - if (ret != INFINI_STATUS_SUCCESS) { \ - throw std::runtime_error("`" #call "` failed with error: " + std::string(infini_status_string(ret)) \ - + " from " + std::string(__func__) \ - + " at " + std::string(__FILE__) \ - + ":" + std::to_string(__LINE__) + "."); \ - } \ +#define INFINICORE_CHECK_ERROR(call) \ + do { \ + SPDLOG_DEBUG("Entering `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`."); \ + infiniStatus_t ret = (call); \ + SPDLOG_DEBUG("Exiting `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`."); \ + if (ret != INFINI_STATUS_SUCCESS) { \ + throw std::runtime_error(#call " failed with error: " + std::string(infini_status_string(ret))); \ + } \ } while (false) #define INFINICORE_ASSERT_TENSORS_SAME_DEVICE(FIRST___, ...) \ diff --git a/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh new file mode 100644 index 000000000..3d6b13b53 --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh @@ -0,0 +1,8 @@ +#ifndef __ADD_RMS_NORM_METAX_CUH__ +#define __ADD_RMS_NORM_METAX_CUH__ + +#include "../add_rms_norm.h" + +DESCRIPTOR(metax) + +#endif diff --git a/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca new file mode 100644 index 000000000..aad96b649 --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca @@ -0,0 +1,167 @@ +#include "../../../devices/metax/metax_common.h" +#include "add_rms_norm_metax.cuh" + +#include "../../../devices/metax/metax_kernel_common.h" +#include + +#include "../../../reduce/cuda/reduce.cuh" + +#include "../cuda/kernel.cuh" + +template +INFINIOP_METAX_KERNEL add_rmsnormKernel( + Tdata *__restrict__ y, + Tdata *__restrict__ residual_out, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_nhead, + ptrdiff_t stride_residual_out_batch, + ptrdiff_t stride_residual_out_nhead, + const Tdata *__restrict__ a, + ptrdiff_t stride_a_batch, + ptrdiff_t stride_a_nhead, + const Tdata *__restrict__ b, + ptrdiff_t stride_b_batch, + ptrdiff_t stride_b_nhead, + const Tweight *__restrict__ w, + size_t nhead, + size_t dim, + float epsilon) { + add_rmsnormBlock( + y, residual_out, + stride_y_batch, stride_y_nhead, + stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + w, nhead, dim, epsilon); +} + +namespace op::add_rms_norm::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t weight_desc, + float epsilon, + infiniopTensorDescriptor_t residual_out_desc) { + auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); + CHECK_RESULT(result); + auto info = result.take(); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +// launch kernel with different data types +template +infiniStatus_t launchKernel( + uint32_t batch_size, size_t nhead, size_t dim, + void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead, + void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead, + const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead, + const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead, + const void *w, infiniDtype_t wtype, + float epsilon, + hcStream_t stream) { + +#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ + add_rmsnormKernel<<>>( \ + reinterpret_cast(y), \ + reinterpret_cast(residual_out), \ + stride_y_batch, \ + stride_y_nhead, \ + stride_residual_out_batch, \ + stride_residual_out_nhead, \ + reinterpret_cast(a), \ + stride_a_batch, \ + stride_a_nhead, \ + reinterpret_cast(b), \ + stride_b_batch, \ + stride_b_nhead, \ + reinterpret_cast(w), \ + nhead, \ + dim, \ + epsilon) + + if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(half, half, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(half, __hpcc_bfloat16, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(half, float, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(__hpcc_bfloat16, half, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(__hpcc_bfloat16, float, float); + } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(float, float, float); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + +#undef LAUNCH_KERNEL + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *a, const void *b, const void *weight, + void *residual_out, void *stream_) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + auto stride_a_batch = _info.a_strides[0]; + auto stride_a_nhead = _info.a_strides[1]; + auto stride_b_batch = _info.b_strides[0]; + auto stride_b_nhead = _info.b_strides[1]; + auto stride_y_batch = _info.y_strides[0]; + auto stride_y_nhead = _info.y_strides[1]; + auto stride_residual_out_batch = _info.residual_out_strides[0]; + auto stride_residual_out_nhead = _info.residual_out_strides[1]; + auto dim = _info.dim(); + uint32_t batch_size = static_cast(_info.shape[0]); + size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1; + auto stream = reinterpret_cast(stream_); + + // launch kernel with different block sizes + if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, stream)); + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + return INFINI_STATUS_SUCCESS; +} +} // namespace op::add_rms_norm::metax diff --git a/src/infiniop/ops/add_rms_norm/operator.cc b/src/infiniop/ops/add_rms_norm/operator.cc index a856e5447..625828823 100644 --- a/src/infiniop/ops/add_rms_norm/operator.cc +++ b/src/infiniop/ops/add_rms_norm/operator.cc @@ -17,8 +17,7 @@ // #include "bang/add_rms_norm_bang.h" #endif #ifdef ENABLE_METAX_API -// TODO: Add Metax implementation -// #include "metax/add_rms_norm_metax.cuh" +#include "metax/add_rms_norm_metax.cuh" #endif #ifdef ENABLE_MOORE_API // TODO: Add Moore implementation @@ -67,6 +66,9 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor( #ifdef ENABLE_HYGON_API CREATE(INFINI_DEVICE_HYGON, nvidia); #endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_KUNLUN_API // CREATE(INFINI_DEVICE_KUNLUN, kunlun); #endif @@ -100,6 +102,9 @@ __C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescript #ifdef ENABLE_HYGON_API GET(INFINI_DEVICE_HYGON, nvidia); #endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_KUNLUN_API // GET(INFINI_DEVICE_KUNLUN, kunlun); #endif @@ -144,6 +149,9 @@ __C infiniStatus_t infiniopAddRMSNorm( #ifdef ENABLE_HYGON_API CALCULATE(INFINI_DEVICE_HYGON, nvidia); #endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_KUNLUN_API // CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); #endif @@ -179,6 +187,9 @@ __C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescrip #ifdef ENABLE_HYGON_API DESTROY(INFINI_DEVICE_HYGON, nvidia); #endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_KUNLUN_API // DESTROY(INFINI_DEVICE_KUNLUN, kunlun); #endif