From c3d2b2480e43afa24f16698b1fc4af46ef0401ed Mon Sep 17 00:00:00 2001 From: Gustavo Delerue Date: Fri, 13 Mar 2026 17:09:34 +0000 Subject: [PATCH] Refactor cfold propagation and add eager mode Refactor `cfold` in `ecPhlCodeTx` to track propagated substitutions and preserved dependencies explicitly, and use that state to decide when folding can continue across assignments and structured instructions. Add an eager `cfold*` variant that keeps folding by promoting preserved variables into the propagated substitution when possible, and extend `tests/cfold.ec` with coverage for both the default and eager behaviors. --- src/ecCoreModules.ml | 36 +++++ src/ecCoreModules.mli | 5 + src/ecParser.mly | 4 +- src/ecParsetree.ml | 8 +- src/phl/ecPhlCodeTx.ml | 321 ++++++++++++++++++++++++++-------------- src/phl/ecPhlCodeTx.mli | 4 +- src/phl/ecPhlLoopTx.ml | 2 +- tests/cfold.ec | 62 ++++++++ 8 files changed, 321 insertions(+), 121 deletions(-) diff --git a/src/ecCoreModules.ml b/src/ecCoreModules.ml index 5d0773018d..891e897e28 100644 --- a/src/ecCoreModules.ml +++ b/src/ecCoreModules.ml @@ -55,6 +55,16 @@ let lv_of_expr e = LvTuple (List.map (fun e -> EcTypes.destr_var e, e_ty e) pvs) | _ -> failwith "failed to construct lv from expr" +let explode_assgn (lv : lvalue) (e : expr) : ((prog_var * ty) * expr) list = + match lv, e with + | LvVar lv, e -> + [(lv, e)] + | LvTuple lvs, { e_node = Etuple es } -> + List.combine lvs es + | LvTuple lvs, e -> + List.mapi (fun i (pv, ty) -> + ((pv, ty), e_proj_simpl e i ty)) lvs + (* -------------------------------------------------------------------- *) type instr = EcAst.instr @@ -161,6 +171,14 @@ let is_while = _is_of_get get_while let is_match = _is_of_get get_match let is_raise = _is_of_get get_raise +(* -------------------------------------------------------------------- *) +let i_asgn_of_pve (pve : ((prog_var * ty) * expr) list) : instr option = + let lvs, es = List.split pve in + + lvs + |> lv_of_list + |> omap (fun lvs -> i_asgn (lvs, e_tuple es)) + (* -------------------------------------------------------------------- *) let i_iter (f : instr -> unit) = let rec i_iter (i : instr) = @@ -181,6 +199,24 @@ let i_iter (f : instr -> unit) = in fun (i : instr) -> i_iter i +(* -------------------------------------------------------------------- *) +let i_map_expr (tx : expr -> expr) = + let rec doit (i : instr) = + match i.i_node with + | Sasgn (lv, e) -> i_asgn (lv, (tx e)) + | Sif (c, t, f) -> i_if (tx c, doit_s t, doit_s f) + | Smatch (e, cs) -> i_match (tx e, List.map (snd_map doit_s) cs) + | Swhile (c, bd) -> i_while (tx c, doit_s bd) + | Srnd (lv, e) -> i_rnd (lv, tx e) + | Sraise e -> i_raise (tx e) + | Sabstract (_ : memory) -> i + | Scall (lv, f, args) -> i_call (lv, f, List.map tx args) + + and doit_s (s : stmt) = + stmt (List.map doit s.s_node) in + + fun i -> doit i + (* -------------------------------------------------------------------- *) module Uninit = struct (* FIXME: generalize this for use in ecPV *) let e_pv e = diff --git a/src/ecCoreModules.mli b/src/ecCoreModules.mli index 080d7f6474..69b7a3753a 100644 --- a/src/ecCoreModules.mli +++ b/src/ecCoreModules.mli @@ -15,6 +15,7 @@ val lv_to_list : lvalue -> prog_var list val lv_to_ty_list : lvalue -> (prog_var * ty) list val name_of_lv : lvalue -> string val lv_of_expr : expr -> lvalue +val explode_assgn : lvalue -> expr -> ((prog_var * ty) * expr) list (* --------------------------------------------------------------------- *) type instr = EcAst.instr @@ -76,8 +77,12 @@ val is_while : instr -> bool val is_match : instr -> bool val is_raise : instr -> bool +(* -------------------------------------------------------------------- *) +val i_asgn_of_pve : ((prog_var * ty) * expr) list -> instr option + (* -------------------------------------------------------------------- *) val i_iter : (instr -> unit) -> instr -> unit +val i_map_expr : (expr -> expr) -> instr -> instr (* -------------------------------------------------------------------- *) val get_uninit_read : stmt -> Sx.t diff --git a/src/ecParser.mly b/src/ecParser.mly index 52e9f4b499..d97f4ccb93 100644 --- a/src/ecParser.mly +++ b/src/ecParser.mly @@ -2990,8 +2990,8 @@ interleave_info: | INTERLEAVE info=loc(interleave_info) { Pinterleave info } -| CFOLD s=side? c=codepos n=word? - { Pcfold (s, c, n) } +| CFOLD eager=boption(STAR) side=side? start=codepos length=word? + { Pcfold { side; start; length; eager; } } | RND s=side? info=rnd_info c=prefix(COLON, semrndpos)? { Prnd (s, c, info) } diff --git a/src/ecParsetree.ml b/src/ecParsetree.ml index 8e12874e82..877f370741 100644 --- a/src/ecParsetree.ml +++ b/src/ecParsetree.ml @@ -757,7 +757,7 @@ type phltactic = | Pcond of pcond_info | Pmatch of matchmode | Pswap of ((oside * pswap_kind) located list) - | Pcfold of (oside * pcodepos * int option) + | Pcfold of pcfold | Pinline of inline_info | Poutline of outline_info | Pinterleave of interleave_info located @@ -812,6 +812,12 @@ and rwprgm = [ | `IdAssign of pcodepos * pqsymbol ] +and pcfold = + { side : oside + ; start : pcodepos + ; length : int option + ; eager : bool } + (* -------------------------------------------------------------------- *) type include_exclude = [ `Include | `Exclude ] type pdbmap1 = { diff --git a/src/phl/ecPhlCodeTx.ml b/src/phl/ecPhlCodeTx.ml index 2968953af1..a49af0d92f 100644 --- a/src/phl/ecPhlCodeTx.ml +++ b/src/phl/ecPhlCodeTx.ml @@ -185,93 +185,172 @@ let t_set_match_r (side : oside) (cpos : Position.codepos) (id : symbol) pattern (t_zip (set_match_stmt id pattern)) tc (* -------------------------------------------------------------------- *) -let cfold_stmt ?(simplify = true) (pf, hyps) (me : memenv) (olen : int option) (zpr : Zpr.zipper) = - let env = LDecl.toenv hyps in - - let simplify : expr -> expr = - if simplify then (fun e -> - let e = ss_inv_of_expr (fst me) e in - let e = map_ss_inv1 (EcReduction.simplify EcReduction.nodelta hyps) e in - let e = expr_of_ss_inv e in - e - ) else identity in - - let for_instruction ((subst as subst0) : (expr, unit) Mpv.t) (i : instr) = - let wr = EcPV.i_write env i in - let i = Mpv.isubst env subst i in - - let (subst, asgn) = - List.fold_left_map (fun subst (pv, e) -> - let exception Remove in - - try - if PV.mem_pv env pv wr then raise Remove; - let rd = EcPV.e_read env e in - if PV.mem_pv env pv rd then raise Remove; - subst, None - - with Remove -> - Mpv.remove env pv subst, Some ((pv, e.e_ty), e) - ) subst (EcPV.Mnpv.bindings (Mpv.pvs subst)) in - - let asgn = List.filter_map identity asgn in - - let mk_asgn (lve : ((prog_var * ty) * expr) list) = - let lvs, es = List.split lve in - lv_of_list lvs - |> Option.map (fun lv -> i_asgn (lv, e_tuple es)) - |> Option.to_list in - - let exception Interrupt in - - try - let subst, aout = - let exception Default in - - try - match i.i_node with - | Sasgn (lv, e) -> - (* We already removed the variables of `lv` & the rhs from the substitution *) - (* We are only interested in the variables of `lv` that are in `wr` *) - let es = - match simplify e, lv with - | { e_node = Etuple es }, LvTuple _ -> es - | _, LvTuple _ -> raise Default - | e, _ -> [e] in - - let lv = lv_to_ty_list lv in - - let tosubst, asgn2 = List.partition (fun ((pv, _), _) -> - Mpv.mem env pv subst0 - ) (List.combine lv es) in - - let subst = - List.fold_left - (fun subst ((pv, _), e) -> Mpv.add env pv e subst) - subst tosubst in - - let asgn = - List.filter - (fun ((pv, _), _) -> not (Mpv.mem env pv subst)) - asgn in - - (subst, mk_asgn asgn @ mk_asgn asgn2) - - | Srnd _ -> - (subst, mk_asgn asgn @ [i]) - - | _ -> raise Default - - with Default -> - if List.exists - (fun (pv, _) -> Mpv.mem env pv subst0) - (fst (PV.elements wr)) - then raise Interrupt; - (subst, mk_asgn asgn @ [i]) - - in `Continue (subst, aout) +(* + Works on a block starting at an assignment to local variables. + + It initializes: + - propagate: a substitution mapping the assigned variables to their values + - preserve : for each propagated variable, the variables that must keep their + current value for that propagated expression to remain valid + + It then scans subsequent instructions from left to right. + + For assignments: + - if the assigned variable is preserved, stop in non-eager mode; in eager + mode, substitute in the right-hand side and promote that variable to the + propagated substitution + - if the assigned variable is already propagated, update its propagated value + and recompute its preservation set + - otherwise, substitute propagated values in the right-hand side and keep the + assignment + + For calls, loops, conditionals, matches, and random samplings: + - continue only if none of the currently propagated or preserved variables is + written by the instruction; in that case, substitute propagated values in + the instruction + - otherwise, stop + + For abstract instructions without calls: + - continue only if they neither read nor write propagated or preserved + variables + - otherwise, stop + + When the scan stops, the remaining propagated substitution is materialized as + assignments appended after the transformed prefix. +*) + +let cfold_stmt + ?(simplify : bool = true) + ?(eager : bool = true) + ((pf, hyps) : proofenv * LDecl.hyps) + (me : memenv) + (olen : int option) + (zpr : Zpr.zipper) += + let env = LDecl.toenv hyps in + + let e_simplify (e : expr) = + let e = form_of_expr ~m:(fst me) e in + let e = EcReduction.simplify EcReduction.nodelta hyps e in + expr_of_ss_inv { m = fst me; inv = e } in + + let i_simplify (i : instr) = + i_map_expr e_simplify i in + + let e_simplify, i_simplify = + if simplify + then (identity, identity) + else (e_simplify, i_simplify) in + + (* + Process one instruction under the current propagated substitution and + preservation map. + + - `Continue ((subst, preserve), is)` means that propagation may proceed, + with updated state and replacement instructions `is` + - `Interrupt` means that propagation stops before this instruction + + In eager mode, assigning to a preserved variable does not stop the scan: + the assigned expression is first substituted, then that variable is + promoted into the propagated substitution. + *) + let for_instruction (subst, preserve: (expr, unit) Mpv.t * (PV.t Mnpv.t)) (i : instr) = + let esubst subst e = + EcPV.Mpv.esubst env subst e |> e_simplify + in + let isubst subst i = + EcPV.Mpv.isubst env subst i |> i_simplify + in + let is_preserved preserve pv = + Mnpv.exists (fun _ preserve -> EcPV.PV.mem_pv env pv preserve) preserve + in + let is_propagated subst pv = + Mnpv.contains (Mpv.pvs subst) pv + in + let propagated_pvs subst = + (Mpv.pvs subst) |> Mnpv.bindings |> List.fst + in + (* Update preserve vars on assignment to given PV *) + (* Do not include any propagated vars, since these *) + (* are automatically preserved by construction *) + let update_preserved preserve subst pv e = + let rd = EcPV.e_read env e in + let rd = List.fold_left (fun rd pv -> + EcPV.PV.remove env pv rd + ) rd (propagated_pvs subst) + in + Mnpv.add pv rd preserve + in + let promote_preserved_to_propagated subst preserve pv (e:expr) = + let preserve = Mnpv.map (fun preserve -> + PV.remove env pv preserve + ) preserve + in + let subst = Mpv.add env pv e subst in + (subst, preserve) + in - with Interrupt -> `Interrupt + match i.i_node with + | Sasgn (lv, e) -> + let asgns = explode_assgn lv e in + let exception Abort in + begin try + let (subst, preserve), asgns = List.fold_left_map (fun (subst, preserve) ((pv, t), e) -> + (* 1. When hitting an assignment to a preserved var *) + if is_preserved preserve pv then + if eager (* 1.1 Promote to propagated on eager *) + then + let e = esubst subst e in + promote_preserved_to_propagated subst preserve pv e, None + else raise Abort (* 1.2 Fail on non-eager *) + else + (* 2. When not preserved and not propagated, do nothing *) + if not (is_propagated subst pv) then + (subst, preserve), Some ((pv, t), esubst subst e) + (* 3. When propagated, propagate *) + else + let e = esubst subst e in + let preserve = update_preserved preserve subst pv e in + let subst = Mpv.add env pv e subst in + (subst, preserve), None + ) (subst, preserve) asgns + in + let asgns = List.filter_map identity asgns in + `Continue ((subst, preserve), Option.to_list (i_asgn_of_pve asgns)) + with Abort -> `Interrupt + end + + | Srnd _ + | Scall _ + | Swhile _ + | Sif _ + | Smatch _ -> + let wr = EcPV.i_write env i in + let spvs = Mnpv.keys (Mpv.pvs subst) in + let ppvs = Mnpv.keys preserve in + if + let check = List.for_all (fun pv -> + not @@ EcPV.PV.mem_pv env pv wr) in + check spvs && check ppvs + then + `Continue ((subst, preserve), [isubst subst i]) + else + `Interrupt + + | Sraise _ -> `Interrupt + + | Sabstract id -> + let aus = EcEnv.AbsStmt.byid id env in + begin match aus with + | { aus_calls = []; aus_reads; aus_writes } -> + if List.for_all (fun (pv, _) -> + not ((is_propagated subst pv) || (is_preserved preserve pv)) + ) (aus_reads @ aus_writes) then + `Continue ((subst, preserve), [i]) + else + `Interrupt + | _ -> `Interrupt + end in let body, epilog = @@ -283,68 +362,80 @@ let cfold_stmt ?(simplify = true) (pf, hyps) (me : memenv) (olen : int option) ( tc_error pf "expecting at least %d instructions" olen; List.takedrop (olen+1) zpr.z_tail in - let lv, subst, body, rem = + let _lv, (subst, _preserve), body, rem = match body with | { i_node = Sasgn (lv, e) } :: is -> - let es = - match simplify e, lv with - | { e_node = Etuple es }, LvTuple _ -> es - | _, LvTuple _ -> - tc_error pf - "the left-value is a tuple but the right-hand expression \ - is not a tuple expression"; - | e, _ -> [e] in - let lv = lv_to_ty_list lv in - + let asgns = explode_assgn lv e in + let lv = List.fst asgns in + if not (List.for_all (is_loc -| fst) lv) then tc_error pf "left-values must be made of local variables only"; + (* Variables in the domain of substs + are variables to be propagated *) let subst = List.fold_left (fun subst ((pv, _), e) -> Mpv.add env pv e subst) - Mpv.empty (List.combine lv es) in + Mpv.empty asgns in + + let preserve = + List.fold_left + (fun preserve ((pv, _), e) -> + Mnpv.add + pv + EcPV.(PV.remove env pv (e_read env e)) + preserve) + Mnpv.empty + asgns + in - let subst, is, rem = - List.fold_left_map_while for_instruction subst is in + let (subst, preserve), is, rem = + List.fold_left_map_while for_instruction (subst, preserve) is in - lv, subst, List.flatten is, rem + lv, (subst, preserve), List.flatten is, rem | _ -> tc_error pf "cannot find a left-value assignment at given position" in - let lv, es = - List.filter_map (fun ((pv, _) as pvty) -> - match Mpv.find env pv subst with - | e -> Some (pvty, e) - | exception Not_found -> None - ) lv |> List.split in + let asgns = Mnpv.bindings (Mpv.pvs subst) in + + let lv, es = List.map (fun (pv, e) -> + (pv, e_ty e), e) asgns |> List.split + in let asgn = lv_of_list lv |> Option.map (fun lv -> i_asgn (lv, e_tuple es)) |> Option.to_list in - let zpr = { zpr with Zpr.z_tail = body @ asgn @ rem @ epilog } in + let zpr = + { zpr with Zpr.z_tail = body @ asgn @ rem @ epilog } in + (me, zpr, []) (* -------------------------------------------------------------------- *) -let t_cfold_r side cpos olen g = +let t_cfold + ~(eager : bool) + (side : side option) + (cpos : Position.codepos) + (olen : int option) + (tc : tcenv1) += let tr = fun side -> `Fold (side, cpos, olen) in - let cb = fun cenv _ me zpr -> cfold_stmt cenv me olen zpr in - t_code_transform side ~bdhoare:true cpos tr (t_zip cb) g + let cb = fun cenv _ me zpr -> cfold_stmt ~eager cenv me olen zpr in + t_code_transform side ~bdhoare:true cpos tr (t_zip cb) tc (* -------------------------------------------------------------------- *) let t_kill = FApi.t_low3 "code-tx-kill" t_kill_r let t_alias = FApi.t_low3 "code-tx-alias" t_alias_r let t_set = FApi.t_low4 "code-tx-set" t_set_r let t_set_match = FApi.t_low4 "code-tx-set-match" t_set_match_r -let t_cfold = FApi.t_low3 "code-tx-cfold" t_cfold_r (* -------------------------------------------------------------------- *) -let process_cfold (side, cpos, olen) tc = - let cpos = EcLowPhlGoal.tc1_process_codepos tc (side, cpos) in - t_cfold side cpos olen tc +let process_cfold (info : pcfold) tc = + let cpos = EcLowPhlGoal.tc1_process_codepos tc (info.side, info.start) in + t_cfold ~eager:info.eager info.side cpos info.length tc let process_kill (side, cpos, len) tc = let cpos = EcLowPhlGoal.tc1_process_codepos tc (side, cpos) in diff --git a/src/phl/ecPhlCodeTx.mli b/src/phl/ecPhlCodeTx.mli index b1dab22744..468b4b0562 100644 --- a/src/phl/ecPhlCodeTx.mli +++ b/src/phl/ecPhlCodeTx.mli @@ -12,14 +12,14 @@ val t_kill : oside -> codepos -> int option -> backward val t_alias : oside -> codepos -> psymbol option -> backward val t_set : oside -> codepos -> bool * psymbol -> expr -> backward val t_set_match : oside -> codepos -> symbol -> unienv * mevmap * form -> backward -val t_cfold : oside -> codepos -> int option -> backward +val t_cfold : eager:bool -> oside -> codepos -> int option -> backward (* -------------------------------------------------------------------- *) val process_kill : oside * pcodepos * int option -> backward val process_alias : oside * pcodepos * psymbol option -> backward val process_set : oside * pcodepos * bool * psymbol * pexpr -> backward val process_set_match : oside * pcodepos * psymbol * pformula -> backward -val process_cfold : oside * pcodepos * int option -> backward +val process_cfold : pcfold -> backward val process_case : oside * pcodepos -> backward (* -------------------------------------------------------------------- *) diff --git a/src/phl/ecPhlLoopTx.ml b/src/phl/ecPhlLoopTx.ml index 60aa5e71a6..cbe4027de1 100644 --- a/src/phl/ecPhlLoopTx.ml +++ b/src/phl/ecPhlLoopTx.ml @@ -334,7 +334,7 @@ let process_unroll_for ~cfold side cpos tc = let cpos = EcMatching.Position.shift ~offset:(-1) cpos in let clen = blen * (List.length zs - 1) in - FApi.t_last (EcPhlCodeTx.t_cfold side cpos (Some clen)) tcenv + FApi.t_last (EcPhlCodeTx.t_cfold ~eager:false side cpos (Some clen)) tcenv end else tcenv (* -------------------------------------------------------------------- *) diff --git a/tests/cfold.ec b/tests/cfold.ec index 3d6d435623..0bd44f036e 100644 --- a/tests/cfold.ec +++ b/tests/cfold.ec @@ -1,6 +1,68 @@ (* -------------------------------------------------------------------- *) require import AllCore Distr. +(* -------------------------------------------------------------------- *) +theory CfoldSelf. + module M = { + proc f(a : int, b : int) : int = { + var c : int; + var d : int; + + c <- c; + c <- c + 1; + c <- c + d; + d <- b + a; + c <- d; + if (a + b = c) { + c <- 0; + a <- c; + } else { + c <- 1; + b <- c; + } + return c; + } + }. + + lemma L : hoare[M.f : true ==> res = 0]. + proof. + proc. + cfold 1. + by auto => /> ?; apply addzC. + qed. +end CfoldSelf. + +(* -------------------------------------------------------------------- *) +theory CfoldStarSelf. + module M = { + proc f(a : int, b : int) : int = { + var c : int; + var d : int; + + c <- c; + c <- c + 1; + c <- c + d; + d <- b + a; + c <- d; + if (a + b = c) { + c <- 0; + a <- c; + } else { + c <- 1; + b <- c; + } + return c; + } + }. + + lemma L : hoare[M.f : true ==> res = 0]. + proof. + proc. + cfold* 1. + by auto => /> ?; apply addzC. + qed. +end CfoldStarSelf. + (* -------------------------------------------------------------------- *) theory CfoldStopIf. module M = {