Skip to content

Commit 9b38e26

Browse files
authored
fix: include function pointer arguments in call_edges for reduce_to (#898)
The call_edges method now collects both: 1. Direct calls: functions used as the func operand in Call terminators 2. Indirect calls: functions passed as arguments (ZeroSized constants) that may be called via function pointers This fixes the 'unknown function' issue when a function is passed as an argument to higher-order functions like Result::map.
1 parent 6e25bee commit 9b38e26

File tree

2 files changed

+32
-10
lines changed

2 files changed

+32
-10
lines changed

kmir/src/kmir/smir.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -201,26 +201,49 @@ def reduce_to(self, start_name: str) -> SMIRInfo:
201201

202202
@cached_property
203203
def call_edges(self) -> dict[Ty, set[Ty]]:
204-
# determines which functions are _called_ from others, by symbol name
205-
result = {}
204+
"""Determines which functions are called or referenced from others.
205+
206+
This includes:
207+
1. Direct calls: functions used as the `func` operand in Call terminators
208+
2. Indirect calls: functions passed as arguments (ZeroSized constants) that may be
209+
called via function pointers (e.g., closures passed to higher-order functions)
210+
"""
211+
result: dict[Ty, set[Ty]] = {}
212+
function_tys = set(self.function_symbols.keys())
213+
206214
for sym, item in self.items.items():
207215
if not SMIRInfo._is_func(item):
208216
continue
209217
# skip functions not present in the `function_symbols_reverse` table
210-
if not sym in self.function_symbols_reverse:
218+
if sym not in self.function_symbols_reverse:
211219
continue
212220
body = item['mono_item_kind']['MonoItemFn'].get('body')
213221
if body is None or 'blocks' not in body:
214222
# No MIR body means we cannot extract call edges; treat as having no outgoing edges.
215223
_LOGGER.debug(f'Skipping call edge extraction for {sym}: missing body')
216224
called_tys: set[Ty] = set()
217225
else:
218-
called_funs = [
219-
b['terminator']['kind']['Call']['func'] for b in body['blocks'] if 'Call' in b['terminator']['kind']
220-
]
221-
called_tys = {Ty(op['Constant']['const_']['ty']) for op in called_funs if 'Constant' in op}
222-
# TODO also add any constant operands used as arguments whose ty refer to a function
223-
# the semantics currently does not support this, see issue #488 and stable-mir-json issue #55
226+
called_tys = set()
227+
for block in body['blocks']:
228+
if 'Call' not in block['terminator']['kind']:
229+
continue
230+
call = block['terminator']['kind']['Call']
231+
232+
# 1. Direct call: the function being called
233+
func = call['func']
234+
if 'Constant' in func:
235+
called_tys.add(Ty(func['Constant']['const_']['ty']))
236+
237+
# 2. Indirect call: function pointers passed as arguments
238+
# These are ZeroSized constants whose ty is in the function table
239+
for arg in call.get('args', []):
240+
if 'Constant' in arg:
241+
const_ = arg['Constant'].get('const_', {})
242+
if const_.get('kind') == 'ZeroSized':
243+
ty = const_.get('ty')
244+
if isinstance(ty, int) and ty in function_tys:
245+
called_tys.add(Ty(ty))
246+
224247
for ty in self.function_symbols_reverse[sym]:
225248
result[Ty(ty)] = called_tys
226249
return result

kmir/src/tests/integration/data/prove-rs/488-support-function-pointer-calls.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ impl TryFrom<&[u8]> for EightBytes {
2020

2121
fn main() {
2222
let bytes: [u8;8] = [1,2,3,4,5,6,7,8];
23-
let _unused = EightBytes::from(bytes);
2423
let slice: &[u8] = &bytes;
2524
let thing: Result<EightBytes, _> = EightBytes::try_from(slice);
2625
assert!(thing.is_ok());

0 commit comments

Comments
 (0)