@@ -282,11 +282,7 @@ def block_size_var(self, block_id: int) -> str | None:
282282
283283 var_name = self .new_var (f"_BLOCK_SIZE_{ block_id } " )
284284 self .block_size_var_cache [key ] = var_name
285- host_expr = HostFunction .current ().literal_expr (block_value )
286- if self .constexpr_arg (var_name , host_expr ):
287- self .codegen .host_statements .append (
288- statement_from_string (f"{ var_name } = { host_expr } " )
289- )
285+ self .constexpr_arg_with_host_def (var_name , block_value )
290286
291287 return self .block_size_var_cache [key ]
292288
@@ -484,14 +480,50 @@ def expr_arg(self, sym: sympy.Expr, origin: Origin) -> SymbolArgument:
484480 self ._expr_args [sym ] = arg
485481 return self ._expr_args [sym ]
486482
487- def constexpr_arg (self , name : str , host_str : str | None = None ) -> bool :
483+ def constexpr_arg (self , name : str , value : object | None = None ) -> bool :
488484 """Create a constexpr argument, returns True if created, False if already exists."""
489485 if name in self ._constexpr_args :
490486 return False
491- self ._constexpr_args [name ] = rv = ConstExprArg (name , host_str or name )
487+ host_str = name if value is None else self ._format_constexpr_value (value )
488+ self ._constexpr_args [name ] = rv = ConstExprArg (name , host_str )
492489 self .arguments .append (rv )
493490 return True
494491
492+ def constexpr_arg_with_host_def (self , name : str , value : object ) -> None :
493+ """Create a constexpr argument and add its host-side definition if needed."""
494+ if self .constexpr_arg (name , value ):
495+ host_expr = self ._constexpr_args [name ].host_str ()
496+ self .codegen .host_statements .append (
497+ statement_from_string (f"{ name } = { host_expr } " )
498+ )
499+
500+ def _format_constexpr_value (self , value : object ) -> str :
501+ if isinstance (value , str ):
502+ return value
503+ if isinstance (value , (int , float , bool )):
504+ return repr (value )
505+
506+ # Extract sympy expression from torch symbolic types
507+ if isinstance (value , (torch .SymInt , torch .SymFloat , torch .SymBool )):
508+ value = value ._sympy_ ()
509+
510+ # Handle sympy expressions (sanitize by replacing triton_helpers functions)
511+ if isinstance (value , sympy .Expr ):
512+ expr = value .replace (
513+ lambda node : isinstance (node , sympy .Function )
514+ and getattr (node .func , "__name__" , "" )
515+ == "triton_helpers.div_floor_integer" ,
516+ lambda node : sympy .floor (node .args [0 ] / node .args [1 ]), # pyright: ignore[reportAttributeAccessIssue]
517+ ).replace (
518+ lambda node : isinstance (node , sympy .Function )
519+ and getattr (node .func , "__name__" , "" )
520+ == "triton_helpers.remainder_integer" ,
521+ lambda node : sympy .Mod (node .args [0 ], node .args [1 ]), # pyright: ignore[reportAttributeAccessIssue]
522+ )
523+ return HostFunction .current ().sympy_expr (expr )
524+
525+ return HostFunction .current ().literal_expr (value )
526+
495527 def _tensor_property (
496528 self ,
497529 prop_cls : type [_P ],
0 commit comments