From b10c5494bb1a194d7a43f4f3c5c57ca94f2aaa97 Mon Sep 17 00:00:00 2001 From: Stevengre Date: Tue, 16 Dec 2025 20:58:06 +0800 Subject: [PATCH] fix: include function pointer arguments in call_edges for reduce_to The call_edges method now collects both: 1. Direct calls: functions used as the func operand in Call terminators 2. Indirect calls: functions passed as arguments (ZeroSized constants) that may be called via function pointers This fixes the 'unknown function' issue when a function is passed as an argument to higher-order functions like Result::map. --- kmir/src/kmir/smir.py | 41 +++++++++++++++---- .../488-support-function-pointer-calls.rs | 1 - 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/kmir/src/kmir/smir.py b/kmir/src/kmir/smir.py index e40157e50..f609a9fe8 100644 --- a/kmir/src/kmir/smir.py +++ b/kmir/src/kmir/smir.py @@ -201,13 +201,21 @@ def reduce_to(self, start_name: str) -> SMIRInfo: @cached_property def call_edges(self) -> dict[Ty, set[Ty]]: - # determines which functions are _called_ from others, by symbol name - result = {} + """Determines which functions are called or referenced from others. + + This includes: + 1. Direct calls: functions used as the `func` operand in Call terminators + 2. Indirect calls: functions passed as arguments (ZeroSized constants) that may be + called via function pointers (e.g., closures passed to higher-order functions) + """ + result: dict[Ty, set[Ty]] = {} + function_tys = set(self.function_symbols.keys()) + for sym, item in self.items.items(): if not SMIRInfo._is_func(item): continue # skip functions not present in the `function_symbols_reverse` table - if not sym in self.function_symbols_reverse: + if sym not in self.function_symbols_reverse: continue body = item['mono_item_kind']['MonoItemFn'].get('body') if body is None or 'blocks' not in body: @@ -215,12 +223,27 @@ def call_edges(self) -> dict[Ty, set[Ty]]: _LOGGER.debug(f'Skipping call edge extraction for {sym}: missing body') called_tys: set[Ty] = set() else: - called_funs = [ - b['terminator']['kind']['Call']['func'] for b in body['blocks'] if 'Call' in b['terminator']['kind'] - ] - called_tys = {Ty(op['Constant']['const_']['ty']) for op in called_funs if 'Constant' in op} - # TODO also add any constant operands used as arguments whose ty refer to a function - # the semantics currently does not support this, see issue #488 and stable-mir-json issue #55 + called_tys = set() + for block in body['blocks']: + if 'Call' not in block['terminator']['kind']: + continue + call = block['terminator']['kind']['Call'] + + # 1. Direct call: the function being called + func = call['func'] + if 'Constant' in func: + called_tys.add(Ty(func['Constant']['const_']['ty'])) + + # 2. Indirect call: function pointers passed as arguments + # These are ZeroSized constants whose ty is in the function table + for arg in call.get('args', []): + if 'Constant' in arg: + const_ = arg['Constant'].get('const_', {}) + if const_.get('kind') == 'ZeroSized': + ty = const_.get('ty') + if isinstance(ty, int) and ty in function_tys: + called_tys.add(Ty(ty)) + for ty in self.function_symbols_reverse[sym]: result[Ty(ty)] = called_tys return result diff --git a/kmir/src/tests/integration/data/prove-rs/488-support-function-pointer-calls.rs b/kmir/src/tests/integration/data/prove-rs/488-support-function-pointer-calls.rs index 39cdce47a..2d093e81a 100644 --- a/kmir/src/tests/integration/data/prove-rs/488-support-function-pointer-calls.rs +++ b/kmir/src/tests/integration/data/prove-rs/488-support-function-pointer-calls.rs @@ -20,7 +20,6 @@ impl TryFrom<&[u8]> for EightBytes { fn main() { let bytes: [u8;8] = [1,2,3,4,5,6,7,8]; - let _unused = EightBytes::from(bytes); let slice: &[u8] = &bytes; let thing: Result = EightBytes::try_from(slice); assert!(thing.is_ok());