From 948e8b9701051aef32c8afd53a7c680899cea9f6 Mon Sep 17 00:00:00 2001 From: Brady Planden Date: Sat, 29 Nov 2025 10:51:17 +0000 Subject: [PATCH] fix: fallback to external functions when LLVM intrinsics is not available --- diffsl/src/execution/llvm/codegen.rs | 466 +++++++++++++-------------- 1 file changed, 232 insertions(+), 234 deletions(-) diff --git a/diffsl/src/execution/llvm/codegen.rs b/diffsl/src/execution/llvm/codegen.rs index 6be1e10..a631bb2 100644 --- a/diffsl/src/execution/llvm/codegen.rs +++ b/diffsl/src/execution/llvm/codegen.rs @@ -1303,259 +1303,257 @@ impl<'ctx> CodeGen<'ctx> { } fn get_function(&mut self, name: &str) -> Option> { - match self.functions.get(name) { - Some(&func) => Some(func), - None => { - let function = match name { - // support some llvm intrinsics - "sin" | "cos" | "tan" | "exp" | "log" | "log10" | "sqrt" | "abs" - | "copysign" | "pow" | "min" | "max" => { - let arg_len = 1; - let intrinsic_name = match name { - "min" => "minnum", - "max" => "maxnum", - "abs" => "fabs", - _ => name, - }; - let llvm_name = - format!("llvm.{}.{}", intrinsic_name, self.diffsl_real_type.as_str()); - let intrinsic = Intrinsic::find(&llvm_name).unwrap(); - let ret_type = self.real_type; - - let args_types = std::iter::repeat_n(ret_type, arg_len) - .map(|f| f.into()) - .collect::>(); - // if we get an intrinsic, we don't need to add to the list of functions and can return early - return intrinsic.get_declaration(&self.module, args_types.as_slice()); + // Check cache for function + if let Some(&func) = self.functions.get(name) { + return Some(func); + } + + let function = match name { + // support some llvm intrinsics + "sin" | "cos" | "tan" | "exp" | "log" | "log10" | "sqrt" | "abs" | "copysign" + | "pow" | "min" | "max" => { + let intrinsic_name = match name { + "min" => "minnum", + "max" => "maxnum", + "abs" => "fabs", + _ => name, + }; + let llvm_name = + format!("llvm.{}.{}", intrinsic_name, self.diffsl_real_type.as_str()); + + let args_types: Vec = vec![self.real_type.into()]; + + // Try intrinsic first, fall back to libm for all functions + Some(match Intrinsic::find(&llvm_name) { + Some(intrinsic) => intrinsic.get_declaration(&self.module, &args_types)?, + None => { + // Fallback: declare external libm function + let args_types_meta: Vec = + vec![self.real_type.into()]; + let fn_type = self.real_type.fn_type(&args_types_meta, false); + self.module + .add_function(name, fn_type, Some(Linkage::External)) } - // some custom functions - "sigmoid" => { - let arg_len = 1; - let ret_type = self.real_type; + }) + } + // some custom functions + "sigmoid" => { + let arg_len = 1; + let ret_type = self.real_type; - let args_types = std::iter::repeat_n(ret_type, arg_len) - .map(|f| f.into()) - .collect::>(); + let args_types = std::iter::repeat_n(ret_type, arg_len) + .map(|f| f.into()) + .collect::>(); - let fn_type = ret_type.fn_type(args_types.as_slice(), false); - let fn_val = self.module.add_function(name, fn_type, None); + let fn_type = ret_type.fn_type(args_types.as_slice(), false); + let fn_val = self.module.add_function(name, fn_type, None); - for arg in fn_val.get_param_iter() { - arg.into_float_value().set_name("x"); - } + for arg in fn_val.get_param_iter() { + arg.into_float_value().set_name("x"); + } - let current_block = self.builder.get_insert_block().unwrap(); - let basic_block = self.context.append_basic_block(fn_val, "entry"); - self.builder.position_at_end(basic_block); - let x = fn_val.get_nth_param(0)?.into_float_value(); - let one = self.real_type.const_float(1.0); - let negx = self.builder.build_float_neg(x, name).ok()?; - let exp = self.get_function("exp").unwrap(); - let exp_negx = self - .builder - .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name) - .ok()?; - let one_plus_exp_negx = self - .builder - .build_float_add( - exp_negx - .try_as_basic_value() - .unwrap_basic() - .into_float_value(), - one, - name, - ) - .ok()?; - let sigmoid = self - .builder - .build_float_div(one, one_plus_exp_negx, name) - .ok()?; - self.builder.build_return(Some(&sigmoid)).ok(); - self.builder.position_at_end(current_block); - Some(fn_val) - } - "arcsinh" | "arccosh" => { - let arg_len = 1; - let ret_type = self.real_type; + let current_block = self.builder.get_insert_block().unwrap(); + let basic_block = self.context.append_basic_block(fn_val, "entry"); + self.builder.position_at_end(basic_block); + let x = fn_val.get_nth_param(0)?.into_float_value(); + let one = self.real_type.const_float(1.0); + let negx = self.builder.build_float_neg(x, name).ok()?; + let exp = self.get_function("exp").unwrap(); + let exp_negx = self + .builder + .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name) + .ok()?; + let one_plus_exp_negx = self + .builder + .build_float_add( + exp_negx + .try_as_basic_value() + .unwrap_basic() + .into_float_value(), + one, + name, + ) + .ok()?; + let sigmoid = self + .builder + .build_float_div(one, one_plus_exp_negx, name) + .ok()?; + self.builder.build_return(Some(&sigmoid)).ok(); + self.builder.position_at_end(current_block); + Some(fn_val) + } + "arcsinh" | "arccosh" => { + let arg_len = 1; + let ret_type = self.real_type; - let args_types = std::iter::repeat_n(ret_type, arg_len) - .map(|f| f.into()) - .collect::>(); + let args_types = std::iter::repeat_n(ret_type, arg_len) + .map(|f| f.into()) + .collect::>(); - let fn_type = ret_type.fn_type(args_types.as_slice(), false); - let fn_val = self.module.add_function(name, fn_type, None); + let fn_type = ret_type.fn_type(args_types.as_slice(), false); + let fn_val = self.module.add_function(name, fn_type, None); - for arg in fn_val.get_param_iter() { - arg.into_float_value().set_name("x"); - } + for arg in fn_val.get_param_iter() { + arg.into_float_value().set_name("x"); + } - let current_block = self.builder.get_insert_block().unwrap(); - let basic_block = self.context.append_basic_block(fn_val, "entry"); - self.builder.position_at_end(basic_block); - let x = fn_val.get_nth_param(0)?.into_float_value(); - let one = match name { - "arccosh" => self.real_type.const_float(-1.0), - "arcsinh" => self.real_type.const_float(1.0), - _ => panic!("unknown function"), - }; - let x_squared = self.builder.build_float_mul(x, x, name).ok()?; - let one_plus_x_squared = - self.builder.build_float_add(x_squared, one, name).ok()?; - let sqrt = self.get_function("sqrt").unwrap(); - let sqrt_one_plus_x_squared = self - .builder - .build_call( - sqrt, - &[BasicMetadataValueEnum::FloatValue(one_plus_x_squared)], - name, - ) - .unwrap() - .try_as_basic_value() - .unwrap_basic() - .into_float_value(); - let x_plus_sqrt_one_plus_x_squared = self - .builder - .build_float_add(x, sqrt_one_plus_x_squared, name) - .ok()?; - let ln = self.get_function("log").unwrap(); - let result = self - .builder - .build_call( - ln, - &[BasicMetadataValueEnum::FloatValue( - x_plus_sqrt_one_plus_x_squared, - )], - name, - ) - .unwrap() - .try_as_basic_value() - .unwrap_basic() - .into_float_value(); - self.builder.build_return(Some(&result)).ok(); - self.builder.position_at_end(current_block); - Some(fn_val) - } - "heaviside" => { - let arg_len = 1; - let ret_type = self.real_type; + let current_block = self.builder.get_insert_block().unwrap(); + let basic_block = self.context.append_basic_block(fn_val, "entry"); + self.builder.position_at_end(basic_block); + let x = fn_val.get_nth_param(0)?.into_float_value(); + let one = match name { + "arccosh" => self.real_type.const_float(-1.0), + "arcsinh" => self.real_type.const_float(1.0), + _ => panic!("unknown function"), + }; + let x_squared = self.builder.build_float_mul(x, x, name).ok()?; + let one_plus_x_squared = self.builder.build_float_add(x_squared, one, name).ok()?; + let sqrt = self.get_function("sqrt").unwrap(); + let sqrt_one_plus_x_squared = self + .builder + .build_call( + sqrt, + &[BasicMetadataValueEnum::FloatValue(one_plus_x_squared)], + name, + ) + .unwrap() + .try_as_basic_value() + .unwrap_basic() + .into_float_value(); + let x_plus_sqrt_one_plus_x_squared = self + .builder + .build_float_add(x, sqrt_one_plus_x_squared, name) + .ok()?; + let ln = self.get_function("log").unwrap(); + let result = self + .builder + .build_call( + ln, + &[BasicMetadataValueEnum::FloatValue( + x_plus_sqrt_one_plus_x_squared, + )], + name, + ) + .unwrap() + .try_as_basic_value() + .unwrap_basic() + .into_float_value(); + self.builder.build_return(Some(&result)).ok(); + self.builder.position_at_end(current_block); + Some(fn_val) + } + "heaviside" => { + let arg_len = 1; + let ret_type = self.real_type; - let args_types = std::iter::repeat_n(ret_type, arg_len) - .map(|f| f.into()) - .collect::>(); + let args_types = std::iter::repeat_n(ret_type, arg_len) + .map(|f| f.into()) + .collect::>(); - let fn_type = ret_type.fn_type(args_types.as_slice(), false); - let fn_val = self.module.add_function(name, fn_type, None); + let fn_type = ret_type.fn_type(args_types.as_slice(), false); + let fn_val = self.module.add_function(name, fn_type, None); - for arg in fn_val.get_param_iter() { - arg.into_float_value().set_name("x"); - } + for arg in fn_val.get_param_iter() { + arg.into_float_value().set_name("x"); + } - let current_block = self.builder.get_insert_block().unwrap(); - let basic_block = self.context.append_basic_block(fn_val, "entry"); - self.builder.position_at_end(basic_block); - let x = fn_val.get_nth_param(0)?.into_float_value(); - let zero = self.real_type.const_float(0.0); - let one = self.real_type.const_float(1.0); - let result = self - .builder - .build_select( - self.builder - .build_float_compare(FloatPredicate::OGE, x, zero, "x >= 0") - .unwrap(), - one, - zero, - name, - ) - .ok()?; - self.builder.build_return(Some(&result)).ok(); - self.builder.position_at_end(current_block); - Some(fn_val) - } - "tanh" | "sinh" | "cosh" => { - let arg_len = 1; - let ret_type = self.real_type; + let current_block = self.builder.get_insert_block().unwrap(); + let basic_block = self.context.append_basic_block(fn_val, "entry"); + self.builder.position_at_end(basic_block); + let x = fn_val.get_nth_param(0)?.into_float_value(); + let zero = self.real_type.const_float(0.0); + let one = self.real_type.const_float(1.0); + let result = self + .builder + .build_select( + self.builder + .build_float_compare(FloatPredicate::OGE, x, zero, "x >= 0") + .unwrap(), + one, + zero, + name, + ) + .ok()?; + self.builder.build_return(Some(&result)).ok(); + self.builder.position_at_end(current_block); + Some(fn_val) + } + "tanh" | "sinh" | "cosh" => { + let arg_len = 1; + let ret_type = self.real_type; - let args_types = std::iter::repeat_n(ret_type, arg_len) - .map(|f| f.into()) - .collect::>(); + let args_types = std::iter::repeat_n(ret_type, arg_len) + .map(|f| f.into()) + .collect::>(); - let fn_type = ret_type.fn_type(args_types.as_slice(), false); - let fn_val = self.module.add_function(name, fn_type, None); + let fn_type = ret_type.fn_type(args_types.as_slice(), false); + let fn_val = self.module.add_function(name, fn_type, None); - for arg in fn_val.get_param_iter() { - arg.into_float_value().set_name("x"); - } + for arg in fn_val.get_param_iter() { + arg.into_float_value().set_name("x"); + } - let current_block = self.builder.get_insert_block().unwrap(); - let basic_block = self.context.append_basic_block(fn_val, "entry"); - self.builder.position_at_end(basic_block); - let x = fn_val.get_nth_param(0)?.into_float_value(); - let negx = self.builder.build_float_neg(x, name).ok()?; - let exp = self.get_function("exp").unwrap(); - let exp_negx = self - .builder - .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name) - .ok()?; - let expx = self - .builder - .build_call(exp, &[BasicMetadataValueEnum::FloatValue(x)], name) - .ok()?; - let expx_minus_exp_negx = self - .builder - .build_float_sub( - expx.try_as_basic_value().unwrap_basic().into_float_value(), - exp_negx - .try_as_basic_value() - .unwrap_basic() - .into_float_value(), - name, - ) - .ok()?; - let expx_plus_exp_negx = self - .builder - .build_float_add( - expx.try_as_basic_value().unwrap_basic().into_float_value(), - exp_negx - .try_as_basic_value() - .unwrap_basic() - .into_float_value(), - name, - ) - .ok()?; - let result = match name { - "tanh" => self - .builder - .build_float_div(expx_minus_exp_negx, expx_plus_exp_negx, name) - .ok()?, - "sinh" => self - .builder - .build_float_div( - expx_minus_exp_negx, - self.real_type.const_float(2.0), - name, - ) - .ok()?, - "cosh" => self - .builder - .build_float_div( - expx_plus_exp_negx, - self.real_type.const_float(2.0), - name, - ) - .ok()?, - _ => panic!("unknown function"), - }; - self.builder.build_return(Some(&result)).ok(); - self.builder.position_at_end(current_block); - Some(fn_val) - } - _ => None, - }?; - self.functions.insert(name.to_owned(), function); - Some(function) + let current_block = self.builder.get_insert_block().unwrap(); + let basic_block = self.context.append_basic_block(fn_val, "entry"); + self.builder.position_at_end(basic_block); + let x = fn_val.get_nth_param(0)?.into_float_value(); + let negx = self.builder.build_float_neg(x, name).ok()?; + let exp = self.get_function("exp").unwrap(); + let exp_negx = self + .builder + .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name) + .ok()?; + let expx = self + .builder + .build_call(exp, &[BasicMetadataValueEnum::FloatValue(x)], name) + .ok()?; + let expx_minus_exp_negx = self + .builder + .build_float_sub( + expx.try_as_basic_value().unwrap_basic().into_float_value(), + exp_negx + .try_as_basic_value() + .unwrap_basic() + .into_float_value(), + name, + ) + .ok()?; + let expx_plus_exp_negx = self + .builder + .build_float_add( + expx.try_as_basic_value().unwrap_basic().into_float_value(), + exp_negx + .try_as_basic_value() + .unwrap_basic() + .into_float_value(), + name, + ) + .ok()?; + let result = match name { + "tanh" => self + .builder + .build_float_div(expx_minus_exp_negx, expx_plus_exp_negx, name) + .ok()?, + "sinh" => self + .builder + .build_float_div(expx_minus_exp_negx, self.real_type.const_float(2.0), name) + .ok()?, + "cosh" => self + .builder + .build_float_div(expx_plus_exp_negx, self.real_type.const_float(2.0), name) + .ok()?, + _ => panic!("unknown function"), + }; + self.builder.build_return(Some(&result)).ok(); + self.builder.position_at_end(current_block); + Some(fn_val) } - } + _ => None, + }?; + self.functions.insert(name.to_owned(), function); + Some(function) } + /// Returns the `FunctionValue` representing the function being compiled. #[inline] fn fn_value(&self) -> FunctionValue<'ctx> {