diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 00000000..b438243a --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,5 @@ +[alias] +testall = ["test", "--all", "--release"] + +[target.'cfg(all())'] +rustflags = ["-C", "target-cpu=native"] \ No newline at end of file diff --git a/.gitignore b/.gitignore index beebda35..5996f514 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target .vscode /docs/benchmark_graphs/.venv +minimal_zkVM.synctex.gz \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index f00002a1..5a4ccc2a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -16,7 +16,6 @@ name = "air" version = "0.1.0" dependencies = [ "multilinear-toolkit", - "p3-air", "p3-koala-bear", "p3-util", "rand", @@ -24,6 +23,14 @@ dependencies = [ "utils", ] +[[package]] +name = "air" +version = "0.3.0" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#eeb6404188e4b4c9207b9e8e8a2120156602d92e" +dependencies = [ + "p3-field", +] + [[package]] name = "ansi_term" version = "0.12.1" @@ -69,7 +76,7 @@ version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ - "windows-sys 0.61.2", + "windows-sys", ] [[package]] @@ -80,7 +87,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys 0.61.2", + "windows-sys", ] [[package]] @@ -92,15 +99,13 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "backend" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#62766141561550c3540f9f644085fec53d721f16" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#eeb6404188e4b4c9207b9e8e8a2120156602d92e" dependencies = [ - "fiat-shamir", "itertools", "p3-field", "p3-util", "rand", "rayon", - "tracing", ] [[package]] @@ -129,9 +134,9 @@ checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" [[package]] name = "clap" -version = "4.5.53" +version = "4.5.54" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9e340e012a1bf4935f5282ed1436d1489548e8f72308207ea5df0e23d2d03f8" +checksum = "c6e6ff9dcd79cff5cd969a17a545d79e84ab086e444102a591e288a8aa3ce394" dependencies = [ "clap_builder", "clap_derive", @@ -139,9 +144,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.53" +version = "4.5.54" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d76b5d13eaa18c901fd2f7fca939fefe3a0727a953561fefdf3b2922b8569d00" +checksum = "fa42cf4d2b7a41bc8f663a7cab4031ebafa1bf3875705bfaf8466dc60ab52c00" dependencies = [ "anstream", "anstyle", @@ -163,9 +168,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.6" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" +checksum = "c3e64b0cc0439b12df2fa678eae89a1c56a529fd067a9115f7827f1fffd22b32" [[package]] name = "colorchoice" @@ -175,32 +180,23 @@ checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" [[package]] name = "colored" -version = "3.0.0" +version = "3.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e" +checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34" dependencies = [ - "windows-sys 0.59.0", + "windows-sys", ] [[package]] name = "constraints-folder" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#62766141561550c3540f9f644085fec53d721f16" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#eeb6404188e4b4c9207b9e8e8a2120156602d92e" dependencies = [ - "fiat-shamir", - "p3-air", + "air 0.3.0", + "backend", "p3-field", ] -[[package]] -name = "convert_case" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9" -dependencies = [ - "unicode-segmentation", -] - [[package]] name = "cpufeatures" version = "0.2.17" @@ -245,29 +241,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "derive_more" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d751e9e49156b02b44f9c1815bcb94b984cdcc4396ecc32521c739452808b134" -dependencies = [ - "derive_more-impl", -] - -[[package]] -name = "derive_more-impl" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "799a97264921d8623a957f6c3b9011f3b5492f557bbb7a5a19b7fa6d06ba8dcb" -dependencies = [ - "convert_case", - "proc-macro2", - "quote", - "rustc_version", - "syn", - "unicode-xid", -] - [[package]] name = "digest" version = "0.10.7" @@ -293,11 +266,11 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "fiat-shamir" version = "0.1.0" -source = "git+https://github.com/leanEthereum/fiat-shamir.git#bcf23c766f2e930acf11e68777449483a55af077" +source = "git+https://github.com/leanEthereum/fiat-shamir.git#493be07bdd8669a816d28c3befc08bc3e68e590a" dependencies = [ - "p3-challenger", "p3-field", - "p3-koala-bear", + "p3-symmetric", + "rayon", "serde", ] @@ -337,9 +310,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "indexmap" -version = "2.12.1" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", "hashbrown", @@ -362,9 +335,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ee5b5339afb4c41626dde77b7a611bd4f2c202b897852b4bcf5d03eddc61010" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" [[package]] name = "keccak" @@ -385,11 +358,15 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" name = "lean-multisig" version = "0.1.0" dependencies = [ + "air 0.1.0", "clap", + "lean_vm", "multilinear-toolkit", "p3-koala-bear", - "poseidon_circuit", + "rand", "rec_aggregation", + "sub_protocols", + "utils", "whir-p3", "xmss", ] @@ -398,12 +375,9 @@ dependencies = [ name = "lean_compiler" version = "0.1.0" dependencies = [ - "air", + "air 0.1.0", "lean_vm", - "lookup", "multilinear-toolkit", - "p3-air", - "p3-challenger", "p3-koala-bear", "p3-poseidon2", "p3-symmetric", @@ -422,21 +396,17 @@ dependencies = [ name = "lean_prover" version = "0.1.0" dependencies = [ - "air", + "air 0.1.0", "itertools", "lean_compiler", "lean_vm", - "lookup", "multilinear-toolkit", - "p3-air", - "p3-challenger", "p3-koala-bear", "p3-poseidon2", "p3-symmetric", "p3-util", "pest", "pest_derive", - "poseidon_circuit", "rand", "sub_protocols", "tracing", @@ -450,15 +420,11 @@ dependencies = [ name = "lean_vm" version = "0.1.0" dependencies = [ - "air", + "air 0.1.0", "colored", - "derive_more", "itertools", - "lookup", "multilinear-toolkit", "num_enum", - "p3-air", - "p3-challenger", "p3-koala-bear", "p3-poseidon2", "p3-symmetric", @@ -467,7 +433,6 @@ dependencies = [ "pest_derive", "rand", "strum", - "sub_protocols", "thiserror", "tracing", "utils", @@ -477,9 +442,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.178" +version = "0.2.180" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" [[package]] name = "log" @@ -487,20 +452,6 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" -[[package]] -name = "lookup" -version = "0.1.0" -dependencies = [ - "multilinear-toolkit", - "p3-challenger", - "p3-koala-bear", - "p3-util", - "rand", - "tracing", - "utils", - "whir-p3", -] - [[package]] name = "matchers" version = "0.2.0" @@ -519,8 +470,9 @@ checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" [[package]] name = "multilinear-toolkit" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#62766141561550c3540f9f644085fec53d721f16" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#eeb6404188e4b4c9207b9e8e8a2120156602d92e" dependencies = [ + "air 0.3.0", "backend", "constraints-folder", "fiat-shamir", @@ -528,6 +480,7 @@ dependencies = [ "p3-util", "rayon", "sumcheck", + "tracing", ] [[package]] @@ -536,7 +489,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.61.2", + "windows-sys", ] [[package]] @@ -601,19 +554,10 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" -[[package]] -name = "p3-air" -version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1db9df28abd6db586eaa891af2416d94d1b026ae" -dependencies = [ - "p3-field", - "p3-matrix", -] - [[package]] name = "p3-baby-bear" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1db9df28abd6db586eaa891af2416d94d1b026ae" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#87bad263fb05630a27ba2d72ac1509d4b9fe91b1" dependencies = [ "p3-field", "p3-mds", @@ -626,7 +570,7 @@ dependencies = [ [[package]] name = "p3-challenger" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1db9df28abd6db586eaa891af2416d94d1b026ae" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#87bad263fb05630a27ba2d72ac1509d4b9fe91b1" dependencies = [ "p3-field", "p3-maybe-rayon", @@ -638,7 +582,7 @@ dependencies = [ [[package]] name = "p3-commit" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1db9df28abd6db586eaa891af2416d94d1b026ae" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#87bad263fb05630a27ba2d72ac1509d4b9fe91b1" dependencies = [ "itertools", "p3-challenger", @@ -652,7 +596,7 @@ dependencies = [ [[package]] name = "p3-dft" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1db9df28abd6db586eaa891af2416d94d1b026ae" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#87bad263fb05630a27ba2d72ac1509d4b9fe91b1" dependencies = [ "itertools", "p3-field", @@ -665,7 +609,7 @@ dependencies = [ [[package]] name = "p3-field" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1db9df28abd6db586eaa891af2416d94d1b026ae" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#87bad263fb05630a27ba2d72ac1509d4b9fe91b1" dependencies = [ "itertools", "num-bigint", @@ -680,7 +624,7 @@ dependencies = [ [[package]] name = "p3-interpolation" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1db9df28abd6db586eaa891af2416d94d1b026ae" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#87bad263fb05630a27ba2d72ac1509d4b9fe91b1" dependencies = [ "p3-field", "p3-matrix", @@ -691,7 +635,7 @@ dependencies = [ [[package]] name = "p3-koala-bear" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1db9df28abd6db586eaa891af2416d94d1b026ae" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#87bad263fb05630a27ba2d72ac1509d4b9fe91b1" dependencies = [ "itertools", "num-bigint", @@ -707,7 +651,7 @@ dependencies = [ [[package]] name = "p3-matrix" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1db9df28abd6db586eaa891af2416d94d1b026ae" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#87bad263fb05630a27ba2d72ac1509d4b9fe91b1" dependencies = [ "itertools", "p3-field", @@ -722,7 +666,7 @@ dependencies = [ [[package]] name = "p3-maybe-rayon" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1db9df28abd6db586eaa891af2416d94d1b026ae" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#87bad263fb05630a27ba2d72ac1509d4b9fe91b1" dependencies = [ "rayon", ] @@ -730,7 +674,7 @@ dependencies = [ [[package]] name = "p3-mds" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1db9df28abd6db586eaa891af2416d94d1b026ae" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#87bad263fb05630a27ba2d72ac1509d4b9fe91b1" dependencies = [ "p3-dft", "p3-field", @@ -742,7 +686,7 @@ dependencies = [ [[package]] name = "p3-merkle-tree" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1db9df28abd6db586eaa891af2416d94d1b026ae" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#87bad263fb05630a27ba2d72ac1509d4b9fe91b1" dependencies = [ "itertools", "p3-commit", @@ -759,7 +703,7 @@ dependencies = [ [[package]] name = "p3-monty-31" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1db9df28abd6db586eaa891af2416d94d1b026ae" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#87bad263fb05630a27ba2d72ac1509d4b9fe91b1" dependencies = [ "itertools", "num-bigint", @@ -781,7 +725,7 @@ dependencies = [ [[package]] name = "p3-poseidon2" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1db9df28abd6db586eaa891af2416d94d1b026ae" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#87bad263fb05630a27ba2d72ac1509d4b9fe91b1" dependencies = [ "p3-field", "p3-mds", @@ -793,7 +737,7 @@ dependencies = [ [[package]] name = "p3-symmetric" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1db9df28abd6db586eaa891af2416d94d1b026ae" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#87bad263fb05630a27ba2d72ac1509d4b9fe91b1" dependencies = [ "itertools", "p3-field", @@ -803,7 +747,7 @@ dependencies = [ [[package]] name = "p3-util" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#1db9df28abd6db586eaa891af2416d94d1b026ae" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#87bad263fb05630a27ba2d72ac1509d4b9fe91b1" dependencies = [ "rayon", "serde", @@ -817,9 +761,9 @@ checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] name = "pest" -version = "2.8.4" +version = "2.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbcfd20a6d4eeba40179f05735784ad32bdaef05ce8e8af05f180d45bb3e7e22" +checksum = "2c9eb05c21a464ea704b53158d358a31e6425db2f63a1a7312268b05fe2b75f7" dependencies = [ "memchr", "ucd-trie", @@ -827,9 +771,9 @@ dependencies = [ [[package]] name = "pest_derive" -version = "2.8.4" +version = "2.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51f72981ade67b1ca6adc26ec221be9f463f2b5839c7508998daa17c23d94d7f" +checksum = "68f9dbced329c441fa79d80472764b1a2c7e57123553b8519b36663a2fb234ed" dependencies = [ "pest", "pest_generator", @@ -837,9 +781,9 @@ dependencies = [ [[package]] name = "pest_generator" -version = "2.8.4" +version = "2.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dee9efd8cdb50d719a80088b76f81aec7c41ed6d522ee750178f83883d271625" +checksum = "3bb96d5051a78f44f43c8f712d8e810adb0ebf923fc9ed2655a7f66f63ba8ee5" dependencies = [ "pest", "pest_meta", @@ -850,9 +794,9 @@ dependencies = [ [[package]] name = "pest_meta" -version = "2.8.4" +version = "2.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf1d70880e76bdc13ba52eafa6239ce793d85c8e43896507e43dd8984ff05b82" +checksum = "602113b5b5e8621770cfd490cfd90b9f84ab29bd2b0e49ad83eb6d186cef2365" dependencies = [ "pest", "sha2", @@ -864,21 +808,6 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" -[[package]] -name = "poseidon_circuit" -version = "0.1.0" -dependencies = [ - "multilinear-toolkit", - "p3-koala-bear", - "p3-monty-31", - "p3-poseidon2", - "rand", - "sub_protocols", - "tracing", - "utils", - "whir-p3", -] - [[package]] name = "ppv-lite86" version = "0.2.21" @@ -899,18 +828,18 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.103" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.42" +version = "1.0.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" dependencies = [ "proc-macro2", ] @@ -943,9 +872,9 @@ dependencies = [ [[package]] name = "rand_core" -version = "0.9.3" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" dependencies = [ "getrandom", ] @@ -974,15 +903,12 @@ dependencies = [ name = "rec_aggregation" version = "0.1.0" dependencies = [ - "air", + "air 0.1.0", "bincode", "lean_compiler", "lean_prover", "lean_vm", - "lookup", "multilinear-toolkit", - "p3-air", - "p3-challenger", "p3-koala-bear", "p3-poseidon2", "p3-symmetric", @@ -1014,27 +940,12 @@ version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" -[[package]] -name = "rustc_version" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" -dependencies = [ - "semver", -] - [[package]] name = "rustversion" version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" -[[package]] -name = "semver" -version = "1.0.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" - [[package]] name = "serde" version = "1.0.228" @@ -1067,9 +978,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.147" +version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6af14725505314343e673e9ecb7cd7e8a36aa9791eb936235a3567cc31447ae4" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ "itoa", "memchr", @@ -1151,8 +1062,7 @@ dependencies = [ name = "sub_protocols" version = "0.1.0" dependencies = [ - "derive_more", - "lookup", + "lean_vm", "multilinear-toolkit", "p3-koala-bear", "p3-util", @@ -1165,12 +1075,12 @@ dependencies = [ [[package]] name = "sumcheck" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#62766141561550c3540f9f644085fec53d721f16" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#eeb6404188e4b4c9207b9e8e8a2120156602d92e" dependencies = [ + "air 0.3.0", "backend", "constraints-folder", "fiat-shamir", - "p3-air", "p3-field", "p3-util", "rayon", @@ -1178,9 +1088,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.111" +version = "2.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" +checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" dependencies = [ "proc-macro2", "quote", @@ -1189,18 +1099,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.17" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "2.0.17" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" dependencies = [ "proc-macro2", "quote", @@ -1280,9 +1190,9 @@ dependencies = [ [[package]] name = "tracing-forest" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92bdb3c949c9e81b71f78ba782f956b896019d82cc2f31025d21e04adab4d695" +checksum = "f09cb459317a3811f76644334473239d696cd8efc606963ae7d1c308cead3b74" dependencies = [ "ansi_term", "smallvec", @@ -1348,18 +1258,6 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" -[[package]] -name = "unicode-segmentation" -version = "1.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" - -[[package]] -name = "unicode-xid" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" - [[package]] name = "utf8parse" version = "0.2.2" @@ -1371,8 +1269,6 @@ name = "utils" version = "0.1.0" dependencies = [ "multilinear-toolkit", - "p3-air", - "p3-challenger", "p3-koala-bear", "p3-poseidon2", "p3-symmetric", @@ -1397,9 +1293,9 @@ checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "wasip2" -version = "1.0.1+wasi-0.2.4" +version = "1.0.2+wasi-0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" dependencies = [ "wit-bindgen", ] @@ -1407,12 +1303,11 @@ dependencies = [ [[package]] name = "whir-p3" version = "0.1.0" -source = "git+https://github.com/TomWambsgans/whir-p3?branch=lean-multisig#04fb1c1f2e3bbd14e6e4aee32621656eb3f3949f" +source = "git+https://github.com/TomWambsgans/whir-p3?branch=lean-multisig#bc7bf99c224d63582945f065f555b9824d843d3c" dependencies = [ "itertools", "multilinear-toolkit", "p3-baby-bear", - "p3-challenger", "p3-commit", "p3-dft", "p3-field", @@ -1425,7 +1320,6 @@ dependencies = [ "p3-util", "rand", "rayon", - "thiserror", "tracing", "tracing-forest", "tracing-subscriber", @@ -1459,15 +1353,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" -[[package]] -name = "windows-sys" -version = "0.59.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" -dependencies = [ - "windows-targets", -] - [[package]] name = "windows-sys" version = "0.61.2" @@ -1477,70 +1362,6 @@ dependencies = [ "windows-link", ] -[[package]] -name = "windows-targets" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" -dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" - -[[package]] -name = "windows_i686_gnu" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" - -[[package]] -name = "windows_i686_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" - -[[package]] -name = "windows_i686_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" - [[package]] name = "winnow" version = "0.7.14" @@ -1552,22 +1373,18 @@ dependencies = [ [[package]] name = "wit-bindgen" -version = "0.46.0" +version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" [[package]] name = "witness_generation" version = "0.1.0" dependencies = [ - "air", - "derive_more", + "air 0.1.0", "lean_compiler", "lean_vm", - "lookup", "multilinear-toolkit", - "p3-air", - "p3-challenger", "p3-koala-bear", "p3-monty-31", "p3-poseidon2", @@ -1575,7 +1392,6 @@ dependencies = [ "p3-util", "pest", "pest_derive", - "poseidon_circuit", "rand", "sub_protocols", "tracing", @@ -1598,18 +1414,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.31" +version = "0.8.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd74ec98b9250adb3ca554bdde269adf631549f51d8a8f8f0a10b50f1cb298c3" +checksum = "668f5168d10b9ee831de31933dc111a459c97ec93225beb307aed970d1372dfd" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.31" +version = "0.8.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8a8d209fdf45cf5138cbb5a506f6b52522a25afccc534d1475dad8e31105c6a" +checksum = "2c7962b26b0a8685668b671ee4b54d007a67d4eaf05fda79ac0ecf41e32270f1" dependencies = [ "proc-macro2", "quote", @@ -1618,6 +1434,6 @@ dependencies = [ [[package]] name = "zmij" -version = "0.1.9" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0095ecd462946aa3927d9297b63ef82fb9a5316d7a37d134eeb36e58228615a" +checksum = "dfcd145825aace48cff44a8844de64bf75feec3080e0aa5cdbde72961ae51a65" diff --git a/Cargo.toml b/Cargo.toml index 15369a4b..bc37a467 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,19 +48,16 @@ utils = { path = "crates/utils" } lean_vm = { path = "crates/lean_vm" } xmss = { path = "crates/xmss" } sub_protocols = { path = "crates/sub_protocols" } -lookup = { path = "crates/lookup" } lean_compiler = { path = "crates/lean_compiler" } lean_prover = { path = "crates/lean_prover" } rec_aggregation = { path = "crates/rec_aggregation" } witness_generation = { path = "crates/lean_prover/witness_generation" } -poseidon_circuit = { path = "crates/poseidon_circuit" } # External thiserror = "2.0" clap = { version = "4.5.52", features = ["derive"] } rand = "0.9.2" sha3 = "0.10.8" -derive_more = { version = "2.0.1", features = ["full"] } pest = "2.7" pest_derive = "2.7" itertools = "0.14.0" @@ -77,47 +74,24 @@ p3-koala-bear = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = p3-baby-bear = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } p3-poseidon2 = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } p3-symmetric = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } -p3-air = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } p3-goldilocks = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } -p3-challenger = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } p3-util = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } p3-monty-31 = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } whir-p3 = { git = "https://github.com/TomWambsgans/whir-p3", branch = "lean-multisig" } multilinear-toolkit = { git = "https://github.com/leanEthereum/multilinear-toolkit.git" } - [dependencies] clap.workspace = true rec_aggregation.workspace = true xmss.workspace = true -poseidon_circuit.workspace = true +air.workspace = true +rand.workspace = true +sub_protocols.workspace = true +utils.workspace = true p3-koala-bear.workspace = true +lean_vm.workspace = true multilinear-toolkit.workspace = true whir-p3.workspace = true -# [patch."https://github.com/TomWambsgans/Plonky3.git"] -# p3-koala-bear = { path = "../Plonky3/koala-bear" } -# p3-field = { path = "../Plonky3/field" } -# p3-poseidon2 = { path = "../Plonky3/poseidon2" } -# p3-symmetric = { path = "../Plonky3/symmetric" } -# p3-air = { path = "../Plonky3/air" } -# p3-merkle-tree = { path = "../Plonky3/merkle-tree" } -# p3-commit = { path = "../Plonky3/commit" } -# p3-matrix = { path = "../Plonky3/matrix" } -# p3-dft = { path = "../Plonky3/dft" } -# p3-challenger = { path = "../Plonky3/challenger" } -# p3-monty-31 = { path = "../Plonky3/monty-31" } -# p3-maybe-rayon = { path = "../Plonky3/maybe-rayon" } -# p3-util = { path = "../Plonky3/util" } - -# [patch."https://github.com/TomWambsgans/whir-p3.git"] -# whir-p3 = { path = "../whir-p3" } - -# [patch."https://github.com/leanEthereum/multilinear-toolkit.git"] -# multilinear-toolkit = { path = "../multilinear-toolkit" } - -# [profile.release] -# opt-level = 1 - [profile.release] lto = "thin" diff --git a/README.md b/README.md index b317b721..2e171dcd 100644 --- a/README.md +++ b/README.md @@ -1,95 +1,53 @@ -

♦ leanMultisig ♦

+

leanMultisig

-XMSS + minimal [zkVM](minimal_zkVM.pdf) = lightweight PQ signatures, with unbounded aggregation +

+ +

-## Status +Minimal hash-based zkVM, targeting recursion and aggregation of hash-based signatures, for a Post-Quantum Ethereum. -- branch [main](https://github.com/leanEthereum/leanMultisig): optimized for **prover efficiency** -- branch [lean-vm-simple](https://github.com/leanEthereum/leanMultisig/tree/lean-vm-simple): optimized for **simplicity** - -Both versions will eventually merge into one. +Documentation: [PDF](minimal_zkVM.pdf) ## Proving System -- [WHIR](https://eprint.iacr.org/2024/1586.pdf) -- [SuperSpartan](https://eprint.iacr.org/2023/552.pdf), with AIR-specific optimizations developed by W. Borgeaud in [A simple multivariate AIR argument inspired by SuperSpartan](https://solvable.group/posts/super-air/#fnref:1) -- [Univariate Skip](https://eprint.iacr.org/2024/108.pdf) -- [Logup*](https://eprint.iacr.org/2025/946.pdf) -- ... +- multilinear with [WHIR](https://eprint.iacr.org/2024/1586.pdf) +- [SuperSpartan](https://eprint.iacr.org/2023/552.pdf), with [AIR-specific optimizations](https://solvable.group/posts/super-air/#fnref:1) +- [Logup](https://eprint.iacr.org/2023/1284.pdf) / [Logup*](https://eprint.iacr.org/2025/946.pdf) The VM design is inspired by the famous [Cairo paper](https://eprint.iacr.org/2021/1063.pdf). -## Benchmarks - -Benchmarks are performed on 2 laptops: -- i9-12900H, 32 gb of RAM -- mac m4 max - -target ≈ 128 bits of security, currently using conjecture: 4.12 of [WHIR](https://eprint.iacr.org/2024/1586.pdf), "up to capacity" (TODO: provable security) - -### Poseidon2 - -Poseidon2 over 16 KoalaBear field elements. - -```console -RUSTFLAGS='-C target-cpu=native' cargo run --release -- poseidon --log-n-perms 20 -``` - -![Alt text](docs/benchmark_graphs/graphs/raw_poseidons.svg) - -### Recursion - -The full recursion program is not finished yet. Instead, we prove validity of a WHIR opening, with 25 variables, and rate = 1/4. +### Security -- 1-to-1: Recursive proof of a single WHIR opening -- n-to-1: Recursive proof of many WHIR openings (≈ 8) (we report prover time per WHIR) +123 bits of security. Johnson bound + degree 5 extension of koala-bear -> **no proximity gaps conjecture**. (TODO 128 bits, which requires hash digests bigger than 8 koala-bears). -```console -RUSTFLAGS='-C target-cpu=native' cargo run --release -- recursion --count 8 -``` - -![Alt text](docs/benchmark_graphs/graphs/recursive_whir_opening.svg) - - -### XMSS aggregation - -```console -RUSTFLAGS='-C target-cpu=native' cargo run --release -- xmss --n-signatures 1775 -``` - -[Trivial encoding](docs/XMSS_trivial_encoding.pdf) (for now). - - -![Alt text](docs/benchmark_graphs/graphs/xmss_aggregated.svg) - -![Alt text](docs/benchmark_graphs/graphs/xmss_aggregated_overhead.svg) +## Benchmarks -### Fibonacci: +Machine: M4 Max 48GB (CPU only) -n = 2,000,000 +| Benchmark | Current | Target | +| -------------------------- | -------------------- | --------------- | +| Poseidon2 (16 koala-bears) | `560K Poseidon2 / s` | n/a | +| 2 -> 1 Recursion | `1.35 s` | `0.25 s ` | +| XMSS aggregation | `554 XMSS / s` | `1000 XMSS / s` | -``` -FIB_N=2000000 RUSTFLAGS='-C target-cpu=native' cargo test --release --package lean_prover --test test_zkvm -- --nocapture -- test_prove_fibonacci --exact --nocapture -``` +*Expect incoming perf improvements.* -Proving time: +To reproduce: +- `cargo run --release -- poseidon --log-n-perms 20` +- `cargo run --release -- recursion --n 2` +- `cargo run --release -- xmss --n-signatures 1350` -- i9-12900H: 2.0 s (1.0 MHz) -- mac m4 max: 1.2 s (1.7 MHz) +(Small detail remaining in recursion: final (multilinear) evaluation of the guest program bytecode, there are multiple ways of handling it... TBD soon) ### Proof size -With conjecture "up to capacity", and rate = 1/2, current proofs are about ≈ 400 - 500 KiB. On the [lean-vm-simple](https://github.com/leanEthereum/leanMultisig/tree/lean-vm-simple) branch, proofs are ≈ 300 KiB. This part has not (at all) been optimized (no Merkle pruning...): big gains are expected. - -Target: -- 256 KiB for fast proof (rate = 1/2) -- close to 128 KiB for slower proofs (rate = 1/4 or 1/8) +WHIR intial rate = 1/4. Proof size ≈ 380 KiB. TODO: Merkle pruning + WHIR batch opening -> 256 KiB. (To go below 256 KiB -> rate 1/8 or 1/16 in the final recursion). ## Credits -- [Plonky3](https://github.com/Plonky3/Plonky3) for its various performant crates (Finite fields, poseidon2 AIR etc) +- [Plonky3](https://github.com/Plonky3/Plonky3) for its various performant crates - [whir-p3](https://github.com/tcoratger/whir-p3): a Plonky3-compatible WHIR implementation - [Whirlaway](https://github.com/TomWambsgans/Whirlaway): Multilinear snark for AIR + minimal zkVM diff --git a/TODO.md b/TODO.md index 715e4e3a..83747478 100644 --- a/TODO.md +++ b/TODO.md @@ -2,40 +2,24 @@ ## Perf -- WHIR univariate skip? -- Opti recursion bytecode -- packing (SIMD) everywhere -- one can "move out" the variable of the eq(.) polynomials out of the sumcheck computation in WHIR (as done in the PIOP) -- Structured AIR: often no all the columns use both up/down -> only handle the used ones to speed up the PIOP zerocheck -- Use Univariate Skip to commit to tables with k.2^n rows (k small) -- opti logup* GKR when the indexes are not a power of 2 (which is the case in the execution table) -- incremental merkle paths in whir-p3 -- Avoid embedding overhead on the flag, len, and index columns in the AIR table for dot products -- Lev's trick to skip some low-level modular reduction -- Sumcheck, case z = 0, no need to fold, only keep first half of the values (done in PR 33 by Lambda) -- Custom AVX2 / AVX512 / Neon implem in Plonky3 for all of the finite field operations (done for degree 4 extension, but not degree 5) -- Many times, we evaluate different multilinear polynomials (different columns of the same table etc) at a common point. OPTI = compute the eq(.) once, and then dot_product with everything -- To commit to multiple AIR table using 1 single pcs, the most general form our "packed pcs" api should accept is: - a list of n (n not a power of 2) columns, each ending with m repeated values (in this manner we can reduce proof size when they are a lot of columns (poseidons ...)) -- in the runner of leanISA program, if we call 2 times the same function with the same arguments, we can reuse the same memory frame +- 128 bits security +- Merkle pruning - the interpreter of leanISA (+ witness generation) can be partially parallelized when there are some independent loops -- (1 - x).r1 + x.r2 = x.(r2 - r1) + r1 TODO this opti is not everywhere currently + TODO generalize this with the univariate skip -- opti compute_eval_eq when scalar = ONE -- Dmitry's range check, bonus: we can spare 2 memory cells if the value being range check is small (using the zeros present by conventio on the public memory) - Make everything "padding aware" (including WHIR, logup*, AIR, etc) - Opti WHIR: in sumcheck we know more than f(0) + f(1), we know f(0) and f(1) -- Opti WHIR https://github.com/tcoratger/whir-p3/issues/303 and https://github.com/tcoratger/whir-p3/issues/306 -- Avoid committing to the 3 index columns, and replace it by a sumcheck? Using this idea, we would only commit to PC and FP for the execution table. Idea by Georg (Powdr). Do we even need to commit to FP then? -- Avoid the embedding overhead in logup, when denominators = "c - index", as it was previously done -- SIMD (Packing) for PoW grinding in Fiat-Shamir (has been implemented in the lean-vm-simple branch by [x-senpai-x](https://github.com/x-senpai-x), see [here](https://github.com/leanEthereum/fiat-shamir/blob/d80da40a76c00aaa6d35fe5e51c3bf31eaf8fe17/src/prover.rs#L98)) - +- Opti WHIR https://github.com/tcoratger/whir-p3/issues/303 and https://github.com/tcoratger/whir-p3/issues/306 ? +- Avoid the embedding overhead in logup, when denominators = "c - index" +- Proof size: replace all equality checks in the verifier algo by value deduction +- Poseidon in 'Compression' mode everywhere (except in 'Sponge' mode? cf. eprint 2014/223) +- XMSS: move from toy implem (usefull for benchmark) to a secure implem +- Recursion: Remove the few hardcoded constants that depend on the guest execution (cycles etc) - About the ordering of the variables in sumchecks, currently we do as follows: [a, b, c, d, e, f, g, h] (1st round of sumcheck) [(a-r).a + r.e, (1-r).b + r.f, (1-r).c + r.g, (1-r).d + r.h] (2nd round of sumcheck) ... etc -This is otpimal for packing (SIMD) but not optimal when to comes to padding. +This is optimal for packing (SIMD) but not optimal when to comes to padding. When there are a lot of "ending" zeros, the optimal way of folding is: [a, b, c, d, e, 0, 0, 0] (1st round of sumcheck) @@ -48,92 +32,15 @@ But we can get the bost of both worlds (suggested by Lev, TODO implement): [(1-r).a + r.c, (1-r).b + r.d, (1-r).e + r.g, (1-r).f + r.h, (1-r).i, 0, 0, 0] (2nd round of sumcheck) ... etc -About "the packed pcs" (similar to SP1 Jagged PCS, slightly less efficient, but simpler (no sumchecks)): -- The best strategy is probably to pack as much as possible (the cost increasing the density = additional inner evaluations), if we can fit below a power of 2 - epsilon (epsilon = 20% for instance, tbd), if the sum of the non zero data is just above a power of 2, no packed technique, even the best, can help us, so we should spread aniway (to reduce the pressure of inner evaluations) -- About those inner evaluations, there is a trick: we need to compute M1(a, b, c, d, ...) then M2(b, c, d, ...), then M3(c, d, ...) -> The trick = compute the "eq(.) for (b, c, d), then dot product with M3. Then expand to eq(b, c, d, ...), dot product with M2. Then expand to eq(a, b, c, d, ...), dot product with M1. The idea is that in this order, computing each "eq" is easier is we start from the previous one. -- Currently the packed pcs works as follows: - -``` -┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐ -| || || || || || || || || || || || || || | -| || || || || || || || || || || || || || | -| || || || || || || || || || || || || || | -| || || || || || || || || || || || || || | -| || || || || || || || || || || || || || | -| || || || || || || || || || || || || || | -└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘ -┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐ -| || || || || || || || || || || || || || | -| || || || || || || || || || || || || || | -└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘ -┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐ -└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘ -``` - -But we reduce proof size a lot using instead (TODO): - -``` -┌────────────────────────┐┌──────────┐┌─┐ -| || || | -| || || | -| || || | -| || || | -| || || | -| || || | -└────────────────────────┘└──────────┘└─┘ -┌────────────────────────┐┌──────────┐┌─┐ -| || || | -| || || | -└────────────────────────┘└──────────┘└─┘ -┌────────────────────────┐┌──────────┐┌─┐ -└────────────────────────┘└──────────┘└─┘ -``` - ## Security: -Fiat Shamir: add a claim tracing feature, to ensure all the claims are indeed checked (Lev) - -## Not Perf - -- Whir batching: handle the case where the second polynomial is too small compared to the first one -- bounddary condition on dot_product table: first flag = 1 -- verify correctness of the Grand Product check -- Proof size: replace all equality checks in the verifier algo by value deduction +- Fiat Shamir: add a claim tracing feature, to ensure all the claims are indeed checked (Lev) +- Double Check AIR constraints, logup overflows etc +- Formal Verification -- KoalaBear extension of degree 5: the current implem (in a fork of Plonky3) has not been been optimized -- KoalaBear extension of degree 6: in order to use the (proven) Johnson bound in WHIR -- current "packed PCS" is not optimal in the end: can lead to [16][4][2][2] (instead of [16][8]) -- make test_packed_pcs pass again -- Poseidon AIR: handle properly the compression mode ? (where output = poseidon(input) + input) (both in WHIR / XMSS) -- XMSS: implem the hash tweak (almost no performance impact as long as we use 1 tweak / XMSS, but this requires further security analysis) -- Grinding before GKR (https://eprint.iacr.org/2025/118) - - -# Random ideas +# Ideas - About range checks, that can currently be done in 3 cycles (see 2.5.3 of the zkVM pdf) + 3 memory cells used. For small ranges we can save 2 memory cells. - -## Known leanISA compiler bugs: - -### Non exhaustive conditions in inlined functions - -Currently, to inline functions we simply replace the "return" keyword by some variable assignment, i.e. -we do not properly handle conditions, we would need to add some JUMPs ... - -``` -fn works(x) inline -> 1 { - if x == 0 { - return 0; - } else { - return 1; - } -} - -fn doesnt_work(x) inline -> 1 { - if x == 0 { - return 0; // will be compiled to `res = 0`; - } // the bug: we do not JUMP here, when inlined - return 1; // will be compiled to `res = 1`; -> invalid (res = 0 and = 1 at the same time) -} -``` - +- Avoid committing to the 3 index columns, and replace it by a sumcheck? Idea by Georg (Powdr). Advantage: Less commitment surface. Drawback: increase the number of instances in the final WHIR batch opening -> proof size overhead +- Lev's trick to skip some low-level modular reduction? + diff --git a/crates/air/Cargo.toml b/crates/air/Cargo.toml index e8eb3f85..57afa0a7 100644 --- a/crates/air/Cargo.toml +++ b/crates/air/Cargo.toml @@ -9,7 +9,6 @@ workspace = true [dependencies] tracing.workspace = true utils.workspace = true -p3-air.workspace = true p3-util.workspace = true multilinear-toolkit.workspace = true diff --git a/crates/air/src/lib.rs b/crates/air/src/lib.rs index df7cb7ec..a3544550 100644 --- a/crates/air/src/lib.rs +++ b/crates/air/src/lib.rs @@ -6,6 +6,19 @@ mod utils; mod validity_check; mod verify; +use multilinear_toolkit::prelude::{Field, MultilinearPoint}; pub use prove::*; pub use validity_check::*; pub use verify::*; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AirClaims { + pub point: MultilinearPoint, + pub evals_f: Vec, + pub evals_ef: Vec, + + // only for columns with a "shift", in case univariate skip == 1 + pub down_point: Option>, + pub evals_f_on_down_columns: Vec, + pub evals_ef_on_down_columns: Vec, +} diff --git a/crates/air/src/prove.rs b/crates/air/src/prove.rs index ad6b8409..51cf62ca 100644 --- a/crates/air/src/prove.rs +++ b/crates/air/src/prove.rs @@ -1,10 +1,9 @@ use multilinear_toolkit::prelude::*; -use p3_air::Air; use p3_util::{log2_ceil_usize, log2_strict_usize}; use tracing::{info_span, instrument}; -use utils::{FSProver, fold_multilinear_chunks, multilinears_linear_combination}; +use utils::{fold_multilinear_chunks, multilinears_linear_combination}; -use crate::{uni_skip_utils::matrix_next_mle_folded, utils::column_shifted}; +use crate::{AirClaims, uni_skip_utils::matrix_next_mle_folded, utils::column_shifted}; /* @@ -15,17 +14,15 @@ cf https://eprint.iacr.org/2023/552.pdf and https://solvable.group/posts/super-a #[instrument(name = "prove air", skip_all)] #[allow(clippy::too_many_arguments)] pub fn prove_air>, A: Air>( - prover_state: &mut FSProver>, + prover_state: &mut impl FSProver, air: &A, - mut extra_data: A::ExtraData, + extra_data: A::ExtraData, univariate_skips: usize, columns_f: &[impl AsRef<[PF]>], columns_ef: &[impl AsRef<[EF]>], - last_row_shifted_f: &[PF], - last_row_shifted_ef: &[EF], - virtual_column_statements: Option>, // point should be randomness generated after committing to the columns + virtual_column_statement: Option>, // point should be randomness generated after committing to the columns store_intermediate_foldings: bool, -) -> (MultilinearPoint, Vec, Vec) +) -> AirClaims where A::ExtraData: AlphaPowersMut + AlphaPowers, { @@ -50,16 +47,11 @@ where // ) // .unwrap(); - let alpha = prover_state.sample(); // random challenge for batching constraints - - *extra_data.alpha_powers_mut() = alpha - .powers() - .take(air.n_constraints() + virtual_column_statements.as_ref().map_or(0, |s| s.values.len())) - .collect(); + assert!(extra_data.alpha_powers().len() >= air.n_constraints() + virtual_column_statement.is_some() as usize); let n_sc_rounds = log_n_rows + 1 - univariate_skips; - let zerocheck_challenges = virtual_column_statements + let zerocheck_challenges = virtual_column_statement .as_ref() .map(|st| st.point.0.clone()) .unwrap_or_else(|| prover_state.sample_vec(n_sc_rounds)); @@ -68,14 +60,12 @@ where let shifted_rows_f = air .down_column_indexes_f() .par_iter() - .zip_eq(last_row_shifted_f) - .map(|(&col_index, &final_value)| column_shifted(columns_f[col_index], final_value.as_base().unwrap())) + .map(|&col_index| column_shifted(columns_f[col_index])) .collect::>(); let shifted_rows_ef = air .down_column_indexes_ef() .par_iter() - .zip_eq(last_row_shifted_ef) - .map(|(&col_index, &final_value)| column_shifted(columns_ef[col_index], final_value)) + .map(|&col_index| column_shifted(columns_ef[col_index])) .collect::>(); let mut columns_up_down_f = columns_f.to_vec(); // orginal columns, followed by shifted ones @@ -98,11 +88,11 @@ where air, &extra_data, Some((zerocheck_challenges, None)), - virtual_column_statements.is_none(), + virtual_column_statement.is_none(), prover_state, - virtual_column_statements + virtual_column_statement .as_ref() - .map(|st| dot_product(st.values.iter().copied(), alpha.powers())) + .map(|st| st.value) .unwrap_or_else(|| EF::ZERO), store_intermediate_foldings, ) @@ -110,29 +100,41 @@ where prover_state.add_extension_scalars(&inner_sums); - open_columns( - prover_state, - univariate_skips, - &air.down_column_indexes_f(), - &air.down_column_indexes_ef(), - &columns_f, - &columns_ef, - &outer_sumcheck_challenge, - ) + if univariate_skips == 1 { + open_columns_no_skip( + prover_state, + &inner_sums, + &air.down_column_indexes_f(), + &air.down_column_indexes_ef(), + &columns_f, + &columns_ef, + &outer_sumcheck_challenge, + ) + } else if shifted_rows_f.is_empty() && shifted_rows_ef.is_empty() { + // usefull for poseidon2 benchmark + open_flat_columns( + prover_state, + univariate_skips, + &columns_f, + &columns_ef, + &outer_sumcheck_challenge, + ) + } else { + panic!( + "Currently unsupported for simplicty (checkout c7944152a4325b1e1913446e6684112099db5d78 for a version that supported this case)" + ); + } } #[instrument(skip_all)] -fn open_columns>>( - prover_state: &mut FSProver>, +fn open_flat_columns>>( + prover_state: &mut impl FSProver, univariate_skips: usize, - columns_with_shift_f: &[usize], - columns_with_shift_ef: &[usize], columns_f: &[&[PF]], columns_ef: &[&[EF]], outer_sumcheck_challenge: &[EF], -) -> (MultilinearPoint, Vec, Vec) { - let n_up_down_columns = - columns_f.len() + columns_ef.len() + columns_with_shift_f.len() + columns_with_shift_ef.len(); +) -> AirClaims { + let n_up_down_columns = columns_f.len() + columns_ef.len(); let batching_scalars = prover_state.sample_vec(log2_ceil_usize(n_up_down_columns)); let eval_eq_batching_scalars = eval_eq(&batching_scalars)[..n_up_down_columns].to_vec(); @@ -152,47 +154,10 @@ fn open_columns>>( }); } - let columns_shifted_f = &columns_with_shift_f.iter().map(|&i| columns_f[i]).collect::>(); - let columns_shifted_ef = &columns_with_shift_ef.iter().map(|&i| columns_ef[i]).collect::>(); - - let mut batched_column_down = if columns_shifted_f.is_empty() { - tracing::warn!("TODO optimize open_columns when no shifted F columns"); - vec![EF::ZERO; batched_column_up.len()] - } else { - multilinears_linear_combination( - columns_shifted_f, - &eval_eq_batching_scalars[columns_f.len() + columns_ef.len()..][..columns_shifted_f.len()], - ) - }; - - if !columns_shifted_ef.is_empty() { - let batched_column_down_ef = multilinears_linear_combination( - columns_shifted_ef, - &eval_eq_batching_scalars[columns_f.len() + columns_ef.len() + columns_shifted_f.len()..], - ); - batched_column_down - .par_iter_mut() - .zip(&batched_column_down_ef) - .for_each(|(a, &b)| { - *a += b; - }); - } - - let sub_evals = info_span!("fold_multilinear_chunks").in_scope(|| { - let sub_evals_up = fold_multilinear_chunks( - &batched_column_up, - &MultilinearPoint(outer_sumcheck_challenge[1..].to_vec()), - ); - let sub_evals_down = fold_multilinear_chunks( - &column_shifted(&batched_column_down, EF::ZERO), - &MultilinearPoint(outer_sumcheck_challenge[1..].to_vec()), - ); - sub_evals_up - .iter() - .zip(sub_evals_down.iter()) - .map(|(&up, &down)| up + down) - .collect::>() - }); + let sub_evals = fold_multilinear_chunks( + &batched_column_up, + &MultilinearPoint(outer_sumcheck_challenge[1..].to_vec()), + ); prover_state.add_extension_scalars(&sub_evals); let epsilons = prover_state.sample_vec(univariate_skips); @@ -203,14 +168,8 @@ fn open_columns>>( // TODO opti in case of flat AIR (no need of `matrix_next_mle_folded`) let matrix_up = eval_eq(&point); - let matrix_down = matrix_next_mle_folded(&point); let inner_mle = info_span!("packing").in_scope(|| { - MleGroupOwned::ExtensionPacked(vec![ - pack_extension(&matrix_up), - pack_extension(&batched_column_up), - pack_extension(&matrix_down), - pack_extension(&batched_column_down), - ]) + MleGroupOwned::ExtensionPacked(vec![pack_extension(&matrix_up), pack_extension(&batched_column_up)]) }); let (inner_challenges, _, _) = info_span!("structured columns sumcheck").in_scope(|| { @@ -218,7 +177,7 @@ fn open_columns>>( 1, inner_mle, None, - &MySumcheck, + &ProductComputation {}, &vec![], None, false, @@ -244,40 +203,115 @@ fn open_columns>>( prover_state.add_extension_scalars(&evaluations_remaining_to_prove_f); prover_state.add_extension_scalars(&evaluations_remaining_to_prove_ef); - ( - inner_challenges, - evaluations_remaining_to_prove_f, - evaluations_remaining_to_prove_ef, - ) + AirClaims { + point: inner_challenges, + evals_f: evaluations_remaining_to_prove_f, + evals_ef: evaluations_remaining_to_prove_ef, + down_point: None, + evals_f_on_down_columns: vec![], + evals_ef_on_down_columns: vec![], + } } -struct MySumcheck; +#[instrument(skip_all)] +fn open_columns_no_skip>>( + prover_state: &mut impl FSProver, + inner_evals: &[EF], + columns_with_shift_f: &[usize], + columns_with_shift_ef: &[usize], + columns_f: &[&[PF]], + columns_ef: &[&[EF]], + outer_sumcheck_challenge: &[EF], +) -> AirClaims { + let n_columns_f_up = columns_f.len(); + let n_columns_ef_up = columns_ef.len(); + let n_columns_f_down = columns_with_shift_f.len(); + let n_columns_ef_down = columns_with_shift_ef.len(); + let n_down_columns = n_columns_f_down + n_columns_ef_down; + assert_eq!(inner_evals.len(), n_columns_f_up + n_columns_ef_up + n_down_columns); + + let evals_up_f = inner_evals[..n_columns_f_up].to_vec(); + let evals_down_f = &inner_evals[n_columns_f_up..][..n_columns_f_down]; + let evals_up_ef = inner_evals[n_columns_f_up + n_columns_f_down..][..n_columns_ef_up].to_vec(); + let evals_down_ef = &inner_evals[n_columns_f_up + n_columns_f_down + n_columns_ef_up..]; + + if n_down_columns == 0 { + return AirClaims { + point: MultilinearPoint(outer_sumcheck_challenge.to_vec()), + evals_f: evals_up_f, + evals_ef: evals_up_ef, + down_point: None, + evals_f_on_down_columns: vec![], + evals_ef_on_down_columns: vec![], + }; + } -impl>> SumcheckComputation for MySumcheck { - type ExtraData = Vec; + let batching_scalar = prover_state.sample(); + let batching_scalar_powers = batching_scalar.powers().collect_n(n_down_columns); - fn degree(&self) -> usize { - 2 - } - #[inline(always)] - fn eval_base(&self, _: &[PF], _: &[EF], _: &Self::ExtraData) -> EF { - unreachable!() - } - #[inline(always)] - fn eval_extension(&self, point: &[EF], _: &[EF], _: &Self::ExtraData) -> EF { - point[0] * point[1] + point[2] * point[3] - } - #[inline(always)] - fn eval_packed_base(&self, _: &[PFPacking], _: &[EFPacking], _: &Self::ExtraData) -> EFPacking { - unreachable!() + let columns_shifted_f = &columns_with_shift_f.iter().map(|&i| columns_f[i]).collect::>(); + let columns_shifted_ef = &columns_with_shift_ef.iter().map(|&i| columns_ef[i]).collect::>(); + + let mut batched_column_down = + multilinears_linear_combination(columns_shifted_f, &batching_scalar_powers[..n_columns_f_down]); + + if n_columns_ef_down > 0 { + let batched_column_down_ef = + multilinears_linear_combination(columns_shifted_ef, &batching_scalar_powers[n_columns_f_down..]); + batched_column_down + .par_iter_mut() + .zip(&batched_column_down_ef) + .for_each(|(a, &b)| { + *a += b; + }); } - #[inline(always)] - fn eval_packed_extension( - &self, - point: &[EFPacking], - _: &[EFPacking], - _: &Self::ExtraData, - ) -> EFPacking { - point[0] * point[1] + point[2] * point[3] + + let matrix_down = matrix_next_mle_folded(outer_sumcheck_challenge); + let inner_mle = info_span!("packing").in_scope(|| { + MleGroupOwned::ExtensionPacked(vec![pack_extension(&matrix_down), pack_extension(&batched_column_down)]) + }); + + let inner_sum = dot_product( + evals_down_f.iter().chain(evals_down_ef).copied(), + batching_scalar_powers.iter().copied(), + ); + + let (inner_challenges, _, _) = info_span!("structured columns sumcheck").in_scope(|| { + sumcheck_prove::( + 1, + inner_mle, + None, + &ProductComputation {}, + &vec![], + None, + false, + prover_state, + inner_sum, + false, + ) + }); + + let (evals_f_on_down_columns, evals_ef_on_down_columns) = info_span!("final evals").in_scope(|| { + ( + columns_shifted_f + .par_iter() + .map(|col| col.evaluate(&inner_challenges)) + .collect::>(), + columns_shifted_ef + .par_iter() + .map(|col| col.evaluate(&inner_challenges)) + .collect::>(), + ) + }); + prover_state.add_extension_scalars(&evals_f_on_down_columns); + prover_state.add_extension_scalars(&evals_ef_on_down_columns); + + AirClaims { + point: MultilinearPoint(outer_sumcheck_challenge.to_vec()), + evals_f: evals_up_f, + evals_ef: evals_up_ef, + down_point: Some(inner_challenges), + evals_f_on_down_columns, + evals_ef_on_down_columns, } } diff --git a/crates/air/src/uni_skip_utils.rs b/crates/air/src/uni_skip_utils.rs index b1b42a37..65cef1b3 100644 --- a/crates/air/src/uni_skip_utils.rs +++ b/crates/air/src/uni_skip_utils.rs @@ -15,6 +15,8 @@ pub fn matrix_next_mle_folded>>(outer_challenges: &[F]) res[i] += *v; } } + res[(1 << n) - 1] += outer_challenges.iter().copied().product::(); + res } @@ -35,9 +37,13 @@ mod tests { let matrix = matrix_next_mle_folded(&x_bools); for y in 0..1 << n_vars { let y_bools = to_big_endian_in_field::(y, n_vars); - let expected = F::from_bool(x + 1 == y); + let expected = F::from_bool(if (x, y) == ((1 << n_vars) - 1, (1 << n_vars) - 1) { + true + } else { + x + 1 == y + }); assert_eq!(matrix.evaluate(&MultilinearPoint(y_bools.clone())), expected); - assert_eq!(next_mle(&[x_bools.clone(), y_bools].concat()), expected); + assert_eq!(next_mle(&x_bools, &y_bools), expected); } } } diff --git a/crates/air/src/utils.rs b/crates/air/src/utils.rs index 63ae1972..0614f515 100644 --- a/crates/air/src/utils.rs +++ b/crates/air/src/utils.rs @@ -11,7 +11,7 @@ use multilinear_toolkit::prelude::*; /// ... ... ... /// (0 0 0 0 ... 0 1 0) /// (0 0 0 0 ... 0 0 1) -/// (0 0 0 0 ... 0 0 0) +/// (0 0 0 0 ... 0 0 1) /// /// # Arguments /// - `point`: A slice of 2n field elements representing two n-bit vectors concatenated. @@ -34,49 +34,31 @@ use multilinear_toolkit::prelude::*; /// /// # Returns /// Field element: 1 if y = x + 1, 0 otherwise. -pub(crate) fn next_mle(point: &[F]) -> F { - // Check that the point length is even: we split into x and y of equal length. - assert_eq!(point.len() % 2, 0, "Input point must have an even number of variables."); - let n = point.len() / 2; - - // Split point into x (first n) and y (last n). - let (x, y) = point.split_at(n); - - // Sum contributions for each possible carry position k = 0..n-1. - (0..n) - .map(|k| { - // Term 1: bits to the left of k match - // - // For i > k, enforce x_i == y_i. - // Using equality polynomial: x_i * y_i + (1 - x_i)*(1 - y_i). - // - // Indices are reversed because bits are big-endian. - let eq_high_bits = (k + 1..n) - .map(|i| x[n - 1 - i] * y[n - 1 - i] + (F::ONE - x[n - 1 - i]) * (F::ONE - y[n - 1 - i])) - .product::(); - - // Term 2: carry bit at position k - // - // Enforce x_k = 0 and y_k = 1. - // Condition: (1 - x_k) * y_k. - let carry_bit = (F::ONE - x[n - 1 - k]) * y[n - 1 - k]; - - // Term 3: bits to the right of k are 1 in x and 0 in y - // - // For i < k, enforce x_i = 1 and y_i = 0. - // Condition: x_i * (1 - y_i). - let low_bits_are_one_zero = (0..k).map(|i| x[n - 1 - i] * (F::ONE - y[n - 1 - i])).product::(); +pub(crate) fn next_mle(x: &[F], y: &[F]) -> F { + assert_eq!(x.len(), y.len()); + let n = x.len(); + let mut eq_prefix = Vec::with_capacity(n + 1); + eq_prefix.push(F::ONE); + for i in 0..n { + let eq_i = x[i] * y[i] + (F::ONE - x[i]) * (F::ONE - y[i]); + eq_prefix.push(eq_prefix[i] * eq_i); + } + let mut low_suffix = vec![F::ONE; n + 1]; + for i in (0..n).rev() { + low_suffix[i] = low_suffix[i + 1] * x[i] * (F::ONE - y[i]); + } + let mut sum = F::ZERO; + for arr in 0..n { + let carry = (F::ONE - x[arr]) * y[arr]; + sum += eq_prefix[arr] * carry * low_suffix[arr + 1]; + } - // Multiply the three terms for this k, representing one "carry pattern". - eq_high_bits * carry_bit * low_bits_are_one_zero - }) - // Sum over all carry positions: any valid "k" gives contribution 1. - .sum() + sum + x.iter().chain(y).copied().product::() } -pub(crate) fn column_shifted(column: &[F], final_value: F) -> Vec { +pub(crate) fn column_shifted(column: &[F]) -> Vec { let mut down = unsafe { uninitialized_vec(column.len()) }; parallel_clone(&column[1..], &mut down[..column.len() - 1]); - down[column.len() - 1] = final_value; + down[column.len() - 1] = column[column.len() - 1]; down } diff --git a/crates/air/src/validity_check.rs b/crates/air/src/validity_check.rs index f9cb2876..1afd6007 100644 --- a/crates/air/src/validity_check.rs +++ b/crates/air/src/validity_check.rs @@ -1,5 +1,4 @@ -use multilinear_toolkit::prelude::{ExtensionField, PF}; -use p3_air::Air; +use multilinear_toolkit::prelude::*; use tracing::instrument; use utils::ConstraintChecker; @@ -68,8 +67,8 @@ pub fn check_air_validity>>( let up_ef = (0..air.n_columns_ef_air()) .map(|j| columns_ef[j][n_rows - 1]) .collect::>(); - assert_eq!(last_row_f.len(), air.down_column_indexes_f().len()); - assert_eq!(last_row_ef.len(), air.down_column_indexes_ef().len()); + assert_eq!(last_row_f.len(), air.n_down_columns_f()); + assert_eq!(last_row_ef.len(), air.n_down_columns_ef()); let mut constraints_checker = ConstraintChecker { up_f, up_ef, diff --git a/crates/air/src/verify.rs b/crates/air/src/verify.rs index cdbfec62..2e3d00ef 100644 --- a/crates/air/src/verify.rs +++ b/crates/air/src/verify.rs @@ -1,63 +1,49 @@ use multilinear_toolkit::prelude::*; -use p3_air::Air; use p3_util::log2_ceil_usize; -use crate::utils::next_mle; +use crate::{AirClaims, utils::next_mle}; #[allow(clippy::type_complexity)] #[allow(clippy::too_many_arguments)] pub fn verify_air>, A: Air>( - verifier_state: &mut FSVerifier>, + verifier_state: &mut impl FSVerifier, air: &A, - mut extra_data: A::ExtraData, + extra_data: A::ExtraData, univariate_skips: usize, log_n_rows: usize, - last_row_f: &[PF], - last_row_ef: &[EF], - virtual_column_statements: Option>, // point should be randomness generated after committing to the columns -) -> Result<(MultilinearPoint, Vec, Vec), ProofError> + virtual_column_statement: Option>, // point should be randomness generated after committing to the columns +) -> ProofResult> where A::ExtraData: AlphaPowersMut + AlphaPowers, { - let alpha = verifier_state.sample(); // random challenge for batching constraints - - *extra_data.alpha_powers_mut() = alpha - .powers() - .take(air.n_constraints() + virtual_column_statements.as_ref().map_or(0, |s| s.values.len())) - .collect(); + assert!(extra_data.alpha_powers().len() >= air.n_constraints() + virtual_column_statement.is_some() as usize); let n_sc_rounds = log_n_rows + 1 - univariate_skips; - let zerocheck_challenges = virtual_column_statements + let zerocheck_challenges = virtual_column_statement .as_ref() .map(|st| st.point.0.clone()) .unwrap_or_else(|| verifier_state.sample_vec(n_sc_rounds)); assert_eq!(zerocheck_challenges.len(), n_sc_rounds); let (sc_sum, outer_statement) = - sumcheck_verify_with_univariate_skip::(verifier_state, air.degree() + 1, log_n_rows, univariate_skips)?; + sumcheck_verify_with_univariate_skip::(verifier_state, air.degree_air() + 1, log_n_rows, univariate_skips)?; if sc_sum - != virtual_column_statements + != virtual_column_statement .as_ref() - .map(|st| dot_product(st.values.iter().copied(), alpha.powers())) + .map(|st| st.value) .unwrap_or_else(|| EF::ZERO) { return Err(ProofError::InvalidProof); } - let outer_selector_evals = univariate_selectors::>(univariate_skips) - .iter() - .map(|s| s.evaluate(outer_statement.point[0])) - .collect::>(); - - let mut inner_sums = verifier_state.next_extension_scalars_vec( - air.n_columns_air() + air.down_column_indexes_f().len() + air.down_column_indexes_ef().len(), - )?; + let inner_evals = verifier_state + .next_extension_scalars_vec(air.n_columns_air() + air.n_down_columns_f() + air.n_down_columns_ef())?; - let n_columns_down_f = air.down_column_indexes_f().len(); - let constraint_evals = SumcheckComputation::eval_extension( - air, - &inner_sums[..air.n_columns_f_air() + n_columns_down_f], - &inner_sums[air.n_columns_f_air() + n_columns_down_f..], + let n_columns_down_f = air.n_down_columns_f(); + let n_columns_down_ef = air.n_down_columns_ef(); + let constraint_evals = air.eval_extension( + &inner_evals[..air.n_columns_f_air() + n_columns_down_f], + &inner_evals[air.n_columns_f_air() + n_columns_down_f..], &extra_data, ); @@ -67,76 +53,54 @@ where return Err(ProofError::InvalidProof); } - inner_sums = [ - inner_sums[..air.n_columns_f_air()].to_vec(), - inner_sums[air.n_columns_f_air() + n_columns_down_f..][..air.n_columns_ef_air()].to_vec(), - inner_sums[air.n_columns_f_air()..][..n_columns_down_f].to_vec(), - inner_sums[air.n_columns_f_air() + n_columns_down_f + air.n_columns_ef_air()..].to_vec(), - ] - .concat(); - - open_columns( - verifier_state, - air.n_columns_f_air(), - air.n_columns_ef_air(), - univariate_skips, - &air.down_column_indexes_f(), - &air.down_column_indexes_ef(), - inner_sums, - &Evaluation::new(outer_statement.point[1..].to_vec(), outer_statement.value), - &outer_selector_evals, - log_n_rows, - last_row_f, - last_row_ef, - ) + if univariate_skips == 1 { + open_columns_no_skip(verifier_state, air, log_n_rows, &inner_evals, &outer_statement.point) + } else if n_columns_down_f == 0 && n_columns_down_ef == 0 { + // usefull for poseidon2 benchmark + let outer_selector_evals = univariate_selectors::>(univariate_skips) + .iter() + .map(|s| s.evaluate(outer_statement.point[0])) + .collect::>(); + open_flat_columns( + verifier_state, + air.n_columns_f_air(), + air.n_columns_ef_air(), + univariate_skips, + inner_evals, + &Evaluation::new(outer_statement.point[1..].to_vec(), outer_statement.value), + &outer_selector_evals, + log_n_rows, + ) + } else { + panic!( + "Currently unsupported for simplicty (checkout c7944152a4325b1e1913446e6684112099db5d78 for a version that supported this case)" + ); + } } #[allow(clippy::too_many_arguments)] // TODO #[allow(clippy::type_complexity)] -fn open_columns>>( - verifier_state: &mut FSVerifier>, +fn open_flat_columns>>( + verifier_state: &mut impl FSVerifier, n_columns_f: usize, n_columns_ef: usize, univariate_skips: usize, - columns_with_shift_f: &[usize], - columns_with_shift_ef: &[usize], - mut evals_up_and_down: Vec, + inner_evals: Vec, outer_sumcheck_challenge: &Evaluation, outer_selector_evals: &[EF], log_n_rows: usize, - last_row_f: &[PF], - last_row_ef: &[EF], -) -> Result<(MultilinearPoint, Vec, Vec), ProofError> { +) -> ProofResult> { let n_columns = n_columns_f + n_columns_ef; - assert_eq!( - n_columns + last_row_f.len() + last_row_ef.len(), - evals_up_and_down.len() - ); - let last_row_selector = outer_selector_evals[(1 << univariate_skips) - 1] - * outer_sumcheck_challenge.point.iter().copied().product::(); - for (&last_row_value, down_col_eval) in last_row_f.iter().zip(&mut evals_up_and_down[n_columns..]) { - *down_col_eval -= last_row_selector * last_row_value; - } - for (&last_row_value, down_col_eval) in last_row_ef - .iter() - .zip(&mut evals_up_and_down[n_columns + last_row_f.len()..]) - { - *down_col_eval -= last_row_selector * last_row_value; - } - let batching_scalars = verifier_state.sample_vec(log2_ceil_usize(n_columns + last_row_f.len() + last_row_ef.len())); + let batching_scalars = verifier_state.sample_vec(log2_ceil_usize(n_columns)); let eval_eq_batching_scalars = eval_eq(&batching_scalars); let batching_scalars_up = &eval_eq_batching_scalars[..n_columns]; - let batching_scalars_down = &eval_eq_batching_scalars[n_columns..]; let sub_evals = verifier_state.next_extension_scalars_vec(1 << univariate_skips)?; if dot_product::(sub_evals.iter().copied(), outer_selector_evals.iter().copied()) - != dot_product::( - evals_up_and_down.iter().copied(), - eval_eq_batching_scalars.iter().copied(), - ) + != dot_product::(inner_evals.iter().copied(), eval_eq_batching_scalars.iter().copied()) { return Err(ProofError::InvalidProof); } @@ -151,14 +115,6 @@ fn open_columns>>( let matrix_up_sc_eval = MultilinearPoint([epsilons.0.clone(), outer_sumcheck_challenge.point.0.clone()].concat()) .eq_poly_outside(&inner_sumcheck_stement.point); - let matrix_down_sc_eval = next_mle( - &[ - epsilons.0, - outer_sumcheck_challenge.point.to_vec(), - inner_sumcheck_stement.point.0.clone(), - ] - .concat(), - ); let evaluations_remaining_to_verify_f = verifier_state.next_extension_scalars_vec(n_columns_f)?; let evaluations_remaining_to_verify_ef = verifier_state.next_extension_scalars_vec(n_columns_ef)?; @@ -171,26 +127,84 @@ fn open_columns>>( batching_scalars_up.iter().copied(), evaluations_remaining_to_verify.iter().copied(), ); - let mut columns_with_shift = columns_with_shift_f.to_vec(); - columns_with_shift.extend_from_slice( - columns_with_shift_ef - .iter() - .map(|&x| x + n_columns_f) - .collect::>() - .as_slice(), + + if inner_sumcheck_stement.value != matrix_up_sc_eval * batched_col_up_sc_eval { + return Err(ProofError::InvalidProof); + } + Ok(AirClaims { + point: inner_sumcheck_stement.point.clone(), + evals_f: evaluations_remaining_to_verify_f, + evals_ef: evaluations_remaining_to_verify_ef, + down_point: None, + evals_f_on_down_columns: vec![], + evals_ef_on_down_columns: vec![], + }) +} + +fn open_columns_no_skip>>( + verifier_state: &mut impl FSVerifier, + air: &A, + log_n_rows: usize, + inner_evals: &[EF], + outer_sumcheck_challenge: &[EF], +) -> ProofResult> { + let n_columns_f_up = air.n_columns_f_air(); + let n_columns_ef_up = air.n_columns_ef_air(); + let n_columns_f_down = air.n_down_columns_f(); + let n_columns_ef_down = air.n_down_columns_ef(); + let n_down_columns = n_columns_f_down + n_columns_ef_down; + assert_eq!(inner_evals.len(), n_columns_f_up + n_columns_ef_up + n_down_columns); + + let evals_up_f = inner_evals[..n_columns_f_up].to_vec(); + let evals_down_f = inner_evals[n_columns_f_up..][..n_columns_f_down].to_vec(); + let evals_up_ef = inner_evals[n_columns_f_up + n_columns_f_down..][..n_columns_ef_up].to_vec(); + let evals_down_ef = inner_evals[n_columns_f_up + n_columns_f_down + n_columns_ef_up..].to_vec(); + + if n_down_columns == 0 { + return Ok(AirClaims { + point: MultilinearPoint(outer_sumcheck_challenge.to_vec()), + evals_f: evals_up_f, + evals_ef: evals_up_ef, + down_point: None, + evals_f_on_down_columns: vec![], + evals_ef_on_down_columns: vec![], + }); + } + + let batching_scalar = verifier_state.sample(); + let batching_scalar_powers = batching_scalar.powers().collect_n(n_down_columns); + + let inner_sum: EF = dot_product( + evals_down_f.into_iter().chain(evals_down_ef), + batching_scalar_powers.iter().copied(), ); - let batched_col_down_sc_eval = (0..columns_with_shift.len()) - .map(|i| evaluations_remaining_to_verify[columns_with_shift[i]] * batching_scalars_down[i]) - .sum::(); - if inner_sumcheck_stement.value - != matrix_up_sc_eval * batched_col_up_sc_eval + matrix_down_sc_eval * batched_col_down_sc_eval - { + let (inner_sum_retrieved, inner_sumcheck_stement) = sumcheck_verify(verifier_state, log_n_rows, 2)?; + + if inner_sum != inner_sum_retrieved { return Err(ProofError::InvalidProof); } - Ok(( - inner_sumcheck_stement.point.clone(), - evaluations_remaining_to_verify_f, - evaluations_remaining_to_verify_ef, - )) + + let matrix_down_sc_eval = next_mle(outer_sumcheck_challenge, &inner_sumcheck_stement.point); + + let evals_f_on_down_columns = verifier_state.next_extension_scalars_vec(n_columns_f_down)?; + let evals_ef_on_down_columns = verifier_state.next_extension_scalars_vec(n_columns_ef_down)?; + let evaluations_remaining_to_verify = [evals_f_on_down_columns.clone(), evals_ef_on_down_columns.clone()].concat(); + let batched_col_down_sc_eval = dot_product::( + batching_scalar_powers.iter().copied(), + evaluations_remaining_to_verify.iter().copied(), + ); + + if inner_sumcheck_stement.value != matrix_down_sc_eval * batched_col_down_sc_eval { + return Err(ProofError::InvalidProof); + } + + Ok(AirClaims { + point: MultilinearPoint(outer_sumcheck_challenge.to_vec()), + evals_f: evals_up_f, + evals_ef: evals_up_ef, + down_point: Some(inner_sumcheck_stement.point.clone()), + evals_f_on_down_columns, + evals_ef_on_down_columns, + }) } diff --git a/crates/air/tests/complex_air.rs b/crates/air/tests/complex_air.rs deleted file mode 100644 index d2cc7c76..00000000 --- a/crates/air/tests/complex_air.rs +++ /dev/null @@ -1,214 +0,0 @@ -use multilinear_toolkit::prelude::*; -use p3_air::{Air, AirBuilder}; -use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; -use rand::{Rng, SeedableRng, rngs::StdRng}; -use utils::{build_prover_state, build_verifier_state}; - -use air::{check_air_validity, prove_air, verify_air}; - -const UNIVARIATE_SKIPS: usize = 3; - -const N_COLS_F: usize = 2; - -type F = KoalaBear; -type EF = QuinticExtensionFieldKB; - -struct ExampleStructuredAir; - -impl Air - for ExampleStructuredAir -{ - type ExtraData = Vec; - - fn n_columns_f_air(&self) -> usize { - N_COLS_F - } - fn n_columns_ef_air(&self) -> usize { - N_COLUMNS - N_COLS_F - } - fn degree(&self) -> usize { - N_PREPROCESSED_COLUMNS - } - fn n_constraints(&self) -> usize { - 50 // too much, but ok for tests - } - fn down_column_indexes_f(&self) -> Vec { - vec![] - } - fn down_column_indexes_ef(&self) -> Vec { - (N_PREPROCESSED_COLUMNS - N_COLS_F..N_COLUMNS - N_COLS_F).collect::>() - } - - #[inline] - fn eval(&self, builder: &mut AB, _: &Self::ExtraData) { - let up_f = builder.up_f().to_vec(); - let up_ef = builder.up_ef().to_vec(); - let down_ef = builder.down_ef().to_vec(); - assert_eq!(up_f.len(), N_COLS_F); - assert_eq!(up_f.len() + up_ef.len(), N_COLUMNS); - assert_eq!(down_ef.len(), N_COLUMNS - N_PREPROCESSED_COLUMNS); - - if VIRTUAL_COLUMN { - // virtual column A = col_0 * col_1 + col_2 - // virtual column B = col_0 - col_1 - builder.eval_virtual_column(up_ef[0].clone() + up_f[0].clone() * up_f[1].clone()); - builder.eval_virtual_column(AB::EF::from(up_f[0].clone() - up_f[1].clone())); - } - - for j in N_PREPROCESSED_COLUMNS..N_COLUMNS { - builder.assert_eq_ef( - down_ef[j - N_PREPROCESSED_COLUMNS].clone(), - up_ef[j - N_COLS_F].clone() - + AB::F::from_usize(j) - + (0..N_PREPROCESSED_COLUMNS - N_COLS_F) - .map(|k| up_ef[k].clone()) - .product::() - * up_f[0].clone() - * up_f[1].clone(), - ); - } - } -} - -fn generate_trace( - n_rows: usize, -) -> (Vec>, Vec>) { - let mut rng = StdRng::seed_from_u64(0); - let mut trace_f = vec![]; - for _ in 0..N_COLS_F { - trace_f.push((0..n_rows).map(|_| rng.random()).collect::>()); - } - let mut trace_ef = vec![]; - for _ in N_COLS_F..N_PREPROCESSED_COLUMNS { - trace_ef.push((0..n_rows).map(|_| rng.random()).collect::>()); - } - let mut witness_cols = vec![vec![EF::ZERO]; N_COLUMNS - N_PREPROCESSED_COLUMNS]; - for i in 1..n_rows { - for (j, witness_col) in witness_cols.iter_mut().enumerate() { - let witness_cols_j_i_min_1 = witness_col[i - 1]; - witness_col.push( - witness_cols_j_i_min_1 - + F::from_usize(j + N_PREPROCESSED_COLUMNS) - + (0..3).map(|k| trace_ef[k][i - 1]).product::() * trace_f[0][i - 1] * trace_f[1][i - 1], - ); - } - } - trace_ef.extend(witness_cols); - (trace_f, trace_ef) -} - -#[test] -fn test_air() { - test_air_helper::(); - test_air_helper::(); -} - -fn test_air_helper() { - const N_COLUMNS: usize = 17; - const N_PREPROCESSED_COLUMNS: usize = 5; - let log_n_rows = 12; - let n_rows = 1 << log_n_rows; - let mut prover_state = build_prover_state::(false); - - let (columns_plus_one_f, columns_plus_one_ef) = generate_trace::(n_rows + 1); - let columns_ref_f = columns_plus_one_f.iter().map(|col| &col[..n_rows]).collect::>(); - let columns_ref_ef = columns_plus_one_ef.iter().map(|col| &col[..n_rows]).collect::>(); - let mut last_row_ef = columns_plus_one_ef.iter().map(|col| col[n_rows]).collect::>(); - last_row_ef = last_row_ef[N_PREPROCESSED_COLUMNS - N_COLS_F..].to_vec(); - - let virtual_column_statement_prover = if VIRTUAL_COLUMN { - let virtual_column_a = columns_ref_f[0] - .iter() - .zip(columns_ref_f[1].iter()) - .zip(columns_ref_ef[0].iter()) - .map(|((&a, &b), &c)| c + a * b) - .collect::>(); - let virtual_column_evaluation_point = - MultilinearPoint(prover_state.sample_vec(log_n_rows + 1 - UNIVARIATE_SKIPS)); - let selectors = univariate_selectors::>(UNIVARIATE_SKIPS); - let virtual_column_value_a = evaluate_univariate_multilinear::<_, _, _, true>( - &virtual_column_a, - &virtual_column_evaluation_point, - &selectors, - None, - ); - let virtual_column_b = columns_ref_f[0] - .iter() - .zip(columns_ref_f[1].iter()) - .map(|(&a, &b)| EF::from(a - b)) - .collect::>(); - let virtual_column_value_b = evaluate_univariate_multilinear::<_, _, _, true>( - &virtual_column_b, - &virtual_column_evaluation_point, - &selectors, - None, - ); - prover_state.add_extension_scalar(virtual_column_value_a); - prover_state.add_extension_scalar(virtual_column_value_b); - - Some(MultiEvaluation::new( - virtual_column_evaluation_point.0.clone(), - vec![virtual_column_value_a, virtual_column_value_b], - )) - } else { - None - }; - - let air = ExampleStructuredAir:: {}; - - check_air_validity(&air, &vec![], &columns_ref_f, &columns_ref_ef, &[], &last_row_ef).unwrap(); - - let (point_prover, evaluations_remaining_to_prove_f, evaluations_remaining_to_prove_ef) = prove_air( - &mut prover_state, - &air, - vec![], - UNIVARIATE_SKIPS, - &columns_ref_f, - &columns_ref_ef, - &[], - &last_row_ef, - virtual_column_statement_prover, - true, - ); - let mut verifier_state = build_verifier_state(prover_state); - - let virtual_column_statement_verifier = if VIRTUAL_COLUMN { - let virtual_column_evaluation_point = - MultilinearPoint(verifier_state.sample_vec(log_n_rows + 1 - UNIVARIATE_SKIPS)); - let virtual_column_value_a = verifier_state.next_extension_scalar().unwrap(); - let virtual_column_value_b = verifier_state.next_extension_scalar().unwrap(); - Some(MultiEvaluation::new( - virtual_column_evaluation_point.0.clone(), - vec![virtual_column_value_a, virtual_column_value_b], - )) - } else { - None - }; - - let (point_verifier, evaluations_remaining_to_verify_f, evaluations_remaining_to_verify_ef) = verify_air( - &mut verifier_state, - &air, - vec![], - UNIVARIATE_SKIPS, - log_n_rows, - &[], - &last_row_ef, - virtual_column_statement_verifier, - ) - .unwrap(); - assert_eq!(point_prover, point_verifier); - assert_eq!(&evaluations_remaining_to_prove_f, &evaluations_remaining_to_verify_f); - assert_eq!(&evaluations_remaining_to_prove_ef, &evaluations_remaining_to_verify_ef); - for i in 0..N_COLS_F { - assert_eq!( - columns_ref_f[i].evaluate(&point_prover), - evaluations_remaining_to_verify_f[i] - ); - } - for i in 0..N_COLUMNS - N_COLS_F { - assert_eq!( - columns_ref_ef[i].evaluate(&point_prover), - evaluations_remaining_to_verify_ef[i] - ); - } -} diff --git a/crates/air/tests/fib_air.rs b/crates/air/tests/fib_air.rs deleted file mode 100644 index 968a3fe4..00000000 --- a/crates/air/tests/fib_air.rs +++ /dev/null @@ -1,119 +0,0 @@ -use multilinear_toolkit::prelude::*; -use p3_air::{Air, AirBuilder}; -use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; -use utils::{build_prover_state, build_verifier_state}; - -use air::{check_air_validity, prove_air, verify_air}; - -const UNIVARIATE_SKIPS: usize = 3; - -type F = KoalaBear; -type EF = QuinticExtensionFieldKB; - -struct FibonacciAir; - -impl Air for FibonacciAir { - type ExtraData = Vec; - - fn n_columns_f_air(&self) -> usize { - 1 - } - fn n_columns_ef_air(&self) -> usize { - 1 - } - fn degree(&self) -> usize { - 1 - } - fn n_constraints(&self) -> usize { - 10 // too much, but ok for tests - } - fn down_column_indexes_f(&self) -> Vec { - vec![0] - } - fn down_column_indexes_ef(&self) -> Vec { - vec![0] - } - #[inline] - fn eval(&self, builder: &mut AB, _: &Self::ExtraData) { - let a_up = builder.up_f()[0].clone(); - let b_up = builder.up_ef()[0].clone(); - let a_down = builder.down_f()[0].clone(); - let b_down = builder.down_ef()[0].clone(); - builder.assert_eq_ef(b_down, b_up.clone() + a_up); - builder.assert_eq_ef(AB::EF::from(a_down), b_up); - } -} - -fn generate_trace(n_rows: usize) -> (Vec, Vec) { - let mut col_a = vec![F::ONE]; - let mut col_b = vec![EF::ONE]; - for i in 1..n_rows { - let a_next = col_b[i - 1].as_base().unwrap(); - let b_next = col_b[i - 1] + col_a[i - 1]; - col_a.push(a_next); - col_b.push(b_next); - } - (col_a, col_b) -} - -#[test] -fn test_air_fibonacci() { - let log_n_rows = 14; - let n_rows = 1 << log_n_rows; - let mut prover_state = build_prover_state::(false); - - let (columns_plus_one_f, columns_plus_one_ef) = generate_trace(n_rows + 1); - let columns_ref_f = vec![&columns_plus_one_f[..n_rows]]; - let columns_ref_ef = vec![&columns_plus_one_ef[..n_rows]]; - let last_row_f = vec![columns_plus_one_f[n_rows]]; - let last_row_ef = vec![columns_plus_one_ef[n_rows]]; - - let air = FibonacciAir {}; - - check_air_validity( - &air, - &vec![], - &columns_ref_f, - &columns_ref_ef, - &last_row_f, - &last_row_ef, - ) - .unwrap(); - - let (point_prover, evaluations_remaining_to_prove_f, evaluations_remaining_to_prove_ef) = prove_air( - &mut prover_state, - &air, - vec![], - UNIVARIATE_SKIPS, - &columns_ref_f, - &columns_ref_ef, - &last_row_f, - &last_row_ef, - None, - true, - ); - let mut verifier_state = build_verifier_state(prover_state); - - let (point_verifier, evaluations_remaining_to_verify_f, evaluations_remaining_to_verify_ef) = verify_air( - &mut verifier_state, - &air, - vec![], - UNIVARIATE_SKIPS, - log_n_rows, - &last_row_f, - &last_row_ef, - None, - ) - .unwrap(); - assert_eq!(point_prover, point_verifier); - assert_eq!(&evaluations_remaining_to_prove_f, &evaluations_remaining_to_verify_f); - assert_eq!(&evaluations_remaining_to_prove_ef, &evaluations_remaining_to_verify_ef); - assert_eq!( - columns_ref_f[0].evaluate(&point_prover), - evaluations_remaining_to_verify_f[0] - ); - assert_eq!( - columns_ref_ef[0].evaluate(&point_prover), - evaluations_remaining_to_verify_ef[0] - ); -} diff --git a/crates/lean_compiler/Cargo.toml b/crates/lean_compiler/Cargo.toml index 8441e6dc..e33b21b2 100644 --- a/crates/lean_compiler/Cargo.toml +++ b/crates/lean_compiler/Cargo.toml @@ -14,14 +14,13 @@ xmss.workspace = true rand.workspace = true p3-poseidon2.workspace = true p3-koala-bear.workspace = true -p3-challenger.workspace = true -p3-air.workspace = true p3-symmetric.workspace = true p3-util.workspace = true whir-p3.workspace = true tracing.workspace = true air.workspace = true sub_protocols.workspace = true -lookup.workspace = true lean_vm.workspace = true -multilinear-toolkit.workspace = true \ No newline at end of file +multilinear-toolkit.workspace = true + +[dev-dependencies] \ No newline at end of file diff --git a/crates/lean_compiler/snark_lib.py b/crates/lean_compiler/snark_lib.py new file mode 100644 index 00000000..dcda6560 --- /dev/null +++ b/crates/lean_compiler/snark_lib.py @@ -0,0 +1,98 @@ +# Import this in zkDSL .py files to make them executable as normal Python + +import math +from typing import Any + +# Type annotations +Mut = Any +Const = Any +Imu = Any + + +# @inline decorator (does nothing in Python execution) +def inline(fn): + return fn + + +# unroll(a, b) returns range(a, b) for Python execution +def unroll(a: int, b: int): + return range(a, b) + + +# Array - simulates write-once memory with pointer arithmetic +class Array: + def __init__(self, size: int): + # TODO + return + + def __getitem__(self, idx): + # TODO + return + + def __setitem__(self, idx, value): + # TODO + return + + def __add__(self, offset: int): + # TODO + return + + def __len__(self): + # TODO + return + + +# DynArray - dynamic array with push/pop (compile-time construct) +class DynArray: + def __init__(self, initial: list): + self._data = list(initial) + + def __getitem__(self, idx): + return self._data[idx] + + def __len__(self): + return len(self._data) + + def push(self, value): + self._data.append(value) + + def pop(self): + self._data.pop() + + +# Built-in constants +ZERO_VEC_PTR = 0 +ONE_VEC_PTR = 16 +NONRESERVED_PROGRAM_INPUT_START = 58 + + +def poseidon16(left, right, output, mode): + _ = left, right, output, mode + + +def dot_product(a, b, result, length, mode): + _ = a, b, result, length, mode + + +def hint_decompose_bits(value, bits, n_bits, endian): + _ = value, bits, n_bits, endian + + +def log2_ceil(x: int) -> int: + assert x > 0 + return math.ceil(math.log2(x)) + + +def next_multiple_of(x: int, n: int) -> int: + return x + (n - x % n) % n + + +def saturating_sub(a: int, b: int) -> int: + return max(0, a - b) + + +def debug_assert(cond, msg=None): + if not cond: + if msg: + raise AssertionError(msg) + raise AssertionError() diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index 72192c41..f2735f81 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -1,19 +1,10 @@ -use crate::{ - Counter, F, - lang::{ - AssignmentTarget, Condition, ConstExpression, ConstMallocLabel, Context, Expression, Function, Line, - MathOperation, Program, Scope, SimpleExpr, Var, - }, - parser::ConstArrayValue, -}; -use lean_vm::{ - Boolean, BooleanExpr, CustomHint, FileId, FunctionName, SourceLineNumber, SourceLocation, Table, TableT, -}; +use crate::{F, lang::*, parser::ConstArrayValue}; +use lean_vm::{ALL_TABLES, Boolean, BooleanExpr, CustomHint, FunctionName, SourceLocation, Table, TableT}; use std::{ collections::{BTreeMap, BTreeSet}, fmt::{Display, Formatter}, }; -use utils::ToUsize; +use utils::{Counter, ToUsize}; #[derive(Debug, Clone)] pub struct SimpleProgram { @@ -23,7 +14,6 @@ pub struct SimpleProgram { #[derive(Debug, Clone)] pub struct SimpleFunction { pub name: String, - pub file_id: FileId, pub arguments: Vec, pub n_returned_vars: usize, pub instructions: Vec, @@ -85,7 +75,7 @@ pub enum SimpleLine { condition: SimpleExpr, then_branch: Vec, else_branch: Vec, - line_number: SourceLineNumber, + location: SourceLocation, }, AssertZero { operation: MathOperation, @@ -96,7 +86,7 @@ pub enum SimpleLine { function_name: String, args: Vec, return_data: Vec, - line_number: SourceLineNumber, + location: SourceLocation, }, FunctionRet { return_data: Vec, @@ -105,16 +95,15 @@ pub enum SimpleLine { table: Table, args: Vec, }, - Panic, + Panic { + message: Option, + }, // Hints /// each field element x is decomposed to: (a0, a1, a2, ..., a11, b) where: /// x = a0 + a1.4 + a2.4^2 + a3.4^3 + ... + a11.4^11 + b.2^24 /// and ai < 4, b < 2^7 - 1 /// The decomposition is unique, and always exists (except for x = -1) CustomHint(CustomHint, Vec), - PrivateInputStart { - result: Var, - }, Print { line_info: String, content: Vec, @@ -122,8 +111,6 @@ pub enum SimpleLine { HintMAlloc { var: Var, size: SimpleExpr, - vectorized: bool, - vectorized_len: SimpleExpr, }, ConstMalloc { // always not vectorized @@ -135,7 +122,12 @@ pub enum SimpleLine { LocationReport { location: SourceLocation, }, - DebugAssert(BooleanExpr, SourceLineNumber), + DebugAssert(BooleanExpr, SourceLocation), + /// Range check: assert val <= bound + RangeCheck { + val: SimpleExpr, + bound: SimpleExpr, + }, } impl SimpleLine { @@ -147,27 +139,83 @@ impl SimpleLine { arg1: SimpleExpr::zero(), } } + + /// Returns mutable references to all nested blocks (arms of match, branches of if). + pub fn nested_blocks_mut(&mut self) -> Vec<&mut Vec> { + match self { + Self::Match { arms, .. } => arms.iter_mut().collect(), + Self::IfNotZero { + then_branch, + else_branch, + .. + } => vec![then_branch, else_branch], + Self::ForwardDeclaration { .. } + | Self::Assignment { .. } + | Self::RawAccess { .. } + | Self::AssertZero { .. } + | Self::FunctionCall { .. } + | Self::FunctionRet { .. } + | Self::Precompile { .. } + | Self::Panic { .. } + | Self::CustomHint(..) + | Self::Print { .. } + | Self::HintMAlloc { .. } + | Self::ConstMalloc { .. } + | Self::LocationReport { .. } + | Self::DebugAssert(..) + | Self::RangeCheck { .. } => vec![], + } + } + + pub fn nested_blocks(&self) -> Vec<&Vec> { + match self { + Self::Match { arms, .. } => arms.iter().collect(), + Self::IfNotZero { + then_branch, + else_branch, + .. + } => vec![then_branch, else_branch], + Self::ForwardDeclaration { .. } + | Self::Assignment { .. } + | Self::RawAccess { .. } + | Self::AssertZero { .. } + | Self::FunctionCall { .. } + | Self::FunctionRet { .. } + | Self::Precompile { .. } + | Self::Panic { .. } + | Self::CustomHint(..) + | Self::Print { .. } + | Self::HintMAlloc { .. } + | Self::ConstMalloc { .. } + | Self::LocationReport { .. } + | Self::DebugAssert(..) + | Self::RangeCheck { .. } => vec![], + } + } } -pub fn simplify_program(mut program: Program) -> SimpleProgram { +fn ends_with_early_exit(block: &[SimpleLine]) -> bool { + match block.last() { + Some(SimpleLine::Panic { .. }) | Some(SimpleLine::FunctionRet { .. }) => true, + Some(last) => { + let nested = last.nested_blocks(); + !nested.is_empty() && nested.iter().all(|b| ends_with_early_exit(b)) + } + None => false, + } +} + +pub fn simplify_program(mut program: Program) -> Result { check_program_scoping(&program); - handle_inlined_functions(&mut program); - // Iterate between unrolling and const argument handling until fixed point let mut unroll_counter = Counter::new(); - let mut max_iterations = 100; - loop { - let mut any_change = false; + let mut inline_counter = Counter::new(); + compile_time_transform_in_program(&mut program, &mut unroll_counter, &mut inline_counter)?; - any_change |= unroll_loops_in_program(&mut program, &mut unroll_counter); - any_change |= handle_const_arguments(&mut program); + // Remove all inlined functions (they've been inlined) + program.functions.retain(|_, func| !func.inlined); - max_iterations -= 1; - assert!(max_iterations > 0, "Too many iterations while simplifying program"); - if !any_change { - break; - } - } + validate_program_vectors(&program)?; // Remove all const functions - they should all have been specialized by now let const_func_names: Vec<_> = program @@ -180,6 +228,9 @@ pub fn simplify_program(mut program: Program) -> SimpleProgram { program.functions.remove(&name); } + let mut mutable_loop_counter = MutableLoopTransformCounter::default(); + transform_mutable_in_loops_in_program(&mut program, &mut mutable_loop_counter); + let mut new_functions = BTreeMap::new(); let mut counters = Counters::default(); let mut const_malloc = ConstMalloc::default(); @@ -189,2387 +240,3126 @@ pub fn simplify_program(mut program: Program) -> SimpleProgram { }; for (name, func) in &program.functions { let mut array_manager = ArrayManager::default(); + let mut mut_tracker = MutableVarTracker::default(); + let mut vec_tracker = VectorTracker::default(); + + // Register mutable arguments and capture their initial versioned names + // BEFORE simplifying the body + let arguments: Vec = func + .arguments + .iter() + .map(|arg| { + assert!(!arg.is_const); + if arg.is_mutable { + mut_tracker.register_mutable(&arg.name); + // Capture the initial versioned name (version 0) + mut_tracker.current_name(&arg.name) + } else { + mut_tracker.assigned.insert(arg.name.clone()); + arg.name.clone() + } + }) + .collect(); + let mut state = SimplifyState { counters: &mut counters, array_manager: &mut array_manager, + mut_tracker: &mut mut_tracker, + vec_tracker: &mut vec_tracker, }; let simplified_instructions = simplify_lines( &ctx, &mut state, &mut const_malloc, &mut new_functions, - func.file_id, func.n_returned_vars, &func.body, false, - ); - let arguments = func - .arguments - .iter() - .map(|(v, is_const)| { - assert!(!is_const,); - v.clone() - }) - .collect::>(); + )?; let simplified_function = SimpleFunction { name: name.clone(), - file_id: func.file_id, arguments, n_returned_vars: func.n_returned_vars, instructions: simplified_instructions, }; - if !func.assume_always_returns { - check_function_always_returns(&simplified_function); - } + check_function_always_returns(&simplified_function)?; new_functions.insert(name.clone(), simplified_function); const_malloc.map.clear(); } - SimpleProgram { + Ok(SimpleProgram { functions: new_functions, - } + }) } -fn unroll_loops_in_program(program: &mut Program, unroll_counter: &mut Counter) -> bool { - let mut changed = false; - for func in program.functions.values_mut() { - changed |= unroll_loops_in_lines(&mut func.body, &program.const_arrays, unroll_counter); +#[derive(Debug, Clone, Default)] +pub struct VectorLenTracker { + vectors: BTreeMap, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum VectorLenValue { + Scalar, + Vector(Vec), +} + +impl VectorLenTracker { + fn register(&mut self, var: &Var, value: VectorLenValue) { + self.vectors.insert(var.clone(), value); + } + + fn is_vector(&self, var: &Var) -> bool { + self.vectors.contains_key(var) + } + + pub fn get(&self, var: &Var) -> Option<&VectorLenValue> { + self.vectors.get(var) + } + + fn get_mut(&mut self, var: &Var) -> Option<&mut VectorLenValue> { + self.vectors.get_mut(var) } - changed } -fn unroll_loops_in_lines( - lines: &mut Vec, - const_arrays: &BTreeMap, - unroll_counter: &mut Counter, -) -> bool { - let mut changed = false; - let mut i = 0; - while i < lines.len() { - // First, recursively process nested structures - match &mut lines[i] { - Line::ForLoop { body, .. } => { - changed |= unroll_loops_in_lines(body, const_arrays, unroll_counter); - } - Line::IfCondition { - then_branch, - else_branch, - .. - } => { - changed |= unroll_loops_in_lines(then_branch, const_arrays, unroll_counter); - changed |= unroll_loops_in_lines(else_branch, const_arrays, unroll_counter); - } - Line::Match { arms, .. } => { - for (_, arm_body) in arms { - changed |= unroll_loops_in_lines(arm_body, const_arrays, unroll_counter); - } - } - _ => {} +impl VectorLenValue { + pub fn push(&mut self, elem: Self) { + match self { + Self::Vector(v) => v.push(elem), + _ => panic!("push on scalar"), } + } - // Now try to unroll if it's an unrollable loop - if let Line::ForLoop { - iterator, - start, - end, - body, - rev, - unroll: true, - line_number: _, - } = &lines[i] - && let (Some(start_val), Some(end_val)) = (start.naive_eval(const_arrays), end.naive_eval(const_arrays)) - { - let start_usize = start_val.to_usize(); - let end_usize = end_val.to_usize(); - let unroll_index = unroll_counter.next(); - - let (internal_vars, _) = find_variable_usage(body, const_arrays); + pub fn pop(&mut self) -> Option { + match self { + Self::Vector(v) => v.pop(), + _ => panic!("pop on scalar"), + } + } - let mut range: Vec<_> = (start_usize..end_usize).collect(); - if *rev { - range.reverse(); - } + pub fn len(&self) -> usize { + match self { + Self::Vector(v) => v.len(), + _ => panic!("len on scalar"), + } + } - let iterator = iterator.clone(); - let body = body.clone(); + pub fn is_vector(&self) -> bool { + matches!(self, Self::Vector(_)) + } - let mut unrolled = Vec::new(); - for j in range { - let mut body_copy = body.clone(); - replace_vars_for_unroll(&mut body_copy, &iterator, unroll_index, j, &internal_vars); - unrolled.extend(body_copy); - } + fn get(&self, i: usize) -> Option<&Self> { + match self { + Self::Vector(v) => v.get(i), + _ => None, + } + } - let num_inserted = unrolled.len(); - lines.splice(i..=i, unrolled); - changed = true; - i += num_inserted; - continue; + fn get_mut(&mut self, i: usize) -> Option<&mut Self> { + match self { + Self::Vector(v) => v.get_mut(i), + _ => None, } + } - i += 1; + pub fn navigate(&self, idx: &[F]) -> Option<&Self> { + idx.iter().try_fold(self, |v, &i| v.get(i.to_usize())) } - changed + + pub fn navigate_mut(&mut self, idx: &[F]) -> Option<&mut Self> { + idx.iter().try_fold(self, |v, &i| v.get_mut(i.to_usize())) + } +} + +fn build_vector_len_value(elements: &[VecLiteral]) -> VectorLenValue { + let mut vec_elements = Vec::new(); + + for elem in elements { + let elem_len_value = build_vector_len_value_from_element(elem); + vec_elements.push(elem_len_value); + } + + VectorLenValue::Vector(vec_elements) } -/// Analyzes a simplified function to verify that it returns on each code path. -fn check_function_always_returns(func: &SimpleFunction) { - check_block_always_returns(&func.name, &func.instructions); +fn build_vector_len_value_from_element(element: &VecLiteral) -> VectorLenValue { + match element { + VecLiteral::Vec(inner) => build_vector_len_value(inner), + VecLiteral::Expr(_) => VectorLenValue::Scalar, + } } -fn check_block_always_returns(function_name: &String, instructions: &[SimpleLine]) { - match instructions.last() { - Some(SimpleLine::Match { value: _, arms }) => { - for arm in arms { - check_block_always_returns(function_name, arm); - } - } - Some(SimpleLine::IfNotZero { - condition: _, - then_branch, - else_branch, - line_number: _, - }) => { - check_block_always_returns(function_name, then_branch); - check_block_always_returns(function_name, else_branch); +fn compile_time_transform_in_program( + program: &mut Program, + unroll_counter: &mut Counter, + inline_counter: &mut Counter, +) -> Result<(), String> { + let const_arrays = program.const_arrays.clone(); + + // Collect inlined functions + let inlined_functions: BTreeMap<_, _> = program + .functions + .iter() + .filter(|(_, func)| func.inlined) + .map(|(name, func)| (name.clone(), func.clone())) + .collect(); + + for func in inlined_functions.values() { + if func.has_mutable_arguments() { + return Err("Inlined functions with mutable arguments are not supported yet".to_string()); } - Some(SimpleLine::FunctionRet { return_data: _ }) => { - // good + if func.has_const_arguments() { + return Err("Inlined functions with constant arguments are not supported yet".to_string()); } - Some(SimpleLine::Panic) => { - // good + } + + // Process all functions, including newly created specialized ones + let mut processed: BTreeSet = BTreeSet::new(); + loop { + let to_process: Vec<_> = program + .functions + .iter() + .filter(|(name, func)| !func.inlined && !func.has_const_arguments() && !processed.contains(*name)) + .map(|(name, _)| name.clone()) + .collect(); + + if to_process.is_empty() { + break; } - _ => { - panic!("Cannot prove that function always returns: {function_name}"); + + let existing_functions = program.functions.clone(); + for func_name in to_process { + processed.insert(func_name.clone()); + let func = program.functions.get_mut(&func_name).unwrap(); + let mut new_functions = BTreeMap::new(); + compile_time_transform_in_lines( + &mut func.body, + &const_arrays, + &existing_functions, + &inlined_functions, + &mut new_functions, + unroll_counter, + inline_counter, + )?; + // Add new specialized functions - they'll be processed in the next iteration of this loop + for (name, new_func) in new_functions { + program.functions.entry(name).or_insert(new_func); + } } } + Ok(()) } -/// Analyzes the program to verify that each variable is defined in each context where it is used. -fn check_program_scoping(program: &Program) { - for (_, function) in program.functions.iter() { - let mut scope = Scope { vars: BTreeSet::new() }; - for (arg, _) in function.arguments.iter() { - scope.vars.insert(arg.clone()); +fn compile_time_transform_in_lines( + lines: &mut Vec, + const_arrays: &BTreeMap, + existing_functions: &BTreeMap, + inlined_functions: &BTreeMap, + new_functions: &mut BTreeMap, + unroll_counter: &mut Counter, + inline_counter: &mut Counter, +) -> Result<(), String> { + let mut vector_len_tracker = VectorLenTracker::default(); + let mut const_var_exprs: BTreeMap = BTreeMap::new(); // used to simplify expressions containing variables with known constant values + + let mut i = 0; + while i < lines.len() { + let line = &mut lines[i]; + + for expr in line.expressions_mut() { + substitute_const_vars_in_expr(expr, &const_var_exprs); + compile_time_transform_in_expr(expr, const_arrays, &vector_len_tracker); } - let mut ctx = Context { - scopes: vec![scope], - const_arrays: program.const_arrays.clone(), - }; - check_block_scoping(&function.body, &mut ctx); - } -} + // Extract nested inlined calls from expressions (e.g., `x = a + inlined_func(b)`) + if let Some(new_lines) = extract_inlined_calls(line, inlined_functions, inline_counter)? { + lines.splice(i..=i, new_lines); + continue; + } -/// Analyzes the block to verify that each variable is defined in each context where it is used. -fn check_block_scoping(block: &[Line], ctx: &mut Context) { - for line in block.iter() { match line { - Line::ForwardDeclaration { var } => { - ctx.add_var(var); - } - Line::Match { value, arms } => { - check_expr_scoping(value, ctx); - for (_, arm) in arms { - ctx.scopes.push(Scope { vars: BTreeSet::new() }); - check_block_scoping(arm, ctx); - ctx.scopes.pop(); - } - } Line::Statement { targets, value, .. } => { - check_expr_scoping(value, ctx); - // First: add new variables to scope - for target in targets { - if let AssignmentTarget::Var(var) = target - && !ctx.defines(var) + if let Some(inlined) = try_inline_call(value, targets, inlined_functions, const_arrays, inline_counter) + { + lines.splice(i..=i, inlined); + continue; + } + if let Expression::FunctionCall { + function_name, args, .. + } = value + && let Some(func) = existing_functions.get(function_name.as_str()) + && func.has_const_arguments() + { + let mut const_evals = Vec::new(); + for (arg_expr, arg) in args.iter().zip(&func.arguments) { + if arg.is_const { + if let Some(const_eval) = arg_expr.as_scalar() { + const_evals.push((arg.name.clone(), const_eval)); + } else { + return Err(format!( + "Cannot evaluate const argument '{}' for function '{}'", + arg.name, function_name + )); + } + } + } + let const_funct_name = format!( + "{function_name}_{}", + const_evals + .iter() + .map(|(v, c)| format!("{v}={c}")) + .collect::>() + .join("_") + ); + *function_name = const_funct_name.clone(); + *args = args + .iter() + .zip(&func.arguments) + .filter(|(_, arg)| !arg.is_const) + .map(|(e, _)| e.clone()) + .collect(); + if !new_functions.contains_key(&const_funct_name) + && !existing_functions.contains_key(&const_funct_name) { - ctx.add_var(var); + let mut new_body = func.body.clone(); + replace_vars_by_const_in_lines(&mut new_body, &const_evals.iter().cloned().collect()); + new_functions.insert( + const_funct_name.clone(), + Function { + name: const_funct_name, + arguments: func.arguments.iter().filter(|a| !a.is_const).cloned().collect(), + inlined: false, + body: new_body, + n_returned_vars: func.n_returned_vars, + }, + ); } } - // Second pass: check array access targets - for target in targets { - if let AssignmentTarget::ArrayAccess { array: _, index } = target { - check_expr_scoping(index, ctx); + if targets.len() == 1 + && let AssignmentTarget::Var { var, is_mutable: false } = &targets[0] + && let Some(value_const) = value.as_scalar() + { + const_var_exprs.insert(var.clone(), value_const); + } + } + + Line::VecDeclaration { var, elements, .. } => { + vector_len_tracker.register(var, build_vector_len_value(elements)); + } + + Line::Push { + vector, + indices, + element, + .. + } => { + let Some(const_indices) = indices.iter().map(|idx| idx.as_scalar()).collect::>>() else { + return Err("push with non-constant indices".to_string()); + }; + let new_element = build_vector_len_value_from_element(element); + let vector_value = vector_len_tracker + .get_mut(vector) + .ok_or_else(|| "pushing to undeclared vector".to_string())?; + if const_indices.is_empty() { + vector_value.push(new_element); + } else { + let target = vector_value + .navigate_mut(&const_indices) + .ok_or_else(|| "push target index out of bounds".to_string())?; + if !target.is_vector() { + return Err("push target is not a vector".to_string()); } + target.push(new_element); } } - Line::Assert { boolean, .. } => { - check_boolean_scoping(boolean, ctx); + + Line::Pop { + vector, + indices, + location, + } => { + let Some(const_indices) = indices.iter().map(|idx| idx.as_scalar()).collect::>>() else { + return Err(format!("line {}: pop with non-constant indices", location)); + }; + let vector_value = vector_len_tracker + .get_mut(vector) + .ok_or_else(|| format!("line {}: pop on undeclared vector '{}'", location, vector))?; + if const_indices.is_empty() { + if vector_value.len() == 0 { + return Err(format!("line {}: pop on empty vector '{}'", location, vector)); + } + vector_value.pop(); + } else { + let target = vector_value + .navigate_mut(&const_indices) + .ok_or_else(|| format!("line {}: pop target index out of bounds", location))?; + if !target.is_vector() { + return Err(format!("line {}: pop target is not a vector", location)); + } + if target.len() == 0 { + return Err(format!("line {}: pop on empty vector", location)); + } + target.pop(); + } } + Line::IfCondition { condition, then_branch, else_branch, - line_number: _, + .. } => { - check_condition_scoping(condition, ctx); - for branch in [then_branch, else_branch] { - ctx.scopes.push(Scope { vars: BTreeSet::new() }); - check_block_scoping(branch, ctx); - ctx.scopes.pop(); + if let Some(constant_condition) = condition.eval_with(&|expr| expr.as_scalar()) { + let chosen_branch = if constant_condition { then_branch } else { else_branch }.clone(); + lines.splice(i..=i, chosen_branch); + continue; } } + Line::ForLoop { iterator, start, end, body, - rev: _, - unroll: _, - line_number: _, - } => { - check_expr_scoping(start, ctx); - check_expr_scoping(end, ctx); - let mut new_scope_vars = BTreeSet::new(); - new_scope_vars.insert(iterator.clone()); - ctx.scopes.push(Scope { vars: new_scope_vars }); - check_block_scoping(body, ctx); - ctx.scopes.pop(); - } - Line::FunctionRet { return_data } => { - for expr in return_data { - check_expr_scoping(expr, ctx); - } - } - Line::Precompile { table: _, args } => { - for arg in args { - check_expr_scoping(arg, ctx); - } - } - Line::Break | Line::Panic | Line::LocationReport { .. } => {} - Line::Print { line_info: _, content } => { - for expr in content { - check_expr_scoping(expr, ctx); - } - } - Line::MAlloc { - var, - size, - vectorized: _, - vectorized_len, + unroll: true, + .. } => { - check_expr_scoping(size, ctx); - check_expr_scoping(vectorized_len, ctx); - if !ctx.defines(var) { - ctx.add_var(var); - } - } - Line::CustomHint(_, args) => { - for arg in args { - check_expr_scoping(arg, ctx); - } - } - Line::PrivateInputStart { result } => { - ctx.add_var(result); + let (Some(start), Some(end)) = (start.as_scalar(), end.as_scalar()) else { + return Err("Cannot unroll loop with non-constant bounds".to_string()); + }; + let unroll_index = unroll_counter.get_next(); + let (internal_vars, _) = find_variable_usage(body, const_arrays); + let iterator = iterator.clone(); + let body = body.clone(); + let mut unrolled = Vec::new(); + for j in start.to_usize()..end.to_usize() { + let mut body_copy = body.clone(); + replace_vars_for_unroll(&mut body_copy, &iterator, unroll_index, j, &internal_vars); + unrolled.extend(body_copy); + } + lines.splice(i..=i, unrolled); + continue; } + _ => {} + } + + for block in lines[i].nested_blocks_mut() { + compile_time_transform_in_lines( + block, + const_arrays, + existing_functions, + inlined_functions, + new_functions, + unroll_counter, + inline_counter, + )?; } + i += 1; } + Ok(()) } -/// Analyzes the expression to verify that each variable is defined in the given context. -fn check_expr_scoping(expr: &Expression, ctx: &Context) { - match expr { - Expression::Value(simple_expr) => { - check_simple_expr_scoping(simple_expr, ctx); - } - Expression::ArrayAccess { array: _, index } => { - for idx in index { - check_expr_scoping(idx, ctx); - } - } - Expression::MathExpr(_, args) => { - for arg in args { - check_expr_scoping(arg, ctx); - } - } - Expression::FunctionCall { args, .. } => { - for arg in args { - check_expr_scoping(arg, ctx); - } - } - Expression::Len { indices, .. } => { - for idx in indices { - check_expr_scoping(idx, ctx); +/// Try to inline a function call. Returns Some(inlined_lines) if successful. +fn try_inline_call( + value: &Expression, + targets: &[AssignmentTarget], + inlined_functions: &BTreeMap, + const_arrays: &BTreeMap, + inline_counter: &mut Counter, +) -> Option> { + let Expression::FunctionCall { + function_name, + args, + location, + } = value + else { + return None; + }; + let func = inlined_functions.get(function_name)?; + + // If any arg is not simple, extract it first + if args.iter().any(|a| !matches!(a, Expression::Value(_))) { + let mut new_lines = vec![]; + let mut new_args = vec![]; + for arg in args { + if let Expression::Value(v) = arg { + new_args.push(Expression::Value(v.clone())); + } else { + let tmp = format!("@inline_arg_{}", inline_counter.get_next()); + new_lines.push(Line::ForwardDeclaration { + var: tmp.clone(), + is_mutable: false, + }); + new_lines.push(Line::Statement { + targets: vec![AssignmentTarget::Var { + var: tmp.clone(), + is_mutable: false, + }], + value: arg.clone(), + location: *location, + }); + new_args.push(Expression::var(tmp)); } } + new_lines.push(Line::Statement { + targets: targets.to_vec(), + value: Expression::FunctionCall { + function_name: function_name.clone(), + args: new_args, + location: *location, + }, + location: *location, + }); + return Some(new_lines); } + + // All args are simple - inline the function body + let args_map: BTreeMap = func + .arguments + .iter() + .zip(args) + .map(|(arg, expr)| { + let Expression::Value(v) = expr else { unreachable!() }; + (arg.name.clone(), v.clone()) + }) + .collect(); + + let mut body = func.body.clone(); + inline_lines(&mut body, &args_map, const_arrays, targets, inline_counter.get_next()); + Some(body) } -/// Analyzes the simple expression to verify that each variable is defined in the given context. -fn check_simple_expr_scoping(expr: &SimpleExpr, ctx: &Context) { - match expr { - SimpleExpr::Memory(VarOrConstMallocAccess::Var(v)) => { - assert!(ctx.defines(v), "Variable used but not defined: {v}"); +/// Extract nested inlined function calls from expressions, replacing them with temp vars. +fn extract_inlined_calls( + line: &mut Line, + inlined_functions: &BTreeMap, + counter: &mut Counter, +) -> Result>, String> { + fn extract( + expr: &mut Expression, + funcs: &BTreeMap, + counter: &mut Counter, + out: &mut Vec, + ) -> Result<(), String> { + for inner in expr.inner_exprs_mut() { + extract(inner, funcs, counter, out)?; + } + if let Expression::FunctionCall { + function_name, + args, + location, + } = expr + && let Some(func) = funcs.get(function_name) + { + if func.n_returned_vars != 1 { + return Err(format!( + "Inlined function '{}' with {} return values cannot appear in expression", + function_name, func.n_returned_vars + )); + } + let tmp = format!("@inline_tmp_{}", counter.get_next()); + out.push(Line::ForwardDeclaration { + var: tmp.clone(), + is_mutable: false, + }); + out.push(Line::Statement { + targets: vec![AssignmentTarget::Var { + var: tmp.clone(), + is_mutable: false, + }], + value: Expression::FunctionCall { + function_name: function_name.clone(), + args: args.clone(), + location: *location, + }, + location: *location, + }); + *expr = Expression::var(tmp); } - SimpleExpr::Memory(VarOrConstMallocAccess::ConstMallocAccess { .. }) => {} - SimpleExpr::Constant(_) => {} + Ok(()) } -} - -fn check_boolean_scoping(boolean: &BooleanExpr, ctx: &Context) { - check_expr_scoping(&boolean.left, ctx); - check_expr_scoping(&boolean.right, ctx); -} -fn check_condition_scoping(condition: &Condition, ctx: &Context) { - match condition { - Condition::AssumeBoolean(expr) => { - check_expr_scoping(expr, ctx); + let mut extractions = vec![]; + // For direct inlined calls, only extract from arguments; otherwise extract from all expressions + match line { + Line::Statement { + value: Expression::FunctionCall { + function_name, args, .. + }, + .. + } if inlined_functions.contains_key(function_name) => { + for arg in args.iter_mut() { + extract(arg, inlined_functions, counter, &mut extractions)?; + } } - Condition::Comparison(boolean) => { - check_boolean_scoping(boolean, ctx); + _ => { + for expr in line.expressions_mut() { + extract(expr, inlined_functions, counter, &mut extractions)?; + } } } -} -#[derive(Debug, Clone, Default)] -struct Counters { - aux_vars: usize, - loops: usize, -} - -impl Counters { - fn aux_var(&mut self) -> Var { - let var = format!("@aux_var_{}", self.aux_vars); - self.aux_vars += 1; - var + if extractions.is_empty() { + Ok(None) + } else { + extractions.push(line.clone()); + Ok(Some(extractions)) } } -struct SimplifyContext<'a> { - functions: &'a BTreeMap, - const_arrays: &'a BTreeMap, +fn compile_time_transform_in_expr( + expr: &mut Expression, + const_arrays: &BTreeMap, + vector_len_tracker: &VectorLenTracker, +) -> bool { + if expr.is_scalar() { + return false; + } + let mut changed = false; + for inner_expr in expr.inner_exprs_mut() { + changed |= compile_time_transform_in_expr(inner_expr, const_arrays, vector_len_tracker); + } + if let Some(scalar) = expr.compile_time_eval(const_arrays, vector_len_tracker) { + *expr = Expression::scalar(scalar); + changed = true; + } + changed } -struct SimplifyState<'a> { - counters: &'a mut Counters, - array_manager: &'a mut ArrayManager, -} +fn substitute_const_vars_in_expr(expr: &mut Expression, const_var_exprs: &BTreeMap) -> bool { + if let Expression::Value(SimpleExpr::Memory(VarOrConstMallocAccess::Var(var))) = expr + && let Some(replacement) = const_var_exprs.get(var) + { + *expr = Expression::scalar(*replacement); + return true; + } -#[derive(Debug, Clone, Default)] -struct ArrayManager { - counter: usize, - aux_vars: BTreeMap<(Var, Expression), Var>, // (array, index) -> aux_var - valid: BTreeSet, // currently valid aux vars + let mut changed = false; + for inner in expr.inner_exprs_mut() { + changed |= substitute_const_vars_in_expr(inner, const_var_exprs); + } + changed } -#[derive(Debug, Clone, Default)] -pub struct ConstMalloc { +// ============================================================================ +// TRANSFORMATION: Mutable variables in non-unrolled loops +// ============================================================================ +// +// This transformation handles mutable variables that are modified inside +// non-unrolled loops by using buffers to store intermediate values. +// +// For a loop like: +// for i in start..end { x += i; } +// +// We transform it to: +// size = end - start; +// x_buff = Array(size + 1); +// x_buff[0] = x; +// for i in start..end { +// buff_idx = i - start; +// mut x_body = x_buff[buff_idx]; +// x_body += i; +// x_buff[buff_idx + 1] = x_body; +// } +// x = x_buff[size]; + +/// Counter for generating unique variable names in the mutable loop transformation +#[derive(Default)] +struct MutableLoopTransformCounter { counter: usize, - map: BTreeMap, } -impl ArrayManager { - fn get_aux_var(&mut self, array: &Var, index: &Expression) -> Var { - if let Some(var) = self.aux_vars.get(&(array.clone(), index.clone())) { - return var.clone(); - } - let new_var = format!("@arr_aux_{}", self.counter); +impl MutableLoopTransformCounter { + fn next_suffix(&mut self) -> usize { + let c = self.counter; self.counter += 1; - self.aux_vars.insert((array.clone(), index.clone()), new_var.clone()); - new_var + c } } -#[allow(clippy::too_many_arguments)] -fn simplify_lines( - ctx: &SimplifyContext<'_>, - state: &mut SimplifyState<'_>, - const_malloc: &mut ConstMalloc, - new_functions: &mut BTreeMap, - file_id: FileId, - n_returned_vars: usize, +/// Finds mutable variables that are: +/// 1. Defined OUTSIDE this block (external) +/// 2. Re-assigned INSIDE this block +fn find_modified_external_vars(lines: &[Line], const_arrays: &BTreeMap) -> BTreeSet { + // Use the existing find_variable_usage to get external variables + // (variables that are read but not defined in this block) + let (internal_vars, external_vars) = find_variable_usage(lines, const_arrays); + + // Now find which external variables are assigned to (modified) + let mut modified_external_vars = BTreeSet::new(); + find_assigned_external_vars_helper( + lines, + const_arrays, + &internal_vars, + &external_vars, + &mut modified_external_vars, + ); + + modified_external_vars +} + +/// Helper to find external variables that are assigned to inside a block. +fn find_assigned_external_vars_helper( lines: &[Line], - in_a_loop: bool, -) -> Vec { - let mut res = Vec::new(); + const_arrays: &BTreeMap, + internal_vars: &BTreeSet, + external_vars: &BTreeSet, + modified_external_vars: &mut BTreeSet, +) { for line in lines { match line { - Line::ForwardDeclaration { var } => { - res.push(SimpleLine::ForwardDeclaration { var: var.clone() }); + Line::Statement { targets, .. } => { + for target in targets { + if let AssignmentTarget::Var { var, is_mutable } = target { + // Only non-mutable assignments can be modifications + // (is_mutable: true means it's the initial declaration) + if !*is_mutable + && external_vars.contains(var) + && !internal_vars.contains(var) + && !const_arrays.contains_key(var) + { + modified_external_vars.insert(var.clone()); + } + } + } } - Line::Match { value, arms } => { - let simple_value = simplify_expr(ctx, state, const_malloc, value, &mut res); - let mut simple_arms = vec![]; - for (i, (pattern, statements)) in arms.iter().enumerate() { - assert_eq!(*pattern, i, "match patterns should be consecutive, starting from 0"); - simple_arms.push(simplify_lines( - ctx, - state, - const_malloc, - new_functions, - file_id, - n_returned_vars, - statements, - in_a_loop, - )); + _ => { + for block in line.nested_blocks() { + find_assigned_external_vars_helper( + block, + const_arrays, + internal_vars, + external_vars, + modified_external_vars, + ); } - res.push(SimpleLine::Match { - value: simple_value, - arms: simple_arms, - }); } - Line::Statement { - targets, - value, - line_number, + } + } +} + +fn transform_mutable_in_loops_in_program(program: &mut Program, counter: &mut MutableLoopTransformCounter) { + for func in program.functions.values_mut() { + transform_mutable_in_loops_in_lines(&mut func.body, &program.const_arrays, counter); + } +} + +fn transform_mutable_in_loops_in_lines( + lines: &mut Vec, + const_arrays: &BTreeMap, + counter: &mut MutableLoopTransformCounter, +) { + let mut i = 0; + while i < lines.len() { + match &mut lines[i] { + Line::ForLoop { body, unroll: true, .. } => { + transform_mutable_in_loops_in_lines(body, const_arrays, counter); + i += 1; + } + Line::ForLoop { + iterator, + start, + end, + body, + unroll: false, + location, } => { - match value { - Expression::FunctionCall { function_name, args } => { - // Function call - may have zero, one, or multiple targets - let function = ctx.functions.get(function_name).unwrap_or_else(|| { - panic!("Function used but not defined: {function_name}, at line {line_number}") - }); - if targets.len() != function.n_returned_vars { - panic!( - "Expected {} returned vars (and not {}) in call to {function_name}, at line {line_number}", - function.n_returned_vars, - targets.len() - ); - } - if args.len() != function.arguments.len() { - panic!( - "Expected {} arguments (and not {}) in call to {function_name}, at line {line_number}", - function.arguments.len(), - args.len() - ); - } + transform_mutable_in_loops_in_lines(body, const_arrays, counter); + let modified_vars = find_modified_external_vars(body, const_arrays); - let simplified_args = args - .iter() - .map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res)) - .collect::>(); + if modified_vars.is_empty() { + // No mutable variables modified, no transformation needed + i += 1; + continue; + } - let mut temp_vars = Vec::new(); - let mut array_targets: Vec<(usize, Var, Box)> = Vec::new(); + let suffix = counter.next_suffix(); - for (i, target) in targets.iter().enumerate() { - match target { - AssignmentTarget::Var(var) => { - temp_vars.push(var.clone()); - } - AssignmentTarget::ArrayAccess { array, index } => { - temp_vars.push(state.counters.aux_var()); - array_targets.push((i, array.clone(), index.clone())); - } - } - } + // Generate the transformed code + let mut new_lines = Vec::new(); - res.push(SimpleLine::FunctionCall { - function_name: function_name.clone(), - args: simplified_args, - return_data: temp_vars.clone(), - line_number: *line_number, - }); + let location = *location; - // For array access targets, add DEREF instructions to copy temp to array element - for (i, array, index) in array_targets { - handle_array_assignment( - ctx, - state, - const_malloc, - &mut res, - &array, - &[*index], - ArrayAccessType::ArrayIsAssigned(Expression::Value( - VarOrConstMallocAccess::Var(temp_vars[i].clone()).into(), - )), - ); - } - } - _ => { - assert!(targets.len() == 1, "Non-function call must have exactly one target"); - let target = &targets[0]; + // Create size variable: @loop_size_{suffix} = end - start + // TODO opti if start is zero + let size_var = format!("@loop_size_{suffix}"); + new_lines.push(Line::Statement { + targets: vec![AssignmentTarget::Var { + var: size_var.clone(), + is_mutable: false, + }], + value: Expression::MathExpr(MathOperation::Sub, vec![end.clone(), start.clone()]), + location, + }); - match target { - AssignmentTarget::Var(var) => { - // Variable assignment - match value { - Expression::Value(val) => { - res.push(SimpleLine::equality(var.clone(), val.clone())); - } - Expression::ArrayAccess { array, index } => { - handle_array_assignment( - ctx, - state, - const_malloc, - &mut res, - array, - index, - ArrayAccessType::VarIsAssigned(var.clone()), - ); - } - Expression::MathExpr(operation, args) => { - let args_simplified = args - .iter() - .map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res)) - .collect::>(); - // If all operands are constants, evaluate at compile time and assign the result - if let Some(const_args) = SimpleExpr::try_vec_as_constant(&args_simplified) { - let result = ConstExpression::MathExpr(*operation, const_args); - res.push(SimpleLine::equality(var.clone(), SimpleExpr::Constant(result))); - } else { - // general case - res.push(SimpleLine::Assignment { - var: var.clone().into(), - operation: *operation, - arg0: args_simplified[0].clone(), - arg1: args_simplified[1].clone(), - }); - } - } - Expression::Len { .. } => unreachable!(), - Expression::FunctionCall { .. } => { - unreachable!("FunctionCall should be handled above") - } - } + let mut var_to_buff: BTreeMap = BTreeMap::new(); // var -> (buff_name, body_name) + + for var in &modified_vars { + let buff_name = format!("@loop_buff_{var}_{suffix}"); + let body_name = format!("@loop_body_{var}_{suffix}"); + + // buff = Array(size + 1) + new_lines.push(Line::Statement { + targets: vec![AssignmentTarget::Var { + var: buff_name.clone(), + is_mutable: false, + }], + value: Expression::FunctionCall { + function_name: "Array".to_string(), + args: vec![Expression::MathExpr( + // TODO opti in case there is only one mutated var + MathOperation::Add, + vec![Expression::var(size_var.clone()), Expression::one()], + )], + location, + }, + location, + }); + + // buff[0] = var (current value) + new_lines.push(Line::Statement { + targets: vec![AssignmentTarget::ArrayAccess { + array: buff_name.clone(), + index: Box::new(Expression::zero()), + }], + value: Expression::var(var.clone()), + location, + }); + + var_to_buff.insert(var.clone(), (buff_name, body_name)); + } + + // Transform the loop body + let iterator = iterator.clone(); + let mut new_body = Vec::new(); + + // buff_idx = i - start + let buff_idx_var = format!("@loop_buff_idx_{suffix}"); + new_body.push(Line::Statement { + targets: vec![AssignmentTarget::Var { + var: buff_idx_var.clone(), + is_mutable: false, + }], + value: Expression::MathExpr( + MathOperation::Sub, + vec![Expression::var(iterator.clone()), start.clone()], // TODO opti if start is zero + ), + location, + }); + + // For each modified variable: mut body_var = buff[buff_idx] + for (var, (buff_name, body_name)) in &var_to_buff { + new_body.push(Line::Statement { + targets: vec![AssignmentTarget::Var { + var: body_name.clone(), + is_mutable: true, + }], + value: Expression::ArrayAccess { + array: buff_name.clone(), + index: vec![Expression::Value( + VarOrConstMallocAccess::Var(buff_idx_var.clone()).into(), + )], + }, + location, + }); + + // Replace all references to var with body_name in the original body + replace_var_in_lines(body, var, body_name); + } + + // Add the original body (now modified to use body_vars) + new_body.append(body); + + // next_idx = buff_idx + 1 + let next_idx_var = format!("@loop_next_idx_{suffix}"); + new_body.push(Line::Statement { + targets: vec![AssignmentTarget::Var { + var: next_idx_var.clone(), + is_mutable: false, + }], + value: Expression::MathExpr( + MathOperation::Add, + vec![Expression::var(buff_idx_var.clone()), Expression::one()], + ), + location, + }); + + // For each modified variable: buff[next_idx] = body_var + for (buff_name, body_name) in var_to_buff.values() { + new_body.push(Line::Statement { + targets: vec![AssignmentTarget::ArrayAccess { + array: buff_name.clone(), + index: Expression::var(next_idx_var.clone()).into(), + }], + value: Expression::var(body_name.clone()), + location, + }); + } + + // Create the new loop + new_lines.push(Line::ForLoop { + iterator: iterator.clone(), + start: start.clone(), + end: end.clone(), + body: new_body, + unroll: false, + location, + }); + + // After the loop: var = buff[size] + for (var, (buff_name, _body_name)) in &var_to_buff { + new_lines.push(Line::Statement { + targets: vec![AssignmentTarget::Var { + var: var.clone(), + is_mutable: false, + }], + value: Expression::ArrayAccess { + array: buff_name.clone(), + index: vec![Expression::var(size_var.clone())], + }, + location, + }); + } + + // Replace the original loop with the new lines + let num_new = new_lines.len(); + lines.splice(i..=i, new_lines); + i += num_new; + } + line @ (Line::IfCondition { .. } | Line::Match { .. }) => { + for block in line.nested_blocks_mut() { + transform_mutable_in_loops_in_lines(block, const_arrays, counter); + } + i += 1; + } + _ => { + i += 1; + } + } + } +} + +/// Replaces all occurrences of a variable with another variable in a list of lines. +/// This is used to replace references to mutable variables with their body counterparts. +fn replace_var_in_lines(lines: &mut [Line], old_var: &Var, new_var: &Var) { + for line in lines { + match line { + Line::ForwardDeclaration { var, .. } => { + if var == old_var { + *var = new_var.clone(); + } + } + Line::Statement { targets, .. } => { + for target in targets { + match target { + AssignmentTarget::Var { var, .. } => { + if var == old_var { + *var = new_var.clone(); } - AssignmentTarget::ArrayAccess { array, index } => { - // Array element assignment - handle_array_assignment( - ctx, - state, - const_malloc, - &mut res, - array, - std::slice::from_ref(&**index), - ArrayAccessType::ArrayIsAssigned(value.clone()), - ); + } + AssignmentTarget::ArrayAccess { array, index } => { + if array == old_var { + *array = new_var.clone(); } + replace_var_in_expr(index, old_var, new_var); } } } } - Line::Assert { - boolean, - line_number, - debug, - } => { - let left = simplify_expr(ctx, state, const_malloc, &boolean.left, &mut res); - let right = simplify_expr(ctx, state, const_malloc, &boolean.right, &mut res); + _ => {} + } + for expr in line.expressions_mut() { + replace_var_in_expr(expr, old_var, new_var); + } + for block in line.nested_blocks_mut() { + replace_var_in_lines(block, old_var, new_var); + } + } +} - if *debug { - res.push(SimpleLine::DebugAssert( - BooleanExpr { - left, - right, - kind: boolean.kind, - }, - *line_number, - )); - } else { - match boolean.kind { - Boolean::Different => { - let diff_var = state.counters.aux_var(); - res.push(SimpleLine::Assignment { - var: diff_var.clone().into(), - operation: MathOperation::Sub, - arg0: left, - arg1: right, - }); - res.push(SimpleLine::IfNotZero { - condition: diff_var.into(), - then_branch: vec![], - else_branch: vec![SimpleLine::Panic], - line_number: *line_number, - }); - } - Boolean::Equal => { - let (var, other): (VarOrConstMallocAccess, _) = if let Ok(left) = left.clone().try_into() { - (left, right) - } else if let Ok(right) = right.clone().try_into() { - (right, left) - } else { - // Both are constants - evaluate at compile time - if let (SimpleExpr::Constant(left_const), SimpleExpr::Constant(right_const)) = - (&left, &right) - && let (Some(left_val), Some(right_val)) = - (left_const.naive_eval(), right_const.naive_eval()) - { - if left_val == right_val { - // Assertion passes at compile time, no code needed - continue; - } else { - panic!( - "Compile-time assertion failed: {} != {} (lines {})", - left_val.to_usize(), - right_val.to_usize(), - line_number - ); - } - } - panic!("Unsupported equality assertion: {left:?}, {right:?}") - }; - res.push(SimpleLine::equality(var, other)); - } - Boolean::LessThan => unreachable!(), +fn replace_var_in_expr(expr: &mut Expression, old_var: &Var, new_var: &Var) { + match expr { + Expression::Value(simple_expr) => { + if let SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) = simple_expr + && var == old_var + { + *var = new_var.clone(); + } + } + Expression::ArrayAccess { array, .. } => { + if array == old_var { + *array = new_var.clone(); + } + } + _ => {} + } + for inner_expr in expr.inner_exprs_mut() { + replace_var_in_expr(inner_expr, old_var, new_var); + } +} + +fn check_function_always_returns(func: &SimpleFunction) -> Result<(), String> { + check_block_always_returns(&func.name, &func.instructions) +} + +fn check_block_always_returns(function_name: &String, instructions: &[SimpleLine]) -> Result<(), String> { + if let Some(last_instruction) = instructions.last() { + if matches!( + last_instruction, + SimpleLine::FunctionRet { return_data: _ } | SimpleLine::Panic { .. } + ) { + return Ok(()); + } + let inner_blocks = last_instruction.nested_blocks(); + if !inner_blocks.is_empty() { + for block in inner_blocks { + check_block_always_returns(function_name, block)?; + } + return Ok(()); + } + } + Err(format!("Cannot prove that function always returns: {function_name}")) +} + +fn check_program_scoping(program: &Program) { + for (_, function) in program.functions.iter() { + let mut scope = Scope { vars: BTreeSet::new() }; + for arg in function.arguments.iter() { + scope.vars.insert(arg.name.clone()); + } + let mut ctx = Context { + scopes: vec![scope], + const_arrays: program.const_arrays.clone(), + }; + + check_block_scoping(&function.body, &mut ctx); + } +} + +fn check_block_scoping(block: &[Line], ctx: &mut Context) { + for line in block.iter() { + match line { + Line::ForwardDeclaration { var, .. } => { + ctx.add_var(var); + } + Line::Match { value, arms, .. } => { + check_expr_scoping(value, ctx); + for (_, arm) in arms { + ctx.scopes.push(Scope { vars: BTreeSet::new() }); + check_block_scoping(arm, ctx); + ctx.scopes.pop(); + } + } + Line::Statement { targets, value, .. } => { + check_expr_scoping(value, ctx); + // First: add new variables to scope + for target in targets { + if let AssignmentTarget::Var { var, .. } = target + && !ctx.defines(var) + { + ctx.add_var(var); + } + } + // Second pass: check array access targets + for target in targets { + if let AssignmentTarget::ArrayAccess { array: _, index } = target { + check_expr_scoping(index, ctx); } } } + Line::Assert { boolean, .. } => { + check_boolean_scoping(boolean, ctx); + } Line::IfCondition { condition, then_branch, else_branch, - line_number, + location: _, } => { - let (condition_simplified, then_branch, else_branch) = match condition { - Condition::Comparison(condition) => { - // Transform if a == b then X else Y into if a != b then Y else X - - let (left, right, then_branch, else_branch) = match condition.kind { - Boolean::Equal => (&condition.left, &condition.right, else_branch, then_branch), // switched - Boolean::Different => (&condition.left, &condition.right, then_branch, else_branch), - Boolean::LessThan => unreachable!(), - }; - - let left_simplified = simplify_expr(ctx, state, const_malloc, left, &mut res); - let right_simplified = simplify_expr(ctx, state, const_malloc, right, &mut res); - - let diff_var = state.counters.aux_var(); - res.push(SimpleLine::Assignment { - var: diff_var.clone().into(), - operation: MathOperation::Sub, - arg0: left_simplified, - arg1: right_simplified, - }); - (diff_var.into(), then_branch, else_branch) - } - Condition::AssumeBoolean(condition) => { - let condition_simplified = simplify_expr(ctx, state, const_malloc, condition, &mut res); - (condition_simplified, then_branch, else_branch) - } - }; - - let mut array_manager_then = state.array_manager.clone(); - let mut state_then = SimplifyState { - counters: state.counters, - array_manager: &mut array_manager_then, - }; - let then_branch_simplified = simplify_lines( - ctx, - &mut state_then, - const_malloc, - new_functions, - file_id, - n_returned_vars, - then_branch, - in_a_loop, - ); - let mut array_manager_else = array_manager_then.clone(); - array_manager_else.valid = state.array_manager.valid.clone(); // Crucial: remove the access added in the IF branch - - let mut state_else = SimplifyState { - counters: state.counters, - array_manager: &mut array_manager_else, - }; - let else_branch_simplified = simplify_lines( - ctx, - &mut state_else, - const_malloc, - new_functions, - file_id, - n_returned_vars, - else_branch, - in_a_loop, - ); - - *state.array_manager = array_manager_else.clone(); - // keep the intersection both branches - state.array_manager.valid = state - .array_manager - .valid - .intersection(&array_manager_then.valid) - .cloned() - .collect(); - - res.push(SimpleLine::IfNotZero { - condition: condition_simplified, - then_branch: then_branch_simplified, - else_branch: else_branch_simplified, - line_number: *line_number, - }); + check_condition_scoping(condition, ctx); + for branch in [then_branch, else_branch] { + ctx.scopes.push(Scope { vars: BTreeSet::new() }); + check_block_scoping(branch, ctx); + ctx.scopes.pop(); + } } Line::ForLoop { iterator, start, end, body, - rev, - unroll, - line_number, + unroll: _, + location: _, } => { - assert!(!*unroll, "Unrolled loops should have been handled already"); - - if *rev { - unimplemented!("Reverse for non-unrolled loops are not implemented yet"); + check_expr_scoping(start, ctx); + check_expr_scoping(end, ctx); + let mut new_scope_vars = BTreeSet::new(); + new_scope_vars.insert(iterator.clone()); + ctx.scopes.push(Scope { vars: new_scope_vars }); + check_block_scoping(body, ctx); + ctx.scopes.pop(); + } + Line::FunctionRet { return_data } => { + for expr in return_data { + check_expr_scoping(expr, ctx); } - - let mut loop_const_malloc = ConstMalloc { - counter: const_malloc.counter, - ..ConstMalloc::default() - }; - let valid_aux_vars_in_array_manager_before = state.array_manager.valid.clone(); - state.array_manager.valid.clear(); - let simplified_body = simplify_lines( - ctx, - state, - &mut loop_const_malloc, - new_functions, - file_id, - 0, - body, - true, - ); - const_malloc.counter = loop_const_malloc.counter; - state.array_manager.valid = valid_aux_vars_in_array_manager_before; // restore the valid aux vars - - let func_name = format!("@loop_{}_line_{}", state.counters.loops, line_number); - state.counters.loops += 1; - - // Find variables used inside loop but defined outside - let (_, mut external_vars) = find_variable_usage(body, ctx.const_arrays); - - // Include variables in start/end - for expr in [start, end] { - for var in vars_in_expression(expr, ctx.const_arrays) { - external_vars.insert(var); - } + } + Line::Panic { .. } | Line::LocationReport { .. } => {} + Line::VecDeclaration { var, elements, .. } => { + // Check expressions in vec elements + check_vec_literal_scoping(elements, ctx); + // Add the vector variable to scope + ctx.add_var(var); + } + Line::Push { + vector, + indices, + element, + .. + } => { + // Check the vector variable is in scope + assert!(ctx.defines(vector), "Vector variable '{}' not in scope", vector); + // Check indices are in scope + for idx in indices { + check_expr_scoping(idx, ctx); } - external_vars.remove(iterator); // Iterator is internal to loop + // Check the pushed element + check_vec_literal_element_scoping(element, ctx); + } + Line::Pop { vector, indices, .. } => { + // Check the vector variable is in scope + assert!(ctx.defines(vector), "Vector variable '{}' not in scope", vector); + // Check indices are in scope + for idx in indices { + check_expr_scoping(idx, ctx); + } + } + } + } +} - let mut external_vars: Vec<_> = external_vars.into_iter().collect(); +fn validate_program_vectors(program: &Program) -> Result<(), String> { + let inlined_functions = program.inlined_function_names(); + for f in program.functions.values() { + validate_vectors(&f.body, &BTreeSet::new(), &inlined_functions, None)?; + } + Ok(()) +} - let start_simplified = simplify_expr(ctx, state, const_malloc, start, &mut res); - let mut end_simplified = simplify_expr(ctx, state, const_malloc, end, &mut res); - if let SimpleExpr::Memory(VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset }) = - end_simplified.clone() - { - // we use an auxilary variable to store the end value (const malloc inside non-unrolled loops does not work) - let aux_end_var = state.counters.aux_var(); - res.push(SimpleLine::equality( - aux_end_var.clone(), - VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset }, - )); - end_simplified = VarOrConstMallocAccess::Var(aux_end_var).into(); - } +fn validate_vectors( + lines: &[Line], + outer: &BTreeSet, + inlined: &BTreeSet, + restrict: Option, +) -> Result<(), String> { + let mut local: BTreeSet = BTreeSet::new(); + macro_rules! all { + () => { + outer.union(&local).cloned().collect::>() + }; + } - for (simplified, original) in [ - (start_simplified.clone(), start.clone()), - (end_simplified.clone(), end.clone()), - ] { - if !matches!(original, Expression::Value(_)) { - // the simplified var is auxiliary - if let SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) = simplified { - external_vars.push(var); - } - } + for line in lines { + match line { + Line::VecDeclaration { var, elements, .. } => { + local.insert(var.clone()); + validate_vec_lit(elements, &all!(), inlined)?; + } + Line::Push { + vector, + element, + location, + .. + } => { + if restrict.is_some() && outer.contains(vector) { + return Err(format!("line {}: push to outer-scope vector '{}'", location, vector)); + } + validate_vec_lit(std::slice::from_ref(element), &all!(), inlined)?; + if !local.contains(vector) && !outer.contains(vector) { + return Err(format!("line {}: unknown vector '{}'", location, vector)); } - - // Create function arguments: iterator + external variables - let mut func_args = vec![iterator.clone()]; - func_args.extend(external_vars.clone()); - - // Create recursive function body - let recursive_func = create_recursive_function( - func_name.clone(), - SourceLocation { - line_number: *line_number, - file_id, - }, - func_args, - iterator.clone(), - end_simplified, - simplified_body, - &external_vars, - ); - new_functions.insert(func_name.clone(), recursive_func); - - // Replace loop with initial function call - let mut call_args = vec![start_simplified]; - call_args.extend(external_vars.iter().map(|v| v.clone().into())); - - res.push(SimpleLine::FunctionCall { - function_name: func_name, - args: call_args, - return_data: vec![], - line_number: *line_number, - }); - } - Line::FunctionRet { return_data } => { - assert!(!in_a_loop, "Function return inside a loop is not currently supported"); - assert!( - return_data.len() == n_returned_vars, - "Wrong number of return values in return statement; expected {n_returned_vars} but got {}", - return_data.len() - ); - let simplified_return_data = return_data - .iter() - .map(|ret| simplify_expr(ctx, state, const_malloc, ret, &mut res)) - .collect::>(); - res.push(SimpleLine::FunctionRet { - return_data: simplified_return_data, - }); - } - Line::Precompile { table, args } => { - let simplified_args = args - .iter() - .map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res)) - .collect::>(); - res.push(SimpleLine::Precompile { - table: *table, - args: simplified_args, - }); } - Line::Print { line_info, content } => { - let simplified_content = content - .iter() - .map(|var| simplify_expr(ctx, state, const_malloc, var, &mut res)) - .collect::>(); - res.push(SimpleLine::Print { - line_info: line_info.clone(), - content: simplified_content, - }); + Line::Pop { vector, location, .. } => { + if restrict.is_some() && outer.contains(vector) { + return Err(format!("line {}: pop from outer-scope vector '{}'", location, vector)); + } + if !local.contains(vector) && !outer.contains(vector) { + return Err(format!("line {}: unknown vector '{}'", location, vector)); + } } - Line::Break => { - assert!(in_a_loop, "Break statement outside of a loop"); - res.push(SimpleLine::FunctionRet { return_data: vec![] }); + Line::Statement { value, .. } => { + check_vec_in_call(value, &all!(), inlined)?; } - Line::MAlloc { - var, - size, - vectorized, - vectorized_len, + Line::IfCondition { + then_branch, + else_branch, + location, + .. } => { - let simplified_size = simplify_expr(ctx, state, const_malloc, size, &mut res); - let simplified_vectorized_len = simplify_expr(ctx, state, const_malloc, vectorized_len, &mut res); - match simplified_size { - SimpleExpr::Constant(const_size) if !*vectorized => { - let label = const_malloc.counter; - const_malloc.counter += 1; - const_malloc.map.insert(var.clone(), label); - res.push(SimpleLine::ConstMalloc { - var: var.clone(), - size: const_size, - label, - }); - } - _ => { - res.push(SimpleLine::HintMAlloc { - var: var.clone(), - size: simplified_size, - vectorized: *vectorized, - vectorized_len: simplified_vectorized_len, - }); - } - } - } - Line::PrivateInputStart { result } => { - res.push(SimpleLine::PrivateInputStart { result: result.clone() }); + validate_vectors(then_branch, &all!(), inlined, Some(*location))?; + validate_vectors(else_branch, &all!(), inlined, Some(*location))?; } - Line::CustomHint(hint, args) => { - let simplified_args = args - .iter() - .map(|expr| simplify_expr(ctx, state, const_malloc, expr, &mut res)) - .collect::>(); - res.push(SimpleLine::CustomHint(*hint, simplified_args)); + Line::ForLoop { + body, unroll, location, .. + } => { + validate_vectors(body, &all!(), inlined, if *unroll { None } else { Some(*location) })?; } - Line::Panic => { - res.push(SimpleLine::Panic); + Line::Match { arms, location, .. } => { + for (_, arm) in arms { + validate_vectors(arm, &all!(), inlined, Some(*location))?; + } } - Line::LocationReport { location } => { - res.push(SimpleLine::LocationReport { location: *location }); + _ => {} + } + } + Ok(()) +} + +fn validate_vec_lit(elems: &[VecLiteral], vecs: &BTreeSet, inlined: &BTreeSet) -> Result<(), String> { + for e in elems { + match e { + VecLiteral::Expr(expr) => check_vec_in_call(expr, vecs, inlined)?, + VecLiteral::Vec(inner) => validate_vec_lit(inner, vecs, inlined)?, + } + } + Ok(()) +} + +fn check_vec_in_call(expr: &Expression, vecs: &BTreeSet, inlined: &BTreeSet) -> Result<(), String> { + if let Expression::FunctionCall { + function_name, + args, + location, + } = expr + && !inlined.contains(function_name) + { + for arg in args { + if let Expression::Value(SimpleExpr::Memory(VarOrConstMallocAccess::Var(v))) = arg + && vecs.contains(v) + { + return Err(format!( + "line {}: vector '{}' passed to function '{}'", + location, v, function_name + )); } } } + Ok(()) +} - res +fn check_vec_literal_scoping(elements: &[VecLiteral], ctx: &Context) { + for elem in elements { + check_vec_literal_element_scoping(elem, ctx); + } } -fn simplify_expr( - ctx: &SimplifyContext<'_>, - state: &mut SimplifyState<'_>, - const_malloc: &ConstMalloc, - expr: &Expression, - lines: &mut Vec, -) -> SimpleExpr { +fn check_vec_literal_element_scoping(elem: &VecLiteral, ctx: &Context) { + match elem { + VecLiteral::Expr(expr) => check_expr_scoping(expr, ctx), + VecLiteral::Vec(inner) => check_vec_literal_scoping(inner, ctx), + } +} + +fn check_expr_scoping(expr: &Expression, ctx: &Context) { + if let Expression::Value(simple_expr) = expr { + check_simple_expr_scoping(simple_expr, ctx); + } + for inner_expr in expr.inner_exprs() { + check_expr_scoping(inner_expr, ctx); + } +} + +fn check_simple_expr_scoping(expr: &SimpleExpr, ctx: &Context) { match expr { - Expression::Value(value) => value.clone(), - Expression::ArrayAccess { array, index } => { - // Check for const array access - if let Some(arr) = ctx.const_arrays.get(array) { - let simplified_index = index - .iter() - .map(|idx| { - simplify_expr(ctx, state, const_malloc, idx, lines) - .as_constant() - .expect("Const array access index should be constant") - .naive_eval() - .expect("Const array access index should be constant") - .to_usize() - }) - .collect::>(); + SimpleExpr::Memory(VarOrConstMallocAccess::Var(v)) => { + assert!(ctx.defines(v), "Variable used but not defined: {v}"); + } + SimpleExpr::Memory(VarOrConstMallocAccess::ConstMallocAccess { .. }) | SimpleExpr::Constant(_) => {} + } +} - return SimpleExpr::Constant(ConstExpression::from( - arr.navigate(&simplified_index) - .expect("Const array access index out of bounds") - .as_scalar() - .expect("Const array access should return a scalar"), - )); - } +fn check_boolean_scoping(boolean: &BooleanExpr, ctx: &Context) { + check_expr_scoping(&boolean.left, ctx); + check_expr_scoping(&boolean.right, ctx); +} - assert_eq!(index.len(), 1); - let index = index[0].clone(); +fn check_condition_scoping(condition: &Condition, ctx: &Context) { + match condition { + Condition::AssumeBoolean(expr) => { + check_expr_scoping(expr, ctx); + } + Condition::Comparison(boolean) => { + check_boolean_scoping(boolean, ctx); + } + } +} - if let Some(label) = const_malloc.map.get(array) - && let Ok(offset) = ConstExpression::try_from(index.clone()) - { - return VarOrConstMallocAccess::ConstMallocAccess { - malloc_label: *label, - offset, - } - .into(); - } +#[derive(Debug, Clone, Default)] +struct Counters { + aux_vars: Counter, + loops: Counter, +} - let aux_arr = state.array_manager.get_aux_var(array, &index); // auxiliary var to store m[array + index] +impl Counters { + fn aux_var(&mut self) -> Var { + let var = format!("@aux_var_{}", self.aux_vars.get_next()); + var + } +} - if !state.array_manager.valid.insert(aux_arr.clone()) { - return VarOrConstMallocAccess::Var(aux_arr).into(); - } +struct SimplifyContext<'a> { + functions: &'a BTreeMap, + const_arrays: &'a BTreeMap, +} - handle_array_assignment( - ctx, - state, - const_malloc, - lines, - array, - &[index], - ArrayAccessType::VarIsAssigned(aux_arr.clone()), - ); - VarOrConstMallocAccess::Var(aux_arr).into() +struct SimplifyState<'a> { + counters: &'a mut Counters, + array_manager: &'a mut ArrayManager, + mut_tracker: &'a mut MutableVarTracker, + vec_tracker: &'a mut VectorTracker, +} + +#[derive(Debug, Clone, Default)] +struct ArrayManager { + counter: usize, + aux_vars: BTreeMap<(Var, Expression), Var>, // (array, index) -> aux_var + valid: BTreeSet, // currently valid aux vars +} + +/// Tracks the current "version" of each mutable variable for SSA-like transformation +#[derive(Debug, Clone, Default, PartialEq, Eq)] +struct MutableVarTracker { + /// For mutable variables: maps original variable name -> current version number (0 = original) + versions: BTreeMap, + /// Tracks assigned immutable variables to detect illegal reassignment + assigned: BTreeSet, +} + +impl MutableVarTracker { + fn is_mutable(&self, var: &Var) -> bool { + self.versions.contains_key(var) + } + + fn register_mutable(&mut self, var: &Var) { + self.versions.insert(var.clone(), 0); + } + + fn current_name(&self, var: &Var) -> Var { + if self.is_mutable(var) { + format!("@mut_{var}_{}", self.versions.get(var).copied().unwrap_or(0)) + } else { + var.clone() } - Expression::MathExpr(operation, args) => { - let simplified_args = args - .iter() - .map(|arg| simplify_expr(ctx, state, const_malloc, arg, lines)) - .collect::>(); - if let Some(const_args) = SimpleExpr::try_vec_as_constant(&simplified_args) { - return SimpleExpr::Constant(ConstExpression::MathExpr(*operation, const_args)); - } - let aux_var = state.counters.aux_var(); - assert_eq!(simplified_args.len(), 2); - lines.push(SimpleLine::Assignment { - var: aux_var.clone().into(), - operation: *operation, - arg0: simplified_args[0].clone(), - arg1: simplified_args[1].clone(), - }); - VarOrConstMallocAccess::Var(aux_var).into() + } + + fn current_version(&self, var: &Var) -> usize { + self.versions.get(var).copied().unwrap_or(0) + } + + fn increment_version(&mut self, var: &Var) -> Var { + let version = self.versions.entry(var.clone()).or_insert(0); + *version += 1; + format!("@mut_{var}_{version}") + } + + fn check_immutable_assignment(&mut self, var: &Var) -> Result<(), String> { + if var.starts_with('@') || self.assigned.insert(var.clone()) { + Ok(()) + } else { + Err(format!( + "Cannot reassign immutable variable '{var}'. Use 'mut {var}' for mutable variables, or 'assert {var} == ;' to check equality" + )) } - Expression::FunctionCall { function_name, args } => { - let function = ctx - .functions - .get(function_name) - .unwrap_or_else(|| panic!("Function used but not defined: {function_name}")); - assert_eq!( - function.n_returned_vars, 1, - "Nested function call to '{function_name}' must return exactly 1 value, but returns {}", - function.n_returned_vars - ); + } - let simplified_args = args + /// Unifies mutable variable versions across multiple branches. + /// Returns forward declarations to add before the branching construct. + fn unify_branch_versions( + &mut self, + snapshot_versions: &BTreeMap, + branch_versions: &[BTreeMap], + branches: &mut [Vec], + ) -> Vec { + let mut forward_decls = Vec::new(); + + let branch_exits_early: Vec = branches.iter().map(|b| ends_with_early_exit(b)).collect(); + + for var in self.versions.clone().keys() { + let snapshot_v = snapshot_versions.get(var).copied().unwrap_or(0); + let versions: Vec = branch_versions .iter() - .map(|arg| simplify_expr(ctx, state, const_malloc, arg, lines)) - .collect::>(); + .map(|v| v.get(var).copied().unwrap_or(0)) + .collect(); - // Create a temporary variable for the function result - let result_var = state.counters.aux_var(); + // Only consider versions from branches that don't exit early for unification + let continuing_versions: Vec = versions + .iter() + .zip(branch_exits_early.iter()) + .filter(|&(_, exits)| !exits) + .map(|(&v, _)| v) + .collect(); + + // If all branches exit early, no unification needed - just keep the snapshot version + if continuing_versions.is_empty() { + self.versions.insert(var.clone(), snapshot_v); + continue; + } + + // Check if all continuing branches have the same version + if continuing_versions.iter().all(|&v| v == continuing_versions[0]) { + // All continuing branches have the same version + let branch_v = continuing_versions[0]; + if branch_v > snapshot_v { + // A new versioned variable was created in all continuing branches + let versioned_var = format!("@mut_{var}_{branch_v}"); + forward_decls.push(SimpleLine::ForwardDeclaration { + var: versioned_var.clone(), + }); + // Remove forward declarations from inside the branches to avoid shadowing + for branch in branches.iter_mut() { + remove_forward_declarations(branch, &versioned_var); + } + } + self.versions.insert(var.clone(), branch_v); + } else { + // Versions differ among continuing branches - need to unify + let max_version = continuing_versions.iter().copied().max().unwrap(); + let unified_version = max_version + 1; + let unified_var = format!("@mut_{var}_{unified_version}"); - lines.push(SimpleLine::FunctionCall { - function_name: function_name.clone(), - args: simplified_args, - return_data: vec![result_var.clone()], - line_number: 0, // No source line number for nested calls - }); + forward_decls.push(SimpleLine::ForwardDeclaration { + var: unified_var.clone(), + }); + + // Add equality assignment at the end of each branch that doesn't exit early + for (branch_idx, branch_v) in versions.iter().enumerate() { + if branch_exits_early[branch_idx] { + // Skip branches that exit early - they never reach code after the if/match + continue; + } + let branch_var_name: Var = format!("@mut_{var}_{branch_v}"); + branches[branch_idx].push(SimpleLine::equality(unified_var.clone(), branch_var_name)); + } - VarOrConstMallocAccess::Var(result_var).into() + self.versions.insert(var.clone(), unified_version); + } } - Expression::Len { .. } => unreachable!(), + + forward_decls } } -/// Returns (internal_vars, external_vars) -pub fn find_variable_usage( - lines: &[Line], - const_arrays: &BTreeMap, -) -> (BTreeSet, BTreeSet) { - let mut internal_vars = BTreeSet::new(); - let mut external_vars = BTreeSet::new(); +/// Compile-time vector. Scalars hold variable names; Vectors hold nested values. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum VectorValue { + Scalar { var: Var }, + Vector(Vec), +} - let on_new_expr = |expr: &Expression, internal_vars: &BTreeSet, external_vars: &mut BTreeSet| { - for var in vars_in_expression(expr, const_arrays) { - if !internal_vars.contains(&var) && !const_arrays.contains_key(&var) { - external_vars.insert(var); - } +impl VectorValue { + pub fn len(&self) -> usize { + match self { + Self::Vector(v) => v.len(), + _ => panic!("len on scalar"), } - }; + } - let on_new_condition = - |condition: &Condition, internal_vars: &BTreeSet, external_vars: &mut BTreeSet| match condition { - Condition::Comparison(comp) => { - on_new_expr(&comp.left, internal_vars, external_vars); - on_new_expr(&comp.right, internal_vars, external_vars); - } - Condition::AssumeBoolean(expr) => { - on_new_expr(expr, internal_vars, external_vars); - } - }; + pub fn is_vector(&self) -> bool { + matches!(self, Self::Vector(_)) + } + + fn get(&self, i: usize) -> Option<&Self> { + match self { + Self::Vector(v) => v.get(i), + _ => None, + } + } + + fn get_mut(&mut self, i: usize) -> Option<&mut Self> { + match self { + Self::Vector(v) => v.get_mut(i), + _ => None, + } + } + + pub fn navigate(&self, idx: &[usize]) -> Option<&Self> { + idx.iter().try_fold(self, |v, &i| v.get(i)) + } + + pub fn navigate_mut(&mut self, idx: &[usize]) -> Option<&mut Self> { + idx.iter().try_fold(self, |v, &i| v.get_mut(i)) + } + + pub fn push(&mut self, elem: Self) { + match self { + Self::Vector(v) => v.push(elem), + _ => panic!("push on scalar"), + } + } + + pub fn pop(&mut self) -> Option { + match self { + Self::Vector(v) => v.pop(), + _ => panic!("pop on scalar"), + } + } +} + +#[derive(Debug, Clone, Default)] +struct VectorTracker { + vectors: BTreeMap, +} + +impl VectorTracker { + fn register(&mut self, var: &Var, value: VectorValue) { + self.vectors.insert(var.clone(), value); + } + fn is_vector(&self, var: &Var) -> bool { + self.vectors.contains_key(var) + } + + fn get(&self, var: &Var) -> Option<&VectorValue> { + self.vectors.get(var) + } + + fn get_mut(&mut self, var: &Var) -> Option<&mut VectorValue> { + self.vectors.get_mut(var) + } +} + +#[derive(Debug, Clone, Default)] +pub struct ConstMalloc { + counter: usize, + map: BTreeMap, +} + +impl ArrayManager { + fn get_aux_var(&mut self, array: &Var, index: &Expression) -> Var { + if let Some(var) = self.aux_vars.get(&(array.clone(), index.clone())) { + return var.clone(); + } + let new_var = format!("@arr_aux_{}", self.counter); + self.counter += 1; + self.aux_vars.insert((array.clone(), index.clone()), new_var.clone()); + new_var + } +} + +fn build_vector_value( + ctx: &SimplifyContext<'_>, + state: &mut SimplifyState<'_>, + const_malloc: &mut ConstMalloc, + elements: &[VecLiteral], + lines: &mut Vec, + location: SourceLocation, +) -> Result { + let mut vec_elements = Vec::new(); + + for elem in elements { + vec_elements.push(build_vector_value_from_element( + ctx, + state, + const_malloc, + elem, + lines, + location, + )?); + } + + Ok(VectorValue::Vector(vec_elements)) +} + +fn build_vector_value_from_element( + ctx: &SimplifyContext<'_>, + state: &mut SimplifyState<'_>, + const_malloc: &mut ConstMalloc, + element: &VecLiteral, + lines: &mut Vec, + location: SourceLocation, +) -> Result { + match element { + VecLiteral::Vec(inner) => build_vector_value(ctx, state, const_malloc, inner, lines, location), + VecLiteral::Expr(expr) => { + // Scalar expression - create auxiliary variable and emit assignment + let aux_var = state.counters.aux_var(); + let simplified_value = simplify_expr(ctx, state, const_malloc, expr, lines)?; + lines.push(SimpleLine::equality(aux_var.clone(), simplified_value)); + Ok(VectorValue::Scalar { var: aux_var }) + } + } +} + +#[allow(clippy::too_many_arguments)] +fn simplify_lines( + ctx: &SimplifyContext<'_>, + state: &mut SimplifyState<'_>, + const_malloc: &mut ConstMalloc, + new_functions: &mut BTreeMap, + n_returned_vars: usize, + lines: &[Line], + in_a_loop: bool, +) -> Result, String> { + let mut res = Vec::new(); for line in lines { match line { - Line::ForwardDeclaration { var } => { - internal_vars.insert(var.clone()); + Line::ForwardDeclaration { var, is_mutable } => { + if *is_mutable { + state.mut_tracker.register_mutable(var); + } + let versioned_var = if *is_mutable { + state.mut_tracker.current_name(var) + } else { + var.clone() + }; + res.push(SimpleLine::ForwardDeclaration { var: versioned_var }); } - Line::Match { value, arms } => { - on_new_expr(value, &internal_vars, &mut external_vars); - for (_, statements) in arms { - let (stmt_internal, stmt_external) = find_variable_usage(statements, const_arrays); - internal_vars.extend(stmt_internal); - external_vars.extend(stmt_external.into_iter().filter(|v| !internal_vars.contains(v))); + Line::Match { value, arms, .. } => { + let simple_value = simplify_expr(ctx, state, const_malloc, value, &mut res)?; + + // Snapshot state before processing arms + let mut_tracker_snapshot = state.mut_tracker.clone(); + let array_manager_snapshot = state.array_manager.clone(); + + let mut simple_arms = vec![]; + let mut arm_versions = vec![]; + + for (i, (pattern, statements)) in arms.iter().enumerate() { + assert_eq!(*pattern, i, "match patterns should be consecutive, starting from 0"); + + // Restore snapshot for each arm + *state.mut_tracker = mut_tracker_snapshot.clone(); + *state.array_manager = array_manager_snapshot.clone(); + + let arm_simplified = simplify_lines( + ctx, + state, + const_malloc, + new_functions, + n_returned_vars, + statements, + in_a_loop, + )?; + simple_arms.push(arm_simplified); + arm_versions.push(state.mut_tracker.versions.clone()); } + + // Unify mutable variable versions across all arms + let forward_decls = state.mut_tracker.unify_branch_versions( + &mut_tracker_snapshot.versions, + &arm_versions, + &mut simple_arms, + ); + res.extend(forward_decls); + + // Restore array manager to snapshot state + *state.array_manager = array_manager_snapshot; + + res.push(SimpleLine::Match { + value: simple_value, + arms: simple_arms, + }); } - Line::Statement { targets, value, .. } => { - on_new_expr(value, &internal_vars, &mut external_vars); - for target in targets { - match target { - AssignmentTarget::Var(var) => { - internal_vars.insert(var.clone()); + Line::Statement { + targets, + value, + location, + } => { + // Helper function to get the target variable name, handling mutable variable versioning + let get_target_var_name = + |state: &mut SimplifyState<'_>, var: &Var, is_mutable: bool| -> Result { + if is_mutable { + // First assignment with `mut` - register as mutable + state.mut_tracker.register_mutable(var); + // Return versioned name so subsequent reads can find it + Ok(state.mut_tracker.current_name(var)) + } else if state.mut_tracker.is_mutable(var) { + // Increment version and get new variable name + Ok(state.mut_tracker.increment_version(var)) + } else { + // Check for reassignment of immutable variable + state.mut_tracker.check_immutable_assignment(var)?; + Ok(var.clone()) } - AssignmentTarget::ArrayAccess { array, index } => { - assert!(!const_arrays.contains_key(array), "Cannot assign to const array"); - if !internal_vars.contains(array) { - external_vars.insert(array.clone()); + }; + + match value { + Expression::FunctionCall { + function_name, args, .. + } => { + // Special handling for Array builtin + if function_name == "Array" { + if args.len() != 1 { + return Err(format!( + "Array expects exactly 1 argument, got {}, at {location}", + args.len() + )); } - on_new_expr(index, &internal_vars, &mut external_vars); + if targets.len() != 1 { + return Err(format!( + "Array expects exactly 1 return target, got {}, at {location}", + targets.len() + )); + } + let target = &targets[0]; + match target { + AssignmentTarget::Var { var, is_mutable } => { + let target_var = get_target_var_name(state, var, *is_mutable)?; + let simplified_size = simplify_expr(ctx, state, const_malloc, &args[0], &mut res)?; + match simplified_size { + SimpleExpr::Constant(const_size) => { + let label = const_malloc.counter; + const_malloc.counter += 1; + const_malloc.map.insert(target_var.clone(), label); + res.push(SimpleLine::ConstMalloc { + var: target_var, + size: const_size, + label, + }); + } + _ => { + res.push(SimpleLine::HintMAlloc { + var: target_var, + size: simplified_size, + }); + } + } + } + AssignmentTarget::ArrayAccess { .. } => { + return Err(format!( + "Array does not support array access as return target, at {location}" + )); + } + } + continue; + } + + // Special handling for print builtin + if function_name == "print" { + if !targets.is_empty() { + return Err(format!("print should not return values, at {location}")); + } + let simplified_content = args + .iter() + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res)) + .collect::, _>>()?; + res.push(SimpleLine::Print { + line_info: format!("line {}", location.line_number), + content: simplified_content, + }); + continue; + } + + // Special handling for precompile functions (poseidon16, dot_product) + if let Some(table) = ALL_TABLES.into_iter().find(|p| p.name() == function_name) + && !table.is_execution_table() + { + if !targets.is_empty() { + return Err(format!( + "Precompile {function_name} should not return values, at {location}" + )); + } + let simplified_args = args + .iter() + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res)) + .collect::, _>>()?; + res.push(SimpleLine::Precompile { + table, + args: simplified_args, + }); + continue; + } + + // Special handling for custom hints + if let Some(hint) = CustomHint::find_by_name(function_name) { + if !targets.is_empty() { + return Err(format!( + "Custom hint {function_name} should not return values, at {location}" + )); + } + if !hint.n_args_range().contains(&args.len()) { + return Err(format!( + "Custom hint {function_name}: invalid number of arguments, at {location}" + )); + } + let simplified_args = args + .iter() + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res)) + .collect::, _>>()?; + res.push(SimpleLine::CustomHint(hint, simplified_args)); + continue; + } + + // Regular function call - may have zero, one, or multiple targets + let function = ctx + .functions + .get(function_name) + .ok_or_else(|| format!("Function used but not defined: {function_name}, at {location}"))?; + if targets.len() != function.n_returned_vars { + return Err(format!( + "Expected {} returned vars (and not {}) in call to {function_name}, at {location}", + function.n_returned_vars, + targets.len() + )); + } + if args.len() != function.arguments.len() { + return Err(format!( + "Expected {} arguments (and not {}) in call to {function_name}, at {location}", + function.arguments.len(), + args.len() + )); + } + + let simplified_args = args + .iter() + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res)) + .collect::, _>>()?; + + let mut temp_vars = Vec::new(); + let mut array_targets: Vec<(usize, Var, Box)> = Vec::new(); + + for (i, target) in targets.iter().enumerate() { + match target { + AssignmentTarget::Var { var, is_mutable } => { + let target_var = get_target_var_name(state, var, *is_mutable)?; + // Add forward declaration for new versioned variable + if *is_mutable || state.mut_tracker.current_version(var) > 0 { + res.push(SimpleLine::ForwardDeclaration { + var: target_var.clone(), + }); + } + temp_vars.push(target_var); + } + AssignmentTarget::ArrayAccess { array, index } => { + temp_vars.push(state.counters.aux_var()); + array_targets.push((i, array.clone(), index.clone())); + } + } + } + + res.push(SimpleLine::FunctionCall { + function_name: function_name.clone(), + args: simplified_args, + return_data: temp_vars.clone(), + location: *location, + }); + + // For array access targets, add DEREF instructions to copy temp to array element + for (i, array, index) in array_targets { + let simplified_index = simplify_expr(ctx, state, const_malloc, &index, &mut res)?; + let simplified_value = VarOrConstMallocAccess::Var(temp_vars[i].clone()).into(); + handle_array_assignment( + ctx, + state, + &mut res, + &array, + &[simplified_index], + ArrayAccessType::ArrayIsAssigned(simplified_value), + ); + } + } + _ => { + assert!(targets.len() == 1, "Non-function call must have exactly one target"); + let target = &targets[0]; + + match target { + AssignmentTarget::Var { var, is_mutable } => { + // IMPORTANT: Simplify RHS BEFORE updating version tracker + // This ensures the RHS uses the current (old) version of any mutable variables + match value { + Expression::Value(val) => { + let simplified_val = simplify_expr( + ctx, + state, + const_malloc, + &Expression::Value(val.clone()), + &mut res, + )?; + let target_var = get_target_var_name(state, var, *is_mutable)?; + if state.mut_tracker.is_mutable(var) + && state.mut_tracker.current_version(var) > 0 + { + res.push(SimpleLine::ForwardDeclaration { + var: target_var.clone(), + }); + } + res.push(SimpleLine::equality(target_var, simplified_val)); + } + Expression::ArrayAccess { array, index } => { + // Check if array is a vector (needs to be handled before simplifying indices) + let versioned_array = state.mut_tracker.current_name(array); + if state.vec_tracker.is_vector(&versioned_array) { + // Use simplify_expr which handles vectors correctly + let simplified_val = simplify_expr( + ctx, + state, + const_malloc, + &Expression::ArrayAccess { + array: array.clone(), + index: index.clone(), + }, + &mut res, + )?; + let target_var = get_target_var_name(state, var, *is_mutable)?; + res.push(SimpleLine::equality(target_var, simplified_val)); + } else { + // Pre-simplify indices before version update + let simplified_index = index + .iter() + .map(|idx| simplify_expr(ctx, state, const_malloc, idx, &mut res)) + .collect::, _>>()?; + let target_var = get_target_var_name(state, var, *is_mutable)?; + if state.mut_tracker.is_mutable(var) + && state.mut_tracker.current_version(var) > 0 + { + res.push(SimpleLine::ForwardDeclaration { + var: target_var.clone(), + }); + } + handle_array_assignment( + ctx, + state, + &mut res, + array, + &simplified_index, + ArrayAccessType::VarIsAssigned(target_var), + ); + } + } + Expression::MathExpr(operation, args) => { + let args_simplified = args + .iter() + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res)) + .collect::, _>>()?; + let target_var = get_target_var_name(state, var, *is_mutable)?; + if state.mut_tracker.is_mutable(var) + && state.mut_tracker.current_version(var) > 0 + { + res.push(SimpleLine::ForwardDeclaration { + var: target_var.clone(), + }); + } + // If all operands are constants, evaluate at compile time + if let Some(const_args) = SimpleExpr::try_vec_as_constant(&args_simplified) { + let result = ConstExpression::MathExpr(*operation, const_args); + res.push(SimpleLine::equality(target_var, SimpleExpr::Constant(result))); + } else { + res.push(SimpleLine::Assignment { + var: target_var.into(), + operation: *operation, + arg0: args_simplified[0].clone(), + arg1: args_simplified[1].clone(), + }); + } + } + Expression::Len { .. } => unreachable!(), + Expression::FunctionCall { .. } => { + unreachable!("FunctionCall should be handled above") + } + } + } + AssignmentTarget::ArrayAccess { array, index } => { + // Array element assignment - pre-simplify index first + let simplified_index = simplify_expr(ctx, state, const_malloc, index, &mut res)?; + + // Optimization: direct math assignment to const_malloc array with constant index + if let SimpleExpr::Constant(offset) = &simplified_index + && let Some(label) = const_malloc.map.get(array) + && let Expression::MathExpr(operation, args) = value + { + let var = VarOrConstMallocAccess::ConstMallocAccess { + malloc_label: *label, + offset: offset.clone(), + }; + let simplified_args = args + .iter() + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res)) + .collect::, _>>()?; + // If all operands are constants, evaluate at compile time + if let Some(const_args) = SimpleExpr::try_vec_as_constant(&simplified_args) { + let result = ConstExpression::MathExpr(*operation, const_args); + res.push(SimpleLine::equality(var, SimpleExpr::Constant(result))); + } else { + assert_eq!(simplified_args.len(), 2); + res.push(SimpleLine::Assignment { + var, + operation: *operation, + arg0: simplified_args[0].clone(), + arg1: simplified_args[1].clone(), + }); + } + } else { + // General case: pre-simplify value and use handle_array_assignment + let simplified_value = simplify_expr(ctx, state, const_malloc, value, &mut res)?; + handle_array_assignment( + ctx, + state, + &mut res, + array, + &[simplified_index], + ArrayAccessType::ArrayIsAssigned(simplified_value), + ); + } + } + } + } + } + } + Line::Assert { + boolean, + debug, + location, + } => { + let left = simplify_expr(ctx, state, const_malloc, &boolean.left, &mut res)?; + let right = simplify_expr(ctx, state, const_malloc, &boolean.right, &mut res)?; + if *debug { + res.push(SimpleLine::DebugAssert( + BooleanExpr { + left, + right, + kind: boolean.kind, + }, + *location, + )); + } else { + match boolean.kind { + Boolean::Different => { + let diff_var = state.counters.aux_var(); + res.push(SimpleLine::Assignment { + var: diff_var.clone().into(), + operation: MathOperation::Sub, + arg0: left, + arg1: right, + }); + res.push(SimpleLine::IfNotZero { + condition: diff_var.into(), + then_branch: vec![], + else_branch: vec![SimpleLine::Panic { message: None }], + location: *location, + }); } + Boolean::Equal => { + let (var, other): (VarOrConstMallocAccess, _) = if let Ok(left) = left.clone().try_into() { + (left, right) + } else if let Ok(right) = right.clone().try_into() { + (right, left) + } else { + // Both are constants - evaluate at compile time + if let (SimpleExpr::Constant(left_const), SimpleExpr::Constant(right_const)) = + (&left, &right) + && let (Some(left_val), Some(right_val)) = + (left_const.naive_eval(), right_const.naive_eval()) + { + if left_val == right_val { + // Assertion passes at compile time, no code needed + continue; + } else { + return Err(format!( + "Compile-time assertion failed: {} != {} ({})", + left_val.to_usize(), + right_val.to_usize(), + location + )); + } + } + return Err(format!("Unsupported equality assertion: {left:?}, {right:?}")); + }; + res.push(SimpleLine::equality(var, other)); + } + Boolean::LessThan => { + // assert left < right is equivalent to assert left <= right - 1 + let bound_minus_one = state.counters.aux_var(); + res.push(SimpleLine::Assignment { + var: bound_minus_one.clone().into(), + operation: MathOperation::Sub, + arg0: right, + arg1: SimpleExpr::one(), + }); + + // We add a debug assert for sanity + res.push(SimpleLine::DebugAssert( + BooleanExpr { + kind: Boolean::LessOrEqual, + left: left.clone(), + right: bound_minus_one.clone().into(), + }, + *location, + )); + + res.push(SimpleLine::RangeCheck { + val: left, + bound: bound_minus_one.into(), + }); + } + Boolean::LessOrEqual => { + // Range check: assert left <= right + + // we add a debug assert for sanity + res.push(SimpleLine::DebugAssert( + BooleanExpr { + kind: Boolean::LessOrEqual, + left: left.clone(), + right: right.clone(), + }, + *location, + )); + + res.push(SimpleLine::RangeCheck { + val: left, + bound: right, + }); + } + } + } + } + Line::IfCondition { + condition, + then_branch, + else_branch, + location, + } => { + let (condition_simplified, then_branch, else_branch) = match condition { + Condition::Comparison(condition) => { + // Transform if a == b then X else Y into if a != b then Y else X + + let (left, right, then_branch, else_branch) = match condition.kind { + Boolean::Equal => (&condition.left, &condition.right, else_branch, then_branch), // switched + Boolean::Different => (&condition.left, &condition.right, then_branch, else_branch), + Boolean::LessThan | Boolean::LessOrEqual => unreachable!(), + }; + + let left_simplified = simplify_expr(ctx, state, const_malloc, left, &mut res)?; + let right_simplified = simplify_expr(ctx, state, const_malloc, right, &mut res)?; + + let diff_var = state.counters.aux_var(); + res.push(SimpleLine::Assignment { + var: diff_var.clone().into(), + operation: MathOperation::Sub, + arg0: left_simplified, + arg1: right_simplified, + }); + (diff_var.into(), then_branch, else_branch) } - } - } - Line::IfCondition { - condition, - then_branch, - else_branch, - line_number: _, - } => { - on_new_condition(condition, &internal_vars, &mut external_vars); + Condition::AssumeBoolean(condition) => { + let condition_simplified = simplify_expr(ctx, state, const_malloc, condition, &mut res)?; - let (then_internal, then_external) = find_variable_usage(then_branch, const_arrays); - let (else_internal, else_external) = find_variable_usage(else_branch, const_arrays); + (condition_simplified, then_branch, else_branch) + } + }; - internal_vars.extend(then_internal.union(&else_internal).cloned()); - external_vars.extend( - then_external - .union(&else_external) - .filter(|v| !internal_vars.contains(*v)) - .cloned(), - ); - } - Line::Assert { boolean, .. } => { - on_new_condition( - &Condition::Comparison(boolean.clone()), - &internal_vars, - &mut external_vars, + // Snapshot state before processing branches + let mut_tracker_snapshot = state.mut_tracker.clone(); + let vec_tracker_snapshot = state.vec_tracker.clone(); + + let mut array_manager_then = state.array_manager.clone(); + let mut mut_tracker_then = state.mut_tracker.clone(); + let mut vec_tracker_then = state.vec_tracker.clone(); + let mut state_then = SimplifyState { + counters: state.counters, + array_manager: &mut array_manager_then, + mut_tracker: &mut mut_tracker_then, + vec_tracker: &mut vec_tracker_then, + }; + let then_branch_simplified = simplify_lines( + ctx, + &mut state_then, + const_malloc, + new_functions, + n_returned_vars, + then_branch, + in_a_loop, + )?; + let then_versions = mut_tracker_then.versions.clone(); + + let mut array_manager_else = array_manager_then.clone(); + array_manager_else.valid = state.array_manager.valid.clone(); // Crucial: remove the access added in the IF branch + + // Restore state for else branch + let mut mut_tracker_else = mut_tracker_snapshot.clone(); + let mut vec_tracker_else = vec_tracker_snapshot.clone(); + + let mut state_else = SimplifyState { + counters: state.counters, + array_manager: &mut array_manager_else, + mut_tracker: &mut mut_tracker_else, + vec_tracker: &mut vec_tracker_else, + }; + let else_branch_simplified = simplify_lines( + ctx, + &mut state_else, + const_malloc, + new_functions, + n_returned_vars, + else_branch, + in_a_loop, + )?; + let else_versions = mut_tracker_else.versions.clone(); + + // Unify mutable variable versions across both branches + let branch_versions = vec![then_versions, else_versions]; + let mut branches = vec![then_branch_simplified, else_branch_simplified]; + let forward_decls = state.mut_tracker.unify_branch_versions( + &mut_tracker_snapshot.versions, + &branch_versions, + &mut branches, ); - } - Line::FunctionRet { return_data } => { - for ret in return_data { - on_new_expr(ret, &internal_vars, &mut external_vars); - } - } - Line::MAlloc { var, size, .. } => { - on_new_expr(size, &internal_vars, &mut external_vars); - internal_vars.insert(var.clone()); - } - Line::Precompile { table: _, args } => { - for arg in args { - on_new_expr(arg, &internal_vars, &mut external_vars); - } - } - Line::Print { content, .. } => { - for var in content { - on_new_expr(var, &internal_vars, &mut external_vars); - } - } - Line::PrivateInputStart { result } => { - internal_vars.insert(result.clone()); - } - Line::CustomHint(_, args) => { - for expr in args { - on_new_expr(expr, &internal_vars, &mut external_vars); - } + res.extend(forward_decls); + let [then_branch_simplified, else_branch_simplified] = <[_; 2]>::try_from(branches).unwrap(); + + *state.array_manager = array_manager_else.clone(); + // keep the intersection both branches + state.array_manager.valid = state + .array_manager + .valid + .intersection(&array_manager_then.valid) + .cloned() + .collect(); + + res.push(SimpleLine::IfNotZero { + condition: condition_simplified, + then_branch: then_branch_simplified, + else_branch: else_branch_simplified, + location: *location, + }); } Line::ForLoop { iterator, start, end, body, - rev: _, - unroll: _, - line_number: _, + unroll, + location, } => { - let (body_internal, body_external) = find_variable_usage(body, const_arrays); - internal_vars.extend(body_internal); - internal_vars.insert(iterator.clone()); - external_vars.extend(body_external.difference(&internal_vars).cloned()); - on_new_expr(start, &internal_vars, &mut external_vars); - on_new_expr(end, &internal_vars, &mut external_vars); - } - Line::Panic | Line::Break | Line::LocationReport { .. } => {} - } - } + assert!(!*unroll, "Unrolled loops should have been handled already"); - (internal_vars, external_vars) -} + let mut loop_const_malloc = ConstMalloc { + counter: const_malloc.counter, + ..ConstMalloc::default() + }; + let valid_aux_vars_in_array_manager_before = state.array_manager.valid.clone(); + state.array_manager.valid.clear(); -fn inline_simple_expr(simple_expr: &mut SimpleExpr, args: &BTreeMap, inlining_count: usize) { - if let SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) = simple_expr { - if let Some(replacement) = args.get(var) { - *simple_expr = replacement.clone(); - } else { - *var = format!("@inlined_var_{inlining_count}_{var}"); - } - } -} + // Loop body becomes a separate function, so immutable assignments inside + // shouldn't affect outer scope (but mutable variable versions persist) + let assigned_before = std::mem::take(&mut state.mut_tracker.assigned); + let simplified_body = simplify_lines(ctx, state, &mut loop_const_malloc, new_functions, 0, body, true)?; + state.mut_tracker.assigned = assigned_before; -fn inline_expr(expr: &mut Expression, args: &BTreeMap, inlining_count: usize) { - match expr { - Expression::Value(value) => { - inline_simple_expr(value, args, inlining_count); - } - Expression::ArrayAccess { array, index } => { - if let Some(replacement) = args.get(array) { - let SimpleExpr::Memory(VarOrConstMallocAccess::Var(new_array)) = replacement else { - panic!("Cannot inline array access with non-variable array argument"); - }; - *array = new_array.clone(); - } else { - *array = format!("@inlined_var_{inlining_count}_{array}"); - } - for idx in index { - inline_expr(idx, args, inlining_count); - } - } - Expression::MathExpr(_, math_args) => { - for arg in math_args { - inline_expr(arg, args, inlining_count); - } - } - Expression::FunctionCall { args: func_args, .. } => { - for arg in func_args { - inline_expr(arg, args, inlining_count); - } - } - Expression::Len { indices, .. } => { - for idx in indices { - inline_expr(idx, args, inlining_count); - } - } - } -} + const_malloc.counter = loop_const_malloc.counter; + state.array_manager.valid = valid_aux_vars_in_array_manager_before; // restore the valid aux vars -fn inline_lines( - lines: &mut Vec, - args: &BTreeMap, - res: &[AssignmentTarget], - inlining_count: usize, -) { - let inline_comparison = |comparison: &mut BooleanExpr| { - inline_expr(&mut comparison.left, args, inlining_count); - inline_expr(&mut comparison.right, args, inlining_count); - }; + let func_name = format!("@loop_{}_{}", state.counters.loops.get_next(), location); - let inline_condition = |condition: &mut Condition| match condition { - Condition::Comparison(comparison) => inline_comparison(comparison), - Condition::AssumeBoolean(expr) => inline_expr(expr, args, inlining_count), - }; + // Find variables used inside loop but defined outside + let (_, mut external_vars) = find_variable_usage(body, ctx.const_arrays); - let inline_internal_var = |var: &mut Var| { - assert!( - !args.contains_key(var), - "Variable {var} is both an argument and declared in the inlined function" - ); - *var = format!("@inlined_var_{inlining_count}_{var}"); - }; + // Include variables in start/end + for expr in [start, end] { + for var in vars_in_expression(expr, ctx.const_arrays) { + external_vars.insert(var); + } + } + external_vars.remove(iterator); // Iterator is internal to loop - let mut lines_to_replace = vec![]; - for (i, line) in lines.iter_mut().enumerate() { - match line { - Line::ForwardDeclaration { var } => { - inline_internal_var(var); - } - Line::Match { value, arms } => { - inline_expr(value, args, inlining_count); - for (_, statements) in arms { - inline_lines(statements, args, res, inlining_count); + let mut external_vars: Vec<_> = external_vars + .into_iter() + .map(|var| state.mut_tracker.current_name(&var)) + .collect(); + + let start_simplified = simplify_expr(ctx, state, const_malloc, start, &mut res)?; + let mut end_simplified = simplify_expr(ctx, state, const_malloc, end, &mut res)?; + if let SimpleExpr::Memory(VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset }) = + end_simplified.clone() + { + // we use an auxilary variable to store the end value (const malloc inside non-unrolled loops does not work) + let aux_end_var = state.counters.aux_var(); + res.push(SimpleLine::equality( + aux_end_var.clone(), + VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset }, + )); + end_simplified = VarOrConstMallocAccess::Var(aux_end_var).into(); } - } - Line::Statement { targets, value, .. } => { - inline_expr(value, args, inlining_count); - for target in targets { - match target { - AssignmentTarget::Var(var) => { - inline_internal_var(var); - } - AssignmentTarget::ArrayAccess { array, index } => { - if let Some(replacement) = args.get(array) { - // Array is a function argument - replace with the argument's var name - let SimpleExpr::Memory(VarOrConstMallocAccess::Var(new_array)) = replacement else { - panic!("Cannot inline array access target with non-variable array argument"); - }; - *array = new_array.clone(); - } else { - // Internal variable - rename with inlining prefix - *array = format!("@inlined_var_{inlining_count}_{array}"); - } - inline_expr(index, args, inlining_count); + + for (simplified, original) in [ + (start_simplified.clone(), start.clone()), + (end_simplified.clone(), end.clone()), + ] { + if !matches!(original, Expression::Value(_)) { + // the simplified var is auxiliary + if let SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) = simplified { + external_vars.push(var); } } } - } - Line::IfCondition { - condition, - then_branch, - else_branch, - line_number: _, - } => { - inline_condition(condition); - inline_lines(then_branch, args, res, inlining_count); - inline_lines(else_branch, args, res, inlining_count); - } - Line::Assert { boolean, .. } => { - inline_comparison(boolean); + // Create function arguments: iterator + external variables + let mut func_args = vec![iterator.clone()]; + func_args.extend(external_vars.clone()); + + // Create recursive function body + let recursive_func = create_recursive_function( + func_name.clone(), + *location, + func_args, + iterator.clone(), + end_simplified, + simplified_body, + &external_vars, + ); + new_functions.insert(func_name.clone(), recursive_func); + + // Replace loop with initial function call + let mut call_args = vec![start_simplified]; + call_args.extend(external_vars.iter().map(|v| v.clone().into())); + + res.push(SimpleLine::FunctionCall { + function_name: func_name, + args: call_args, + return_data: vec![], + location: *location, + }); } Line::FunctionRet { return_data } => { - assert_eq!(return_data.len(), res.len()); - - for expr in return_data.iter_mut() { - inline_expr(expr, args, inlining_count); - } - lines_to_replace.push(( - i, - res.iter() - .zip(return_data) - .map(|(target, expr)| Line::Statement { - targets: vec![target.clone()], - value: expr.clone(), - line_number: 0, - }) - .collect::>(), - )); + if in_a_loop { + return Err("Function return inside a loop is not currently supported".to_string()); + } + if return_data.len() != n_returned_vars { + return Err(format!( + "Wrong number of return values in return statement; expected {n_returned_vars} but got {}", + return_data.len() + )); + } + let simplified_return_data = return_data + .iter() + .map(|ret| simplify_expr(ctx, state, const_malloc, ret, &mut res)) + .collect::, _>>()?; + res.push(SimpleLine::FunctionRet { + return_data: simplified_return_data, + }); } - Line::MAlloc { var, size, .. } => { - inline_expr(size, args, inlining_count); - inline_internal_var(var); + Line::Panic { message } => { + res.push(SimpleLine::Panic { + message: message.clone(), + }); } - Line::Precompile { - table: _, - args: precompile_args, - } => { - for arg in precompile_args { - inline_expr(arg, args, inlining_count); - } + Line::LocationReport { location } => { + res.push(SimpleLine::LocationReport { location: *location }); } - Line::Print { content, .. } => { - for var in content { - inline_expr(var, args, inlining_count); + Line::VecDeclaration { + var, + elements, + location, + } => { + let vector_value = build_vector_value(ctx, state, const_malloc, elements, &mut res, *location)?; + state.vec_tracker.register(var, vector_value); + // No SimpleLine for the variable itself - vector metadata is compile-time only + } + Line::Push { + vector, + indices, + element, + location, + } => { + // Get the vector and check it's a tracked vector + if !state.vec_tracker.is_vector(vector) { + return Err(format!( + "push called on non-vector variable '{}', at {}", + vector, location + )); } - } - Line::CustomHint(_, decomposed_args) => { - for expr in decomposed_args { - inline_expr(expr, args, inlining_count); + + // Evaluate indices at compile time + let const_indices: Vec = indices + .iter() + .map(|idx| { + let simplified = simplify_expr(ctx, state, const_malloc, idx, &mut res)?; + let const_val = simplified + .as_constant() + .ok_or_else(|| format!("push index must be a compile-time constant, at {}", location))?; + let val = const_val + .naive_eval() + .ok_or_else(|| format!("push index must be evaluable at compile time, at {}", location))?; + Ok(val.to_usize()) + }) + .collect::, String>>()?; + + // Build VectorValue for the element being pushed + let new_element = + build_vector_value_from_element(ctx, state, const_malloc, element, &mut res, *location)?; + + // Navigate to the target vector and push + let vector_value = state + .vec_tracker + .get_mut(vector) + .expect("Vector should exist after is_vector check"); + + if const_indices.is_empty() { + // Push directly to the top-level vector + vector_value.push(new_element); + } else { + // Navigate to the nested vector and push + let target = vector_value + .navigate_mut(&const_indices) + .ok_or_else(|| format!("push target index out of bounds, at {}", location))?; + if !target.is_vector() { + return Err(format!("push target must be a vector, not a scalar, at {}", location)); + } + target.push(new_element); } } - Line::PrivateInputStart { result } => { - inline_internal_var(result); - } - Line::ForLoop { - iterator, - start, - end, - body, - rev: _, - unroll: _, - line_number: _, + Line::Pop { + vector, + indices, + location, } => { - inline_lines(body, args, res, inlining_count); - inline_internal_var(iterator); - inline_expr(start, args, inlining_count); - inline_expr(end, args, inlining_count); + // Get the vector and check it's a tracked vector + if !state.vec_tracker.is_vector(vector) { + return Err(format!( + "pop called on non-vector variable '{}', at {}", + vector, location + )); + } + + // Evaluate indices at compile time + let const_indices: Vec = indices + .iter() + .map(|idx| { + let simplified = simplify_expr(ctx, state, const_malloc, idx, &mut res)?; + let const_val = simplified + .as_constant() + .ok_or_else(|| format!("pop index must be a compile-time constant, at {}", location))?; + let val = const_val + .naive_eval() + .ok_or_else(|| format!("pop index must be evaluable at compile time, at {}", location))?; + Ok(val.to_usize()) + }) + .collect::, String>>()?; + + // Navigate to the target vector and pop + let vector_value = state + .vec_tracker + .get_mut(vector) + .expect("Vector should exist after is_vector check"); + + if const_indices.is_empty() { + // Pop directly from the top-level vector + if vector_value.len() == 0 { + return Err(format!("pop on empty vector '{}', at {}", vector, location)); + } + vector_value.pop(); + } else { + // Navigate to the nested vector and pop + let target = vector_value + .navigate_mut(&const_indices) + .ok_or_else(|| format!("pop target index out of bounds, at {}", location))?; + if !target.is_vector() { + return Err(format!("pop target must be a vector, not a scalar, at {}", location)); + } + if target.len() == 0 { + return Err(format!("pop on empty vector, at {}", location)); + } + target.pop(); + } } - Line::Panic | Line::Break | Line::LocationReport { .. } => {} } } - for (i, new_lines) in lines_to_replace.into_iter().rev() { - lines.splice(i..=i, new_lines); - } + + Ok(res) } -fn vars_in_expression(expr: &Expression, const_arrays: &BTreeMap) -> BTreeSet { - let mut vars = BTreeSet::new(); +fn simplify_expr( + ctx: &SimplifyContext<'_>, + state: &mut SimplifyState<'_>, + const_malloc: &ConstMalloc, + expr: &Expression, + lines: &mut Vec, +) -> Result { match expr { Expression::Value(value) => { + // Translate mutable variable references to their current versioned name if let SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) = value { - vars.insert(var.clone()); + let versioned_var = state.mut_tracker.current_name(var); + Ok(versioned_var.into()) + } else { + Ok(value.clone()) } } Expression::ArrayAccess { array, index } => { - if !const_arrays.contains_key(array) { - vars.insert(array.clone()); - } - for idx in index { - vars.extend(vars_in_expression(idx, const_arrays)); - } - } - Expression::MathExpr(_, args) => { - for arg in args { - vars.extend(vars_in_expression(arg, const_arrays)); - } - } - Expression::FunctionCall { args, .. } => { - for arg in args { - vars.extend(vars_in_expression(arg, const_arrays)); - } - } - Expression::Len { indices, .. } => { - for idx in indices { - vars.extend(vars_in_expression(idx, const_arrays)); - } - } - } - vars -} + // Check for const array access + if let Some(arr) = ctx.const_arrays.get(array) { + let simplified_index = index + .iter() + .map(|idx| { + idx.as_scalar() + .ok_or_else(|| "Const array access index must be a compile-time constant".to_string()) + }) + .collect::, String>>()?; -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum ArrayAccessType { - VarIsAssigned(Var), // var = array[index] - ArrayIsAssigned(Expression), // array[index] = expr -} + return Ok(SimpleExpr::Constant(ConstExpression::scalar( + arr.navigate(&simplified_index) + .unwrap_or_else(|| panic!("Const array index out of bounds for array '{}'", array)) + .as_scalar() + .expect("Const array access should return a scalar"), + ))); + } + + // Check for compile-time vector access + let versioned_array = state.mut_tracker.current_name(array); + if state.vec_tracker.is_vector(&versioned_array) { + // Vector access - indices must all be compile-time constant + // First, simplify all indices (this may mutate state) + let mut const_indices: Vec = Vec::new(); + for idx in index { + let simplified = simplify_expr(ctx, state, const_malloc, idx, lines)?; + let SimpleExpr::Constant(const_expr) = simplified else { + return Err("Vector index must be compile-time constant".to_string()); + }; + let val = const_expr + .naive_eval() + .ok_or_else(|| "Cannot evaluate vector index".to_string())? + .to_usize(); + const_indices.push(val); + } + + // Now we can borrow vec_tracker again + let vector_value = state.vec_tracker.get(&versioned_array).unwrap(); + + // Navigate to the element + let element = vector_value + .navigate(&const_indices) + .ok_or_else(|| format!("Vector index out of bounds: {:?}", const_indices))?; + + match element { + VectorValue::Scalar { var } => { + // Return memory reference to this variable + return Ok(SimpleExpr::Memory(VarOrConstMallocAccess::Var(var.clone()))); + } + VectorValue::Vector(_) => { + return Err("Cannot use nested vector as expression value".to_string()); + } + } + } -fn handle_array_assignment( - ctx: &SimplifyContext<'_>, - state: &mut SimplifyState<'_>, - const_malloc: &ConstMalloc, - res: &mut Vec, - array: &Var, - index: &[Expression], - access_type: ArrayAccessType, -) { - let simplified_index = index - .iter() - .map(|idx| simplify_expr(ctx, state, const_malloc, idx, res)) - .collect::>(); + assert_eq!(index.len(), 1); + let index = index[0].clone(); - if let ArrayAccessType::VarIsAssigned(var) = &access_type - && let Some(const_array) = ctx.const_arrays.get(array) - { - let idx = simplified_index - .iter() - .map(|idx| { - idx.as_constant() - .expect("Const array access index should be constant") - .naive_eval() - .unwrap() - .to_usize() - }) - .collect::>(); - let value = const_array - .navigate(&idx) - .expect("Const array access index out of bounds") - .as_scalar() - .expect("Const array access should return a scalar"); - res.push(SimpleLine::equality(var.clone(), ConstExpression::from(value))); - return; - } + if let Some(label) = const_malloc.map.get(array) + && let Ok(offset) = ConstExpression::try_from(index.clone()) + { + return Ok(VarOrConstMallocAccess::ConstMallocAccess { + malloc_label: *label, + offset, + } + .into()); + } - if simplified_index.len() == 1 - && let SimpleExpr::Constant(offset) = simplified_index[0].clone() - && let Some(label) = const_malloc.map.get(array) - && let ArrayAccessType::ArrayIsAssigned(Expression::MathExpr(operation, args)) = &access_type - { - let var = VarOrConstMallocAccess::ConstMallocAccess { - malloc_label: *label, - offset, - }; - let simplified_args = args - .iter() - .map(|arg| simplify_expr(ctx, state, const_malloc, arg, res)) - .collect::>(); - if let Some(const_args) = SimpleExpr::try_vec_as_constant(&simplified_args) { - let result = ConstExpression::MathExpr(*operation, const_args); - res.push(SimpleLine::equality(var.clone(), SimpleExpr::Constant(result.clone()))); - } - assert_eq!(simplified_args.len(), 2); - res.push(SimpleLine::Assignment { - var, - operation: *operation, - arg0: simplified_args[0].clone(), - arg1: simplified_args[1].clone(), - }); - return; - } + let versioned_array = state.mut_tracker.current_name(array); + let aux_arr = state.array_manager.get_aux_var(&versioned_array, &index); // auxiliary var to store m[array + index] - let value_simplified = match access_type { - ArrayAccessType::VarIsAssigned(var) => SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)), - ArrayAccessType::ArrayIsAssigned(expr) => simplify_expr(ctx, state, const_malloc, &expr, res), - }; + if !state.array_manager.valid.insert(aux_arr.clone()) { + return Ok(VarOrConstMallocAccess::Var(aux_arr).into()); + } - // TODO opti: in some case we could use ConstMallocAccess - assert_eq!(simplified_index.len(), 1); - let simplified_index = simplified_index[0].clone(); - let (index_var, shift) = match simplified_index { - SimpleExpr::Constant(c) => (SimpleExpr::Memory(VarOrConstMallocAccess::Var(array.clone())), c), - _ => { - // Create pointer variable: ptr = array + index - let ptr_var = state.counters.aux_var(); - res.push(SimpleLine::Assignment { - var: ptr_var.clone().into(), - operation: MathOperation::Add, - arg0: SimpleExpr::Memory(VarOrConstMallocAccess::Var(array.clone())), - arg1: simplified_index, + let simplified_index = simplify_expr(ctx, state, const_malloc, &index, lines)?; + handle_array_assignment( + ctx, + state, + lines, + array, + &[simplified_index], + ArrayAccessType::VarIsAssigned(aux_arr.clone()), + ); + Ok(VarOrConstMallocAccess::Var(aux_arr).into()) + } + Expression::MathExpr(operation, args) => { + let simplified_args = args + .iter() + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, lines)) + .collect::, _>>()?; + if let Some(const_args) = SimpleExpr::try_vec_as_constant(&simplified_args) { + return Ok(SimpleExpr::Constant(ConstExpression::MathExpr(*operation, const_args))); + } + let aux_var = state.counters.aux_var(); + assert_eq!(simplified_args.len(), 2); + lines.push(SimpleLine::Assignment { + var: aux_var.clone().into(), + operation: *operation, + arg0: simplified_args[0].clone(), + arg1: simplified_args[1].clone(), }); - ( - SimpleExpr::Memory(VarOrConstMallocAccess::Var(ptr_var)), - ConstExpression::zero(), - ) + Ok(VarOrConstMallocAccess::Var(aux_var).into()) } - }; - - res.push(SimpleLine::RawAccess { - res: value_simplified, - index: index_var, - shift, - }); -} + Expression::FunctionCall { + function_name, + args, + location, + } => { + let function = ctx + .functions + .get(function_name) + .unwrap_or_else(|| panic!("Function used but not defined: {function_name}")); + if function.n_returned_vars != 1 { + return Err(format!( + "Nested function calls must return exactly one value (function {function_name} returns {} values)", + function.n_returned_vars + )); + } -fn create_recursive_function( - name: String, - location: SourceLocation, - args: Vec, - iterator: Var, - end: SimpleExpr, - mut body: Vec, - external_vars: &[Var], -) -> SimpleFunction { - // Add iterator increment - let next_iter = format!("@incremented_{iterator}"); - body.push(SimpleLine::Assignment { - var: next_iter.clone().into(), - operation: MathOperation::Add, - arg0: iterator.clone().into(), - arg1: SimpleExpr::one(), - }); + let simplified_args = args + .iter() + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, lines)) + .collect::, _>>()?; - // Add recursive call - let mut recursive_args: Vec = vec![next_iter.into()]; - recursive_args.extend(external_vars.iter().map(|v| v.clone().into())); + // Create a temporary variable for the function result + let result_var = state.counters.aux_var(); - body.push(SimpleLine::FunctionCall { - function_name: name.clone(), - args: recursive_args, - return_data: vec![], - line_number: location.line_number, - }); - body.push(SimpleLine::FunctionRet { return_data: vec![] }); + lines.push(SimpleLine::FunctionCall { + function_name: function_name.clone(), + args: simplified_args, + return_data: vec![result_var.clone()], + location: *location, + }); - let diff_var = format!("@diff_{iterator}"); + Ok(VarOrConstMallocAccess::Var(result_var).into()) + } + Expression::Len { array, indices } => { + // Check for compile-time vector len() + let versioned_array = state.mut_tracker.current_name(array); + if state.vec_tracker.is_vector(&versioned_array) { + // Evaluate indices at compile time - first simplify to avoid borrow issues + let mut const_indices: Vec = Vec::new(); + for idx in indices { + let simplified = simplify_expr(ctx, state, const_malloc, idx, lines)?; + let SimpleExpr::Constant(const_expr) = simplified else { + return Err("Vector len() index must be compile-time constant".to_string()); + }; + let val = const_expr + .naive_eval() + .ok_or_else(|| "Cannot evaluate len() index".to_string())? + .to_usize(); + const_indices.push(val); + } + + // Now we can borrow vec_tracker again + let vector_value = state.vec_tracker.get(&versioned_array).unwrap(); + + // Navigate and get length + let target = if const_indices.is_empty() { + vector_value + } else { + vector_value + .navigate(&const_indices) + .ok_or_else(|| "len() index out of bounds".to_string())? + }; - let instructions = vec![ - SimpleLine::Assignment { - var: diff_var.clone().into(), - operation: MathOperation::Sub, - arg0: iterator.into(), - arg1: end, - }, - SimpleLine::IfNotZero { - condition: diff_var.into(), - then_branch: body, - else_branch: vec![SimpleLine::FunctionRet { return_data: vec![] }], - line_number: location.line_number, - }, - ]; + return Ok(SimpleExpr::Constant(ConstExpression::from(target.len()))); + } - SimpleFunction { - name, - file_id: location.file_id, - arguments: args, - n_returned_vars: 0, - instructions, + // Fall through to const array handling (should be unreachable for vectors) + unreachable!("len() should have been resolved at parse time for const arrays") + } } } -fn replace_vars_for_unroll_in_expr( - expr: &mut Expression, - iterator: &Var, - unroll_index: usize, - iterator_value: usize, - internal_vars: &BTreeSet, -) { - match expr { - Expression::Value(value_expr) => match value_expr { - SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) => { - if var == iterator { - *value_expr = SimpleExpr::Constant(ConstExpression::from(iterator_value)); - } else if internal_vars.contains(var) { - *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); - } - } - SimpleExpr::Constant(_) | SimpleExpr::Memory(VarOrConstMallocAccess::ConstMallocAccess { .. }) => {} - }, - Expression::ArrayAccess { array, index } => { - assert!(array != iterator, "Weird"); - if internal_vars.contains(array) { - *array = format!("@unrolled_{unroll_index}_{iterator_value}_{array}"); - } - for index in index { - replace_vars_for_unroll_in_expr(index, iterator, unroll_index, iterator_value, internal_vars); +fn remove_forward_declarations(lines: &mut Vec, var: &Var) { + for i in (0..lines.len()).rev() { + if let SimpleLine::ForwardDeclaration { var: decl_var } = &lines[i] + && decl_var == var + { + lines.remove(i); + } else { + for block in lines[i].nested_blocks_mut() { + remove_forward_declarations(block, var); } } - Expression::MathExpr(_, args) => { - for arg in args { - replace_vars_for_unroll_in_expr(arg, iterator, unroll_index, iterator_value, internal_vars); + } +} + +/// Returns (internal_vars, external_vars) +pub fn find_variable_usage( + lines: &[Line], + const_arrays: &BTreeMap, +) -> (BTreeSet, BTreeSet) { + let mut internal_vars = BTreeSet::new(); + let mut external_vars = BTreeSet::new(); + + let on_new_expr = |expr: &Expression, internal_vars: &BTreeSet, external_vars: &mut BTreeSet| { + for var in vars_in_expression(expr, const_arrays) { + if !internal_vars.contains(&var) && !const_arrays.contains_key(&var) { + external_vars.insert(var); } } - Expression::FunctionCall { args, .. } => { - for arg in args { - replace_vars_for_unroll_in_expr(arg, iterator, unroll_index, iterator_value, internal_vars); + }; + + let on_new_condition = + |condition: &Condition, internal_vars: &BTreeSet, external_vars: &mut BTreeSet| match condition { + Condition::Comparison(comp) => { + on_new_expr(&comp.left, internal_vars, external_vars); + on_new_expr(&comp.right, internal_vars, external_vars); } - } - Expression::Len { indices, .. } => { - for idx in indices { - replace_vars_for_unroll_in_expr(idx, iterator, unroll_index, iterator_value, internal_vars); + Condition::AssumeBoolean(expr) => { + on_new_expr(expr, internal_vars, external_vars); } - } - } -} + }; -fn replace_vars_for_unroll( - lines: &mut [Line], - iterator: &Var, - unroll_index: usize, - iterator_value: usize, - internal_vars: &BTreeSet, -) { for line in lines { match line { - Line::Match { value, arms } => { - replace_vars_for_unroll_in_expr(value, iterator, unroll_index, iterator_value, internal_vars); + Line::ForwardDeclaration { var, .. } => { + internal_vars.insert(var.clone()); + } + Line::Match { value, arms, .. } => { + on_new_expr(value, &internal_vars, &mut external_vars); for (_, statements) in arms { - replace_vars_for_unroll(statements, iterator, unroll_index, iterator_value, internal_vars); + let (stmt_internal, stmt_external) = find_variable_usage(statements, const_arrays); + internal_vars.extend(stmt_internal); + external_vars.extend(stmt_external.into_iter().filter(|v| !internal_vars.contains(v))); } } - Line::ForwardDeclaration { var } => { - *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); - } Line::Statement { targets, value, .. } => { - replace_vars_for_unroll_in_expr(value, iterator, unroll_index, iterator_value, internal_vars); + on_new_expr(value, &internal_vars, &mut external_vars); for target in targets { match target { - AssignmentTarget::Var(var) => { - assert!(var != iterator, "Weird"); - *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); + AssignmentTarget::Var { var, .. } => { + // Only mark as internal if not already used as external + // This ensures re-assignments to external (mutable) variables + // keep them as external + if !external_vars.contains(var) { + internal_vars.insert(var.clone()); + } } AssignmentTarget::ArrayAccess { array, index } => { - assert!(array != iterator, "Weird"); - if internal_vars.contains(array) { - *array = format!("@unrolled_{unroll_index}_{iterator_value}_{array}"); + assert!(!const_arrays.contains_key(array), "Cannot assign to const array"); + if !internal_vars.contains(array) { + external_vars.insert(array.clone()); } - replace_vars_for_unroll_in_expr( - index, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + on_new_expr(index, &internal_vars, &mut external_vars); } } } } - Line::Assert { boolean, .. } => { - replace_vars_for_unroll_in_expr( - &mut boolean.left, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - &mut boolean.right, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } Line::IfCondition { condition, then_branch, else_branch, - line_number: _, + .. } => { - match condition { - Condition::Comparison(cond) => { - replace_vars_for_unroll_in_expr( - &mut cond.left, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - &mut cond.right, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - } - Condition::AssumeBoolean(expr) => { - replace_vars_for_unroll_in_expr(expr, iterator, unroll_index, iterator_value, internal_vars); - } - } + on_new_condition(condition, &internal_vars, &mut external_vars); - replace_vars_for_unroll(then_branch, iterator, unroll_index, iterator_value, internal_vars); - replace_vars_for_unroll(else_branch, iterator, unroll_index, iterator_value, internal_vars); + let (then_internal, then_external) = find_variable_usage(then_branch, const_arrays); + let (else_internal, else_external) = find_variable_usage(else_branch, const_arrays); + + internal_vars.extend(then_internal.union(&else_internal).cloned()); + external_vars.extend( + then_external + .union(&else_external) + .filter(|v| !internal_vars.contains(*v)) + .cloned(), + ); + } + Line::Assert { boolean, .. } => { + on_new_condition( + &Condition::Comparison(boolean.clone()), + &internal_vars, + &mut external_vars, + ); + } + Line::FunctionRet { return_data } => { + for ret in return_data { + on_new_expr(ret, &internal_vars, &mut external_vars); + } } Line::ForLoop { - iterator: other_iterator, + iterator, start, end, body, - rev: _, - unroll: _, - line_number: _, + .. } => { - assert!(other_iterator != iterator); - *other_iterator = format!("@unrolled_{unroll_index}_{iterator_value}_{other_iterator}"); - replace_vars_for_unroll_in_expr(start, iterator, unroll_index, iterator_value, internal_vars); - replace_vars_for_unroll_in_expr(end, iterator, unroll_index, iterator_value, internal_vars); - replace_vars_for_unroll(body, iterator, unroll_index, iterator_value, internal_vars); + let (body_internal, body_external) = find_variable_usage(body, const_arrays); + internal_vars.extend(body_internal); + internal_vars.insert(iterator.clone()); + external_vars.extend(body_external.difference(&internal_vars).cloned()); + on_new_expr(start, &internal_vars, &mut external_vars); + on_new_expr(end, &internal_vars, &mut external_vars); } - Line::FunctionRet { return_data } => { - for ret in return_data { - replace_vars_for_unroll_in_expr(ret, iterator, unroll_index, iterator_value, internal_vars); - } + Line::Panic { .. } | Line::LocationReport { .. } => {} + Line::VecDeclaration { var, elements, .. } => { + // Process expressions in vec elements + process_vec_elements_usage(elements, &internal_vars, &mut external_vars, const_arrays); + // Add the vector variable to internal vars + internal_vars.insert(var.clone()); } - Line::Precompile { table: _, args } => { - for arg in args { - replace_vars_for_unroll_in_expr(arg, iterator, unroll_index, iterator_value, internal_vars); + Line::Push { + vector, + indices, + element, + .. + } => { + // The vector variable is used + if !internal_vars.contains(vector) { + external_vars.insert(vector.clone()); } - } - Line::Print { line_info, content } => { - // Print statements are not unrolled, so we don't need to change them - *line_info += &format!(" (unrolled {unroll_index} {iterator_value})"); - for var in content { - replace_vars_for_unroll_in_expr(var, iterator, unroll_index, iterator_value, internal_vars); + // Process index expressions + for idx in indices { + on_new_expr(idx, &internal_vars, &mut external_vars); } + // Process the pushed element + process_vec_element_usage(element, &internal_vars, &mut external_vars, const_arrays); } - Line::MAlloc { - var, - size, - vectorized: _, - vectorized_len, - } => { - assert!(var != iterator, "Weird"); - *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); - replace_vars_for_unroll_in_expr(size, iterator, unroll_index, iterator_value, internal_vars); - replace_vars_for_unroll_in_expr(vectorized_len, iterator, unroll_index, iterator_value, internal_vars); - } - Line::PrivateInputStart { result } => { - assert!(result != iterator, "Weird"); - *result = format!("@unrolled_{unroll_index}_{iterator_value}_{result}"); - } - Line::CustomHint(_, decomposed_args) => { - for expr in decomposed_args { - replace_vars_for_unroll_in_expr(expr, iterator, unroll_index, iterator_value, internal_vars); + Line::Pop { vector, indices, .. } => { + // The vector variable is used + if !internal_vars.contains(vector) { + external_vars.insert(vector.clone()); + } + // Process index expressions + for idx in indices { + on_new_expr(idx, &internal_vars, &mut external_vars); } } - Line::Break | Line::Panic | Line::LocationReport { .. } => {} } } -} -fn handle_inlined_functions(program: &mut Program) { - let inlined_functions = program - .functions - .iter() - .filter(|(_, func)| func.inlined) - .map(|(name, func)| (name.clone(), func.clone())) - .collect::>(); + (internal_vars, external_vars) +} - for func in inlined_functions.values() { - assert!( - !func.has_const_arguments(), - "Inlined functions with constant arguments are not supported yet" - ); +fn process_vec_elements_usage( + elements: &[VecLiteral], + internal_vars: &BTreeSet, + external_vars: &mut BTreeSet, + const_arrays: &BTreeMap, +) { + for elem in elements { + process_vec_element_usage(elem, internal_vars, external_vars, const_arrays); } +} - // Process inline functions iteratively to handle dependencies - // Repeat until all inline function calls are resolved - let mut max_iterations = 10; - - let mut counter1 = Counter::new(); - let mut counter2 = Counter::new(); - - while max_iterations > 0 { - let mut any_changes = false; - - // Process non-inlined functions - for func in program.functions.values_mut() { - if !func.inlined { - let old_body = func.body.clone(); - - let mut ctx = Context::new(); - for (var, _) in func.arguments.iter() { - ctx.add_var(var); - } - func.body = handle_inlined_functions_helper( - &mut ctx, - &func.body, - &inlined_functions, - &mut counter1, - &mut counter2, - ); - - if func.body != old_body { - any_changes = true; - } - } - } - - // Process inlined functions that may call other inlined functions - // We need to update them so that when they get inlined later, they don't have unresolved calls - for func in program.functions.values_mut() { - if func.inlined { - let old_body = func.body.clone(); - - let mut ctx = Context::new(); - for (var, _) in func.arguments.iter() { - ctx.add_var(var); - } - handle_inlined_functions_helper(&mut ctx, &func.body, &inlined_functions, &mut counter1, &mut counter2); - - if func.body != old_body { - any_changes = true; +fn process_vec_element_usage( + elem: &VecLiteral, + internal_vars: &BTreeSet, + external_vars: &mut BTreeSet, + const_arrays: &BTreeMap, +) { + match elem { + VecLiteral::Expr(expr) => { + for var in vars_in_expression(expr, const_arrays) { + if !internal_vars.contains(&var) && !const_arrays.contains_key(&var) { + external_vars.insert(var); } } } - - if !any_changes { - break; + VecLiteral::Vec(inner) => { + process_vec_elements_usage(inner, internal_vars, external_vars, const_arrays); } - - max_iterations -= 1; - } - - assert!(max_iterations > 0, "Too many iterations processing inline functions"); - - // Remove all inlined functions from the program (they've been inlined) - for func_name in inlined_functions.keys() { - program.functions.remove(func_name); } } -/// Recursively extracts inlined function calls from an expression. -/// Returns the modified expression and lines to prepend (forward declarations and function calls). -fn extract_inlined_calls_from_expr( - expr: &Expression, - inlined_functions: &BTreeMap, - inlined_var_counter: &mut Counter, -) -> (Expression, Vec) { - let mut lines = vec![]; +enum VarTransform { + ReplaceWithExpr(SimpleExpr), + Rename(String), + Keep, +} - match expr { - Expression::Value(_) => (expr.clone(), vec![]), - Expression::ArrayAccess { array, index } => { - let mut index_new = vec![]; - for idx in index { - let (idx, idx_lines) = extract_inlined_calls_from_expr(idx, inlined_functions, inlined_var_counter); - lines.extend(idx_lines); - index_new.push(idx); - } - ( - Expression::ArrayAccess { - array: array.clone(), - index: index_new, - }, - lines, - ) - } - Expression::MathExpr(operation, args) => { - let mut args_new = vec![]; - for arg in args { - let (arg, arg_lines) = extract_inlined_calls_from_expr(arg, inlined_functions, inlined_var_counter); - lines.extend(arg_lines); - args_new.push(arg); +impl VarTransform { + fn apply_to_var(self, var: &mut Var) { + match self { + VarTransform::ReplaceWithExpr(SimpleExpr::Memory(VarOrConstMallocAccess::Var(new_var))) => { + *var = new_var; } - (Expression::MathExpr(*operation, args_new), lines) - } - Expression::FunctionCall { function_name, args } => { - let mut args_new = vec![]; - for arg in args { - let (arg, arg_lines) = extract_inlined_calls_from_expr(arg, inlined_functions, inlined_var_counter); - args_new.push(arg); - lines.extend(arg_lines); - } - - if inlined_functions.contains_key(function_name) { - let aux_var = format!("@inlined_var_{}", inlined_var_counter.next()); - lines.push(Line::ForwardDeclaration { var: aux_var.clone() }); - lines.push(Line::Statement { - targets: vec![AssignmentTarget::Var(aux_var.clone())], - value: Expression::FunctionCall { - function_name: function_name.clone(), - args: args.clone(), - }, - line_number: 0, - }); - (Expression::Value(VarOrConstMallocAccess::Var(aux_var).into()), lines) - } else { - (expr.clone(), lines) + VarTransform::ReplaceWithExpr(_) => { + panic!("Cannot replace variable with non-variable expression in this context"); } - } - Expression::Len { array, indices } => { - let mut new_indices = vec![]; - for idx in indices.iter() { - let (idx, idx_lines) = extract_inlined_calls_from_expr(idx, inlined_functions, inlined_var_counter); - lines.extend(idx_lines); - new_indices.push(idx); + VarTransform::Rename(new_name) => { + *var = new_name; } - ( - Expression::Len { - array: array.clone(), - indices: new_indices, - }, - lines, - ) + VarTransform::Keep => {} } } } -fn extract_inlined_calls_from_boolean_expr( - boolean: &BooleanExpr, - inlined_functions: &BTreeMap, - inlined_var_counter: &mut Counter, -) -> (BooleanExpr, Vec) { - let (left, mut lines) = extract_inlined_calls_from_expr(&boolean.left, inlined_functions, inlined_var_counter); - let (right, right_lines) = extract_inlined_calls_from_expr(&boolean.right, inlined_functions, inlined_var_counter); - lines.extend(right_lines); - let boolean = BooleanExpr { - kind: boolean.kind, - left, - right, - }; - (boolean, lines) -} - -fn extract_inlined_calls_from_condition( - condition: &Condition, - inlined_functions: &BTreeMap, - inlined_var_counter: &mut Counter, -) -> (Condition, Vec) { - match condition { - Condition::AssumeBoolean(expr) => { - let (expr, expr_lines) = extract_inlined_calls_from_expr(expr, inlined_functions, inlined_var_counter); - (Condition::AssumeBoolean(expr), expr_lines) - } - Condition::Comparison(boolean) => { - let (boolean, boolean_lines) = - extract_inlined_calls_from_boolean_expr(boolean, inlined_functions, inlined_var_counter); - (Condition::Comparison(boolean), boolean_lines) +fn transform_vars_in_simple_expr(simple_expr: &mut SimpleExpr, transform: &impl Fn(&Var) -> VarTransform) { + if let SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) = simple_expr { + match transform(var) { + VarTransform::ReplaceWithExpr(replacement) => { + *simple_expr = replacement; + } + VarTransform::Rename(new_name) => { + *var = new_name; + } + VarTransform::Keep => {} } } } -fn handle_inlined_functions_helper( - ctx: &mut Context, - lines_in: &Vec, - inlined_functions: &BTreeMap, - inlined_var_counter: &mut Counter, - total_inlined_counter: &mut Counter, -) -> Vec { - let mut lines_out = vec![]; - for line in lines_in { - match line { - Line::Break | Line::Panic | Line::LocationReport { .. } => { - lines_out.push(line.clone()); - } - Line::Statement { - targets, - value: Expression::FunctionCall { function_name, args }, - line_number: _, - } => { - if let Some(func) = inlined_functions.get(function_name) { - let mut inlined_lines = vec![]; - - // Only add forward declarations for variable targets, not array accesses - for target in targets.iter() { - if let AssignmentTarget::Var(var) = target - && !ctx.defines(var) - { - inlined_lines.push(Line::ForwardDeclaration { var: var.clone() }); - ctx.add_var(var); - } - } - - let mut simplified_args = vec![]; - for arg in args { - if let Expression::Value(simple_expr) = arg { - simplified_args.push(simple_expr.clone()); - } else { - let aux_var = format!("@inlined_var_{}", inlined_var_counter.next()); - // Check if the argument is a function call to an inlined function - // If so, create a Line::Statement so it gets inlined in subsequent iterations - if let Expression::FunctionCall { - function_name: arg_func_name, - args: arg_args, - } = arg - { - if inlined_functions.contains_key(arg_func_name) { - inlined_lines.push(Line::ForwardDeclaration { var: aux_var.clone() }); - inlined_lines.push(Line::Statement { - targets: vec![AssignmentTarget::Var(aux_var.clone())], - value: Expression::FunctionCall { - function_name: arg_func_name.clone(), - args: arg_args.clone(), - }, - line_number: 0, - }); - } else { - inlined_lines.push(Line::Statement { - targets: vec![AssignmentTarget::Var(aux_var.clone())], - value: arg.clone(), - line_number: 0, - }); - } - } else { - inlined_lines.push(Line::Statement { - targets: vec![AssignmentTarget::Var(aux_var.clone())], - value: arg.clone(), - line_number: 0, - }); - } - simplified_args.push(VarOrConstMallocAccess::Var(aux_var).into()); - } - } - assert_eq!(simplified_args.len(), func.arguments.len()); - let inlined_args = func - .arguments - .iter() - .zip(&simplified_args) - .map(|((var, _), expr)| (var.clone(), expr.clone())) - .collect::>(); - let mut func_body = func.body.clone(); - inline_lines(&mut func_body, &inlined_args, targets, total_inlined_counter.next()); - inlined_lines.extend(func_body); - lines_out.extend(inlined_lines); - } else { - lines_out.push(line.clone()); - } - } - Line::Statement { - targets, - value, - line_number, - } => { - let (value, value_lines) = - extract_inlined_calls_from_expr(value, inlined_functions, inlined_var_counter); - lines_out.extend(value_lines); - for target in targets { - if let AssignmentTarget::Var(var) = target - && !ctx.defines(var) - { - ctx.add_var(var); - } - } - lines_out.push(Line::Statement { - targets: targets.clone(), - value, - line_number: *line_number, - }); - } - Line::IfCondition { - condition, - then_branch, - else_branch, - line_number, - } => { - extract_inlined_calls_from_condition(condition, inlined_functions, inlined_var_counter); - ctx.scopes.push(Scope::default()); - let then_branch_out = handle_inlined_functions_helper( - ctx, - then_branch, - inlined_functions, - inlined_var_counter, - total_inlined_counter, - ); - ctx.scopes.pop(); - ctx.scopes.push(Scope::default()); - let else_branch_out = handle_inlined_functions_helper( - ctx, - else_branch, - inlined_functions, - inlined_var_counter, - total_inlined_counter, - ); - ctx.scopes.pop(); - lines_out.push(Line::IfCondition { - condition: condition.clone(), - then_branch: then_branch_out, - else_branch: else_branch_out, - line_number: *line_number, - }); - } - Line::Match { value, arms } => { - let mut arms_out: Vec<(usize, Vec)> = Vec::new(); - for (i, arm) in arms { - ctx.scopes.push(Scope::default()); - let arm_out = handle_inlined_functions_helper( - ctx, - arm, - inlined_functions, - inlined_var_counter, - total_inlined_counter, - ); - ctx.scopes.pop(); - arms_out.push((*i, arm_out)); - } - lines_out.push(Line::Match { - value: value.clone(), - arms: arms_out, - }); - } - Line::ForwardDeclaration { var } => { - lines_out.push(line.clone()); - ctx.add_var(var); - } - Line::PrivateInputStart { result } => { - lines_out.push(line.clone()); - if !ctx.defines(result) { - ctx.add_var(result); - } - } - Line::ForLoop { - iterator, - start, - end, - body, - rev, - unroll, - line_number, - } => { - // Handle inlining in the loop bounds - let (start, start_lines) = - extract_inlined_calls_from_expr(start, inlined_functions, inlined_var_counter); - lines_out.extend(start_lines); - let (end, end_lines) = extract_inlined_calls_from_expr(end, inlined_functions, inlined_var_counter); - lines_out.extend(end_lines); - - // Handle inlining in the loop body - ctx.scopes.push(Scope::default()); - ctx.add_var(iterator); - let loop_body_out = handle_inlined_functions_helper( - ctx, - body, - inlined_functions, - inlined_var_counter, - total_inlined_counter, - ); - ctx.scopes.pop(); - - // Push modified loop - lines_out.push(Line::ForLoop { - iterator: iterator.clone(), - start, - end, - body: loop_body_out, - rev: *rev, - unroll: *unroll, - line_number: *line_number, - }); - } - Line::Assert { - debug, - boolean, - line_number, - } => { - let (boolean, boolean_lines) = - extract_inlined_calls_from_boolean_expr(boolean, inlined_functions, inlined_var_counter); - lines_out.extend(boolean_lines); - lines_out.push(Line::Assert { - debug: *debug, - boolean, - line_number: *line_number, - }); - } - Line::Print { line_info, content } => { - let mut new_content = vec![]; - for expr in content { - let (expr, expr_lines) = - extract_inlined_calls_from_expr(expr, inlined_functions, inlined_var_counter); - lines_out.extend(expr_lines); - new_content.push(expr); - } - lines_out.push(Line::Print { - line_info: line_info.clone(), - content: new_content, - }); +fn transform_vars_in_expr(expr: &mut Expression, transform: &impl Fn(&Var) -> VarTransform) { + match expr { + Expression::Value(value) => { + transform_vars_in_simple_expr(value, transform); + } + Expression::ArrayAccess { array, .. } | Expression::Len { array, .. } => { + transform(array).apply_to_var(array); + } + Expression::MathExpr(_, _) | Expression::FunctionCall { .. } => {} + } + for inner_expr in expr.inner_exprs_mut() { + transform_vars_in_expr(inner_expr, transform); + } +} + +fn transform_vars_in_lines(lines: &mut [Line], transform: &impl Fn(&Var) -> VarTransform) { + for line in lines { + for expr in line.expressions_mut() { + transform_vars_in_expr(expr, transform); + } + for block in line.nested_blocks_mut() { + transform_vars_in_lines(block, transform); + } + match line { + Line::ForwardDeclaration { var, .. } => { + transform(var).apply_to_var(var); } - Line::FunctionRet { return_data } => { - let mut new_return_data = vec![]; - for expr in return_data { - let (expr, expr_lines) = - extract_inlined_calls_from_expr(expr, inlined_functions, inlined_var_counter); - lines_out.extend(expr_lines); - new_return_data.push(expr); + Line::Statement { targets, .. } => { + for target in targets { + match target { + AssignmentTarget::Var { var, .. } => { + transform(var).apply_to_var(var); + } + AssignmentTarget::ArrayAccess { array, .. } => { + transform(array).apply_to_var(array); + } + } } - lines_out.push(Line::FunctionRet { - return_data: new_return_data, - }); } - Line::Precompile { table, args } => { - let mut new_args = vec![]; - for expr in args { - let (expr, new_lines) = - extract_inlined_calls_from_expr(expr, inlined_functions, inlined_var_counter); - lines_out.extend(new_lines); - new_args.push(expr); - } - lines_out.push(Line::Precompile { - table: *table, - args: new_args, - }); + Line::ForLoop { iterator, .. } => { + transform(iterator).apply_to_var(iterator); } - Line::MAlloc { - var, - size, - vectorized, - vectorized_len, - } => { - let (size, size_lines) = extract_inlined_calls_from_expr(size, inlined_functions, inlined_var_counter); - lines_out.extend(size_lines); - let (vectorized_len, vectorized_len_lines) = - extract_inlined_calls_from_expr(vectorized_len, inlined_functions, inlined_var_counter); - lines_out.extend(vectorized_len_lines); - - if !ctx.defines(var) { - ctx.add_var(var); - } - lines_out.push(Line::MAlloc { - var: var.clone(), - size, - vectorized: *vectorized, - vectorized_len, - }); + Line::VecDeclaration { var, .. } => { + transform(var).apply_to_var(var); } - Line::CustomHint(hint, args) => { - let mut new_args = vec![]; - for expr in args { - let (expr, new_lines) = - extract_inlined_calls_from_expr(expr, inlined_functions, inlined_var_counter); - lines_out.extend(new_lines); - new_args.push(expr); - } - lines_out.push(Line::CustomHint(*hint, new_args)); + Line::Push { vector, .. } => { + transform(vector).apply_to_var(vector); } - }; + Line::Pop { vector, .. } => { + transform(vector).apply_to_var(vector); + } + _ => {} + } } - lines_out } -fn handle_const_arguments(program: &mut Program) -> bool { - let mut any_changes = false; - let mut new_functions = BTreeMap::::new(); - let constant_functions = program - .functions - .iter() - .filter(|(_, func)| func.has_const_arguments()) - .map(|(name, func)| (name.clone(), func.clone())) - .collect::>(); +fn inline_lines( + lines: &mut Vec, + args: &BTreeMap, + const_arrays: &BTreeMap, + res: &[AssignmentTarget], + inlining_count: usize, +) { + let transform = |var: &Var| -> VarTransform { + if let Some(replacement) = args.get(var) { + VarTransform::ReplaceWithExpr(replacement.clone()) + } else if const_arrays.contains_key(var) { + VarTransform::Keep + } else { + VarTransform::Rename(format!("@inlined_var_{inlining_count}_{var}")) + } + }; - // First pass: process non-const functions that call const functions - for func in program.functions.values_mut() { - if !func.has_const_arguments() { - any_changes |= handle_const_arguments_helper( - func.file_id, - &mut func.body, - &constant_functions, - &mut new_functions, - &program.const_arrays, - ); + transform_vars_in_lines(lines, &transform); + replace_function_ret_in_lines(lines, res); +} + +fn replace_function_ret_in_lines(lines: &mut Vec, res: &[AssignmentTarget]) { + // First recurse into nested blocks + for line in lines.iter_mut() { + for block in line.nested_blocks_mut() { + replace_function_ret_in_lines(block, res); } } - // Process newly created functions recursively until no more changes - let mut changed = true; - let mut const_depth = 0; - while changed { - changed = false; - const_depth += 1; - assert!(const_depth < 100, "Too many levels of constant arguments"); - let mut additional_functions = BTreeMap::new(); - - // Collect all function names to process - let function_names: Vec = new_functions.keys().cloned().collect(); - - for name in function_names { - if let Some(func) = new_functions.get_mut(&name) { - let initial_count = additional_functions.len(); - handle_const_arguments_helper( - func.file_id, - &mut func.body, - &constant_functions, - &mut additional_functions, - &program.const_arrays, - ); - if additional_functions.len() > initial_count { - changed = true; - any_changes = true; - } - } + // Then handle FunctionRet → Statement conversion at this level + let mut lines_to_replace = vec![]; + for (i, line) in lines.iter().enumerate() { + if let Line::FunctionRet { return_data } = line { + assert_eq!(return_data.len(), res.len()); + lines_to_replace.push(( + i, + res.iter() + .zip(return_data.iter()) + .map(|(target, expr)| Line::Statement { + targets: vec![target.clone()], + value: expr.clone(), + location: SourceLocation { + file_id: 0, + line_number: 0, + }, // TODO + }) + .collect::>(), + )); } + } + for (i, new_lines) in lines_to_replace.into_iter().rev() { + lines.splice(i..=i, new_lines); + } +} - // Add any newly discovered functions - for (name, func) in additional_functions { - if let std::collections::btree_map::Entry::Vacant(e) = new_functions.entry(name) { - e.insert(func); - changed = true; - any_changes = true; - } +fn vars_in_expression(expr: &Expression, const_arrays: &BTreeMap) -> BTreeSet { + let mut vars = BTreeSet::new(); + match expr { + Expression::Value(SimpleExpr::Memory(VarOrConstMallocAccess::Var(var))) => { + vars.insert(var.clone()); + } + Expression::ArrayAccess { array, .. } if !const_arrays.contains_key(array) => { + vars.insert(array.clone()); } + _ => {} } + for inner_expr in expr.inner_exprs() { + vars.extend(vars_in_expression(inner_expr, const_arrays)); + } + vars +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ArrayAccessType { + VarIsAssigned(Var), // var = array[index] + ArrayIsAssigned(SimpleExpr), // array[index] = expr +} - any_changes |= !new_functions.is_empty(); +fn handle_array_assignment( + ctx: &SimplifyContext<'_>, + state: &mut SimplifyState<'_>, + res: &mut Vec, + array: &Var, + simplified_index: &[SimpleExpr], + access_type: ArrayAccessType, +) { + // Convert array name to versioned name if it's a mutable variable + let array = state.mut_tracker.current_name(array); - for (name, func) in new_functions { - assert!(!program.functions.contains_key(&name)); - program.functions.insert(name, func); + if let ArrayAccessType::VarIsAssigned(var) = &access_type + && let Some(const_array) = ctx.const_arrays.get(&array) + { + let idx = simplified_index + .iter() + .map(|idx| { + idx.as_constant() + .expect("Const array access index should be constant") + .naive_eval() + .unwrap() + }) + .collect::>(); + let value = const_array + .navigate(&idx) + .expect("Const array access index out of bounds") + .as_scalar() + .expect("Const array access should return a scalar"); + res.push(SimpleLine::equality(var.clone(), ConstExpression::scalar(value))); + return; } - // DON'T remove const functions here - they might be needed in subsequent iterations - // They will be removed at the end of simplify_program + let value_simplified = match access_type { + ArrayAccessType::VarIsAssigned(var) => SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)), + ArrayAccessType::ArrayIsAssigned(expr) => expr, + }; + + // TODO opti: in some case we could use ConstMallocAccess + assert_eq!(simplified_index.len(), 1); + let simplified_index = simplified_index[0].clone(); + let (index_var, shift) = match simplified_index { + SimpleExpr::Constant(c) => (SimpleExpr::Memory(VarOrConstMallocAccess::Var(array.clone())), c), + _ => { + // Create pointer variable: ptr = array + index + let ptr_var = state.counters.aux_var(); + res.push(SimpleLine::Assignment { + var: ptr_var.clone().into(), + operation: MathOperation::Add, + arg0: SimpleExpr::Memory(VarOrConstMallocAccess::Var(array.clone())), + arg1: simplified_index, + }); + ( + SimpleExpr::Memory(VarOrConstMallocAccess::Var(ptr_var)), + ConstExpression::zero(), + ) + } + }; - any_changes + res.push(SimpleLine::RawAccess { + res: value_simplified, + index: index_var, + shift, + }); } -fn handle_const_arguments_helper( - file_id: FileId, - lines: &mut [Line], - constant_functions: &BTreeMap, - new_functions: &mut BTreeMap, - const_arrays: &BTreeMap, -) -> bool { - let mut changed = false; - 'outer: for line in lines { - match line { - Line::Statement { - targets: _, - value: Expression::FunctionCall { function_name, args }, - line_number: _, - } => { - if let Some(func) = constant_functions.get(function_name.as_str()) { - // Check if all const arguments can be evaluated - let mut const_evals = Vec::new(); - for (arg_expr, (arg_var, is_constant)) in args.iter().zip(&func.arguments) { - if *is_constant { - if let Some(const_eval) = arg_expr.naive_eval(const_arrays) { - const_evals.push((arg_var.clone(), const_eval)); - } else { - // Skip this call, will be handled in a later pass after more unrolling - continue 'outer; - } - } - } +fn create_recursive_function( + name: String, + location: SourceLocation, + args: Vec, + iterator: Var, + end: SimpleExpr, + mut body: Vec, + external_vars: &[Var], +) -> SimpleFunction { + // Add iterator increment + let next_iter = format!("@incremented_{iterator}"); + body.push(SimpleLine::Assignment { + var: next_iter.clone().into(), + operation: MathOperation::Add, + arg0: iterator.clone().into(), + arg1: SimpleExpr::one(), + }); - let const_funct_name = format!( - "{function_name}_{}", - const_evals - .iter() - .map(|(arg_var, const_eval)| { format!("{arg_var}={const_eval}") }) - .collect::>() - .join("_") - ); + // Add recursive call + let mut recursive_args: Vec = vec![next_iter.into()]; + recursive_args.extend(external_vars.iter().map(|v| v.clone().into())); - *function_name = const_funct_name.clone(); // change the name of the function called - // ... and remove constant arguments - *args = args - .iter() - .zip(&func.arguments) - .filter(|(_, (_, is_constant))| !is_constant) - .filter(|(_, (_, is_const))| !is_const) - .map(|(arg_expr, _)| arg_expr.clone()) - .collect(); + body.push(SimpleLine::FunctionCall { + function_name: name.clone(), + args: recursive_args, + return_data: vec![], + location, + }); + body.push(SimpleLine::FunctionRet { return_data: vec![] }); - changed = true; + let diff_var = format!("@diff_{iterator}"); - if new_functions.contains_key(&const_funct_name) { - continue; - } + let instructions = vec![ + SimpleLine::Assignment { + var: diff_var.clone().into(), + operation: MathOperation::Sub, + arg0: iterator.into(), + arg1: end, + }, + SimpleLine::IfNotZero { + condition: diff_var.into(), + then_branch: body, + else_branch: vec![SimpleLine::FunctionRet { return_data: vec![] }], + location, + }, + ]; - let mut new_body = func.body.clone(); - replace_vars_by_const_in_lines(&mut new_body, &const_evals.iter().cloned().collect()); - new_functions.insert( - const_funct_name.clone(), - Function { - name: const_funct_name, - file_id, - arguments: func - .arguments - .iter() - .filter(|(_, is_const)| !is_const) - .cloned() - .collect(), - inlined: false, - body: new_body, - n_returned_vars: func.n_returned_vars, - assume_always_returns: func.assume_always_returns, - }, - ); - } - } - Line::Statement { .. } => {} - Line::IfCondition { - then_branch, - else_branch, - .. - } => { - changed |= handle_const_arguments_helper( - file_id, - then_branch, - constant_functions, - new_functions, - const_arrays, - ); - changed |= handle_const_arguments_helper( - file_id, - else_branch, - constant_functions, - new_functions, - const_arrays, - ); - } - Line::ForLoop { body, unroll: _, .. } => { - // TODO we should unroll before const arguments handling - handle_const_arguments_helper(file_id, body, constant_functions, new_functions, const_arrays); - } - Line::Match { arms, .. } => { - for (_, arm) in arms { - changed |= - handle_const_arguments_helper(file_id, arm, constant_functions, new_functions, const_arrays); - } - } - _ => {} - } + SimpleFunction { + name, + arguments: args, + n_returned_vars: 0, + instructions, } - changed +} + +fn replace_vars_for_unroll( + lines: &mut [Line], + iterator: &Var, + unroll_index: usize, + iterator_value: usize, + internal_vars: &BTreeSet, +) { + let transform = |var: &Var| -> VarTransform { + if var == iterator { + VarTransform::ReplaceWithExpr(SimpleExpr::Constant(ConstExpression::from(iterator_value))) + } else if internal_vars.contains(var) { + VarTransform::Rename(format!("@unrolled_{unroll_index}_{iterator_value}_{var}")) + } else { + VarTransform::Keep + } + }; + + transform_vars_in_lines(lines, &transform); } fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap) { @@ -2577,7 +3367,7 @@ fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap) Expression::Value(value) => match &value { SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) => { if let Some(const_value) = map.get(var) { - *value = SimpleExpr::scalar(const_value.to_usize()); + *value = SimpleExpr::scalar(*const_value); } } SimpleExpr::Memory(VarOrConstMallocAccess::ConstMallocAccess { .. }) => { @@ -2612,20 +3402,20 @@ fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap) fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { for line in lines { match line { - Line::Match { value, arms } => { + Line::Match { value, arms, .. } => { replace_vars_by_const_in_expr(value, map); for (_, statements) in arms { replace_vars_by_const_in_lines(statements, map); } } - Line::ForwardDeclaration { var } => { + Line::ForwardDeclaration { var, .. } => { assert!(!map.contains_key(var), "Variable {var} is a constant"); } Line::Statement { targets, value, .. } => { replace_vars_by_const_in_expr(value, map); for target in targets { match target { - AssignmentTarget::Var(var) => { + AssignmentTarget::Var { var, .. } => { assert!(!map.contains_key(var), "Variable {var} is a constant"); } AssignmentTarget::ArrayAccess { array, index } => { @@ -2639,7 +3429,7 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { condition, then_branch, else_branch, - line_number: _, + .. } => { match condition { Condition::Comparison(cond) => { @@ -2667,29 +3457,11 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { replace_vars_by_const_in_expr(ret, map); } } - Line::Precompile { table: _, args } => { - for arg in args { - replace_vars_by_const_in_expr(arg, map); - } - } - Line::Print { content, .. } => { - for var in content { - replace_vars_by_const_in_expr(var, map); - } - } - Line::CustomHint(_, decomposed_args) => { - for expr in decomposed_args { - replace_vars_by_const_in_expr(expr, map); - } - } - Line::PrivateInputStart { result } => { - assert!(!map.contains_key(result), "Variable {result} is a constant"); + Line::LocationReport { .. } | Line::Panic { .. } => {} + Line::VecDeclaration { .. } | Line::Push { .. } | Line::Pop { .. } => { + // VecDeclaration, Push and Pop contain VecLiteral elements which may have expressions + // but these are compile-time constructs handled separately } - Line::MAlloc { var, size, .. } => { - assert!(!map.contains_key(var), "Variable {var} is a constant"); - replace_vars_by_const_in_expr(size, map); - } - Line::Panic | Line::Break | Line::LocationReport { .. } => {} } } } @@ -2722,21 +3494,27 @@ impl SimpleLine { let arms_str = arms .iter() .enumerate() - .map(|(pattern, stmt)| { + .map(|(index, body)| { + let body = body + .iter() + .map(|line| line.to_string_with_indent(indent + 2)) + .collect::>() + .join("\n"); + format!( - "{} => {}", - pattern, - stmt.iter() - .map(|line| line.to_string_with_indent(indent + 1)) - .collect::>() - .join("\n") + "{}{} => {{{}\n{}}}", + " ".repeat(indent + 1), + index, + body, + " ".repeat(indent + 1), ) }) .collect::>() - .join(", "); + .join("\n"); format!("match {value} {{\n{arms_str}\n{spaces}}}") } + Self::Assignment { var, operation, @@ -2762,7 +3540,7 @@ impl SimpleLine { condition, then_branch, else_branch, - line_number: _, + .. } => { let then_str = then_branch .iter() @@ -2786,7 +3564,7 @@ impl SimpleLine { function_name, args, return_data, - line_number: _, + .. } => { let args_str = args.iter().map(|arg| format!("{arg}")).collect::>().join(", "); let return_data_str = return_data @@ -2823,29 +3601,23 @@ impl SimpleLine { let content_str = content.iter().map(|c| format!("{c}")).collect::>().join(", "); format!("print({content_str})") } - Self::HintMAlloc { - var, - size, - vectorized, - vectorized_len, - } => { - if *vectorized { - format!("{var} = malloc_vec({size}, {vectorized_len})") - } else { - format!("{var} = malloc({size})") - } + Self::HintMAlloc { var, size } => { + format!("{var} = Array({size})") } Self::ConstMalloc { var, size, label: _ } => { - format!("{var} = malloc({size})") + format!("{var} = Array({size})") } - Self::PrivateInputStart { result } => { - format!("private_input_start({result})") - } - Self::Panic => "panic".to_string(), + Self::Panic { message } => match message { + Some(msg) => format!("assert False, \"{msg}\""), + None => "assert False".to_string(), + }, Self::LocationReport { .. } => Default::default(), Self::DebugAssert(bool, _) => { format!("debug_assert({bool})") } + Self::RangeCheck { val, bound } => { + format!("range_check({val} <= {bound})") + } }; format!("{spaces}{line_str}") } @@ -2868,11 +3640,11 @@ impl Display for SimpleFunction { .join("\n"); if self.instructions.is_empty() { - write!(f, "fn {}({}) -> {} {{}}", self.name, args_str, self.n_returned_vars) + write!(f, "def {}({}) -> {} {{}}", self.name, args_str, self.n_returned_vars) } else { write!( f, - "fn {}({}) -> {} {{\n{}\n}}", + "def {}({}) -> {} {{\n{}\n}}", self.name, args_str, self.n_returned_vars, instructions_str ) } diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index 67e32992..0cb86aaf 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -145,7 +145,6 @@ fn compile_function( compiler.args_count = function.arguments.len(); compile_lines( - function.file_id, &Label::function(function.name.clone()), &function.instructions, compiler, @@ -154,7 +153,6 @@ fn compile_function( } fn compile_lines( - file_id: FileId, function_name: &Label, lines: &[SimpleLine], compiler: &mut Compiler, @@ -223,8 +221,7 @@ fn compile_lines( for arm in arms.iter() { compiler.stack_pos = saved_stack_pos; compiler.stack_frame_layout.scopes.push(ScopeLayout::default()); - let arm_instructions = - compile_lines(file_id, function_name, arm, compiler, Some(end_label.clone()))?; + let arm_instructions = compile_lines(function_name, arm, compiler, Some(end_label.clone()))?; compiled_arms.push(arm_instructions); compiler.stack_frame_layout.scopes.pop(); new_stack_pos = new_stack_pos.max(compiler.stack_pos); @@ -244,7 +241,7 @@ fn compile_lines( instructions.push(IntermediateInstruction::Computation { operation: Operation::Mul, arg_a: value_simplified, - arg_c: ConstExpression::Value(ConstantValue::MatchBlockSize { match_index }).into(), + arg_b: ConstExpression::Value(ConstantValue::MatchBlockSize { match_index }).into(), res: value_scaled_offset.clone(), }); @@ -255,7 +252,7 @@ fn compile_lines( instructions.push(IntermediateInstruction::Computation { operation: Operation::Add, arg_a: value_scaled_offset, - arg_c: ConstExpression::Value(ConstantValue::MatchFirstBlockStart { match_index }).into(), + arg_b: ConstExpression::Value(ConstantValue::MatchFirstBlockStart { match_index }).into(), res: jump_dest_offset.clone(), }); instructions.push(IntermediateInstruction::Jump { @@ -263,7 +260,7 @@ fn compile_lines( updated_fp: None, }); - let remaining = compile_lines(file_id, function_name, &lines[i + 1..], compiler, final_jump)?; + let remaining = compile_lines(function_name, &lines[i + 1..], compiler, final_jump)?; compiler.bytecode.insert(end_label, remaining); compiler.stack_frame_layout.scopes.pop(); @@ -278,15 +275,15 @@ fn compile_lines( condition, then_branch, else_branch, - line_number, + location, } => { let if_id = compiler.if_counter; compiler.if_counter += 1; let (if_label, else_label, end_label) = ( - Label::if_label(if_id, *line_number), - Label::else_label(if_id, *line_number), - Label::if_else_end(if_id, *line_number), + Label::if_label(if_id, *location), + Label::else_label(if_id, *location), + Label::if_else_end(if_id, *location), ); // c: condition @@ -306,7 +303,7 @@ fn compile_lines( instructions.push(IntermediateInstruction::Computation { operation: Operation::Mul, arg_a: condition_simplified.clone(), - arg_c: IntermediateValue::MemoryAfterFp { + arg_b: IntermediateValue::MemoryAfterFp { offset: condition_inverse_offset.into(), }, res: IntermediateValue::MemoryAfterFp { @@ -324,7 +321,7 @@ fn compile_lines( arg_a: IntermediateValue::MemoryAfterFp { offset: one_minus_product_offset.into(), }, - arg_c: IntermediateValue::MemoryAfterFp { + arg_b: IntermediateValue::MemoryAfterFp { offset: product_offset.into(), }, res: ConstExpression::one().into(), @@ -336,7 +333,7 @@ fn compile_lines( arg_a: IntermediateValue::MemoryAfterFp { offset: one_minus_product_offset.into(), }, - arg_c: condition_simplified, + arg_b: condition_simplified, res: ConstExpression::zero().into(), }); @@ -355,16 +352,14 @@ fn compile_lines( let saved_stack_pos = compiler.stack_pos; compiler.stack_frame_layout.scopes.push(ScopeLayout::default()); - let then_instructions = - compile_lines(file_id, function_name, then_branch, compiler, Some(end_label.clone()))?; + let then_instructions = compile_lines(function_name, then_branch, compiler, Some(end_label.clone()))?; let then_stack_pos = compiler.stack_pos; compiler.stack_pos = saved_stack_pos; compiler.stack_frame_layout.scopes.pop(); compiler.stack_frame_layout.scopes.push(ScopeLayout::default()); - let else_instructions = - compile_lines(file_id, function_name, else_branch, compiler, Some(end_label.clone()))?; + let else_instructions = compile_lines(function_name, else_branch, compiler, Some(end_label.clone()))?; compiler.bytecode.insert(if_label, then_instructions); compiler.bytecode.insert(else_label, else_instructions); @@ -372,7 +367,7 @@ fn compile_lines( compiler.stack_frame_layout.scopes.pop(); compiler.stack_pos = compiler.stack_pos.max(then_stack_pos); - let remaining = compile_lines(file_id, function_name, &lines[i + 1..], compiler, final_jump)?; + let remaining = compile_lines(function_name, &lines[i + 1..], compiler, final_jump)?; compiler.bytecode.insert(end_label, remaining); // It is not necessary to update compiler.stack_size here because the preceding call to // compile_lines should have done so. @@ -405,12 +400,11 @@ fn compile_lines( function_name: callee_function_name, args, return_data, - line_number, + location, } => { let call_id = compiler.call_counter; compiler.call_counter += 1; - let return_label = Label::return_from_call(call_id, *line_number); - + let return_label = Label::return_from_call(call_id, *location); let new_fp_pos = compiler.stack_pos; compiler.stack_pos += 1; @@ -446,13 +440,7 @@ fn compile_lines( }); } - instructions.extend(compile_lines( - file_id, - function_name, - &lines[i + 1..], - compiler, - final_jump, - )?); + instructions.extend(compile_lines(function_name, &lines[i + 1..], compiler, final_jump)?); instructions }; @@ -465,17 +453,34 @@ fn compile_lines( } SimpleLine::Precompile { table, args, .. } => { - if *table == Table::poseidon24_mem() { - assert_eq!(args.len(), 3); - } else { - assert_eq!(args.len(), 4); + match table { + Table::DotProduct(_) => assert_eq!(args.len(), 5), + Table::Poseidon16(_) => assert_eq!(args.len(), 4), + Table::Execution(_) => unreachable!(), } + // if arg_c is constant, create a variable (in memory) to hold it + let arg_c = if let SimpleExpr::Constant(cst) = &args[2] { + instructions.push(IntermediateInstruction::Computation { + operation: Operation::Add, + arg_a: IntermediateValue::Constant(cst.clone()), + arg_b: IntermediateValue::Constant(0.into()), + res: IntermediateValue::MemoryAfterFp { + offset: compiler.stack_pos.into(), + }, + }); + let offset = compiler.stack_pos; + compiler.stack_pos += 1; + IntermediateValue::MemoryAfterFp { offset: offset.into() } + } else { + IntermediateValue::from_simple_expr(&args[2], compiler) + }; instructions.push(IntermediateInstruction::Precompile { table: *table, arg_a: IntermediateValue::from_simple_expr(&args[0], compiler), arg_b: IntermediateValue::from_simple_expr(&args[1], compiler), - arg_c: IntermediateValue::from_simple_expr(&args[2], compiler), - aux: args.get(3).unwrap_or(&SimpleExpr::zero()).as_constant().unwrap(), + arg_c, + aux_1: args.get(3).unwrap_or(&SimpleExpr::zero()).as_constant().unwrap(), + aux_2: args.get(4).unwrap_or(&SimpleExpr::zero()).as_constant().unwrap(), }); } @@ -489,7 +494,7 @@ fn compile_lines( instructions.push(IntermediateInstruction::Computation { operation: Operation::Add, arg_a: IntermediateValue::Constant(0.into()), - arg_c: IntermediateValue::Constant(0.into()), + arg_b: IntermediateValue::Constant(0.into()), res: zero_value_offset.clone(), }); instructions.push(IntermediateInstruction::Jump { @@ -500,13 +505,13 @@ fn compile_lines( compile_function_ret(&mut instructions, return_data, compiler); } } - SimpleLine::Panic => instructions.push(IntermediateInstruction::Panic), - SimpleLine::HintMAlloc { - var, - size, - vectorized, - vectorized_len, - } => { + SimpleLine::Panic { message } => { + instructions.push(IntermediateInstruction::PanicHint { + message: message.clone(), + }); + instructions.push(IntermediateInstruction::Panic); + } + SimpleLine::HintMAlloc { var, size } => { if !compiler.is_in_scope(var) { let current_scope_layout = compiler.stack_frame_layout.scopes.last_mut().unwrap(); current_scope_layout @@ -517,8 +522,6 @@ fn compile_lines( instructions.push(IntermediateInstruction::RequestMemory { offset: compiler.get_offset(&var.clone().into()), size: IntermediateValue::from_simple_expr(size, compiler), - vectorized: *vectorized, - vectorized_len: IntermediateValue::from_simple_expr(vectorized_len, compiler), }); } SimpleLine::ConstMalloc { var, size, label } => { @@ -539,18 +542,6 @@ fn compile_lines( .collect::>(); instructions.push(IntermediateInstruction::CustomHint(*hint, simplified_args)); } - SimpleLine::PrivateInputStart { result } => { - if !compiler.is_in_scope(result) { - let current_scope_layout = compiler.stack_frame_layout.scopes.last_mut().unwrap(); - current_scope_layout - .var_positions - .insert(result.clone(), compiler.stack_pos); - compiler.stack_pos += 1; - } - instructions.push(IntermediateInstruction::PrivateInputStart { - res_offset: compiler.get_offset(&result.clone().into()), - }); - } SimpleLine::Print { line_info, content } => { instructions.push(IntermediateInstruction::Print { line_info: line_info.clone(), @@ -563,17 +554,88 @@ fn compile_lines( SimpleLine::LocationReport { location } => { instructions.push(IntermediateInstruction::LocationReport { location: *location }); } - SimpleLine::DebugAssert(boolean, line_number) => { + SimpleLine::DebugAssert(boolean, location) => { let boolean_simplified = BooleanExpr { kind: boolean.kind, left: IntermediateValue::from_simple_expr(&boolean.left, compiler), right: IntermediateValue::from_simple_expr(&boolean.right, compiler), }; - let location = SourceLocation { - file_id, - line_number: *line_number, + instructions.push(IntermediateInstruction::DebugAssert(boolean_simplified, *location)); + } + SimpleLine::RangeCheck { val, bound } => { + // Range check for val <= bound compiles to: + // 1. DEREF: m[fp + aux1] = m[m[fp + val_offset]] - proves val < M + // 2. ADD: m[fp + val_offset] + m[fp + aux2] = bound - computes complement + // 3. DEREF: m[fp + aux3] = m[m[fp + aux2]] - proves complement < M + // + // DerefHint records constraints: memory[target] = memory[memory[src]] + // These are resolved at end of execution in correct order. + + // Get the offset of the value being range-checked + let val_offset = match val { + SimpleExpr::Memory(var_or_const) => compiler.get_offset(var_or_const), + SimpleExpr::Constant(val_const) => { + // For constants, we need to store in a temp variable first + let temp_offset = compiler.stack_pos; + compiler.stack_pos += 1; + instructions.push(IntermediateInstruction::Computation { + operation: Operation::Add, + arg_a: IntermediateValue::Constant(val_const.clone()), + arg_b: IntermediateValue::Constant(ConstExpression::zero()), + res: IntermediateValue::MemoryAfterFp { + offset: ConstExpression::from_usize(temp_offset), + }, + }); + ConstExpression::from_usize(temp_offset) + } }; - instructions.push(IntermediateInstruction::DebugAssert(boolean_simplified, location)); + + // Allocate 3 auxiliary cells + let aux1_offset = ConstExpression::from_usize(compiler.stack_pos); + compiler.stack_pos += 1; + let aux2_offset = ConstExpression::from_usize(compiler.stack_pos); + compiler.stack_pos += 1; + let aux3_offset = ConstExpression::from_usize(compiler.stack_pos); + compiler.stack_pos += 1; + + // DerefHint for first DEREF: memory[aux1] = memory[memory[val_offset]] + instructions.push(IntermediateInstruction::DerefHint { + offset_src: val_offset.clone(), + offset_target: aux1_offset.clone(), + }); + + // 1. DEREF: m[fp + aux1] = m[m[fp + val_offset]] + instructions.push(IntermediateInstruction::Deref { + shift_0: val_offset.clone(), + shift_1: ConstExpression::zero(), + res: IntermediateValue::MemoryAfterFp { offset: aux1_offset }, + }); + + // 2. ADD: m[fp + val_offset] + m[fp + aux2] = bound + let bound_value = IntermediateValue::from_simple_expr(bound, compiler); + instructions.push(IntermediateInstruction::Computation { + operation: Operation::Add, + arg_a: IntermediateValue::MemoryAfterFp { + offset: val_offset.clone(), + }, + arg_b: IntermediateValue::MemoryAfterFp { + offset: aux2_offset.clone(), + }, + res: bound_value, + }); + + // DerefHint for second DEREF: memory[aux3] = memory[memory[aux2]] + instructions.push(IntermediateInstruction::DerefHint { + offset_src: aux2_offset.clone(), + offset_target: aux3_offset.clone(), + }); + + // 3. DEREF: m[fp + aux3] = m[m[fp + aux2]] + instructions.push(IntermediateInstruction::Deref { + shift_0: aux2_offset, + shift_1: ConstExpression::zero(), + res: IntermediateValue::MemoryAfterFp { offset: aux3_offset }, + }); } } } @@ -602,7 +664,7 @@ fn handle_const_malloc( instructions.push(IntermediateInstruction::Computation { operation: Operation::Add, arg_a: IntermediateValue::Constant(compiler.stack_pos.into()), - arg_c: IntermediateValue::Fp, + arg_b: IntermediateValue::Fp, res: IntermediateValue::MemoryAfterFp { offset: compiler.get_offset(&var.clone().into()), }, @@ -621,8 +683,6 @@ fn setup_function_call( IntermediateInstruction::RequestMemory { offset: new_fp_pos.into(), size: ConstExpression::function_size(Label::function(func_name)).into(), - vectorized: false, - vectorized_len: IntermediateValue::Constant(ConstExpression::zero()), }, IntermediateInstruction::Deref { shift_0: new_fp_pos.into(), diff --git a/crates/lean_compiler/src/c_compile_final.rs b/crates/lean_compiler/src/c_compile_final.rs index a5e2918f..b1136066 100644 --- a/crates/lean_compiler/src/c_compile_final.rs +++ b/crates/lean_compiler/src/c_compile_final.rs @@ -1,4 +1,4 @@ -use crate::{F, NONRESERVED_PROGRAM_INPUT_START, ZERO_VEC_PTR, ir::*, lang::*}; +use crate::{F, ir::*, lang::*}; use lean_vm::*; use multilinear_toolkit::prelude::*; use std::collections::BTreeMap; @@ -10,10 +10,11 @@ impl IntermediateInstruction { Self::RequestMemory { .. } | Self::Print { .. } | Self::CustomHint { .. } - | Self::PrivateInputStart { .. } | Self::Inverse { .. } | Self::LocationReport { .. } - | Self::DebugAssert { .. } => true, + | Self::DebugAssert { .. } + | Self::DerefHint { .. } + | Self::PanicHint { .. } => true, Self::Computation { .. } | Self::Panic | Self::Deref { .. } @@ -48,7 +49,7 @@ pub fn compile_to_low_level_bytecode( let starting_frame_memory = *intermediate_bytecode .memory_size_per_function .get("main") - .expect("Missing main function"); + .ok_or("Missing main function")?; let mut hints = BTreeMap::new(); let mut label_to_pc = BTreeMap::new(); @@ -218,36 +219,34 @@ fn compile_block( IntermediateInstruction::Computation { operation, mut arg_a, - mut arg_c, + mut arg_b, res, } => { if let Some(arg_a_cst) = try_as_constant(&arg_a, compiler) - && let Some(arg_b_cst) = try_as_constant(&arg_c, compiler) + && let Some(arg_b_cst) = try_as_constant(&arg_b, compiler) { // res = constant +/x constant let op_res = operation.compute(arg_a_cst, arg_b_cst); - let res: MemOrFp = res.try_into_mem_or_fp(compiler).unwrap(); - low_level_bytecode.push(Instruction::Computation { operation: Operation::Add, arg_a: MemOrConstant::zero(), - arg_c: res, + arg_c: res.try_into_mem_or_fp(compiler).unwrap(), res: MemOrConstant::Constant(op_res), }); pc += 1; continue; } - if arg_c.is_constant() { - std::mem::swap(&mut arg_a, &mut arg_c); + if arg_b.is_constant() { + std::mem::swap(&mut arg_a, &mut arg_b); } low_level_bytecode.push(Instruction::Computation { operation, arg_a: try_as_mem_or_constant(&arg_a).unwrap(), - arg_c: try_as_mem_or_fp(&arg_c).unwrap(), + arg_c: try_as_mem_or_fp(&arg_b).unwrap(), res: try_as_mem_or_constant(&res).unwrap(), }); } @@ -281,7 +280,7 @@ fn compile_block( updated_fp, } => codegen_jump(hints, low_level_bytecode, condition, dest, updated_fp), IntermediateInstruction::Jump { dest, updated_fp } => { - let one = IntermediateValue::Constant(ConstExpression::Value(ConstantValue::Scalar(1))); + let one = ConstExpression::one().into(); codegen_jump(hints, low_level_bytecode, one, dest, updated_fp) } IntermediateInstruction::Precompile { @@ -289,19 +288,16 @@ fn compile_block( arg_a, arg_b, arg_c, - aux, + aux_1, + aux_2, } => { low_level_bytecode.push(Instruction::Precompile { table, arg_a: try_as_mem_or_constant(&arg_a).unwrap(), arg_b: try_as_mem_or_constant(&arg_b).unwrap(), arg_c: try_as_mem_or_fp(&arg_c).unwrap(), - aux: eval_const_expression_usize(&aux, compiler), - }); - } - IntermediateInstruction::PrivateInputStart { res_offset } => { - hints.entry(pc).or_default().push(Hint::PrivateInputStart { - res_offset: eval_const_expression_usize(&res_offset, compiler), + aux_1: eval_const_expression_usize(&aux_1, compiler), + aux_2: eval_const_expression_usize(&aux_2, compiler), }); } IntermediateInstruction::CustomHint(hint, args) => { @@ -320,20 +316,12 @@ fn compile_block( }; hints.entry(pc).or_default().push(hint); } - IntermediateInstruction::RequestMemory { - offset, - size, - vectorized, - vectorized_len, - } => { + IntermediateInstruction::RequestMemory { offset, size } => { let size = try_as_mem_or_constant(&size).unwrap(); - let vectorized_len = try_as_constant(&vectorized_len, compiler).unwrap().to_usize(); let hint = Hint::RequestMemory { function_name: function_name.clone(), offset: eval_const_expression_usize(&offset, compiler), - vectorized, size, - vectorized_len, }; hints.entry(pc).or_default().push(hint); } @@ -362,6 +350,20 @@ fn compile_block( ); hints.entry(pc).or_default().push(hint); } + IntermediateInstruction::DerefHint { + offset_src, + offset_target, + } => { + let hint = Hint::DerefHint { + offset_src: eval_const_expression_usize(&offset_src, compiler), + offset_target: eval_const_expression_usize(&offset_target, compiler), + }; + hints.entry(pc).or_default().push(hint); + } + IntermediateInstruction::PanicHint { message } => { + let hint = Hint::Panic { message }; + hints.entry(pc).or_default().push(hint); + } } if !instruction.is_hint() { @@ -376,10 +378,7 @@ fn count_real_instructions(instrs: &[IntermediateInstruction]) -> usize { fn eval_constant_value(constant: &ConstantValue, compiler: &Compiler) -> usize { match constant { - ConstantValue::Scalar(scalar) => *scalar, - ConstantValue::PublicInputStart => NONRESERVED_PROGRAM_INPUT_START, - ConstantValue::PointerToZeroVector => ZERO_VEC_PTR, - ConstantValue::PointerToOneVector => ONE_VEC_PTR, + ConstantValue::Scalar(scalar) => scalar.to_usize(), ConstantValue::FunctionSize { function_name } => { let func_name_str = match function_name { Label::Function(name) => name, diff --git a/crates/lean_compiler/src/grammar.pest b/crates/lean_compiler/src/grammar.pest index bc2905a6..bbb354e8 100644 --- a/crates/lean_compiler/src/grammar.pest +++ b/crates/lean_compiler/src/grammar.pest @@ -1,78 +1,111 @@ -WHITESPACE = _{ " " | "\t" | "\n" | "\r" } +WHITESPACE = _{ " " | "\t" } +// Comments are stripped by preprocessing (remove_comments function) +// Newlines are significant and represented as tokens + +// Newline marker - makes line boundaries significant +newline = _{ "" } // Program structure -program = { SOI ~ import_statement* ~ constant_declaration* ~ function* ~ EOI } +program = { SOI ~ (import_statement ~ newline)* ~ (constant_declaration ~ newline)* ~ function* ~ EOI } -// Imports -import_statement = { "import" ~ filepath ~ ";" } +import_statement = { "from" ~ module_path ~ "import" ~ "*" } +module_path = { identifier ~ ("." ~ identifier)* } // Constants -constant_declaration = { "const" ~ identifier ~ "=" ~ (array_literal | expression) ~ ";" } -array_literal = { "[" ~ (array_element ~ ("," ~ array_element)*)? ~ "]" } +constant_declaration = { identifier ~ "=" ~ (array_literal | expression) } +array_literal = { "[" ~ (array_element ~ ("," ~ array_element)* ~ ","?)? ~ "]" } array_element = { array_literal | expression } // Functions -function = { pragma? ~ "fn" ~ identifier ~ "(" ~ parameter_list? ~ ")" ~ inlined_statement? ~ return_count? ~ "{" ~ statement* ~ "}" } -pragma = { "#![assume_always_returns]" } -parameter_list = { parameter ~ ("," ~ parameter)* } -parameter = { (const_keyword)? ~ identifier } -const_keyword = { "const" } -inlined_statement = { "inline" } -return_count = { "->" ~ number } - -// Statements +decorator = { "@" ~ identifier ~ newline } +function = { decorator? ~ "def" ~ identifier ~ "(" ~ parameter_list? ~ ")" ~ ":" ~ newline ~ statement* ~ end_block } +parameter_list = { parameter ~ ("," ~ parameter)* ~ ","? } +parameter = { identifier ~ param_annotation? } +param_annotation = { ":" ~ ("Const" | "Mut") } +end_block = { "" ~ newline } + +// Statements - explicitly exclude end_block from being matched as a statement +// Simple statements end with newline, compound statements (with blocks) have their own structure statement = { + !"" ~ ( + if_statement | + for_statement | + match_statement | + simple_statement ~ newline + ) +} + +simple_statement = { forward_declaration | - if_statement | - for_statement | - match_statement | return_statement | - break_statement | - continue_statement | assert_statement | debug_assert_statement | + vec_declaration | + push_statement | + pop_statement | assignment } -return_statement = { "return" ~ (tuple_expression)? ~ ";" } +// Vector declaration: var = vec![...] (vectors are implicitly mutable for push) +vec_declaration = { identifier ~ "=" ~ vec_literal } + +// Push statement: vec_var.push(element) or vec_var[i][j].push(element) +push_statement = { push_target ~ "." ~ "push" ~ "(" ~ vec_element ~ ")" } +push_target = { identifier ~ ("[" ~ expression ~ "]")* } -break_statement = { "break" ~ ";" } -continue_statement = { "continue" ~ ";" } +// Pop statement: vec_var.pop() or vec_var[i][j].pop() +pop_statement = { pop_target ~ "." ~ "pop" ~ "(" ~ ")" } +pop_target = { identifier ~ ("[" ~ expression ~ "]")* } -forward_declaration = { "var" ~ identifier ~ ";" } +return_statement = { "return" ~ (("(" ~ tuple_expression ~ ")") | tuple_expression)? } -// General assignment: LHS is optional list of variables/array accesses, RHS is any expression -assignment = { (assignment_target_list ~ "=")? ~ expression ~ ";" } -assignment_target_list = { assignment_target ~ ("," ~ assignment_target)* } -assignment_target = { array_access_expr | identifier } -if_statement = { "if" ~ condition ~ "{" ~ statement* ~ "}" ~ else_if_clause* ~ else_clause? } +mut_keyword = @{ "mut" ~ !(ASCII_ALPHANUMERIC | "_") } +mut_annotation = { ":" ~ "Mut" } +im_annotation = { ":" ~ "Imu" } -condition = { assumed_bool_expr | comparison } +// Forward declaration: x: Imu or x: Mut (not followed by =) +forward_declaration = { identifier ~ (im_annotation | mut_annotation) ~ !("=") } + +// General assignment: LHS is optional, RHS is any expression +// Compound operators (+=, -=, *=, /=) only allow single target without mut (enforced in parser) +assignment = { (assignment_target_list ~ assign_op)? ~ expression } +assign_op = { "+=" | "-=" | "*=" | "/=" | "=" } +assignment_target_list = { ("(" ~ simple_target_list ~ ")") | simple_target_list } +simple_target_list = { assignment_target ~ ("," ~ assignment_target)* ~ ","? } +assignment_target = { (identifier ~ mut_annotation) | array_access_expr | identifier } + +if_statement = { "if" ~ condition ~ ":" ~ newline ~ statement* ~ end_block ~ else_if_clause* ~ else_clause? } + +condition = { assumed_bool_expr | "(" ~ comparison ~ ")" | comparison } assumed_bool_expr = { "!!assume_bool" ~ "(" ~ expression ~ ")" } // Comparisons (shared between conditions and assertions) comparison = { add_expr ~ comparison_op ~ add_expr } -comparison_op = { "==" | "!=" | "<" } +comparison_op = { "==" | "!=" | "<=" | "<" } -else_if_clause = { "else" ~ "if" ~ condition ~ "{" ~ statement* ~ "}" } +else_if_clause = { "else" ~ "if" ~ condition ~ ":" ~ newline ~ statement* ~ end_block } -else_clause = { "else" ~ "{" ~ statement* ~ "}" } +else_clause = { "else" ~ ":" ~ newline ~ statement* ~ end_block } -for_statement = { "for" ~ identifier ~ "in" ~ rev_clause? ~ expression ~ ".." ~ expression ~ unroll_clause? ~ "{" ~ statement* ~ "}" } -rev_clause = { "rev" } -unroll_clause = { "unroll" } +for_statement = { "for" ~ identifier ~ "in" ~ (unroll_range | range) ~ ":" ~ newline ~ statement* ~ end_block } +range = { "range" ~ "(" ~ expression ~ "," ~ expression ~ ")" } +unroll_range = { "unroll" ~ "(" ~ expression ~ "," ~ expression ~ ")" } -match_statement = { "match" ~ expression ~ "{" ~ match_arm* ~ "}" } -match_arm = { pattern ~ "=>" ~ "{" ~ statement* ~ "}" } +match_statement = { "match" ~ expression ~ ":" ~ newline ~ match_arm* ~ end_block } +match_arm = { "case" ~ pattern ~ ":" ~ newline ~ statement* ~ end_block } pattern = { constant_value } -assert_statement = { "assert" ~ comparison ~ ";" } -debug_assert_statement = { "debug_assert" ~ comparison ~ ";" } +assert_keyword = @{ "assert" ~ !(ASCII_ALPHANUMERIC | "_") } +debug_assert_keyword = @{ "debug_assert" ~ !(ASCII_ALPHANUMERIC | "_") } +assert_statement = { assert_keyword ~ (assert_false | "(" ~ comparison ~ ")" | comparison) } +assert_false = { "False" ~ ("," ~ string_literal)? } +debug_assert_statement = { debug_assert_keyword ~ "(" ~ comparison ~ ")" } +string_literal = @{ "\"" ~ (!"\"" ~ ANY)* ~ "\"" } // Expressions -tuple_expression = { expression ~ ("," ~ expression)* } +tuple_expression = { expression ~ ("," ~ expression)* ~ ","? } expression = { add_expr } add_expr = { sub_expr ~ ("+" ~ sub_expr)* } sub_expr = { mul_expr ~ ("-" ~ mul_expr)* } @@ -90,6 +123,10 @@ primary = { function_call_expr | var_or_constant } + +// DynArray literal: DynArray([elem1, elem2, ...]) - compile-time dynamic arrays +vec_literal = { "DynArray" ~ "(" ~ "[" ~ (vec_element ~ ("," ~ vec_element)* ~ ","?)? ~ "]" ~ ")" } +vec_element = { vec_literal | expression } function_call_expr = { identifier ~ "(" ~ tuple_expression? ~ ")" } log2_ceil_expr = { "log2_ceil" ~ "(" ~ expression ~ ")" } next_multiple_of_expr = { "next_multiple_of" ~ "(" ~ expression ~ "," ~ expression ~ ")" } @@ -100,10 +137,8 @@ array_access_expr = { identifier ~ ("[" ~ expression ~ "]")+ } // Basic elements var_or_constant = { constant_value | identifier } -constant_value = { number | "public_input_start" | "pointer_to_zero_vector" | "pointer_to_one_vector" } +constant_value = { number } // Lexical elements identifier = @{ (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_")* } number = @{ ASCII_DIGIT+ } -filepath = { "\"" ~ filepath_character* ~ "\"" } -filepath_character = { ASCII_ALPHANUMERIC | "-" | "_" | " " | "." | "+" | "/" } diff --git a/crates/lean_compiler/src/ir/instruction.rs b/crates/lean_compiler/src/ir/instruction.rs index 02fc285d..c6a3ecd3 100644 --- a/crates/lean_compiler/src/ir/instruction.rs +++ b/crates/lean_compiler/src/ir/instruction.rs @@ -9,7 +9,7 @@ pub enum IntermediateInstruction { Computation { operation: Operation, arg_a: IntermediateValue, - arg_c: IntermediateValue, + arg_b: IntermediateValue, res: IntermediateValue, }, Deref { @@ -32,7 +32,8 @@ pub enum IntermediateInstruction { arg_a: IntermediateValue, arg_b: IntermediateValue, arg_c: IntermediateValue, - aux: ConstExpression, + aux_1: ConstExpression, + aux_2: ConstExpression, }, // HINTS (does not appears in the final bytecode) Inverse { @@ -43,12 +44,14 @@ pub enum IntermediateInstruction { RequestMemory { offset: ConstExpression, // m[fp + offset] where the hint will be stored size: IntermediateValue, // the hint - vectorized: bool, // if true, will be (2^vectorized_len)-alligned, and the returned pointer will be "divied" by 2^vectorized_len - vectorized_len: IntermediateValue, }, CustomHint(CustomHint, Vec), - PrivateInputStart { - res_offset: ConstExpression, + /// Deref hint for range checks - records constraint resolved at end of execution + DerefHint { + /// Offset of cell containing the address to dereference + offset_src: ConstExpression, + /// Offset of cell where result will be stored + offset_target: ConstExpression, }, Print { line_info: String, // information about the line where the print occurs @@ -59,38 +62,41 @@ pub enum IntermediateInstruction { location: SourceLocation, }, DebugAssert(BooleanExpr, SourceLocation), + PanicHint { + message: Option, + }, } impl IntermediateInstruction { pub fn computation( operation: MathOperation, arg_a: IntermediateValue, - arg_c: IntermediateValue, + arg_b: IntermediateValue, res: IntermediateValue, ) -> Self { match operation { MathOperation::Add => Self::Computation { operation: Operation::Add, arg_a, - arg_c, + arg_b, res, }, MathOperation::Mul => Self::Computation { operation: Operation::Mul, arg_a, - arg_c, + arg_b, res, }, MathOperation::Sub => Self::Computation { operation: Operation::Add, arg_a: res, - arg_c, + arg_b, res: arg_a, }, MathOperation::Div => Self::Computation { operation: Operation::Mul, arg_a: res, - arg_c, + arg_b, res: arg_a, }, MathOperation::Exp @@ -107,7 +113,7 @@ impl IntermediateInstruction { Self::Computation { operation: Operation::Add, arg_a: left, - arg_c: IntermediateValue::Constant(ConstExpression::zero()), + arg_b: IntermediateValue::Constant(ConstExpression::zero()), res: right, } } @@ -119,13 +125,13 @@ impl Display for IntermediateInstruction { Self::Computation { operation, arg_a, - arg_c, + arg_b, res, } => { - write!(f, "{res} = {arg_a} {operation} {arg_c}") + write!(f, "{res} = {arg_a} {operation} {arg_b}") } Self::Deref { shift_0, shift_1, res } => write!(f, "{res} = m[m[fp + {shift_0}] + {shift_1}]"), - Self::Panic => write!(f, "panic"), + Self::Panic => write!(f, "assert False"), Self::Jump { dest, updated_fp } => { if let Some(fp) = updated_fp { write!(f, "jump {dest} with fp = {fp}") @@ -133,9 +139,6 @@ impl Display for IntermediateInstruction { write!(f, "jump {dest}") } } - Self::PrivateInputStart { res_offset } => { - write!(f, "m[fp + {res_offset}] = private_input_start()") - } Self::JumpIfNotZero { condition, dest, @@ -152,24 +155,16 @@ impl Display for IntermediateInstruction { arg_a, arg_b, arg_c, - aux, + aux_1, + aux_2, } => { - write!(f, "{}({arg_a}, {arg_b}, {arg_c}, {aux})", table.name()) + write!(f, "{}({arg_a}, {arg_b}, {arg_c}, {aux_1}, {aux_2})", table.name()) } Self::Inverse { arg, res_offset } => { write!(f, "m[fp + {res_offset}] = inverse({arg})") } - Self::RequestMemory { - offset, - size, - vectorized, - vectorized_len, - } => { - if *vectorized { - write!(f, "m[fp + {offset}] = request_memory_vec({size}, {vectorized_len})") - } else { - write!(f, "m[fp + {offset}] = request_memory({size})") - } + Self::RequestMemory { offset, size } => { + write!(f, "m[fp + {offset}] = request_memory({size})") } Self::CustomHint(hint, args) => { write!(f, "{}(", hint.name())?; @@ -195,6 +190,16 @@ impl Display for IntermediateInstruction { Self::DebugAssert(boolean_expr, _) => { write!(f, "debug_assert {boolean_expr}") } + Self::DerefHint { + offset_src, + offset_target, + } => { + write!(f, "m[fp + {offset_target}] = m[m[fp + {offset_src}]]") + } + Self::PanicHint { message } => match message { + Some(msg) => write!(f, "panic hint: \"{msg}\""), + None => write!(f, "panic hint"), + }, } } } diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index 4afe46a0..927484e1 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -5,7 +5,7 @@ use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Display, Formatter}; use utils::ToUsize; -use crate::a_simplify_lang::VarOrConstMallocAccess; +use crate::a_simplify_lang::{VarOrConstMallocAccess, VectorLenTracker}; use crate::{F, parser::ConstArrayValue}; pub use lean_vm::{FileId, FunctionName, SourceLocation}; @@ -18,20 +18,65 @@ pub struct Program { pub filepaths: BTreeMap, } -#[derive(Debug, Clone)] +impl Program { + pub fn inlined_function_names(&self) -> BTreeSet { + self.functions + .iter() + .filter(|(_, func)| func.inlined) + .map(|(name, _)| name.clone()) + .collect() + } + + pub fn non_constant_functions_mut(&mut self) -> BTreeSet<&mut Function> { + self.functions + .iter_mut() + .filter(|(_, func)| !func.has_const_arguments()) + .map(|(_, func)| func) + .collect() + } +} + +/// A function argument with its modifiers +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct FunctionArg { + pub name: Var, + pub is_const: bool, + pub is_mutable: bool, +} + +impl FunctionArg { + pub fn new(name: Var, is_const: bool, is_mutable: bool) -> Self { + Self { + name, + is_const, + is_mutable, + } + } + + pub fn simple(name: Var) -> Self { + Self { + name, + is_const: false, + is_mutable: false, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct Function { pub name: String, - pub file_id: FileId, - pub arguments: Vec<(Var, bool)>, // (name, is_const) + pub arguments: Vec, pub inlined: bool, pub n_returned_vars: usize, pub body: Vec, - pub assume_always_returns: bool, } impl Function { pub fn has_const_arguments(&self) -> bool { - self.arguments.iter().any(|(_, is_const)| *is_const) + self.arguments.iter().any(|arg| arg.is_const) + } + pub fn has_mutable_arguments(&self) -> bool { + self.arguments.iter().any(|arg| arg.is_mutable) } } @@ -46,14 +91,14 @@ pub enum SimpleExpr { impl SimpleExpr { pub fn zero() -> Self { - Self::scalar(0) + Self::scalar(F::ZERO) } pub fn one() -> Self { - Self::scalar(1) + Self::scalar(F::ONE) } - pub fn scalar(scalar: usize) -> Self { + pub fn scalar(scalar: F) -> Self { Self::Constant(ConstantValue::Scalar(scalar).into()) } @@ -103,10 +148,7 @@ impl SimpleExpr { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum ConstantValue { - Scalar(usize), - PublicInputStart, - PointerToZeroVector, // In the memory of chunks of 8 field elements - PointerToOneVector, // In the memory of chunks of 8 field elements + Scalar(F), FunctionSize { function_name: Label }, Label(Label), MatchBlockSize { match_index: usize }, @@ -121,7 +163,7 @@ pub enum ConstExpression { impl From for ConstExpression { fn from(value: usize) -> Self { - Self::Value(ConstantValue::Scalar(value)) + Self::Value(ConstantValue::Scalar(F::from_usize(value))) } } @@ -148,21 +190,25 @@ impl TryFrom for ConstExpression { impl ConstExpression { pub const fn zero() -> Self { - Self::scalar(0) + Self::scalar(F::ZERO) } pub const fn one() -> Self { - Self::scalar(1) + Self::scalar(F::ONE) } pub const fn label(label: Label) -> Self { Self::Value(ConstantValue::Label(label)) } - pub const fn scalar(scalar: usize) -> Self { + pub const fn scalar(scalar: F) -> Self { Self::Value(ConstantValue::Scalar(scalar)) } + pub fn from_usize(value: usize) -> Self { + Self::Value(ConstantValue::Scalar(F::from_usize(value))) + } + pub const fn function_size(function_name: Label) -> Self { Self::Value(ConstantValue::FunctionSize { function_name }) } @@ -184,7 +230,7 @@ impl ConstExpression { pub fn naive_eval(&self) -> Option { self.eval_with(&|value| match value { - ConstantValue::Scalar(scalar) => Some(F::from_usize(*scalar)), + ConstantValue::Scalar(scalar) => Some(*scalar), _ => None, }) } @@ -211,6 +257,34 @@ impl Display for Condition { } } +impl Condition { + pub fn expressions_mut(&mut self) -> Vec<&mut Expression> { + match self { + Self::AssumeBoolean(expr) => vec![expr], + Self::Comparison(cmp) => vec![&mut cmp.left, &mut cmp.right], + } + } + + pub fn eval_with(&self, eval_expr: &impl Fn(&Expression) -> Option) -> Option { + match self { + Self::AssumeBoolean(expr) => { + let val = eval_expr(expr)?; + Some(val != F::ZERO) + } + Self::Comparison(cmp) => { + let left = eval_expr(&cmp.left)?; + let right = eval_expr(&cmp.right)?; + Some(match cmp.kind { + Boolean::Equal => left == right, + Boolean::Different => left != right, + Boolean::LessThan => left < right, + Boolean::LessOrEqual => left <= right, + }) + } + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Expression { Value(SimpleExpr), @@ -222,6 +296,7 @@ pub enum Expression { FunctionCall { function_name: String, args: Vec, + location: SourceLocation, }, Len { array: String, @@ -281,6 +356,9 @@ impl Display for MathOperation { } impl MathOperation { + pub fn is_unary(&self) -> bool { + self.num_args() == 1 + } pub fn num_args(&self) -> usize { match self { Self::Log2Ceil => 1, @@ -323,32 +401,34 @@ impl From for Expression { } } -impl From for Expression { - fn from(var: Var) -> Self { - Self::Value(var.into()) - } -} - impl Expression { - pub fn naive_eval(&self, const_arrays: &BTreeMap) -> Option { + pub fn compile_time_eval( + &self, + const_arrays: &BTreeMap, + vector_len: &VectorLenTracker, + ) -> Option { // Handle Len specially since it needs const_arrays if let Self::Len { array, indices } = self { - let idx: Option> = indices + let idx = indices .iter() - .map(|e| e.naive_eval(const_arrays).map(|f| f.to_usize())) - .collect(); - let idx = idx?; - let arr = const_arrays.get(array)?; - let target = arr.navigate(&idx)?; - return Some(F::from_usize(target.len())); + .map(|e| e.compile_time_eval(const_arrays, vector_len)) + .collect::>>()?; + if let Some(arr) = const_arrays.get(array) { + let target = arr.navigate(&idx)?; + return Some(F::from_usize(target.len())); + } + if let Some(arr) = vector_len.get(array) { + let target = arr.navigate(&idx)?; + return Some(F::from_usize(target.len())); + } + return None; } self.eval_with( &|value: &SimpleExpr| value.as_constant()?.naive_eval(), &|arr, indexes| { let array = const_arrays.get(arr)?; assert_eq!(indexes.len(), array.depth()); - let idx = indexes.iter().map(|e| e.to_usize()).collect::>(); - array.navigate(&idx)?.as_scalar().map(F::from_usize) + array.navigate(&indexes)?.as_scalar() }, ) } @@ -379,92 +459,175 @@ impl Expression { } } - pub fn scalar(scalar: usize) -> Self { + pub fn inner_exprs_mut(&mut self) -> Vec<&mut Self> { + match self { + Self::Value(_) => vec![], + Self::ArrayAccess { index, .. } => index.iter_mut().collect(), + Self::MathExpr(_, args) => args.iter_mut().collect(), + Self::FunctionCall { args, .. } => args.iter_mut().collect(), + Self::Len { indices, .. } => indices.iter_mut().collect(), + } + } + + pub fn inner_exprs(&self) -> Vec<&Self> { + match self { + Self::Value(_) => vec![], + Self::ArrayAccess { index, .. } => index.iter().collect(), + Self::MathExpr(_, args) => args.iter().collect(), + Self::FunctionCall { args, .. } => args.iter().collect(), + Self::Len { indices, .. } => indices.iter().collect(), + } + } + + pub fn var(var: Var) -> Self { + SimpleExpr::from(var).into() + } + + pub fn scalar(scalar: F) -> Self { SimpleExpr::scalar(scalar).into() } + pub fn as_scalar(&self) -> Option { + match self { + Self::Value(SimpleExpr::Constant(ConstExpression::Value(ConstantValue::Scalar(start_val)))) => { + Some(*start_val) + } + _ => None, + } + } + + pub fn is_scalar(&self) -> bool { + self.as_scalar().is_some() + } + pub fn zero() -> Self { - Self::scalar(0) + Self::scalar(F::ZERO) + } + + pub fn one() -> Self { + Self::scalar(F::ONE) } } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum AssignmentTarget { - Var(Var), - ArrayAccess { array: Var, index: Box }, + Var { var: Var, is_mutable: bool }, + ArrayAccess { array: Var, index: Box }, // always immutable } impl Display for AssignmentTarget { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Self::Var(var) => write!(f, "{var}"), + Self::Var { var, is_mutable } => { + if *is_mutable { + write!(f, "{var}: Mut") + } else { + write!(f, "{var}") + } + } Self::ArrayAccess { array, index } => write!(f, "{array}[{index}]"), } } } +impl AssignmentTarget { + pub fn index_expression_mut(&mut self) -> Option<&mut Expression> { + match self { + Self::Var { .. } => None, + Self::ArrayAccess { index, .. } => Some(index), + } + } +} + +/// A compile-time dynamic array literal: DynArray(elem1, elem2, ...) +/// Elements can be expressions or nested DynArray literals. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum VecLiteral { + /// A scalar expression element + Expr(Expression), + /// A nested vector literal + Vec(Vec), +} + +impl VecLiteral { + pub fn all_exprs_mut_in_slice(arr: &mut [Self]) -> Vec<&mut Expression> { + let mut exprs = Vec::new(); + for elem in arr { + match elem { + Self::Expr(expr) => exprs.push(expr), + Self::Vec(nested) => { + exprs.extend(Self::all_exprs_mut_in_slice(nested)); + } + } + } + exprs + } +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Line { Match { value: Expression, arms: Vec<(usize, Vec)>, + location: SourceLocation, }, ForwardDeclaration { var: Var, + is_mutable: bool, }, Statement { targets: Vec, // LHS - can be empty for standalone calls value: Expression, // RHS - any expression - line_number: SourceLineNumber, + location: SourceLocation, }, Assert { debug: bool, boolean: BooleanExpr, - line_number: SourceLineNumber, + location: SourceLocation, }, IfCondition { condition: Condition, then_branch: Vec, else_branch: Vec, - line_number: SourceLineNumber, + location: SourceLocation, }, ForLoop { iterator: Var, start: Expression, end: Expression, body: Vec, - rev: bool, unroll: bool, - line_number: SourceLineNumber, + location: SourceLocation, }, FunctionRet { return_data: Vec, }, - Precompile { - table: Table, - args: Vec, + Panic { + message: Option, }, - Break, - Panic, - // Hints: - Print { - line_info: String, - content: Vec, + // noop, debug purpose only + LocationReport { + location: SourceLocation, }, - MAlloc { + /// Compile-time dynamic array declaration: var = DynArray(...) + VecDeclaration { var: Var, - size: Expression, - vectorized: bool, - vectorized_len: Expression, + elements: Vec, + location: SourceLocation, }, - PrivateInputStart { - result: Var, + /// Compile-time vector push: push(vec_var, element) or push(vec_var[i][j], element) + Push { + vector: Var, + indices: Vec, + element: VecLiteral, + location: SourceLocation, }, - // noop, debug purpose only - LocationReport { + /// Compile-time vector pop: vec_var.pop() or vec_var[i][j].pop() + Pop { + vector: Var, + indices: Vec, location: SourceLocation, }, - CustomHint(CustomHint, Vec), } /// A context specifying which variables are in scope. @@ -523,7 +686,9 @@ impl Display for Expression { let args_str = args.iter().map(|arg| format!("{arg}")).collect::>().join(", "); write!(f, "{math_expr}({args_str})") } - Self::FunctionCall { function_name, args } => { + Self::FunctionCall { + function_name, args, .. + } => { let args_str = args.iter().map(|arg| format!("{arg}")).collect::>().join(", "); write!(f, "{function_name}({args_str})") } @@ -543,7 +708,7 @@ impl Line { // print nothing Default::default() } - Self::Match { value, arms } => { + Self::Match { value, arms, .. } => { let arms_str = arms .iter() .map(|(const_expr, body)| { @@ -552,14 +717,18 @@ impl Line { .map(|line| line.to_string_with_indent(indent + 1)) .collect::>() .join("\n"); - format!("{const_expr} => {{\n{body_str}\n{spaces}}}") + format!("case {const_expr}: {{\n{body_str}\n{spaces}}}") }) .collect::>() .join("\n"); - format!("match {value} {{\n{arms_str}\n{spaces}}}") + format!("match {value}: {{\n{arms_str}\n{spaces}}}") } - Self::ForwardDeclaration { var } => { - format!("var {var}") + Self::ForwardDeclaration { var, is_mutable } => { + if *is_mutable { + format!("{var}: Mut") + } else { + format!("{var}: Imu") + } } Self::Statement { targets, value, .. } => { if targets.is_empty() { @@ -573,19 +742,16 @@ impl Line { format!("{targets_str} = {value}") } } - Self::PrivateInputStart { result } => { - format!("{result} = private_input_start()") - } Self::Assert { debug, boolean, - line_number: _, + location: _, } => format!("{}assert {}", if *debug { "debug_" } else { "" }, boolean), Self::IfCondition { condition, then_branch, else_branch, - line_number: _, + location: _, } => { let then_str = then_branch .iter() @@ -610,24 +776,18 @@ impl Line { start, end, body, - rev, unroll, - line_number: _, + location: _, } => { let body_str = body .iter() .map(|line| line.to_string_with_indent(indent + 1)) .collect::>() .join("\n"); + let range_fn = if *unroll { "unroll" } else { "range" }; format!( - "for {} in {}{}..{} {}{{\n{}\n{}}}", - iterator, - start, - if *rev { "rev " } else { "" }, - end, - if *unroll { "unroll " } else { "" }, - body_str, - spaces + "for {} in {}({}, {}) {{\n{}\n{}}}", + iterator, range_fn, start, end, body_str, spaces ) } Self::FunctionRet { return_data } => { @@ -638,50 +798,117 @@ impl Line { .join(", "); format!("return {return_data_str}") } - Self::Precompile { - table: precompile, - args, + Self::Panic { message } => match message { + Some(msg) => format!("assert False, \"{msg}\""), + None => "assert False".to_string(), + }, + Self::VecDeclaration { var, elements, .. } => { + format!("{var} = DynArray({})", elements.len()) + } + Self::Push { + vector, + indices, + element, + .. } => { format!( - "{}({})", - precompile.name(), - args.iter().map(|arg| format!("{arg}")).collect::>().join(", ") + "{}[{}].push({})", + vector, + indices.iter().map(|i| format!("{i}")).collect::>().join("]["), + element ) } - Self::Print { line_info: _, content } => { - let content_str = content.iter().map(|c| format!("{c}")).collect::>().join(", "); - format!("print({content_str})") - } - Self::MAlloc { - var, - size, - vectorized, - vectorized_len, - } => { - if *vectorized { - format!("{var} = malloc_vec({size}, {vectorized_len})") + Self::Pop { vector, indices, .. } => { + if indices.is_empty() { + format!("{}.pop()", vector) } else { - format!("{var} = malloc({size})") + format!( + "{}[{}].pop()", + vector, + indices.iter().map(|i| format!("{i}")).collect::>().join("][") + ) } } - Self::CustomHint(hint, args) => { - let args_str = args.iter().map(|arg| format!("{arg}")).collect::>().join(", "); - format!("{}({args_str})", hint.name()) - } - Self::Break => "break".to_string(), - Self::Panic => "panic".to_string(), }; format!("{spaces}{line_str}") } + + pub fn nested_blocks(&self) -> Vec<&Vec> { + match self { + Self::Match { arms, .. } => arms.iter().map(|(_, body)| body).collect(), + Self::IfCondition { + then_branch, + else_branch, + .. + } => vec![then_branch, else_branch], + Self::ForLoop { body, .. } => vec![body], + Self::ForwardDeclaration { .. } + | Self::Statement { .. } + | Self::Assert { .. } + | Self::FunctionRet { .. } + | Self::Panic { .. } + | Self::LocationReport { .. } + | Self::VecDeclaration { .. } + | Self::Push { .. } + | Self::Pop { .. } => vec![], + } + } + + pub fn nested_blocks_mut(&mut self) -> Vec<&mut Vec> { + match self { + Self::Match { arms, .. } => arms.iter_mut().map(|(_, body)| body).collect(), + Self::IfCondition { + then_branch, + else_branch, + .. + } => vec![then_branch, else_branch], + Self::ForLoop { body, .. } => vec![body], + Self::ForwardDeclaration { .. } + | Self::Statement { .. } + | Self::Assert { .. } + | Self::FunctionRet { .. } + | Self::Panic { .. } + | Self::LocationReport { .. } + | Self::VecDeclaration { .. } + | Self::Push { .. } + | Self::Pop { .. } => vec![], + } + } + + /// Returns mutable references to all expressions contained in this line. + /// Does NOT include expressions inside nested blocks (use nested_blocks_mut for those). + pub fn expressions_mut(&mut self) -> Vec<&mut Expression> { + match self { + Self::Match { value, .. } => vec![value], + Self::Statement { targets, value, .. } => { + let mut exprs = vec![value]; + for target in targets { + if let Some(idx) = target.index_expression_mut() { + exprs.push(idx); + } + } + exprs + } + Self::Assert { boolean, .. } => vec![&mut boolean.left, &mut boolean.right], + Self::IfCondition { condition, .. } => condition.expressions_mut(), + Self::ForLoop { start, end, .. } => vec![start, end], + Self::FunctionRet { return_data } => return_data.iter_mut().collect(), + Self::Push { indices, element, .. } => { + let mut exprs = indices.iter_mut().collect::>(); + exprs.extend(VecLiteral::all_exprs_mut_in_slice(std::slice::from_mut(element))); + exprs + } + Self::Pop { indices, .. } => indices.iter_mut().collect(), + Self::VecDeclaration { elements, .. } => VecLiteral::all_exprs_mut_in_slice(elements), + Self::ForwardDeclaration { .. } | Self::Panic { .. } | Self::LocationReport { .. } => vec![], + } + } } impl Display for ConstantValue { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { Self::Scalar(scalar) => write!(f, "{scalar}"), - Self::PublicInputStart => write!(f, "@public_input_start"), - Self::PointerToZeroVector => write!(f, "@pointer_to_zero_vector"), - Self::PointerToOneVector => write!(f, "@pointer_to_one_vector"), Self::FunctionSize { function_name } => { write!(f, "@function_size_{function_name}") } @@ -696,6 +923,22 @@ impl Display for ConstantValue { } } +impl Display for VecLiteral { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Expr(expr) => write!(f, "{expr}"), + Self::Vec(elements) => { + let elements_str = elements + .iter() + .map(|elem| format!("{elem}")) + .collect::>() + .join(", "); + write!(f, "DynArray([{elements_str}])") + } + } + } +} + impl Display for SimpleExpr { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { @@ -765,9 +1008,14 @@ impl Display for Function { let args_str = self .arguments .iter() - .map(|arg| match arg { - (name, true) => format!("const {name}"), - (name, false) => name.to_string(), + .map(|arg| { + if arg.is_const { + format!("const {}", arg.name) + } else if arg.is_mutable { + format!("mut {}", arg.name) + } else { + arg.name.to_string() + } }) .collect::>() .join(", "); @@ -780,11 +1028,11 @@ impl Display for Function { .join("\n"); if self.body.is_empty() { - write!(f, "fn {}({}) -> {} {{}}", self.name, args_str, self.n_returned_vars) + write!(f, "def {}({}) -> {} {{}}", self.name, args_str, self.n_returned_vars) } else { write!( f, - "fn {}({}) -> {} {{\n{}\n}}", + "def {}({}) -> {} {{\n{}\n}}", self.name, args_str, self.n_returned_vars, instructions_str ) } diff --git a/crates/lean_compiler/src/lib.rs b/crates/lean_compiler/src/lib.rs index e7887969..a59e0d7c 100644 --- a/crates/lean_compiler/src/lib.rs +++ b/crates/lean_compiler/src/lib.rs @@ -1,4 +1,5 @@ use std::collections::BTreeMap; +use std::fmt; use lean_vm::*; @@ -14,6 +15,77 @@ pub mod ir; mod lang; mod parser; +pub use parser::{ParseError, RESERVED_FUNCTION_NAMES}; + +pub use lean_vm::RunnerError; + +#[derive(Debug)] +pub enum CompileError { + Parse(ParseError), + Compile(String), + Io(std::io::Error), +} + +impl fmt::Display for CompileError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Parse(e) => write!(f, "{e}"), + Self::Compile(e) => write!(f, "Compile error: {e}"), + Self::Io(e) => write!(f, "IO error: {e}"), + } + } +} + +impl std::error::Error for CompileError {} + +impl From for CompileError { + fn from(e: ParseError) -> Self { + Self::Parse(e) + } +} + +impl From for CompileError { + fn from(e: String) -> Self { + Self::Compile(e) + } +} + +impl From for CompileError { + fn from(e: std::io::Error) -> Self { + Self::Io(e) + } +} + +/// Error type for compile and run operations +#[derive(Debug)] +pub enum Error { + Compile(CompileError), + Runtime(RunnerError), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Compile(e) => write!(f, "{e}"), + Self::Runtime(e) => write!(f, "Runtime error: {e}"), + } + } +} + +impl std::error::Error for Error {} + +impl From for Error { + fn from(e: CompileError) -> Self { + Self::Compile(e) + } +} + +impl From for Error { + fn from(e: RunnerError) -> Self { + Self::Runtime(e) + } +} + #[derive(Debug, Clone)] pub enum ProgramSource { Raw(String), @@ -21,7 +93,7 @@ pub enum ProgramSource { } impl ProgramSource { - pub fn get_content(&self, flags: &CompilationFlags) -> Result { + pub fn get_content(&self, flags: &CompilationFlags) -> Result { match self { ProgramSource::Raw(src) => { let mut result = src.clone(); @@ -31,7 +103,7 @@ impl ProgramSource { Ok(result) } ProgramSource::Filepath(fp) => { - let mut result = std::fs::read_to_string(fp)?; + let mut result = std::fs::read_to_string(fp).map_err(|e| format!("Failed to read file {fp}: {e}"))?; for (key, value) in flags.replacements.iter() { result = result.replace(key, value); } @@ -47,61 +119,43 @@ pub struct CompilationFlags { pub replacements: BTreeMap, } -pub fn compile_program_with_flags(input: &ProgramSource, flags: CompilationFlags) -> Bytecode { - let parsed_program = parse_program(input, flags).unwrap(); - // println!("Parsed program: {}", parsed_program.to_string()); +pub fn try_compile_program_with_flags( + input: &ProgramSource, + flags: CompilationFlags, +) -> Result { + let parsed_program = parse_program(input, flags)?; let function_locations = parsed_program.function_locations.clone(); let source_code = parsed_program.source_code.clone(); let filepaths = parsed_program.filepaths.clone(); - let simple_program = simplify_program(parsed_program); - // println!("Simplified program: {}", simple_program); - let intermediate_bytecode = compile_to_intermediate_bytecode(simple_program).unwrap(); - // println!("Intermediate Bytecode:\n\n{}", intermediate_bytecode.to_string()); + let simple_program = simplify_program(parsed_program)?; + let intermediate_bytecode = compile_to_intermediate_bytecode(simple_program)?; + let bytecode = compile_to_low_level_bytecode(intermediate_bytecode, function_locations, source_code, filepaths)?; + Ok(bytecode) +} + +pub fn compile_program_with_flags(input: &ProgramSource, flags: CompilationFlags) -> Bytecode { + try_compile_program_with_flags(input, flags).unwrap() +} - // println!("Function Locations: \n"); - // for (loc, name) in function_locations.iter() { - // println!("{name}: {loc}"); - // } - /* let compiled = */ - compile_to_low_level_bytecode(intermediate_bytecode, function_locations, source_code, filepaths).unwrap() // ; - // println!("\n\nCompiled Program:\n\n{compiled}"); - // compiled +pub fn try_compile_program(input: &ProgramSource) -> Result { + try_compile_program_with_flags(input, Default::default()) } pub fn compile_program(input: &ProgramSource) -> Bytecode { - compile_program_with_flags(input, CompilationFlags::default()) + try_compile_program(input).unwrap() } -pub fn compile_and_run( - source: &ProgramSource, +pub fn try_compile_and_run( + input: &ProgramSource, (public_input, private_input): (&[F], &[F]), - no_vec_runtime_memory: usize, // size of the "non-vectorized" runtime memory profiler: bool, -) { - let bytecode = compile_program(source); - let summary = execute_bytecode( - &bytecode, - (public_input, private_input), - no_vec_runtime_memory, - profiler, - (&vec![], &vec![]), - Default::default(), - ) - .summary; - println!("{summary}"); +) -> Result { + let bytecode = try_compile_program(input)?; + let result = try_execute_bytecode(&bytecode, (public_input, private_input), profiler, &vec![])?; + Ok(result.summary) } -#[derive(Debug, Clone, Default)] -struct Counter(usize); - -impl Counter { - const fn next(&mut self) -> usize { - let val = self.0; - self.0 += 1; - val - } - - const fn new() -> Self { - Self(0) - } +pub fn compile_and_run(input: &ProgramSource, (public_input, private_input): (&[F], &[F]), profiler: bool) { + let summary = try_compile_and_run(input, (public_input, private_input), profiler).unwrap(); + println!("{summary}"); } diff --git a/crates/lean_compiler/src/parser/error.rs b/crates/lean_compiler/src/parser/error.rs index 01191b04..f300273c 100644 --- a/crates/lean_compiler/src/parser/error.rs +++ b/crates/lean_compiler/src/parser/error.rs @@ -9,6 +9,9 @@ pub enum ParseError { /// High-level semantic validation error SemanticError(SemanticError), + + /// IO error (e.g., file not found) + IoError(std::io::Error), } /// Semantic errors that occur during AST construction and validation. @@ -56,6 +59,12 @@ impl From for ParseError { } } +impl From for ParseError { + fn from(error: std::io::Error) -> Self { + Self::IoError(error) + } +} + impl Display for SemanticError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.message)?; @@ -71,6 +80,7 @@ impl Display for ParseError { match self { Self::SyntaxError(e) => write!(f, "Syntax error: {e}"), Self::SemanticError(e) => write!(f, "Semantic error: {e}"), + Self::IoError(e) => write!(f, "IO error: {e}"), } } } diff --git a/crates/lean_compiler/src/parser/lexer.rs b/crates/lean_compiler/src/parser/lexer.rs deleted file mode 100644 index 991c17e4..00000000 --- a/crates/lean_compiler/src/parser/lexer.rs +++ /dev/null @@ -1,51 +0,0 @@ -/// Preprocesses source code by removing comments and normalizing whitespace. -pub fn preprocess_source(input: &str) -> String { - input - .lines() - .map(|line| { - // Remove line comments (everything after //) - if let Some(pos) = line.find("//") { - line[..pos].trim_end() - } else { - line - } - }) - .collect::>() - .join("\n") -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_comment_removal() { - let input = r#" - let x = 5; // This is a comment - let y = 10; - // This whole line is a comment - let z = x + y; - "#; - - let expected = r#" - let x = 5; - let y = 10; - - let z = x + y; - "#; - - assert_eq!(preprocess_source(input), expected); - } - - #[test] - fn test_no_comments() { - let input = "let x = 5;\nlet y = 10;"; - assert_eq!(preprocess_source(input), input); - } - - #[test] - fn test_empty_lines_preserved() { - let input = "line1\n\nline3"; - assert_eq!(preprocess_source(input), input); - } -} diff --git a/crates/lean_compiler/src/parser/mod.rs b/crates/lean_compiler/src/parser/mod.rs index 8a0e1f2b..b69fb97d 100644 --- a/crates/lean_compiler/src/parser/mod.rs +++ b/crates/lean_compiler/src/parser/mod.rs @@ -2,8 +2,8 @@ mod error; mod grammar; -mod lexer; mod parsers; +pub use error::ParseError; pub use parsers::ConstArrayValue; -pub use parsers::program::parse_program; +pub use parsers::{function::RESERVED_FUNCTION_NAMES, program::parse_program}; diff --git a/crates/lean_compiler/src/parser/parsers/expression.rs b/crates/lean_compiler/src/parser/parsers/expression.rs index 28473e5f..987b9cd9 100644 --- a/crates/lean_compiler/src/parser/parsers/expression.rs +++ b/crates/lean_compiler/src/parser/parsers/expression.rs @@ -1,8 +1,11 @@ +use lean_vm::{F, SourceLocation}; +use multilinear_toolkit::prelude::*; + use super::literal::{VarOrConstantParser, evaluate_const_expr}; use super::{ConstArrayValue, Parse, ParseContext, next_inner_pair}; use crate::lang::MathOperation; use crate::{ - lang::{ConstExpression, ConstantValue, Expression, SimpleExpr}, + lang::{ConstExpression, ConstantValue, Expression, SimpleExpr, VecLiteral}, parser::{ error::{ParseResult, SemanticError}, grammar::{ParsePair, Rule}, @@ -45,6 +48,10 @@ impl Parse for MathOperation { let mut inner = pair.into_inner(); let mut expr = ExpressionParser.parse(next_inner_pair(&mut inner, "math expr left")?, ctx)?; + if self.is_unary() { + return Ok(Expression::MathExpr(*self, vec![expr])); + } + for right in inner { let right_expr = ExpressionParser.parse(right, ctx)?; expr = Expression::MathExpr(*self, vec![expr, right_expr]); @@ -58,6 +65,7 @@ pub struct FunctionCallExprParser; impl Parse for FunctionCallExprParser { fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { + let line_number = pair.line_col().0; let mut inner = pair.into_inner(); let function_name = next_inner_pair(&mut inner, "function name")?.as_str().to_string(); @@ -70,7 +78,14 @@ impl Parse for FunctionCallExprParser { Vec::new() }; - Ok(Expression::FunctionCall { function_name, args }) + Ok(Expression::FunctionCall { + function_name, + args, + location: SourceLocation { + file_id: ctx.current_file_id, + line_number, + }, + }) } } @@ -90,7 +105,7 @@ impl Parse for ArrayAccessParser { } } -/// Parser for len() expressions on const arrays (supports indexed access like len(ARR[i])). +/// Parser for len() expressions on const arrays and vectors (supports indexed access like len(ARR[i])). pub struct LenParser; impl Parse for LenParser { @@ -104,67 +119,99 @@ impl Parse for LenParser { .as_str() .to_string(); - // Check if the array exists - if ctx.get_const_array(&ident).is_none() { - return Err(SemanticError::with_context( - format!("len() argument '{ident}' is not a const array"), - "len expression", - ) - .into()); - } - let mut index_exprs = Vec::new(); for index_pair in arg_inner { index_exprs.push(ExpressionParser.parse(index_pair, ctx)?); } - // Try to evaluate indices at parse time - let mut indices = Vec::new(); - let mut all_const = true; - for index_expr in &index_exprs { - if let Some(index_val) = evaluate_const_expr(index_expr, ctx) { - indices.push(index_val); - } else { - all_const = false; - break; + // Check if this is a const array - if so, try to evaluate at parse time + if let Some(base_array) = ctx.get_const_array(&ident) { + // Try to evaluate indices at parse time + let mut indices = Vec::new(); + let mut all_const = true; + for index_expr in &index_exprs { + if let Some(index_val) = evaluate_const_expr(index_expr, ctx) { + indices.push(index_val); + } else { + all_const = false; + break; + } + } + + // If all indices are constants, evaluate len() now + if all_const { + let target = if indices.is_empty() { + base_array + } else { + base_array.navigate(&indices).ok_or_else(|| { + SemanticError::with_context( + format!( + "len() index out of bounds for '{ident}': [{}]", + indices.iter().map(|i| i.to_string()).collect::>().join("][") + ), + "len expression", + ) + })? + }; + + let length = match target { + ConstArrayValue::Scalar(_) => { + return Err(SemanticError::with_context( + "Cannot call len() on a scalar value", + "len expression", + ) + .into()); + } + ConstArrayValue::Array(arr) => arr.len(), + }; + + return Ok(Expression::Value(SimpleExpr::Constant(ConstExpression::Value( + ConstantValue::Scalar(F::from_usize(length)), + )))); } } - // If all indices are constants, evaluate len() now - if all_const { - let base_array = ctx.get_const_array(&ident).unwrap(); - let target = if indices.is_empty() { - base_array - } else { - base_array.navigate(&indices).ok_or_else(|| { - SemanticError::with_context( - format!( - "len() index out of bounds for '{ident}': [{}]", - indices.iter().map(|i| i.to_string()).collect::>().join("][") - ), - "len expression", - ) - })? - }; - - let length = match target { - ConstArrayValue::Scalar(_) => { - return Err( - SemanticError::with_context("Cannot call len() on a scalar value", "len expression").into(), - ); - } - ConstArrayValue::Array(arr) => arr.len(), - }; + // Defer evaluation for non-const arrays (could be vectors) or non-const indices + Ok(Expression::Len { + array: ident, + indices: index_exprs, + }) + } +} - Ok(Expression::Value(SimpleExpr::Constant(ConstExpression::Value( - ConstantValue::Scalar(length), - )))) - } else { - // Defer evaluation - return Expression::Len - Ok(Expression::Len { - array: ident, - indices: index_exprs, - }) +/// Parser for vec![...] literals (compile-time vectors) +/// Parses into the VecLiteral enum (separate from Expression) +pub struct VecLiteralParser; + +impl Parse for VecLiteralParser { + fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { + // vec_literal = { "vec!" ~ "[" ~ (vec_element ~ ("," ~ vec_element)*)? ~ "]" } + // vec_element = { vec_literal | expression } + let elements: Vec = pair + .into_inner() + .map(|elem_pair| VecElementParser.parse(elem_pair, ctx)) + .collect::, _>>()?; + + Ok(VecLiteral::Vec(elements)) + } +} + +/// Parser for vec element (either a nested vec_literal or an expression) +pub struct VecElementParser; + +impl Parse for VecElementParser { + fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { + match pair.as_rule() { + Rule::vec_element => { + // vec_element contains either vec_literal or expression + let inner = next_inner_pair(&mut pair.into_inner(), "vec element")?; + match inner.as_rule() { + Rule::vec_literal => VecLiteralParser.parse(inner, ctx), + _ => Ok(VecLiteral::Expr(ExpressionParser.parse(inner, ctx)?)), + } + } + Rule::vec_literal => VecLiteralParser.parse(pair, ctx), + _ => Ok(VecLiteral::Expr(ExpressionParser.parse(pair, ctx)?)), } } } diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index ebfd7d63..40a47084 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -2,14 +2,46 @@ use super::expression::ExpressionParser; use super::statement::StatementParser; use super::{Parse, ParseContext, next_inner_pair}; use crate::{ - SourceLineNumber, - lang::{AssignmentTarget, Expression, Function, Line, SourceLocation}, + a_simplify_lang::VarOrConstMallocAccess, + lang::{AssignmentTarget, Expression, Function, FunctionArg, Line, MathOperation, SimpleExpr, SourceLocation}, parser::{ error::{ParseResult, SemanticError}, grammar::{ParsePair, Rule}, }, }; -use lean_vm::{ALL_TABLES, CustomHint, LOG_VECTOR_LEN, Table, TableT}; +use lean_vm::{ALL_TABLES, TableT}; + +/// Reserved function names that users cannot define. +pub const RESERVED_FUNCTION_NAMES: &[&str] = &[ + // Built-in functions + "print", + "Array", + "DynArray", + "push", // Compile-time vector push + // Compile-time only functions + "len", + "log2_ceil", + "next_multiple_of", + "saturating_sub", + // Custom hints (manually listed since CustomHint doesn't re-export strum iterator) + "hint_decompose_bits_xmss", + "hint_decompose_bits", +]; + +/// Check if a function name is reserved. +fn is_reserved_function_name(name: &str) -> bool { + // Check static reserved names + if RESERVED_FUNCTION_NAMES.contains(&name) { + return true; + } + // Check precompile names (poseidon16, dot_product, execution) + for table in ALL_TABLES { + if table.name() == name && !table.is_execution_table() { + return true; + } + } + false +} /// Parser for complete function definitions. pub struct FunctionParser; @@ -17,18 +49,29 @@ pub struct FunctionParser; impl Parse for FunctionParser { fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { let mut inner = pair.into_inner().peekable(); - let assume_always_returns = match inner.peek().map(|x| x.as_rule()) { - Some(Rule::pragma) => { - inner.next(); - true + + // Parse optional @inline decorator + let inlined = match inner.peek().map(|x| x.as_rule()) { + Some(Rule::decorator) => { + let decorator = inner.next().unwrap(); + let decorator_name = decorator.into_inner().next().unwrap().as_str(); + if decorator_name == "inline" { + true + } else { + return Err(SemanticError::new(format!("Unknown decorator '@{decorator_name}'")).into()); + } } _ => false, }; + let name = next_inner_pair(&mut inner, "function name")?.as_str().to_string(); + // Check for reserved function names + if is_reserved_function_name(&name) { + return Err(SemanticError::new(format!("Cannot define function with reserved name '{name}'")).into()); + } + let mut arguments = Vec::new(); - let mut n_returned_vars = 0; - let mut inlined = false; let mut body = Vec::new(); for pair in inner { @@ -40,12 +83,6 @@ impl Parse for FunctionParser { } } } - Rule::inlined_statement => { - inlined = true; - } - Rule::return_count => { - n_returned_vars = ReturnCountParser.parse(pair, ctx)?; - } Rule::statement => { Self::add_statement_with_location(&mut body, pair, ctx)?; } @@ -53,14 +90,14 @@ impl Parse for FunctionParser { } } + let n_returned_vars = Self::infer_return_count(&name, &body)?; + Ok(Function { name, - file_id: ctx.current_file_id, arguments, inlined, n_returned_vars, body, - assume_always_returns, }) } } @@ -84,35 +121,62 @@ impl FunctionParser { Ok(()) } -} -/// Parser for function parameters. -pub struct ParameterParser; + /// Infer the number of return values from return statements in the function body. + /// All return statements must return the same number of values. + fn infer_return_count(func_name: &str, body: &[Line]) -> ParseResult { + let mut return_counts: Vec = Vec::new(); + Self::collect_return_counts(body, &mut return_counts); -impl Parse<(String, bool)> for ParameterParser { - fn parse(&self, pair: ParsePair<'_>, _ctx: &mut ParseContext) -> ParseResult<(String, bool)> { - let mut inner = pair.into_inner(); - let first = next_inner_pair(&mut inner, "parameter")?; + match return_counts.as_slice() { + [] => Err(SemanticError::new(format!("Function '{func_name}' has no return statements")).into()), + [first, rest @ ..] => { + if rest.iter().any(|&count| count != *first) { + return Err( + SemanticError::new(format!("Inconsistent return counts in function '{func_name}'")).into(), + ); + } + Ok(*first) + } + } + } - if first.as_rule() == Rule::const_keyword { - let identifier = next_inner_pair(&mut inner, "identifier after 'const'")?; - Ok((identifier.as_str().to_string(), true)) - } else { - Ok((first.as_str().to_string(), false)) + fn collect_return_counts(body: &[Line], counts: &mut Vec) { + for line in body { + if let Line::FunctionRet { return_data } = line { + counts.push(return_data.len()); + } + for block in line.nested_blocks() { + Self::collect_return_counts(block, counts); + } } } } -/// Parser for function return count declarations. -pub struct ReturnCountParser; +/// Parser for function parameters. +pub struct ParameterParser; -impl Parse for ReturnCountParser { - fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { - let count_str = next_inner_pair(&mut pair.into_inner(), "return count")?.as_str(); +impl Parse for ParameterParser { + fn parse(&self, pair: ParsePair<'_>, _ctx: &mut ParseContext) -> ParseResult { + let mut inner = pair.into_inner(); + let name = next_inner_pair(&mut inner, "parameter name")?.as_str().to_string(); - ctx.get_constant(count_str) - .or_else(|| count_str.parse().ok()) - .ok_or_else(|| SemanticError::new("Invalid return count").into()) + // Check for optional type annotation (: Const or : Mut) + let (is_const, is_mutable) = if let Some(annotation) = inner.next() { + match annotation.as_str().trim() { + ": Const" => (true, false), + ": Mut" => (false, true), + other => return Err(SemanticError::new(format!("Invalid parameter annotation: {other}")).into()), + } + } else { + (false, false) + }; + + Ok(FunctionArg { + name, + is_const, + is_mutable, + }) } } @@ -121,11 +185,13 @@ pub struct AssignmentTargetParser; impl Parse for AssignmentTargetParser { fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { - let inner = next_inner_pair(&mut pair.into_inner(), "assignment target")?; + let mut inner = pair.into_inner().peekable(); - match inner.as_rule() { + let first_pair = next_inner_pair(&mut inner, "assignment target")?; + + match first_pair.as_rule() { Rule::array_access_expr => { - let mut inner_pairs = inner.into_inner(); + let mut inner_pairs = first_pair.into_inner(); let array = next_inner_pair(&mut inner_pairs, "array name")?.as_str().to_string(); let index = ExpressionParser.parse(next_inner_pair(&mut inner_pairs, "array index")?, ctx)?; Ok(AssignmentTarget::ArrayAccess { @@ -133,7 +199,18 @@ impl Parse for AssignmentTargetParser { index: Box::new(index), }) } - Rule::identifier => Ok(AssignmentTarget::Var(inner.as_str().to_string())), + Rule::identifier => { + let var = first_pair.as_str().to_string(); + // Check for mut_annotation (: Mut) following the identifier + let is_mutable = inner + .peek() + .map(|p| p.as_rule() == Rule::mut_annotation) + .unwrap_or(false); + if is_mutable { + inner.next(); // consume the mut_annotation + } + Ok(AssignmentTarget::Var { var, is_mutable }) + } _ => Err(SemanticError::new("Expected identifier or array access").into()), } } @@ -146,37 +223,170 @@ impl Parse for AssignmentParser { let line_number = pair.line_col().0; let mut inner = pair.into_inner().peekable(); - // Check if there's an assignment_target_list (LHS) - let mut targets: Vec = Vec::new(); - if let Some(first) = inner.peek() + // Check if there's assignment_target_list and assign_op + let lhs_info = if let Some(first) = inner.peek() && first.as_rule() == Rule::assignment_target_list { - targets = inner - .next() - .unwrap() - .into_inner() - .map(|item| AssignmentTargetParser.parse(item, ctx)) - .collect::>>()?; - } + let target_list = inner.next().unwrap(); + let op_pair = next_inner_pair(&mut inner, "assignment operator")?; + Some(Self::parse_lhs(target_list, op_pair, ctx)?) + } else { + None + }; - // Parse the expression (RHS) + // Parse the RHS expression let expr_pair = next_inner_pair(&mut inner, "expression")?; - let expr = ExpressionParser.parse(expr_pair, ctx)?; + let rhs_expr = ExpressionParser.parse(expr_pair, ctx)?; + let location = SourceLocation { + file_id: ctx.current_file_id, + line_number, + }; + match lhs_info { + Some(LhsInfo::Compound { target, lhs_expr, op }) => { + // Desugar: target op= expr -> target = target op expr + let desugared_expr = Expression::MathExpr(op, vec![lhs_expr, rhs_expr]); + Ok(Line::Statement { + targets: vec![target], + value: desugared_expr, + location, + }) + } + Some(LhsInfo::Simple { mut targets }) => { + for target in &mut targets { + if let AssignmentTarget::Var { var, .. } = target + && var == "_" + { + *var = ctx.next_trash_var(); + } + } + Self::finalize_simple_assignment(location, targets, rhs_expr) + } + None => { + // No LHS - expression statement (e.g., function call) + Self::finalize_simple_assignment(location, Vec::new(), rhs_expr) + } + } + } +} - for target in &mut targets { - if let AssignmentTarget::Var(var) = target - && var == "_" - { - *var = ctx.next_trash_var(); +/// Parsed LHS information +enum LhsInfo { + Compound { + target: AssignmentTarget, + lhs_expr: Expression, + op: MathOperation, + }, + Simple { + targets: Vec, + }, +} + +impl AssignmentParser { + /// Parse assignment LHS (target list + operator) and return structured info + fn parse_lhs( + target_list_pair: ParsePair<'_>, + op_pair: ParsePair<'_>, + ctx: &mut ParseContext, + ) -> ParseResult { + let op_str = op_pair.as_str(); + + if op_str == "=" { + // Simple assignment - parse target list + let mut inner = target_list_pair.into_inner(); + let first = next_inner_pair(&mut inner, "assignment target")?; + + let targets = match first.as_rule() { + Rule::simple_target_list => first + .into_inner() + .map(|item| AssignmentTargetParser.parse(item, ctx)) + .collect::>>()?, + _ => return Err(SemanticError::new("Expected assignment target").into()), + }; + Ok(LhsInfo::Simple { targets }) + } else { + // Compound assignment - validate constraints + let mut outer = target_list_pair.into_inner(); + let inner_list = next_inner_pair(&mut outer, "assignment target")?; + + // Must be simple_target_list with exactly one target + let targets: Vec<_> = inner_list.into_inner().collect(); + + if targets.len() != 1 { + return Err(SemanticError::new("Compound assignment operators only allow a single target").into()); } + + let target_pair = targets.into_iter().next().unwrap(); + let (target, lhs_expr) = Self::parse_compound_target(target_pair, ctx)?; + + let op = match op_str { + "+=" => MathOperation::Add, + "-=" => MathOperation::Sub, + "*=" => MathOperation::Mul, + "/=" => MathOperation::Div, + _ => return Err(SemanticError::new("Invalid compound operator").into()), + }; + + Ok(LhsInfo::Compound { target, lhs_expr, op }) } + } - match &expr { - Expression::FunctionCall { function_name, args } => { - Self::handle_function_call(line_number, function_name.clone(), args.clone(), targets) + /// Parse a single target for compound assignment (no mut allowed) + fn parse_compound_target( + pair: ParsePair<'_>, + ctx: &mut ParseContext, + ) -> ParseResult<(AssignmentTarget, Expression)> { + let mut inner = pair.into_inner().peekable(); + + let target_inner = next_inner_pair(&mut inner, "assignment target")?; + + // Check for mut annotation (: Mut) - not allowed in compound assignment + if inner + .peek() + .map(|p| p.as_rule() == Rule::mut_annotation) + .unwrap_or(false) + { + return Err(SemanticError::new("Cannot use ': Mut' with compound assignment operators").into()); + } + + match target_inner.as_rule() { + Rule::array_access_expr => { + let mut arr_inner = target_inner.into_inner(); + let array = next_inner_pair(&mut arr_inner, "array name")?.as_str().to_string(); + let indices: Vec = arr_inner + .map(|idx_pair| ExpressionParser.parse(idx_pair, ctx)) + .collect::>>()?; + + let target = AssignmentTarget::ArrayAccess { + array: array.clone(), + index: Box::new(indices[0].clone()), + }; + let lhs_expr = Expression::ArrayAccess { array, index: indices }; + Ok((target, lhs_expr)) + } + Rule::identifier => { + let var = target_inner.as_str().to_string(); + let target = AssignmentTarget::Var { + var: var.clone(), + is_mutable: false, + }; + let lhs_expr = Expression::Value(SimpleExpr::Memory(VarOrConstMallocAccess::Var(var))); + Ok((target, lhs_expr)) } + _ => Err(SemanticError::new("Expected identifier or array access").into()), + } + } + + /// Finalize a simple assignment (handles function calls vs regular expressions) + fn finalize_simple_assignment( + location: SourceLocation, + targets: Vec, + expr: Expression, + ) -> ParseResult { + match &expr { + Expression::FunctionCall { + function_name, args, .. + } => Self::handle_function_call(location, function_name.clone(), args.clone(), targets), _ => { - // Non-function-call expression - must have exactly one target if targets.is_empty() { return Err(SemanticError::new("Expression statement has no effect").into()); } @@ -186,11 +396,10 @@ impl Parse for AssignmentParser { ) .into()); } - Ok(Line::Statement { targets, value: expr, - line_number, + location, }) } } @@ -199,108 +408,21 @@ impl Parse for AssignmentParser { impl AssignmentParser { fn handle_function_call( - line_number: SourceLineNumber, + location: SourceLocation, function_name: String, args: Vec, return_data: Vec, ) -> ParseResult { - // Helper to extract a single variable from return_data for builtins - let require_single_var = |return_data: &[AssignmentTarget], builtin_name: &str| -> ParseResult { - if return_data.len() != 1 { - return Err(SemanticError::new(format!("Invalid {builtin_name} call: expected 1 return value")).into()); - } - match &return_data[0] { - AssignmentTarget::Var(v) => Ok(v.clone()), - AssignmentTarget::ArrayAccess { .. } => Err(SemanticError::new(format!( - "{builtin_name} does not support array access as return target" - )) - .into()), - } - }; - - match function_name.as_str() { - "malloc" => { - if args.len() != 1 { - return Err(SemanticError::new("Invalid malloc call").into()); - } - let var = require_single_var(&return_data, "malloc")?; - Ok(Line::MAlloc { - var, - size: args[0].clone(), - vectorized: false, - vectorized_len: Expression::zero(), - }) - } - "malloc_vec" => { - let vectorized_len = if args.len() == 1 { - Expression::scalar(LOG_VECTOR_LEN) - } else if args.len() == 2 { - args[1].clone() - } else { - return Err(SemanticError::new("Invalid malloc_vec call").into()); - }; - let var = require_single_var(&return_data, "malloc")?; - Ok(Line::MAlloc { - var, - size: args[0].clone(), - vectorized: true, - vectorized_len, - }) - } - "print" => { - if !return_data.is_empty() { - return Err(SemanticError::new("Print function should not return values").into()); - } - Ok(Line::Print { - line_info: function_name.clone(), - content: args, - }) - } - "private_input_start" => { - if !args.is_empty() { - return Err(SemanticError::new("Invalid private_input_start call").into()); - } - let result = require_single_var(&return_data, "private_input_start")?; - Ok(Line::PrivateInputStart { result }) - } - "panic" => { - if !return_data.is_empty() || !args.is_empty() { - return Err(SemanticError::new("Panic has no args and returns no values").into()); - } - Ok(Line::Panic) - } - _ => { - // Check for special precompile functions - if let Some(table) = ALL_TABLES.into_iter().find(|p| p.name() == function_name) - && table != Table::execution() - { - return Ok(Line::Precompile { table, args }); - } - - // Check for custom hint - if let Some(hint) = CustomHint::find_by_name(&function_name) { - if !return_data.is_empty() { - return Err(SemanticError::new(format!( - "Custom hint: \"{function_name}\" should not return values", - )) - .into()); - } - if !hint.n_args_range().contains(&args.len()) { - return Err(SemanticError::new(format!( - "Custom hint: \"{function_name}\" : invalid number of arguments", - )) - .into()); - } - return Ok(Line::CustomHint(hint, args)); - } - // Regular function call - allow array access targets - Ok(Line::Statement { - targets: return_data, - value: Expression::FunctionCall { function_name, args }, - line_number, - }) - } - } + // Function calls (print, precompiles, custom hints) are handled in a_simplify_lang.rs + Ok(Line::Statement { + targets: return_data, + value: Expression::FunctionCall { + function_name, + args, + location, + }, + location, + }) } } diff --git a/crates/lean_compiler/src/parser/parsers/literal.rs b/crates/lean_compiler/src/parser/parsers/literal.rs index 2c63a6e0..5bcdb0f3 100644 --- a/crates/lean_compiler/src/parser/parsers/literal.rs +++ b/crates/lean_compiler/src/parser/parsers/literal.rs @@ -1,3 +1,6 @@ +use lean_vm::{NONRESERVED_PROGRAM_INPUT_START, ONE_VEC_PTR, PRIVATE_INPUT_START_PTR, ZERO_VEC_PTR}; +use multilinear_toolkit::prelude::*; + use super::expression::ExpressionParser; use super::{ConstArrayValue, Parse, ParseContext, ParsedConstant, next_inner_pair}; use crate::a_simplify_lang::VarOrConstMallocAccess; @@ -9,8 +12,6 @@ use crate::{ grammar::{ParsePair, Rule}, }, }; -use multilinear_toolkit::prelude::*; -use utils::ToUsize; /// Parser for constant declarations. pub struct ConstantDeclarationParser; @@ -31,7 +32,10 @@ impl Parse<(String, ParsedConstant)> for ConstantDeclarationParser { let expr = ExpressionParser.parse(value_pair, ctx)?; let value = evaluate_const_expr(&expr, ctx).ok_or_else(|| { - SemanticError::with_context(format!("Failed to evaluate constant: {name}"), "constant declaration") + SemanticError::with_context( + format!("Failed to evaluate constant: {name}, with expression: {}", expr), + "constant declaration", + ) })?; Ok((name, ParsedConstant::Scalar(value))) @@ -87,21 +91,19 @@ fn parse_array_literal(pair: ParsePair<'_>, ctx: &mut ParseContext, const_name: } /// Evaluate a const expression to a usize value at parse time. -pub fn evaluate_const_expr(expr: &crate::lang::Expression, ctx: &ParseContext) -> Option { +pub fn evaluate_const_expr(expr: &crate::lang::Expression, ctx: &ParseContext) -> Option { expr.eval_with( &|simple_expr| match simple_expr { SimpleExpr::Constant(cst) => cst.naive_eval(), - SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) => ctx.get_constant(var).map(F::from_usize), + SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) => ctx.get_constant(var), SimpleExpr::Memory(VarOrConstMallocAccess::ConstMallocAccess { .. }) => None, }, &|arr, index| { // Support const array access in expressions - let idx = index.iter().map(|e| e.to_usize()).collect::>(); let array = ctx.get_const_array(arr)?; - array.navigate(&idx)?.as_scalar().map(F::from_usize) + array.navigate(&index)?.as_scalar() }, ) - .map(|f| f.to_usize()) } /// Parser for variable or constant references. @@ -126,15 +128,12 @@ impl VarOrConstantParser { fn parse_identifier_or_constant(text: &str, ctx: &ParseContext) -> ParseResult { match text { // Special built-in constants - "public_input_start" => Ok(SimpleExpr::Constant(ConstExpression::Value( - ConstantValue::PublicInputStart, - ))), - "pointer_to_zero_vector" => Ok(SimpleExpr::Constant(ConstExpression::Value( - ConstantValue::PointerToZeroVector, - ))), - "pointer_to_one_vector" => Ok(SimpleExpr::Constant(ConstExpression::Value( - ConstantValue::PointerToOneVector, + "NONRESERVED_PROGRAM_INPUT_START" => Ok(SimpleExpr::Constant(ConstExpression::from( + NONRESERVED_PROGRAM_INPUT_START, ))), + "PRIVATE_INPUT_START_PTR" => Ok(SimpleExpr::Constant(ConstExpression::from(PRIVATE_INPUT_START_PTR))), + "ZERO_VEC_PTR" => Ok(SimpleExpr::Constant(ConstExpression::from(ZERO_VEC_PTR))), + "ONE_VEC_PTR" => Ok(SimpleExpr::Constant(ConstExpression::from(ONE_VEC_PTR))), _ => { // Check if it's a const array (error case - can't use array as value) if ctx.get_const_array(text).is_some() { @@ -154,7 +153,7 @@ impl VarOrConstantParser { // Try to parse as numeric literal else if let Ok(value) = text.parse::() { Ok(SimpleExpr::Constant(ConstExpression::Value(ConstantValue::Scalar( - value, + F::from_usize(value), )))) } // Otherwise treat as variable reference @@ -169,22 +168,23 @@ impl VarOrConstantParser { /// Parser for constant expressions used in match patterns. pub struct ConstExprParser; -impl Parse for ConstExprParser { - fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { +impl Parse for ConstExprParser { + fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { let inner = pair.into_inner().next().unwrap(); match inner.as_rule() { Rule::constant_value => { let text = inner.as_str(); match text { - "public_input_start" => { - Err(SemanticError::new("public_input_start cannot be used as match pattern").into()) - } + "NONRESERVED_PROGRAM_INPUT_START" => Err(SemanticError::new( + "NONRESERVED_PROGRAM_INPUT_START cannot be used as match pattern", + ) + .into()), _ => { if let Some(value) = ctx.get_constant(text) { Ok(value) } else if let Ok(value) = text.parse::() { - Ok(value) + Ok(F::from_usize(value)) } else { Err(SemanticError::with_context( format!("Invalid constant expression in match pattern: {text}"), diff --git a/crates/lean_compiler/src/parser/parsers/mod.rs b/crates/lean_compiler/src/parser/parsers/mod.rs index bc9e0556..bc87f5e1 100644 --- a/crates/lean_compiler/src/parser/parsers/mod.rs +++ b/crates/lean_compiler/src/parser/parsers/mod.rs @@ -1,3 +1,6 @@ +use lean_vm::F; +use utils::ToUsize; + use crate::lang::FileId; use crate::parser::{ error::{ParseResult, SemanticError}, @@ -16,7 +19,7 @@ pub mod statement; /// Supports arbitrary nesting: `[[1, 2], [3, 4, 5], []]` #[derive(Debug, Clone, PartialEq, Eq)] pub enum ConstArrayValue { - Scalar(usize), + Scalar(F), Array(Vec), } @@ -48,17 +51,17 @@ impl ConstArrayValue { } } - pub fn as_scalar(&self) -> Option { + pub fn as_scalar(&self) -> Option { match self { Self::Scalar(v) => Some(*v), Self::Array(_) => None, } } - pub fn navigate(&self, indices: &[usize]) -> Option<&Self> { + pub fn navigate(&self, indices: &[F]) -> Option<&Self> { let mut current = self; for &idx in indices { - current = current.get(idx)?; + current = current.get(idx.to_usize())?; } Some(current) } @@ -67,7 +70,7 @@ impl ConstArrayValue { /// Represents a parsed constant value (scalar or array). #[derive(Debug, Clone)] pub enum ParsedConstant { - Scalar(usize), + Scalar(F), Array(ConstArrayValue), } @@ -75,7 +78,7 @@ pub enum ParsedConstant { #[derive(Debug)] pub struct ParseContext { /// Compile-time scalar constants defined in the program - pub constants: BTreeMap, + pub constants: BTreeMap, /// Compile-time array constants defined in the program (supports nested arrays) pub const_arrays: BTreeMap, /// Counter for generating unique trash variable names @@ -88,6 +91,10 @@ pub struct ParseContext { pub current_file_id: FileId, /// Absolute filepaths imported so far (also includes the root filepath) pub imported_filepaths: BTreeSet, + /// Stack of files currently being imported (for circular import detection) + pub import_stack: Vec, + /// Root directory for resolving imports (directory of the entry point file) + pub import_root: String, /// Next unused file ID pub next_file_id: usize, /// Compilation flags @@ -101,6 +108,11 @@ impl ParseContext { ProgramSource::Raw(_) => ("".to_string(), BTreeSet::new()), ProgramSource::Filepath(fp) => (fp.clone(), [fp.clone()].into_iter().collect()), }; + let import_stack = vec![current_filepath.clone()]; + let import_root = std::path::Path::new(¤t_filepath) + .parent() + .map(|p| p.to_string_lossy().to_string()) + .unwrap_or_default(); Ok(Self { constants: BTreeMap::new(), const_arrays: BTreeMap::new(), @@ -108,6 +120,8 @@ impl ParseContext { current_filepath, current_file_id: 0, imported_filepaths, + import_stack, + import_root, current_source_code, next_file_id: 1, flags, @@ -115,7 +129,7 @@ impl ParseContext { } /// Adds a scalar constant to the context. - pub fn add_constant(&mut self, name: String, value: usize) -> Result<(), SemanticError> { + pub fn add_constant(&mut self, name: String, value: F) -> Result<(), SemanticError> { if self.constants.contains_key(&name) || self.const_arrays.contains_key(&name) { Err(SemanticError::with_context( format!("Defined multiple times: {name}"), @@ -141,7 +155,7 @@ impl ParseContext { } /// Looks up a scalar constant value. - pub fn get_constant(&self, name: &str) -> Option { + pub fn get_constant(&self, name: &str) -> Option { self.constants.get(name).copied() } diff --git a/crates/lean_compiler/src/parser/parsers/program.rs b/crates/lean_compiler/src/parser/parsers/program.rs index 7bddf0f4..49097691 100644 --- a/crates/lean_compiler/src/parser/parsers/program.rs +++ b/crates/lean_compiler/src/parser/parsers/program.rs @@ -6,7 +6,6 @@ use crate::{ parser::{ error::{ParseError, ParseResult, SemanticError}, grammar::{ParsePair, Rule, parse_source}, - lexer, parsers::{Parse, ParseContext, ParsedConstant, next_inner_pair}, }, }; @@ -40,31 +39,46 @@ impl Parse for ProgramParser { // Visit the imported file and parse it into the context // and program; also keep track of which files have been // imported and do not import the same file twice. + // Imports are resolved from the import root (entry point directory), + // matching Python's behavior with PYTHONPATH. let filepath = ImportStatementParser.parse(item, ctx)?; - let filepath = Path::new(&ctx.current_filepath) - .parent() - .expect("Empty filepath") + let filepath = Path::new(&ctx.import_root) .join(filepath) .to_str() .expect("Invalid UTF-8 in filepath") .to_string(); + + // Check for circular imports + if ctx.import_stack.contains(&filepath) { + let cycle: Vec<_> = ctx + .import_stack + .iter() + .skip_while(|p| *p != &filepath) + .cloned() + .collect(); + return Err(SemanticError::new(format!( + "Circular import detected: {} -> {}", + cycle.join(" -> "), + filepath + )) + .into()); + } + if !ctx.imported_filepaths.contains(&filepath) { let saved_filepath = ctx.current_filepath.clone(); let saved_file_id = ctx.current_file_id; ctx.current_filepath = filepath.clone(); ctx.imported_filepaths.insert(filepath.clone()); - ctx.current_source_code = ProgramSource::Filepath(filepath.clone()) - .get_content(&ctx.flags) - .unwrap(); + ctx.import_stack.push(filepath.clone()); + ctx.current_source_code = ProgramSource::Filepath(filepath).get_content(&ctx.flags)?; let subprogram = parse_program_helper(ctx)?; + ctx.import_stack.pop(); functions.extend(subprogram.functions); function_locations.extend(subprogram.function_locations); source_code.extend(subprogram.source_code); filepaths.extend(subprogram.filepaths); ctx.current_filepath = saved_filepath; ctx.current_file_id = saved_file_id; - // It is unnecessary to save and restore current_source_code because it will not - // be referenced again for the same file. } } Rule::function => { @@ -104,37 +118,168 @@ pub struct ImportStatementParser; impl Parse for ImportStatementParser { fn parse(&self, pair: ParsePair<'_>, _ctx: &mut ParseContext) -> ParseResult { let mut inner = pair.into_inner(); - let item = next_inner_pair(&mut inner, "filepath")?; + let item = next_inner_pair(&mut inner, "module_path")?; match item.as_rule() { - Rule::filepath => { - let inner = item.into_inner(); - let mut filepath = String::new(); - for item in inner { - match item.as_rule() { - Rule::filepath_character => { - filepath.push_str(item.as_str()); - } - _ => { - return Err(SemanticError::with_context( - format!("Expected a filepath character, got: {}", item.as_str()), - "filepath character", - ) - .into()); - } - } - } + Rule::module_path => { + let parts: Vec<&str> = item.into_inner().map(|p| p.as_str()).collect(); + // Convert module.path to path/to/file.py + let filepath = format!("{}.py", parts.join("/")); Ok(filepath) } - _ => Err( - SemanticError::with_context(format!("Expected a filepath, got: {}", item.as_str()), "filepath").into(), - ), + _ => Err(SemanticError::with_context( + format!("Expected a module path, got: {}", item.as_str()), + "module_path", + ) + .into()), + } + } +} + +pub fn remove_comments(input: &str) -> String { + let mut s = input; + let mut result = String::with_capacity(input.len()); + while !s.is_empty() { + // Handle # line comments (but not #![...] pragmas) + if s.starts_with('#') && !s.starts_with("#![") { + s = s.find('\n').map_or("", |i| &s[i..]); + // Handle """ block comments + } else if let Some(rest) = s.strip_prefix("\"\"\"") { + s = rest.find("\"\"\"").map_or("", |i| &rest[i + 3..]); + // Find next potential comment start + } else if let Some(i) = s[1..].find(['#', '"']) { + result.push_str(&s[..i + 1]); + s = &s[i + 1..]; + } else { + result.push_str(s); + break; + } + } + result +} + +/// Removes the snark_lib import if it's on the first line. +/// This import is only used for Python execution compatibility and is not relevant to the zkDSL. +pub fn remove_snark_lib_import(input: &str) -> String { + let first_line = input.lines().next().unwrap_or(""); + let trimmed = first_line.trim(); + let is_snark_lib_import = + (trimmed.starts_with("import ") || trimmed.starts_with("from ")) && trimmed.contains("snark_lib"); + if is_snark_lib_import { + input + .strip_prefix(first_line) + .unwrap_or(input) + .trim_start_matches('\n') + .to_string() + } else { + input.to_string() + } +} + +/// Preprocesses Python-like indentation syntax into explicit block markers. +/// Handles line continuations (`\`) and implicit continuation inside parentheses/brackets/braces. +/// Converts indentation-based blocks to markers. +pub fn preprocess_indentation(input: &str) -> Result { + let mut result = String::with_capacity(input.len() * 2); + let mut indent_stack: Vec = vec![0]; + + // First, collect logical lines by joining continued lines + // Continuation happens with `\` or when inside unclosed parentheses/brackets/braces + let mut logical_lines: Vec<(usize, String)> = Vec::new(); // (starting line number, content) + let mut current_logical_line = String::new(); + let mut logical_line_start = 1; + let mut paren_depth = 0i32; // tracks (), [], {} + + for (i, line) in input.lines().enumerate() { + let line_number = i + 1; + let trimmed = line.trim_end(); + + if current_logical_line.is_empty() { + logical_line_start = line_number; + } + + // Count parentheses/brackets/braces in this line + for c in trimmed.chars() { + match c { + '(' | '[' | '{' => paren_depth += 1, + ')' | ']' | '}' => paren_depth -= 1, + _ => {} + } + } + + // Check for explicit line continuation with `\` + if let Some(without_backslash) = trimmed.strip_suffix('\\') { + current_logical_line.push_str(without_backslash.trim_end()); + current_logical_line.push(' '); + } else if paren_depth > 0 { + // Implicit continuation: inside unclosed parens/brackets/braces + current_logical_line.push_str(trimmed); + current_logical_line.push(' '); + } else { + current_logical_line.push_str(line); + logical_lines.push((logical_line_start, std::mem::take(&mut current_logical_line))); + } + } + // Handle any remaining content (file ending with `\` or unclosed parens) + if !current_logical_line.is_empty() { + logical_lines.push((logical_line_start, current_logical_line)); + } + + // Process each logical line + for (line_number, line) in logical_lines { + let indent = line + .chars() + .take_while(|c| *c == ' ' || *c == '\t') + .map(|c| if c == '\t' { 4 } else { 1 }) + .sum::(); + + let trimmed = line.trim(); + + if trimmed.is_empty() { + continue; + } + + let current_indent = *indent_stack.last().unwrap(); + + if indent > current_indent { + return Err(ParseError::from(format!( + "Unexpected indentation at line {line_number}: expected {current_indent} spaces, got {indent}" + ))); + } + + if indent < current_indent { + while indent_stack.len() > 1 && indent < *indent_stack.last().unwrap() { + indent_stack.pop(); + result.push_str(""); + } + if indent != *indent_stack.last().unwrap() { + return Err(ParseError::from(format!( + "Invalid indentation at line {line_number}: got {indent} spaces, which doesn't match any block level" + ))); + } + } + + result.push_str(trimmed); + result.push_str(""); + + // Handle indent (open block after colon) + if trimmed.ends_with(':') && !trimmed.starts_with("import") { + indent_stack.push(indent + 4); // expect indented block } } + + // Close any remaining open blocks + while indent_stack.len() > 1 { + indent_stack.pop(); + result.push_str(""); + } + + Ok(result) } fn parse_program_helper(ctx: &mut ParseContext) -> Result { - // Preprocess source to remove comments - let processed_input = lexer::preprocess_source(&ctx.current_source_code); + let without_snark_lib_import = remove_snark_lib_import(&ctx.current_source_code); + let without_comments = remove_comments(&without_snark_lib_import); + let processed_input = preprocess_indentation(&without_comments)?; // Parse grammar into AST nodes let program_pair = parse_source(&processed_input)?; diff --git a/crates/lean_compiler/src/parser/parsers/statement.rs b/crates/lean_compiler/src/parser/parsers/statement.rs index 0052a25e..565c39c8 100644 --- a/crates/lean_compiler/src/parser/parsers/statement.rs +++ b/crates/lean_compiler/src/parser/parsers/statement.rs @@ -1,12 +1,13 @@ use lean_vm::{Boolean, BooleanExpr}; +use utils::ToUsize; -use super::expression::ExpressionParser; +use super::expression::{ExpressionParser, VecElementParser, VecLiteralParser}; use super::function::{AssignmentParser, TupleExpressionParser}; use super::literal::ConstExprParser; use super::{Parse, ParseContext, next_inner_pair}; use crate::{ SourceLineNumber, - lang::{Condition, Expression, Line, SourceLocation}, + lang::{Condition, Expression, Line, SourceLocation, VecLiteral}, parser::{ error::{ParseResult, SemanticError}, grammar::{ParsePair, Rule}, @@ -18,35 +19,34 @@ pub struct StatementParser; impl Parse for StatementParser { fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { - let inner = next_inner_pair(&mut pair.into_inner(), "statement body")?; + let mut inner_iter = pair.into_inner(); + let inner = next_inner_pair(&mut inner_iter, "statement body")?; match inner.as_rule() { - Rule::forward_declaration => ForwardDeclarationParser.parse(inner, ctx), - Rule::assignment => AssignmentParser.parse(inner, ctx), + // Compound statements (have their own block structure) Rule::if_statement => IfStatementParser.parse(inner, ctx), Rule::for_statement => ForStatementParser.parse(inner, ctx), Rule::match_statement => MatchStatementParser.parse(inner, ctx), - Rule::return_statement => ReturnStatementParser.parse(inner, ctx), - Rule::assert_statement => AssertParser::.parse(inner, ctx), - Rule::debug_assert_statement => AssertParser::.parse(inner, ctx), - Rule::break_statement => Ok(Line::Break), - Rule::continue_statement => Err(SemanticError::new("Continue statement not implemented yet").into()), + // Simple statements (wrapped in simple_statement rule) + Rule::simple_statement => { + let simple_inner = next_inner_pair(&mut inner.into_inner(), "simple statement body")?; + match simple_inner.as_rule() { + Rule::forward_declaration => ForwardDeclarationParser.parse(simple_inner, ctx), + Rule::assignment => AssignmentParser.parse(simple_inner, ctx), + Rule::return_statement => ReturnStatementParser.parse(simple_inner, ctx), + Rule::assert_statement => AssertParser::.parse(simple_inner, ctx), + Rule::debug_assert_statement => AssertParser::.parse(simple_inner, ctx), + Rule::vec_declaration => VecDeclarationParser.parse(simple_inner, ctx), + Rule::push_statement => PushStatementParser.parse(simple_inner, ctx), + Rule::pop_statement => PopStatementParser.parse(simple_inner, ctx), + _ => Err(SemanticError::new("Unknown simple statement").into()), + } + } _ => Err(SemanticError::new("Unknown statement").into()), } } } -/// Parser for forward declarations of variables. -pub struct ForwardDeclarationParser; - -impl Parse for ForwardDeclarationParser { - fn parse(&self, pair: ParsePair<'_>, _ctx: &mut ParseContext) -> ParseResult { - let mut inner = pair.into_inner(); - let var = next_inner_pair(&mut inner, "variable name")?.as_str().to_string(); - Ok(Line::ForwardDeclaration { var }) - } -} - /// Parser for if-else conditional statements. pub struct IfStatementParser; @@ -72,7 +72,9 @@ impl Parse for IfStatementParser { ConditionParser.parse(next_inner_pair(&mut inner, "else if condition")?, ctx)?; let mut else_if_branch = Vec::new(); for else_if_item in inner { - Self::add_statement_with_location(&mut else_if_branch, else_if_item, ctx)?; + if else_if_item.as_rule() == Rule::statement { + Self::add_statement_with_location(&mut else_if_branch, else_if_item, ctx)?; + } } else_if_branches.push((else_if_condition, else_if_branch, line_number)); } @@ -95,7 +97,10 @@ impl Parse for IfStatementParser { condition: else_if_condition, then_branch: else_if_branch, else_branch: Vec::new(), - line_number, + location: SourceLocation { + file_id: ctx.current_file_id, + line_number, + }, }); inner_else_branch = match &mut inner_else_branch[0] { Line::IfCondition { else_branch, .. } => else_branch, @@ -109,7 +114,10 @@ impl Parse for IfStatementParser { condition, then_branch, else_branch: outer_else_branch, - line_number, + location: SourceLocation { + file_id: ctx.current_file_id, + line_number, + }, }) } } @@ -168,6 +176,7 @@ impl ComparisonParser { "==" => Boolean::Equal, "!=" => Boolean::Different, "<" => Boolean::LessThan, + "<=" => Boolean::LessOrEqual, _ => unreachable!(), }; @@ -184,30 +193,18 @@ impl Parse for ForStatementParser { let mut inner = pair.into_inner(); let iterator = next_inner_pair(&mut inner, "loop iterator")?.as_str().to_string(); - // Check for optional reverse clause - let mut rev = false; - if let Some(next_peek) = inner.clone().next() - && next_peek.as_rule() == Rule::rev_clause - { - rev = true; - inner.next(); // Consume the rev clause - } + // Next is either range or unroll_range + let range_pair = next_inner_pair(&mut inner, "range expression")?; + let unroll = matches!(range_pair.as_rule(), Rule::unroll_range); - let start = ExpressionParser.parse(next_inner_pair(&mut inner, "loop start")?, ctx)?; - let end = ExpressionParser.parse(next_inner_pair(&mut inner, "loop end")?, ctx)?; + let mut range_inner = range_pair.into_inner(); + let start = ExpressionParser.parse(next_inner_pair(&mut range_inner, "loop start")?, ctx)?; + let end = ExpressionParser.parse(next_inner_pair(&mut range_inner, "loop end")?, ctx)?; - let mut unroll = false; let mut body = Vec::new(); - for item in inner { - match item.as_rule() { - Rule::unroll_clause => { - unroll = true; - } - Rule::statement => { - Self::add_statement_with_location(&mut body, item, ctx)?; - } - _ => {} + if item.as_rule() == Rule::statement { + Self::add_statement_with_location(&mut body, item, ctx)?; } } @@ -216,9 +213,11 @@ impl Parse for ForStatementParser { start, end, body, - rev, unroll, - line_number, + location: SourceLocation { + file_id: ctx.current_file_id, + line_number, + }, }) } } @@ -249,6 +248,7 @@ pub struct MatchStatementParser; impl Parse for MatchStatementParser { fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { + let line_number = pair.line_col().0; let mut inner = pair.into_inner(); let value = ExpressionParser.parse(next_inner_pair(&mut inner, "match value")?, ctx)?; @@ -258,7 +258,7 @@ impl Parse for MatchStatementParser { if arm_pair.as_rule() == Rule::match_arm { let mut arm_inner = arm_pair.into_inner(); let const_expr = next_inner_pair(&mut arm_inner, "match pattern")?; - let pattern = ConstExprParser.parse(const_expr, ctx)?; + let pattern = ConstExprParser.parse(const_expr, ctx)?.to_usize(); let mut statements = Vec::new(); for stmt in arm_inner { @@ -270,8 +270,11 @@ impl Parse for MatchStatementParser { arms.push((pattern, statements)); } } - - Ok(Line::Match { value, arms }) + let location = SourceLocation { + file_id: ctx.current_file_id, + line_number, + }; + Ok(Line::Match { value, arms, location }) } } @@ -319,13 +322,157 @@ pub struct AssertParser; impl Parse for AssertParser { fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { let line_number = pair.line_col().0; - let comparison = next_inner_pair(&mut pair.into_inner(), "comparison")?; - let boolean = ComparisonParser::parse(comparison, ctx)?; + let mut inner = pair.into_inner(); + // Skip the assert_keyword / debug_assert_keyword + let _ = next_inner_pair(&mut inner, "assert keyword")?; + let next = next_inner_pair(&mut inner, "comparison or assert_false")?; + + match next.as_rule() { + Rule::assert_false => { + // assert False or assert False, "message" + let mut false_inner = next.into_inner(); + let message = false_inner.next().map(|s| { + let text = s.as_str(); + // Strip the quotes from the string literal + text[1..text.len() - 1].to_string() + }); + Ok(Line::Panic { message }) + } + Rule::comparison => { + let boolean = ComparisonParser::parse(next, ctx)?; + Ok(Line::Assert { + debug: DEBUG, + boolean, + location: SourceLocation { + file_id: ctx.current_file_id, + line_number, + }, + }) + } + _ => Err(SemanticError::new("Expected comparison or False in assert statement").into()), + } + } +} - Ok(Line::Assert { - debug: DEBUG, - boolean, - line_number, +/// Parser for vector declarations: `var = vec![...]` (vectors are implicitly mutable for push) +pub struct VecDeclarationParser; + +impl Parse for VecDeclarationParser { + fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { + let line_number = pair.line_col().0; + let mut inner = pair.into_inner(); + + // Parse variable name + let var = next_inner_pair(&mut inner, "variable name")?.as_str().to_string(); + + // Parse the vec_literal + let vec_literal_pair = next_inner_pair(&mut inner, "vec literal")?; + let vec_literal = VecLiteralParser.parse(vec_literal_pair, ctx)?; + + // Extract elements from the VecLiteral::Vec + let elements = match vec_literal { + VecLiteral::Vec(elems) => elems, + VecLiteral::Expr(_) => { + return Err(SemanticError::new("Expected vec literal, got expression").into()); + } + }; + + Ok(Line::VecDeclaration { + var, + elements, + location: SourceLocation { + file_id: ctx.current_file_id, + line_number, + }, + }) + } +} + +/// Parser for push statements: `vec_var.push(element);` or `vec_var[i][j].push(element);` +pub struct PushStatementParser; + +impl Parse for PushStatementParser { + fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { + let line_number = pair.line_col().0; + let mut inner = pair.into_inner(); + + // Parse the push_target (identifier with optional indices) + let push_target = next_inner_pair(&mut inner, "push target")?; + let mut target_inner = push_target.into_inner(); + + // First element is the vector variable name + let vector = next_inner_pair(&mut target_inner, "vector variable")? + .as_str() + .to_string(); + + // Remaining elements are index expressions + let indices: Vec = target_inner + .map(|idx_pair| ExpressionParser.parse(idx_pair, ctx)) + .collect::, _>>()?; + + // Parse the element to push (vec_element can be vec_literal or expression) + let element_pair = next_inner_pair(&mut inner, "push element")?; + let element = VecElementParser.parse(element_pair, ctx)?; + + Ok(Line::Push { + vector, + indices, + element, + location: SourceLocation { + file_id: ctx.current_file_id, + line_number, + }, }) } } + +/// Parser for pop statements: `vec_var.pop();` or `vec_var[i][j].pop();` +pub struct PopStatementParser; + +impl Parse for PopStatementParser { + fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { + let line_number = pair.line_col().0; + let mut inner = pair.into_inner(); + + // Parse the pop_target (identifier with optional indices) + let pop_target = next_inner_pair(&mut inner, "pop target")?; + let mut target_inner = pop_target.into_inner(); + + // First element is the vector variable name + let vector = next_inner_pair(&mut target_inner, "vector variable")? + .as_str() + .to_string(); + + // Remaining elements are index expressions + let indices: Vec = target_inner + .map(|idx_pair| ExpressionParser.parse(idx_pair, ctx)) + .collect::, _>>()?; + + Ok(Line::Pop { + vector, + indices, + location: SourceLocation { + file_id: ctx.current_file_id, + line_number, + }, + }) + } +} + +/// Parser for forward declarations: `x: Imu` or `x: Mut` +pub struct ForwardDeclarationParser; + +impl Parse for ForwardDeclarationParser { + fn parse(&self, pair: ParsePair<'_>, _ctx: &mut ParseContext) -> ParseResult { + let mut inner = pair.into_inner(); + + // Parse variable name + let var = next_inner_pair(&mut inner, "variable name")?.as_str().to_string(); + + // Check for : Mut or : Imu annotation + let annotation = next_inner_pair(&mut inner, "type annotation")?; + let is_mutable = annotation.as_rule() == Rule::mut_annotation; + + Ok(Line::ForwardDeclaration { var, is_mutable }) + } +} diff --git a/crates/lean_compiler/tests/bar.snark b/crates/lean_compiler/tests/bar.snark deleted file mode 100644 index 563193a7..00000000 --- a/crates/lean_compiler/tests/bar.snark +++ /dev/null @@ -1,3 +0,0 @@ -fn bar(x) -> 1 { - return x * 2; -} diff --git a/crates/lean_compiler/tests/circular_import.snark b/crates/lean_compiler/tests/circular_import.snark deleted file mode 100644 index e7df5cad..00000000 --- a/crates/lean_compiler/tests/circular_import.snark +++ /dev/null @@ -1 +0,0 @@ -import "circular_import.snark"; diff --git a/crates/lean_compiler/tests/foo.snark b/crates/lean_compiler/tests/foo.snark deleted file mode 100644 index 41987a0a..00000000 --- a/crates/lean_compiler/tests/foo.snark +++ /dev/null @@ -1 +0,0 @@ -const FOO = 3; diff --git a/crates/lean_compiler/tests/program_0.snark b/crates/lean_compiler/tests/program_0.snark deleted file mode 100644 index 6e59f905..00000000 --- a/crates/lean_compiler/tests/program_0.snark +++ /dev/null @@ -1,5 +0,0 @@ -import "asdfasdfadsfasdf.snark"; - -fn main() { - return; -} \ No newline at end of file diff --git a/crates/lean_compiler/tests/program_1.snark b/crates/lean_compiler/tests/program_1.snark deleted file mode 100644 index b9ca401e..00000000 --- a/crates/lean_compiler/tests/program_1.snark +++ /dev/null @@ -1,10 +0,0 @@ - import "bar.snark"; -import "foo.snark"; - -fn bar() { - return; -} - -fn main() { - return; -} \ No newline at end of file diff --git a/crates/lean_compiler/tests/program_2.snark b/crates/lean_compiler/tests/program_2.snark deleted file mode 100644 index 9a9f6ed6..00000000 --- a/crates/lean_compiler/tests/program_2.snark +++ /dev/null @@ -1,8 +0,0 @@ - import "bar.snark"; -import "foo.snark"; - -const FOO = 5; - -fn main() { - return; -} \ No newline at end of file diff --git a/crates/lean_compiler/tests/program_3.snark b/crates/lean_compiler/tests/program_3.snark deleted file mode 100644 index bcaebaa2..00000000 --- a/crates/lean_compiler/tests/program_3.snark +++ /dev/null @@ -1,6 +0,0 @@ -import "foo.snark"; -import "foo.snark"; - -fn main() { - return; -} \ No newline at end of file diff --git a/crates/lean_compiler/tests/program_4.snark b/crates/lean_compiler/tests/program_4.snark deleted file mode 100644 index bbe66620..00000000 --- a/crates/lean_compiler/tests/program_4.snark +++ /dev/null @@ -1,5 +0,0 @@ -import "circular_import.snark"; - -fn main() { - return; -} \ No newline at end of file diff --git a/crates/lean_compiler/tests/program_6.snark b/crates/lean_compiler/tests/program_6.snark deleted file mode 100644 index 9f1049fb..00000000 --- a/crates/lean_compiler/tests/program_6.snark +++ /dev/null @@ -1,8 +0,0 @@ -import "bar.snark"; -import "foo.snark"; - -fn main() { - x = bar(FOO); - assert x == 6; - return; -} \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index e89eea89..b4c9db36 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -1,1613 +1,161 @@ use lean_compiler::*; use lean_vm::*; -use utils::{poseidon16_permute, poseidon24_permute}; - -const DEFAULT_NO_VEC_RUNTIME_MEMORY: usize = 1 << 15; - -#[test] -#[should_panic] -fn test_duplicate_function_name() { - let program = r#" - fn a() -> 1 { - return 0; - } - - fn a() -> 1 { - return 1; - } - - fn main() { - a(); - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -#[should_panic] -fn test_duplicate_constant_name() { - let program = r#" - const A = 1; - const A = 0; - - fn main() { - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -#[should_panic] -fn test_wrong_n_returned_vars_1() { - let program = r#" - fn main() { - a, b = f(); - return; - } - - fn f() -> 1 { - return 0; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -#[should_panic] -fn test_wrong_n_returned_vars_2() { - let program = r#" - fn main() { - a = f(); - return; - } - - fn f() -> 1 { - return 0, 1; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -#[should_panic] -fn test_no_return() { - let program = r#" - fn main() { - a = f(); - return; - } - - fn f() -> 1 { - } - - fn g() -> 1 { - return 0; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_assumed_return() { - let program = r#" - fn main() { - a = f(); - return; - } - - #![assume_always_returns] - fn f() -> 1 { - if 1 == 1 { - return 0; - } else { - print(1); - } - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_fibonacci_program() { - // a program to check the value of the 30th Fibonacci number (832040) - let program = r#" - fn main() { - fibonacci(0, 1, 0, 30); - return; - } - - fn fibonacci(a, b, i, n) { - if i == n { - print(a); - return; - } - fibonacci(b, a + b, i + 1, n); - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_edge_case_0() { - let program = r#" - fn main() { - a = malloc(1); - a[0] = 0; - for i in 0..1 { - x = 1 + a[i]; - } - for i in 0..1 { - y = 1 + a[i]; - } - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_edge_case_1() { - let program = r#" - fn main() { - a = malloc(1); - a[0] = 0; - assert a[8 - 8] == 0; - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_edge_case_2() { - let program = r#" - fn main() { - for i in 0..5 unroll { - x = i; - print(x); - } - for i in 0..3 unroll { - x = i; - print(x); - } - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_decompose_bits() { - let program = r#" - fn main() { - x = 2**20 - 1; - a = malloc(31); - print(a); - hint_decompose_bits(x, a); - for i in 0..20 { - debug_assert a[i] == 1; - assert a[i] == 1; - } - for i in 20..31 { - assert a[i] == 0; - } - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_unroll() { - // a program to check the value of the 30th Fibonacci number (832040) - let program = r#" - fn main() { - for i in 0..5 unroll { - for j in i..2*i unroll { - print(i, j); - } - } - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_rev_unroll() { - // a program to check the value of the 30th Fibonacci number (832040) - let program = r#" - fn main() { - print(785 * 78 + 874 - 1); - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_mini_program_0() { - let program = r#" - fn main() { - for i in 0..5 { - for j in i..2*i*(2+1) { - print(i, j); - if i == 4 { - if j == 7 { - break; - } - } - } - } - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_mini_program_1() { - let program = r#" - const N = 10; - - fn main() { - arr = malloc(N); - fill_array(arr); - print_array(arr); - return; - } - - fn fill_array(arr) { - for i in 0..N { - if i == 0 { - arr[i] = 10; - } else if i == 1 { - arr[i] = 20; - } else if i == 2 { - arr[i] = 30; - } else { - i_plus_one = i + 1; - arr[i] = i_plus_one; - } - } - return; - } - - fn print_array(arr) { - for i in 0..N { - arr_i = arr[i]; - print(arr_i); - } - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_mini_program_2() { - let program = r#" - fn main() { - for i in 0..10 { - for j in i..10 { - for k in j..10 { - sum, prod = compute_sum_and_product(i, j, k); - if sum == 10 { - print(i, j, k, prod); - } - } - } - } - return; - } - - fn compute_sum_and_product(a, b, c) -> 2 { - s1 = a + b; - sum = s1 + c; - p1 = a * b; - product = p1 * c; - return sum, product; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} +use multilinear_toolkit::prelude::BasedVectorSpace; +use rand::{Rng, SeedableRng, rngs::StdRng}; +use utils::poseidon16_permute; #[test] -fn test_mini_program_3() { +fn test_poseidon() { let program = r#" - fn main() { - a = public_input_start / 8; - b = a + 1; - c = malloc_vec(2); - poseidon16(a, b, c, 0); +def main(): + a = NONRESERVED_PROGRAM_INPUT_START + b = a + 8 + c = Array(2*8) + poseidon16(a, b, c, 0) - c_shifted = c * 8; - d_shifted = (c + 1) * 8; - - for i in 0..8 { - cc = c_shifted[i]; - print(cc); - } - for i in 0..8 { - dd = d_shifted[i]; - print(dd); - } - return; - } + for i in range(0, 8): + cc = c[i] + print(cc) + for i in range(0, 8): + dd = c[i+8] + print(dd) + return "#; let public_input: [F; 16] = (0..16).map(F::new).collect::>().try_into().unwrap(); - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&public_input, &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + compile_and_run(&ProgramSource::Raw(program.to_string()), (&public_input, &[]), false); let _ = dbg!(poseidon16_permute(public_input)); } #[test] -fn test_mini_program_4() { - let program = r#" - fn main() { - a = public_input_start / 8; - c = a + 2; - f = malloc_vec(1); - poseidon24(a, c, f); - - f_shifted = f * 8; - for j in 0..8 { - print(f_shifted[j]); - } - return; - } - "#; - let public_input: [F; 24] = (0..24).map(F::new).collect::>().try_into().unwrap(); - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&public_input, &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); - - dbg!(&poseidon24_permute(public_input)[16..]); -} - -#[test] -fn test_mini_program_5() { - let program = r#" - fn main() { - arr = malloc(10); - arr[6] = 42; - arr[8] = 11; - sum_1 = func_1(arr[6], arr[8]); - assert sum_1 == 53; - return; - } - - fn func_1(i, j) inline -> 1 { - for k in 0..i { - for u in 0..j { - assert k + u != 1000000; - } - } - return i + j; - } - - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_inlined() { - let program = r#" - fn main() { - x = 1; - y = 2; - i, j, k = func_1(x, y); - assert i == 2; - assert j == 3; - assert k == 2130706432; - - g = malloc_vec(1); - h = malloc_vec(1); - g_ptr = g * 8; - h_ptr = h * 8; - for i in 0..8 { - g_ptr[i] = i; - } - for i in 0..8 unroll { - h_ptr[i] = i; - } - assert_vectorized_eq_1(g, h); - assert_vectorized_eq_2(g, h); - assert_vectorized_eq_3(g, h); - assert_vectorized_eq_4(g, h); - assert_vectorized_eq_5(g, h); - return; - } - - fn func_1(a, b) inline -> 3 { - x = a * b; - y = a + b; - return x, y, a - b; - } - - fn assert_vectorized_eq_1(x, y) { - x_ptr = x * 8; - y_ptr = y * 8; - for i in 0..4 unroll { - assert x_ptr[i] == y_ptr[i]; - } - for i in 4..8 { - assert x_ptr[i] == y_ptr[i]; - } - return; - } - - fn assert_vectorized_eq_2(x, y) inline { - x_ptr = x * 8; - y_ptr = y * 8; - for i in 0..4 unroll { - assert x_ptr[i] == y_ptr[i]; - } - for i in 4..8 { - assert x_ptr[i] == y_ptr[i]; - } - return; - } - fn assert_vectorized_eq_3(x, y) inline { - u = x + 7; - assert_vectorized_eq_1(u-7, y * 7 / 7); - return; - } - fn assert_vectorized_eq_4(x, y) { - dot_product_ee(x * 8, pointer_to_one_vector * 8, y * 8, 1); - dot_product_ee(x * 8 + 3, pointer_to_one_vector * 8, y * 8 + 3, 1); - return; - } - fn assert_vectorized_eq_5(x, y) inline { - dot_product_ee(x * 8, pointer_to_one_vector * 8, y * 8, 1); - dot_product_ee(x * 8 + 3, pointer_to_one_vector * 8, y * 8 + 3, 1); - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_inlined_2() { - let program = r#" - fn main() { - b = is_one(); - c = b; - return; - } - - fn is_one() inline -> 1 { - if !!assume_bool(1) { - return 1; - } else { - return 0; - } - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_inlined_3() { - let program = r#" - fn main() { - x = func(); - return; - } - fn func() -> 1 { - var a; - if 0 == 0 { - a = aux(); - } - return a; - } - - fn aux() inline -> 1 { - return 1; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_match() { - let program = r#" - fn main() { - for x in 0..3 unroll { - func_match(x); - } - for x in 0..2 unroll { - match x { - 0 => { - y = 10 * (x + 8); - z = 10 * y; - print(z); - } - 1 => { - y = 10 * x; - z = func_2(y); - print(z); - } - } - } - return; - } - - fn func_match(x) inline { - match x { - 0 => { - print(41); - } - 1 => { - y = func_1(x); - print(y + 1); - } - 2 => { - y = 10 * x; - print(y); - } - } - return; - } - - fn func_1(x) -> 1 { - return x * x * x * x; - } - - fn func_2(x) inline -> 1 { - return x * x * x * x * x * x; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_match_shrink() { - let program = r#" - fn main() { - match 1 { - 0 => { - y = 90; - } - 1 => { - y = 10; - z = func_2(y); - } - } - return; - } - - fn func_2(x) inline -> 1 { - return x * x; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -// #[test] -// fn inline_bug_mre() { -// let program = r#" -// fn main() { -// boolean(0); -// return; -// } - -// fn boolean(a) inline -> 1 { -// if a == 0 { -// return 0; -// } -// return 1; -// } -// "#; -// compile_and_run(program.to_string(), (&[], &[])); -// } - -#[test] -fn test_const_functions_calling_const_functions() { - // Test that const functions can call other const functions - let program = r#" - fn main() { - y = compute_value(3); - print(y); - return; - } - - fn compute_value(const n) -> 1 { - result = complex_computation(n, 5); - return result; - } - - fn complex_computation(const a, const b) -> 1 { - return a * a + b * b; - } - "#; - - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_inline_functions_calling_inline_functions() { - let program = r#" - fn main() { - x = double(3); - y = quad(x); - print(y); - return; - } - - fn double(a) inline -> 1 { - return a + a; - } - - fn quad(b) inline -> 1 { - result = double(b); - return result + result; - } - "#; - - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_nested_inline_functions() { - let program = r#" - fn main() { - result = level_one(3); - print(result); - return; - } - - fn level_one(x) inline -> 1 { - result = level_two(x); - return result; - } - - fn level_two(y) inline -> 1 { - result = level_three(y); - return result; - } - - fn level_three(z) inline -> 1 { - return z * z * z; - } - "#; - - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_const_and_nonconst_malloc_sharing_name() { - let program = r#" - fn main() { - f(1); - return; - } - - fn f(n) { - if 0 == 0 { - res = malloc(2); - res[1] = 0; - return; - } else { - res = malloc(n * 1); - return; - } - } - "#; - - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_debug_assert_eq() { - let program = r#" - fn main() { - a = 10; - b = 20; - debug_assert a * 2 == b; - debug_assert a != b; - debug_assert a < b; - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[should_panic] -#[test] -fn test_debug_assert_eq_fail() { - let program = r#" - fn main() { - a = 10; - b = 25; - debug_assert a * 2 == b; - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[should_panic] -#[test] -fn test_debug_assert_not_eq_fail() { - let program = r#" - fn main() { - a = 10; - b = 10; - debug_assert a != b; - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[should_panic] -#[test] -fn test_debug_assert_lt_fail() { +fn test_div_extension_field() { let program = r#" - fn main() { - a = 30; - b = 20; - debug_assert a < b; - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_next_multiple_of() { - let program = r#" - fn main() { - a = double(next_multiple_of(12, 8)); - assert a == 32; - return; - } - - fn double(const n) -> 1 { - return next_multiple_of(n, n) * 2; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} +# Dot product precompile: +BE = 1 # base-extension +EE = 0 # extension-extension +DIM = 5 -#[test] -fn test_const_array() { - let program = r#" - const FIVE = 5; - const ARR = [4, FIVE, 4 + 2, 3 * 2 + 1]; - fn main() { - for i in 1..len(ARR) unroll { - x = i + 4; - assert ARR[i] == x; - } - four = 4; - assert len(ARR) == four; - res = func(2); - six = 6; - assert res == six; - nothing(ARR[0]); - mem_arr = malloc(len(ARR)); - for i in 0..len(ARR) unroll { - mem_arr[i] = ARR[i]; - } - for i in 0..ARR[0] { - print(2**ARR[0]); - } - print(2**ARR[1]); - return; - } +def main(): + n = NONRESERVED_PROGRAM_INPUT_START + d = NONRESERVED_PROGRAM_INPUT_START + DIM + q = NONRESERVED_PROGRAM_INPUT_START + 2 * DIM + computed_q_1 = div_ext_1(n, d) + computed_q_2 = div_ext_2(n, d) + assert_eq_ext(computed_q_2, q) + assert_eq_ext(computed_q_1, q) + return - fn func(const x) -> 1 { - return ARR[x]; - } - fn nothing(x) { - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} +def assert_eq_ext(x, y): + for i in unroll(0, DIM): + assert x[i] == y[i] + return -#[test] -fn test_const_malloc_end_iterator_loop() { - let program = r#" - fn main() { - x = malloc(2); - x[0] = 3; - x[1] = 5; - for i in 0..2 unroll { - for j in 0..x[i] { - print(i, j); - } - } - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_array_return_targets() { - let program = r#" - fn main() { - a = malloc(10); - b = malloc(10); - a[1], b[4] = get_two_values(); - assert a[1] == 42; - assert b[4] == 99; - - i = 2; - j = 3; - a[i], b[j] = get_two_values(); - assert a[2] == 42; - assert b[3] == 99; - - x, a[5] = get_two_values(); - assert x == 42; - assert a[5] == 99; - - return; - } - - fn get_two_values() -> 2 { - return 42, 99; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_array_return_targets_with_expressions() { - let program = r#" - fn main() { - arr = malloc(20); - for i in 0..5 { - arr[i * 2], arr[i * 2 + 1] = compute_pair(i); - } - assert arr[0] == 0; - assert arr[1] == 0; - assert arr[2] == 1; - assert arr[3] == 2; - assert arr[4] == 2; - assert arr[5] == 4; - return; - } - - fn compute_pair(n) -> 2 { - return n, n * 2; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn intertwined_unrolled_loops_and_const_function_arguments() { - let program = r#" - const ARR = [10, 100]; - fn main() { - buff = malloc(3); - buff[0] = 0; - for i in 0..2 unroll { - res = f1(ARR[i]); - buff[i + 1] = res; - } - assert buff[2] == 1390320454; - return; - } - - fn f1(const x) -> 1 { - buff = malloc(9); - buff[0] = 1; - for i in x..x+4 unroll { - for j in i..i+2 unroll { - index = (i - x) * 2 + (j - i); - res = f2(i, j); - buff[index+1] = buff[index] * res; - } - } - return buff[8]; - } - - fn f2(const x, const y) -> 1 { - buff = malloc(7); - buff[0] = 0; - for i in x..x+2 unroll { - for j in i..i+3 unroll { - index = (i - x) * 3 + (j - i); - buff[index+1] = buff[index] + i + j; - } - } - return buff[4]; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} +def div_ext_1(n, d): + quotient = Array(DIM) + dot_product(d, quotient, n, 1, EE) + return quotient -#[test] -fn test_direct_const_arr_access() { - let program = r#" - const ARR = [10, 100]; - fn main() { - a = ARR[0]; - assert a == 10; - return; - } +def div_ext_2(n, d): + quotient = Array(DIM) + dot_product(quotient, d, n, 1, EE) + return quotient "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} -#[test] -fn test_const_fibonacci() { - let program = r#" - fn main() { - res = fib(8); - assert res == 21; - return; - } - fn fib(const n) -> 1 { - if n == 0 { - return 0; - } - if n == 1 { - return 1; - } - a = fib(saturating_sub(n, 1)); - b = fib(saturating_sub(n, 2)); - return a + b; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + let mut rng = StdRng::seed_from_u64(0); + let n: EF = rng.random(); + let d: EF = rng.random(); + let q = n / d; + let mut public_input = vec![]; + public_input.extend(n.as_basis_coefficients_slice()); + public_input.extend(d.as_basis_coefficients_slice()); + public_input.extend(q.as_basis_coefficients_slice()); + compile_and_run(&ProgramSource::Raw(program.to_string()), (&public_input, &[]), false); } -fn run_program_in_files(i: usize) { +fn test_data_dir() -> String { let manifest_dir = env!("CARGO_MANIFEST_DIR"); - let path = format!("{manifest_dir}/tests/program_{i}.snark"); - compile_and_run( - &ProgramSource::Filepath(path), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -#[should_panic] -fn test_undefined_import() { - run_program_in_files(0); -} - -#[test] -#[should_panic] -fn test_imported_function_name_clash() { - run_program_in_files(1); -} - -#[test] -#[should_panic] -fn test_imported_constant_name_clash() { - run_program_in_files(2); -} - -#[test] -fn test_double_import_tolerance() { - run_program_in_files(3); -} - -#[test] -fn test_circular_import_tolerance() { - run_program_in_files(4); -} - -#[test] -#[should_panic] -fn test_no_main() { - run_program_in_files(5); -} - -#[test] -fn test_imports() { - run_program_in_files(6); -} - -#[test] -fn test_name_conflict() { - let program = r#" - fn main() { - a = b(); - b = a(); - assert a + b == 30; - return; - } - fn a() -> 1 { - return 10; - } - fn b() -> 1 { - return 20; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -// BUG here: - -// #[test] -// fn test_num_files() { -// let expected_num_files = 3; // program_6.snark imports foo.snark and bar.snark -// let manifest_dir = env!("CARGO_MANIFEST_DIR"); -// let path = format!("{manifest_dir}/tests/program_6.snark"); -// let bytecode = compile_program(&ProgramSource::Filepath(path)); -// assert_eq!(bytecode.filepaths.len(), expected_num_files); -// assert_eq!(bytecode.source_code.len(), expected_num_files); -// } - -// TODO BUG - -// #[test] -// fn bug() { -// let program = r#" -// fn main() { -// x = func(); -// return; -// } -// fn func() -> 1 { -// var a; -// if 0 == 0 { -// a = aux(); -// } -// return a; -// } - -// fn aux() inline -> 1 { -// return 1; -// } -// "#; -// compile_and_run(program.to_string(), (&[], &[]), false); -// } - -#[test] -fn test_2d_const_array() { - let program = r#" - const NESTED = [[1, 2], [3, 4, 5], [6]]; - fn main() { - // Test len() on nested arrays - assert len(NESTED) == 3; - assert len(NESTED[0]) == 2; - assert len(NESTED[1]) == 3; - assert len(NESTED[2]) == 1; - - // Test chained indexing - assert NESTED[0][0] == 1; - assert NESTED[0][1] == 2; - assert NESTED[1][0] == 3; - assert NESTED[1][2] == 5; - assert NESTED[2][0] == 6; - - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_3d_const_array() { - let program = r#" - const DEEP = [[[1, 2], [3]], [[4, 5, 6]]]; - const ONE = 1; - fn main() { - assert len(DEEP) == 2; - assert len(DEEP[0]) == 2; - assert len(DEEP[0][0]) == 2; - assert len(DEEP[0][1]) == 1; - one = 1; - assert len(DEEP[ONE]) == one; - assert len(DEEP[1][0]) == 3; - - assert DEEP[0][0][0] == 1; - assert DEEP[0][0][1] == 2; - assert DEEP[0][1][0] == 3; - assert DEEP[1][0][0] == 4; - assert DEEP[1][0][2] == 6; - - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_2d_nested_array_with_expressions() { - let program = r#" - const TWO = 2; - const ARR = [[1 + 1, TWO * 2], [3 + TWO]]; - const INCR_ARR = [[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5]]; - fn main() { - assert len(ARR) == 2; - assert ARR[0][0] == 2; - assert ARR[0][1] == 4; - assert ARR[1][0] == 5; - five = ARR[1][0]; - assert five == 5; - x = 2 + 3 * (ARR[0][0] + ARR[1][0] + 3)**2; - assert x == 302; - for i in 0..4 unroll { - for j in 0..3 unroll { - y = INCR_ARR[i][j]; - assert INCR_ARR[i][j] == i + j - INCR_ARR[i][j] + y; + format!("{manifest_dir}/tests/test_data") +} + +fn find_files(dir: &str, prefix: &str, suffix: &str) -> Vec { + let mut paths: Vec = std::fs::read_dir(dir) + .expect("Failed to read test data directory") + .filter_map(|entry| entry.ok()) + .filter_map(|entry| { + let path = entry.path(); + let filename = path.file_name()?.to_str()?; + if filename.starts_with(prefix) && filename.ends_with(suffix) { + Some(path.to_string_lossy().to_string()) + } else { + None } - } - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); -} - -#[test] -fn test_const_array_element_exponentiation() { - let program = r#" - const ARR = [[5]]; - fn main() { - x = ARR[0][0]**2; - assert x == 25; - return; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); + }) + .collect(); + paths.sort(); + paths } #[test] fn test_num_files() { - let expected_num_files = 3; // program_6.snark imports foo.snark and bar.snark - let manifest_dir = env!("CARGO_MANIFEST_DIR"); - let path = format!("{manifest_dir}/tests/program_6.snark"); + let expected_num_files = 3; // program_2.py imports foo.py and bar.py + let path = format!("{}/program_2.py", test_data_dir()); let bytecode = compile_program(&ProgramSource::Filepath(path)); assert_eq!(bytecode.filepaths.len(), expected_num_files); assert_eq!(bytecode.source_code.len(), expected_num_files); } #[test] -fn test_nested_function_call() { - let program = r#" - fn main() { - assert incr(incr(incr(1))) == 4; - x = add(incr(1), incr(2)); - assert x == 5; +fn test_all_errors() { + let test_dir = test_data_dir(); + let paths = find_files(&test_dir, "error_", ".py"); - assert incr_inline(incr_inline(incr_inline(1))) == 4; - y = add_inlined(incr_inline(1), add_inlined(incr_inline(2), incr_inline(2))); - assert y == 8; - - return; - } + assert!(!paths.is_empty(), "No error_*.py files found"); + println!("Found {} test error programs", paths.len()); - fn add(a, b) -> 1 { - return a + b; + for path in paths { + let result = try_compile_and_run(&ProgramSource::Filepath(path.clone()), (&[], &[]), false); + assert!(result.is_err(), "Expected error for {}, but it succeeded", path); } +} - fn incr(x) -> 1 { - return x + 1; - } +#[test] +fn test_all_programs() { + let test_dir = test_data_dir(); + let paths = find_files(&test_dir, "program_", ".py"); - fn incr_inline(x) inline -> 1 { - return x + 1; - } - + assert!(!paths.is_empty(), "No program_*.py files found"); + println!("Found {} test programs", paths.len()); - fn add_inlined(a, b) inline -> 1 { - c = malloc(1); - zero = 0; - c[zero] = a + b; - return c[0]; + for path in paths { + if let Err(err) = try_compile_and_run(&ProgramSource::Filepath(path.clone()), (&[], &[]), false) { + panic!("Program {} failed with error: {:?}", path, err); + } } - "#; - - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); } #[test] -fn test_len_2d_array() { - let program = r#" - const ARR = [[1], [7, 3], [7]]; - const N = 2 + len(ARR[0]); - fn main() { - for i in 0..N unroll { - for j in 0..len(ARR[i]) unroll { - assert j * (j - 1) == 0; - } - } - return; +fn test_reserved_function_names() { + for name in RESERVED_FUNCTION_NAMES { + let program = format!("def main():\n return\ndef {name}():\n return"); + assert!( + try_compile_and_run(&ProgramSource::Raw(program), (&[], &[]), false).is_err(), + "Expected error when defining function with reserved name '{name}', but it succeeded" + ); } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); } #[test] -fn test_nested_matches() { - let program = r#" - fn main() { - assert test_func(0, 0) == 6; - assert test_func(1, 0) == 3; - return; - } - - fn test_func(a, b) -> 1 { - x = 1; - - var mut_x_2; - match a { - 0 => { - var mut_x_1; - mut_x_1 = x + 2; - match b { - 0 => { - mut_x_2 = mut_x_1 + 3; - } - } - } - 1 => { - mut_x_2 = x + 2; - } - } - - return mut_x_2; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); +fn debug_file_program() { + let index = 167; + let path = format!("{}/program_{}.py", test_data_dir(), index); + compile_and_run(&ProgramSource::Filepath(path), (&[], &[]), false); } #[test] -fn test_deeply_nested_match() { - // Test with 3 levels of nesting, multiple arms, and variables at each level +fn debug_str_program() { let program = r#" - fn main() { - // Test each combination with expected values - // (0,0,0): base=1000, local_a=5, local_b=8, inner_val=1008 - assert compute(0, 0, 0) == 1008; - // (0,0,1): base=1000, local_a=5, local_b=8, inner_val=1009 - assert compute(0, 0, 1) == 1009; - // (0,1,0): base=1000, local_a=5, local_b=12, inner_val=1012 - assert compute(0, 1, 0) == 1012; - // (0,1,1): base=1000, local_a=5, local_b=12, inner_val=1013 - assert compute(0, 1, 1) == 1013; - // (1,0,0): base=1000, local_a=16, local_b=36, inner_val=1036 - assert compute(1, 0, 0) == 1036; - // (1,0,1): base=1000, local_a=16, local_b=36, inner_val=1037 - assert compute(1, 0, 1) == 1037; - // (1,1,0): base=1000, local_a=16, local_b=46, inner_val=1046 - assert compute(1, 1, 0) == 1046; - // (1,1,1): base=1000, local_a=16, local_b=46, inner_val=1047 - assert compute(1, 1, 1) == 1047; - return; - } - - fn compute(a, b, c) -> 1 { - base = 1000; - var outer_val; - var mid_val; - var inner_val; - - match a { - 0 => { - outer_val = 5; - var local_a; - local_a = a + outer_val; // local_a = 5 - - match b { - 0 => { - mid_val = 3; - var local_b; - local_b = local_a + mid_val; // local_b = 8 - - match c { - 0 => { - inner_val = base + local_b + c; // 1000 + 8 + 0 = 1008 - } - 1 => { - inner_val = base + local_b + c; // 1000 + 8 + 1 = 1009 - } - } - } - 1 => { - mid_val = 7; - var local_b; - local_b = local_a + mid_val; // local_b = 12 - - match c { - 0 => { - inner_val = base + local_b + c; // 1000 + 12 + 0 = 1012 - } - 1 => { - inner_val = base + local_b + c; // 1000 + 12 + 1 = 1013 - } - } - } - } - } - 1 => { - outer_val = 15; - var local_a; - local_a = a + outer_val; // local_a = 16 - - match b { - 0 => { - mid_val = 20; - var local_b; - local_b = local_a + mid_val; // local_b = 36 - - match c { - 0 => { - inner_val = base + local_b + c; // 1000 + 36 + 0 = 1036 - } - 1 => { - inner_val = base + local_b + c; // 1000 + 36 + 1 = 1037 - } - } - } - 1 => { - mid_val = 30; - var local_b; - local_b = local_a + mid_val; // local_b = 46 - - match c { - 0 => { - inner_val = base + local_b + c; // 1000 + 46 + 0 = 1046 - } - 1 => { - inner_val = base + local_b + c; // 1000 + 46 + 1 = 1047 - } - } - } - } - } - } - - return inner_val; - } - "#; - compile_and_run( - &ProgramSource::Raw(program.to_string()), - (&[], &[]), - DEFAULT_NO_VEC_RUNTIME_MEMORY, - false, - ); +def main(): + return + "#; + compile_and_run(&ProgramSource::Raw(program.to_string()), (&[], &[]), false); } diff --git a/crates/lean_compiler/tests/program_5.snark b/crates/lean_compiler/tests/test_data/__init__.py similarity index 100% rename from crates/lean_compiler/tests/program_5.snark rename to crates/lean_compiler/tests/test_data/__init__.py diff --git a/crates/lean_compiler/tests/test_data/error_0.py b/crates/lean_compiler/tests/test_data/error_0.py new file mode 100644 index 00000000..18db8583 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_0.py @@ -0,0 +1,11 @@ +from snark_lib import * + +# Error: imported constant name clash (FOO defined in foo.py) +from misc.bar import * +from misc.foo import * + +FOO = 5 + + +def main(): + return diff --git a/crates/lean_compiler/tests/test_data/error_1.py b/crates/lean_compiler/tests/test_data/error_1.py new file mode 100644 index 00000000..be38bff5 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_1.py @@ -0,0 +1,9 @@ +from snark_lib import * + + +# Error: debug_assert(equality fails (10 * 2 != 25)) +def main(): + a = 10 + b = 25 + debug_assert(a * 2 == b) + return diff --git a/crates/lean_compiler/tests/test_data/error_10.py b/crates/lean_compiler/tests/test_data/error_10.py new file mode 100644 index 00000000..e2dbcabd --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_10.py @@ -0,0 +1,8 @@ +from snark_lib import * + +# Error: undefined import (file does not exist) +from asdfasdfadsfasdf import * + + +def main(): + return diff --git a/crates/lean_compiler/tests/test_data/error_11.py b/crates/lean_compiler/tests/test_data/error_11.py new file mode 100644 index 00000000..32218ea7 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_11.py @@ -0,0 +1,11 @@ +from snark_lib import * + + +# Error: wrong number of returned vars (expecting 2, got 1) +def main(): + a, b = f() + return + + +def f(): + return 0 diff --git a/crates/lean_compiler/tests/test_data/error_12.py b/crates/lean_compiler/tests/test_data/error_12.py new file mode 100644 index 00000000..06ed30e0 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_12.py @@ -0,0 +1,11 @@ +from snark_lib import * + + +# Error: wrong number of returned vars (expecting 1, got 2) +def main(): + a = f() + return + + +def f(): + return 0, 1 diff --git a/crates/lean_compiler/tests/test_data/error_13.py b/crates/lean_compiler/tests/test_data/error_13.py new file mode 100644 index 00000000..0450a5aa --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_13.py @@ -0,0 +1,14 @@ +from snark_lib import * + + +def main(): + a: Imu + a = 0 + a = a + 1 + if a == 1: + a = a + 10 + else: + a = a + 100 + a = a + 1000 + assert a == 11 + return diff --git a/crates/lean_compiler/tests/test_data/error_14.py b/crates/lean_compiler/tests/test_data/error_14.py new file mode 100644 index 00000000..61c9184e --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_14.py @@ -0,0 +1,12 @@ +from snark_lib import * + + +def main(): + a = 5 + b = 10 + a = swap(a, b) + 11 + return + + +def swap(a, b): + return b, a diff --git a/crates/lean_compiler/tests/test_data/error_15.py b/crates/lean_compiler/tests/test_data/error_15.py new file mode 100644 index 00000000..0d071150 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_15.py @@ -0,0 +1,13 @@ +from snark_lib import * + + +def main(): + a = 5 + b = 10 + a = swap(a, b) + 11 + return + + +@inline +def swap(a, b): + return b, a diff --git a/crates/lean_compiler/tests/test_data/error_17.py b/crates/lean_compiler/tests/test_data/error_17.py new file mode 100644 index 00000000..0c0e8381 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_17.py @@ -0,0 +1,12 @@ +from snark_lib import * + + +# Error: push to outer-scope vector inside else branch +def main(): + v = DynArray([1, 2, 3]) # Vector in outer scope + x = 5 + if x == 5: + x = 6 + else: + v.push(4) # Error: cannot push to outer-scope vector inside if/else + return diff --git a/crates/lean_compiler/tests/test_data/error_18.py b/crates/lean_compiler/tests/test_data/error_18.py new file mode 100644 index 00000000..86866ef7 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_18.py @@ -0,0 +1,9 @@ +from snark_lib import * + + +# Error: push to outer-scope vector inside non-unrolled loop +def main(): + v = DynArray([1, 2, 3]) # Vector in outer scope + for i in range(0, 5): + v.push(i) # Error: cannot push to outer-scope vector inside non-unrolled loop + return diff --git a/crates/lean_compiler/tests/test_data/error_19.py b/crates/lean_compiler/tests/test_data/error_19.py new file mode 100644 index 00000000..774778da --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_19.py @@ -0,0 +1,12 @@ +from snark_lib import * + + +# Error: vector passed to non-inlined function +def main(): + v = DynArray([1, 2, 3]) + process(v) # Error: vectors cannot be passed to non-inlined functions + return + + +def process(x): + return diff --git a/crates/lean_compiler/tests/test_data/error_2.py b/crates/lean_compiler/tests/test_data/error_2.py new file mode 100644 index 00000000..bdb76f8e --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_2.py @@ -0,0 +1,9 @@ +from snark_lib import * + + +# Error: debug_assert(less-than fails (30 < 20 is false)) +def main(): + a = 30 + b = 20 + debug_assert(a < b) + return diff --git a/crates/lean_compiler/tests/test_data/error_20.py b/crates/lean_compiler/tests/test_data/error_20.py new file mode 100644 index 00000000..9f089b33 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_20.py @@ -0,0 +1,9 @@ +from snark_lib import * + + +# timing matters +def main(): + v = DynArray([]) + print(v[0]) + v.push(10) + return diff --git a/crates/lean_compiler/tests/test_data/error_21.py b/crates/lean_compiler/tests/test_data/error_21.py new file mode 100644 index 00000000..c3ccc449 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_21.py @@ -0,0 +1,8 @@ +from snark_lib import * + + +# Error test: pop on empty vector +def main(): + v = DynArray([]) + v.pop() + return diff --git a/crates/lean_compiler/tests/test_data/error_22.py b/crates/lean_compiler/tests/test_data/error_22.py new file mode 100644 index 00000000..b478bcff --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_22.py @@ -0,0 +1,8 @@ +from snark_lib import * + + +# Error test: pop on empty nested vector +def main(): + v = DynArray([DynArray([])]) + v[0].pop() + return diff --git a/crates/lean_compiler/tests/test_data/error_23.py b/crates/lean_compiler/tests/test_data/error_23.py new file mode 100644 index 00000000..38945b13 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_23.py @@ -0,0 +1,9 @@ +from snark_lib import * + + +# Error test: pop from outer-scope vector in non-unroll loop +def main(): + v = DynArray([1, 2, 3]) + for i in range(0, 2): + v.pop() + return diff --git a/crates/lean_compiler/tests/test_data/error_24.py b/crates/lean_compiler/tests/test_data/error_24.py new file mode 100644 index 00000000..3d9075ed --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_24.py @@ -0,0 +1,6 @@ +from snark_lib import * +# Test: bad indentation - inconsistent indent after function definition +def main(): + x = 1 + y = 2 + return diff --git a/crates/lean_compiler/tests/test_data/error_25.py b/crates/lean_compiler/tests/test_data/error_25.py new file mode 100644 index 00000000..b578da3b --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_25.py @@ -0,0 +1,6 @@ +from snark_lib import * +# Test: bad indentation - dedent closes function too early, leaving code at program level +def main(): + x = 1 +y = 2 + return diff --git a/crates/lean_compiler/tests/test_data/error_26.py b/crates/lean_compiler/tests/test_data/error_26.py new file mode 100644 index 00000000..8dc10530 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_26.py @@ -0,0 +1,7 @@ +from snark_lib import * +# Test: bad indentation - dedent closes function and loop too early +def main(): + for i in unroll(0, 3): + x = i +y = 0 + return diff --git a/crates/lean_compiler/tests/test_data/error_27.py b/crates/lean_compiler/tests/test_data/error_27.py new file mode 100644 index 00000000..882cf6f2 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_27.py @@ -0,0 +1,6 @@ +from snark_lib import * +# Test: bad indentation - dedent closes function too early, leaving code at program level +def main(): + x = 1 + y = 2 + return diff --git a/crates/lean_compiler/tests/test_data/error_28.py b/crates/lean_compiler/tests/test_data/error_28.py new file mode 100644 index 00000000..7466e295 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_28.py @@ -0,0 +1,8 @@ +from snark_lib import * + + +# fmt: off +def main(): + for i in unroll(0, 3): + print(i) + return diff --git a/crates/lean_compiler/tests/test_data/error_29.py b/crates/lean_compiler/tests/test_data/error_29.py new file mode 100644 index 00000000..6b3a89bc --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_29.py @@ -0,0 +1,8 @@ +from snark_lib import * + + +# fmt: off +def main(): + if 8 == 9: + print(1) + return diff --git a/crates/lean_compiler/tests/test_data/error_3.py b/crates/lean_compiler/tests/test_data/error_3.py new file mode 100644 index 00000000..19d91600 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_3.py @@ -0,0 +1,9 @@ +from snark_lib import * + + +# Error: debug_assert(not-equal fails (10 != 10 is false)) +def main(): + a = 10 + b = 10 + debug_assert(a != b) + return diff --git a/crates/lean_compiler/tests/test_data/error_30.py b/crates/lean_compiler/tests/test_data/error_30.py new file mode 100644 index 00000000..88744814 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_30.py @@ -0,0 +1,11 @@ +from snark_lib import * + + +def main(): + x = get_value() + x = 42 + return + + +def get_value(): + return 42 diff --git a/crates/lean_compiler/tests/test_data/error_31.py b/crates/lean_compiler/tests/test_data/error_31.py new file mode 100644 index 00000000..e3094719 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_31.py @@ -0,0 +1,14 @@ +from snark_lib import * + + +# Error: Inconsistent return counts in function +def main(): + x = bad_func(0) + return + + +def bad_func(cond): + if cond == 0: + return 1 + else: + return 1, 2 diff --git a/crates/lean_compiler/tests/test_data/error_32.py b/crates/lean_compiler/tests/test_data/error_32.py new file mode 100644 index 00000000..27855779 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_32.py @@ -0,0 +1,6 @@ +from snark_lib import * +from misc.circular_import import * + + +def main(): + return diff --git a/crates/lean_compiler/tests/test_data/error_33.py b/crates/lean_compiler/tests/test_data/error_33.py new file mode 100644 index 00000000..95444332 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_33.py @@ -0,0 +1,9 @@ +from snark_lib import * + +def main(): + x, y = + some_function() + return + +def some_function(): + return 1, 2 diff --git a/crates/lean_compiler/tests/test_data/error_4.py b/crates/lean_compiler/tests/test_data/error_4.py new file mode 100644 index 00000000..ac852b9f --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_4.py @@ -0,0 +1,9 @@ +from snark_lib import * + +# Error: duplicate constant name (A defined twice) +A = 1 +A = 0 + + +def main(): + return diff --git a/crates/lean_compiler/tests/test_data/error_5.py b/crates/lean_compiler/tests/test_data/error_5.py new file mode 100644 index 00000000..baedb43f --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_5.py @@ -0,0 +1,15 @@ +from snark_lib import * + + +# Error: duplicate function name (a defined twice) +def a(): + return 0 + + +def a(): + return 1 + + +def main(): + a() + return diff --git a/crates/lean_compiler/tests/test_data/error_6.py b/crates/lean_compiler/tests/test_data/error_6.py new file mode 100644 index 00000000..3681fa0a --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_6.py @@ -0,0 +1,13 @@ +from snark_lib import * + +# Error: imported function name clash (bar defined in bar.py) +from misc.bar import * +from misc.foo import * + + +def bar(): + return + + +def main(): + return diff --git a/crates/lean_compiler/tests/test_data/error_7.py b/crates/lean_compiler/tests/test_data/error_7.py new file mode 100644 index 00000000..94c262a6 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_7.py @@ -0,0 +1,12 @@ +from snark_lib import * + + +# Error: inline functions with parameters: Mut are not supported +def main(): + return + + +@inline +def double(x: Mut): + x = x * 2 + return x diff --git a/crates/lean_compiler/tests/test_data/error_8.py b/crates/lean_compiler/tests/test_data/error_8.py new file mode 100644 index 00000000..62cc3e69 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_8.py @@ -0,0 +1,6 @@ +from snark_lib import * + + +# Error: no main function +def not_main(): + return diff --git a/crates/lean_compiler/tests/test_data/error_9.py b/crates/lean_compiler/tests/test_data/error_9.py new file mode 100644 index 00000000..305eefab --- /dev/null +++ b/crates/lean_compiler/tests/test_data/error_9.py @@ -0,0 +1,10 @@ +from snark_lib import * +# Error: function f has no return statement +def main(): + a = f() + return + +def f(): + +def g(): + return 0 \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/misc/__init__.py b/crates/lean_compiler/tests/test_data/misc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/crates/lean_compiler/tests/test_data/misc/bar.py b/crates/lean_compiler/tests/test_data/misc/bar.py new file mode 100644 index 00000000..ecaae8a5 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/misc/bar.py @@ -0,0 +1,5 @@ +from snark_lib import * + + +def bar(x): + return x * 2 diff --git a/crates/lean_compiler/tests/test_data/misc/circular_import.py b/crates/lean_compiler/tests/test_data/misc/circular_import.py new file mode 100644 index 00000000..7faea16d --- /dev/null +++ b/crates/lean_compiler/tests/test_data/misc/circular_import.py @@ -0,0 +1,2 @@ +from snark_lib import * +from misc.circular_import import * diff --git a/crates/lean_compiler/tests/test_data/misc/foo.py b/crates/lean_compiler/tests/test_data/misc/foo.py new file mode 100644 index 00000000..3dd0de82 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/misc/foo.py @@ -0,0 +1,4 @@ +from snark_lib import * +from misc.bar import * + +FOO = 3 diff --git a/crates/lean_compiler/tests/test_data/perf_constant_if_baseline.py b/crates/lean_compiler/tests/test_data/perf_constant_if_baseline.py new file mode 100644 index 00000000..4ca0f19b --- /dev/null +++ b/crates/lean_compiler/tests/test_data/perf_constant_if_baseline.py @@ -0,0 +1,37 @@ +from snark_lib import * +# Baseline program - equivalent to perf_constant_if_with_conditions.py +# after all constant if/else conditions are eliminated at compile time +# +# This is what the optimized version should compile down to + + +def main(): + result: Mut = 0 + + # From: if A == 10 { result = result + 1; } + result = result + 1 + + # From: if D == 100 { ... } else { result = result + 2; } + result = result + 2 + + # From: nested if A == 10, B == 20, C == 30 + result = result + 4 + + # From: if A == 10 { if B == 999 { ... } else { result = result + 8; } } + result = result + 8 + + # From: if A != 5 { result = result + 16; } + result = result + 16 + + # From: deeply nested (5 levels all true) + result = result + 32 + + # From: if-else-if chain where A == 10 + result = result + 64 + + # From: if B == 20 { if C == 999 { ... } else { if D == 5 { result = result + 128; } } } + result = result + 128 + + # Final result should be: 1 + 2 + 4 + 8 + 16 + 32 + 64 + 128 = 255 + assert result == 255 + return diff --git a/crates/lean_compiler/tests/test_data/perf_constant_if_with_conditions.py b/crates/lean_compiler/tests/test_data/perf_constant_if_with_conditions.py new file mode 100644 index 00000000..83e93f82 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/perf_constant_if_with_conditions.py @@ -0,0 +1,84 @@ +from snark_lib import * +# Complex test program with constant if/else conditions +# All conditions evaluate to constants at compile time and should be eliminated + +A = 10 +B = 20 +C = 30 +D = 5 +ZERO = 0 +ONE = 1 + +def main(): + result: Mut = 0 + + # Simple constant condition (true) + if A == 10: + result = result + 1 + + # Simple constant condition (false, no else) + if A == 999: + result = result + 1000 + + # Constant condition with else (false branch taken) + if D == 100: + result = result + 2000 + else: + result = result + 2 + + # Nested constant conditions (all true) + if A == 10: + if B == 20: + if C == 30: + result = result + 4 + + # Nested with mixed true/false (outer true, inner false) + if A == 10: + if B == 999: + result = result + 3000 + else: + result = result + 8 + + # Using != operator (true) + if A != 5: + result = result + 16 + + # Using != operator (false) + if A != 10: + result = result + 4000 + + # Deeply nested (5 levels) + if A == 10: + if B == 20: + if C == 30: + if D == 5: + if ONE == 1: + result = result + 32 + + # Chain of if-else-if with constants + if A == 1: + result = result + 5000 + else if A == 2: + result = result + 6000 + else if A == 10: + result = result + 64 + else: + result = result + 7000 + + # Nested false conditions (entire block should be eliminated) + if ZERO == 1: + if A == 10: + if B == 20: + result = result + 8000 + + # Complex: true outer, false inner with else + if B == 20: + if C == 999: + result = result + 9000 + else: + if D == 5: + result = result + 128 + + # Final result should be: 1 + 2 + 4 + 8 + 16 + 32 + 64 + 128 = 255 + assert result == 255 + return \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_0.py b/crates/lean_compiler/tests/test_data/program_0.py new file mode 100644 index 00000000..ef5cd0b1 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_0.py @@ -0,0 +1,7 @@ +from snark_lib import * +from misc.foo import * +from misc.foo import * + + +def main(): + return diff --git a/crates/lean_compiler/tests/test_data/program_1.py b/crates/lean_compiler/tests/test_data/program_1.py new file mode 100644 index 00000000..60c96b1d --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_1.py @@ -0,0 +1,5 @@ +from snark_lib import * + + +def main(): + return diff --git a/crates/lean_compiler/tests/test_data/program_10.py b/crates/lean_compiler/tests/test_data/program_10.py new file mode 100644 index 00000000..f7dfc1f2 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_10.py @@ -0,0 +1,6 @@ +from snark_lib import * + + +def main(): + print(785 * 78 + 874 - 1) + return diff --git a/crates/lean_compiler/tests/test_data/program_100.py b/crates/lean_compiler/tests/test_data/program_100.py new file mode 100644 index 00000000..5eeac3b6 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_100.py @@ -0,0 +1,23 @@ +from snark_lib import * + + +def main(): + x: Imu + y: Imu + + cond = 1 + if cond == 1: + x = 10 + y = 20 + else: + x = 100 + y = 200 + + x2: Mut = x + y2: Mut = y + x2 = x2 + y2 # 10 + 20 = 30 + y2 = y2 - 5 # 20 - 5 = 15 + + assert x2 == 30 + assert y2 == 15 + return diff --git a/crates/lean_compiler/tests/test_data/program_101.py b/crates/lean_compiler/tests/test_data/program_101.py new file mode 100644 index 00000000..cbd3923d --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_101.py @@ -0,0 +1,18 @@ +from snark_lib import * +def main(): + x: Mut = 5 + + cond = 1 + if cond == 1: + x = x + 10 + else: + assert x == 15 + + y: Mut = 5 + cond2 = 0 + if cond2 == 1: + y = y + 10 + else: + assert y == 5 + + return \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_102.py b/crates/lean_compiler/tests/test_data/program_102.py new file mode 100644 index 00000000..e9e0e407 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_102.py @@ -0,0 +1,21 @@ +from snark_lib import * + + +def main(): + a: Mut = 10 + b: Mut = 20 + + temp = a + a = b + b = temp + + assert a == 20 + assert b == 10 + + temp2 = a + a = b + b = temp2 + + assert a == 10 + assert b == 20 + return diff --git a/crates/lean_compiler/tests/test_data/program_103.py b/crates/lean_compiler/tests/test_data/program_103.py new file mode 100644 index 00000000..5e9b3549 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_103.py @@ -0,0 +1,21 @@ +from snark_lib import * + + +def main(): + arr = Array(5) + for i in unroll(0, 5): + arr[i] = i * 10 + + idx: Mut = 0 + val1 = arr[idx] # arr[0] = 0 + assert val1 == 0 + + idx = idx + 1 + val2 = arr[idx] # arr[1] = 10 + assert val2 == 10 + + idx = idx + 2 + val3 = arr[idx] # arr[3] = 30 + assert val3 == 30 + + return diff --git a/crates/lean_compiler/tests/test_data/program_104.py b/crates/lean_compiler/tests/test_data/program_104.py new file mode 100644 index 00000000..a54f03b2 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_104.py @@ -0,0 +1,28 @@ +from snark_lib import * + + +def main(): + a: Mut = 1 + b: Mut = 2 + + a, b = pair_incr(a, b) # a=2, b=3 + assert a == 2 + assert b == 3 + + a, b = pair_incr(a, b) # a=3, b=4 + assert a == 3 + assert b == 4 + + a, b = pair_swap(a, b) # a=4, b=3 + assert a == 4 + assert b == 3 + + return + + +def pair_incr(x, y): + return x + 1, y + 1 + + +def pair_swap(x, y): + return y, x diff --git a/crates/lean_compiler/tests/test_data/program_105.py b/crates/lean_compiler/tests/test_data/program_105.py new file mode 100644 index 00000000..0a59a33a --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_105.py @@ -0,0 +1,21 @@ +from snark_lib import * + + +def main(): + x: Mut = 0 + x = x + 1 # 1 + x = x + 1 # 2 + x = x + 1 # 3 + x = x + 1 # 4 + x = x + 1 # 5 + x = x + 1 # 6 + x = x + 1 # 7 + x = x + 1 # 8 + x = x + 1 # 9 + x = x + 1 # 10 + x = x * 2 # 20 + x = x + 1 # 21 + x = x + 1 # 22 + x = x + 1 # 23 + assert x == 23 + return diff --git a/crates/lean_compiler/tests/test_data/program_106.py b/crates/lean_compiler/tests/test_data/program_106.py new file mode 100644 index 00000000..dec8abe1 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_106.py @@ -0,0 +1,20 @@ +from snark_lib import * + + +def main(): + x: Mut = 0 + + x = x + 5 + if x == 5: + x = x + 10 + else: + x = x + 100 + assert x == 15 + + if x == 15: + x = x + 1 + else: + x = x + 1000 + assert x == 16 + + return diff --git a/crates/lean_compiler/tests/test_data/program_107.py b/crates/lean_compiler/tests/test_data/program_107.py new file mode 100644 index 00000000..4f5b6ce2 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_107.py @@ -0,0 +1,22 @@ +from snark_lib import * + + +def main(): + assert test_func(0, 1) == 111 + assert test_func(1, 0) == 200 + return + + +def test_func(sel, cond): + x: Mut = 100 + + match sel: + case 0: + x = x + 10 + case 1: + x = x + 100 + + if cond == 1: + x = x + 1 + + return x diff --git a/crates/lean_compiler/tests/test_data/program_108.py b/crates/lean_compiler/tests/test_data/program_108.py new file mode 100644 index 00000000..8750127b --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_108.py @@ -0,0 +1,17 @@ +from snark_lib import * + + +def main(): + assert test_func(0, 0) == 6 + return + + +def test_func(a, b): + x: Mut = 1 + match a: + case 0: + x = x + 2 + match b: + case 0: + x = x + 3 + return x diff --git a/crates/lean_compiler/tests/test_data/program_109.py b/crates/lean_compiler/tests/test_data/program_109.py new file mode 100644 index 00000000..b04f2608 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_109.py @@ -0,0 +1,21 @@ +from snark_lib import * + + +def main(): + assert test_func(0, 0) == 6 + return + + +def test_func(a, b): + x = 1 + + mut_x_2: Imu + match a: + case 0: + mut_x_1: Imu + mut_x_1 = x + 2 + match b: + case 0: + mut_x_2 = mut_x_1 + 3 + + return mut_x_2 diff --git a/crates/lean_compiler/tests/test_data/program_11.py b/crates/lean_compiler/tests/test_data/program_11.py new file mode 100644 index 00000000..e0b4f20b --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_11.py @@ -0,0 +1,13 @@ +from snark_lib import * + +ARR = [0, 1, 2, 3, 4] + + +def main(): + vector = DynArray([]) + for i in unroll(0, len(ARR)): + k = ARR[i] + vector.push(DynArray([])) + vector[k].push(k) + assert vector[k][0] == k + return diff --git a/crates/lean_compiler/tests/test_data/program_110.py b/crates/lean_compiler/tests/test_data/program_110.py new file mode 100644 index 00000000..71ddfd8b --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_110.py @@ -0,0 +1,100 @@ +from snark_lib import * +def main(): + sum1: Mut = 0 + sum2: Mut = 0 + count: Mut = 0 + + for i in unroll(0, 4): + for j in unroll(0, 3): + count = count + 1 + remainder = j % 2 + if remainder == 0: + sum1 = sum1 + i + j + else: + sum2 = sum2 + i * j + + assert count == 12 + assert sum1 == 20 + assert sum2 == 6 + + state: Mut = 0 + for phase in unroll(0, 5): + match phase: + case 0: + state = state + 1 + case 1: + state = state * 10 + case 2: + if state == 10: + state = state + 5 + else: + state = state + 1000 + case 3: + state = state * 2 + case 4: + state = state + 1 + assert state == 31 + + a: Mut = 5 + b: Mut = 10 + + for round in unroll(0, 3): + x, y = double_and_add(a, b) + a = x + b = y + assert a == 40 + assert b == 25 + + p: Mut = 1 + q: Mut = 2 + r: Mut = 3 + + outer_sel = 1 + if outer_sel == 0: + p = p + 100 + else if outer_sel == 1: + inner_sel = 2 + match inner_sel: + case 0: + q = q + 200 + case 1: + r = r + 300 + case 2: + deep_cond = 0 + if deep_cond == 0: + p = p * 10 + q = q * 10 + r = r * 10 + else: + p = p + 9999 + else: + r = r + 400 + + assert p == 10 + assert q == 20 + assert r == 30 + + result = complex_compute(3, 4, 5) + assert result == 47 + + fwd_val: Imu + cond = 1 + if cond == 0: + fwd_val = 100 + else: + fwd_val = 200 + fwd_val2: Mut = fwd_val + fwd_val2 = fwd_val2 + 50 + fwd_val2 = fwd_val2 * 2 + assert fwd_val2 == 500 + + return + +def double_and_add(x, y): + return x * 2, y + 5 + +def complex_compute(a, b, c): + sum = a + b + product = sum * c + extra = a * b + return product + extra \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_111.py b/crates/lean_compiler/tests/test_data/program_111.py new file mode 100644 index 00000000..49dc08dd --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_111.py @@ -0,0 +1,107 @@ +from snark_lib import * +def main(): + fib_result = fib_iterative(10) + assert fib_result == 55 + + accum: Mut = 0 + for i in unroll(0, 5): + accum = step_accumulate(accum, i) + assert accum == 25 + + a, b, c = chain_compute(5, 3) + assert a == 11 + assert b == 3 + assert c == 39 + + result = nested_mut_params(100) + assert result == 106 + + state: Mut = 0 + for phase in unroll(0, 5): + state = state_machine_step(state, phase) + assert state == 151 + + x: Mut = 10 + y: Mut = 20 + + cond1 = 1 + if cond1 == 1: + x = x + y + y = y - 5 + else: + x = x * 2 + + cond2 = 0 + if cond2 == 1: + x = x * 100 + else: + y = y + x + + assert x == 30 + assert y == 45 + + sum_outer: Mut = 0 + sum_inner: Mut = 0 + for i in unroll(0, 3): + sum_outer = sum_outer + i + for j in unroll(0, 4): + sum_inner = sum_inner + j + assert sum_outer == 3 + assert sum_inner == 18 + + result8 = complex_chain(2, 3, 5) + assert result8 == 31 + + return + +def fib_iterative(n: Const): + prev: Mut = 0 + curr: Mut = 1 + for i in unroll(0, n): + if i == 0: + else: + next = prev + curr + prev = curr + curr = next + return curr + +def step_accumulate(acc, i): + return acc + i * 2 + 1 + +def step_compute(x, y): + sum = x + y + product = x * y + return sum, y, product + +def chain_compute(x, y): + a1, b1, c1 = step_compute(x, y) + a2, b2, c2 = step_compute(a1, b1) + return a2, b2, c1 + c2 + +def nested_mut_params(base: Mut): + for i in unroll(0, 3): + base = base + i * 2 + return base + +def state_machine_step(current_state, phase): + result: Imu + if phase == 0: + if current_state == 0: + result = 1 + else: + result = current_state + 1000 + else if phase == 1: + result = current_state + 11 + else if phase == 2: + result = current_state + 3 + else if phase == 3: + result = current_state * 10 + else: + result = current_state + 1 + return result + +def complex_chain(a, b, c): + sum = a + b + product1 = sum * c + product2 = a * b + return product1 + product2 \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_112.py b/crates/lean_compiler/tests/test_data/program_112.py new file mode 100644 index 00000000..93ed1207 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_112.py @@ -0,0 +1,108 @@ +from snark_lib import * +def main(): + result1: Imu + outer_sel = 1 + match outer_sel: + case 0: + result1 = 100 + case 1: + inner_sel = 2 + match inner_sel: + case 0: + result1 = 200 + case 1: + result1 = 300 + case 2: + result1 = 456 + assert result1 == 456 + + counter: Imu + flag: Imu + + phase = 1 + if phase == 0: + counter = 0 + flag = 100 + else if phase == 1: + counter = 10 + flag = 200 + else: + counter = 100 + flag = 300 + + counter2: Mut = counter + flag2: Mut = flag + counter2 = counter2 + 5 + flag2 = flag2 * 2 + + assert counter2 == 15 + assert flag2 == 400 + + x: Imu + y: Imu + + init_sel = 0 + if init_sel == 0: + x = 5 + y = 10 + else: + x = 50 + y = 100 + + x2: Mut = x + y2: Mut = y + x2 = x2 * 2 + y2 = y2 + x2 + x2 = x2 + 1 + x2 = x2 * y2 + + assert x2 == 220 + assert y2 == 20 + + outcome: Imu + selector = 4 + match selector: + case 0: + outcome = compute_outcome(0, 0) + case 1: + outcome = compute_outcome(1, 1) + case 2: + outcome = compute_outcome(2, 4) + case 3: + outcome = compute_outcome(3, 9) + case 4: + outcome = compute_outcome(4, 16) + case 5: + outcome = compute_outcome(5, 25) + assert outcome == 84 + + p: Imu + q: Imu + r: Imu + + s1 = 1 + if s1 == 1: + p = 1 + else: + p = 10 + + s2 = 0 + if s2 == 1: + q = 100 + else: + q = p + 10 + + s3 = 1 + if s3 == 1: + r = p + q + 100 + else: + r = 999 + + assert p == 1 + assert q == 11 + assert r == 112 + + return + +def compute_outcome(a, b): + return a * b + a + b \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_114.py b/crates/lean_compiler/tests/test_data/program_114.py new file mode 100644 index 00000000..fbad1ccb --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_114.py @@ -0,0 +1,138 @@ +from snark_lib import * +def main(): + x1, y1, z1 = initial_values() + assert x1 == 10 + assert y1 == 20 + assert z1 == 30 + + x2, y2, z2 = rotate_triple(x1, y1, z1) + assert x2 == 20 + assert y2 == 30 + assert z2 == 10 + + x3, y3, z3 = scale_triple(x2, y2, z2, 2) + assert x3 == 40 + assert y3 == 60 + assert z3 == 20 + + a, b = swap_pair(100, 200) + assert a == 200 + assert b == 100 + + arr = Array(20) + for i in unroll(0, 10): + arr[i] = i * 5 + + sum = sum_array_func(arr, 5) + assert sum == 50 + + result4 = complex_nested_compute(2, 1, 3) + assert result4 == 280 + + fwd_x: Imu + fwd_y: Imu + + mode = 2 + if mode == 0: + fwd_x = 1 + fwd_y = 1 + else if mode == 1: + fwd_x = 10 + fwd_y = 10 + else: + fwd_x = 100 + fwd_y = 200 + + fwd_x2: Mut = fwd_x + fwd_y2: Mut = fwd_y + fwd_x2 = fwd_x2 + fwd_y2 + fwd_y2 = fwd_x2 - 100 + + assert fwd_x2 == 300 + assert fwd_y2 == 200 + + result6 = chain_of_funcs(5) + assert result6 == 115 + + p1, q1 = first_pair(3, 4) + p2, q2 = second_pair(p1, q1) + p3, q3 = third_pair(p2, q2) + + assert p3 == 103 + assert q3 == 1596 + + return + +def initial_values(): + return 10, 20, 30 + +def rotate_triple(a, b, c): + return b, c, a + +def scale_triple(a, b, c, factor): + return a * factor, b * factor, c * factor + +def swap_pair(a, b): + return b, a + +def sum_array_func(arr, n: Const): + total: Mut = 0 + for i in unroll(0, n): + total = total + arr[i] + return total + +def complex_nested_compute(outer, inner, depth): + result: Imu + + if outer == 0: + result = 100 + else if outer == 1: + if inner == 0: + result = 110 + else: + result = 120 + else: + if inner == 0: + if depth == 0: + result = 200 + else if depth == 1: + result = 210 + else if depth == 2: + result = 220 + else: + result = 230 + else: + if depth == 0: + result = 250 + else if depth == 1: + result = 260 + else if depth == 2: + result = 270 + else: + result = 280 + + return result + +def chain_of_funcs(x): + y = step_one(x) + z = step_two(y) + w = step_three(z) + return w + +def step_one(n): + return n + 10 + +def step_two(n): + return n * 2 + +def step_three(n): + return n + 85 + +def first_pair(a, b): + return a + b, a * b + +def second_pair(a, b): + return a + b, a * b + +def third_pair(a, b): + return a + b, a * b \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_117.py b/crates/lean_compiler/tests/test_data/program_117.py new file mode 100644 index 00000000..c379a5a1 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_117.py @@ -0,0 +1,79 @@ +from snark_lib import * + + +def main(): + result1 = multi_ops() + assert result1 == 42 + + result2 = double_use_expr() + assert result2 == 18 + + a, b, c = cascade_assign() + assert a == 3 + assert b == 6 + assert c == 18 + + result4 = mixed_mut_immut() + assert result4 == 35 + + result5 = expr_tree() + assert result5 == 100 + + return + + +def multi_ops(): + x: Mut = 2 + x = x + 3 # 5 + x = x * 4 # 20 + x = x + 2 # 22 + x = x * 2 # 44 + x = x - 2 # 42 + return x + + +def double_use_expr(): + x: Mut = 3 + y = x + x * 2 # 3 + 6 = 9 + x = y # x = 9 + z = x + x # 18 + return z + + +def cascade_assign(): + a: Mut = 1 + b: Mut = 2 + c: Mut = 3 + + a = a + 2 # a = 3 + b = a * 2 # b = 6 + c = b * 3 # c = 18 + + return a, b, c + + +def mixed_mut_immut(): + immut_x = 10 + y: Mut = 5 + + y = y + immut_x # 15 + y = y + immut_x # 25 + y = y + immut_x # 35 + + return y + + +def expr_tree(): + a: Mut = 2 + b: Mut = 3 + c: Mut = 5 + + a = a * b + c # 2*3 + 5 = 11 + b = a + c # 11 + 5 = 16 + c = a * b - (a + b) # 11*16 - 27 = 176 - 27 = 149 + + a = 10 + b = a + c = a * b # 100 + + return c diff --git a/crates/lean_compiler/tests/test_data/program_118.py b/crates/lean_compiler/tests/test_data/program_118.py new file mode 100644 index 00000000..d2b19edf --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_118.py @@ -0,0 +1,83 @@ +from snark_lib import * +def main(): + result1 = asymmetric_depth(0, 0, 0) + assert result1 == 1111 + + result2 = asymmetric_depth(0, 1, 1) + assert result2 == 1122 + + result3 = asymmetric_depth(1, 0, 0) # Shallow branch + assert result3 == 2000 + + result4 = unbalanced_modifications(0) + assert result4 == 25 + + result5 = unbalanced_modifications(1) + assert result5 == 110 + + result6 = empty_else(0) + assert result6 == 5 + + result7 = empty_else(1) + assert result7 == 15 + + result8 = long_else_if_chain(0) + assert result8 == 1 + + result9 = long_else_if_chain(3) + assert result9 == 4 + + result10 = long_else_if_chain(5) + assert result10 == 0 + + return + +def asymmetric_depth(outer, mid, inner): + x: Mut = 1000 + if outer == 0: + x = x + 100 + if mid == 0: + x = x + 10 + if inner == 0: + x = x + 1 + else: + x = x + 2 + else: + x = x + 20 + if inner == 0: + x = x + 1 + else: + x = x + 2 + else: + x = 2000 + return x + +def unbalanced_modifications(cond): + x: Mut = 5 + if cond == 0: + x = x + 5 # 10 + x = x * 2 # 20 + x = x + 5 # 25 + else: + x = 110 + return x + +def empty_else(cond): + x: Mut = 5 + if cond == 1: + x = x + 10 + return x + +def long_else_if_chain(n): + result: Mut = 0 + if n == 0: + result = 1 + else if n == 1: + result = 2 + else if n == 2: + result = 3 + else if n == 3: + result = 4 + else if n == 4: + result = 5 + return result \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_119.py b/crates/lean_compiler/tests/test_data/program_119.py new file mode 100644 index 00000000..e9c29647 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_119.py @@ -0,0 +1,125 @@ +from snark_lib import * + + +def main(): + assert five_way(0) == 1 + assert five_way(1) == 2 + assert five_way(2) == 4 + assert five_way(3) == 8 + assert five_way(4) == 16 + + a0, b0, c0 = five_way_multi(0) + assert a0 == 1 + assert b0 == 0 + assert c0 == 0 + + a2, b2, c2 = five_way_multi(2) + assert a2 == 0 + assert b2 == 0 + assert c2 == 4 + + a4, b4, c4 = five_way_multi(4) + assert a4 == 0 + assert b4 == 16 + assert c4 == 0 + + result = five_way_postmerge(2) + assert result == 24 # 4 * 6 + + result2 = nested_five_way(1, 2) + assert result2 == 1020 + + return + + +def five_way(n): + x: Mut = 0 + match n: + case 0: + x = 1 + case 1: + x = 2 + case 2: + x = 4 + case 3: + x = 8 + case 4: + x = 16 + return x + + +def five_way_multi(n): + a: Mut = 0 + b: Mut = 0 + c: Mut = 0 + match n: + case 0: + a = 1 + case 1: + b = 2 + case 2: + c = 4 + case 3: + a = 8 + case 4: + b = 16 + return a, b, c + + +def five_way_postmerge(n): + x: Mut = 0 + match n: + case 0: + x = 1 + case 1: + x = 2 + case 2: + x = 4 + case 3: + x = 8 + case 4: + x = 16 + result = x * 6 + return result + + +def nested_five_way(outer, inner): + x: Mut = 1000 + match outer: + case 0: + match inner: + case 0: + x = x + 1 + case 1: + x = x + 2 + case 2: + x = x + 3 + case 3: + x = x + 4 + case 4: + x = x + 5 + case 1: + match inner: + case 0: + x = x + 10 + case 1: + x = x + 20 + case 2: + x = x + 20 + case 3: + x = x + 30 + case 4: + x = x + 40 + case 2: + match inner: + case 0: + x = x + 100 + case 1: + x = x + 200 + case 2: + x = x + 300 + case 3: + x = x + 400 + case 4: + x = x + 500 + return x diff --git a/crates/lean_compiler/tests/test_data/program_12.py b/crates/lean_compiler/tests/test_data/program_12.py new file mode 100644 index 00000000..7d6696e9 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_12.py @@ -0,0 +1,27 @@ +from snark_lib import * +N = 10 + +def main(): + arr = Array(N) + fill_array(arr) + print_array(arr) + return + +def fill_array(arr): + for i in range(0, N): + if i == 0: + arr[i] = 10 + else if i == 1: + arr[i] = 20 + else if i == 2: + arr[i] = 30 + else: + i_plus_one = i + 1 + arr[i] = i_plus_one + return + +def print_array(arr): + for i in range(0, N): + arr_i = arr[i] + print(arr_i) + return \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_120.py b/crates/lean_compiler/tests/test_data/program_120.py new file mode 100644 index 00000000..eb89b7d5 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_120.py @@ -0,0 +1,49 @@ +from snark_lib import * + + +def main(): + final_state = state_machine(0) + assert final_state == 3 + + counter = counting_machine() + assert counter == 10 + + fib = fib_machine() + assert fib == 13 + + result = conditional_loop_machine() + assert result == 45 + + return + + +def state_machine(initial): + state: Mut = initial + for i in unroll(0, 3): + state = state + 1 + return state + + +def counting_machine(): + counter: Mut = 0 + for i in unroll(0, 5): + counter = counter + 1 + counter = counter + 1 + return counter + + +def fib_machine(): + a: Mut = 1 + b: Mut = 1 + for i in unroll(0, 5): + temp = a + b + a = b + b = temp + return b + + +def conditional_loop_machine(): + sum: Mut = 0 + for i in unroll(0, 10): + sum = sum + i + return sum diff --git a/crates/lean_compiler/tests/test_data/program_121.py b/crates/lean_compiler/tests/test_data/program_121.py new file mode 100644 index 00000000..d9080a7d --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_121.py @@ -0,0 +1,79 @@ +from snark_lib import * + + +def main(): + r1 = reuse_after_conditional(0) + assert r1 == 15 + + r2 = reuse_after_conditional(1) + assert r2 == 35 + + r3 = sequential_conditionals(0, 0) + assert r3 == 5 + + r4 = sequential_conditionals(1, 1) + assert r4 == 36 + + r5 = nested_read_write(0) + assert r5 == 20 + + r6 = nested_read_write(1) + assert r6 == 40 + + r7 = diamond_continue(0, 0) + assert r7 == 222 + + r8 = diamond_continue(1, 1) + assert r8 == 484 + + return + + +def reuse_after_conditional(cond): + x: Mut = 10 + if cond == 1: + x = x + 20 # x = 30 + result = x + 5 # 15 or 35 + return result + + +def sequential_conditionals(c1, c2): + x: Mut = 5 + + if c1 == 1: + x = x + 10 # 15 + + if c2 == 1: + x = x * 2 # 10 or 30 + x = x + 6 # 16 or 36 + + return x + + +def nested_read_write(cond): + x: Mut = 10 + if cond == 0: + y = x * 2 # Read x (10), compute 20 + x = y # Write x = 20 + else: + y = x * 3 # Read x (10), compute 30 + z = y + x # Read both (30 + 10 = 40) + x = z # Write x = 40 + return x + + +def diamond_continue(c1, c2): + x: Mut = 100 + if c1 == 0: + x = x + 10 # 110 + else: + x = x + 20 # 120 + x = x * 2 # 220 or 240 + + if c2 == 0: + x = x + 2 # 222 or 242 + else: + x = x * 2 # 440 or 480 + x = x + 4 # 444 or 484 + + return x diff --git a/crates/lean_compiler/tests/test_data/program_122.py b/crates/lean_compiler/tests/test_data/program_122.py new file mode 100644 index 00000000..d04c9085 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_122.py @@ -0,0 +1,85 @@ +from snark_lib import * +def main(): + r1 = partial_match_update(0) + assert r1 == 100 + + r2 = partial_match_update(1) + assert r2 == 10 + + r3 = partial_match_update(2) + assert r3 == 200 + + a1, b1, c1 = scattered_updates(0) + assert a1 == 1 + assert b1 == 0 + assert c1 == 0 + + a2, b2, c2 = scattered_updates(1) + assert a2 == 0 + assert b2 == 2 + assert c2 == 0 + + a3, b3, c3 = scattered_updates(2) + assert a3 == 0 + assert b3 == 0 + assert c3 == 3 + + r4 = sandwich_phi(0) + assert r4 == 60 + + r5 = sandwich_phi(1) + assert r5 == 80 + + r6 = loop_partial_match() + assert r6 == 10 # 1+2+3+4 + + return + +def partial_match_update(selector): + x: Mut = 10 + match selector: + case 0: + x = 100 # Modified + case 1: + case 2: + x = 200 # Modified + return x + +def scattered_updates(selector): + a: Mut = 0 + b: Mut = 0 + c: Mut = 0 + match selector: + case 0: + a = 1 + case 1: + b = 2 + case 2: + c = 3 + return a, b, c + +def sandwich_phi(cond): + x: Mut = 10 + x = x * 2 # Pre-branch: x = 20 + + if cond == 0: + x = x + 10 # 30 + else: + x = x + 20 # 40 + + x = x * 2 # Post-branch: 60 or 80 + return x + +def loop_partial_match(): + sum: Mut = 0 + for i in unroll(0, 4): + match i: + case 0: + sum = sum + 1 + case 1: + sum = sum + 2 + case 2: + sum = sum + 3 + case 3: + sum = sum + 4 + return sum \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_123.py b/crates/lean_compiler/tests/test_data/program_123.py new file mode 100644 index 00000000..ff889e07 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_123.py @@ -0,0 +1,84 @@ +from snark_lib import * + + +def main(): + r1 = many_assigns() + assert r1 == 256 + + r2 = self_referential() + assert r2 == 120 + + a, b, c, d = quad_assign() + assert a == 10 + assert b == 20 + assert c == 30 + assert d == 40 + + r3 = interleaved_ops() + assert r3 == 100 + + r4 = dependency_chain() + assert r4 == 55 + + return + + +def many_assigns(): + x: Mut = 1 + x = 2 + x = 4 + x = 8 + x = 16 + x = 32 + x = 64 + x = 128 + x = 256 + return x + + +def self_referential(): + x: Mut = 1 + x = x * 2 # 2 + x = x * 3 # 6 + x = x * 4 # 24 + x = x * 5 # 120 + return x + + +def quad_assign(): + a: Mut = 0 + b: Mut = 0 + c: Mut = 0 + d: Mut = 0 + + a = 10 + b = 20 + c = 30 + d = 40 + + return a, b, c, d + + +def interleaved_ops(): + x: Mut = 5 + y: Mut = 10 + + temp = x + y # Read both: 15 + x = temp # x = 15 + temp2 = x * 2 # Read x: 30 + y = temp2 # y = 30 + x = y + x # x = 30 + 15 = 45 + y = x + y # y = 45 + 30 = 75 + x = y + 25 # x = 100 + + return x + + +def dependency_chain(): + x: Mut = 1 + x = x + 2 # 3 + x = x + 4 # 7 + x = x + 8 # 15 + x = x + 16 # 31 + x = x + 24 # 55 + return x diff --git a/crates/lean_compiler/tests/test_data/program_124.py b/crates/lean_compiler/tests/test_data/program_124.py new file mode 100644 index 00000000..d3be1241 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_124.py @@ -0,0 +1,92 @@ +from snark_lib import * + + +def main(): + r1 = five_deep_if(1, 1, 1, 1, 1) + assert r1 == 11111 + + r2 = five_deep_if(1, 1, 0, 0, 0) # a=1,b=1: 10000+1000 = 11000 + assert r2 == 11000 + + r3 = mixed_deep(0, 0, 0) + assert r3 == 111 + + r4 = mixed_deep(1, 1, 1) + assert r4 == 222 + + a, b = dual_mut_deep(0, 0) + assert a == 110 + assert b == 1 + + a2, b2 = dual_mut_deep(1, 1) + assert a2 == 1 + assert b2 == 220 + + return + + +def five_deep_if(a, b, c, d, e): + x: Mut = 0 + if a == 1: + x = x + 10000 + if b == 1: + x = x + 1000 + if c == 1: + x = x + 100 + if d == 1: + x = x + 10 + if e == 1: + x = x + 1 + return x + + +def mixed_deep(outer_match, inner_if, innermost_match): + x: Mut = 0 + match outer_match: + case 0: + x = x + 100 + if inner_if == 0: + match innermost_match: + case 0: + x = x + 11 + case 1: + x = x + 12 + else: + match innermost_match: + case 0: + x = x + 21 + case 1: + x = x + 22 + case 1: + x = x + 200 + if inner_if == 0: + match innermost_match: + case 0: + x = x + 11 + case 1: + x = x + 12 + else: + match innermost_match: + case 0: + x = x + 21 + case 1: + x = x + 22 + return x + + +def dual_mut_deep(c1, c2): + a: Mut = 1 + b: Mut = 1 + if c1 == 0: + a = a + 100 + if c2 == 0: + a = a + 9 + else: + a = a + 19 + else: + b = b + 200 + if c2 == 0: + b = b + 9 + else: + b = b + 19 + return a, b diff --git a/crates/lean_compiler/tests/test_data/program_125.py b/crates/lean_compiler/tests/test_data/program_125.py new file mode 100644 index 00000000..b48caecd --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_125.py @@ -0,0 +1,66 @@ +from snark_lib import * + + +def main(): + r1 = cross_level_mut() + assert r1 == 130 + + a, b = loop_match_multi_mut() + assert a == 9 # 1+3+5 = 9 + assert b == 6 # 2+4 = 6 + + r2 = triple_same_mut() + assert r2 == 492 + + r3 = outer_used_after_inner() + assert r3 == 45 + + return + + +def cross_level_mut(): + total: Mut = 0 + for i in unroll(0, 5): + local: Mut = i * 10 + for j in unroll(0, 3): + local = local + j + total = total + 1 + total = total + local + return total + + +def loop_match_multi_mut(): + sum_a: Mut = 0 + sum_b: Mut = 0 + for i in unroll(0, 5): + match i: + case 0: + sum_a = sum_a + 1 + case 1: + sum_b = sum_b + 2 + case 2: + sum_a = sum_a + 3 + case 3: + sum_b = sum_b + 4 + case 4: + sum_a = sum_a + 5 + return sum_a, sum_b + + +def triple_same_mut(): + x: Mut = 0 + for i in unroll(0, 3): + x = x + i * 100 + for j in unroll(0, 4): + x = x + j * 10 + for k in unroll(0, 2): + x = x + k + return x + + +def outer_used_after_inner(): + outer: Mut = 0 + for i in unroll(0, 5): + for j in unroll(0, 3): + outer = outer + i + j + return outer diff --git a/crates/lean_compiler/tests/test_data/program_126.py b/crates/lean_compiler/tests/test_data/program_126.py new file mode 100644 index 00000000..815fb1e7 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_126.py @@ -0,0 +1,51 @@ +from snark_lib import * + + +def main(): + result1 = match_on_loop_var() + assert result1 == 100 + + result2 = match_computed_selector() + assert result2 == 10 + + result3 = match_complex_selector() + assert result3 == 222 + + return + + +def match_on_loop_var(): + acc: Mut = 0 + for i in unroll(0, 2): + match i: + case 0: + acc = acc + i # acc = 0 + case 1: + acc = acc + 100 # acc = 100 + return acc + + +def match_computed_selector(): + acc: Mut = 0 + for i in unroll(0, 4): + selector = i % 2 # Use actual modulo + match selector: + case 0: + acc = acc + i + case 1: + acc = acc + i * 2 + return acc + + +def match_complex_selector(): + sum: Mut = 0 + for i in unroll(0, 6): + selector = i % 3 + match selector: + case 0: + sum = sum + 1 # i=0,3: sum += 2 + case 1: + sum = sum + 10 # i=1,4: sum += 20 + case 2: + sum = sum + 100 # i=2,5: sum += 200 + return sum diff --git a/crates/lean_compiler/tests/test_data/program_127.py b/crates/lean_compiler/tests/test_data/program_127.py new file mode 100644 index 00000000..c960998c --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_127.py @@ -0,0 +1,74 @@ +from snark_lib import * +def main(): + assert diamond_nested(0, 0) == 10 + assert diamond_nested(0, 1) == 20 + assert diamond_nested(1, 0) == 30 + assert diamond_nested(1, 1) == 40 + + assert sequential_diamonds(0, 0) == 103 + assert sequential_diamonds(0, 1) == 105 + assert sequential_diamonds(1, 0) == 203 + assert sequential_diamonds(1, 1) == 205 + + a, b, c = multi_var_diamond(0, 0) + assert a == 1 + assert b == 10 + assert c == 100 + + a2, b2, c2 = multi_var_diamond(1, 1) + assert a2 == 4 + assert b2 == 40 + assert c2 == 400 + + return + +def diamond_nested(outer_cond, inner_cond): + result: Mut = 0 + if outer_cond == 0: + if inner_cond == 0: + result = 10 + else: + result = 20 + else: + if inner_cond == 0: + result = 30 + else: + result = 40 + return result + +def sequential_diamonds(cond1, cond2): + x: Mut = 100 + + if cond1 == 0: + x = x + 1 + else: + x = x + 101 + + if cond2 == 0: + x = x + 2 + else: + x = x + 4 + + return x + +def multi_var_diamond(c1, c2): + a: Mut = 0 + b: Mut = 0 + c: Mut = 0 + + if c1 == 0: + a = 1 + b = 10 + c = 100 + else: + a = 2 + b = 20 + c = 200 + + if c2 == 0: + else: + a = a * 2 + b = b * 2 + c = c * 2 + + return a, b, c \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_128.py b/crates/lean_compiler/tests/test_data/program_128.py new file mode 100644 index 00000000..409cb944 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_128.py @@ -0,0 +1,34 @@ +from snark_lib import * + + +def main(): + a: Mut = 5 + b: Mut = 10 + + a += 3 + assert a == 8 + + a -= 2 + assert a == 6 + + a *= 4 + assert a == 24 + + a /= 3 + assert a == 8 + + b += a + assert b == 18 + + b -= 3 + b *= 2 + assert b == 30 + + c: Mut = 100 + c /= 4 + c -= 5 + c += 10 + c *= 3 + assert c == 90 + + return diff --git a/crates/lean_compiler/tests/test_data/program_129.py b/crates/lean_compiler/tests/test_data/program_129.py new file mode 100644 index 00000000..4dfbbd8c --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_129.py @@ -0,0 +1,12 @@ +from snark_lib import * + + +def main(): + x: Mut = Array(1) + x[0] = 1 + for i in unroll(0, 5): + y = Array(1) + y[0] = x[0] + 1 + x = y + assert x[0] == 6 + return diff --git a/crates/lean_compiler/tests/test_data/program_13.py b/crates/lean_compiler/tests/test_data/program_13.py new file mode 100644 index 00000000..ed37bb8d --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_13.py @@ -0,0 +1,19 @@ +from snark_lib import * + + +def main(): + for i in range(0, 10): + for j in range(i, 10): + for k in range(j, 10): + sum, prod = compute_sum_and_product(i, j, k) + if sum == 10: + print(i, j, k, prod) + return + + +def compute_sum_and_product(a, b, c): + s1 = a + b + sum = s1 + c + p1 = a * b + product = p1 * c + return sum, product diff --git a/crates/lean_compiler/tests/test_data/program_130.py b/crates/lean_compiler/tests/test_data/program_130.py new file mode 100644 index 00000000..5749d8f1 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_130.py @@ -0,0 +1,20 @@ +from snark_lib import * +# Test: Mutable variables in non-unrolled loops +# This tests the automatic buffer transformation for mutable variables + + +def main(): + x: Mut = 0 + y: Mut = 3 + x += y + y += x + assert x == 3 + assert y == 6 + for i in range(4, 6): + x += i + x += y + y = i + y += x + assert x == 35 + assert y == 40 + return diff --git a/crates/lean_compiler/tests/test_data/program_131.py b/crates/lean_compiler/tests/test_data/program_131.py new file mode 100644 index 00000000..c36f8fce --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_131.py @@ -0,0 +1,11 @@ +from snark_lib import * +# Test: Simple mutable variable in non-unrolled loop +# Sum of 1 to 10 + + +def main(): + s: Mut = 0 + for i in range(1, 11): + s += i + assert s == 55 + return diff --git a/crates/lean_compiler/tests/test_data/program_132.py b/crates/lean_compiler/tests/test_data/program_132.py new file mode 100644 index 00000000..611ca57d --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_132.py @@ -0,0 +1,11 @@ +from snark_lib import * +# Test: Mutable variables with different operations in non-unrolled loop + + +def main(): + product: Mut = 1 + for i in range(1, 6): + product *= i + # 1 * 2 * 3 * 4 * 5 = 120 + assert product == 120 + return diff --git a/crates/lean_compiler/tests/test_data/program_133.py b/crates/lean_compiler/tests/test_data/program_133.py new file mode 100644 index 00000000..e25c7b34 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_133.py @@ -0,0 +1,16 @@ +from snark_lib import * +# Test: Nested non-unrolled loops with mutable variables +# Computes sum of i*j for i in 0..3, j in 0..4 + + +def main(): + total: Mut = 0 + for i in range(0, 3): + for j in range(0, 4): + total += i * j + # i=0: 0*0 + 0*1 + 0*2 + 0*3 = 0 + # i=1: 1*0 + 1*1 + 1*2 + 1*3 = 6 + # i=2: 2*0 + 2*1 + 2*2 + 2*3 = 12 + # total = 0 + 6 + 12 = 18 + assert total == 18 + return diff --git a/crates/lean_compiler/tests/test_data/program_134.py b/crates/lean_compiler/tests/test_data/program_134.py new file mode 100644 index 00000000..1d2a4363 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_134.py @@ -0,0 +1,23 @@ +from snark_lib import * +# Test: Conditionals inside non-unrolled loop with mutable variables +# Tests if/else branches that modify mutable variables differently + + +def main(): + a: Mut = 0 + b: Mut = 100 + for i in range(0, 5): + if i == 2: + a += 10 + b -= 50 + else: + a += 1 + b -= 1 + # i=0: a=1, b=99 + # i=1: a=2, b=98 + # i=2: a=12, b=48 + # i=3: a=13, b=47 + # i=4: a=14, b=46 + assert a == 14 + assert b == 46 + return diff --git a/crates/lean_compiler/tests/test_data/program_135.py b/crates/lean_compiler/tests/test_data/program_135.py new file mode 100644 index 00000000..1952905a --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_135.py @@ -0,0 +1,22 @@ +from snark_lib import * +# Test: Match statement inside non-unrolled loop with mutable variables + + +def main(): + score: Mut = 0 + for i in range(0, 4): + match i: + case 0: + score += 100 + case 1: + score += 50 + case 2: + score += 25 + case 3: + score += 10 + # i=0: score=100 + # i=1: score=150 + # i=2: score=175 + # i=3: score=185 + assert score == 185 + return diff --git a/crates/lean_compiler/tests/test_data/program_136.py b/crates/lean_compiler/tests/test_data/program_136.py new file mode 100644 index 00000000..07a3b515 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_136.py @@ -0,0 +1,20 @@ +from snark_lib import * +# Test: Complex nested loops with multiple mutable variables +# Outer loop updates one set of vars, inner loop updates another, +# and they interact with each other + + +def main(): + outer_sum: Mut = 0 + inner_count: Mut = 0 + for i in range(1, 4): + outer_sum += i * 10 + for j in range(0, i): + inner_count += 1 + outer_sum += j + # i=1: outer_sum=10, inner: j=0: inner_count=1, outer_sum=10 + # i=2: outer_sum=30, inner: j=0: inner_count=2, outer_sum=30; j=1: inner_count=3, outer_sum=31 + # i=3: outer_sum=61, inner: j=0: inner_count=4, outer_sum=61; j=1: inner_count=5, outer_sum=62; j=2: inner_count=6, outer_sum=64 + assert outer_sum == 64 + assert inner_count == 6 + return diff --git a/crates/lean_compiler/tests/test_data/program_137.py b/crates/lean_compiler/tests/test_data/program_137.py new file mode 100644 index 00000000..a2551f38 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_137.py @@ -0,0 +1,21 @@ +from snark_lib import * +# Test: Deeply nested conditionals inside non-unrolled loop + +def main(): + result: Mut = 0 + for i in range(0, 6): + if i == 0: + result += 1 + else if i == 1: + result += 2 + else if i == 2: + result += 4 + else if i == 3: + result += 8 + else if i == 4: + result += 16 + else: + result += 32 + # Powers of 2: 1 + 2 + 4 + 8 + 16 + 32 = 63 + assert result == 63 + return \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_138.py b/crates/lean_compiler/tests/test_data/program_138.py new file mode 100644 index 00000000..f2977c49 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_138.py @@ -0,0 +1,17 @@ +from snark_lib import * +# Test: Mix of unrolled outer loop and non-unrolled inner loop with mutable vars + + +def main(): + total: Mut = 0 + for i in unroll(0, 3): + # Inner loop is non-unrolled, uses mutable variable + inner_sum: Mut = 0 + for j in range(0, 4): + inner_sum += j + i + total += inner_sum + # i=0: inner_sum = 0+1+2+3 = 6, total = 6 + # i=1: inner_sum = 1+2+3+4 = 10, total = 16 + # i=2: inner_sum = 2+3+4+5 = 14, total = 30 + assert total == 30 + return diff --git a/crates/lean_compiler/tests/test_data/program_139.py b/crates/lean_compiler/tests/test_data/program_139.py new file mode 100644 index 00000000..c49f1373 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_139.py @@ -0,0 +1,25 @@ +from snark_lib import * +# Test: Mutable variable with array operations inside non-unrolled loop + + +def main(): + arr = Array(5) + arr[0] = 10 + arr[1] = 20 + arr[2] = 30 + arr[3] = 40 + arr[4] = 50 + + sum: Mut = 0 + prev: Mut = 0 + for i in range(0, 5): + val = arr[i] + sum += val + # Track running difference + diff = val - prev + prev = val + # sum = 10 + 20 + 30 + 40 + 50 = 150 + # prev = 50 (last value) + assert sum == 150 + assert prev == 50 + return diff --git a/crates/lean_compiler/tests/test_data/program_14.py b/crates/lean_compiler/tests/test_data/program_14.py new file mode 100644 index 00000000..6a31656b --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_14.py @@ -0,0 +1,18 @@ +from snark_lib import * + + +def main(): + arr = Array(10) + arr[6] = 42 + arr[8] = 11 + sum_1 = func_1(arr[6], arr[8]) + assert sum_1 == 53 + return + + +@inline +def func_1(i, j): + for k in range(0, i): + for u in range(0, j): + assert k + u != 1000000 + return i + j diff --git a/crates/lean_compiler/tests/test_data/program_140.py b/crates/lean_compiler/tests/test_data/program_140.py new file mode 100644 index 00000000..ca8c7129 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_140.py @@ -0,0 +1,13 @@ +from snark_lib import * +# Test: Three levels of nested non-unrolled loops with mutable variable + + +def main(): + count: Mut = 0 + for i in range(0, 2): + for j in range(0, 3): + for k in range(0, 4): + count += 1 + # Total iterations: 2 * 3 * 4 = 24 + assert count == 24 + return diff --git a/crates/lean_compiler/tests/test_data/program_141.py b/crates/lean_compiler/tests/test_data/program_141.py new file mode 100644 index 00000000..1ef3b0a2 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_141.py @@ -0,0 +1,27 @@ +from snark_lib import * +# Test: Match with conditions inside non-unrolled loop + + +def main(): + a: Mut = 0 + b: Mut = 0 + for i in range(0, 3): + match i: + case 0: + a += 1 + if a == 1: + b += 10 + case 1: + a += 2 + if a == 3: + b += 20 + case 2: + a += 4 + if a == 7: + b += 40 + # i=0: a=1, b=10 + # i=1: a=3, b=30 + # i=2: a=7, b=70 + assert a == 7 + assert b == 70 + return diff --git a/crates/lean_compiler/tests/test_data/program_142.py b/crates/lean_compiler/tests/test_data/program_142.py new file mode 100644 index 00000000..b15fb77b --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_142.py @@ -0,0 +1,294 @@ +from snark_lib import * +# Comprehensive stress test for mutable variables in non-unrolled loops +# Tests: nested loops, conditionals, match, multiple mutable vars, edge cases + +def main(): + # ========================================================================= + # TEST 1: Triple nested loops with multiple interacting mutable variables + # ========================================================================= + a: Mut = 0 + b: Mut = 1 + c: Mut = 100 + for i in range(0, 3): + for j in range(0, 4): + for k in range(0, 2): + a += 1 + b += a + c -= 1 + # a = 3*4*2 = 24 increments = 24 + # b = 1 + 1 + 3 + 6 + 10 + 15 + 21 + 28 + 36 + 45 + 55 + 66 + 78 + 91 + 105 + 120 + 136 + 153 + 171 + 190 + 210 + 231 + 253 + 276 + 300 = 301 + # c = 100 - 24 = 76 + assert a == 24 + assert b == 301 + assert c == 76 + + # ========================================================================= + # TEST 2: Mutable variable modified differently in if/else branches + # ========================================================================= + x: Mut = 0 + y: Mut = 0 + for i in range(0, 8): + if i == 0: + x += 100 + y += 1 + else if i == 1: + x += 50 + y += 2 + else if i == 2: + x += 25 + y += 4 + else if i == 3: + x -= 10 + y += 8 + else: + x += i + y *= 2 + # i=0: x=100, y=1 + # i=1: x=150, y=3 + # i=2: x=175, y=7 + # i=3: x=165, y=15 + # i=4: x=169, y=30 + # i=5: x=174, y=60 + # i=6: x=180, y=120 + # i=7: x=187, y=240 + assert x == 187 + assert y == 240 + + # ========================================================================= + # TEST 3: Match statements with mutable variables in nested loop + # ========================================================================= + score: Mut = 0 + multiplier: Mut = 1 + for round in range(0, 3): + for action in range(0, 4): + match action: + case 0: + score += 10 * multiplier + case 1: + score += 5 * multiplier + multiplier += 1 + case 2: + score -= 2 * multiplier + case 3: + multiplier *= 2 + score += multiplier + # Round 0: action 0: score=10, mult=1 + # action 1: score=15, mult=2 + # action 2: score=11, mult=2 + # action 3: mult=4, score=15 + # Round 1: action 0: score=55, mult=4 + # action 1: score=75, mult=5 + # action 2: score=65, mult=5 + # action 3: mult=10, score=75 + # Round 2: action 0: score=175, mult=10 + # action 1: score=225, mult=11 + # action 2: score=203, mult=11 + # action 3: mult=22, score=225 + assert score == 225 + assert multiplier == 22 + + # ========================================================================= + # TEST 4: Loop with non-zero start index + # ========================================================================= + sum_from_5: Mut = 0 + for i in range(5, 10): + sum_from_5 += i + # 5 + 6 + 7 + 8 + 9 = 35 + assert sum_from_5 == 35 + + # ========================================================================= + # TEST 5: Single iteration loop (edge case) + # ========================================================================= + single: Mut = 42 + for i in range(7, 8): + single += i + assert single == 49 + + # ========================================================================= + # TEST 6: Mutable variable reassigned multiple times per iteration + # ========================================================================= + multi: Mut = 0 + for i in range(1, 5): + multi += i + multi *= 2 + multi -= 1 + multi += i + # i=1: multi = 0+1=1, *2=2, -1=1, +1=2 + # i=2: multi = 2+2=4, *2=8, -1=7, +2=9 + # i=3: multi = 9+3=12, *2=24, -1=23, +3=26 + # i=4: multi = 26+4=30, *2=60, -1=59, +4=63 + assert multi == 63 + + # ========================================================================= + # TEST 7: Mutable variables with array operations + # ========================================================================= + arr = Array(6) + arr[0] = 1 + arr[1] = 2 + arr[2] = 4 + arr[3] = 8 + arr[4] = 16 + arr[5] = 32 + + arr_sum: Mut = 0 + arr_prod: Mut = 1 + last_val: Mut = 0 + for idx in range(0, 6): + val = arr[idx] + arr_sum += val + arr_prod *= val + 1 + last_val = val + # sum = 1+2+4+8+16+32 = 63 + # prod = 2*3*5*9*17*33 = 151470 + # last_val = 32 + assert arr_sum == 63 + assert arr_prod == 151470 + assert last_val == 32 + + # ========================================================================= + # TEST 8: Nested conditionals inside nested loops + # ========================================================================= + complex: Mut = 0 + for i in range(0, 3): + for j in range(0, 3): + if i == j: + if i == 0: + complex += 100 + else if i == 1: + complex += 200 + else: + complex += 300 + else: + if i == 0: + complex += 1 + else: + complex += 2 + # i=0,j=0: i==j, i==0: +100 -> 100 + # i=0,j=1: i!=j, i==0: +1 -> 101 + # i=0,j=2: i!=j, i==0: +1 -> 102 + # i=1,j=0: i!=j, i!=0: +2 -> 104 + # i=1,j=1: i==j, i==1: +200 -> 304 + # i=1,j=2: i!=j, i!=0: +2 -> 306 + # i=2,j=0: i!=j, i!=0: +2 -> 308 + # i=2,j=1: i!=j, i!=0: +2 -> 310 + # i=2,j=2: i==j, i==2: +300 -> 610 + assert complex == 610 + + # ========================================================================= + # TEST 9: Function calls with mutable variables + # ========================================================================= + func_result: Mut = 0 + for i in range(1, 6): + increment = compute_increment(i) + func_result += increment + # compute_increment(1) = 1 + # compute_increment(2) = 4 + # compute_increment(3) = 9 + # compute_increment(4) = 16 + # compute_increment(5) = 25 + # sum = 1 + 4 + 9 + 16 + 25 = 55 + assert func_result == 55 + + # ========================================================================= + # TEST 10: Outer mutable modified by inner loop result + # ========================================================================= + outer_acc: Mut = 0 + for i in range(1, 4): + inner_acc: Mut = 0 + for j in range(0, i): + inner_acc += j + 1 + outer_acc += inner_acc * i + # i=1: inner_acc = 1, outer_acc = 1*1 = 1 + # i=2: inner_acc = 1+2 = 3, outer_acc = 1 + 3*2 = 7 + # i=3: inner_acc = 1+2+3 = 6, outer_acc = 7 + 6*3 = 25 + assert outer_acc == 25 + + # ========================================================================= + # TEST 11: Large number of iterations + # ========================================================================= + big_sum: Mut = 0 + for i in range(0, 100): + big_sum += 1 + assert big_sum == 100 + + # ========================================================================= + # TEST 12: Mutable with division and subtraction + # ========================================================================= + countdown: Mut = 1000 + steps: Mut = 0 + for i in range(1, 11): + countdown -= i * 10 + steps += 1 + # countdown = 1000 - 10 - 20 - 30 - 40 - 50 - 60 - 70 - 80 - 90 - 100 + # = 1000 - 550 = 450 + assert countdown == 450 + assert steps == 10 + + # ========================================================================= + # TEST 13: Mix of unrolled inner and non-unrolled outer + # ========================================================================= + mixed: Mut = 0 + for i in range(0, 4): + for j in unroll(0, 3): + mixed += i * 3 + j + # i=0: 0+1+2 = 3 + # i=1: 3+4+5 = 12 + # i=2: 6+7+8 = 21 + # i=3: 9+10+11 = 30 + # total = 3+12+21+30 = 66 + assert mixed == 66 + + # ========================================================================= + # TEST 14: Multiple mutable variables, some modified some not per iteration + # ========================================================================= + always: Mut = 0 + sometimes: Mut = 100 + rarely: Mut = 1000 + for i in range(0, 10): + always += 1 + if i == 3: + sometimes += 50 + if i == 7: + sometimes -= 25 + rarely += 500 + if i == 9: + rarely *= 2 + assert always == 10 + assert sometimes == 125 + assert rarely == 3000 + + # ========================================================================= + # TEST 15: Chained mutable dependencies in same iteration + # ========================================================================= + chain_a: Mut = 1 + chain_b: Mut = 0 + chain_c: Mut = 0 + for i in range(0, 5): + chain_a *= 2 + chain_b = chain_a + i + chain_c += chain_b + # i=0: a=2, b=2+0=2, c=0+2=2 + # i=1: a=4, b=4+1=5, c=2+5=7 + # i=2: a=8, b=8+2=10, c=7+10=17 + # i=3: a=16, b=16+3=19, c=17+19=36 + # i=4: a=32, b=32+4=36, c=36+36=72 + assert chain_a == 32 + assert chain_b == 36 + assert chain_c == 72 + + # ========================================================================= + # TEST 16: Zero-iteration loop (edge case - empty range) + # No iterations should occur for 5..5 + # ========================================================================= + zero_iter: Mut = 999 + for i in range(5, 5): + zero_iter = 0 + assert zero_iter == 999 + + # ========================================================================= + # All tests passed! + # ========================================================================= + return + +def compute_increment(n): + return n * n \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_143.py b/crates/lean_compiler/tests/test_data/program_143.py new file mode 100644 index 00000000..3152abb6 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_143.py @@ -0,0 +1,215 @@ +from snark_lib import * +# Comprehensive test for nested inlined function calls +# Tests various scenarios where inline functions are nested within other calls + +def main(): + # Test 1: Basic nested inline calls - f(g(h(x))) + # incr(incr(incr(5))) = 8 + result1 = incr(incr(incr(5))) + assert result1 == 8 + + # Test 2: Nested inline calls as argument to print (builtin) + # This was the original bug: print(incr(incr(incr(5)))) + print(incr(incr(incr(5)))) + + # Test 3: Multiple inline calls in one expression + # double(3) + triple(2) = 6 + 6 = 12 + result3 = double(3) + triple(2) + assert result3 == 12 + + # Test 4: Nested inline calls in arithmetic expression + # incr(double(3)) * incr(triple(2)) = 7 * 7 = 49 + result4 = incr(double(3)) * incr(triple(2)) + assert result4 == 49 + + # Test 5: Multiple levels of nesting in arithmetic + # double(incr(triple(2))) + incr(double(incr(1))) + # = double(incr(6)) + incr(double(2)) + # = double(7) + incr(4) + # = 14 + 5 = 19 + result5 = double(incr(triple(2))) + incr(double(incr(1))) + assert result5 == 19 + + # Test 6: Inline functions calling other inline functions + # quad(3) = double(double(3)) = double(6) = 12 + result6 = quad(3) + assert result6 == 12 + + # Test 7: Deeply nested composition + # quad(incr(double(1))) = quad(incr(2)) = quad(3) = 12 + result7 = quad(incr(double(1))) + assert result7 == 12 + + # Test 8: Multiple inline calls as arguments to non-inline function + # add_three(incr(1), double(2), triple(1)) = add_three(2, 4, 3) = 9 + result8 = add_three(incr(1), double(2), triple(1)) + assert result8 == 9 + + # Test 9: Nested inline call as argument to non-inline function + # add_three(incr(incr(1)), double(incr(2)), triple(incr(0))) + # = add_three(3, 6, 3) = 12 + result9 = add_three(incr(incr(1)), double(incr(2)), triple(incr(0))) + assert result9 == 12 + + # Test 10: Print multiple nested inline calls + print(double(5), triple(5), quad(2)) + + # Test 11: Complex expression with multiple nested inline calls + # (incr(double(2)) + triple(incr(1))) * double(incr(incr(0))) + # = (incr(4) + triple(2)) * double(2) + # = (5 + 6) * 4 + # = 44 + result11 = (incr(double(2)) + triple(incr(1))) * double(incr(incr(0))) + assert result11 == 44 + + # Test 12: Inline in unrolled loop + sum: Mut = 0 + for i in unroll(0, 4): + sum = sum + incr(i) + # sum = incr(0) + incr(1) + incr(2) + incr(3) = 1 + 2 + 3 + 4 = 10 + assert sum == 10 + + # Test 13: Inline functions in if condition (comparison) + result13: Imu + if incr(incr(0)) == 2: + result13 = 100 + else: + result13 = 0 + assert result13 == 100 + + # Test 14: Nested inline calls in both sides of if condition + result14: Imu + if double(3) == triple(2): + result14 = 1 + else: + result14 = 0 + # double(3) = 6, triple(2) = 6, so they are equal + assert result14 == 1 + + # Test 15: Inline calls inside if/else branches + result15: Imu + if 1 == 1: + result15 = incr(incr(incr(10))) + else: + result15 = 0 + assert result15 == 13 + + # Test 16: Multiple nested inline calls in if condition + result16: Imu + if incr(double(incr(1))) == 5: + # incr(1) = 2, double(2) = 4, incr(4) = 5 + result16 = 200 + else: + result16 = 0 + assert result16 == 200 + + # Test 17: Inline call with != comparison + result17: Imu + if incr(5) != 5: + result17 = 300 + else: + result17 = 0 + assert result17 == 300 + + # Test 18: Assertion with inline functions + assert incr(incr(0)) == 2 + assert double(triple(2)) == 12 + assert quad(incr(1)) == 8 + + # Test 19: Debug assertion with inline functions + debug_assert(incr(5) == 6) + debug_assert(double(incr(2)) == 6) + + # Test 20: Inline in non-unrolled loop + arr = Array(4) + for i in range(0, 4): + arr[i] = incr(i) + assert arr[0] == 1 + assert arr[1] == 2 + assert arr[2] == 3 + assert arr[3] == 4 + + # Test 21: Nested inline calls in non-unrolled loop + arr2 = Array(3) + for i in range(0, 3): + arr2[i] = double(incr(i)) + # double(incr(0)) = double(1) = 2 + # double(incr(1)) = double(2) = 4 + # double(incr(2)) = double(3) = 6 + assert arr2[0] == 2 + assert arr2[1] == 4 + assert arr2[2] == 6 + + # Test 22: Mixing inline and non-inline in complex expression inside loop + sum23: Mut = 0 + for i in unroll(0, 3): + sum23 = sum23 + add_three(incr(i), double(i), triple(i)) + # i=0: add_three(1, 0, 0) = 1 + # i=1: add_three(2, 2, 3) = 7 + # i=2: add_three(3, 4, 6) = 13 + # total = 1 + 7 + 13 = 21 + assert sum23 == 21 + + # Test 24: Chained else-if with inline conditions + result24: Imu + x24 = 5 + if incr(x24) == 4: + result24 = 1 + else if incr(x24) == 5: + result24 = 2 + else if incr(x24) == 6: + result24 = 3 + else: + result24 = 0 + # incr(5) = 6, so third condition matches + assert result24 == 3 + + # Test 25: Inline call as argument to inline call to non-inline function + # add_three takes 3 args, but we nest inline calls in each position + result25 = add_three(quad(1), quad(incr(0)), incr(quad(1))) + # quad(1) = 4 + # quad(incr(0)) = quad(1) = 4 + # incr(quad(1)) = incr(4) = 5 + # add_three(4, 4, 5) = 13 + assert result25 == 13 + + return + +# Simple inline function: increment by 1 +@inline +def incr(a): + b = a + 1 + return b + +# Inline function: multiply by 2 +@inline +def double(x): + return x * 2 + +# Inline function: multiply by 3 +@inline +def triple(x): + if x == 78990: + return 236970 + else: + y: Mut = x + two: Imu + match y - x + 1: + case 0: + assert False + case 1: + two = 2 + for i in range(0, two): + y = y + x + return y + +# Inline function that calls another inline function +@inline +def quad(x): + if x == 78990: + return 157980 + return double(double(x)) + +# Non-inline function that takes multiple arguments +def add_three(a, b, c): + return a + b + c diff --git a/crates/lean_compiler/tests/test_data/program_144.py b/crates/lean_compiler/tests/test_data/program_144.py new file mode 100644 index 00000000..77dd1048 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_144.py @@ -0,0 +1,378 @@ +from snark_lib import * +# Comprehensive tests for mutable variables with early exits (panic/return) in branches. +# This tests the SSA transformation when branches end with assert False or return. +# Bug fix: ensure proper handling of mutable variable unification when some branches exit early. + +def main(): + # ========================================================================== + # TEST 1: Basic - panic in else branch (the original bug case) + # ========================================================================== + two: Imu + if 1 == 1: + two = 2 + else: + assert False + assert two == 2 + + # ========================================================================== + # TEST 2: panic in then branch + # ========================================================================== + three: Imu + if 1 != 1: + assert False + else: + three = 3 + assert three == 3 + + # ========================================================================== + # TEST 3: Multiple mutable variables, panic in else + # ========================================================================== + a: Imu + b: Imu + c: Imu + if 1 == 1: + a = 10 + b = 20 + c = 30 + else: + assert False + assert a == 10 + assert b == 20 + assert c == 30 + + # ========================================================================== + # TEST 4: Nested if with panic in inner else + # ========================================================================== + x: Imu + if 1 == 1: + if 2 == 2: + x = 42 + else: + assert False + else: + assert False + assert x == 42 + + # ========================================================================== + # TEST 5: Mutable modified = None in then, panic in else + # ========================================================================== + counter: Mut = 0 + if 1 == 1: + counter = counter + 5 + else: + assert False + assert counter == 5 + + # ========================================================================== + # TEST 6: Multiple modifications before panic check + # ========================================================================== + val: Mut = 1 + val = val * 2 + val = val + 3 + if val == 5: + val = val * 10 + else: + assert False + assert val == 50 + + # ========================================================================== + # TEST 7: Chain of else-if with panic in final else + # ========================================================================== + result: Imu + selector = 1 + if selector == 0: + result = 100 + else if selector == 1: + result = 200 + else if selector == 2: + result = 300 + else: + assert False + assert result == 200 + + # ========================================================================== + # TEST 8: Match with panic in one arm + # ========================================================================== + matched: Imu + tag = 1 + match tag: + case 0: + assert False + case 1: + matched = 111 + case 2: + assert False + assert matched == 111 + + # ========================================================================== + # TEST 9: Match where only one arm doesn't panic + # ========================================================================== + only_valid: Imu + tag2 = 2 + match tag2: + case 0: + assert False + case 1: + assert False + case 2: + only_valid = 222 + case 3: + assert False + assert only_valid == 222 + + # ========================================================================== + # TEST 10: Panic in deeply nested structure + # ========================================================================== + deep: Imu + if 1 == 1: + if 1 == 1: + if 1 == 1: + deep = 999 + else: + assert False + else: + assert False + else: + assert False + assert deep == 999 + + # ========================================================================== + # TEST 11: Mutable used = None after branch with panic + # ========================================================================== + acc: Mut = 0 + for i in unroll(0, 3): + if 1 == 1: + acc = acc + i + else: + assert False + assert acc == 3 + + # ========================================================================== + # TEST 12: Forward declared with = None panic in branch + # ========================================================================== + fwd: Imu + cond = 1 + if cond == 1: + fwd = 777 + else: + assert False + assert fwd == 777 + + # ========================================================================== + # TEST 13: Both mutable and immutable forward decl with panic + # ========================================================================== + imm: Imu + mtbl: Imu + flag = 0 + if flag == 0: + imm = 100 + mtbl = 200 + else: + assert False + mtbl2: Mut = mtbl + mtbl2 = mtbl2 + 50 + assert imm == 100 + assert mtbl2 == 250 + + # ========================================================================== + # TEST 14: Return in function branch (early exit) + # ========================================================================== + res14 = test_early_return(1) + assert res14 == 10 + res14b = test_early_return(0) + assert res14b == 20 + + # ========================================================================== + # TEST 15: Multiple mutable vars with return in branch + # ========================================================================== + r15a, r15b = test_multi_return(1) + assert r15a == 100 + assert r15b == 200 + + # ========================================================================== + # TEST 16: Mutable with = None panic in match, then more operations + # ========================================================================== + m16: Mut = 5 + sel16 = 0 + match sel16: + case 0: + m16 = m16 * 2 + case 1: + assert False + m16 = m16 + 3 + assert m16 == 13 + + # ========================================================================== + # TEST 17: Nested match with panic + # ========================================================================== + nested_match: Imu + outer = 1 + match outer: + case 0: + assert False + case 1: + inner = 0 + match inner: + case 0: + nested_match = 500 + case 1: + assert False + assert nested_match == 500 + + # ========================================================================== + # TEST 18: If inside match with panic + # ========================================================================== + if_in_match: Imu + m18_sel = 0 + match m18_sel: + case 0: + cond18 = 1 + if cond18 == 1: + if_in_match = 600 + else: + assert False + case 1: + assert False + assert if_in_match == 600 + + # ========================================================================== + # TEST 19: Match inside if with panic + # ========================================================================== + match_in_if: Imu + cond19 = 1 + if cond19 == 1: + tag19 = 1 + match tag19: + case 0: + assert False + case 1: + match_in_if = 700 + else: + assert False + assert match_in_if == 700 + + # ========================================================================== + # TEST 20: Panic after partial assignment + # ========================================================================== + partial: Imu + check = 0 + if check == 0: + partial_tmp: Mut = 1 + partial_tmp = partial_tmp + 1 + partial_tmp = partial_tmp * 2 + partial = partial_tmp + else: + partial = 999 + assert False + assert partial == 4 + + # ========================================================================== + # TEST 21: Unrolled loop with panic in branch at each iteration + # ========================================================================== + sum21: Mut = 0 + for i in unroll(0, 5): + expected = i + if i == expected: + sum21 = sum21 + i + else: + assert False + assert sum21 == 10 + + # ========================================================================== + # TEST 22: Function with mutable param and early return + # ========================================================================== + res22 = func_with_mut_param(5, 1) + assert res22 == 50 + + # ========================================================================== + # TEST 23: Multiple levels - if/match/if with panics + # ========================================================================== + multi_level: Imu + c1 = 1 + if c1 == 1: + s1 = 0 + match s1: + case 0: + c2 = 0 + if c2 == 0: + multi_level = 888 + else: + assert False + case 1: + assert False + else: + assert False + assert multi_level == 888 + + # ========================================================================== + # TEST 24: Panic in both outer branches but inner assigns + # ========================================================================== + inner_assigns: Imu + outer24 = 0 + match outer24: + case 0: + inner24 = 1 + if inner24 == 1: + inner_assigns = 1000 + else: + assert False + case 1: + assert False + assert inner_assigns == 1000 + + # ========================================================================== + # TEST 25: Complex - multiple vars, nested, with panics + # ========================================================================== + va: Imu + vb: Imu + vc: Imu + + outer25 = 1 + if outer25 == 1: + va = 1 + mid25 = 0 + match mid25: + case 0: + vb = 2 + inner25 = 1 + if inner25 == 1: + vc = 3 + else: + assert False + case 1: + assert False + else: + assert False + + total = va + vb + vc + assert total == 6 + + return + +# Helper function for TEST 14 +def test_early_return(flag): + result: Imu + if flag == 1: + result = 10 + else: + result = 20 + return result + +# Helper function for TEST 15 +def test_multi_return(flag): + a: Imu + b: Imu + if flag == 1: + a = 100 + b = 200 + else: + assert False + return a, b + +# Helper function for TEST 22 +def func_with_mut_param(x: Mut, flag): + if flag == 1: + x = x * 10 + else: + assert False + return x \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_145.py b/crates/lean_compiler/tests/test_data/program_145.py new file mode 100644 index 00000000..867d12ee --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_145.py @@ -0,0 +1,117 @@ +from snark_lib import * +# Test for inlined functions with const arguments + +ARR = [1, 2, 3, 4, 5] +SIZES = [2, 4, 3, 5] +NESTED_ARR = [[2, 4, 6], [1, 3, 5]] + + +def main(): + # Test 1: Basic + result1 = sum_to_n(ARR[2]) + assert result1 == 3 + + # Test 2: Inside unroll loop + total: Mut = 0 + for i in unroll(0, 4): + total = total + sum_to_n(SIZES[i]) + assert total == 20 + + # Test 3: Complex expression as argument + result3 = sum_to_n(ARR[1] + ARR[2]) # sum_to_n(5) = 0+1+2+3+4 = 10 + assert result3 == 10 + + # Test 4: Inline function returning const array value as loop bound + bound = get_bound(1) # SIZES[1] = 4 + sum4: Mut = 0 + for j in unroll(0, bound): + sum4 = sum4 + j + assert sum4 == 6 # 0+1+2+3 = 6 + + # Test 5: Zero iterations edge case + result5 = sum_to_n(0) + assert result5 == 0 + + # Test 6: Single iteration + result6 = sum_to_n(1) + assert result6 == 0 # for i in range(0, 1): acc = 0 + + # Test 7: Large count + result7 = sum_to_n(11) # 0+1+2+...+10 = 55 + assert result7 == 55 + + # Test 8: Multiple independent calls in sequence + a = sum_to_n(3) # 0+1+2 = 3 + b = sum_to_n(4) # 0+1+2+3 = 6 + c = sum_to_n(5) # 0+1+2+3+4 = 10 + assert a + b + c == 19 + + # Test 9: Inline function with multiple loop iterations and arithmetic + result9 = product_factorial(4) # 1*2*3*4 = 24 + assert result9 == 24 + + # Test 10: Nested const array access + result10 = sum_to_n(NESTED_ARR[0][1]) # NESTED_ARR[0][1] = 4, sum = 6 + assert result10 == 6 + + # Test 11: Different inline functions with same pattern + for k in unroll(0, 3): + x = sum_to_n(ARR[k]) # ARR[0]=1, ARR[1]=2, ARR[2]=3 + y = sum_squared(ARR[k]) # 0^2 + 1^2 + ... + (n-1)^2 + print(x) + print(y) + + # Test 12: Inline call with expression involving loop variable and const + test12_total: Mut = 0 + for m in unroll(0, 3): + test12_total = test12_total + sum_to_n(ARR[m] + 1) + # ARR[0]+1=2 -> sum=1, ARR[1]+1=3 -> sum=3, ARR[2]+1=4 -> sum=6 + assert test12_total == 10 + + # Test 13: Const expression as inline argument + result13 = sum_to_n(2 + 3) # sum_to_n(5) = 10 + assert result13 == 10 + + # Test 14: Inline function with no loop (simple passthrough) + result14 = double_value(SIZES[2]) # SIZES[2]=3, result=6 + assert result14 == 6 + + return + + +# Product: 1 * 2 * ... * n +@inline +def product_factorial(n): + acc: Mut = 1 + for i in unroll(1, n + 1): + acc = acc * i + return acc + + +# Sum of squares: 0^2 + 1^2 + ... + (n-1)^2 +@inline +def sum_squared(n): + acc: Mut = 0 + for i in unroll(0, n): + acc = acc + i * i + return acc + + +# Returns element from const array +@inline +def get_bound(idx): + return SIZES[idx] + + +# Simple passthrough with arithmetic +@inline +def double_value(x): + return x * 2 + + +@inline +def sum_to_n(n): + acc: Mut = 0 + for i in unroll(0, n): + acc = acc + i + return acc diff --git a/crates/lean_compiler/tests/test_data/program_146.py b/crates/lean_compiler/tests/test_data/program_146.py new file mode 100644 index 00000000..a30d01dc --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_146.py @@ -0,0 +1,11 @@ +from snark_lib import * + + +# Test basic DynArray([]) creation and indexing +def main(): + v = DynArray([1, 2, 3]) + assert v[0] == 1 + assert v[1] == 2 + assert v[2] == 3 + assert len(v) == 3 + return diff --git a/crates/lean_compiler/tests/test_data/program_147.py b/crates/lean_compiler/tests/test_data/program_147.py new file mode 100644 index 00000000..4eaf9f56 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_147.py @@ -0,0 +1,14 @@ +from snark_lib import * + + +# Test .push() on vectors +def main(): + v = DynArray([1, 2, 3]) + assert len(v) == 3 + v.push(4) + assert len(v) == 4 + assert v[3] == 4 + v.push(5) + assert len(v) == 5 + assert v[4] == 5 + return diff --git a/crates/lean_compiler/tests/test_data/program_148.py b/crates/lean_compiler/tests/test_data/program_148.py new file mode 100644 index 00000000..1312850e --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_148.py @@ -0,0 +1,14 @@ +from snark_lib import * + + +# Test nested vectors +def main(): + v = DynArray([DynArray([1, 2]), DynArray([3, 4, 5])]) + assert len(v) == 2 + assert len(v[0]) == 2 + assert len(v[1]) == 3 + assert v[0][0] == 1 + assert v[0][1] == 2 + assert v[1][0] == 3 + assert v[1][2] == 5 + return diff --git a/crates/lean_compiler/tests/test_data/program_149.py b/crates/lean_compiler/tests/test_data/program_149.py new file mode 100644 index 00000000..1831734e --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_149.py @@ -0,0 +1,16 @@ +from snark_lib import * + + +# Test push with nested vectors +def main(): + v = DynArray([DynArray([1, 2])]) + assert len(v) == 1 + v.push(DynArray([3, 4])) + assert len(v) == 2 + assert v[1][0] == 3 + assert v[1][1] == 4 + v.push(DynArray([5, 6, 7])) + assert len(v) == 3 + assert len(v[2]) == 3 + assert v[2][2] == 7 + return diff --git a/crates/lean_compiler/tests/test_data/program_15.py b/crates/lean_compiler/tests/test_data/program_15.py new file mode 100644 index 00000000..70fb86e0 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_15.py @@ -0,0 +1,74 @@ +from snark_lib import * + +BE = 1 # base-extension +EE = 0 # extension-extension + + +def main(): + x = 1 + y = 2 + i, j, k = func_1(x, y) + assert i == 2 + assert j == 3 + assert k == 2130706432 + + g = Array(8) + h = Array(8) + for i in range(0, 8): + g[i] = i + for i in unroll(0, 8): + h[i] = i + assert_eq_1(g, h) + assert_eq_2(g, h) + assert_eq_3(g, h) + assert_eq_4(g, h) + assert_eq_5(g, h) + return + + +@inline +def func_1(a, b): + x = a * b + y = a + b + return x, y, a - b + + +def assert_eq_1(x, y): + x_ptr = x + y_ptr = y + for i in unroll(0, 4): + assert x_ptr[i] == y_ptr[i] + for i in range(4, 8): + assert x_ptr[i] == y_ptr[i] + return + + +@inline +def assert_eq_2(x, y): + x_ptr = x + y_ptr = y + for i in unroll(0, 4): + assert x_ptr[i] == y_ptr[i] + for i in range(4, 8): + assert x_ptr[i] == y_ptr[i] + return + + +@inline +def assert_eq_3(x, y): + u = x + 7 + assert_eq_1(u - 7, y * 7 / 7) + return + + +def assert_eq_4(x, y): + dot_product(x, ONE_VEC_PTR, y, 1, EE) + dot_product(x + 3, ONE_VEC_PTR, y + 3, 1, EE) + return + + +@inline +def assert_eq_5(x, y): + dot_product(x, ONE_VEC_PTR, y, 1, EE) + dot_product(x + 3, ONE_VEC_PTR, y + 3, 1, EE) + return diff --git a/crates/lean_compiler/tests/test_data/program_150.py b/crates/lean_compiler/tests/test_data/program_150.py new file mode 100644 index 00000000..cd05f6f9 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_150.py @@ -0,0 +1,15 @@ +from snark_lib import * + + +# Test vectors with unrolled loops +def main(): + v = DynArray([]) + for i in unroll(0, 5): + v.push(i * 2) + assert len(v) == 5 + assert v[0] == 0 + assert v[1] == 2 + assert v[2] == 4 + assert v[3] == 6 + assert v[4] == 8 + return diff --git a/crates/lean_compiler/tests/test_data/program_151.py b/crates/lean_compiler/tests/test_data/program_151.py new file mode 100644 index 00000000..f9bbd26c --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_151.py @@ -0,0 +1,14 @@ +from snark_lib import * + + +# Test vectors with expression elements +def main(): + x = 10 + y = 20 + v = DynArray([x, y, x + y, x * 2]) + assert len(v) == 4 + assert v[0] == 10 + assert v[1] == 20 + assert v[2] == 30 + assert v[3] == 20 + return diff --git a/crates/lean_compiler/tests/test_data/program_152.py b/crates/lean_compiler/tests/test_data/program_152.py new file mode 100644 index 00000000..eb9c9fd9 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_152.py @@ -0,0 +1,17 @@ +from snark_lib import * + + +# Test vectors with nested unrolled loops +def main(): + v = DynArray([]) + for i in unroll(0, 3): + for j in unroll(0, 2): + v.push(i * 10 + j) + assert len(v) == 6 + assert v[0] == 0 # i=0, j=0 + assert v[1] == 1 # i=0, j=1 + assert v[2] == 10 # i=1, j=0 + assert v[3] == 11 # i=1, j=1 + assert v[4] == 20 # i=2, j=0 + assert v[5] == 21 # i=2, j=1 + return diff --git a/crates/lean_compiler/tests/test_data/program_153.py b/crates/lean_compiler/tests/test_data/program_153.py new file mode 100644 index 00000000..664402cc --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_153.py @@ -0,0 +1,28 @@ +from snark_lib import * + + +# Test pushing nested vectors in unrolled loop +def main(): + v = DynArray([]) + for i in unroll(0, 3): + v.push(DynArray([i, i + 1, i + 2])) + assert len(v) == 3 + assert len(v[0]) == 3 + assert len(v[1]) == 3 + assert len(v[2]) == 3 + + # v[0] = [0, 1, 2] + assert v[0][0] == 0 + assert v[0][1] == 1 + assert v[0][2] == 2 + + # v[1] = [1, 2, 3] + assert v[1][0] == 1 + assert v[1][1] == 2 + assert v[1][2] == 3 + + # v[2] = [2, 3, 4] + assert v[2][0] == 2 + assert v[2][1] == 3 + assert v[2][2] == 4 + return diff --git a/crates/lean_compiler/tests/test_data/program_154.py b/crates/lean_compiler/tests/test_data/program_154.py new file mode 100644 index 00000000..8cb0c7fc --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_154.py @@ -0,0 +1,12 @@ +from snark_lib import * + + +# Test accessing vector elements inside unrolled loop +def main(): + v = DynArray([10, 20, 30, 40, 50]) + + sum: Mut = 0 + for i in unroll(0, 5): + sum = sum + v[i] + assert sum == 150 + return diff --git a/crates/lean_compiler/tests/test_data/program_155.py b/crates/lean_compiler/tests/test_data/program_155.py new file mode 100644 index 00000000..709840d2 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_155.py @@ -0,0 +1,17 @@ +from snark_lib import * + + +# Test building vector and then reading in separate unrolled loops +def main(): + # Build a vector of squares + squares = DynArray([]) + for i in unroll(0, 6): + squares.push(i * i) + + # Verify in a separate loop + for i in unroll(0, 6): + assert squares[i] == i * i + + # Also check len + assert len(squares) == 6 + return diff --git a/crates/lean_compiler/tests/test_data/program_156.py b/crates/lean_compiler/tests/test_data/program_156.py new file mode 100644 index 00000000..af062869 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_156.py @@ -0,0 +1,20 @@ +from snark_lib import * + + +# Test vector with expression using loop variable +def main(): + # Build Fibonacci-like sequence using vector + fib = DynArray([1, 1]) + for i in unroll(2, 8): + fib.push(fib[i - 1] + fib[i - 2]) + + assert len(fib) == 8 + assert fib[0] == 1 + assert fib[1] == 1 + assert fib[2] == 2 + assert fib[3] == 3 + assert fib[4] == 5 + assert fib[5] == 8 + assert fib[6] == 13 + assert fib[7] == 21 + return diff --git a/crates/lean_compiler/tests/test_data/program_157.py b/crates/lean_compiler/tests/test_data/program_157.py new file mode 100644 index 00000000..2181399f --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_157.py @@ -0,0 +1,31 @@ +from snark_lib import * + + +# Test pushing to nested vectors with indices +def main(): + # Create a vector of empty vectors + v = DynArray([DynArray([]), DynArray([]), DynArray([])]) + + # Push to nested vectors using indices + v[0].push(10) + v[0].push(20) + v[1].push(30) + v[2].push(40) + v[2].push(50) + v[2].push(60) + + # Verify structure + assert len(v) == 3 + assert len(v[0]) == 2 + assert len(v[1]) == 1 + assert len(v[2]) == 3 + + # Verify values + assert v[0][0] == 10 + assert v[0][1] == 20 + assert v[1][0] == 30 + assert v[2][0] == 40 + assert v[2][1] == 50 + assert v[2][2] == 60 + + return diff --git a/crates/lean_compiler/tests/test_data/program_158.py b/crates/lean_compiler/tests/test_data/program_158.py new file mode 100644 index 00000000..8f56540c --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_158.py @@ -0,0 +1,33 @@ +from snark_lib import * +# Test pushing to nested vectors in unrolled loops + +ARR = [3, 4, 5] + + +def main(): + # Create a 3-element vector of empty vectors + rows = DynArray([DynArray([]), DynArray([]), DynArray([])]) + + # Fill each row with its index repeated + for i in unroll(0, 3): + for j in unroll(0, 3): + rows[i].push(i + len(ARR) - 1 + ARR[0] - 3 + j - 2) + + # Verify structure + assert len(rows) == 3 + assert len(rows[0]) == 3 + assert len(rows[1]) == 3 + assert len(rows[2]) == 3 + + # Verify values: rows[i][j] == i + j + assert rows[0][0] == 0 + assert rows[0][1] == 1 + assert rows[0][2] == 2 + assert rows[1][0] == 1 + assert rows[1][1] == 2 + assert rows[1][len(ARR) - 1 + ARR[0] - 3] == 3 + assert rows[2][0] == 2 + assert rows[2][1] == 3 + assert rows[2][2] == 4 + + return diff --git a/crates/lean_compiler/tests/test_data/program_159.py b/crates/lean_compiler/tests/test_data/program_159.py new file mode 100644 index 00000000..1d86b9d7 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_159.py @@ -0,0 +1,21 @@ +from snark_lib import * + + +# Test: local vectors inside if/else branches are allowed +def main(): + x = 5 + if x == 5: + # Local vector in then branch - allowed + v = DynArray([1, 2, 3]) + v.push(4) + assert v[3] == 4 + assert len(v) == 4 + w = DynArray([]) + w.push(100) + assert w[0] == 100 + else: + # Different local vector in else branch - allowed (no clash, different control flow) + w = DynArray([10, 20]) + w.push(30) + assert w[2] == 30 + return diff --git a/crates/lean_compiler/tests/test_data/program_16.py b/crates/lean_compiler/tests/test_data/program_16.py new file mode 100644 index 00000000..7280a522 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_16.py @@ -0,0 +1,12 @@ +from snark_lib import * +def main(): + b = is_one() + c = b + return + +@inline +def is_one(): + if !!assume_bool(1): + return 1 + else: + return 0 diff --git a/crates/lean_compiler/tests/test_data/program_160.py b/crates/lean_compiler/tests/test_data/program_160.py new file mode 100644 index 00000000..1659f99d --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_160.py @@ -0,0 +1,15 @@ +from snark_lib import * + + +# Test: local vectors inside non-unrolled loops are allowed +# This just tests that local vector creation and push works inside a loop +def main(): + for i in range(0, 3): + # Local vector created fresh each iteration - allowed + v = DynArray([1, 2, 3]) + v.push(4) + # Use the vector within the same iteration + assert v[0] == 1 + assert v[3] == 4 + assert len(v) == 4 + return diff --git a/crates/lean_compiler/tests/test_data/program_161.py b/crates/lean_compiler/tests/test_data/program_161.py new file mode 100644 index 00000000..da7b127c --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_161.py @@ -0,0 +1,14 @@ +from snark_lib import * + +# Test: compile-time true condition allows push to outer-scope vector in then branch +FLAG = 1 + + +def main(): + v = DynArray([1, 2, 3]) + if FLAG == 1: + v.push(4) # OK: condition is compile-time true, branch is inlined + else: + v.push(5) + assert v[3] == 4 + return diff --git a/crates/lean_compiler/tests/test_data/program_162.py b/crates/lean_compiler/tests/test_data/program_162.py new file mode 100644 index 00000000..74268a82 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_162.py @@ -0,0 +1,14 @@ +from snark_lib import * + +# Test: compile-time false condition allows push to outer-scope vector in else branch +FLAG = 0 + + +def main(): + v = DynArray([1, 2, 3]) + if FLAG == 1: + v.push(4) + else: + v.push(5) # OK: condition is compile-time false, else branch is inlined + assert v[3] == 5 + return diff --git a/crates/lean_compiler/tests/test_data/program_163.py b/crates/lean_compiler/tests/test_data/program_163.py new file mode 100644 index 00000000..507acf7d --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_163.py @@ -0,0 +1,16 @@ +from snark_lib import * + +# Test: compile-time condition using const array access +ARR = [0, 1, 2] + + +def main(): + v = DynArray([]) + v.push(10) + if ARR[1] == 1: + v.push(20) # OK: ARR[1] == 1 is compile-time true + else: + v.push(30) + assert v[0] == 10 + assert v[1] == 20 + return diff --git a/crates/lean_compiler/tests/test_data/program_164.py b/crates/lean_compiler/tests/test_data/program_164.py new file mode 100644 index 00000000..140275e8 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_164.py @@ -0,0 +1,18 @@ +from snark_lib import * + +# Test: nested compile-time conditions +A = 1 +B = 2 + + +def main(): + v = DynArray([]) + if A == 1: + if B == 2: + v.push(100) # OK: both conditions are compile-time true + else: + v.push(200) + else: + v.push(300) + assert v[0] == 100 + return diff --git a/crates/lean_compiler/tests/test_data/program_165.py b/crates/lean_compiler/tests/test_data/program_165.py new file mode 100644 index 00000000..ef5bc8d7 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_165.py @@ -0,0 +1,98 @@ +from snark_lib import * +# Comprehensive test: nested unrolled loops, vectors with pushes in various scopes + + +def main(): + # === PART 1: Basic nested loops over 2D vector === + + outer = DynArray([DynArray([1, 2]), DynArray([10, 20, 30])]) + + total: Mut = 0 + for i in unroll(0, len(outer)): + row_sum: Mut = 0 + for j in unroll(0, len(outer[i])): + row_sum = row_sum + outer[i][j] + total = total + row_sum + assert total == 63 + + # === PART 2: Push new row to outer, iterate again === + + outer.push(DynArray([100, 200])) + assert len(outer) == 3 + + total: Mut = 0 + for i in unroll(0, len(outer)): + row_sum: Mut = 0 + for j in unroll(0, len(outer[i])): + row_sum = row_sum + outer[i][j] + total = total + row_sum + assert total == 363 + + # === PART 3: Multiple vectors cross product === + + v1 = DynArray([1, 2, 3]) + v2 = DynArray([10, 20]) + + cross_sum: Mut = 0 + for i in unroll(0, len(v1)): + for j in unroll(0, len(v2)): + cross_sum = cross_sum + v1[i] * v2[j] + assert cross_sum == 180 + + v2.push(30) + cross_sum: Mut = 0 + for i in unroll(0, len(v1)): + for j in unroll(0, len(v2)): + cross_sum = cross_sum + v1[i] * v2[j] + assert cross_sum == 360 + + # === PART 4: Accumulator reused without reset === + + data = DynArray([5, 10, 15, 20]) + + acc: Mut = 0 + for i in unroll(0, len(data)): + acc = acc + data[i] + assert acc == 50 + + for i in unroll(0, len(data)): + acc = acc + data[i] * data[i] + assert acc == 800 + + # === PART 5: if inside unrolled loop (compile-time condition) === + + data2 = DynArray([1, 2, 3, 4]) + acc2: Mut = 0 + for i in unroll(0, len(data2)): + acc2 = acc2 + data2[i] + if i == 2: + acc2 = acc2 * 2 + assert acc2 == 16 + + assert inlined() == 5 + + return + + +def inlined(): + v = DynArray([1, 2, 3]) + sum: Mut = 0 + for i in unroll(0, len(v)): + sum = sum + v[i] + debug_assert(sum == 6) + v.push(4) + assert len(v) == 4 + sum: Mut = 0 + for i in unroll(0, len(v)): + sum += v[i] + assert sum == 10 + w = DynArray([]) + for i in unroll(0, 5): + w.push(DynArray([])) + for j in unroll(0, i): + w[i].push(1) + sum: Mut = 0 + for j in unroll(0, len(w[i])): + sum += w[i][j] + assert sum == i + return len(w) diff --git a/crates/lean_compiler/tests/test_data/program_166.py b/crates/lean_compiler/tests/test_data/program_166.py new file mode 100644 index 00000000..f9c39236 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_166.py @@ -0,0 +1,39 @@ +from snark_lib import * + +EE = 0 +DIM = 5 + + +def main(): + v = DynArray([1, 2, 3]) + sum1: Mut = 0 + for i in unroll(0, len(v)): + sum1 = sum1 + v[i] + assert sum1 == 6 + v.push(4) + assert len(v) == 4 + sum2: Mut = 0 + for i in unroll(0, len(v)): + sum2 = sum2 + v[i] + assert sum2 == 10 + # Test nested vectors with len(w[i]) + w = DynArray([]) + for i in unroll(0, 5): + w.push(DynArray([])) + for j in unroll(0, i): + w[i].push(1) + assert len(w[i]) == i + assert len(w) == 5 + a = Array(DIM) + for i in unroll(0, DIM): + a[i] = 1 + w.push(DynArray([a])) + b = Array(DIM) + copy_5(w[5][0], b) + return + + +@inline +def copy_5(a, b): + dot_product(a, ONE_VEC_PTR, b, 1, EE) + return diff --git a/crates/lean_compiler/tests/test_data/program_167.py b/crates/lean_compiler/tests/test_data/program_167.py new file mode 100644 index 00000000..17ac8968 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_167.py @@ -0,0 +1,12 @@ +from snark_lib import * + +ARR = [1, 2, 3, 4, 5] + + +def main(): + x = (len(ARR) + ARR[2]) / ARR[3] + sum: Mut = 0 + for i in range(0, x): + sum += 1 + assert sum == 2 + return diff --git a/crates/lean_compiler/tests/test_data/program_168.py b/crates/lean_compiler/tests/test_data/program_168.py new file mode 100644 index 00000000..a529c6a0 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_168.py @@ -0,0 +1,86 @@ +from snark_lib import * +# Comprehensive test for vector.pop() + + +def main(): + # Basic pop on simple vector + v1 = DynArray([1, 2, 3, 4, 5]) + assert len(v1) == 5 + v1.pop() + assert len(v1) == 4 + v1.pop() + v1.pop() + assert len(v1) == 2 + # v1 should now be [1, 2] + assert v1[0] == 1 + assert v1[1] == 2 + + """ + multi line + comment + """ + + # Pop in unrolled loop + v2 = DynArray([10, 20, 30, 40, 50]) + for i in unroll(0, 3): + v2.pop() + assert len(v2) == 2 + assert v2[0] == 10 + assert v2[1] == 20 + + # Pop from nested vector + matrix = DynArray([DynArray([1, 2, 3]), DynArray([4, 5, 6, 7]), DynArray([8, 9])]) + assert len(matrix[0]) == 3 + assert len(matrix[1]) == 4 + matrix[1].pop() + assert len(matrix[1]) == 3 + matrix[0].pop() + matrix[0].pop() + assert len(matrix[0]) == 1 + assert matrix[0][0] == 1 + assert matrix[1][0] == 4 + assert matrix[1][1] == 5 + assert matrix[1][2] == 6 + + # Pop outer vector element + matrix.pop() + assert len(matrix) == 2 + + # Mix push and pop + v3 = DynArray([100]) + v3.push(200) + v3.push(300) + assert len(v3) == 3 + v3.pop() + assert len(v3) == 2 + v3.push(400) + assert len(v3) == 3 + assert v3[0] == 100 + assert v3[1] == 200 + assert v3[2] == 400 + + # Pop until one element remains + v4 = DynArray([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + for i in unroll(0, 9): + v4.pop() + assert len(v4) == 1 + assert v4[0] == 1 + + # Pop on nested vector with index expression + nested = DynArray([DynArray([DynArray([1, 2, 3])])]) + nested[0][0].pop() + assert len(nested[0][0]) == 2 + + # Build vector, pop some, then iterate + v5 = DynArray([]) + for i in unroll(0, 5): + v5.push(i * 10) + v5.pop() + v5.pop() + sum: Mut = 0 + for i in unroll(0, len(v5)): + sum = sum + v5[i] + # v5 = [0, 10, 20], sum = 30 + assert sum == 30 + + return diff --git a/crates/lean_compiler/tests/test_data/program_169.py b/crates/lean_compiler/tests/test_data/program_169.py new file mode 100644 index 00000000..2e142746 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_169.py @@ -0,0 +1,14 @@ +from snark_lib import * + + +def main(): + x = 1 + 2 + 3 + assert x == 6 + + y = foo(10, 20, 30) + assert y == 60 + return + + +def foo(a, b, c): + return a + b + c diff --git a/crates/lean_compiler/tests/test_data/program_17.py b/crates/lean_compiler/tests/test_data/program_17.py new file mode 100644 index 00000000..d9274d42 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_17.py @@ -0,0 +1,18 @@ +from snark_lib import * + + +def main(): + x = func() + return + + +def func(): + a: Imu + if 0 == 0: + a = aux() + return a + + +@inline +def aux(): + return 1 diff --git a/crates/lean_compiler/tests/test_data/program_170.py b/crates/lean_compiler/tests/test_data/program_170.py new file mode 100644 index 00000000..8af7976f --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_170.py @@ -0,0 +1,67 @@ +from snark_lib import * + + +def add_four(a, b, c, d): + return a + b + c + d + + +def multi_return(a, b): + return ( + a + 1, + b + 2, + a + b, + ) + + +def multi_line_params( + a, + b: Mut, + c: Const, +): + return a + b + c + + +def main(): + result = add_four(1, 2, 3, 4) + assert result == 10 + + arr = DynArray([1, 2, 3]) + assert arr[0] == 1 + assert arr[1] == 2 + assert arr[2] == 3 + + nested = add_four(1, add_four(10, 20, 30, 40), 2, 3) + assert nested == 106 + + x = 5 + y = 10 + z: Imu + if x + y == 15: + z = 1 + else: + z = 0 + assert z == 1 + + w: Imu + if x + y * 2 == 25: + w = 100 + else: + w = 0 + assert w == 100 + + r1, r2, r3 = multi_return(10, 20) + assert r1 == 11 + assert r2 == 22 + assert r3 == 30 + + assert r1 == 11 + assert r2 + r3 == 52 + + (s1, s2, s3) = multi_return(100, 200) + assert s1 == 101 + assert s2 == 202 + assert s3 == 300 + + mlp = multi_line_params(1, 2, 3) + assert mlp == 6 + return diff --git a/crates/lean_compiler/tests/test_data/program_18.py b/crates/lean_compiler/tests/test_data/program_18.py new file mode 100644 index 00000000..0f9ee7bd --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_18.py @@ -0,0 +1,40 @@ +from snark_lib import * + + +def main(): + for x in unroll(0, 3): + func_match(x) + for x in unroll(0, 2): + match x: + case 0: + y = 10 * (x + 8) + z = 10 * y + print(z) + case 1: + y = 10 * x + z = func_2(y) + print(z) + return + + +@inline +def func_match(x): + match x: + case 0: + print(41) + case 1: + y = func_1(x) + print(y + 1) + case 2: + y = 10 * x + print(y) + return + + +def func_1(x): + return x * x * x * x + + +@inline +def func_2(x): + return x * x * x * x * x * x diff --git a/crates/lean_compiler/tests/test_data/program_19.py b/crates/lean_compiler/tests/test_data/program_19.py new file mode 100644 index 00000000..124e8547 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_19.py @@ -0,0 +1,16 @@ +from snark_lib import * + + +def main(): + match 1: + case 0: + y = 90 + case 1: + y = 10 + z = func_2(y) + return + + +@inline +def func_2(x): + return x * x diff --git a/crates/lean_compiler/tests/test_data/program_2.py b/crates/lean_compiler/tests/test_data/program_2.py new file mode 100644 index 00000000..83276d95 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_2.py @@ -0,0 +1,9 @@ +from snark_lib import * +from misc.bar import * +from misc.foo import * + + +def main(): + x = bar(FOO) + assert x == 6 + return diff --git a/crates/lean_compiler/tests/test_data/program_20.py b/crates/lean_compiler/tests/test_data/program_20.py new file mode 100644 index 00000000..896cd214 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_20.py @@ -0,0 +1,16 @@ +from snark_lib import * + + +def main(): + y = compute_value(3) + print(y) + return + + +def compute_value(n: Const): + result = complex_computation(n, 5) + return result + + +def complex_computation(a: Const, b: Const): + return a * a + b * b diff --git a/crates/lean_compiler/tests/test_data/program_21.py b/crates/lean_compiler/tests/test_data/program_21.py new file mode 100644 index 00000000..0397a623 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_21.py @@ -0,0 +1,19 @@ +from snark_lib import * + + +def main(): + x = double(3) + y = quad(x) + print(y) + return + + +@inline +def double(a): + return a + a + + +@inline +def quad(b): + result = double(b) + return result + result diff --git a/crates/lean_compiler/tests/test_data/program_22.py b/crates/lean_compiler/tests/test_data/program_22.py new file mode 100644 index 00000000..e150345a --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_22.py @@ -0,0 +1,24 @@ +from snark_lib import * + + +def main(): + result = level_one(3) + print(result) + return + + +@inline +def level_one(x): + result = level_two(x) + return result + + +@inline +def level_two(y): + result = level_three(y) + return result + + +@inline +def level_three(z): + return z * z * z diff --git a/crates/lean_compiler/tests/test_data/program_23.py b/crates/lean_compiler/tests/test_data/program_23.py new file mode 100644 index 00000000..ac182cd3 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_23.py @@ -0,0 +1,16 @@ +from snark_lib import * + + +def main(): + f(1) + return + + +def f(n): + if 0 == 0: + res = Array(2) + res[1] = 0 + return + else: + res = Array(n * 1) + return diff --git a/crates/lean_compiler/tests/test_data/program_24.py b/crates/lean_compiler/tests/test_data/program_24.py new file mode 100644 index 00000000..64caae06 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_24.py @@ -0,0 +1,10 @@ +from snark_lib import * + + +def main(): + a = 10 + b = 20 + debug_assert(a * 2 == b) + debug_assert(a != b) + debug_assert(a < b) + return diff --git a/crates/lean_compiler/tests/test_data/program_25.py b/crates/lean_compiler/tests/test_data/program_25.py new file mode 100644 index 00000000..9dc3a67f --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_25.py @@ -0,0 +1,11 @@ +from snark_lib import * + + +def main(): + a = double(next_multiple_of(12, 8)) + assert a == 32 + return + + +def double(n: Const): + return next_multiple_of(n, n) * 2 diff --git a/crates/lean_compiler/tests/test_data/program_26.py b/crates/lean_compiler/tests/test_data/program_26.py new file mode 100644 index 00000000..0e98dad4 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_26.py @@ -0,0 +1,31 @@ +from snark_lib import * + +FIVE = 5 +ARR = [4, FIVE, 4 + 2, 3 * 2 + 1] + + +def main(): + for i in unroll(1, len(ARR)): + x = i + 4 + assert ARR[i] == x + four = 4 + assert len(ARR) == four + res = func(2) + six = 6 + assert res == six + nothing(ARR[0]) + mem_arr = Array(len(ARR)) + for i in unroll(0, len(ARR)): + mem_arr[i] = ARR[i] + for i in range(0, ARR[0]): + print(2 ** ARR[0]) + print(2 ** ARR[1]) + return + + +def func(x: Const): + return ARR[x] + + +def nothing(x): + return diff --git a/crates/lean_compiler/tests/test_data/program_27.py b/crates/lean_compiler/tests/test_data/program_27.py new file mode 100644 index 00000000..17de9b6e --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_27.py @@ -0,0 +1,11 @@ +from snark_lib import * + + +def main(): + x = Array(2) + x[0] = 3 + x[1] = 5 + for i in unroll(0, 2): + for j in range(0, x[i]): + print(i, j) + return diff --git a/crates/lean_compiler/tests/test_data/program_28.py b/crates/lean_compiler/tests/test_data/program_28.py new file mode 100644 index 00000000..9682965a --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_28.py @@ -0,0 +1,25 @@ +from snark_lib import * + + +def main(): + a = Array(10) + b = Array(10) + a[1], b[4] = get_two_values() + assert a[1] == 42 + assert b[4] == 99 + + i = 2 + j = 3 + a[i], b[j] = get_two_values() + assert a[2] == 42 + assert b[3] == 99 + + x, a[5] = get_two_values() + assert x == 42 + assert a[5] == 99 + + return + + +def get_two_values(): + return 42, 99 diff --git a/crates/lean_compiler/tests/test_data/program_29.py b/crates/lean_compiler/tests/test_data/program_29.py new file mode 100644 index 00000000..f428be81 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_29.py @@ -0,0 +1,18 @@ +from snark_lib import * + + +def main(): + arr = Array(20) + for i in range(0, 5): + arr[i * 2], arr[i * 2 + 1] = compute_pair(i) + assert arr[0] == 0 + assert arr[1] == 0 + assert arr[2] == 1 + assert arr[3] == 2 + assert arr[4] == 2 + assert arr[5] == 4 + return + + +def compute_pair(n): + return n, n * 2 diff --git a/crates/lean_compiler/tests/test_data/program_3.py b/crates/lean_compiler/tests/test_data/program_3.py new file mode 100644 index 00000000..5d35c7d8 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_3.py @@ -0,0 +1,14 @@ +from snark_lib import * + + +def main(): + a = f() + assert log2_ceil(13) == 4 + return + + +def f(): + if 1 == 1: + return 0 + else: + assert 5 == 7 diff --git a/crates/lean_compiler/tests/test_data/program_30.py b/crates/lean_compiler/tests/test_data/program_30.py new file mode 100644 index 00000000..0348faa6 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_30.py @@ -0,0 +1,34 @@ +from snark_lib import * + +ARR = [10, 100] + + +def main(): + buff = Array(3) + buff[0] = 0 + for i in unroll(0, 2): + res = f1(ARR[i]) + buff[i + 1] = res + assert buff[2] == 1390320454 + return + + +def f1(x: Const): + buff = Array(9) + buff[0] = 1 + for i in unroll(x, x + 4): + for j in unroll(i, i + 2): + index = (i - x) * 2 + (j - i) + res = f2(i, j) + buff[index + 1] = buff[index] * res + return buff[8] + + +def f2(x: Const, y: Const): + buff = Array(7) + buff[0] = 0 + for i in unroll(x, x + 2): + for j in unroll(i, i + 3): + index = (i - x) * 3 + (j - i) + buff[index + 1] = buff[index] + i + j + return buff[4] diff --git a/crates/lean_compiler/tests/test_data/program_31.py b/crates/lean_compiler/tests/test_data/program_31.py new file mode 100644 index 00000000..765ecfba --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_31.py @@ -0,0 +1,9 @@ +from snark_lib import * + +ARR = [10, 100] + + +def main(): + a = ARR[0] + assert a == 10 + return diff --git a/crates/lean_compiler/tests/test_data/program_32.py b/crates/lean_compiler/tests/test_data/program_32.py new file mode 100644 index 00000000..752dcf31 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_32.py @@ -0,0 +1,17 @@ +from snark_lib import * + + +def main(): + res = fib(8) + assert res == 21 + return + + +def fib(n: Const): + if n == 0: + return 0 + if n == 1: + return 1 + a = fib(saturating_sub(n, 1)) + b = fib(saturating_sub(n, 2)) + return a + b diff --git a/crates/lean_compiler/tests/test_data/program_33.py b/crates/lean_compiler/tests/test_data/program_33.py new file mode 100644 index 00000000..e790e085 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_33.py @@ -0,0 +1,16 @@ +from snark_lib import * + + +def main(): + a = b() + b = a() + assert a + b == 30 + return + + +def a(): + return 10 + + +def b(): + return 20 diff --git a/crates/lean_compiler/tests/test_data/program_34.py b/crates/lean_compiler/tests/test_data/program_34.py new file mode 100644 index 00000000..261ab1ae --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_34.py @@ -0,0 +1,18 @@ +from snark_lib import * + +NESTED = [[1, 2], [3, 4, 5], [6]] + + +def main(): + assert len(NESTED) == 3 + assert len(NESTED[0]) == 2 + assert len(NESTED[1]) == 3 + assert len(NESTED[2]) == 1 + + assert NESTED[0][0] == 1 + assert NESTED[0][1] == 2 + assert NESTED[1][0] == 3 + assert NESTED[1][2] == 5 + assert NESTED[2][0] == 6 + + return diff --git a/crates/lean_compiler/tests/test_data/program_35.py b/crates/lean_compiler/tests/test_data/program_35.py new file mode 100644 index 00000000..617e4d3d --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_35.py @@ -0,0 +1,22 @@ +from snark_lib import * + +DEEP = [[[1, 2], [3]], [[4, 5, 6]]] +ONE = 1 + + +def main(): + assert len(DEEP) == 2 + assert len(DEEP[0]) == 2 + assert len(DEEP[0][0]) == 2 + assert len(DEEP[0][1]) == 1 + one = 1 + assert len(DEEP[ONE]) == one + assert len(DEEP[1][0]) == 3 + + assert DEEP[0][0][0] == 1 + assert DEEP[0][0][1] == 2 + assert DEEP[0][1][0] == 3 + assert DEEP[1][0][0] == 4 + assert DEEP[1][0][2] == 6 + + return diff --git a/crates/lean_compiler/tests/test_data/program_36.py b/crates/lean_compiler/tests/test_data/program_36.py new file mode 100644 index 00000000..5be45973 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_36.py @@ -0,0 +1,21 @@ +from snark_lib import * + +TWO = 2 +ARR = [[1 + 1, TWO * 2], [3 + TWO]] +INCR_ARR = [[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5]] + + +def main(): + assert len(ARR) == 2 + assert ARR[0][0] == 2 + assert ARR[0][1] == 4 + assert ARR[1][0] == 5 + five = ARR[1][0] + assert five == 5 + x = 2 + 3 * (ARR[0][0] + ARR[1][0] + 3) ** 2 + assert x == 302 + for i in unroll(0, 4): + for j in unroll(0, 3): + y = INCR_ARR[i][j] + assert INCR_ARR[i][j] == i + j - INCR_ARR[i][j] + y + return diff --git a/crates/lean_compiler/tests/test_data/program_37.py b/crates/lean_compiler/tests/test_data/program_37.py new file mode 100644 index 00000000..29b2943e --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_37.py @@ -0,0 +1,9 @@ +from snark_lib import * + +ARR = [[5]] + + +def main(): + x = ARR[0][0] ** 2 + assert x == 25 + return diff --git a/crates/lean_compiler/tests/test_data/program_38.py b/crates/lean_compiler/tests/test_data/program_38.py new file mode 100644 index 00000000..6b3df3dd --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_38.py @@ -0,0 +1,34 @@ +from snark_lib import * + + +def main(): + assert incr(incr(incr(1))) == 4 + x = add(incr(1), incr(2)) + assert x == 5 + + assert incr_inlined(incr_inlined(incr_inlined(1))) == 4 + y = add_inlined(incr_inlined(1), add_inlined(incr_inlined(2), incr_inlined(2))) + assert y == 8 + + return + + +def add(a, b): + return a + b + + +def incr(x): + return x + 1 + + +@inline +def incr_inlined(x): + return x + 1 + + +@inline +def add_inlined(a, b): + c = Array(1) + zero = 0 + c[zero] = a + b + return c[0] diff --git a/crates/lean_compiler/tests/test_data/program_39.py b/crates/lean_compiler/tests/test_data/program_39.py new file mode 100644 index 00000000..40fc1da2 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_39.py @@ -0,0 +1,11 @@ +from snark_lib import * + +ARR = [[1], [7, 3], [7]] +N = 2 + len(ARR[0]) + + +def main(): + for i in unroll(0, N): + for j in unroll(0, len(ARR[i])): + assert j * (j - 1) == 0 + return diff --git a/crates/lean_compiler/tests/test_data/program_4.py b/crates/lean_compiler/tests/test_data/program_4.py new file mode 100644 index 00000000..92e4e26f --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_4.py @@ -0,0 +1,14 @@ +from snark_lib import * + + +def main(): + fibonacci(0, 1, 0, 30) + return + + +def fibonacci(a, b, i, n): + if i == n: + print(a) + return + fibonacci(b, a + b, i + 1, n) + return diff --git a/crates/lean_compiler/tests/test_data/program_40.py b/crates/lean_compiler/tests/test_data/program_40.py new file mode 100644 index 00000000..04054392 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_40.py @@ -0,0 +1,9 @@ +from snark_lib import * + + +def main(): + x: Mut = 1 + x = x + 1 + x = x + 1 + assert x == 3 + return diff --git a/crates/lean_compiler/tests/test_data/program_41.py b/crates/lean_compiler/tests/test_data/program_41.py new file mode 100644 index 00000000..d8155ac8 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_41.py @@ -0,0 +1,9 @@ +from snark_lib import * + + +def main(): + sum: Mut = 0 + for i in unroll(0, 5): + sum = sum + i + assert sum == 10 + return diff --git a/crates/lean_compiler/tests/test_data/program_42.py b/crates/lean_compiler/tests/test_data/program_42.py new file mode 100644 index 00000000..b2e592e5 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_42.py @@ -0,0 +1,12 @@ +from snark_lib import * + + +def main(): + x: Mut = 1 + cond = 1 + if cond == 1: + x = x + 10 + else: + x = x + 20 + assert x == 11 + return diff --git a/crates/lean_compiler/tests/test_data/program_43.py b/crates/lean_compiler/tests/test_data/program_43.py new file mode 100644 index 00000000..aed3a00b --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_43.py @@ -0,0 +1,13 @@ +from snark_lib import * + + +def main(): + result = increment_twice(5) + assert result == 7 + return + + +def increment_twice(x: Mut): + x = x + 1 + x = x + 1 + return x diff --git a/crates/lean_compiler/tests/test_data/program_44.py b/crates/lean_compiler/tests/test_data/program_44.py new file mode 100644 index 00000000..9173c510 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_44.py @@ -0,0 +1,10 @@ +from snark_lib import * +def main(): + a, b: Mut = get_two() + b = b + 1 + assert a == 10 + assert b == 21 + return + +def get_two(): + return 10, 20 \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_45.py b/crates/lean_compiler/tests/test_data/program_45.py new file mode 100644 index 00000000..ca62bab3 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_45.py @@ -0,0 +1,8 @@ +from snark_lib import * + + +def main(): + x: Mut = 5 + x = x + 1 + assert x == 6 + return diff --git a/crates/lean_compiler/tests/test_data/program_46.py b/crates/lean_compiler/tests/test_data/program_46.py new file mode 100644 index 00000000..c69cb6f0 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_46.py @@ -0,0 +1,16 @@ +from snark_lib import * + + +def main(): + total: Mut = 0 + outer_sum: Mut = 0 + for i in unroll(0, 3): + outer_sum = outer_sum + i + inner_sum: Mut = 0 + for j in unroll(0, 4): + inner_sum = inner_sum + j + total = total + 1 + assert inner_sum == 6 + assert outer_sum == 3 + assert total == 12 + return diff --git a/crates/lean_compiler/tests/test_data/program_47.py b/crates/lean_compiler/tests/test_data/program_47.py new file mode 100644 index 00000000..258b9dab --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_47.py @@ -0,0 +1,20 @@ +from snark_lib import * + + +def main(): + count: Mut = 0 + sum_i: Mut = 0 + sum_j: Mut = 0 + sum_k: Mut = 0 + for i in unroll(0, 2): + sum_i = sum_i + i + for j in unroll(0, 3): + sum_j = sum_j + j + for k in unroll(0, 2): + sum_k = sum_k + k + count = count + 1 + assert count == 12 + assert sum_i == 1 + assert sum_j == 6 + assert sum_k == 6 + return diff --git a/crates/lean_compiler/tests/test_data/program_48.py b/crates/lean_compiler/tests/test_data/program_48.py new file mode 100644 index 00000000..e74475fc --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_48.py @@ -0,0 +1,17 @@ +from snark_lib import * + +N = 5 +M = 3 + + +def main(): + acc: Mut = 0 + for i in unroll(0, N): + acc = acc + M + assert acc == 15 + + product: Mut = 1 + for i in unroll(0, M): + product = product * 2 + assert product == 8 + return diff --git a/crates/lean_compiler/tests/test_data/program_49.py b/crates/lean_compiler/tests/test_data/program_49.py new file mode 100644 index 00000000..4742af40 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_49.py @@ -0,0 +1,28 @@ +from snark_lib import * + + +def main(): + a: Mut = 1 + b: Mut = 2 + a, b = swap(a, b) + assert a == 2 + assert b == 1 + + a, b = swap(a, b) + assert a == 1 + assert b == 2 + + c: Mut = compute(a, b) + assert c == 5 # 1 + 2*2 = 5 + c = compute(c, c) + assert c == 15 # 5 + 5*2 = 15 + return + + +def swap(x, y): + return y, x + + +def compute(x, y): + result = x + y * 2 + return result diff --git a/crates/lean_compiler/tests/test_data/program_5.py b/crates/lean_compiler/tests/test_data/program_5.py new file mode 100644 index 00000000..6b72944f --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_5.py @@ -0,0 +1,11 @@ +from snark_lib import * + + +def main(): + a = Array(1) + a[0] = 0 + for i in range(0, 1): + x = 1 + a[i] + for i in range(0, 1): + y = 1 + a[i] + return diff --git a/crates/lean_compiler/tests/test_data/program_50.py b/crates/lean_compiler/tests/test_data/program_50.py new file mode 100644 index 00000000..b44712ec --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_50.py @@ -0,0 +1,28 @@ +from snark_lib import * + + +def main(): + x: Mut = 0 + y: Mut = 0 + z: Mut = 0 + + cond1 = 1 + if cond1 == 1: + x = x + 10 + y = y + 20 + else: + x = x + 100 + z = z + 30 + assert x == 10 + assert y == 20 + assert z == 0 + + cond2 = 0 + if cond2 == 1: + x = x + 1 + else: + x = x + 2 + y = y + 3 + assert x == 12 + assert y == 23 + return diff --git a/crates/lean_compiler/tests/test_data/program_51.py b/crates/lean_compiler/tests/test_data/program_51.py new file mode 100644 index 00000000..11fdd72b --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_51.py @@ -0,0 +1,20 @@ +from snark_lib import * + + +def main(): + counter: Mut = 0 + + a = 1 + if a == 1: + counter = counter + 1 + else: + counter = counter + 1000 + assert counter == 1 + + b = 1 + if b == 1: + counter = counter + 10 + else: + counter = counter + 100 + assert counter == 11 + return diff --git a/crates/lean_compiler/tests/test_data/program_52.py b/crates/lean_compiler/tests/test_data/program_52.py new file mode 100644 index 00000000..30098556 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_52.py @@ -0,0 +1,28 @@ +from snark_lib import * + + +def main(): + a: Mut = 1 + b: Mut = 2 + c: Mut = 3 + + a = a + b # a = 3 + b = b + c # b = 5 + c = c + a # c = 6 + + a = a * 2 # a = 6 + b = b * 2 # b = 10 + c = c * 2 # c = 12 + + assert a == 6 + assert b == 10 + assert c == 12 + + a = b + c # a = 22 + b = c + a # b = 34 (uses new a) + c = a + b # c = 56 (uses new a and b) + + assert a == 22 + assert b == 34 + assert c == 56 + return diff --git a/crates/lean_compiler/tests/test_data/program_53.py b/crates/lean_compiler/tests/test_data/program_53.py new file mode 100644 index 00000000..9becfaca --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_53.py @@ -0,0 +1,30 @@ +from snark_lib import * + + +def main(): + fib_prev: Mut = 0 + fib_curr: Mut = 1 + + temp0 = fib_curr + fib_curr = fib_prev + fib_curr + fib_prev = temp0 + + temp1 = fib_curr + fib_curr = fib_prev + fib_curr + fib_prev = temp1 + + temp2 = fib_curr + fib_curr = fib_prev + fib_curr + fib_prev = temp2 + + temp3 = fib_curr + fib_curr = fib_prev + fib_curr + fib_prev = temp3 + + temp4 = fib_curr + fib_curr = fib_prev + fib_curr + fib_prev = temp4 + + assert fib_curr == 8 + assert fib_prev == 5 + return diff --git a/crates/lean_compiler/tests/test_data/program_54.py b/crates/lean_compiler/tests/test_data/program_54.py new file mode 100644 index 00000000..f8bb5230 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_54.py @@ -0,0 +1,19 @@ +from snark_lib import * + +N = 5 + + +def main(): + arr = Array(N) + + sum: Mut = 0 + for i in unroll(0, N): + arr[i] = i * 2 + sum = sum + arr[i] + assert sum == 20 + + product: Mut = 1 + for i in unroll(1, N): + product = product * arr[i] + assert product == 384 + return diff --git a/crates/lean_compiler/tests/test_data/program_55.py b/crates/lean_compiler/tests/test_data/program_55.py new file mode 100644 index 00000000..ac69163b --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_55.py @@ -0,0 +1,17 @@ +from snark_lib import * + + +def main(): + selector = 2 + match selector: + case 0: + result: Mut = 1 + assert result == 1 + case 1: + result: Mut = 10 + assert result == 10 + case 2: + result: Mut = 100 + result = result + 5 + assert result == 105 + return diff --git a/crates/lean_compiler/tests/test_data/program_56.py b/crates/lean_compiler/tests/test_data/program_56.py new file mode 100644 index 00000000..0eb90483 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_56.py @@ -0,0 +1,12 @@ +from snark_lib import * + +WEIGHTS = [1, 2, 3, 4, 5] +N = 5 + + +def main(): + weighted_sum: Mut = 0 + for i in unroll(0, N): + weighted_sum = weighted_sum + WEIGHTS[i] * (i + 1) + assert weighted_sum == 55 + return diff --git a/crates/lean_compiler/tests/test_data/program_57.py b/crates/lean_compiler/tests/test_data/program_57.py new file mode 100644 index 00000000..3d4965c4 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_57.py @@ -0,0 +1,28 @@ +from snark_lib import * + + +def main(): + x: Mut = 1 + x = step1(x) + x = step2(x) + x = step3(x) + assert x == 47 + return + + +def step1(n: Mut): + n = n * 2 + n = n + 1 + return n + + +def step2(n: Mut): + n = n * 3 + n = n + 2 + return n + + +def step3(n: Mut): + n = n * 4 + n = n + 3 + return n diff --git a/crates/lean_compiler/tests/test_data/program_58.py b/crates/lean_compiler/tests/test_data/program_58.py new file mode 100644 index 00000000..a8e0bda5 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_58.py @@ -0,0 +1,22 @@ +from snark_lib import * + + +def main(): + a: Mut = 5 + b: Mut = 3 + + a = a * a + b * b # 25 + 9 = 34 + assert a == 34 + + b = a - b * 2 # 34 - 6 = 28 + assert b == 28 + + c: Mut = a + b # 34 + 28 = 62 + assert c == 62 + + c = c + 8 # 70 + assert c == 70 + + c = c - 10 # 60 + assert c == 60 + return diff --git a/crates/lean_compiler/tests/test_data/program_59.py b/crates/lean_compiler/tests/test_data/program_59.py new file mode 100644 index 00000000..ab969273 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_59.py @@ -0,0 +1,16 @@ +from snark_lib import * + + +def main(): + total: Mut = 0 + for i in unroll(0, 5): + temp = add_one_pure(i) + total = total + temp + assert total == 15 + return + + +@inline +def add_one_pure(x): + result = x + 1 + return result diff --git a/crates/lean_compiler/tests/test_data/program_6.py b/crates/lean_compiler/tests/test_data/program_6.py new file mode 100644 index 00000000..a3eda9ae --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_6.py @@ -0,0 +1,8 @@ +from snark_lib import * + + +def main(): + a = Array(1) + a[0] = 0 + assert a[8 - 8] == 0 + return diff --git a/crates/lean_compiler/tests/test_data/program_60.py b/crates/lean_compiler/tests/test_data/program_60.py new file mode 100644 index 00000000..02be789b --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_60.py @@ -0,0 +1,21 @@ +from snark_lib import * + +N = 6 + + +def main(): + factorial: Mut = 1 + for i in unroll(1, N): + factorial = factorial * i + assert factorial == 120 + + sum_squares: Mut = 0 + for i in unroll(1, N): + sum_squares = sum_squares + i * i + assert sum_squares == 55 + + triangular: Mut = 0 + for i in unroll(1, N): + triangular = triangular + i + assert triangular == 15 + return diff --git a/crates/lean_compiler/tests/test_data/program_61.py b/crates/lean_compiler/tests/test_data/program_61.py new file mode 100644 index 00000000..91e0342b --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_61.py @@ -0,0 +1,26 @@ +from snark_lib import * + + +def main(): + a: Mut = 1 + b: Mut = 2 + c: Mut = 3 + + a, b = double_both(a, b) + assert a == 2 + assert b == 4 + + b, c = double_both(b, c) + assert b == 8 + assert c == 6 + + a, c = double_both(a, c) + assert a == 4 + assert c == 12 + + assert a + b + c == 24 + return + + +def double_both(x, y): + return x * 2, y * 2 diff --git a/crates/lean_compiler/tests/test_data/program_63.py b/crates/lean_compiler/tests/test_data/program_63.py new file mode 100644 index 00000000..c5dc5beb --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_63.py @@ -0,0 +1,16 @@ +from snark_lib import * + + +def main(): + even_sum: Mut = 0 + odd_sum: Mut = 0 + + for i in unroll(0, 6): + remainder = i % 2 + if remainder == 0: + even_sum = even_sum + i + else: + odd_sum = odd_sum + i + assert even_sum == 6 + assert odd_sum == 9 + return diff --git a/crates/lean_compiler/tests/test_data/program_64.py b/crates/lean_compiler/tests/test_data/program_64.py new file mode 100644 index 00000000..cfbb8ffd --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_64.py @@ -0,0 +1,26 @@ +from snark_lib import * +def main(): + for i in range(0, 6): + x: Mut = i + x = x + 1 + for j in range(0, 3): + y: Mut = x + 1 + y = y + j + if i == 10: + y = y - 1 + if j == 10000: + y = y - 2 + else if i != 1000: + y = y + 2 + if j == 10000: + y = y - 2 + else if i == 1000: + y = y + 2 + if j == 10000: + y = y - 2 + else if i != 1000: + y = y + 2 + else: + y = y + 2 + assert y == i + j + 6 + return \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_65.py b/crates/lean_compiler/tests/test_data/program_65.py new file mode 100644 index 00000000..f3c88e47 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_65.py @@ -0,0 +1,16 @@ +from snark_lib import * +def main(): + assert func(0) == 11 + assert func(1) == 20 + assert func(2) == 10 + return + +def func(i): + x: Mut = 10 + match i: + case 0: + x = x + 1 + case 1: + x = x + 10 + case 2: + return x \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_66.py b/crates/lean_compiler/tests/test_data/program_66.py new file mode 100644 index 00000000..6232911b --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_66.py @@ -0,0 +1,30 @@ +from snark_lib import * +def main(): + assert func(0) == 11 + assert func(1) == 40 + assert func(2) == 10 + return + +def func(i): + x: Mut = 10 + match i: + case 0: + x = x + 1 + case 1: + if 1 == 0: + x = x + 100 + else: + x = x + 10 + if 1 == 0: + + else: + x = x + 10 + if 1 == 1: + if 1 == 0: + + else: + x = x + 10 + else: + + case 2: + return x \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_67.py b/crates/lean_compiler/tests/test_data/program_67.py new file mode 100644 index 00000000..ed80db93 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_67.py @@ -0,0 +1,8 @@ +from snark_lib import * + + +def main(): + mut_a: Imu + mut_a = 5 + assert mut_a == 5 + return diff --git a/crates/lean_compiler/tests/test_data/program_68.py b/crates/lean_compiler/tests/test_data/program_68.py new file mode 100644 index 00000000..13eb50c6 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_68.py @@ -0,0 +1,24 @@ +from snark_lib import * + + +def main(): + assert test_func(0, 0) == 6 + assert test_func(1, 0) == 3 + return + + +def test_func(a, b): + x = 1 + + mut_x_2: Imu + match a: + case 0: + mut_x_1: Imu + mut_x_1 = x + 2 + match b: + case 0: + mut_x_2 = mut_x_1 + 3 + case 1: + mut_x_2 = x + 2 + + return mut_x_2 diff --git a/crates/lean_compiler/tests/test_data/program_69.py b/crates/lean_compiler/tests/test_data/program_69.py new file mode 100644 index 00000000..4c19fda8 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_69.py @@ -0,0 +1,76 @@ +from snark_lib import * + + +def main(): + assert compute(0, 0, 0) == 1008 + assert compute(0, 0, 1) == 1009 + assert compute(0, 1, 0) == 1012 + assert compute(0, 1, 1) == 1013 + assert compute(1, 0, 0) == 1036 + assert compute(1, 0, 1) == 1037 + assert compute(1, 1, 0) == 1046 + assert compute(1, 1, 1) == 1047 + return + + +def compute(a, b, c): + base = 1000 + outer_val: Imu + mid_val: Imu + inner_val: Imu + + match a: + case 0: + outer_val = 5 + local_a: Imu + local_a = a + outer_val + + match b: + case 0: + mid_val = 3 + local_b: Imu + local_b = local_a + mid_val + + match c: + case 0: + inner_val = base + local_b + c + case 1: + inner_val = base + local_b + c + case 1: + mid_val = 7 + local_b: Imu + local_b = local_a + mid_val + + match c: + case 0: + inner_val = base + local_b + c + case 1: + inner_val = base + local_b + c + case 1: + outer_val = 15 + local_a: Imu + local_a = a + outer_val + + match b: + case 0: + mid_val = 20 + local_b: Imu + local_b = local_a + mid_val + + match c: + case 0: + inner_val = base + local_b + c + case 1: + inner_val = base + local_b + c + case 1: + mid_val = 30 + local_b: Imu + local_b = local_a + mid_val + + match c: + case 0: + inner_val = base + local_b + c + case 1: + inner_val = base + local_b + c + + return inner_val diff --git a/crates/lean_compiler/tests/test_data/program_7.py b/crates/lean_compiler/tests/test_data/program_7.py new file mode 100644 index 00000000..038a6cd0 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_7.py @@ -0,0 +1,11 @@ +from snark_lib import * + + +def main(): + for i in unroll(0, 5): + x = i + print(x) + for i in unroll(0, 3): + x = i + print(x) + return diff --git a/crates/lean_compiler/tests/test_data/program_70.py b/crates/lean_compiler/tests/test_data/program_70.py new file mode 100644 index 00000000..af5bd9fb --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_70.py @@ -0,0 +1,16 @@ +from snark_lib import * +def main(): + x: Mut = 5 + cond = 1 + if cond == 1: + x = x + 10 + else: + assert x == 15 + + y: Mut = 10 + cond2 = 0 + if cond2 == 1: + else: + y = y + 5 + assert y == 15 + return \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_71.py b/crates/lean_compiler/tests/test_data/program_71.py new file mode 100644 index 00000000..31eac88d --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_71.py @@ -0,0 +1,25 @@ +from snark_lib import * + + +def main(): + x: Mut = 1 + cond = 1 + if cond == 1: + x = x + 1 + x = x + 1 + x = x + 1 + else: + x = x + 10 + assert x == 4 + + y: Mut = 1 + cond2 = 0 + if cond2 == 1: + y = y + 1 + else: + y = y + 1 + y = y + 1 + y = y + 1 + y = y + 1 + assert y == 5 + return diff --git a/crates/lean_compiler/tests/test_data/program_72.py b/crates/lean_compiler/tests/test_data/program_72.py new file mode 100644 index 00000000..fecf99d9 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_72.py @@ -0,0 +1,22 @@ +from snark_lib import * + + +def main(): + x: Mut = 0 + a = 1 + b = 1 + c = 1 + if a == 1: + x = x + 1 + if b == 1: + x = x + 10 + if c == 1: + x = x + 100 + else: + x = x + 200 + else: + x = x + 20 + else: + x = x + 1000 + assert x == 111 + return diff --git a/crates/lean_compiler/tests/test_data/program_73.py b/crates/lean_compiler/tests/test_data/program_73.py new file mode 100644 index 00000000..2db92911 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_73.py @@ -0,0 +1,20 @@ +from snark_lib import * + + +def main(): + a: Mut = 0 + b: Mut = 0 + c: Mut = 0 + + cond = 1 + if cond == 1: + a = a + 1 + b = b + 10 + else: + b = b + 20 + c = c + 100 + + assert a == 1 + assert b == 10 + assert c == 0 + return diff --git a/crates/lean_compiler/tests/test_data/program_74.py b/crates/lean_compiler/tests/test_data/program_74.py new file mode 100644 index 00000000..c9fe124e --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_74.py @@ -0,0 +1,21 @@ +from snark_lib import * + + +def main(): + assert test_func(1, 0) == 11 + assert test_func(1, 1) == 20 + assert test_func(0, 0) == 100 + return + + +def test_func(cond, selector): + x: Mut = 10 + if cond == 1: + match selector: + case 0: + x = x + 1 + case 1: + x = x + 10 + else: + x = x + 90 + return x diff --git a/crates/lean_compiler/tests/test_data/program_75.py b/crates/lean_compiler/tests/test_data/program_75.py new file mode 100644 index 00000000..3a9a7bb6 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_75.py @@ -0,0 +1,25 @@ +from snark_lib import * + + +def main(): + assert test_func(0, 1) == 11 + assert test_func(0, 0) == 20 + assert test_func(1, 1) == 110 + assert test_func(1, 0) == 200 + return + + +def test_func(selector, cond): + x: Mut = 10 + match selector: + case 0: + if cond == 1: + x = x + 1 + else: + x = x + 10 + case 1: + if cond == 1: + x = x + 100 + else: + x = x + 190 + return x diff --git a/crates/lean_compiler/tests/test_data/program_76.py b/crates/lean_compiler/tests/test_data/program_76.py new file mode 100644 index 00000000..ccf74758 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_76.py @@ -0,0 +1,13 @@ +from snark_lib import * + + +def main(): + x: Mut = 1 + cond = 1 + if cond == 1: + x = x + 10 + else: + x = x + 100 + x = x + 1000 # Should work on unified version + assert x == 1011 + return diff --git a/crates/lean_compiler/tests/test_data/program_77.py b/crates/lean_compiler/tests/test_data/program_77.py new file mode 100644 index 00000000..05f9e6a1 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_77.py @@ -0,0 +1,26 @@ +from snark_lib import * + + +def main(): + x: Mut = 0 + + cond1 = 1 + if cond1 == 1: + x = x + 1 + else: + x = x + 10 + + cond2 = 0 + if cond2 == 1: + x = x + 100 + else: + x = x + 200 + + cond3 = 1 + if cond3 == 1: + x = x + 1000 + else: + x = x + 2000 + + assert x == 1201 + return diff --git a/crates/lean_compiler/tests/test_data/program_78.py b/crates/lean_compiler/tests/test_data/program_78.py new file mode 100644 index 00000000..b05b2cb9 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_78.py @@ -0,0 +1,11 @@ +from snark_lib import * + + +def main(): + x: Mut = 5 + x = x + x # Should be 5 + 5 = 10, not 10 + 10 + assert x == 10 + + x = x * x # 10 * 10 = 100 + assert x == 100 + return diff --git a/crates/lean_compiler/tests/test_data/program_79.py b/crates/lean_compiler/tests/test_data/program_79.py new file mode 100644 index 00000000..3f150392 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_79.py @@ -0,0 +1,18 @@ +from snark_lib import * + + +def main(): + a: Mut = 10 + b: Mut = 20 + + cond = 1 + if cond == 1: + a = b + 1 # a = 21 + b = a + 1 # b = 22 (uses new a) + else: + b = a + 100 + a = b + 1 + + assert a == 21 + assert b == 22 + return diff --git a/crates/lean_compiler/tests/test_data/program_8.py b/crates/lean_compiler/tests/test_data/program_8.py new file mode 100644 index 00000000..55e315fb --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_8.py @@ -0,0 +1,17 @@ +from snark_lib import * + +BIG_ENDIAN = 0 +LITTLE_ENDIAN = 1 + + +def main(): + x = 2**20 - 1 + a = Array(31) + print(a) + hint_decompose_bits(x, a, 31, LITTLE_ENDIAN) + for i in range(0, 20): + debug_assert(a[i] == 1) + assert a[i] == 1 + for i in range(20, 31): + assert a[i] == 0 + return diff --git a/crates/lean_compiler/tests/test_data/program_80.py b/crates/lean_compiler/tests/test_data/program_80.py new file mode 100644 index 00000000..ced88d74 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_80.py @@ -0,0 +1,12 @@ +from snark_lib import * +def main(): + total: Mut = 0 + for i in unroll(0, 5): + if i == 2: + total = total + 100 + else if i == 4: + total = total + 1000 + else: + total = total + 1 + assert total == 1103 + return \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_81.py b/crates/lean_compiler/tests/test_data/program_81.py new file mode 100644 index 00000000..7034b6f6 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_81.py @@ -0,0 +1,16 @@ +from snark_lib import * +def main(): + assert test_func(0) == 11 + assert test_func(1) == 10 # no mutation + assert test_func(2) == 30 + return + +def test_func(sel): + x: Mut = 10 + match sel: + case 0: + x = x + 1 + case 1: + case 2: + x = x + 20 + return x \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_82.py b/crates/lean_compiler/tests/test_data/program_82.py new file mode 100644 index 00000000..c29c2157 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_82.py @@ -0,0 +1,23 @@ +from snark_lib import * + + +def main(): + assert test_func(0) == 11 + assert test_func(1) == 12 + assert test_func(2) == 13 + return + + +def test_func(sel): + x: Mut = 10 + match sel: + case 0: + x = x + 1 + case 1: + x = x + 1 + x = x + 1 + case 2: + x = x + 1 + x = x + 1 + x = x + 1 + return x diff --git a/crates/lean_compiler/tests/test_data/program_83.py b/crates/lean_compiler/tests/test_data/program_83.py new file mode 100644 index 00000000..aa423c9f --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_83.py @@ -0,0 +1,14 @@ +from snark_lib import * + + +def main(): + x: Imu + cond = 1 + if cond == 1: + x = 10 + else: + x = 20 + x2: Mut = x + x2 = x2 + 1 + assert x2 == 11 + return diff --git a/crates/lean_compiler/tests/test_data/program_84.py b/crates/lean_compiler/tests/test_data/program_84.py new file mode 100644 index 00000000..e462e5df --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_84.py @@ -0,0 +1,11 @@ +from snark_lib import * + + +def main(): + x: Mut = 3 + x = x * x + x # 3*3 + 3 = 12 + assert x == 12 + + x = (x + 1) * (x + 2) # 13 * 14 = 182 + assert x == 182 + return diff --git a/crates/lean_compiler/tests/test_data/program_85.py b/crates/lean_compiler/tests/test_data/program_85.py new file mode 100644 index 00000000..cdac8c29 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_85.py @@ -0,0 +1,29 @@ +from snark_lib import * + + +def main(): + assert test_func(0, 0) == 111 + assert test_func(0, 1) == 121 + assert test_func(1, 0) == 211 + assert test_func(1, 1) == 221 + return + + +def test_func(a, b): + x: Mut = 100 + match a: + case 0: + x = x + 10 + match b: + case 0: + x = x + 1 + case 1: + x = x + 11 + case 1: + x = x + 110 + match b: + case 0: + x = x + 1 + case 1: + x = x + 11 + return x diff --git a/crates/lean_compiler/tests/test_data/program_86.py b/crates/lean_compiler/tests/test_data/program_86.py new file mode 100644 index 00000000..8cd4004e --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_86.py @@ -0,0 +1,12 @@ +from snark_lib import * + + +def main(): + outer_sum: Mut = 0 + for i in unroll(0, 3): + inner_sum: Mut = 0 + for j in unroll(0, 4): + inner_sum = inner_sum + j + outer_sum = outer_sum + inner_sum + assert outer_sum == 18 + return diff --git a/crates/lean_compiler/tests/test_data/program_87.py b/crates/lean_compiler/tests/test_data/program_87.py new file mode 100644 index 00000000..8d4c12b1 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_87.py @@ -0,0 +1,17 @@ +from snark_lib import * + + +def main(): + x: Mut = 10 + y = 5 # immutable + + cond = 1 + if cond == 1: + x = x + y # 10 + 5 = 15 + z = x + y # 15 + 5 = 20, immutable + x = x + z # 15 + 20 = 35 + else: + x = x + 100 + + assert x == 35 + return diff --git a/crates/lean_compiler/tests/test_data/program_88.py b/crates/lean_compiler/tests/test_data/program_88.py new file mode 100644 index 00000000..dcf3af1d --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_88.py @@ -0,0 +1,16 @@ +from snark_lib import * +def main(): + assert test_func(0) == 11 + assert test_func(1) == 20 + assert test_func(2) == 30 + return + +def test_func(cond): + x: Mut = 10 + if cond == 0: + x = x + 1 + else if cond == 1: + x = x + 10 + else: + x = x + 20 + return x \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_89.py b/crates/lean_compiler/tests/test_data/program_89.py new file mode 100644 index 00000000..a39882bd --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_89.py @@ -0,0 +1,27 @@ +from snark_lib import * + + +def main(): + x: Mut = 0 + a = 1 + b = 1 + c = 1 + d = 1 + if a == 1: + x = x + 1 + if b == 1: + x = x + 10 + if c == 1: + x = x + 100 + if d == 1: + x = x + 1000 + else: + x = x + 2000 + else: + x = x + 200 + else: + x = x + 20 + else: + x = x + 2 + assert x == 1111 + return diff --git a/crates/lean_compiler/tests/test_data/program_9.py b/crates/lean_compiler/tests/test_data/program_9.py new file mode 100644 index 00000000..a6857a3b --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_9.py @@ -0,0 +1,8 @@ +from snark_lib import * + + +def main(): + for i in unroll(0, 5): + for j in unroll(i, 2 * i): + print(i, j) + return diff --git a/crates/lean_compiler/tests/test_data/program_90.py b/crates/lean_compiler/tests/test_data/program_90.py new file mode 100644 index 00000000..50ba475f --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_90.py @@ -0,0 +1,15 @@ +from snark_lib import * + + +def main(): + total: Mut = 0 + for i in unroll(0, 3): + match i: + case 0: + total = total + 1 + case 1: + total = total + 10 + case 2: + total = total + 100 + assert total == 111 + return diff --git a/crates/lean_compiler/tests/test_data/program_91.py b/crates/lean_compiler/tests/test_data/program_91.py new file mode 100644 index 00000000..7a70901c --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_91.py @@ -0,0 +1,16 @@ +from snark_lib import * + + +def main(): + x: Mut = 1 + cond = 1 + if cond == 1: + x = x + 1 # 2 + x = x * 2 # 4 + x = x + 3 # 7 + x = x * 2 # 14 + x = x + 1 # 15 + else: + x = x + 100 + assert x == 15 + return diff --git a/crates/lean_compiler/tests/test_data/program_92.py b/crates/lean_compiler/tests/test_data/program_92.py new file mode 100644 index 00000000..5eb0d932 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_92.py @@ -0,0 +1,16 @@ +from snark_lib import * + + +def main(): + x: Mut = 5 + cond = 1 + if cond == 1: + x = compute(x, 3) + else: + x = compute(x, 10) + assert x == 18 + return + + +def compute(a, b): + return a * b + b diff --git a/crates/lean_compiler/tests/test_data/program_93.py b/crates/lean_compiler/tests/test_data/program_93.py new file mode 100644 index 00000000..4b35eb0b --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_93.py @@ -0,0 +1,11 @@ +from snark_lib import * + + +def main(): + outer_x: Mut = 0 + for i in unroll(0, 2): + x: Mut = 1 # fresh x each iteration + x = x + i + outer_x = outer_x + x + assert outer_x == 3 + return diff --git a/crates/lean_compiler/tests/test_data/program_94.py b/crates/lean_compiler/tests/test_data/program_94.py new file mode 100644 index 00000000..9bcd19ff --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_94.py @@ -0,0 +1,26 @@ +from snark_lib import * + + +def main(): + x: Mut = 0 + + a = 1 + if a == 1: + x = x + 1 + else: + x = x + 100 + + b = 0 + if b == 1: + x = x * 100 + else: + x = x * 2 + + c = 1 + if c == 1: + x = x + 10 + else: + x = x + 1000 + + assert x == 12 + return diff --git a/crates/lean_compiler/tests/test_data/program_95.py b/crates/lean_compiler/tests/test_data/program_95.py new file mode 100644 index 00000000..269aeb03 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_95.py @@ -0,0 +1,20 @@ +from snark_lib import * + + +def main(): + evens: Mut = 0 + odds: Mut = 0 + all: Mut = 0 + + for i in unroll(0, 6): + all = all + 1 + remainder = i % 2 + if remainder == 0: + evens = evens + i + else: + odds = odds + i + + assert evens == 6 + assert odds == 9 + assert all == 6 + return diff --git a/crates/lean_compiler/tests/test_data/program_96.py b/crates/lean_compiler/tests/test_data/program_96.py new file mode 100644 index 00000000..bfcb5e5d --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_96.py @@ -0,0 +1,28 @@ +from snark_lib import * +def main(): + assert test_func(0) == 10 + assert test_func(1) == 11 + assert test_func(2) == 12 + assert test_func(3) == 14 + assert test_func(4) == 18 + return + +def test_func(sel): + x: Mut = 10 + match sel: + case 0: + case 1: + x = x + 1 + case 2: + x = x + 1 + x = x + 1 + case 3: + x = x + 1 + x = x + 1 + x = x + 2 + case 4: + x = x + 1 + x = x + 1 + x = x + 2 + x = x + 4 + return x \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_97.py b/crates/lean_compiler/tests/test_data/program_97.py new file mode 100644 index 00000000..2d8ef4f5 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_97.py @@ -0,0 +1,28 @@ +from snark_lib import * +def main(): + assert test_func(0, 0, 0) == 1000 + assert test_func(0, 0, 1) == 1001 + assert test_func(0, 1, 0) == 1010 + assert test_func(1, 0, 0) == 1100 + assert test_func(1, 1, 1) == 1111 + return + +def test_func(a, b, c): + x: Mut = 0 + if a == 0: + x = x + 1000 + if b == 0: + if c == 0: + else: + x = x + 1 + else: + x = x + 10 + if c == 1: + x = x + 1 + else: + x = x + 1100 + if b == 1: + x = x + 10 + if c == 1: + x = x + 1 + return x \ No newline at end of file diff --git a/crates/lean_compiler/tests/test_data/program_99.py b/crates/lean_compiler/tests/test_data/program_99.py new file mode 100644 index 00000000..40d61c71 --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_99.py @@ -0,0 +1,13 @@ +from snark_lib import * + + +def main(): + result = accumulate(5) + assert result == 8 + return + + +def accumulate(x: Mut): + for i in unroll(0, 3): + x = x + i + return x diff --git a/crates/lean_compiler/tests/test_data/program_range_check.py b/crates/lean_compiler/tests/test_data/program_range_check.py new file mode 100644 index 00000000..fce93afa --- /dev/null +++ b/crates/lean_compiler/tests/test_data/program_range_check.py @@ -0,0 +1,40 @@ +from snark_lib import * + + +def main(): + # Test range check with various values + x = 500 + assert x < 1000 + + y = 0 + assert y < 10 + + z = 999 + assert z < 1000 + + # Test with computed value + a = 100 + 200 + assert a < 400 + + # Test <= (becomes < bound+1) + b = 999 + assert b <= 999 + + c = 0 + assert c <= 0 + + # Test with non-constant bound + bound = 500 + d = 100 + assert d < bound + + # Test with computed bound + e = 50 + f = 200 + 100 # f = 300 + assert e < f + + for i in range(50, 100): + for j in range(i - 10, i + 1): + assert j <= i + + return diff --git a/crates/lean_compiler/tests/test_performance.rs b/crates/lean_compiler/tests/test_performance.rs new file mode 100644 index 00000000..ed99bd3a --- /dev/null +++ b/crates/lean_compiler/tests/test_performance.rs @@ -0,0 +1,32 @@ +use lean_compiler::*; +use lean_vm::*; + +fn test_data_dir() -> String { + let manifest_dir = env!("CARGO_MANIFEST_DIR"); + format!("{manifest_dir}/tests/test_data") +} + +/// Helper to get the number of cycles for a program file +fn get_cycle_count(path: &str) -> usize { + let bytecode = compile_program(&ProgramSource::Filepath(path.to_string())); + let result = try_execute_bytecode(&bytecode, (&[], &[]), false, &vec![]).unwrap(); + result.pcs.len() +} + +#[test] +fn test_constant_if_else_optimization() { + let path_with_conditions = format!("{}/perf_constant_if_with_conditions.py", test_data_dir()); + let path_baseline = format!("{}/perf_constant_if_baseline.py", test_data_dir()); + + let cycles_with_conditions = get_cycle_count(&path_with_conditions); + let cycles_baseline = get_cycle_count(&path_baseline); + + assert_eq!( + cycles_with_conditions, cycles_baseline, + "Constant if/else conditions should be eliminated at compile time.\n\ + Program with conditions: {} cycles\n\ + Baseline (no conditions): {} cycles\n\ + Expected equal cycle counts.", + cycles_with_conditions, cycles_baseline + ); +} diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md new file mode 100644 index 00000000..71ee29f5 --- /dev/null +++ b/crates/lean_compiler/zkDSL.md @@ -0,0 +1,595 @@ +# zkDSL Language Reference + +## Program Structure + +``` +from snark_lib import * # Python compatibility (ignored by compiler) +from dir.file import * # imports (optional, Python-style) +NAME = value # constants (optional, uppercase by convention) +def main(): # entry point (required) + ... +def helper(): # other functions (optional) + ... +``` + +The `from snark_lib import *` line imports Python definitions for zkDSL primitives (Array, DynArray, Mut, Const, etc.), allowing `.py` files to be executed as normal Python scripts for testing. The zkDSL compiler ignores this import line. + +To run zkDSL files as Python scripts, run from the file's directory with PYTHONPATH pointing to the lean_compiler crate (for snark_lib.py): +```bash +export PYTHONPATH=/path/to/repo/crates/lean_compiler +cd crates/lean_compiler/tests/test_data +python program_0.py +``` + +## Constants + +Constants are declared at the top level (outside functions) using simple assignment. By convention, constant names are UPPERCASE. + +``` +X = 42 +ARR = [1, 2, 3] +NESTED = [[1, 2], [3]] +``` + +### Multi-Dimensional Const Arrays + +Const arrays can be nested to any depth, and inner arrays can have different lengths (ragged arrays). All const array values are resolved at compile time. + +``` +MATRIX = [[1, 2, 3], [4, 5], [6, 7, 8, 9]] # ragged 2D array +DEEP = [[[1, 2], [3]], [[4, 5, 6]]] # 3D array +``` + +**Accessing elements:** Use chained indexing with compile-time indices: +``` +x = MATRIX[0][2] # x = 3 +y = DEEP[1][0][1] # y = 5 +``` + +**Using `len()` on inner arrays:** The `len()` function can be applied to any level of a nested const array, including inner arrays accessed by index. This is particularly useful for iterating over ragged arrays where each row has a different length: + +``` +len(MATRIX) # 3 +len(MATRIX[0]) # 3 +len(DEEP[0][0]) # 2 +``` + +**Important:** When using `len()` on an inner array with a variable index (e.g., `len(ARR[i])`), the index must be a compile-time constant. This works inside `unroll` loops because the loop variable becomes a compile-time constant during unrolling. + +**Example: Iterating over a ragged 2D array:** +``` +MATRIX = [[1, 2, 3], [4, 5], [6, 7, 8, 9]] + +def main(): + total: Mut = 0 + for row in unroll(0, len(MATRIX)): + for col in unroll(0, len(MATRIX[row])): + total = total + MATRIX[row][col] + assert total == 45 # 1+2+3+4+5+6+7+8+9 + return +``` + +## Functions + +``` +def add(a, b): # return count is inferred from return statements + return a + b + +def swap(a, b): # multiple return values + return b, a + +def main(): + x, y = swap(1, 2) + return +``` + +The number of return values is automatically inferred from the `return` statements. All return statements in a function must return the same number of values. + +### Parameter Modifiers + +| Syntax | Meaning | +|--------|---------| +| `x` | immutable parameter | +| `x: Const` | compile-time value (enables `unroll` with dynamic bounds) | +| `x: Mut` | mutable within function body only | + +**All parameters are pass-by-value.** The `: Mut` modifier allows reassignment within the function, but changes are not visible to the caller. Use return values to communicate results. + +``` +def repeat(n: Const): # Const enables unroll + sum: Mut = 0 + for i in unroll(0, n): + sum = sum + i + return sum + +def double(x: Mut): # Mut allows local reassignment + x = x * 2 # only affects local copy + return x # must return to pass result back +``` + +### Inline Functions +Use the `@inline` decorator to mark functions for inlining at call sites: +``` +@inline +def square(x): + return x * x +``` +**Note:** Inline functions cannot have `: Mut` parameters. + +## Variables + +| Declaration | Mutability | Notes | +|-------------|------------|-------| +| `x = 10` | immutable | cannot be reassigned | +| `x: Mut = 10` | mutable | can be reassigned | +| `x: Imu` | immutable | forward declaration, assign exactly once later | +| `x: Mut` | mutable | forward declaration for mutable variable | + +### Forward Declarations + +Use `x: Imu` when a variable must be assigned in different branches: + +``` +result: Imu # immutable: assign exactly once +if cond == 1: + result = 10 +else: + result = 20 +# result cannot be reassigned after this +``` + +Use `x: Mut` when you need the variable to be mutable after assignment: + +``` +x: Mut +if cond == 1: + x = 10 +else: + x = 20 +x = x + 1 # OK: x was declared as mutable +``` + +### Tuple Assignments with Mutable Variables + +When a function returns multiple values and some need to be mutable, use forward declarations: + +``` +b: Mut # declare b as mutable +a, b, c = some_function() +# a and c are immutable, b is mutable +b = b + 1 # OK +# a = 5 # ERROR: a is immutable +``` + +This is useful when a function returns multiple values and only some need to be modified later. + +## Memory and Arrays + +``` +buffer = Array(16) # allocate 16 field elements +buffer[0] = 42 +x = buffer[5] + +matrix = Array(64) # 2D via manual indexing +matrix[row * 8 + col] = value + +ptr2 = ptr + 5 # pointer arithmetic +ptr2[0] = 100 # same as ptr[5] = 100 +``` + +**Memory is write-once.** Due to SSA constraints, each memory location can only hold one value. Writing to the same location multiple times is allowed, but all writes must produce the same value—otherwise a runtime error occurs. + +``` +arr = Array(3) +arr[0] = 10 # OK: first write +arr[0] = 10 # OK: same value +arr[0] = 20 # ERROR: different value at same location +``` + +Use `mut` variables when you need mutability, the compiler cannot handle mutability on hand-written allocated memory ("Array(...)"). + +## DynArray (Compile-Time Dynamic Arrays) + +DynArrays are compile-time constructs for building dynamic arrays. Unlike `Array`, DynArrays track structure at compile time—each element gets its own memory slot. + +``` +v = DynArray([1, 2, 3]) # create dynamic array +v.push(4) # append element +v.pop() # remove last element (does not return it) +x = v[2] # access (index must be compile-time constant) +n = len(v) # get length +``` + +### Nested DynArrays + +``` +matrix = DynArray([DynArray([1, 2]), DynArray([3, 4, 5])]) +matrix[1].push(6) # push to inner array +matrix[0].pop() # pop from inner array +x = matrix[0][0] # x = 1 +n = len(matrix[1]) # n = 4 +``` + +### Building DynArrays in Loops + +Use `unroll` loops to build arrays dynamically: + +``` +v = DynArray([]) +for i in unroll(0, 5): + v.push(i * i) # v = [0, 1, 4, 9, 16] +``` + +### Restrictions + +DynArrays are compile-time only. The compiler must know the exact structure at every point: + +1. **Indices must be compile-time constants** (literals or unroll loop variables) +2. **Push/pop to outer-scope arrays forbidden** inside `if/else`, `match`, or non-unrolled loops +3. **DynArrays cannot be passed to non-inlined functions** +4. **Pop on empty array is a compile error** + +``` +# OK: local array in branch +if cond == 1: + v = DynArray([1, 2]) + v.push(3) + +# ERROR: push to outer-scope array in branch +v = DynArray([1, 2]) +if cond == 1: + v.push(3) # compile error + +# OK: same variable name in different branches +if cond == 1: + v = DynArray([1]) +else: + v = DynArray([2, 3]) # different structure, but only one executes +``` + +## Control Flow + +### If/Else +``` +if x == 0: + y = 1 +else if x == 1: + y = 2 +else: + y = 3 +``` +Comparison operators: `==`, `!=` + +### Match +Patterns must be consecutive integers starting from 0: +``` +match value: + case 0: + result = 100 + case 1: + result = 200 + case 2: + result = 300 +``` + +### For Loops +``` +for i in range(0, 10): # standard loop + ... +for i in unroll(0, 4): # unrolled at compile time + ... +``` +Use `unroll` when bounds are const or compile-time expansion is needed. + +**Mutable variables in non-unrolled loops:** Mutable variables can be modified inside non-unrolled loops. The compiler automatically transforms these into buffer-based implementations: + +``` +sum: Mut = 0 +for i in range(1, 11): + sum += i +assert sum == 55 +``` + +Loops limitations: +- no "continue" or "break" are supported yet +- the "return" keyword is not supported inside the body of a normal (non-unrolled) loop (because under the hood normal loops are transformed into recursive functions) + +## Expressions + +### Arithmetic +- `+`, `-`, `*`, `/` (field operations): allowed at runtime +- `%` (modulo), `**` (exponentiation): only allowed at compile time + +### Compound Assignment +Syntactic sugar for updating mutable variables: +``` +x: Mut = 10 +x += 5 # equivalent to: x = x + 5 +x -= 3 # equivalent to: x = x - 3 +x *= 2 # equivalent to: x = x * 2 +x /= 4 # equivalent to: x = x / 4 +``` + +### Built-in Functions +Only allowed at compile time: + +``` +log2_ceil(x) # ceiling of log2 +next_multiple_of(x, n) # smallest multiple of n >= x +saturating_sub(a, b) # max(0, a - b) +len(array) # length of const array or vector +``` + +## Assertions + +``` +# constraint in proof +assert x == y +assert x != y +# unconditional failure (panic) +assert False +assert False, "error message" +# runtime check only (not constrained by the snark) +debug_assert(x == y) +debug_assert(x != y) +debug_assert(x < y) +``` + +## Comments + +``` +# Single-line comment + +""" +Multi-line comment +can span multiple lines +""" +``` + +## Imports + +``` +from utils import * # imports utils.py (relative to import root) +from dir.subdir.file import * # imports dir/subdir/file.py +``` + +## Built-in Constants + +``` +NONRESERVED_PROGRAM_INPUT_START # pointer to public input +ZERO_VEC_PTR # pre-initialized zeros +ONE_VEC_PTR # [1, 0, 0, ...] +``` + +## Precompiles + +### poseidon16 +``` +COMPRESSION = 1 # (output: 8 elements) (For now this is not a real permutation in the cryptographic sense, see Plonky3 PseudoCompression trait, but it will change in the future) +PERMUTATION = 0 # full permutation (output: 16 elements) + +poseidon16(left, right, output, mode) +``` +- `left`, `right`: pointers to 8 field elements each +- `output`: pointer to result (8 or 16 elements depending on mode) +- Used for Merkle tree hashing and Fiat-Shamir: +``` +poseidon16(leaf_a, leaf_b, parent_hash, COMPRESSION) +poseidon16(state, data, new_state, PERMUTATION) +``` + +### dot_product +``` +DIM = 5 # extension field degree +BE = 1 # base × extension +EE = 0 # extension × extension + +dot_product(a, b, result, length, mode) +``` +- `length`: number of elements in the dot product +- `b`: pointer to `length * DIM` field elements (extension elements) +- `result`: pointer to output (DIM=5 field elements) +- `mode`: + - `EE`: `a` points to `length * DIM` field elements (extension field elements) + - `BE`: `a` points to `length` field elements (base field elements) +``` +# Multiply two extension field elements (EE mode, length=1) +dot_product(x, y, z, 1, EE) # z = x * y + +# Copy extension element (multiply by [1,0,0,0,0]) +dot_product(src, ONE_VEC_PTR, dst, 1, EE) +``` + +## Debugging + +``` +print(value) +print(a, b, c) +``` + +## Example + +``` +SIZE = 8 + +def main(): + arr = Array(SIZE) + for i in unroll(0, SIZE): + arr[i] = i * i + sum = compute_sum(arr, SIZE) + assert sum == 140 + return + +def compute_sum(ptr, n: Const): + acc: Mut = 0 + for i in unroll(0, n): + acc = acc + ptr[i] + return acc +``` + +## Line Continuation + +Like Python, lines can be continued in two ways: + +### Implicit continuation (inside parentheses/brackets/braces) + +Expressions inside `()`, `[]`, or `{}` can span multiple lines without any special syntax: + +``` +result = function_call( + arg1, + arg2, + arg3 +) + +arr = DynArray([ + 1, + 2, + 3 +]) +``` + +### Explicit continuation with backslash + +Long lines can also be split using `\` at the end of a line: + +``` +x = very_long_function_name(arg1, \ + arg2, \ + arg3) + +y = 1 + 2 + \ + 3 + 4 +``` + +The `\` and following newline are replaced with a single space. Any whitespace after `\` and before the newline is ignored. + +## Tips + +1. Use `unroll` for small, fixed-size loops +2. Use `const` parameters when loop bounds depend on arguments +3. Use `mut` sparingly - immutable is easier to verify +4. Use `x: Imu` or `x: Mut` for forward-declaring variables that will be assigned in branches +5. Match patterns must be consecutive from 0 and exhaustive + +## Example: From high level syntactic sugar to minimal ISA, with read-only memory + +Take the following program: + +``` +def main(): + x: Mut = 0 + y: Mut = 3 + x += y + y += x + for i in range(4, 6): + x += i + x += y + y = i + y += x + assert x == 35 + assert y == 40 + return +``` + +First, we use buffers to handle mutable variables across (non-unrolled) loops. + +``` +def main(): + x: Mut = 0 + y: Mut = 3 + x += y + y += x + size = 6 - 4 + x_buff = Array(size + 1) + x_buff[0] = x + y_buff = Array(size + 1) + y_buff[0] = y + for i in range(4, 6): + buff_idx = i - 4 + x_body: Mut = x_buff[buff_idx] + y_body: Mut = y_buff[buff_idx] + x_body += i + x_body += y_body + y_body = i + y_body += x_body + next_idx = buff_idx + 1 + x_buff[next_idx] = x_body + y_buff[next_idx] = y_body + x = x_buff[size] + y = y_buff[size] + assert x == 35 + assert y == 40 + return +``` + +Then, use auxiliary variables to transform it into SSA form (Static Single-Assignment): + + +``` +def main(): + x = 0 + y = 3 + x2 = x + y + y2 = y + x2 + size = 6 - 4 + x_buff = Array(size + 1) + x_buff[0] = x2 + y_buff = Array(size + 1) + y_buff[0] = y2 + for i in range(4, 6): + buff_idx = i - 4 + x_body1 = x_buff[buff_idx] + y_body1 = y_buff[buff_idx] + x_body2 = x_body1 + i + x_body3 = x_body2 + y_body1 + y_body2 = i + y_body3 = y_body2 + x_body3 + next_idx = buff_idx + 1 + x_buff[next_idx] = x_body3 + y_buff[next_idx] = y_body3 + x3 = x_buff[size] + y3 = y_buff[size] + assert x3 == 35 + assert y3 == 40 + return +``` + +Finally, transform the loop into a recursive function: + +``` +def main(): + x = 0 + y = 3 + x2 = x + y + y2 = y + x2 + size = 6 - 4 + x_buff = Array(size + 1) + x_buff[0] = x2 + y_buff = Array(size + 1) + y_buff[0] = y2 + loop(4, x_buff, y_buff) + x3 = x_buff[size] + y3 = y_buff[size] + assert x3 == 35 + assert y3 == 40 + return + +def loop(i, x_buff, y_buff): + if i == 6: + return + else: + buff_idx = i - 4 + x_body1 = x_buff[buff_idx] + y_body1 = y_buff[buff_idx] + x_body2 = x_body1 + i + x_body3 = x_body2 + y_body1 + y_body2 = i + y_body3 = y_body2 + x_body3 + next_idx = buff_idx + 1 + x_buff[next_idx] = x_body3 + y_buff[next_idx] = y_body3 + loop(i + 1, x_buff, y_buff) + return +``` + diff --git a/crates/lean_prover/Cargo.toml b/crates/lean_prover/Cargo.toml index d43b4368..ed4593e7 100644 --- a/crates/lean_prover/Cargo.toml +++ b/crates/lean_prover/Cargo.toml @@ -14,20 +14,16 @@ xmss.workspace = true rand.workspace = true p3-poseidon2.workspace = true p3-koala-bear.workspace = true -p3-challenger.workspace = true -p3-air.workspace = true p3-symmetric.workspace = true p3-util.workspace = true whir-p3.workspace = true tracing.workspace = true air.workspace = true sub_protocols.workspace = true -lookup.workspace = true lean_vm.workspace = true lean_compiler.workspace = true witness_generation.workspace = true multilinear-toolkit.workspace = true -poseidon_circuit.workspace = true itertools.workspace = true [dev-dependencies] diff --git a/crates/lean_prover/src/common.rs b/crates/lean_prover/src/common.rs index e5cef782..b196fde7 100644 --- a/crates/lean_prover/src/common.rs +++ b/crates/lean_prover/src/common.rs @@ -1,47 +1,6 @@ -use std::collections::BTreeMap; - -use multilinear_toolkit::prelude::*; -use p3_koala_bear::{KOALABEAR_RC16_INTERNAL, KOALABEAR_RC24_INTERNAL}; -use poseidon_circuit::{PoseidonGKRLayers, default_cube_layers}; -use sub_protocols::ColDims; - use crate::*; use lean_vm::*; - -pub(crate) const N_COMMITED_CUBES_P16: usize = KOALABEAR_RC16_INTERNAL.len() - 2; -pub(crate) const N_COMMITED_CUBES_P24: usize = KOALABEAR_RC24_INTERNAL.len() - 2; - -pub(crate) fn get_base_dims( - log_public_memory: usize, - private_memory_len: usize, - (p16_gkr_layers, p24_gkr_layers): ( - &PoseidonGKRLayers<16, N_COMMITED_CUBES_P16>, - &PoseidonGKRLayers<24, N_COMMITED_CUBES_P24>, - ), - table_heights: &BTreeMap, -) -> Vec> { - let p16_default_cubes = default_cube_layers::(p16_gkr_layers); - let p24_default_cubes = default_cube_layers::(p24_gkr_layers); - - let mut dims = [ - vec![ - ColDims::padded_with_public_data(Some(log_public_memory), private_memory_len, F::ZERO), // memory - ], - p16_default_cubes - .iter() - .map(|&c| ColDims::padded(table_heights[&Table::poseidon16_core()].n_rows_non_padded_maxed(), c)) - .collect::>(), // commited cubes for poseidon16 - p24_default_cubes - .iter() - .map(|&c| ColDims::padded(table_heights[&Table::poseidon24_core()].n_rows_non_padded_maxed(), c)) - .collect::>(), - ] - .concat(); - for (table, height) in table_heights { - dims.extend(table.committed_dims(height.n_rows_non_padded_maxed())); - } - dims -} +use multilinear_toolkit::prelude::*; pub(crate) fn fold_bytecode(bytecode: &Bytecode, folding_challenges: &MultilinearPoint) -> Vec { let encoded_bytecode = padd_with_zero_to_next_power_of_two( @@ -54,12 +13,6 @@ pub(crate) fn fold_bytecode(bytecode: &Bytecode, folding_challenges: &Multilinea fold_multilinear_chunks(&encoded_bytecode, folding_challenges) } -pub(crate) fn initial_and_final_pc_conditions(log_n_cycles: usize) -> (Evaluation, Evaluation) { - let initial_pc_statement = Evaluation::new(EF::zero_vec(log_n_cycles), EF::from_usize(STARTING_PC)); - let final_pc_statement = Evaluation::new(vec![EF::ONE; log_n_cycles], EF::from_usize(ENDING_PC)); - (initial_pc_statement, final_pc_statement) -} - fn split_at(stmt: &MultiEvaluation, start: usize, end: usize) -> Vec> { vec![MultiEvaluation::new( stmt.point.clone(), diff --git a/crates/lean_prover/src/lib.rs b/crates/lean_prover/src/lib.rs index 9de8c003..e14d05bb 100644 --- a/crates/lean_prover/src/lib.rs +++ b/crates/lean_prover/src/lib.rs @@ -9,24 +9,52 @@ use witness_generation::*; mod common; pub mod prove_execution; +#[cfg(test)] +mod test_zkvm; pub mod verify_execution; -const UNIVARIATE_SKIPS: usize = 3; -const TWO_POW_UNIVARIATE_SKIPS: usize = 1 << UNIVARIATE_SKIPS; +pub use witness_generation::bytecode_to_multilinear_polynomial; -pub const LOG_SMALLEST_DECOMPOSITION_CHUNK: usize = 12; // TODO optimize +// Right now, hash digests = 8 koala-bear (p = 2^31 - 2^24 + 1, i.e. ≈ 30.98 bits per field element) +// so ≈ 123.92 bits of security against collisions +pub const SECURITY_BITS: usize = 123; // TODO 128 bits security (with Poseidon over 20 field elements) -const DOT_PRODUCT_UNIVARIATE_SKIPS: usize = 1; -const TWO_POW_DOT_PRODUCT_UNIVARIATE_SKIPS: usize = 1 << DOT_PRODUCT_UNIVARIATE_SKIPS; +// Provable security (no proximity gaps conjectures) +pub const SECURITY_REGIME: SecurityAssumption = SecurityAssumption::JohnsonBound; -pub fn whir_config_builder() -> WhirConfigBuilder { +pub const GRINDING_BITS: usize = 16; + +pub const STARTING_LOG_INV_RATE_BASE: usize = 2; + +pub const STARTING_LOG_INV_RATE_EXTENSION: usize = 3; + +#[derive(Debug)] +pub struct SnarkParams { + pub first_whir: WhirConfigBuilder, + pub second_whir: WhirConfigBuilder, +} + +impl Default for SnarkParams { + fn default() -> Self { + Self { + first_whir: whir_config_builder(STARTING_LOG_INV_RATE_BASE, 7, 5), + second_whir: whir_config_builder(STARTING_LOG_INV_RATE_EXTENSION, 4, 1), + } + } +} + +pub fn whir_config_builder( + starting_log_inv_rate: usize, + first_folding_factor: usize, + rs_domain_initial_reduction_factor: usize, +) -> WhirConfigBuilder { WhirConfigBuilder { - folding_factor: FoldingFactor::new(7, 4), - soundness_type: SecurityAssumption::CapacityBound, - pow_bits: 16, + folding_factor: FoldingFactor::new(first_folding_factor, 4), + soundness_type: SECURITY_REGIME, + pow_bits: GRINDING_BITS, max_num_variables_to_send_coeffs: 6, - rs_domain_initial_reduction_factor: 5, - security_level: 128, - starting_log_inv_rate: 1, + rs_domain_initial_reduction_factor, + security_level: SECURITY_BITS, + starting_log_inv_rate, } } diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index 2483dfc6..12da4c96 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -3,44 +3,44 @@ use std::collections::BTreeMap; use crate::common::*; use crate::*; use air::prove_air; -use itertools::Itertools; use lean_vm::*; -use lookup::{compute_pushforward, prove_gkr_quotient, prove_logup_star}; use multilinear_toolkit::prelude::*; -use p3_air::Air; -use p3_util::{log2_ceil_usize, log2_strict_usize}; -use poseidon_circuit::{PoseidonGKRLayers, prove_poseidon_gkr}; -use std::collections::VecDeque; + +use p3_util::log2_ceil_usize; use sub_protocols::*; use tracing::info_span; use utils::{build_prover_state, padd_with_zero_to_next_power_of_two}; -use whir_p3::{WhirConfig, WhirConfigBuilder, second_batched_whir_config_builder}; -use xmss::{Poseidon16History, Poseidon24History}; +use whir_p3::{SparseStatement, SparseValue, WhirConfig}; +use xmss::Poseidon16History; + +#[derive(Debug)] +pub struct ExecutionProof { + pub proof: Vec, + pub proof_size_fe: usize, + pub exec_summary: String, + pub first_whir_n_vars: usize, +} pub fn prove_execution( bytecode: &Bytecode, (public_input, private_input): (&[F], &[F]), - whir_config_builder: WhirConfigBuilder, - no_vec_runtime_memory: usize, // size of the "non-vectorized" runtime memory + poseidons_16_precomputed: &Poseidon16History, + params: &SnarkParams, vm_profiler: bool, - (poseidons_16_precomputed, poseidons_24_precomputed): (&Poseidon16History, &Poseidon24History), - merkle_path_hints: VecDeque>, -) -> (Proof, String) { +) -> ExecutionProof { let mut exec_summary = String::new(); let ExecutionTrace { traces, public_memory_size, - mut non_zero_memory_size, - mut memory, // padded with zeros to next power of two + non_zero_memory_size: _, // TODO use the information of the ending zeros for speedup + mut memory, // padded with zeros to next power of two } = info_span!("Witness generation").in_scope(|| { let mut execution_result = info_span!("Executing bytecode").in_scope(|| { execute_bytecode( bytecode, (public_input, private_input), - no_vec_runtime_memory, vm_profiler, - (poseidons_16_precomputed, poseidons_24_precomputed), - merkle_path_hints, + poseidons_16_precomputed, ) }); exec_summary = std::mem::take(&mut execution_result.summary); @@ -49,17 +49,13 @@ pub fn prove_execution( if memory.len() < 1 << MIN_LOG_MEMORY_SIZE { memory.resize(1 << MIN_LOG_MEMORY_SIZE, F::ZERO); - non_zero_memory_size = 1 << MIN_LOG_MEMORY_SIZE; } - let public_memory = &memory[..public_memory_size]; - let private_memory = &memory[public_memory_size..non_zero_memory_size]; - let log_public_memory = log2_strict_usize(public_memory.len()); - let mut prover_state = build_prover_state::(false); + let mut prover_state = build_prover_state(); prover_state.add_base_scalars( &[ - vec![private_memory.len()], - traces.values().map(|t| t.n_rows_non_padded()).collect::>(), + vec![log2_strict_usize(memory.len())], + traces.values().map(|t| t.log_n_rows).collect::>(), ] .concat() .into_iter() @@ -67,545 +63,274 @@ pub fn prove_execution( .collect::>(), ); - // only keep tables with non-zero rows - let traces: BTreeMap<_, _> = traces - .into_iter() - .filter(|(table, trace)| trace.n_rows_non_padded() > 0 || table == &Table::execution() || table.is_poseidon()) - .collect(); - - let p16_gkr_layers = PoseidonGKRLayers::<16, N_COMMITED_CUBES_P16>::build(Some(VECTOR_LEN)); - let p24_gkr_layers = PoseidonGKRLayers::<24, N_COMMITED_CUBES_P24>::build(None); - - let p16_witness = generate_poseidon_witness_helper( - &p16_gkr_layers, - &traces[&Table::poseidon16_core()], - POSEIDON_16_CORE_COL_INPUT_START, - Some(&traces[&Table::poseidon16_core()].base[POSEIDON_16_CORE_COL_COMPRESSION].clone()), - ); - let p24_witness = generate_poseidon_witness_helper( - &p24_gkr_layers, - &traces[&Table::poseidon24_core()], - POSEIDON_24_CORE_COL_INPUT_START, - None, - ); - - let commitmenent_extension_helper = traces - .iter() - .filter(|(table, _)| table.n_commited_columns_ef() > 0) - .map(|(table, trace)| { - ( - *table, - ExtensionCommitmentFromBaseProver::before_commitment( - table - .commited_columns_ef() - .iter() - .map(|&c| &trace.ext[c][..]) - .collect::>(), - ), - ) - }) - .collect::>(); - - let base_dims = get_base_dims( - log_public_memory, - private_memory.len(), - (&p16_gkr_layers, &p24_gkr_layers), - &traces.iter().map(|(table, trace)| (*table, trace.height)).collect(), - ); - - let mut base_pols = [ - vec![memory.as_slice()], - p16_witness - .committed_cubes - .iter() - .map(|s| FPacking::::unpack_slice(s)) - .collect::>(), - p24_witness - .committed_cubes - .iter() - .map(|s| FPacking::::unpack_slice(s)) - .collect::>(), - ] - .concat(); - for (table, trace) in &traces { - base_pols.extend(table.committed_columns(trace, commitmenent_extension_helper.get(table))); - } + // TODO parrallelize + let mut acc = F::zero_vec(memory.len()); + info_span!("Building memory access count").in_scope(|| { + for (table, trace) in &traces { + for lookup in table.lookups_f() { + for i in &trace.base[lookup.index] { + for j in 0..lookup.values.len() { + acc[i.to_usize() + j] += F::ONE; + } + } + } + for lookup in table.lookups_ef() { + for i in &trace.base[lookup.index] { + for j in 0..DIMENSION { + acc[i.to_usize() + j] += F::ONE; + } + } + } + } + }); // 1st Commitment - let packed_pcs_witness_base = packed_pcs_commit( - &whir_config_builder, - &base_pols, - &base_dims, - &mut prover_state, - LOG_SMALLEST_DECOMPOSITION_CHUNK, - ); - - let bus_challenge = prover_state.sample(); - let fingerprint_challenge = prover_state.sample(); - - let mut bus_quotients: BTreeMap = Default::default(); - let mut air_points: BTreeMap> = Default::default(); - let mut evals_f: BTreeMap> = Default::default(); - let mut evals_ef: BTreeMap> = Default::default(); - - for (table, trace) in &traces { - let (this_bus_quotient, this_air_point, this_evals_f, this_evals_ef) = - prove_bus_and_air(&mut prover_state, table, trace, bus_challenge, fingerprint_challenge); - bus_quotients.insert(*table, this_bus_quotient); - air_points.insert(*table, this_air_point); - evals_f.insert(*table, this_evals_f); - evals_ef.insert(*table, this_evals_ef); + let packed_pcs_witness_base = packed_pcs_commit(&mut prover_state, ¶ms.first_whir, &memory, &acc, &traces); + let first_whir_n_vars = packed_pcs_witness_base.packed_polynomial.by_ref().n_vars(); + + // logup (GKR) + let logup_c = prover_state.sample(); + prover_state.duplexing(); + let logup_alpha = prover_state.sample(); + prover_state.duplexing(); + + let logup_statements = prove_generic_logup(&mut prover_state, logup_c, logup_alpha, &memory, &acc, &traces); + let mut committed_statements: CommittedStatements = Default::default(); + for table in ALL_TABLES { + committed_statements.insert( + table, + vec![( + logup_statements.points[&table].clone(), + logup_statements.columns_values[&table].clone(), + )], + ); } - assert_eq!(bus_quotients.values().copied().sum::(), EF::ZERO); + let bus_beta = prover_state.sample(); + prover_state.duplexing(); + let air_alpha = prover_state.sample(); + let air_alpha_powers: Vec = air_alpha.powers().collect_n(max_air_constraints() + 1); + + for (table, trace) in traces.iter() { + let this_air_claims = prove_bus_and_air( + &mut prover_state, + table, + trace, + logup_c, + logup_alpha, + bus_beta, + air_alpha_powers.clone(), + &logup_statements.points[table], + logup_statements.bus_numerators_values[table], + logup_statements.bus_denominators_values[table], + ); + committed_statements.get_mut(table).unwrap().extend(this_air_claims); + } let bytecode_compression_challenges = MultilinearPoint(prover_state.sample_vec(log2_ceil_usize(N_INSTRUCTION_COLUMNS))); let folded_bytecode = fold_bytecode(bytecode, &bytecode_compression_challenges); - let bytecode_lookup_claim_1 = Evaluation::new( - air_points[&Table::execution()].clone(), - padd_with_zero_to_next_power_of_two(&evals_f[&Table::execution()][..N_INSTRUCTION_COLUMNS]) - .evaluate(&bytecode_compression_challenges), + let bytecode_air_entry = &mut committed_statements.get_mut(&Table::execution()).unwrap()[2]; + let bytecode_air_point = bytecode_air_entry.0.clone(); + let mut bytecode_air_values = vec![]; + for bytecode_col_index in N_COMMITTED_EXEC_COLUMNS..N_COMMITTED_EXEC_COLUMNS + N_INSTRUCTION_COLUMNS { + bytecode_air_values.push(bytecode_air_entry.1.remove(&bytecode_col_index).unwrap()); + } + + let bytecode_lookup_claim = Evaluation::new( + bytecode_air_point.clone(), + padd_with_zero_to_next_power_of_two(&bytecode_air_values).evaluate(&bytecode_compression_challenges), ); - let bytecode_poly_eq_point = eval_eq(&air_points[&Table::execution()]); - let bytecode_pushforward = compute_pushforward( - &traces[&Table::execution()].base[COL_INDEX_PC], + let bytecode_poly_eq_point = eval_eq(&bytecode_lookup_claim.point); + let bytecode_pushforward = MleOwned::Extension(compute_pushforward( + &traces[&Table::execution()].base[COL_PC], folded_bytecode.len(), &bytecode_poly_eq_point, - ); + )); - let normal_lookup_into_memory = NormalPackedLookupProver::step_1( - &mut prover_state, - &memory, - traces - .iter() - .flat_map(|(table, trace)| table.normal_lookup_index_columns_f(trace)) - .collect(), - traces - .iter() - .flat_map(|(table, trace)| table.normal_lookup_index_columns_ef(trace)) - .collect(), - traces - .iter() - .flat_map(|(table, trace)| vec![trace.n_rows_non_padded_maxed(); table.num_normal_lookups_f()]) - .collect(), - traces - .iter() - .flat_map(|(table, trace)| vec![trace.n_rows_non_padded_maxed(); table.num_normal_lookups_ef()]) - .collect(), - traces - .keys() - .flat_map(|table| table.normal_lookup_default_indexes_f()) - .collect(), - traces - .keys() - .flat_map(|table| table.normal_lookup_default_indexes_ef()) - .collect(), - traces - .iter() - .flat_map(|(table, trace)| table.normal_lookup_f_value_columns(trace)) - .collect(), - traces - .iter() - .flat_map(|(table, trace)| table.normal_lookup_ef_value_columns(trace)) - .collect(), - traces - .keys() - .flat_map(|table| table.normal_lookups_statements_f(&air_points[table], &evals_f[table])) - .collect(), - traces - .keys() - .flat_map(|table| table.normal_lookups_statements_ef(&air_points[table], &evals_ef[table])) - .collect(), - LOG_SMALLEST_DECOMPOSITION_CHUNK, - ); - - let vectorized_lookup_into_memory = VectorizedPackedLookupProver::<_, VECTOR_LEN>::step_1( - &mut prover_state, - &memory, - traces - .iter() - .flat_map(|(table, trace)| table.vector_lookup_index_columns(trace)) - .collect(), - traces - .iter() - .flat_map(|(table, trace)| vec![trace.n_rows_non_padded_maxed(); table.num_vector_lookups()]) - .collect(), - traces - .keys() - .flat_map(|table| table.vector_lookup_default_indexes()) - .collect(), - traces - .iter() - .flat_map(|(table, trace)| table.vector_lookup_values_columns(trace)) - .collect(), - traces - .keys() - .flat_map(|table| table.vectorized_lookups_statements(&air_points[table], &evals_f[table])) - .collect(), - LOG_SMALLEST_DECOMPOSITION_CHUNK, - ); - - // 2nd Commitment - let extension_pols = vec![ - normal_lookup_into_memory.pushforward_to_commit(), - vectorized_lookup_into_memory.pushforward_to_commit(), - bytecode_pushforward.as_slice(), - ]; - - let extension_dims = vec![ - ColDims::padded(non_zero_memory_size, EF::ZERO), // memory - ColDims::padded(non_zero_memory_size.div_ceil(VECTOR_LEN), EF::ZERO), // memory (folded) - ColDims::padded(bytecode.instructions.len(), EF::ZERO), // bytecode - ]; - - let packed_pcs_witness_extension = packed_pcs_commit( - &second_batched_whir_config_builder( - whir_config_builder.clone(), - packed_pcs_witness_base.packed_polynomial.by_ref().n_vars(), - num_packed_vars_for_dims::(&extension_dims, LOG_SMALLEST_DECOMPOSITION_CHUNK), - ), - &extension_pols, - &extension_dims, - &mut prover_state, - LOG_SMALLEST_DECOMPOSITION_CHUNK, - ); - - let mut normal_lookup_statements = normal_lookup_into_memory.step_2(&mut prover_state, non_zero_memory_size); - - let vectorized_lookup_statements = vectorized_lookup_into_memory.step_2(&mut prover_state, non_zero_memory_size); + let bytecode_pushforward_commitment = + WhirConfig::new(¶ms.second_whir, log2_ceil_usize(bytecode.instructions.len())) + .commit(&mut prover_state, &bytecode_pushforward); let bytecode_logup_star_statements = prove_logup_star( &mut prover_state, &MleRef::Extension(&folded_bytecode), - &traces[&Table::execution()].base[COL_INDEX_PC], - bytecode_lookup_claim_1.value, + &traces[&Table::execution()].base[COL_PC], + bytecode_lookup_claim.value, &bytecode_poly_eq_point, - &bytecode_pushforward, + &bytecode_pushforward.by_ref(), Some(bytecode.instructions.len()), ); - let memory_statements = vec![ - normal_lookup_statements.on_table.clone(), - vectorized_lookup_statements.on_table.clone(), + committed_statements.get_mut(&Table::execution()).unwrap().push(( + bytecode_logup_star_statements.on_indexes.point.clone(), + BTreeMap::from_iter([(COL_PC, bytecode_logup_star_statements.on_indexes.value)]), + )); + + let public_memory_random_point = MultilinearPoint(prover_state.sample_vec(log2_strict_usize(public_memory_size))); + prover_state.duplexing(); + let public_memory_eval = (&memory[..public_memory_size]).evaluate(&public_memory_random_point); + + let memory_acc_statements = vec![ + SparseStatement::new( + packed_pcs_witness_base.packed_n_vars, + logup_statements.memory_acc_point, + vec![ + SparseValue::new(0, logup_statements.value_memory), + SparseValue::new(1, logup_statements.value_acc), + ], + ), + SparseStatement::new( + packed_pcs_witness_base.packed_n_vars, + public_memory_random_point, + vec![SparseValue::new(0, public_memory_eval)], + ), ]; - let mut final_statements: BTreeMap>>> = Default::default(); - for table in traces.keys() { - final_statements.insert( - *table, - table.committed_statements_prover( - &mut prover_state, - &air_points[table], - &evals_f[table], - commitmenent_extension_helper.get(table), - &mut normal_lookup_statements.on_indexes_f, - &mut normal_lookup_statements.on_indexes_ef, - ), - ); - } - assert!(normal_lookup_statements.on_indexes_f.is_empty()); - assert!(normal_lookup_statements.on_indexes_ef.is_empty()); - - let p16_gkr = prove_poseidon_gkr( - &mut prover_state, - &p16_witness, - air_points[&Table::poseidon16_core()].0.clone(), - UNIVARIATE_SKIPS, - &p16_gkr_layers, - ); - assert_eq!(&p16_gkr.output_statements.point, &air_points[&Table::poseidon16_core()]); - assert_eq!( - &p16_gkr.output_statements.values, - &evals_f[&Table::poseidon16_core()][POSEIDON_16_CORE_COL_OUTPUT_START..][..16] - ); - - let p24_gkr = prove_poseidon_gkr( - &mut prover_state, - &p24_witness, - air_points[&Table::poseidon24_core()].0.clone(), - UNIVARIATE_SKIPS, - &p24_gkr_layers, - ); - assert_eq!(&p24_gkr.output_statements.point, &air_points[&Table::poseidon24_core()]); - assert_eq!( - &p24_gkr.output_statements.values[16..], - &evals_f[&Table::poseidon24_core()][POSEIDON_24_CORE_COL_OUTPUT_START..][..8] - ); - - { - let mut cursor = 0; - for table in traces.keys() { - for (statement, lookup) in vectorized_lookup_statements.on_indexes[cursor..] - .iter() - .zip(table.vector_lookups()) - { - final_statements.get_mut(table).unwrap()[lookup.index].extend(statement.clone()); - } - cursor += table.num_vector_lookups(); - } - } - - let (initial_pc_statement, final_pc_statement) = - initial_and_final_pc_conditions(traces[&Table::execution()].log_padded()); - - final_statements.get_mut(&Table::execution()).unwrap()[ExecutionTable.find_committed_column_index_f(COL_INDEX_PC)] - .extend(vec![ - bytecode_logup_star_statements.on_indexes.clone(), - initial_pc_statement, - final_pc_statement, - ]); - let statements_p16_core = final_statements.get_mut(&Table::poseidon16_core()).unwrap(); - for (stmts, gkr_value) in statements_p16_core[POSEIDON_16_CORE_COL_INPUT_START..][..16] - .iter_mut() - .zip(&p16_gkr.input_statements.values) - { - stmts.push(Evaluation::new(p16_gkr.input_statements.point.clone(), *gkr_value)); - } - statements_p16_core[POSEIDON_16_CORE_COL_COMPRESSION].push(p16_gkr.on_compression_selector.unwrap()); - - let statements_p24_core = final_statements.get_mut(&Table::poseidon24_core()).unwrap(); - for (stmts, gkr_value) in statements_p24_core[POSEIDON_24_CORE_COL_INPUT_START..][..24] - .iter_mut() - .zip(&p24_gkr.input_statements.values) - { - stmts.push(Evaluation::new(p24_gkr.input_statements.point.clone(), *gkr_value)); - } - - // First Opening - let mut all_base_statements = [ - vec![memory_statements], - encapsulate_vec(p16_gkr.cubes_statements.split()), - encapsulate_vec(p24_gkr.cubes_statements.split()), - ] - .concat(); - all_base_statements.extend(final_statements.into_values().flatten()); - - let global_statements_base = packed_pcs_global_statements_for_prover( - &base_pols, - &base_dims, - LOG_SMALLEST_DECOMPOSITION_CHUNK, - &all_base_statements, - &mut prover_state, - ); - - // Second Opening - let global_statements_extension = packed_pcs_global_statements_for_prover( - &extension_pols, - &extension_dims, - LOG_SMALLEST_DECOMPOSITION_CHUNK, - &[ - normal_lookup_statements.on_pushforward, - vectorized_lookup_statements.on_pushforward, - bytecode_logup_star_statements.on_pushforward, - ], - &mut prover_state, + let table_heights = traces.iter().map(|(table, trace)| (*table, trace.log_n_rows)).collect(); + let global_statements_base = packed_pcs_global_statements( + packed_pcs_witness_base.packed_n_vars, + log2_strict_usize(memory.len()), + memory_acc_statements, + &table_heights, + &committed_statements, ); WhirConfig::new( - whir_config_builder, + ¶ms.first_whir, packed_pcs_witness_base.packed_polynomial.by_ref().n_vars(), ) - .batch_prove( + .prove( &mut prover_state, global_statements_base, packed_pcs_witness_base.inner_witness, &packed_pcs_witness_base.packed_polynomial.by_ref(), - global_statements_extension, - packed_pcs_witness_extension.inner_witness, - &packed_pcs_witness_extension.packed_polynomial.by_ref(), ); - (prover_state.into_proof(), exec_summary) + WhirConfig::new(¶ms.second_whir, log2_ceil_usize(bytecode.instructions.len())).prove( + &mut prover_state, + bytecode_logup_star_statements + .on_pushforward + .into_iter() + .map(|smt| SparseStatement::dense(smt.point, smt.value)) + .collect::>(), + bytecode_pushforward_commitment, + &bytecode_pushforward.by_ref(), + ); + let proof_size_fe = prover_state.proof_size_fe(); + ExecutionProof { + proof: prover_state.into_proof(), + proof_size_fe, + exec_summary, + first_whir_n_vars, + } } +#[allow(clippy::too_many_arguments)] +#[allow(clippy::type_complexity)] fn prove_bus_and_air( - prover_state: &mut multilinear_toolkit::prelude::FSProver>, - t: &Table, + prover_state: &mut impl FSProver, + table: &Table, trace: &TableTrace, - bus_challenge: EF, - fingerprint_challenge: EF, -) -> (EF, MultilinearPoint, Vec, Vec) { - let n_buses = t.buses().len(); - let n_buses_padded = n_buses.next_power_of_two(); - let log_n_buses = log2_ceil_usize(n_buses); - let n_rows = trace.n_rows_padded(); - let log_n_rows = trace.log_padded(); - - assert!(n_buses > 0, "Table {} has no buses", t.name()); - - let mut numerators = F::zero_vec(n_buses_padded * n_rows); - for (bus, numerators_chunk) in t.buses().iter().zip(numerators.chunks_mut(n_rows)) { - match bus.selector { - BusSelector::Column(selector_col) => { - assert!(selector_col < trace.base.len()); - trace.base[selector_col] - .par_iter() - .zip(numerators_chunk) - .for_each(|(&selector, v)| { - *v = match bus.direction { - BusDirection::Pull => -selector, - BusDirection::Push => selector, - } - }); - } - BusSelector::ConstantOne => { - numerators_chunk.par_iter_mut().for_each(|v| { - *v = match bus.direction { - BusDirection::Pull => F::NEG_ONE, - BusDirection::Push => F::ONE, - } - }); - } + logup_c: EF, + logup_alpha: EF, + bus_beta: EF, + air_alpha_powers: Vec, + bus_point: &MultilinearPoint, + bus_numerator_value: EF, + bus_denominator_value: EF, +) -> Vec<(MultilinearPoint, BTreeMap)> { + let bus_final_value = bus_numerator_value + * match table.bus().direction { + BusDirection::Pull => EF::NEG_ONE, + BusDirection::Push => EF::ONE, } - } - - let mut denominators = unsafe { uninitialized_vec(n_buses_padded * n_rows) }; - for (bus, denomniators_chunk) in t.buses().iter().zip(denominators.chunks_exact_mut(n_rows)) { - denomniators_chunk.par_iter_mut().enumerate().for_each(|(i, v)| { - *v = bus_challenge - + finger_print( - match &bus.table { - BusTable::Constant(table) => table.embed(), - BusTable::Variable(col) => trace.base[*col][i], - }, - bus.data - .iter() - .map(|col| trace.base[*col][i]) - .collect::>() - .as_slice(), - fingerprint_challenge, - ); - }); - } - denominators[n_rows * n_buses..] - .par_iter_mut() - .for_each(|v| *v = EF::ONE); - - // TODO avoid embedding !! - let numerators_embedded = numerators.par_iter().copied().map(EF::from).collect::>(); - - // TODO avoid reallocation due to packing (pack directly when constructing) - let numerators_packed = pack_extension(&numerators_embedded); - let denominators_packed = pack_extension(&denominators); - let (mut quotient, bus_point_global, numerator_value_global, denominator_value_global) = - prove_gkr_quotient::<_, TWO_POW_UNIVARIATE_SKIPS>( - prover_state, - &MleGroupRef::ExtensionPacked(vec![&numerators_packed, &denominators_packed]), - ); + + bus_beta * (bus_denominator_value - logup_c); - let (bus_point, bus_selector_values, bus_data_values) = if n_buses == 1 { - // easy case - ( - bus_point_global, - vec![numerator_value_global], - vec![denominator_value_global], - ) - } else { - let uni_selectors = univariate_selectors::(UNIVARIATE_SKIPS); - - let sub_numerators_evals = numerators - .par_chunks_exact(1 << (log_n_rows - UNIVARIATE_SKIPS)) - .take(n_buses << UNIVARIATE_SKIPS) - .map(|chunk| chunk.evaluate(&MultilinearPoint(bus_point_global[1 + log_n_buses..].to_vec()))) - .collect::>(); - prover_state.add_extension_scalars(&sub_numerators_evals); - // sanity check: - assert_eq!( - numerator_value_global, - evaluate_univariate_multilinear::<_, _, _, false>( - &padd_with_zero_to_next_power_of_two(&sub_numerators_evals), - &bus_point_global[..1 + log_n_buses], - &uni_selectors, - None - ), - ); - - let sub_denominators_evals = denominators - .par_chunks_exact(1 << (log_n_rows - UNIVARIATE_SKIPS)) - .take(n_buses << UNIVARIATE_SKIPS) - .map(|chunk| chunk.evaluate(&MultilinearPoint(bus_point_global[1 + log_n_buses..].to_vec()))) - .collect::>(); - prover_state.add_extension_scalars(&sub_denominators_evals); - // sanity check: - assert_eq!( - denominator_value_global, - evaluate_univariate_multilinear::<_, _, _, false>( - &padd_to_next_power_of_two(&sub_denominators_evals, EF::ONE), - &bus_point_global[..1 + log_n_buses], - &uni_selectors, - None - ), - ); - - let epsilon = prover_state.sample(); - let bus_point = MultilinearPoint([vec![epsilon], bus_point_global[1 + log_n_buses..].to_vec()].concat()); - - let bus_selector_values = sub_numerators_evals - .chunks_exact(1 << UNIVARIATE_SKIPS) - .map(|chunk| evaluate_univariate_multilinear::<_, _, _, false>(chunk, &[epsilon], &uni_selectors, None)) - .collect(); - let bus_data_values = sub_denominators_evals - .chunks_exact(1 << UNIVARIATE_SKIPS) - .map(|chunk| evaluate_univariate_multilinear::<_, _, _, false>(chunk, &[epsilon], &uni_selectors, None)) - .collect(); - - (bus_point, bus_selector_values, bus_data_values) - }; - - let bus_beta = prover_state.sample(); - - let bus_final_values = bus_selector_values - .iter() - .zip_eq(&bus_data_values) - .zip_eq(&t.buses()) - .map(|((&bus_selector_value, &bus_data_value), bus)| { - bus_selector_value - * match bus.direction { - BusDirection::Pull => EF::NEG_ONE, - BusDirection::Push => EF::ONE, - } - + bus_beta * (bus_data_value - bus_challenge) - }) - .collect::>(); - - let bus_virtual_statement = MultiEvaluation::new(bus_point, bus_final_values); - - for bus in t.buses() { - quotient -= bus.padding_contribution(t, trace.padding_len(), bus_challenge, fingerprint_challenge); - } + let bus_virtual_statement = Evaluation::new(bus_point.clone(), bus_final_value); let extra_data = ExtraDataForBuses { - fingerprint_challenge_powers: fingerprint_challenge.powers().collect_n(max_bus_width()), - fingerprint_challenge_powers_packed: EFPacking::::from(fingerprint_challenge) - .powers() - .collect_n(max_bus_width()), + logup_alpha_powers: logup_alpha.powers().collect_n(max_bus_width()), + logup_alpha_powers_packed: EFPacking::::from(logup_alpha).powers().collect_n(max_bus_width()), bus_beta, bus_beta_packed: EFPacking::::from(bus_beta), - alpha_powers: vec![], // filled later + alpha_powers: air_alpha_powers, }; - let (air_point, evals_f, evals_ef) = info_span!("Table AIR proof", table = t.name()).in_scope(|| { + let air_claims = info_span!("AIR proof", table = table.name()).in_scope(|| { macro_rules! prove_air_for_table { ($t:expr) => { prove_air( prover_state, $t, extra_data, - UNIVARIATE_SKIPS, + 1, &trace.base[..$t.n_columns_f_air()], &trace.ext[..$t.n_columns_ef_air()], - &$t.air_padding_row_f(), - &$t.air_padding_row_ef(), Some(bus_virtual_statement), $t.n_columns_air() + $t.total_n_down_columns_air() > 5, // heuristic ) }; } - delegate_to_inner!(t => prove_air_for_table) + delegate_to_inner!(table => prove_air_for_table) }); - (quotient, air_point, evals_f, evals_ef) + let mut res = vec![]; + if let Some(down_point) = air_claims.down_point { + assert_eq!(air_claims.evals_f_on_down_columns.len(), table.n_down_columns_f()); + let mut down_evals = BTreeMap::new(); + for (value_f, col_index) in air_claims + .evals_f_on_down_columns + .iter() + .zip(table.down_column_indexes_f()) + { + down_evals.insert(col_index, *value_f); + } + + assert_eq!(air_claims.evals_ef_on_down_columns.len(), table.n_down_columns_ef()); + for (col_index, value) in table + .down_column_indexes_ef() + .into_iter() + .zip(air_claims.evals_ef_on_down_columns) + { + let transposed = transpose_slice_to_basis_coefficients::(&trace.ext[col_index]) + .iter() + .map(|base_col| base_col.evaluate(&down_point)) + .collect::>(); + assert_eq!(dot_product_with_base(&transposed), value); // sanity check + prover_state.add_extension_scalars(&transposed); + for (j, v) in transposed.iter().enumerate() { + let virtual_index = table.n_columns_f_air() + col_index * DIMENSION + j; + down_evals.insert(virtual_index, *v); + } + } + res.push((down_point, down_evals)); + } + + assert_eq!(air_claims.evals_f.len(), table.n_columns_f_air()); + assert_eq!(air_claims.evals_ef.len(), table.n_columns_ef_air()); + let mut evals = air_claims + .evals_f + .iter() + .copied() + .enumerate() + .collect::>(); + for (col_index, (value, col)) in air_claims.evals_ef.into_iter().zip(&trace.ext).enumerate() { + let transposed = transpose_slice_to_basis_coefficients::(col) + .iter() + .map(|base_col| base_col.evaluate(&air_claims.point)) + .collect::>(); + prover_state.add_extension_scalars(&transposed); + assert_eq!(dot_product_with_base(&transposed), value); // sanity check + for (j, v) in transposed.into_iter().enumerate() { + let virtual_index = table.n_columns_f_air() + col_index * DIMENSION + j; + evals.insert(virtual_index, v); + } + } + + res.push((air_claims.point.clone(), evals)); + + res } diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs new file mode 100644 index 00000000..7e8971ae --- /dev/null +++ b/crates/lean_prover/src/test_zkvm.rs @@ -0,0 +1,184 @@ +use crate::{prove_execution::prove_execution, verify_execution::verify_execution}; +use lean_compiler::*; +use lean_vm::*; +use multilinear_toolkit::prelude::*; +use rand::{Rng, SeedableRng, rngs::StdRng}; +use utils::poseidon16_permute; + +#[test] +fn test_zk_vm_all_precompiles() { + test_zk_vm_all_precompiles_helper(false); +} + +#[test] +#[ignore] // slow test +fn test_zk_vm_fuzzing() { + test_zk_vm_all_precompiles_helper(true); +} + +fn test_zk_vm_all_precompiles_helper(fuzzing: bool) { + let program_str = r#" +DIM = 5 +COMPRESSION = 1 +PERMUTATION = 0 +N = 11 +VECTOR_LEN = 8 + +# Dot product precompile: +BE = 1 # base-extension +EE = 0 # extension-extension + +def main(): + pub_start = NONRESERVED_PROGRAM_INPUT_START + poseidon16(pub_start, pub_start + VECTOR_LEN, pub_start + 2 * VECTOR_LEN, PERMUTATION) + poseidon16(pub_start + 4 * VECTOR_LEN, pub_start + 5 * VECTOR_LEN, pub_start + 6 * VECTOR_LEN, COMPRESSION) + dot_product(pub_start + 88, pub_start + 88 + N, pub_start + 1000, N, BE) + dot_product(pub_start + 88 + N, pub_start + 88 + N * (DIM + 1), pub_start + 1000 + DIM, N, EE) + c: Mut = 0 + for i in range(0,100): + c += 1 + assert c == 100 + + return + +"#; + + const N: usize = 11; + + let mut rng = StdRng::seed_from_u64(0); + let mut public_input = F::zero_vec(1 << 13); + + let poseidon_16_perm_input: [F; 16] = rng.random(); + public_input[..16].copy_from_slice(&poseidon_16_perm_input); + public_input[16..32].copy_from_slice(&poseidon16_permute(poseidon_16_perm_input)); + + let poseidon_16_compress_input: [F; 16] = rng.random(); + public_input[32..48].copy_from_slice(&poseidon_16_compress_input); + public_input[48..56].copy_from_slice(&poseidon16_permute(poseidon_16_compress_input)[..8]); + + let poseidon_24_input: [F; 24] = rng.random(); + public_input[56..80].copy_from_slice(&poseidon_24_input); + + let dot_product_slice_base: [F; N] = rng.random(); + let dot_product_slice_ext_a: [EF; N] = rng.random(); + let dot_product_slice_ext_b: [EF; N] = rng.random(); + + public_input[88..][..N].copy_from_slice(&dot_product_slice_base); + public_input[88 + N..][..N * DIMENSION].copy_from_slice( + &dot_product_slice_ext_a + .iter() + .flat_map(|&x| x.as_basis_coefficients_slice().to_vec()) + .collect::>(), + ); + public_input[88 + N + N * DIMENSION..][..N * DIMENSION].copy_from_slice( + &dot_product_slice_ext_b + .iter() + .flat_map(|&x| x.as_basis_coefficients_slice().to_vec()) + .collect::>(), + ); + let dot_product_base_ext: EF = dot_product(dot_product_slice_ext_a.into_iter(), dot_product_slice_base.into_iter()); + let dot_product_ext_ext: EF = dot_product(dot_product_slice_ext_a.into_iter(), dot_product_slice_ext_b.into_iter()); + + public_input[1000..][..DIMENSION].copy_from_slice(dot_product_base_ext.as_basis_coefficients_slice()); + public_input[1000 + DIMENSION..][..DIMENSION].copy_from_slice(dot_product_ext_ext.as_basis_coefficients_slice()); + + let slice_a: [F; 3] = rng.random(); + let slice_b: [EF; 3] = rng.random(); + let poly_eq = MultilinearPoint(slice_b.to_vec()) + .eq_poly_outside(&MultilinearPoint(slice_a.iter().map(|&x| EF::from(x)).collect())); + public_input[1100..][..3].copy_from_slice(&slice_a); + public_input[1100 + 3..][..3 * DIMENSION].copy_from_slice( + slice_b + .iter() + .flat_map(|&x| x.as_basis_coefficients_slice().to_vec()) + .collect::>() + .as_slice(), + ); + public_input[1100 + 3 + 3 * DIMENSION..][..DIMENSION].copy_from_slice(poly_eq.as_basis_coefficients_slice()); + + test_zk_vm_helper(program_str, (&public_input, &[]), fuzzing); +} + +#[test] +fn test_small_memory() { + let program_str = r#" +def main(): + a = Array(1) + for i in unroll(0, 2**17): + a[0] = 1 * 2 + return +"#; + + test_zk_vm_helper(program_str, (&[], &[]), false); +} + +#[test] +fn test_prove_fibonacci() { + let n = std::env::var("FIB_N") + .unwrap_or("10000".to_string()) + .parse::() + .unwrap(); + let program_str = r#" +N = FIB_N_PLACEHOLDER +STEPS = 10000 # N should be a multiple of STEPS +N_STEPS = N / STEPS + +def main(): + x, y = fibonacci_step(0, 1, N_STEPS) + print(x) + return + +def fibonacci_step(a, b, steps_remaining): + if steps_remaining == 0: + return a, b + new_a, new_b = fibonacci_const(a, b, STEPS) + res_a, res_b = fibonacci_step(new_a, new_b, steps_remaining - 1) + return res_a, res_b + +def fibonacci_const(a, b, n: Const): + buff = Array(n + 2) + buff[0] = a + buff[1] = b + for j in unroll(2, n + 2): + buff[j] = buff[j - 1] + buff[j - 2] + return buff[n], buff[n + 1] +"#; + let program_str = program_str.replace("FIB_N_PLACEHOLDER", &n.to_string()); + + test_zk_vm_helper(&program_str, (&[F::ZERO; 1 << 14], &[]), false); +} + +fn test_zk_vm_helper(program_str: &str, (public_input, private_input): (&[F], &[F]), fuzzing: bool) { + if !fuzzing { + utils::init_tracing(); + } + let bytecode = compile_program(&ProgramSource::Raw(program_str.to_string())); + let time = std::time::Instant::now(); + let proof = prove_execution( + &bytecode, + (public_input, private_input), + &vec![], + &Default::default(), + false, + ); + let proof_time = time.elapsed(); + verify_execution(&bytecode, public_input, proof.proof.clone(), &Default::default()).unwrap(); + println!("{}", proof.exec_summary); + println!("Proof time: {:.3} s", proof_time.as_secs_f32()); + + if fuzzing { + println!("Starting fuzzing..."); + let mut percent = 0; + for i in 0..proof.proof.len() { + let new_percent = i * 100 / proof.proof.len(); + if new_percent != percent { + percent = new_percent; + println!("{}%", percent); + } + let mut fuzzed_proof = proof.proof.clone(); + fuzzed_proof[i] += F::ONE; + let verify_result = verify_execution(&bytecode, public_input, fuzzed_proof, &Default::default()); + assert!(verify_result.is_err(), "Fuzzing failed at index {}", i); + } + } +} diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index a5be0708..e5a0ac17 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -1,185 +1,130 @@ use std::collections::BTreeMap; -use crate::common::*; use crate::*; +use crate::{SnarkParams, common::*}; use air::verify_air; -use itertools::Itertools; use lean_vm::*; -use lookup::verify_gkr_quotient; -use lookup::verify_logup_star; use multilinear_toolkit::prelude::*; use p3_util::{log2_ceil_usize, log2_strict_usize}; -use poseidon_circuit::PoseidonGKRLayers; -use poseidon_circuit::verify_poseidon_gkr; +use sub_protocols::verify_logup_star; use sub_protocols::*; use utils::ToUsize; -use utils::build_challenger; -use whir_p3::WhirConfig; -use whir_p3::WhirConfigBuilder; -use whir_p3::second_batched_whir_config_builder; +use whir_p3::{SparseStatement, SparseValue, WhirConfig}; + +#[derive(Debug, Clone)] +pub struct ProofVerificationDetails { + pub log_memory: usize, + pub table_n_vars: BTreeMap, + pub first_quotient_gkr_n_vars: usize, + pub total_whir_statements_base: usize, +} pub fn verify_execution( bytecode: &Bytecode, public_input: &[F], - proof: Proof, - whir_config_builder: WhirConfigBuilder, -) -> Result<(), ProofError> { - let mut verifier_state = VerifierState::new(proof, build_challenger()); - - let p16_gkr_layers = PoseidonGKRLayers::<16, N_COMMITED_CUBES_P16>::build(Some(VECTOR_LEN)); - let p24_gkr_layers = PoseidonGKRLayers::<24, N_COMMITED_CUBES_P24>::build(None); + proof: Vec, + params: &SnarkParams, +) -> Result { + let mut verifier_state = VerifierState::::new(proof, get_poseidon16().clone()); + verifier_state.duplexing(); let dims = verifier_state .next_base_scalars_vec(1 + N_TABLES)? .into_iter() .map(|x| x.to_usize()) .collect::>(); - let private_memory_len = dims[0]; - let table_heights: BTreeMap = (0..N_TABLES) - .map(|i| (ALL_TABLES[i], TableHeight(dims[i + 1]))) - .collect(); - - // only keep tables with non-zero rows - let table_heights: BTreeMap<_, _> = table_heights - .into_iter() - .filter(|(table, height)| height.n_rows_non_padded() > 0 || table == &Table::execution() || table.is_poseidon()) - .collect(); + let log_memory = dims[0]; + let table_n_vars: BTreeMap = (0..N_TABLES).map(|i| (ALL_TABLES[i], dims[i + 1])).collect(); + for (table, &n_vars) in &table_n_vars { + if n_vars < MIN_LOG_N_ROWS_PER_TABLE { + return Err(ProofError::InvalidProof); + } + if n_vars + > MAX_LOG_N_ROWS_PER_TABLE + .iter() + .find(|(t, _)| t == table) + .map(|(_, m)| *m) + .unwrap() + { + return Err(ProofError::InvalidProof); + } + } + // check memory is bigger than any other table + if log_memory < *table_n_vars.values().max().unwrap() { + return Err(ProofError::InvalidProof); + } let public_memory = build_public_memory(public_input); - let log_public_memory = log2_strict_usize(public_memory.len()); - let log_memory = log2_ceil_usize(public_memory.len() + private_memory_len); - if !(MIN_LOG_MEMORY_SIZE..=MAX_LOG_MEMORY_SIZE).contains(&log_memory) { return Err(ProofError::InvalidProof); } - let base_dims = get_base_dims( - log_public_memory, - private_memory_len, - (&p16_gkr_layers, &p24_gkr_layers), - &table_heights, - ); - let parsed_commitment_base = packed_pcs_parse_commitment( - &whir_config_builder, - &mut verifier_state, - &base_dims, - LOG_SMALLEST_DECOMPOSITION_CHUNK, - )?; + let parsed_commitment_base = + packed_pcs_parse_commitment(¶ms.first_whir, &mut verifier_state, log_memory, &table_n_vars)?; - let bus_challenge = verifier_state.sample(); - let fingerprint_challenge = verifier_state.sample(); + let logup_c = verifier_state.sample(); + verifier_state.duplexing(); + let logup_alpha = verifier_state.sample(); + verifier_state.duplexing(); - let mut bus_quotients: BTreeMap = Default::default(); - let mut air_points: BTreeMap> = Default::default(); - let mut evals_f: BTreeMap> = Default::default(); - let mut evals_ef: BTreeMap> = Default::default(); + let logup_statements = verify_generic_logup(&mut verifier_state, logup_c, logup_alpha, log_memory, &table_n_vars)?; + let mut committed_statements: CommittedStatements = Default::default(); + for table in ALL_TABLES { + committed_statements.insert( + table, + vec![( + logup_statements.points[&table].clone(), + logup_statements.columns_values[&table].clone(), + )], + ); + } - for (table, height) in &table_heights { - let (this_bus_quotient, this_air_point, this_evals_f, this_evals_ef) = verify_bus_and_air( + let bus_beta = verifier_state.sample(); + verifier_state.duplexing(); + let air_alpha = verifier_state.sample(); + let air_alpha_powers: Vec = air_alpha.powers().collect_n(max_air_constraints() + 1); + + for (table, log_n_rows) in &table_n_vars { + let this_air_claims = verify_bus_and_air( &mut verifier_state, table, - *height, - bus_challenge, - fingerprint_challenge, + *log_n_rows, + logup_c, + logup_alpha, + bus_beta, + air_alpha_powers.clone(), + &logup_statements.points[table], + logup_statements.bus_numerators_values[table], + logup_statements.bus_denominators_values[table], )?; - bus_quotients.insert(*table, this_bus_quotient); - air_points.insert(*table, this_air_point); - evals_f.insert(*table, this_evals_f); - evals_ef.insert(*table, this_evals_ef); - } - - if bus_quotients.values().copied().sum::() != EF::ZERO { - return Err(ProofError::InvalidProof); + committed_statements.get_mut(table).unwrap().extend(this_air_claims); } let bytecode_compression_challenges = MultilinearPoint(verifier_state.sample_vec(log2_ceil_usize(N_INSTRUCTION_COLUMNS))); - let bytecode_lookup_claim_1 = Evaluation::new( - air_points[&Table::execution()].clone(), - padd_with_zero_to_next_power_of_two(&evals_f[&Table::execution()][..N_INSTRUCTION_COLUMNS]) - .evaluate(&bytecode_compression_challenges), - ); - - let normal_lookup_into_memory = NormalPackedLookupVerifier::step_1( - &mut verifier_state, - table_heights - .iter() - .flat_map(|(table, height)| vec![height.n_rows_non_padded_maxed(); table.num_normal_lookups_f()]) - .collect(), - table_heights - .iter() - .flat_map(|(table, height)| vec![height.n_rows_non_padded_maxed(); table.num_normal_lookups_ef()]) - .collect(), - table_heights - .keys() - .flat_map(|table| table.normal_lookup_default_indexes_f()) - .collect(), - table_heights - .keys() - .flat_map(|table| table.normal_lookup_default_indexes_ef()) - .collect(), - table_heights - .keys() - .flat_map(|table| table.normal_lookups_statements_f(&air_points[table], &evals_f[table])) - .collect(), - table_heights - .keys() - .flat_map(|table| table.normal_lookups_statements_ef(&air_points[table], &evals_ef[table])) - .collect(), - LOG_SMALLEST_DECOMPOSITION_CHUNK, - &public_memory, // we need to pass the first few values of memory, public memory is enough - )?; - - let vectorized_lookup_into_memory = VectorizedPackedLookupVerifier::<_, VECTOR_LEN>::step_1( - &mut verifier_state, - table_heights - .iter() - .flat_map(|(table, height)| vec![height.n_rows_non_padded_maxed(); table.num_vector_lookups()]) - .collect(), - table_heights - .keys() - .flat_map(|table| table.vector_lookup_default_indexes()) - .collect(), - table_heights - .keys() - .flat_map(|table| table.vectorized_lookups_statements(&air_points[table], &evals_f[table])) - .collect(), - LOG_SMALLEST_DECOMPOSITION_CHUNK, - &public_memory, // we need to pass the first few values of memory, public memory is enough - )?; - - let extension_dims = vec![ - ColDims::padded(public_memory.len() + private_memory_len, EF::ZERO), // memory pushwordard - ColDims::padded( - (public_memory.len() + private_memory_len).div_ceil(VECTOR_LEN), - EF::ZERO, - ), // memory (folded) pushwordard - ColDims::padded(bytecode.instructions.len(), EF::ZERO), // bytecode pushforward - ]; - - let parsed_commitment_extension = packed_pcs_parse_commitment( - &second_batched_whir_config_builder( - whir_config_builder.clone(), - parsed_commitment_base.num_variables, - num_packed_vars_for_dims::(&extension_dims, LOG_SMALLEST_DECOMPOSITION_CHUNK), - ), - &mut verifier_state, - &extension_dims, - LOG_SMALLEST_DECOMPOSITION_CHUNK, - )?; + let bytecode_air_entry = &mut committed_statements.get_mut(&Table::execution()).unwrap()[2]; + let bytecode_air_point = bytecode_air_entry.0.clone(); + let mut bytecode_air_values = vec![]; + for bytecode_col_index in N_COMMITTED_EXEC_COLUMNS..N_COMMITTED_EXEC_COLUMNS + N_INSTRUCTION_COLUMNS { + bytecode_air_values.push(bytecode_air_entry.1.remove(&bytecode_col_index).unwrap()); + } - let mut normal_lookup_statements = normal_lookup_into_memory.step_2(&mut verifier_state, log_memory)?; + let bytecode_lookup_claim = Evaluation::new( + bytecode_air_point.clone(), + padd_with_zero_to_next_power_of_two(&bytecode_air_values).evaluate(&bytecode_compression_challenges), + ); - let vectorized_lookup_statements = vectorized_lookup_into_memory.step_2(&mut verifier_state, log_memory)?; + let bytecode_pushforward_parsed_commitment = + WhirConfig::new(¶ms.second_whir, log2_ceil_usize(bytecode.instructions.len())) + .parse_commitment::(&mut verifier_state)?; let bytecode_logup_star_statements = verify_logup_star( &mut verifier_state, log2_ceil_usize(bytecode.instructions.len()), - table_heights[&Table::execution()].log_padded(), - &[bytecode_lookup_claim_1], - EF::ONE, + table_n_vars[&Table::execution()], + bytecode_lookup_claim, )?; let folded_bytecode = fold_bytecode(bytecode, &bytecode_compression_challenges); if folded_bytecode.evaluate(&bytecode_logup_star_statements.on_table.point) @@ -187,245 +132,162 @@ pub fn verify_execution( { return Err(ProofError::InvalidProof); } - let memory_statements = vec![ - normal_lookup_statements.on_table.clone(), - vectorized_lookup_statements.on_table.clone(), - ]; - - let mut final_statements: BTreeMap> = Default::default(); - for table in table_heights.keys() { - final_statements.insert( - *table, - table.committed_statements_verifier( - &mut verifier_state, - &air_points[table], - &evals_f[table], - &evals_ef[table], - &mut normal_lookup_statements.on_indexes_f, - &mut normal_lookup_statements.on_indexes_ef, - )?, - ); - } - assert!(normal_lookup_statements.on_indexes_f.is_empty()); - assert!(normal_lookup_statements.on_indexes_ef.is_empty()); - - let p16_gkr = verify_poseidon_gkr( - &mut verifier_state, - table_heights[&Table::poseidon16_core()].log_padded(), - &air_points[&Table::poseidon16_core()].0, - &p16_gkr_layers, - UNIVARIATE_SKIPS, - true, - ); - assert_eq!(&p16_gkr.output_statements.point, &air_points[&Table::poseidon16_core()]); - assert_eq!( - &p16_gkr.output_statements.values, - &evals_f[&Table::poseidon16_core()][POSEIDON_16_CORE_COL_OUTPUT_START..][..16] - ); - - let p24_gkr = verify_poseidon_gkr( - &mut verifier_state, - table_heights[&Table::poseidon24_core()].log_padded(), - &air_points[&Table::poseidon24_core()].0, - &p24_gkr_layers, - UNIVARIATE_SKIPS, - false, - ); - assert_eq!(&p24_gkr.output_statements.point, &air_points[&Table::poseidon24_core()]); - assert_eq!( - &p24_gkr.output_statements.values[16..], - &evals_f[&Table::poseidon24_core()][POSEIDON_24_CORE_COL_OUTPUT_START..][..8] - ); - - { - let mut cursor = 0; - for table in table_heights.keys() { - for (statement, lookup) in vectorized_lookup_statements.on_indexes[cursor..] - .iter() - .zip(table.vector_lookups()) - { - final_statements.get_mut(table).unwrap()[lookup.index].extend(statement.clone()); - } - cursor += table.num_vector_lookups(); - } - } - - let (initial_pc_statement, final_pc_statement) = - initial_and_final_pc_conditions(table_heights[&Table::execution()].log_padded()); - final_statements.get_mut(&Table::execution()).unwrap()[ExecutionTable.find_committed_column_index_f(COL_INDEX_PC)] - .extend(vec![ - bytecode_logup_star_statements.on_indexes.clone(), - initial_pc_statement, - final_pc_statement, - ]); - let statements_p16_core = final_statements.get_mut(&Table::poseidon16_core()).unwrap(); - for (stmts, gkr_value) in statements_p16_core[POSEIDON_16_CORE_COL_INPUT_START..][..16] - .iter_mut() - .zip(&p16_gkr.input_statements.values) - { - stmts.push(Evaluation::new(p16_gkr.input_statements.point.clone(), *gkr_value)); - } - statements_p16_core[POSEIDON_16_CORE_COL_COMPRESSION].push(p16_gkr.on_compression_selector.unwrap()); + committed_statements.get_mut(&Table::execution()).unwrap().push(( + bytecode_logup_star_statements.on_indexes.point.clone(), + BTreeMap::from_iter([(COL_PC, bytecode_logup_star_statements.on_indexes.value)]), + )); - let statements_p24_core = final_statements.get_mut(&Table::poseidon24_core()).unwrap(); - for (stmts, gkr_value) in statements_p24_core[POSEIDON_24_CORE_COL_INPUT_START..][..24] - .iter_mut() - .zip(&p24_gkr.input_statements.values) - { - stmts.push(Evaluation::new(p24_gkr.input_statements.point.clone(), *gkr_value)); - } + let public_memory_random_point = + MultilinearPoint(verifier_state.sample_vec(log2_strict_usize(public_memory.len()))); + verifier_state.duplexing(); + let public_memory_eval = public_memory.evaluate(&public_memory_random_point); - let mut all_base_statements = [ - vec![memory_statements], - encapsulate_vec(p16_gkr.cubes_statements.split()), - encapsulate_vec(p24_gkr.cubes_statements.split()), - ] - .concat(); - all_base_statements.extend(final_statements.into_values().flatten()); - let global_statements_base = packed_pcs_global_statements_for_verifier( - &base_dims, - LOG_SMALLEST_DECOMPOSITION_CHUNK, - &all_base_statements, - &mut verifier_state, - &[(0, public_memory.clone())].into_iter().collect(), - )?; + let memory_acc_statements = vec![ + SparseStatement::new( + parsed_commitment_base.num_variables, + logup_statements.memory_acc_point, + vec![ + SparseValue::new(0, logup_statements.value_memory), + SparseValue::new(1, logup_statements.value_acc), + ], + ), + SparseStatement::new( + parsed_commitment_base.num_variables, + public_memory_random_point, + vec![SparseValue::new(0, public_memory_eval)], + ), + ]; - let global_statements_extension = packed_pcs_global_statements_for_verifier( - &extension_dims, - LOG_SMALLEST_DECOMPOSITION_CHUNK, - &[ - normal_lookup_statements.on_pushforward, - vectorized_lookup_statements.on_pushforward, - bytecode_logup_star_statements.on_pushforward, - ], + let global_statements_base = packed_pcs_global_statements( + parsed_commitment_base.num_variables, + log_memory, + memory_acc_statements, + &table_n_vars, + &committed_statements, + ); + let total_whir_statements_base = global_statements_base.iter().map(|s| s.values.len()).sum(); + WhirConfig::new(¶ms.first_whir, parsed_commitment_base.num_variables).verify( &mut verifier_state, - &Default::default(), + &parsed_commitment_base, + global_statements_base, )?; - WhirConfig::new(whir_config_builder, parsed_commitment_base.num_variables).batch_verify( + WhirConfig::new(¶ms.second_whir, log2_ceil_usize(bytecode.instructions.len())).verify( &mut verifier_state, - &parsed_commitment_base, - global_statements_base, - &parsed_commitment_extension, - global_statements_extension, + &bytecode_pushforward_parsed_commitment, + bytecode_logup_star_statements + .on_pushforward + .into_iter() + .map(|smt| SparseStatement::dense(smt.point, smt.value)) + .collect::>(), )?; - Ok(()) + Ok(ProofVerificationDetails { + log_memory, + table_n_vars, + first_quotient_gkr_n_vars: logup_statements.total_n_vars, + total_whir_statements_base, + }) } +#[allow(clippy::too_many_arguments)] #[allow(clippy::type_complexity)] fn verify_bus_and_air( - verifier_state: &mut VerifierState, EF, impl FSChallenger>, - t: &Table, - table_height: TableHeight, - bus_challenge: EF, - fingerprint_challenge: EF, -) -> ProofResult<(EF, MultilinearPoint, Vec, Vec)> { - let n_buses = t.buses().len(); - let log_n_buses = log2_ceil_usize(n_buses); - let log_n_rows = table_height.log_padded(); - - assert!(n_buses > 0, "Table {} has no buses", t.name()); - - let (mut quotient, bus_point_global, numerator_value_global, denominator_value_global) = - verify_gkr_quotient::<_, TWO_POW_UNIVARIATE_SKIPS>(verifier_state, log_n_rows + log_n_buses)?; - - let (bus_point, bus_selector_values, bus_data_values) = if n_buses == 1 { - // easy case - ( - bus_point_global, - vec![numerator_value_global], - vec![denominator_value_global], - ) - } else { - let uni_selectors = univariate_selectors::(UNIVARIATE_SKIPS); - - let sub_numerators_evals = verifier_state.next_extension_scalars_vec(n_buses << UNIVARIATE_SKIPS)?; - assert_eq!( - numerator_value_global, - evaluate_univariate_multilinear::<_, _, _, false>( - &padd_with_zero_to_next_power_of_two(&sub_numerators_evals), - &bus_point_global[..1 + log_n_buses], - &uni_selectors, - None - ), - ); - - let sub_denominators_evals = verifier_state.next_extension_scalars_vec(n_buses << UNIVARIATE_SKIPS)?; - assert_eq!( - denominator_value_global, - evaluate_univariate_multilinear::<_, _, _, false>( - &padd_to_next_power_of_two(&sub_denominators_evals, EF::ONE), - &bus_point_global[..1 + log_n_buses], - &uni_selectors, - None - ), - ); - let epsilon = verifier_state.sample(); - let bus_point = MultilinearPoint([vec![epsilon], bus_point_global[1 + log_n_buses..].to_vec()].concat()); - - let bus_selector_values = sub_numerators_evals - .chunks_exact(1 << UNIVARIATE_SKIPS) - .map(|chunk| evaluate_univariate_multilinear::<_, _, _, false>(chunk, &[epsilon], &uni_selectors, None)) - .collect(); - let bus_data_values = sub_denominators_evals - .chunks_exact(1 << UNIVARIATE_SKIPS) - .map(|chunk| evaluate_univariate_multilinear::<_, _, _, false>(chunk, &[epsilon], &uni_selectors, None)) - .collect(); - - (bus_point, bus_selector_values, bus_data_values) - }; - - let bus_beta = verifier_state.sample(); - - let bus_final_values = bus_selector_values - .iter() - .zip_eq(&bus_data_values) - .zip_eq(&t.buses()) - .map(|((&bus_selector_value, &bus_data_value), bus)| { - bus_selector_value - * match bus.direction { - BusDirection::Pull => EF::NEG_ONE, - BusDirection::Push => EF::ONE, - } - + bus_beta * (bus_data_value - bus_challenge) - }) - .collect::>(); - - let bus_virtual_statement = MultiEvaluation::new(bus_point, bus_final_values); + verifier_state: &mut impl FSVerifier, + table: &Table, + log_n_nrows: usize, + logup_c: EF, + logup_alpha: EF, + bus_beta: EF, + air_alpha_powers: Vec, + bus_point: &MultilinearPoint, + bus_numerator_value: EF, + bus_denominator_value: EF, +) -> ProofResult, BTreeMap)>> { + let bus_final_value = bus_numerator_value + * match table.bus().direction { + BusDirection::Pull => EF::NEG_ONE, + BusDirection::Push => EF::ONE, + } + + bus_beta * (bus_denominator_value - logup_c); - for bus in t.buses() { - quotient -= bus.padding_contribution(t, table_height.padding_len(), bus_challenge, fingerprint_challenge); - } + let bus_virtual_statement = Evaluation::new(bus_point.clone(), bus_final_value); let extra_data = ExtraDataForBuses { - fingerprint_challenge_powers: fingerprint_challenge.powers().collect_n(max_bus_width()), - fingerprint_challenge_powers_packed: EFPacking::::from(fingerprint_challenge) - .powers() - .collect_n(max_bus_width()), + logup_alpha_powers: logup_alpha.powers().collect_n(max_bus_width()), + logup_alpha_powers_packed: EFPacking::::from(logup_alpha).powers().collect_n(max_bus_width()), bus_beta, bus_beta_packed: EFPacking::::from(bus_beta), - alpha_powers: vec![], // filled later + alpha_powers: air_alpha_powers, }; - let (air_point, evals_f, evals_ef) = { + let air_claims = { macro_rules! verify_air_for_table { ($t:expr) => { verify_air( verifier_state, $t, extra_data, - UNIVARIATE_SKIPS, - log_n_rows, - &t.air_padding_row_f(), - &t.air_padding_row_ef(), + 1, + log_n_nrows, Some(bus_virtual_statement), )? }; } - delegate_to_inner!(t => verify_air_for_table) + delegate_to_inner!(table => verify_air_for_table) }; - Ok((quotient, air_point, evals_f, evals_ef)) + let mut res = vec![]; + if let Some(down_point) = air_claims.down_point { + assert_eq!(air_claims.evals_f_on_down_columns.len(), table.n_down_columns_f()); + let mut down_evals = BTreeMap::new(); + for (value_f, col_index) in air_claims + .evals_f_on_down_columns + .iter() + .zip(table.down_column_indexes_f()) + { + down_evals.insert(col_index, *value_f); + } + + assert_eq!(air_claims.evals_ef_on_down_columns.len(), table.n_down_columns_ef()); + for (col_index, value) in table + .down_column_indexes_ef() + .into_iter() + .zip(air_claims.evals_ef_on_down_columns) + { + let transposed = verifier_state.next_extension_scalars_vec(DIMENSION)?; + if dot_product_with_base(&transposed) != value { + return Err(ProofError::InvalidProof); + } + for (j, v) in transposed.iter().enumerate() { + let virtual_index = table.n_columns_f_air() + col_index * DIMENSION + j; + down_evals.insert(virtual_index, *v); + } + } + res.push((down_point, down_evals)); + } + + assert_eq!(air_claims.evals_f.len(), table.n_columns_f_air()); + assert_eq!(air_claims.evals_ef.len(), table.n_columns_ef_air()); + let mut evals = air_claims + .evals_f + .iter() + .copied() + .enumerate() + .collect::>(); + for (col_index, value) in air_claims.evals_ef.into_iter().enumerate() { + let transposed = verifier_state.next_extension_scalars_vec(DIMENSION)?; + if dot_product_with_base(&transposed) != value { + return Err(ProofError::InvalidProof); + } + for (j, v) in transposed.into_iter().enumerate() { + let virtual_index = table.n_columns_f_air() + col_index * DIMENSION + j; + evals.insert(virtual_index, v); + } + } + + res.push((air_claims.point.clone(), evals)); + + Ok(res) } diff --git a/crates/lean_prover/tests/hash_chain.rs b/crates/lean_prover/tests/hash_chain.rs deleted file mode 100644 index 9a304f2b..00000000 --- a/crates/lean_prover/tests/hash_chain.rs +++ /dev/null @@ -1,85 +0,0 @@ -use lean_compiler::*; -use lean_prover::{prove_execution::prove_execution, verify_execution::verify_execution, whir_config_builder}; -use lean_vm::{F, execute_bytecode}; -use multilinear_toolkit::prelude::*; -use std::time::Instant; -use xmss::iterate_hash; - -#[test] -fn benchmark_poseidon_chain() { - let program_str = r#" - - const LOG_CHAIN_LENGTH = LOG_CHAIN_LENGTH_PLACEHOLDER; - const CHAIN_LENGTH = 2 ** LOG_CHAIN_LENGTH; - const COMPRESSION = 1; - const UNROLLED_STEPS = 2**7; - - fn main() { - - // current implem panics if some precompiles are not used... (TODO) - poseidon_24_null_hash_ptr = 5; - zero = 0; - for i in 0..2**9 { - poseidon24(0, 0, poseidon_24_null_hash_ptr); - dot_product_ee(0, 0, zero, 1); - } - - buff = malloc_vec(CHAIN_LENGTH + 1); - poseidon16(0, 0, buff, COMPRESSION); - - for i in 0..CHAIN_LENGTH / UNROLLED_STEPS { - offset = buff + i * UNROLLED_STEPS; - for j in 0..UNROLLED_STEPS unroll { - poseidon16(offset + j, 0, offset + (j + 1), COMPRESSION); - } - } - - buff_ptr = (buff + (CHAIN_LENGTH-1)) * 8; - public_input = public_input_start; - for i in 0..8 { - assert buff_ptr[i] == public_input[i]; - } - - return; - } - "# - .to_string(); - - const LOG_CHAIN_LENGTH: usize = 17; - const CHAIN_LENGTH: usize = 1 << LOG_CHAIN_LENGTH; - - let program_str = program_str.replace("LOG_CHAIN_LENGTH_PLACEHOLDER", &LOG_CHAIN_LENGTH.to_string()); - - let mut public_input = F::zero_vec(1 << 13); - public_input[0..8].copy_from_slice(&iterate_hash(&Default::default(), CHAIN_LENGTH)); - - let private_input = vec![]; - - utils::init_tracing(); - let bytecode = compile_program(&ProgramSource::Raw(program_str)); - let no_vec_runtime_memory = execute_bytecode( - &bytecode, - (&public_input, &private_input), - 1 << (3 + LOG_CHAIN_LENGTH), - false, - (&vec![], &vec![]), - Default::default(), - ) - .no_vec_runtime_memory; - - let time = Instant::now(); - let proof_data = prove_execution( - &bytecode, - (&public_input, &private_input), - whir_config_builder(), - no_vec_runtime_memory, - false, - (&vec![], &vec![]), // TODO poseidons precomputed - Default::default(), // TODO merkle path hints - ) - .0; - let vm_time = time.elapsed(); - verify_execution(&bytecode, &public_input, proof_data, whir_config_builder()).unwrap(); - - println!("VM proof time: {vm_time:?}"); -} diff --git a/crates/lean_prover/tests/test_zkvm.rs b/crates/lean_prover/tests/test_zkvm.rs deleted file mode 100644 index dbc43f44..00000000 --- a/crates/lean_prover/tests/test_zkvm.rs +++ /dev/null @@ -1,196 +0,0 @@ -use std::collections::VecDeque; - -use lean_compiler::*; -use lean_prover::{prove_execution::prove_execution, verify_execution::verify_execution, whir_config_builder}; -use lean_vm::*; -use multilinear_toolkit::prelude::*; -use rand::{Rng, SeedableRng, rngs::StdRng}; -use utils::{poseidon16_permute, poseidon24_permute}; - -#[test] -fn test_zk_vm_all_precompiles() { - let program_str = r#" - const DIM = 5; - const COMPRESSION = 1; - const PERMUTATION = 0; - const N = 11; - const MERKLE_HEIGHT_1 = 10; - const LEAF_POS_1 = 781; - const MERKLE_HEIGHT_2 = 15; - const LEAF_POS_2 = 178; - fn main() { - pub_start = public_input_start; - pub_start_vec = pub_start / 8; - - poseidon16(pub_start_vec, pub_start_vec + 1, pub_start_vec + 2, PERMUTATION); - poseidon16(pub_start_vec + 4, pub_start_vec + 5, pub_start_vec + 6, COMPRESSION); - poseidon24(pub_start_vec + 7, pub_start_vec + 9, pub_start_vec + 10); - dot_product_be(pub_start + 88, pub_start + 88 + N, pub_start + 1000, N); - dot_product_ee(pub_start + 88 + N, pub_start + 88 + N * (DIM + 1), pub_start + 1000 + DIM, N); - merkle_verify((pub_start + 2000) / 8, LEAF_POS_1, (pub_start + 2000 + 8) / 8, MERKLE_HEIGHT_1); - merkle_verify((pub_start + 2000 + 16) / 8, LEAF_POS_2, (pub_start + 2000 + 24) / 8, MERKLE_HEIGHT_2); - index_res_slice_hash = 10000; - slice_hash(5, 6, index_res_slice_hash, 3); - eq_poly_base_ext(pub_start + 1100, pub_start +1100 + 3, pub_start + 1100 + (DIM + 1) * 3, 3); - - return; - } - "#; - - const N: usize = 11; - - let mut rng = StdRng::seed_from_u64(0); - let mut public_input = F::zero_vec(1 << 13); - - let poseidon_16_perm_input: [F; 16] = rng.random(); - public_input[..16].copy_from_slice(&poseidon_16_perm_input); - public_input[16..32].copy_from_slice(&poseidon16_permute(poseidon_16_perm_input)); - - let poseidon_16_compress_input: [F; 16] = rng.random(); - public_input[32..48].copy_from_slice(&poseidon_16_compress_input); - public_input[48..56].copy_from_slice(&poseidon16_permute(poseidon_16_compress_input)[..8]); - - let poseidon_24_input: [F; 24] = rng.random(); - public_input[56..80].copy_from_slice(&poseidon_24_input); - public_input[80..88].copy_from_slice(&poseidon24_permute(poseidon_24_input)[16..]); - - let dot_product_slice_base: [F; N] = rng.random(); - let dot_product_slice_ext_a: [EF; N] = rng.random(); - let dot_product_slice_ext_b: [EF; N] = rng.random(); - - public_input[88..][..N].copy_from_slice(&dot_product_slice_base); - public_input[88 + N..][..N * DIMENSION].copy_from_slice( - &dot_product_slice_ext_a - .iter() - .flat_map(|&x| x.as_basis_coefficients_slice().to_vec()) - .collect::>(), - ); - public_input[88 + N + N * DIMENSION..][..N * DIMENSION].copy_from_slice( - &dot_product_slice_ext_b - .iter() - .flat_map(|&x| x.as_basis_coefficients_slice().to_vec()) - .collect::>(), - ); - let dot_product_base_ext: EF = dot_product(dot_product_slice_ext_a.into_iter(), dot_product_slice_base.into_iter()); - let dot_product_ext_ext: EF = dot_product(dot_product_slice_ext_a.into_iter(), dot_product_slice_ext_b.into_iter()); - - public_input[1000..][..DIMENSION].copy_from_slice(dot_product_base_ext.as_basis_coefficients_slice()); - public_input[1000 + DIMENSION..][..DIMENSION].copy_from_slice(dot_product_ext_ext.as_basis_coefficients_slice()); - - let slice_a: [F; 3] = rng.random(); - let slice_b: [EF; 3] = rng.random(); - let poly_eq = MultilinearPoint(slice_b.to_vec()) - .eq_poly_outside(&MultilinearPoint(slice_a.iter().map(|&x| EF::from(x)).collect())); - public_input[1100..][..3].copy_from_slice(&slice_a); - public_input[1100 + 3..][..3 * DIMENSION].copy_from_slice( - slice_b - .iter() - .flat_map(|&x| x.as_basis_coefficients_slice().to_vec()) - .collect::>() - .as_slice(), - ); - public_input[1100 + 3 + 3 * DIMENSION..][..DIMENSION].copy_from_slice(poly_eq.as_basis_coefficients_slice()); - - fn add_merkle_path( - rng: &mut StdRng, - public_input: &mut [F], - merkle_height: usize, - leaf_position: usize, - ) -> Vec<[F; 8]> { - let leaf: [F; VECTOR_LEN] = rng.random(); - public_input[..VECTOR_LEN].copy_from_slice(&leaf); - let mut merkle_path = Vec::new(); - let mut current_digest = leaf; - for i in 0..merkle_height { - let sibling: [F; VECTOR_LEN] = rng.random(); - merkle_path.push(sibling); - let (left, right) = if (leaf_position >> i) & 1 == 0 { - (current_digest, sibling) - } else { - (sibling, current_digest) - }; - current_digest = poseidon16_permute([left.to_vec(), right.to_vec()].concat().try_into().unwrap()) - [..VECTOR_LEN] - .try_into() - .unwrap(); - } - let root = current_digest; - public_input[VECTOR_LEN..][..VECTOR_LEN].copy_from_slice(&root); - merkle_path - } - - let merkle_path_1 = add_merkle_path(&mut rng, &mut public_input[2000..], 10, 781); - let merkle_path_2 = add_merkle_path(&mut rng, &mut public_input[2000 + 16..], 15, 178); - - let mut merkle_path_hints = VecDeque::new(); - merkle_path_hints.push_back(merkle_path_1); - merkle_path_hints.push_back(merkle_path_2); - - test_zk_vm_helper(program_str, (&public_input, &[]), 0, merkle_path_hints); -} - -#[test] -fn test_prove_fibonacci() { - let program_str = r#" - const N = FIB_N_PLACEHOLDER; - const STEPS = 10000; // N should be a multiple of STEPS - const N_STEPS = N / STEPS; - - fn main() { - x, y = fibonacci_step(0, 1, N_STEPS); - print(x); - return; - } - - fn fibonacci_step(a, b, steps_remaining) -> 2 { - if steps_remaining == 0 { - return a, b; - } - new_a, new_b = fibonacci_const(a, b, STEPS); - res_a, res_b = fibonacci_step(new_a, new_b, steps_remaining - 1); - return res_a, res_b; - } - - fn fibonacci_const(a, b, const n) -> 2 { - buff = malloc(n + 2); - buff[0] = a; - buff[1] = b; - for j in 2..n + 2 unroll { - buff[j] = buff[j - 1] + buff[j - 2]; - } - return buff[n], buff[n + 1]; - } - "#; - - let n = std::env::var("FIB_N") - .unwrap_or("10000".to_string()) - .parse::() - .unwrap(); - let program_str = program_str.replace("FIB_N_PLACEHOLDER", &n.to_string()); - - test_zk_vm_helper(&program_str, (&[F::ZERO; 1 << 14], &[]), 0, Default::default()); -} - -fn test_zk_vm_helper( - program_str: &str, - (public_input, private_input): (&[F], &[F]), - no_vec_runtime_memory: usize, - merkle_path_hints: VecDeque>, -) { - utils::init_tracing(); - let bytecode = compile_program(&ProgramSource::Raw(program_str.to_string())); - let time = std::time::Instant::now(); - let (proof, summary) = prove_execution( - &bytecode, - (public_input, private_input), - whir_config_builder(), - no_vec_runtime_memory, - false, - (&vec![], &vec![]), - merkle_path_hints, - ); - let proof_time = time.elapsed(); - verify_execution(&bytecode, public_input, proof, whir_config_builder()).unwrap(); - println!("{summary}"); - println!("Proof time: {:.3} s", proof_time.as_secs_f32()); -} diff --git a/crates/lean_prover/witness_generation/Cargo.toml b/crates/lean_prover/witness_generation/Cargo.toml index a5d86ea3..01cfdb45 100644 --- a/crates/lean_prover/witness_generation/Cargo.toml +++ b/crates/lean_prover/witness_generation/Cargo.toml @@ -14,18 +14,13 @@ xmss.workspace = true rand.workspace = true p3-poseidon2.workspace = true p3-koala-bear.workspace = true -p3-challenger.workspace = true -p3-air.workspace = true p3-symmetric.workspace = true p3-util.workspace = true whir-p3.workspace = true tracing.workspace = true air.workspace = true sub_protocols.workspace = true -lookup.workspace = true lean_vm.workspace = true lean_compiler.workspace = true -derive_more.workspace = true multilinear-toolkit.workspace = true -poseidon_circuit.workspace = true p3-monty-31.workspace = true \ No newline at end of file diff --git a/crates/lean_prover/witness_generation/src/dot_product.rs b/crates/lean_prover/witness_generation/src/dot_product.rs deleted file mode 100644 index 38859ac3..00000000 --- a/crates/lean_prover/witness_generation/src/dot_product.rs +++ /dev/null @@ -1,85 +0,0 @@ -use lean_vm::{DIMENSION, EF, WitnessDotProduct}; -use multilinear_toolkit::prelude::*; - -pub fn build_dot_product_columns( - witness: &[WitnessDotProduct], - min_n_rows: usize, -) -> (Vec>, usize) { - let ( - mut flag, - mut len, - mut index_a, - mut index_b, - mut index_res, - mut value_a, - mut value_b, - mut res, - mut computation, - ) = ( - Vec::new(), - Vec::new(), - Vec::new(), - Vec::new(), - Vec::new(), - Vec::new(), - Vec::new(), - Vec::new(), - Vec::new(), - ); - for dot_product in witness { - assert!(dot_product.len > 0); - - // computation - { - computation.extend(EF::zero_vec(dot_product.len)); - let new_size = computation.len(); - computation[new_size - 1] = - dot_product.slice_0[dot_product.len - 1] * dot_product.slice_1[dot_product.len - 1]; - for i in 0..dot_product.len - 1 { - computation[new_size - 2 - i] = computation[new_size - 1 - i] - + dot_product.slice_0[dot_product.len - 2 - i] - * dot_product.slice_1[dot_product.len - 2 - i]; - } - } - - flag.push(EF::ONE); - flag.extend(EF::zero_vec(dot_product.len - 1)); - len.extend(((1..=dot_product.len).rev()).map(EF::from_usize)); - index_a.extend( - (0..dot_product.len).map(|i| EF::from_usize(dot_product.addr_0 + i * DIMENSION)), - ); - index_b.extend( - (0..dot_product.len).map(|i| EF::from_usize(dot_product.addr_1 + i * DIMENSION)), - ); - index_res.extend(vec![EF::from_usize(dot_product.addr_res); dot_product.len]); - value_a.extend(dot_product.slice_0.clone()); - value_b.extend(dot_product.slice_1.clone()); - res.extend(vec![dot_product.res; dot_product.len]); - } - - let padding_len = flag.len().next_power_of_two().max(min_n_rows) - flag.len(); - flag.extend(vec![EF::ONE; padding_len]); - len.extend(vec![EF::ONE; padding_len]); - index_a.extend(EF::zero_vec(padding_len)); - index_b.extend(EF::zero_vec(padding_len)); - index_res.extend(EF::zero_vec(padding_len)); - value_a.extend(EF::zero_vec(padding_len)); - value_b.extend(EF::zero_vec(padding_len)); - res.extend(EF::zero_vec(padding_len)); - computation.extend(EF::zero_vec(padding_len)); - - ( - vec![ - flag, - len, - index_a, - index_b, - index_res, - value_a, - value_b, - res, - computation, - ], - padding_len, - ) -} diff --git a/crates/lean_prover/witness_generation/src/execution_trace.rs b/crates/lean_prover/witness_generation/src/execution_trace.rs index eaf55515..cbedbf18 100644 --- a/crates/lean_prover/witness_generation/src/execution_trace.rs +++ b/crates/lean_prover/witness_generation/src/execution_trace.rs @@ -12,20 +12,9 @@ pub struct ExecutionTrace { pub memory: Vec, // of length a multiple of public_memory_size } -pub fn get_execution_trace(bytecode: &Bytecode, mut execution_result: ExecutionResult) -> ExecutionTrace { +pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResult) -> ExecutionTrace { assert_eq!(execution_result.pcs.len(), execution_result.fps.len()); - // padding to make proof work even on small programs (TODO make this more elegant) - let min_cycles = 32 << MIN_LOG_N_ROWS_PER_TABLE; - if execution_result.pcs.len() < min_cycles { - execution_result - .pcs - .resize(min_cycles, *execution_result.pcs.last().unwrap()); - execution_result - .fps - .resize(min_cycles, *execution_result.fps.last().unwrap()); - } - let n_cycles = execution_result.pcs.len(); let memory = &execution_result.memory; let mut main_trace: [Vec; N_EXEC_AIR_COLUMNS + N_TEMPORARY_EXEC_COLUMNS] = @@ -44,64 +33,72 @@ pub fn get_execution_trace(bytecode: &Bytecode, mut execution_result: ExecutionR let field_repr = field_representation(instruction); let mut addr_a = F::ZERO; - if field_repr[COL_INDEX_FLAG_A].is_zero() { + if field_repr[instr_idx(COL_FLAG_A)].is_zero() { // flag_a == 0 - addr_a = F::from_usize(fp) + field_repr[0]; // fp + operand_a + addr_a = F::from_usize(fp) + field_repr[instr_idx(COL_OPERAND_A)]; // fp + operand_a } let value_a = memory.0[addr_a.to_usize()].unwrap(); let mut addr_b = F::ZERO; - if field_repr[COL_INDEX_FLAG_B].is_zero() { + if field_repr[instr_idx(COL_FLAG_B)].is_zero() { // flag_b == 0 - addr_b = F::from_usize(fp) + field_repr[1]; // fp + operand_b + addr_b = F::from_usize(fp) + field_repr[instr_idx(COL_OPERAND_B)]; // fp + operand_b } let value_b = memory.0[addr_b.to_usize()].unwrap(); let mut addr_c = F::ZERO; - if field_repr[COL_INDEX_FLAG_C].is_zero() { + if field_repr[instr_idx(COL_FLAG_C)].is_zero() { // flag_c == 0 - addr_c = F::from_usize(fp) + field_repr[2]; // fp + operand_c + addr_c = F::from_usize(fp) + field_repr[instr_idx(COL_OPERAND_C)]; // fp + operand_c } else if let Instruction::Deref { shift_1, .. } = instruction { let operand_c = F::from_usize(*shift_1); - assert_eq!(field_repr[2], operand_c); // debug purpose + assert_eq!(field_repr[instr_idx(COL_OPERAND_C)], operand_c); // debug purpose addr_c = value_a + operand_c; } let value_c = memory.0[addr_c.to_usize()].unwrap(); for (j, field) in field_repr.iter().enumerate() { - *trace_row[j] = *field; + *trace_row[j + N_COMMITTED_EXEC_COLUMNS] = *field; } - let nu_a = field_repr[COL_INDEX_FLAG_A] * field_repr[COL_INDEX_OPERAND_A] - + (F::ONE - field_repr[COL_INDEX_FLAG_A]) * value_a; - let nu_b = field_repr[COL_INDEX_FLAG_B] * field_repr[COL_INDEX_OPERAND_B] - + (F::ONE - field_repr[COL_INDEX_FLAG_B]) * value_b; - let nu_c = - field_repr[COL_INDEX_FLAG_C] * F::from_usize(fp) + (F::ONE - field_repr[COL_INDEX_FLAG_C]) * value_c; - *trace_row[COL_INDEX_EXEC_NU_A] = nu_a; - *trace_row[COL_INDEX_EXEC_NU_B] = nu_b; - *trace_row[COL_INDEX_EXEC_NU_C] = nu_c; - - *trace_row[COL_INDEX_MEM_VALUE_A] = value_a; - *trace_row[COL_INDEX_MEM_VALUE_B] = value_b; - *trace_row[COL_INDEX_MEM_VALUE_C] = value_c; - *trace_row[COL_INDEX_PC] = F::from_usize(pc); - *trace_row[COL_INDEX_FP] = F::from_usize(fp); - *trace_row[COL_INDEX_MEM_ADDRESS_A] = addr_a; - *trace_row[COL_INDEX_MEM_ADDRESS_B] = addr_b; - *trace_row[COL_INDEX_MEM_ADDRESS_C] = addr_c; + let nu_a = field_repr[instr_idx(COL_FLAG_A)] * field_repr[instr_idx(COL_OPERAND_A)] + + (F::ONE - field_repr[instr_idx(COL_FLAG_A)]) * value_a; + let nu_b = field_repr[instr_idx(COL_FLAG_B)] * field_repr[instr_idx(COL_OPERAND_B)] + + (F::ONE - field_repr[instr_idx(COL_FLAG_B)]) * value_b; + let nu_c = field_repr[instr_idx(COL_FLAG_C)] * F::from_usize(fp) + + (F::ONE - field_repr[instr_idx(COL_FLAG_C)]) * value_c; + *trace_row[COL_EXEC_NU_A] = nu_a; + *trace_row[COL_EXEC_NU_B] = nu_b; + *trace_row[COL_EXEC_NU_C] = nu_c; + + *trace_row[COL_MEM_VALUE_A] = value_a; + *trace_row[COL_MEM_VALUE_B] = value_b; + *trace_row[COL_MEM_VALUE_C] = value_c; + *trace_row[COL_PC] = F::from_usize(pc); + *trace_row[COL_FP] = F::from_usize(fp); + *trace_row[COL_MEM_ADDRESS_A] = addr_a; + *trace_row[COL_MEM_ADDRESS_B] = addr_b; + *trace_row[COL_MEM_ADDRESS_C] = addr_c; }); let mut memory_padded = memory.0.par_iter().map(|&v| v.unwrap_or(F::ZERO)).collect::>(); - memory_padded.resize(memory.0.len().next_power_of_two(), F::ZERO); + // IMPRTANT: memory size should always be >= number of VM cycles + let padded_memory_len = (memory.0.len().max(n_cycles).max(1 << MIN_LOG_N_ROWS_PER_TABLE)).next_power_of_two(); + memory_padded.resize(padded_memory_len, F::ZERO); let ExecutionResult { mut traces, .. } = execution_result; + let poseidon_trace = traces.get_mut(&Table::poseidon16()).unwrap(); + fill_trace_poseidon_16(&mut poseidon_trace.base); + + let dot_product_trace = traces.get_mut(&Table::dot_product()).unwrap(); + fill_trace_dot_product(dot_product_trace, &memory_padded); + traces.insert( Table::execution(), TableTrace { base: Vec::from(main_trace), ext: vec![], - height: Default::default(), + log_n_rows: log2_ceil_usize(n_cycles), }, ); for table in traces.keys().copied().collect::>() { @@ -125,8 +122,8 @@ fn padd_table(table: &Table, traces: &mut BTreeMap) { .enumerate() .for_each(|(i, col)| assert_eq!(col.len(), h, "column {}, table {}", i, table.name())); - trace.height = TableHeight(h); - let padding_len = trace.height.padding_len(); + trace.log_n_rows = log2_ceil_usize(h + 1).max(MIN_LOG_N_ROWS_PER_TABLE); + let padding_len = (1 << trace.log_n_rows) - h; let padding_row_f = table.padding_row_f(); trace.base.par_iter_mut().enumerate().for_each(|(i, col)| { col.extend(repeat_n(padding_row_f[i], padding_len)); diff --git a/crates/lean_prover/witness_generation/src/instruction_encoder.rs b/crates/lean_prover/witness_generation/src/instruction_encoder.rs index 92d3412a..4b883763 100644 --- a/crates/lean_prover/witness_generation/src/instruction_encoder.rs +++ b/crates/lean_prover/witness_generation/src/instruction_encoder.rs @@ -1,5 +1,6 @@ use lean_vm::*; use multilinear_toolkit::prelude::*; +use utils::padd_with_zero_to_next_power_of_two; pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { let mut fields = [F::ZERO; N_INSTRUCTION_COLUMNS]; @@ -12,10 +13,10 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { } => { match operation { Operation::Add => { - fields[COL_INDEX_ADD] = F::ONE; + fields[instr_idx(COL_ADD)] = F::ONE; } Operation::Mul => { - fields[COL_INDEX_MUL] = F::ONE; + fields[instr_idx(COL_MUL)] = F::ONE; } } @@ -24,25 +25,25 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { set_nu_c(&mut fields, arg_c); } Instruction::Deref { shift_0, shift_1, res } => { - fields[COL_INDEX_DEREF] = F::ONE; - fields[COL_INDEX_FLAG_A] = F::ZERO; - fields[COL_INDEX_OPERAND_A] = F::from_usize(*shift_0); - fields[COL_INDEX_FLAG_C] = F::ONE; - fields[COL_INDEX_OPERAND_C] = F::from_usize(*shift_1); + fields[instr_idx(COL_DEREF)] = F::ONE; + fields[instr_idx(COL_FLAG_A)] = F::ZERO; + fields[instr_idx(COL_OPERAND_A)] = F::from_usize(*shift_0); + fields[instr_idx(COL_FLAG_C)] = F::ONE; + fields[instr_idx(COL_OPERAND_C)] = F::from_usize(*shift_1); match res { MemOrFpOrConstant::Constant(cst) => { - fields[COL_INDEX_AUX] = F::ONE; - fields[COL_INDEX_FLAG_B] = F::ONE; - fields[COL_INDEX_OPERAND_B] = *cst; + fields[instr_idx(COL_AUX_1)] = F::ONE; + fields[instr_idx(COL_FLAG_B)] = F::ONE; + fields[instr_idx(COL_OPERAND_B)] = *cst; } MemOrFpOrConstant::MemoryAfterFp { offset } => { - fields[COL_INDEX_AUX] = F::ONE; - fields[COL_INDEX_FLAG_B] = F::ZERO; - fields[COL_INDEX_OPERAND_B] = F::from_usize(*offset); + fields[instr_idx(COL_AUX_1)] = F::ONE; + fields[instr_idx(COL_FLAG_B)] = F::ZERO; + fields[instr_idx(COL_OPERAND_B)] = F::from_usize(*offset); } MemOrFpOrConstant::Fp => { - fields[COL_INDEX_AUX] = F::ZERO; - fields[COL_INDEX_FLAG_B] = F::ONE; + fields[instr_idx(COL_AUX_1)] = F::ZERO; + fields[instr_idx(COL_FLAG_B)] = F::ONE; } } } @@ -52,7 +53,7 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { dest, updated_fp, } => { - fields[COL_INDEX_JUMP] = F::ONE; + fields[instr_idx(COL_JUMP)] = F::ONE; set_nu_a(&mut fields, condition); set_nu_b(&mut fields, dest); set_nu_c(&mut fields, updated_fp); @@ -62,14 +63,16 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { arg_a, arg_b, arg_c, - aux, + aux_1, + aux_2, } => { - fields[COL_INDEX_IS_PRECOMPILE] = F::ONE; - fields[COL_INDEX_PRECOMPILE_INDEX] = table.embed(); + fields[instr_idx(COL_IS_PRECOMPILE)] = F::ONE; + fields[instr_idx(COL_PRECOMPILE_INDEX)] = table.embed(); set_nu_a(&mut fields, arg_a); set_nu_b(&mut fields, arg_b); set_nu_c(&mut fields, arg_c); - fields[COL_INDEX_AUX] = F::from_usize(*aux); + fields[instr_idx(COL_AUX_1)] = F::from_usize(*aux_1); + fields[instr_idx(COL_AUX_2)] = F::from_usize(*aux_2); } } fields @@ -78,12 +81,12 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { fn set_nu_a(fields: &mut [F; N_INSTRUCTION_COLUMNS], a: &MemOrConstant) { match a { MemOrConstant::Constant(cst) => { - fields[COL_INDEX_FLAG_A] = F::ONE; - fields[COL_INDEX_OPERAND_A] = *cst; + fields[instr_idx(COL_FLAG_A)] = F::ONE; + fields[instr_idx(COL_OPERAND_A)] = *cst; } MemOrConstant::MemoryAfterFp { offset } => { - fields[COL_INDEX_FLAG_A] = F::ZERO; - fields[COL_INDEX_OPERAND_A] = F::from_usize(*offset); + fields[instr_idx(COL_FLAG_A)] = F::ZERO; + fields[instr_idx(COL_OPERAND_A)] = F::from_usize(*offset); } } } @@ -91,12 +94,12 @@ fn set_nu_a(fields: &mut [F; N_INSTRUCTION_COLUMNS], a: &MemOrConstant) { fn set_nu_b(fields: &mut [F; N_INSTRUCTION_COLUMNS], b: &MemOrConstant) { match b { MemOrConstant::Constant(cst) => { - fields[COL_INDEX_FLAG_B] = F::ONE; - fields[COL_INDEX_OPERAND_B] = *cst; + fields[instr_idx(COL_FLAG_B)] = F::ONE; + fields[instr_idx(COL_OPERAND_B)] = *cst; } MemOrConstant::MemoryAfterFp { offset } => { - fields[COL_INDEX_FLAG_B] = F::ZERO; - fields[COL_INDEX_OPERAND_B] = F::from_usize(*offset); + fields[instr_idx(COL_FLAG_B)] = F::ZERO; + fields[instr_idx(COL_OPERAND_B)] = F::from_usize(*offset); } } } @@ -104,11 +107,19 @@ fn set_nu_b(fields: &mut [F; N_INSTRUCTION_COLUMNS], b: &MemOrConstant) { fn set_nu_c(fields: &mut [F; N_INSTRUCTION_COLUMNS], c: &MemOrFp) { match c { MemOrFp::Fp => { - fields[COL_INDEX_FLAG_C] = F::ONE; + fields[instr_idx(COL_FLAG_C)] = F::ONE; } MemOrFp::MemoryAfterFp { offset } => { - fields[COL_INDEX_FLAG_C] = F::ZERO; - fields[COL_INDEX_OPERAND_C] = F::from_usize(*offset); + fields[instr_idx(COL_FLAG_C)] = F::ZERO; + fields[instr_idx(COL_OPERAND_C)] = F::from_usize(*offset); } } } + +pub fn bytecode_to_multilinear_polynomial(instructions: &[Instruction]) -> Vec { + let res = instructions + .par_iter() + .flat_map(|instr| padd_with_zero_to_next_power_of_two(&field_representation(instr))) + .collect::>(); + padd_with_zero_to_next_power_of_two(&res) +} diff --git a/crates/lean_prover/witness_generation/src/lib.rs b/crates/lean_prover/witness_generation/src/lib.rs index 6f62d70b..e665ab96 100644 --- a/crates/lean_prover/witness_generation/src/lib.rs +++ b/crates/lean_prover/witness_generation/src/lib.rs @@ -5,6 +5,3 @@ mod instruction_encoder; pub use execution_trace::*; pub use instruction_encoder::*; - -mod poseidon_tables; -pub use poseidon_tables::*; diff --git a/crates/lean_prover/witness_generation/src/poseidon_tables.rs b/crates/lean_prover/witness_generation/src/poseidon_tables.rs deleted file mode 100644 index c9f343f4..00000000 --- a/crates/lean_prover/witness_generation/src/poseidon_tables.rs +++ /dev/null @@ -1,29 +0,0 @@ -use std::array; - -use lean_vm::{F, TableTrace}; -use multilinear_toolkit::prelude::*; -use p3_koala_bear::{KoalaBearInternalLayerParameters, KoalaBearParameters}; -use p3_monty_31::InternalLayerBaseParameters; -use poseidon_circuit::{PoseidonGKRLayers, PoseidonWitness, generate_poseidon_witness}; -use tracing::instrument; - -#[instrument(skip_all)] -pub fn generate_poseidon_witness_helper( - layers: &PoseidonGKRLayers, - trace: &TableTrace, - start_index: usize, - compressions: Option<&[F]>, -) -> PoseidonWitness, WIDTH, N_COMMITED_CUBES> -where - KoalaBearInternalLayerParameters: InternalLayerBaseParameters, -{ - let inputs: [_; WIDTH] = array::from_fn(|i| &trace.base[start_index + i][..]); - let n_poseidons = inputs[0].len(); - assert!(n_poseidons.is_power_of_two()); - let inputs_packed: [_; WIDTH] = array::from_fn(|i| PFPacking::::pack_slice(inputs[i]).to_vec()); // TODO avoid cloning - generate_poseidon_witness::, WIDTH, N_COMMITED_CUBES>( - inputs_packed, - layers, - compressions.map(|c| FPacking::::pack_slice(c).to_vec()), // TODO avoid cloning - ) -} diff --git a/crates/lean_vm/Cargo.toml b/crates/lean_vm/Cargo.toml index 8e64174e..1f3b9bf3 100644 --- a/crates/lean_vm/Cargo.toml +++ b/crates/lean_vm/Cargo.toml @@ -15,17 +15,12 @@ xmss.workspace = true rand.workspace = true p3-poseidon2.workspace = true p3-koala-bear.workspace = true -p3-challenger.workspace = true -p3-air.workspace = true p3-symmetric.workspace = true p3-util.workspace = true whir-p3.workspace = true tracing.workspace = true air.workspace = true -sub_protocols.workspace = true -lookup.workspace = true thiserror.workspace = true -derive_more.workspace = true multilinear-toolkit.workspace = true num_enum.workspace = true itertools.workspace = true diff --git a/crates/lean_vm/src/core/constants.rs b/crates/lean_vm/src/core/constants.rs index 78d33418..bf118c4a 100644 --- a/crates/lean_vm/src/core/constants.rs +++ b/crates/lean_vm/src/core/constants.rs @@ -1,14 +1,25 @@ -/// Vector dimension for field operations +use crate::Table; + +/// Large field = extension field of degree DIMENSION over koala-bear pub const DIMENSION: usize = 5; -/// Logarithm of vector length -pub const LOG_VECTOR_LEN: usize = 3; -/// Vector length (2^LOG_VECTOR_LEN) -pub const VECTOR_LEN: usize = 1 << LOG_VECTOR_LEN; +pub const DIGEST_LEN: usize = 8; + +/// Minimum and maximum memory size (as powers of two) +pub const MIN_LOG_MEMORY_SIZE: usize = 16; +pub const MAX_LOG_MEMORY_SIZE: usize = 29; -/// Maximum memory size for VM runner +/// Maximum memory size for VM runner (specific to this implementation) pub const MAX_RUNNER_MEMORY_SIZE: usize = 1 << 24; +/// Minimum and maximum number of rows per table (as powers of two), both inclusive +pub const MIN_LOG_N_ROWS_PER_TABLE: usize = 8; // Zero padding will be added to each at least, if this minimum is not reached, (ensuring AIR / GKR work fine, with SIMD, without too much edge cases). Long term, we should find a more elegant solution. +pub const MAX_LOG_N_ROWS_PER_TABLE: [(Table, usize); 3] = [ + (Table::execution(), 29), // 3 lookups + (Table::dot_product(), 25), // 4 lookups + (Table::poseidon16(), 25), // 4 lookups +]; // No overflow in logup: (TODO triple check) 3.2^29 + 4.2^25 + 4.2^25 < p = 2^31 - 2^24 + 1 + /// Starting program counter pub const STARTING_PC: usize = 1; @@ -27,25 +38,22 @@ pub const ENDING_PC: usize = 0; /// reserved_area: reserved for special constants (size = 48 field elements) /// program_input: the input of the program we want to prove /// -/// [reserved_area] = [00000000] [00000000] [10000000] [poseidon_16(0) (16 field elements)] [poseidon_24(0) (8 last field elements)] +/// [reserved_area] = [00000000] [00000000] [10000] [01000] [00100] [00010] [00001] [poseidon_16(0) (16 field elements)] [private input start pointer] /// -/// Convention: vectorized pointer of size 2, pointing to 16 zeros +/// Convention: pointing to 16 zeros pub const ZERO_VEC_PTR: usize = 0; -/// Convention: vectorized pointer of size 1, pointing to 10000000 -pub const ONE_VEC_PTR: usize = 2; +/// Convention: pointing to [10000] [01000] [00100] [00010] [00001] +pub const EXTENSION_BASIS_PTR: usize = 2 * DIGEST_LEN; -/// Convention: vectorized pointer of size 2, = the 16 elements of poseidon_16(0) -pub const POSEIDON_16_NULL_HASH_PTR: usize = 3; +/// Convention: pointing to the 16 elements of poseidon_16(0) +pub const POSEIDON_16_NULL_HASH_PTR: usize = EXTENSION_BASIS_PTR + DIMENSION.pow(2); -/// Convention: vectorized pointer of size 1, = the last 8 elements of poseidon_24(0) -pub const POSEIDON_24_NULL_HASH_PTR: usize = 5; +/// Pointer to start of private input +pub const PRIVATE_INPUT_START_PTR: usize = POSEIDON_16_NULL_HASH_PTR + DIGEST_LEN * 2; /// Normal pointer to start of program input -pub const NONRESERVED_PROGRAM_INPUT_START: usize = 6 * 8; +pub const NONRESERVED_PROGRAM_INPUT_START: usize = PRIVATE_INPUT_START_PTR + 1; -/// Precompiles Indexes -pub const TABLE_INDEX_POSEIDONS_16: usize = 1; // should be != 0 -pub const TABLE_INDEX_POSEIDONS_24: usize = 2; -pub const TABLE_INDEX_DOT_PRODUCTS: usize = 3; -pub const TABLE_INDEX_MULTILINEAR_EVAL: usize = 4; +/// The first element of basis corresponds to one +pub const ONE_VEC_PTR: usize = EXTENSION_BASIS_PTR; diff --git a/crates/lean_vm/src/core/label.rs b/crates/lean_vm/src/core/label.rs index 1ff8c0ef..c2adcb4a 100644 --- a/crates/lean_vm/src/core/label.rs +++ b/crates/lean_vm/src/core/label.rs @@ -1,4 +1,4 @@ -use crate::core::SourceLineNumber; +use crate::SourceLocation; /// Structured label for bytecode locations #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] @@ -12,16 +12,14 @@ pub enum Label { If { id: usize, kind: IfKind, - line_number: SourceLineNumber, + location: SourceLocation, }, /// Match statement end: @match_end_{id} MatchEnd(usize), /// Return from function call: @return_from_call_{id} - ReturnFromCall(usize, SourceLineNumber), + ReturnFromCall(usize, SourceLocation), /// Loop definition: @loop_{id}_line_{line_number} - Loop(usize, SourceLineNumber), - /// Built-in memory symbols - BuiltinSymbol(BuiltinSymbol), + Loop(usize, SourceLocation), /// Auxiliary variables during compilation AuxVar { kind: AuxKind, id: usize }, /// Custom/raw label for backwards compatibility or special cases @@ -40,11 +38,11 @@ pub enum IfKind { #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum BuiltinSymbol { - /// @public_input_start + /// @NONRESERVED_PROGRAM_INPUT_START PublicInputStart, - /// @pointer_to_zero_vector + /// @ZERO_VEC_PTR PointerToZeroVector, - /// @pointer_to_one_vector + /// @ONE_VEC_PTR PointerToOneVector, } @@ -71,17 +69,16 @@ impl std::fmt::Display for Label { match self { Self::Function(name) => write!(f, "@function_{name}"), Self::EndProgram => write!(f, "@end_program"), - Self::If { id, kind, line_number } => match kind { - IfKind::If => write!(f, "@if_{id}_line_{line_number}"), - IfKind::Else => write!(f, "@else_{id}_line_{line_number}"), - IfKind::End => write!(f, "@if_else_end_{id}_line_{line_number}"), + Self::If { id, kind, location } => match kind { + IfKind::If => write!(f, "@if_{id}_line_{}", location.line_number), + IfKind::Else => write!(f, "@else_{id}_line_{}", location.line_number), + IfKind::End => write!(f, "@if_else_end_{id}_line_{}", location.line_number), }, Self::MatchEnd(id) => write!(f, "@match_end_{id}"), Self::ReturnFromCall(id, line_number) => { write!(f, "@return_from_call_{id}_line_{line_number}") } Self::Loop(id, line_number) => write!(f, "@loop_{id}_line_{line_number}"), - Self::BuiltinSymbol(symbol) => write!(f, "{symbol}"), Self::AuxVar { kind, id } => match kind { AuxKind::AuxVar => write!(f, "@aux_var_{id}"), AuxKind::ArrayAux => write!(f, "@arr_aux_{id}"), @@ -98,42 +95,32 @@ impl std::fmt::Display for Label { } } -impl std::fmt::Display for BuiltinSymbol { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::PublicInputStart => write!(f, "@public_input_start"), - Self::PointerToZeroVector => write!(f, "@pointer_to_zero_vector"), - Self::PointerToOneVector => write!(f, "@pointer_to_one_vector"), - } - } -} - impl Label { pub fn function(name: impl Into) -> Self { Self::Function(name.into()) } - pub fn if_label(id: usize, line_number: SourceLineNumber) -> Self { + pub fn if_label(id: usize, location: SourceLocation) -> Self { Self::If { id, kind: IfKind::If, - line_number, + location, } } - pub fn else_label(id: usize, line_number: SourceLineNumber) -> Self { + pub fn else_label(id: usize, location: SourceLocation) -> Self { Self::If { id, kind: IfKind::Else, - line_number, + location, } } - pub fn if_else_end(id: usize, line_number: SourceLineNumber) -> Self { + pub fn if_else_end(id: usize, location: SourceLocation) -> Self { Self::If { id, kind: IfKind::End, - line_number, + location, } } @@ -141,12 +128,12 @@ impl Label { Self::MatchEnd(id) } - pub fn return_from_call(id: usize, line_number: SourceLineNumber) -> Self { - Self::ReturnFromCall(id, line_number) + pub fn return_from_call(id: usize, location: SourceLocation) -> Self { + Self::ReturnFromCall(id, location) } - pub fn loop_label(id: usize, line_number: SourceLineNumber) -> Self { - Self::Loop(id, line_number) + pub fn loop_label(id: usize, location: SourceLocation) -> Self { + Self::Loop(id, location) } pub fn aux_var(id: usize) -> Self { diff --git a/crates/lean_vm/src/core/types.rs b/crates/lean_vm/src/core/types.rs index 2948f3b5..3123b298 100644 --- a/crates/lean_vm/src/core/types.rs +++ b/crates/lean_vm/src/core/types.rs @@ -1,6 +1,8 @@ -use derive_more::Display; use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; -use std::cmp::Ordering; +use std::{ + cmp::Ordering, + fmt::{Display, Formatter}, +}; /// Base field type for VM operations pub type F = KoalaBear; @@ -24,13 +26,18 @@ pub type FunctionName = String; pub type FileId = usize; /// Location in source code -#[derive(Display, Hash, PartialEq, Eq, Debug, Clone, Copy)] -#[display("{}:{}", file_id, line_number)] +#[derive(Hash, PartialEq, Eq, Debug, Clone, Copy)] pub struct SourceLocation { pub file_id: FileId, pub line_number: SourceLineNumber, } +impl Display for SourceLocation { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "file_id: {}, line: {}", self.file_id, self.line_number) + } +} + fn cmp_source_location(a: &SourceLocation, b: &SourceLocation) -> Ordering { match a.file_id.cmp(&b.file_id) { Ordering::Less => Ordering::Less, diff --git a/crates/lean_vm/src/diagnostics/error.rs b/crates/lean_vm/src/diagnostics/error.rs index bb1d62bc..b9593737 100644 --- a/crates/lean_vm/src/diagnostics/error.rs +++ b/crates/lean_vm/src/diagnostics/error.rs @@ -31,8 +31,11 @@ pub enum RunnerError { #[error("Program counter out of bounds")] PCOutOfBounds, - #[error("DebugAssert failed: {0} at line {1}")] + #[error("DebugAssert failed: {0} at {1}")] DebugAssertFailed(String, SourceLocation), + + #[error("Invalid dot product")] + InvalidDotProduct, } pub type VMResult = Result; diff --git a/crates/lean_vm/src/execution/memory.rs b/crates/lean_vm/src/execution/memory.rs index 7ece75cd..be2b1cc7 100644 --- a/crates/lean_vm/src/execution/memory.rs +++ b/crates/lean_vm/src/execution/memory.rs @@ -1,11 +1,8 @@ //! Memory management for the VM -use crate::core::{DIMENSION, EF, F, MAX_RUNNER_MEMORY_SIZE, VECTOR_LEN}; +use crate::core::{DIMENSION, EF, F, MAX_RUNNER_MEMORY_SIZE}; use crate::diagnostics::RunnerError; use multilinear_toolkit::prelude::*; -pub const MIN_LOG_MEMORY_SIZE: usize = 16; -pub const MAX_LOG_MEMORY_SIZE: usize = 29; - /// VM memory implementation with sparse allocation #[derive(Debug, Clone, Default)] pub struct Memory(pub Vec>); @@ -32,6 +29,17 @@ impl Memory { .ok_or(RunnerError::UndefinedMemory(index)) } + /// Reads a single value from a memory address, returning ZERO if undefined or out of bounds. + /// Used for range check hint resolution where undefined memory is acceptable. + pub fn get_or_zero(&self, index: usize) -> F { + self.0.get(index).copied().flatten().unwrap_or(F::ZERO) + } + + /// Returns true if a memory address is defined + pub fn is_defined(&self, index: usize) -> bool { + self.0.get(index).copied().flatten().is_some() + } + /// Sets a value at a memory address /// /// Returns an error if the address is already set to a different value @@ -62,11 +70,6 @@ impl Memory { self.0.len() } - /// Get a vector from vectorized memory - pub fn get_vector(&self, index: usize) -> Result<[F; VECTOR_LEN], RunnerError> { - Ok(self.get_vectorized_slice(index, 1)?.try_into().unwrap()) - } - /// Get an extension field element from memory pub fn get_ef_element(&self, index: usize) -> Result { // index: non vectorized pointer @@ -77,12 +80,6 @@ impl Memory { Ok(EF::from_basis_coefficients_slice(&coeffs).unwrap()) } - pub fn get_vectorized_slice(&self, index: usize, len: usize) -> Result, RunnerError> { - let start = index * VECTOR_LEN; - let total_len = len * VECTOR_LEN; - (0..total_len).map(|i| self.get(start + i)).collect() - } - /// Get a continuous slice of extension field elements pub fn get_continuous_slice_of_ef_elements( &self, @@ -100,16 +97,14 @@ impl Memory { Ok(()) } - /// Set a vector in vectorized memory - pub fn set_vector(&mut self, index: usize, value: [F; VECTOR_LEN]) -> Result<(), RunnerError> { - for (i, v) in value.iter().enumerate() { - let idx = VECTOR_LEN * index + i; - self.set(idx, *v)?; - } - Ok(()) + pub fn get_slice(&self, start: usize, len: usize) -> Result, RunnerError> { + (0..len).map(|i| self.get(start + i)).collect() } - pub fn slice(&self, start: usize, len: usize) -> Result, RunnerError> { - (0..len).map(|i| self.get(start + i)).collect() + pub fn set_slice(&mut self, start: usize, values: &[F]) -> Result<(), RunnerError> { + for (i, v) in values.iter().enumerate() { + self.set(start + i, *v)?; + } + Ok(()) } } diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index e3e89644..ec393de9 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -1,20 +1,20 @@ //! VM execution runner use crate::core::{ - DIMENSION, F, FileId, NONRESERVED_PROGRAM_INPUT_START, ONE_VEC_PTR, POSEIDON_16_NULL_HASH_PTR, - POSEIDON_24_NULL_HASH_PTR, VECTOR_LEN, ZERO_VEC_PTR, + DIGEST_LEN, DIMENSION, F, FileId, NONRESERVED_PROGRAM_INPUT_START, POSEIDON_16_NULL_HASH_PTR, ZERO_VEC_PTR, }; use crate::diagnostics::{ExecutionResult, MemoryProfile, RunnerError, memory_profiling_report}; use crate::execution::{ExecutionHistory, Memory}; use crate::isa::Bytecode; use crate::isa::instruction::InstructionContext; use crate::{ - ALL_TABLES, CodeAddress, ENDING_PC, HintExecutionContext, N_TABLES, STARTING_PC, SourceLocation, Table, TableTrace, + ALL_TABLES, CodeAddress, ENDING_PC, EXTENSION_BASIS_PTR, HintExecutionContext, N_TABLES, PRIVATE_INPUT_START_PTR, + STARTING_PC, SourceLocation, Table, TableTrace, }; use multilinear_toolkit::prelude::*; -use std::collections::{BTreeMap, BTreeSet, VecDeque}; -use utils::{poseidon16_permute, poseidon24_permute, pretty_integer}; -use xmss::{Poseidon16History, Poseidon24History}; +use std::collections::{BTreeMap, BTreeSet}; +use utils::{ToUsize, poseidon16_permute, pretty_integer}; +use xmss::Poseidon16History; /// Number of instructions to show in stack trace const STACK_TRACE_INSTRUCTIONS: usize = 5000; @@ -27,37 +27,33 @@ pub fn build_public_memory(public_input: &[F]) -> Vec { public_memory[NONRESERVED_PROGRAM_INPUT_START..][..public_input.len()].copy_from_slice(public_input); // "zero" vector - let zero_start = ZERO_VEC_PTR * VECTOR_LEN; - for slot in public_memory.iter_mut().skip(zero_start).take(2 * VECTOR_LEN) { + let zero_start = ZERO_VEC_PTR; + for slot in public_memory.iter_mut().skip(zero_start).take(2 * DIGEST_LEN) { *slot = F::ZERO; } - // "one" vector - public_memory[ONE_VEC_PTR * VECTOR_LEN] = F::ONE; - let one_start = ONE_VEC_PTR * VECTOR_LEN + 1; - for slot in public_memory.iter_mut().skip(one_start).take(VECTOR_LEN - 1) { - *slot = F::ZERO; + // extension basis + for i in 0..DIMENSION { + let mut vec = F::zero_vec(DIMENSION); + vec[i] = F::ONE; + public_memory[EXTENSION_BASIS_PTR + i * DIMENSION..][..DIMENSION].copy_from_slice(&vec); } - public_memory[POSEIDON_16_NULL_HASH_PTR * VECTOR_LEN..(POSEIDON_16_NULL_HASH_PTR + 2) * VECTOR_LEN] - .copy_from_slice(&poseidon16_permute([F::ZERO; 16])); - public_memory[POSEIDON_24_NULL_HASH_PTR * VECTOR_LEN..(POSEIDON_24_NULL_HASH_PTR + 1) * VECTOR_LEN] - .copy_from_slice(&poseidon24_permute([F::ZERO; 24])[16..]); + public_memory[POSEIDON_16_NULL_HASH_PTR..][..2 * DIGEST_LEN].copy_from_slice(&poseidon16_permute([F::ZERO; 16])); + public_memory[PRIVATE_INPUT_START_PTR] = F::from_usize(public_memory_len); public_memory } -/// Execute bytecode with the given inputs and execution context +/// Execute bytecode with the given inputs and execution context, returning a Result /// /// This is the main VM execution entry point that processes bytecode instructions /// and generates execution traces with witness data. -pub fn execute_bytecode( +pub fn try_execute_bytecode( bytecode: &Bytecode, (public_input, private_input): (&[F], &[F]), - no_vec_runtime_memory: usize, // size of the "non-vectorized" runtime memory profiling: bool, - (poseidons_16_precomputed, poseidons_24_precomputed): (&Poseidon16History, &Poseidon24History), - merkle_path_hints: VecDeque>, -) -> ExecutionResult { + poseidons_16_precomputed: &Poseidon16History, +) -> Result { let mut std_out = String::new(); let mut instruction_history = ExecutionHistory::new(); let result = execute_bytecode_helper( @@ -65,15 +61,13 @@ pub fn execute_bytecode( (public_input, private_input), &mut std_out, &mut instruction_history, - no_vec_runtime_memory, profiling, - (poseidons_16_precomputed, poseidons_24_precomputed), - merkle_path_hints, + poseidons_16_precomputed, ) - .unwrap_or_else(|(last_pc, err)| { + .map_err(|(last_pc, err)| { let lines_history = &instruction_history.lines; let latest_instructions = &lines_history[lines_history.len().saturating_sub(STACK_TRACE_INSTRUCTIONS)..]; - println!( + eprintln!( "\n{}", crate::diagnostics::pretty_stack_trace( &bytecode.source_code, @@ -84,13 +78,13 @@ pub fn execute_bytecode( ) ); if !std_out.is_empty() { - println!("╔══════════════════════════════════════════════════════════════╗"); - println!("║ STD-OUT ║"); - println!("╚══════════════════════════════════════════════════════════════╝\n"); - print!("{std_out}"); + eprintln!("╔══════════════════════════════════════════════════════════════╗"); + eprintln!("║ STD-OUT ║"); + eprintln!("╚══════════════════════════════════════════════════════════════╝\n"); + eprint!("{std_out}"); } - panic!("Error during bytecode execution: {err}"); - }); + err + })?; if profiling { print_line_cycle_counts(instruction_history, &bytecode.filepaths); print_instruction_cycle_counts(bytecode, result.pcs.clone()); @@ -98,7 +92,25 @@ pub fn execute_bytecode( print!("{}", memory_profiling_report(mem_profile)); } } - result + Ok(result) +} + +/// Execute bytecode with the given inputs and execution context +/// +/// Panics on execution errors. Use `try_execute_bytecode` for error handling. +pub fn execute_bytecode( + bytecode: &Bytecode, + (public_input, private_input): (&[F], &[F]), + profiling: bool, + poseidons_16_precomputed: &Poseidon16History, +) -> ExecutionResult { + try_execute_bytecode( + bytecode, + (public_input, private_input), + profiling, + poseidons_16_precomputed, + ) + .unwrap_or_else(|err| panic!("Error during bytecode execution: {err}")) } fn print_line_cycle_counts(history: ExecutionHistory, filepaths: &BTreeMap) { @@ -139,6 +151,45 @@ fn print_instruction_cycle_counts(bytecode: &Bytecode, pcs: Vec) { println!(); } +/// Resolve pending deref hints in correct order +/// +/// Each constraint has form: memory[target_addr] = memory[memory[src_addr]] +/// Order matters because some src addresses might point to targets of other hints. +/// We iteratively resolve constraints until no more progress, then fill remaining with 0. +fn resolve_deref_hints(memory: &mut Memory, pending: &[(usize, usize)]) { + let mut resolved: BTreeSet = BTreeSet::new(); + + loop { + let mut made_progress = false; + + for &(target_addr, src_addr) in pending { + if resolved.contains(&target_addr) { + continue; + } + let Some(addr) = memory.0.get(src_addr).copied().flatten() else { + continue; + }; + let Some(value) = memory.0.get(addr.to_usize()).copied().flatten() else { + continue; + }; + memory.set(target_addr, value).unwrap(); + resolved.insert(target_addr); + made_progress = true; + } + + if !made_progress { + break; + } + } + + // Fill any remaining unresolved targets with 0 + for &(target_addr, _src_addr) in pending { + if !resolved.contains(&target_addr) { + let _ = memory.set(target_addr, F::ZERO); + } + } +} + /// Helper function that performs the actual bytecode execution #[allow(clippy::too_many_arguments)] // TODO fn execute_bytecode_helper( @@ -146,10 +197,8 @@ fn execute_bytecode_helper( (public_input, private_input): (&[F], &[F]), std_out: &mut String, instruction_history: &mut ExecutionHistory, - no_vec_runtime_memory: usize, profiling: bool, - (poseidons_16_precomputed, poseidons_24_precomputed): (&Poseidon16History, &Poseidon24History), - mut merkle_path_hints: VecDeque>, + poseidons_16_precomputed: &Poseidon16History, ) -> Result { // set public memory let mut memory = Memory::new(build_public_memory(public_input)); @@ -172,23 +221,19 @@ fn execute_bytecode_helper( fp = fp.next_multiple_of(DIMENSION); let initial_ap = fp + bytecode.starting_frame_memory; - let initial_ap_vec = (initial_ap + no_vec_runtime_memory).next_multiple_of(VECTOR_LEN) / VECTOR_LEN; let mut pc = STARTING_PC; let mut ap = initial_ap; - let mut ap_vec = initial_ap_vec; let mut cpu_cycles = 0; let mut last_checkpoint_cpu_cycles = 0; let mut checkpoint_ap = initial_ap; - let mut checkpoint_ap_vec = ap_vec; let mut pcs = Vec::new(); let mut fps = Vec::new(); let mut n_poseidon16_precomputed_used = 0; - let mut n_poseidon24_precomputed_used = 0; let mut traces = BTreeMap::from_iter((0..N_TABLES).map(|i| (ALL_TABLES[i], TableTrace::new(&ALL_TABLES[i])))); @@ -200,6 +245,9 @@ fn execute_bytecode_helper( let mut counter_hint = 0; let mut cpu_cycles_before_new_line = 0; + // Pending deref hints: (target_addr, src_addr) constraints to resolve at end + let mut pending_deref_hints: Vec<(usize, usize)> = Vec::new(); + while pc != ENDING_PC { if pc >= bytecode.instructions.len() { return Err((pc, RunnerError::PCOutOfBounds)); @@ -214,10 +262,8 @@ fn execute_bytecode_helper( for hint in bytecode.hints.get(&pc).unwrap_or(&vec![]) { let mut hint_ctx = HintExecutionContext { memory: &mut memory, - private_input_start: public_memory_size, fp, ap: &mut ap, - ap_vec: &mut ap_vec, counter_hint: &mut counter_hint, std_out, instruction_history, @@ -225,9 +271,9 @@ fn execute_bytecode_helper( cpu_cycles, last_checkpoint_cpu_cycles: &mut last_checkpoint_cpu_cycles, checkpoint_ap: &mut checkpoint_ap, - checkpoint_ap_vec: &mut checkpoint_ap_vec, profiling, memory_profile: &mut mem_profile, + pending_deref_hints: &mut pending_deref_hints, }; hint.execute_hint(&mut hint_ctx).map_err(|e| (pc, e))?; } @@ -244,26 +290,23 @@ fn execute_bytecode_helper( deref_counts: &mut deref_counts, jump_counts: &mut jump_counts, poseidon16_precomputed: poseidons_16_precomputed, - poseidon24_precomputed: poseidons_24_precomputed, - merkle_path_hints: &mut merkle_path_hints, n_poseidon16_precomputed_used: &mut n_poseidon16_precomputed_used, - n_poseidon24_precomputed_used: &mut n_poseidon24_precomputed_used, }; instruction .execute_instruction(&mut instruction_ctx) .map_err(|e| (pc, e))?; } + // Resolve pending deref hints in correct order + // Constraint: memory[target_addr] = memory[memory[src_addr]] + // Order matters because some src addresses might point to targets of other hints + resolve_deref_hints(&mut memory, &pending_deref_hints); + assert_eq!( n_poseidon16_precomputed_used, poseidons_16_precomputed.len(), "Warning: not all precomputed Poseidon16 were used" ); - assert_eq!( - n_poseidon24_precomputed_used, - poseidons_24_precomputed.len(), - "Warning: not all precomputed Poseidon24 were used" - ); assert_eq!(pc, ENDING_PC); pcs.push(pc); @@ -304,12 +347,7 @@ fn execute_bytecode_helper( "Private input size: {}\n", pretty_integer(private_input.len()) )); - summary.push_str(&format!( - "Runtime memory: {} ({:.2}% vec) (no vec mem: {})\n", - pretty_integer(runtime_memory_size), - (VECTOR_LEN * (ap_vec - initial_ap_vec)) as f64 / runtime_memory_size as f64 * 100.0, - no_vec_runtime_memory - )); + summary.push_str(&format!("Runtime memory: {}\n", pretty_integer(runtime_memory_size),)); let used_memory_cells = memory .0 .iter() @@ -327,21 +365,14 @@ fn execute_bytecode_helper( pretty_integer(n_poseidon16_precomputed_used), pretty_integer(poseidons_16_precomputed.len()) )); - summary.push_str(&format!( - "Poseidon2_24 precomputed used: {}/{}\n", - pretty_integer(n_poseidon24_precomputed_used), - pretty_integer(poseidons_24_precomputed.len()) - )); summary.push('\n'); - if traces[&Table::poseidon16_core()].base[0].len() + traces[&Table::poseidon24_core()].base[0].len() > 0 { + if !traces[&Table::poseidon16()].base[0].is_empty() { summary.push_str(&format!( - "Poseidon2_16 calls: {}, Poseidon2_24 calls: {}, (1 poseidon per {} instructions)\n", - pretty_integer(traces[&Table::poseidon16_core()].base[0].len()), - pretty_integer(traces[&Table::poseidon24_core()].base[0].len()), - cpu_cycles - / (traces[&Table::poseidon16_core()].base[0].len() + traces[&Table::poseidon24_core()].base[0].len()) + "Poseidon2_16 calls: {} (1 poseidon per {} instructions)\n", + pretty_integer(traces[&Table::poseidon16()].base[0].len()), + cpu_cycles / traces[&Table::poseidon16()].base[0].len() )); } // if !dot_products.is_empty() { diff --git a/crates/lean_vm/src/isa/hint.rs b/crates/lean_vm/src/isa/hint.rs index 2607d7ac..7625aa11 100644 --- a/crates/lean_vm/src/isa/hint.rs +++ b/crates/lean_vm/src/isa/hint.rs @@ -1,4 +1,4 @@ -use crate::core::{F, LOG_VECTOR_LEN, Label, SourceLocation, VECTOR_LEN}; +use crate::core::{F, Label, SourceLocation}; use crate::diagnostics::{MemoryObject, MemoryObjectType, MemoryProfile, RunnerError}; use crate::execution::{ExecutionHistory, Memory}; use crate::isa::operands::MemOrConstant; @@ -8,7 +8,7 @@ use std::fmt::{Display, Formatter}; use std::hash::Hash; use std::ops::Range; use strum::IntoEnumIterator; -use utils::{ToUsize, pretty_integer}; +use utils::{ToUsize, pretty_integer, to_big_endian_in_field, to_little_endian_in_field}; /// VM hints provide execution guidance and debugging information, but does not appear /// in the verified bytecode. @@ -29,10 +29,6 @@ pub enum Hint { offset: usize, /// The requested memory size size: MemOrConstant, - /// Whether memory should be vectorized (aligned) - vectorized: bool, - /// Length for vectorized memory allocation - vectorized_len: usize, }, /// Print debug information during execution Print { @@ -41,9 +37,6 @@ pub enum Hint { /// Values to print content: Vec, }, - PrivateInputStart { - res_offset: usize, - }, /// Report source code location for debugging LocationReport { /// Source code location @@ -61,6 +54,17 @@ pub enum Hint { /// Assert a boolean expression for debugging purposes DebugAssert(BooleanExpr, SourceLocation), Custom(CustomHint, Vec), + /// Deref hint for range checks - records a constraint to be resolved at end of execution + /// Constraint: memory[fp + offset_target] = memory[memory[fp + offset_src]] + /// The runner resolves all these constraints at the end, in the correct order. + DerefHint { + offset_src: usize, + offset_target: usize, + }, + /// Panic hint with optional error message (for debugging) + Panic { + message: Option, + }, } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, strum::EnumIter)] @@ -85,7 +89,7 @@ impl CustomHint { pub fn n_args_range(&self) -> Range { match self { Self::DecomposeBitsXMSS => 3..usize::MAX, - Self::DecomposeBits => 2..3, + Self::DecomposeBits => 4..5, } } @@ -110,11 +114,22 @@ impl CustomHint { } Self::DecomposeBits => { let to_decompose = args[0].read_value(ctx.memory, ctx.fp)?.to_usize(); - let mut memory_index = args[1].read_value(ctx.memory, ctx.fp)?.to_usize(); - for i in 0..F::bits() { - let bit = F::from_bool(to_decompose & (1 << i) != 0); - ctx.memory.set(memory_index, bit)?; - memory_index += 1; + let memory_index = args[1].read_value(ctx.memory, ctx.fp)?.to_usize(); + let num_bits = args[2].read_value(ctx.memory, ctx.fp)?.to_usize(); + let endianness = args[3].read_value(ctx.memory, ctx.fp)?.to_usize(); + assert!( + endianness == 0 || endianness == 1, + "Invalid endianness for DecomposeBits hint" + ); + assert!(num_bits <= F::bits()); + if endianness == 0 { + // Big-endian + ctx.memory + .set_slice(memory_index, &to_big_endian_in_field::(to_decompose, num_bits))? + } else { + // Little-endian + ctx.memory + .set_slice(memory_index, &to_little_endian_in_field::(to_decompose, num_bits))? } } } @@ -131,6 +146,7 @@ pub enum Boolean { Equal, Different, LessThan, + LessOrEqual, } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -144,10 +160,8 @@ pub struct BooleanExpr { #[derive(Debug)] pub struct HintExecutionContext<'a> { pub memory: &'a mut Memory, - pub private_input_start: usize, // normal pointer pub fp: usize, pub ap: &'a mut usize, - pub ap_vec: &'a mut usize, pub counter_hint: &'a mut usize, pub std_out: &'a mut String, pub instruction_history: &'a mut ExecutionHistory, @@ -155,9 +169,12 @@ pub struct HintExecutionContext<'a> { pub cpu_cycles: usize, pub last_checkpoint_cpu_cycles: &'a mut usize, pub checkpoint_ap: &'a mut usize, - pub checkpoint_ap_vec: &'a mut usize, pub profiling: bool, pub memory_profile: &'a mut MemoryProfile, + /// Pending deref hints: (target_addr, src_addr) + /// Constraint: memory[target_addr] = memory[memory[src_addr]] + /// Resolved at end of execution in correct order. + pub pending_deref_hints: &'a mut Vec<(usize, usize)>, } impl Hint { @@ -169,52 +186,22 @@ impl Hint { function_name, offset, size, - vectorized, - vectorized_len, } => { let size = size.read_value(ctx.memory, ctx.fp)?.to_usize(); - if *vectorized { - assert!(*vectorized_len >= LOG_VECTOR_LEN, "TODO"); - - // padding: - while !(*ctx.ap_vec * VECTOR_LEN).is_multiple_of(1 << *vectorized_len) { - *ctx.ap_vec += 1; - } - let allocation_start_addr = *ctx.ap_vec; - ctx.memory.set( - ctx.fp + *offset, - F::from_usize(allocation_start_addr >> (*vectorized_len - LOG_VECTOR_LEN)), - )?; - let size_vectors = size << (*vectorized_len - LOG_VECTOR_LEN); - let size_words = size_vectors * VECTOR_LEN; - *ctx.ap_vec += size_vectors; + let allocation_start_addr = *ctx.ap; + ctx.memory.set(ctx.fp + *offset, F::from_usize(allocation_start_addr))?; + *ctx.ap += size; - if ctx.profiling { - ctx.memory_profile.objects.insert( - allocation_start_addr * VECTOR_LEN, - MemoryObject { - object_type: MemoryObjectType::VectorHeapObject, - function_name: function_name.clone(), - size: size_words, - }, - ); - } - } else { - let allocation_start_addr = *ctx.ap; - ctx.memory.set(ctx.fp + *offset, F::from_usize(allocation_start_addr))?; - *ctx.ap += size; - - if ctx.profiling { - ctx.memory_profile.objects.insert( - allocation_start_addr, - MemoryObject { - object_type: MemoryObjectType::NonVectorHeapObject, - function_name: function_name.clone(), - size, - }, - ); - } + if ctx.profiling { + ctx.memory_profile.objects.insert( + allocation_start_addr, + MemoryObject { + object_type: MemoryObjectType::NonVectorHeapObject, + function_name: function_name.clone(), + size, + }, + ); } } Self::Custom(hint, args) => { @@ -237,19 +224,16 @@ impl Hint { } else { assert_eq!(values.len(), 2); let new_no_vec_memory = *ctx.ap - *ctx.checkpoint_ap; - let new_vec_memory = (*ctx.ap_vec - *ctx.checkpoint_ap_vec) * VECTOR_LEN; *ctx.std_out += &format!( - "[CHECKPOINT {}] new CPU cycles: {}, new runtime memory: {} ({:.1}% vec)\n", + "[CHECKPOINT {}] new CPU cycles: {}, new runtime memory: {}\n", values[1], pretty_integer(ctx.cpu_cycles - *ctx.last_checkpoint_cpu_cycles), - pretty_integer(new_no_vec_memory + new_vec_memory), - new_vec_memory as f64 / (new_no_vec_memory + new_vec_memory) as f64 * 100.0 + pretty_integer(new_no_vec_memory), ); } *ctx.last_checkpoint_cpu_cycles = ctx.cpu_cycles; *ctx.checkpoint_ap = *ctx.ap; - *ctx.checkpoint_ap_vec = *ctx.ap_vec; } let line_info = line_info.replace(';', ""); @@ -262,10 +246,6 @@ impl Hint { .push(*ctx.cpu_cycles_before_new_line); *ctx.cpu_cycles_before_new_line = 0; } - Self::PrivateInputStart { res_offset } => { - ctx.memory - .set(ctx.fp + *res_offset, F::from_usize(ctx.private_input_start))?; - } Self::Label { .. } => {} Self::StackFrame { label, size } => { if ctx.profiling { @@ -286,6 +266,7 @@ impl Hint { Boolean::Equal => left == right, Boolean::Different => left != right, Boolean::LessThan => left < right, + Boolean::LessOrEqual => left <= right, }; if !condition_holds { return Err(RunnerError::DebugAssertFailed( @@ -294,6 +275,20 @@ impl Hint { )); } } + Self::DerefHint { + offset_src, + offset_target, + } => { + // Record a deref constraint: memory[target_addr] = memory[memory[src_addr]] + let src_addr = ctx.fp + offset_src; + let target_addr = ctx.fp + offset_target; + ctx.pending_deref_hints.push((target_addr, src_addr)); + } + Self::Panic { message } => { + if let Some(msg) = message { + *ctx.std_out += &format!("[PANIC] {}\n", msg); + } + } } Ok(()) } @@ -306,17 +301,8 @@ impl Display for Hint { function_name: _, offset, size, - vectorized, - vectorized_len, } => { - if *vectorized { - write!(f, "m[fp + {offset}] = request_memory_vec({size}, {vectorized_len})") - } else { - write!(f, "m[fp + {offset}] = request_memory({size})") - } - } - Self::PrivateInputStart { res_offset } => { - write!(f, "m[fp + {res_offset}] = private_input_start()") + write!(f, "m[fp + {offset}] = request_memory({size})") } Self::Custom(hint, args) => { let decomposed = &args[0]; @@ -356,9 +342,19 @@ impl Display for Hint { Self::StackFrame { label, size } => { write!(f, "stack frame for {label} size {size}") } - Self::DebugAssert(bool_expr, line_number) => { - write!(f, "debug_assert {bool_expr} at line {line_number}") + Self::DebugAssert(bool_expr, location) => { + write!(f, "debug_assert {bool_expr} at {location:?}") + } + Self::DerefHint { + offset_src, + offset_target, + } => { + write!(f, "m[fp + {offset_target}] = m[m[fp + {offset_src}]]") } + Self::Panic { message } => match message { + Some(msg) => write!(f, "panic: \"{msg}\""), + None => write!(f, "panic"), + }, } } } @@ -375,6 +371,7 @@ impl Display for Boolean { Self::Equal => write!(f, "=="), Self::Different => write!(f, "!="), Self::LessThan => write!(f, "<"), + Self::LessOrEqual => write!(f, "<="), } } } diff --git a/crates/lean_vm/src/isa/instruction.rs b/crates/lean_vm/src/isa/instruction.rs index 87629925..135fa767 100644 --- a/crates/lean_vm/src/isa/instruction.rs +++ b/crates/lean_vm/src/isa/instruction.rs @@ -8,7 +8,7 @@ use crate::execution::Memory; use crate::tables::TableT; use crate::{Table, TableTrace}; use multilinear_toolkit::prelude::*; -use std::collections::{BTreeMap, VecDeque}; +use std::collections::BTreeMap; use std::fmt::{Display, Formatter}; use utils::ToUsize; @@ -54,7 +54,8 @@ pub enum Instruction { arg_a: MemOrConstant, arg_b: MemOrConstant, arg_c: MemOrFp, - aux: usize, + aux_1: usize, + aux_2: usize, }, } @@ -71,10 +72,7 @@ pub struct InstructionContext<'a> { pub deref_counts: &'a mut usize, pub jump_counts: &'a mut usize, pub poseidon16_precomputed: &'a [([F; 16], [F; 16])], - pub poseidon24_precomputed: &'a [([F; 24], [F; 8])], - pub merkle_path_hints: &'a mut VecDeque>, pub n_poseidon16_precomputed_used: &'a mut usize, - pub n_poseidon24_precomputed_used: &'a mut usize, } impl Instruction { @@ -132,10 +130,13 @@ impl Instruction { if res.is_value_unknown(ctx.memory, *ctx.fp) { let memory_address_res = res.memory_address(*ctx.fp)?; let ptr = ctx.memory.get(*ctx.fp + shift_0)?; - let value = ctx.memory.get(ptr.to_usize() + shift_1)?; - ctx.memory.set(memory_address_res, value)?; + if let Ok(value) = ctx.memory.get(ptr.to_usize() + shift_1) { + ctx.memory.set(memory_address_res, value)?; + } else { + // Do nothing, we are probably in a range check, will be resolved later + } } else { - let value = res.read_value(ctx.memory, *ctx.fp)?; + let value = res.read_value(ctx.memory, *ctx.fp).unwrap(); let ptr = ctx.memory.get(*ctx.fp + shift_0)?; ctx.memory.set(ptr.to_usize() + shift_1, value)?; } @@ -168,13 +169,15 @@ impl Instruction { arg_a, arg_b, arg_c, - aux: size, + aux_1, + aux_2, } => { table.execute( arg_a.read_value(ctx.memory, *ctx.fp)?, arg_b.read_value(ctx.memory, *ctx.fp)?, arg_c.read_value(ctx.memory, *ctx.fp)?, - *size, + *aux_1, + *aux_2, ctx, )?; @@ -215,9 +218,10 @@ impl Display for Instruction { arg_a, arg_b, arg_c, - aux, + aux_1, + aux_2, } => { - write!(f, "{}({arg_a}, {arg_b}, {arg_c}, {aux})", table.name()) + write!(f, "{}({arg_a}, {arg_b}, {arg_c}, {aux_1}, {aux_2})", table.name()) } } } diff --git a/crates/lean_vm/src/tables/dot_product/air.rs b/crates/lean_vm/src/tables/dot_product/air.rs index 04b354bf..0cdcdbc6 100644 --- a/crates/lean_vm/src/tables/dot_product/air.rs +++ b/crates/lean_vm/src/tables/dot_product/air.rs @@ -2,12 +2,11 @@ use crate::{ DIMENSION, EF, ExtraDataForBuses, TableT, eval_virtual_bus_column, tables::dot_product::DotProductPrecompile, }; use multilinear_toolkit::prelude::*; -use p3_air::{Air, AirBuilder}; /* (DIMENSION = 5) -| Flag | Len | IndexA | IndexB | IndexRes | ValueA | ValueB | Res | Computation | +| Start | Len | IndexA | IndexB | IndexRes | ValueA | ValueB | Res | Computation | | --------- | --- | ------ | ------ | -------- | ------------- | ----------- | ------------------ | ---------------------------------------- | | 1 | 4 | 90 | 211 | 74 | m[90] | m[211..216] | m[74..79] = r3 | r3 = m[90] x m[211..216] + r2 | | 0 | 3 | 91 | 216 | 74 | m[91] | m[216..221] | m[74..79] | r2 = m[91] x m[216..221] + r1 | @@ -21,49 +20,40 @@ use p3_air::{Air, AirBuilder}; */ // F columns -pub(super) const COL_FLAG: usize = 0; -pub(super) const COL_LEN: usize = 1; -pub(super) const COL_INDEX_A: usize = 2; -pub(super) const COL_INDEX_B: usize = 3; -pub(super) const COL_INDEX_RES: usize = 4; - +pub(super) const DOT_COL_IS_BE: usize = 0; +pub(super) const DOT_COL_FLAG: usize = 1; +pub(super) const DOT_COL_START: usize = 2; +pub(super) const DOT_COL_LEN: usize = 3; +pub const DOT_COL_A: usize = 4; +pub(super) const DOT_COL_B: usize = 5; +pub(super) const DOT_COL_RES: usize = 6; +pub const DOT_COL_VALUE_A_F: usize = 7; // EF columns -pub(super) const COL_VALUE_B: usize = 0; -pub(super) const COL_VALUE_RES: usize = 1; -pub(super) const COL_COMPUTATION: usize = 2; - -pub(super) const fn dot_product_air_col_value_a(be: bool) -> usize { - if be { 5 } else { 3 } -} - -pub(super) const fn dot_product_air_n_cols_f(be: bool) -> usize { - if be { 6 } else { 5 } -} - -pub(super) const fn dot_product_air_n_cols_ef(be: bool) -> usize { - if be { 3 } else { 4 } -} +pub const DOT_COL_VALUE_A_EF: usize = 0; +pub(super) const DOT_COL_VALUE_B: usize = 1; +pub(super) const DOT_COL_VALUE_RES: usize = 2; +pub(super) const DOT_COL_COMPUTATION: usize = 3; -impl Air for DotProductPrecompile { +impl Air for DotProductPrecompile { type ExtraData = ExtraDataForBuses; fn n_columns_f_air(&self) -> usize { - dot_product_air_n_cols_f(BE) + 8 } fn n_columns_ef_air(&self) -> usize { - dot_product_air_n_cols_ef(BE) + 4 } - fn degree(&self) -> usize { - 2 + fn degree_air(&self) -> usize { + 3 } fn n_constraints(&self) -> usize { - 8 + 15 // TODO: update } fn down_column_indexes_f(&self) -> Vec { - vec![COL_FLAG, COL_LEN, COL_INDEX_A, COL_INDEX_B] + vec![DOT_COL_START, DOT_COL_IS_BE, DOT_COL_LEN, DOT_COL_A, DOT_COL_B] } fn down_column_indexes_ef(&self) -> Vec { - vec![COL_COMPUTATION] + vec![DOT_COL_COMPUTATION] } #[inline] @@ -73,49 +63,74 @@ impl Air for DotProductPrecompile { let down_f = builder.down_f(); let down_ef = builder.down_ef(); - let flag = up_f[COL_FLAG].clone(); - let len = up_f[COL_LEN].clone(); - let index_a = up_f[COL_INDEX_A].clone(); - let index_b = up_f[COL_INDEX_B].clone(); - let index_res = up_f[COL_INDEX_RES].clone(); - let value_a = if BE { - AB::EF::from(up_f[dot_product_air_col_value_a(BE)].clone()) // TODO embdding overhead - } else { - up_ef[dot_product_air_col_value_a(BE)].clone() - }; - - let value_b = up_ef[COL_VALUE_B].clone(); - let res = up_ef[COL_VALUE_RES].clone(); - let computation = up_ef[COL_COMPUTATION].clone(); - - let flag_down = down_f[0].clone(); - let len_down = down_f[1].clone(); - let index_a_down = down_f[2].clone(); - let index_b_down = down_f[3].clone(); + let is_be = up_f[DOT_COL_IS_BE].clone(); + let flag = up_f[DOT_COL_FLAG].clone(); + let start = up_f[DOT_COL_START].clone(); + let len = up_f[DOT_COL_LEN].clone(); + let index_a = up_f[DOT_COL_A].clone(); + let index_b = up_f[DOT_COL_B].clone(); + let index_res = up_f[DOT_COL_RES].clone(); + let value_a_f = up_f[DOT_COL_VALUE_A_F].clone(); + + let value_a_ef = up_ef[DOT_COL_VALUE_A_EF].clone(); + let value_b = up_ef[DOT_COL_VALUE_B].clone(); + let res = up_ef[DOT_COL_VALUE_RES].clone(); + let computation = up_ef[DOT_COL_COMPUTATION].clone(); + + let start_down = down_f[0].clone(); + let is_be_down = down_f[1].clone(); + let len_down = down_f[2].clone(); + let index_a_down = down_f[3].clone(); + let index_b_down = down_f[4].clone(); let computation_down = down_ef[0].clone(); - builder.eval_virtual_column(eval_virtual_bus_column::( - extra_data, - AB::F::from_usize(self.identifier().index()), - flag.clone(), - &[index_a.clone(), index_b.clone(), index_res.clone(), len.clone()], - )); - + if BUS { + builder.eval_virtual_column(eval_virtual_bus_column::( + extra_data, + AB::F::from_usize(self.table().index()), + flag.clone(), + &[ + index_a.clone(), + index_b.clone(), + index_res.clone(), + len.clone(), + is_be.clone(), + ], + )); + } else { + builder.declare_values(&[ + index_a.clone(), + index_b.clone(), + index_res.clone(), + len.clone(), + is_be.clone(), + ]); + } + + let is_ee = AB::F::ONE - is_be.clone(); + + builder.assert_bool(start.clone()); builder.assert_bool(flag.clone()); + builder.assert_zero(flag.clone() * (AB::F::ONE - start.clone())); + builder.assert_bool(is_be.clone()); + + let mode_switch = (is_be_down.clone() - is_be.clone()).square(); + builder.assert_zero(mode_switch.clone() * (AB::F::ONE - start_down.clone())); + let value_a = AB::EF::from(is_be.clone() * value_a_f.clone()) + value_a_ef.clone() * is_ee.clone(); let product_up = value_b * value_a; - let not_flag_down = AB::F::ONE - flag_down.clone(); + let not_flag_down = AB::F::ONE - start_down.clone(); builder.assert_eq_ef( computation.clone(), product_up.clone() + computation_down * not_flag_down.clone(), ); builder.assert_zero(not_flag_down.clone() * (len.clone() - (len_down + AB::F::ONE))); - builder.assert_zero(flag_down * (len - AB::F::ONE)); - let index_a_increment = AB::F::from_usize(if BE { 1 } else { DIMENSION }); + builder.assert_zero(start_down * (len - AB::F::ONE)); + let index_a_increment = AB::F::from_usize(DIMENSION) * is_ee.clone() + is_be.clone(); builder.assert_zero(not_flag_down.clone() * (index_a - (index_a_down - index_a_increment))); builder.assert_zero(not_flag_down * (index_b - (index_b_down - AB::F::from_usize(DIMENSION)))); - builder.assert_zero_ef((computation - res) * flag); + builder.assert_zero_ef((computation - res) * start); } } diff --git a/crates/lean_vm/src/tables/dot_product/exec.rs b/crates/lean_vm/src/tables/dot_product/exec.rs index 4b25d40a..2333c72b 100644 --- a/crates/lean_vm/src/tables/dot_product/exec.rs +++ b/crates/lean_vm/src/tables/dot_product/exec.rs @@ -16,7 +16,7 @@ pub(super) fn exec_dot_product_be( ) -> Result<(), RunnerError> { assert!(size > 0); - let slice_0 = memory.slice(ptr_arg_0.to_usize(), size)?; + let slice_0 = memory.get_slice(ptr_arg_0.to_usize(), size)?; let slice_1 = memory.get_continuous_slice_of_ef_elements(ptr_arg_1.to_usize(), size)?; let dot_product_result = dot_product::(slice_1.iter().copied(), slice_0.iter().copied()); @@ -24,7 +24,7 @@ pub(super) fn exec_dot_product_be( { { - let computation = &mut trace.ext[COL_COMPUTATION]; + let computation = &mut trace.ext[DOT_COL_COMPUTATION]; computation.extend(EF::zero_vec(size)); let new_size = computation.len(); computation[new_size - 1] = slice_1[size - 1] * slice_0[size - 1]; @@ -34,15 +34,19 @@ pub(super) fn exec_dot_product_be( } } - trace.base[COL_FLAG].push(F::ONE); - trace.base[COL_FLAG].extend(F::zero_vec(size - 1)); - trace.base[COL_LEN].extend(((1..=size).rev()).map(F::from_usize)); - trace.base[COL_INDEX_A].extend((0..size).map(|i| F::from_usize(ptr_arg_0.to_usize() + i))); - trace.base[COL_INDEX_B].extend((0..size).map(|i| F::from_usize(ptr_arg_1.to_usize() + i * DIMENSION))); - trace.base[COL_INDEX_RES].extend(vec![F::from_usize(ptr_res.to_usize()); size]); - trace.base[dot_product_air_col_value_a(true)].extend(slice_0); - trace.ext[COL_VALUE_B].extend(slice_1); - trace.ext[COL_VALUE_RES].extend(vec![dot_product_result; size]); + trace.base[DOT_COL_IS_BE].extend(std::iter::repeat_n(F::from_bool(true), size)); + trace.base[DOT_COL_FLAG].push(F::ONE); + trace.base[DOT_COL_FLAG].extend(F::zero_vec(size - 1)); + trace.base[DOT_COL_START].push(F::ONE); + trace.base[DOT_COL_START].extend(F::zero_vec(size - 1)); + trace.base[DOT_COL_LEN].extend(((1..=size).rev()).map(F::from_usize)); + trace.base[DOT_COL_A].extend((0..size).map(|i| F::from_usize(ptr_arg_0.to_usize() + i))); + trace.base[DOT_COL_B].extend((0..size).map(|i| F::from_usize(ptr_arg_1.to_usize() + i * DIMENSION))); + trace.base[DOT_COL_RES].extend(vec![F::from_usize(ptr_res.to_usize()); size]); + trace.ext[DOT_COL_VALUE_B].extend(slice_1); + trace.ext[DOT_COL_VALUE_RES].extend(vec![dot_product_result; size]); + + // trace.base[COL_VALUE_A_F] and trace.ext[COL_VALUE_A_EF] are filled later } Ok(()) @@ -58,24 +62,55 @@ pub(super) fn exec_dot_product_ee( ) -> Result<(), RunnerError> { assert!(size > 0); - let slice_0 = memory.get_continuous_slice_of_ef_elements(ptr_arg_0.to_usize(), size)?; - - let (slice_1, dot_product_result) = if ptr_arg_1.to_usize() == ONE_VEC_PTR * VECTOR_LEN { + let (slice_0, slice_1, dot_product_result) = if ptr_arg_1.to_usize() == ONE_VEC_PTR { if size != 1 { unimplemented!("weird use case"); } - (vec![EF::ONE], slice_0[0]) + if ptr_res.to_usize() == ZERO_VEC_PTR { + memory.set_ef_element(ptr_arg_0.to_usize(), EF::ZERO)?; + (vec![EF::ZERO], vec![EF::ONE], EF::ZERO) + } else { + let slice_0 = memory.get_continuous_slice_of_ef_elements(ptr_arg_0.to_usize(), size)?; + let res = slice_0[0]; + memory.set_ef_element(ptr_res.to_usize(), res)?; + (slice_0, vec![EF::ONE], res) + } } else { - let slice_1 = memory.get_continuous_slice_of_ef_elements(ptr_arg_1.to_usize(), size)?; - let dot_product_result = dot_product::(slice_1.iter().copied(), slice_0.iter().copied()); - (slice_1, dot_product_result) + match ( + memory.get_continuous_slice_of_ef_elements(ptr_arg_0.to_usize(), size), + memory.get_continuous_slice_of_ef_elements(ptr_arg_1.to_usize(), size), + memory.get_ef_element(ptr_res.to_usize()), + ) { + (Ok(s0), Ok(s1), Ok(res)) => { + if dot_product::(s0.iter().copied(), s1.iter().copied()) != res { + return Err(RunnerError::InvalidDotProduct); + } + (s0, s1, res) + } + (Ok(s0), Ok(s1), Err(_)) => { + let dot_product_result = dot_product::(s0.iter().copied(), s1.iter().copied()); + memory.set_ef_element(ptr_res.to_usize(), dot_product_result)?; + (s0, s1, dot_product_result) + } + (Err(_), Ok(s1), Ok(res)) if size == 1 => { + let div = res / s1[0]; + memory.set_ef_element(ptr_arg_0.to_usize(), div)?; + (vec![div], s1, res) + } + (Ok(s0), Err(_), Ok(res)) if size == 1 => { + let div = res / s0[0]; + memory.set_ef_element(ptr_arg_1.to_usize(), div)?; + (s0, vec![div], res) + } + _ => { + return Err(RunnerError::InvalidDotProduct); + } + } }; - memory.set_ef_element(ptr_res.to_usize(), dot_product_result)?; - { { - let computation = &mut trace.ext[COL_COMPUTATION]; + let computation = &mut trace.ext[DOT_COL_COMPUTATION]; computation.extend(EF::zero_vec(size)); let new_size = computation.len(); computation[new_size - 1] = slice_1[size - 1] * slice_0[size - 1]; @@ -85,16 +120,35 @@ pub(super) fn exec_dot_product_ee( } } - trace.base[COL_FLAG].push(F::ONE); - trace.base[COL_FLAG].extend(F::zero_vec(size - 1)); - trace.base[COL_LEN].extend(((1..=size).rev()).map(F::from_usize)); - trace.base[COL_INDEX_A].extend((0..size).map(|i| F::from_usize(ptr_arg_0.to_usize() + i * DIMENSION))); - trace.base[COL_INDEX_B].extend((0..size).map(|i| F::from_usize(ptr_arg_1.to_usize() + i * DIMENSION))); - trace.base[COL_INDEX_RES].extend(vec![F::from_usize(ptr_res.to_usize()); size]); - trace.ext[dot_product_air_col_value_a(false)].extend(slice_0); - trace.ext[COL_VALUE_B].extend(slice_1); - trace.ext[COL_VALUE_RES].extend(vec![dot_product_result; size]); + trace.base[DOT_COL_IS_BE].extend(std::iter::repeat_n(F::from_bool(false), size)); + trace.base[DOT_COL_FLAG].push(F::ONE); + trace.base[DOT_COL_FLAG].extend(F::zero_vec(size - 1)); + trace.base[DOT_COL_START].push(F::ONE); + trace.base[DOT_COL_START].extend(F::zero_vec(size - 1)); + trace.base[DOT_COL_LEN].extend(((1..=size).rev()).map(F::from_usize)); + trace.base[DOT_COL_A].extend((0..size).map(|i| F::from_usize(ptr_arg_0.to_usize() + i * DIMENSION))); + trace.base[DOT_COL_B].extend((0..size).map(|i| F::from_usize(ptr_arg_1.to_usize() + i * DIMENSION))); + trace.base[DOT_COL_RES].extend(vec![F::from_usize(ptr_res.to_usize()); size]); + trace.ext[DOT_COL_VALUE_B].extend(slice_1); + trace.ext[DOT_COL_VALUE_RES].extend(vec![dot_product_result; size]); + + // trace.base[COL_VALUE_A_F] and trace.ext[COL_VALUE_A_EF] are filled later } Ok(()) } + +pub fn fill_trace_dot_product(trace: &mut TableTrace, memory: &[F]) { + assert!(trace.base[DOT_COL_VALUE_A_F].is_empty()); + assert!(trace.ext[DOT_COL_VALUE_A_EF].is_empty()); + trace.base[DOT_COL_VALUE_A_F] = F::zero_vec(trace.base[DOT_COL_A].len()); + trace.ext[DOT_COL_VALUE_A_EF] = EF::zero_vec(trace.base[DOT_COL_A].len()); + for i in 0..trace.base[DOT_COL_A].len() { + // TODO parallelize + let addr = trace.base[DOT_COL_A][i].to_usize(); + let value_f = memory[addr]; + let value_ef = EF::from_basis_coefficients_slice(&memory[addr..][..DIMENSION]).unwrap(); + trace.base[DOT_COL_VALUE_A_F][i] = value_f; + trace.ext[DOT_COL_VALUE_A_EF][i] = value_ef; + } +} diff --git a/crates/lean_vm/src/tables/dot_product/mod.rs b/crates/lean_vm/src/tables/dot_product/mod.rs index 982e9aa0..0c3a4d8e 100644 --- a/crates/lean_vm/src/tables/dot_product/mod.rs +++ b/crates/lean_vm/src/tables/dot_product/mod.rs @@ -7,92 +7,69 @@ use multilinear_toolkit::prelude::*; mod air; use air::*; mod exec; +pub use exec::fill_trace_dot_product; /// Dot product between 2 vectors in the extension field EF. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct DotProductPrecompile; // BE = true for base-extension, false for extension-extension +pub struct DotProductPrecompile; // BE = true for base-extension, false for extension-extension -impl TableT for DotProductPrecompile { +impl TableT for DotProductPrecompile { fn name(&self) -> &'static str { - if BE { "dot_product_be" } else { "dot_product_ee" } + "dot_product" } - fn identifier(&self) -> Table { - if BE { - Table::dot_product_be() - } else { - Table::dot_product_ee() - } + fn table(&self) -> Table { + Table::dot_product() } - fn commited_columns_f(&self) -> Vec { - vec![COL_FLAG, COL_LEN, COL_INDEX_A, COL_INDEX_B, COL_INDEX_RES] - } - - fn commited_columns_ef(&self) -> Vec { - vec![COL_COMPUTATION] - } - - fn normal_lookups_f(&self) -> Vec { - if BE { - vec![LookupIntoMemory { - index: COL_INDEX_A, - values: dot_product_air_col_value_a(BE), - }] - } else { - vec![] - } + fn lookups_f(&self) -> Vec { + vec![LookupIntoMemory { + index: DOT_COL_A, + values: vec![DOT_COL_VALUE_A_F], + }] } - fn normal_lookups_ef(&self) -> Vec { - let mut res = vec![ + fn lookups_ef(&self) -> Vec { + vec![ ExtensionFieldLookupIntoMemory { - index: COL_INDEX_B, - values: COL_VALUE_B, + index: DOT_COL_A, + values: DOT_COL_VALUE_A_EF, }, ExtensionFieldLookupIntoMemory { - index: COL_INDEX_RES, - values: COL_VALUE_RES, + index: DOT_COL_B, + values: DOT_COL_VALUE_B, }, - ]; - if !BE { - res.insert( - 0, - ExtensionFieldLookupIntoMemory { - index: COL_INDEX_A, - values: dot_product_air_col_value_a(BE), - }, - ); - } - res - } - - fn vector_lookups(&self) -> Vec { - vec![] + ExtensionFieldLookupIntoMemory { + index: DOT_COL_RES, + values: DOT_COL_VALUE_RES, + }, + ] } - fn buses(&self) -> Vec { - vec![Bus { - table: BusTable::Constant(self.identifier()), + fn bus(&self) -> Bus { + Bus { + table: BusTable::Constant(self.table()), direction: BusDirection::Pull, - selector: BusSelector::Column(COL_FLAG), - data: vec![COL_INDEX_A, COL_INDEX_B, COL_INDEX_RES, COL_LEN], - }] + selector: DOT_COL_FLAG, + data: vec![DOT_COL_A, DOT_COL_B, DOT_COL_RES, DOT_COL_LEN, DOT_COL_IS_BE], + } } fn padding_row_f(&self) -> Vec { [ vec![ - F::ONE, // StartFlag - F::ONE, // Len + F::ZERO, // Is BE + F::ZERO, // Flag + F::ONE, // Start + F::ONE, // Len ], - vec![F::ZERO; dot_product_air_n_cols_f(BE) - 2], + vec![F::ZERO; self.n_columns_f_air() - 4], ] .concat() } fn padding_row_ef(&self) -> Vec { - vec![EF::ZERO; dot_product_air_n_cols_ef(BE)] + vec![EF::ZERO; self.n_columns_ef_air()] } #[inline(always)] @@ -101,14 +78,17 @@ impl TableT for DotProductPrecompile { arg_a: F, arg_b: F, arg_c: F, - aux: usize, + size: usize, + is_be: usize, ctx: &mut InstructionContext<'_>, ) -> Result<(), RunnerError> { - let trace = ctx.traces.get_mut(&self.identifier()).unwrap(); - if BE { - exec_dot_product_be(arg_a, arg_b, arg_c, aux, ctx.memory, trace) + assert!(is_be == 0 || is_be == 1); + let is_be = is_be == 1; + let trace = ctx.traces.get_mut(&self.table()).unwrap(); + if is_be { + exec_dot_product_be(arg_a, arg_b, arg_c, size, ctx.memory, trace) } else { - exec_dot_product_ee(arg_a, arg_b, arg_c, aux, ctx.memory, trace) + exec_dot_product_ee(arg_a, arg_b, arg_c, size, ctx.memory, trace) } } } diff --git a/crates/lean_vm/src/tables/eq_poly_base_ext/air.rs b/crates/lean_vm/src/tables/eq_poly_base_ext/air.rs deleted file mode 100644 index b41760f9..00000000 --- a/crates/lean_vm/src/tables/eq_poly_base_ext/air.rs +++ /dev/null @@ -1,95 +0,0 @@ -use crate::{ - DIMENSION, EF, ExtraDataForBuses, TableT, eval_virtual_bus_column, - tables::eq_poly_base_ext::EqPolyBaseExtPrecompile, -}; -use multilinear_toolkit::prelude::*; -use p3_air::{Air, AirBuilder}; - -// F columns -pub(super) const COL_FLAG: usize = 0; -pub(super) const COL_LEN: usize = 1; -pub(super) const COL_INDEX_A: usize = 2; -pub(super) const COL_INDEX_B: usize = 3; -pub(super) const COL_INDEX_RES: usize = 4; -pub(super) const COL_VALUE_A: usize = 5; - -// EF columns -pub(super) const COL_VALUE_B: usize = 0; -pub(super) const COL_VALUE_RES: usize = 1; -pub(super) const COL_COMPUTATION: usize = 2; - -pub(super) const N_COLS_F: usize = 6; -pub(super) const N_COLS_EF: usize = 3; - -impl Air for EqPolyBaseExtPrecompile { - type ExtraData = ExtraDataForBuses; - - fn n_columns_f_air(&self) -> usize { - N_COLS_F - } - fn n_columns_ef_air(&self) -> usize { - N_COLS_EF - } - fn degree(&self) -> usize { - 4 - } - fn n_constraints(&self) -> usize { - 8 - } - fn down_column_indexes_f(&self) -> Vec { - vec![COL_FLAG, COL_LEN, COL_INDEX_A, COL_INDEX_B] - } - fn down_column_indexes_ef(&self) -> Vec { - vec![COL_COMPUTATION] - } - - #[inline] - fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { - let up_f = builder.up_f(); - let up_ef = builder.up_ef(); - let down_f = builder.down_f(); - let down_ef = builder.down_ef(); - - let flag = up_f[COL_FLAG].clone(); - let len = up_f[COL_LEN].clone(); - let index_a = up_f[COL_INDEX_A].clone(); - let index_b = up_f[COL_INDEX_B].clone(); - let index_res = up_f[COL_INDEX_RES].clone(); - let value_a = up_f[COL_VALUE_A].clone(); - - let value_b = up_ef[COL_VALUE_B].clone(); - let res = up_ef[COL_VALUE_RES].clone(); - let computation = up_ef[COL_COMPUTATION].clone(); - - let flag_down = down_f[0].clone(); - let len_down = down_f[1].clone(); - let index_a_down = down_f[2].clone(); - let index_b_down = down_f[3].clone(); - - let computation_down = down_ef[0].clone(); - - builder.eval_virtual_column(eval_virtual_bus_column::( - extra_data, - AB::F::from_usize(self.identifier().index()), - flag.clone(), - &[index_a.clone(), index_b.clone(), index_res.clone(), len.clone()], - )); - - builder.assert_bool(flag.clone()); - - let product_up = - value_b.clone() * value_a.clone() + (AB::EF::ONE - value_b.clone()) * (AB::F::ONE - value_a.clone()); - let not_flag_down = AB::F::ONE - flag_down.clone(); - builder.assert_eq_ef( - computation.clone(), - product_up.clone() * (computation_down * not_flag_down.clone() + flag_down.clone()), - ); - builder.assert_zero(not_flag_down.clone() * (len.clone() - (len_down + AB::F::ONE))); - builder.assert_zero(flag_down * (len - AB::F::ONE)); - let index_a_increment = AB::F::ONE; - builder.assert_zero(not_flag_down.clone() * (index_a - (index_a_down - index_a_increment))); - builder.assert_zero(not_flag_down * (index_b - (index_b_down - AB::F::from_usize(DIMENSION)))); - - builder.assert_zero_ef((computation - res) * flag); - } -} diff --git a/crates/lean_vm/src/tables/eq_poly_base_ext/exec.rs b/crates/lean_vm/src/tables/eq_poly_base_ext/exec.rs deleted file mode 100644 index 898588e8..00000000 --- a/crates/lean_vm/src/tables/eq_poly_base_ext/exec.rs +++ /dev/null @@ -1,46 +0,0 @@ -use crate::EF; -use crate::F; -use crate::Memory; -use crate::RunnerError; -use crate::TableTrace; -use crate::tables::eq_poly_base_ext::*; -use utils::ToUsize; - -pub(super) fn exec_eq_poly_base_ext( - ptr_arg_0: F, - ptr_arg_1: F, - ptr_res: F, - size: usize, - memory: &mut Memory, - trace: &mut TableTrace, -) -> Result<(), RunnerError> { - assert!(size > 0); - - let slice_0 = memory.slice(ptr_arg_0.to_usize(), size)?; - let slice_1 = memory.get_continuous_slice_of_ef_elements(ptr_arg_1.to_usize(), size)?; - - let computation = &mut trace.ext[COL_COMPUTATION]; - computation.extend(EF::zero_vec(size)); - let new_size = computation.len(); - computation[new_size - 1] = - slice_1[size - 1] * slice_0[size - 1] + (EF::ONE - slice_1[size - 1]) * (F::ONE - slice_0[size - 1]); - for i in 0..size - 1 { - computation[new_size - 2 - i] = computation[new_size - 1 - i] - * (slice_1[size - 2 - i] * slice_0[size - 2 - i] - + (EF::ONE - slice_1[size - 2 - i]) * (F::ONE - slice_0[size - 2 - i])); - } - let final_result = computation[new_size - size]; - memory.set_ef_element(ptr_res.to_usize(), final_result)?; - - trace.base[COL_FLAG].push(F::ONE); - trace.base[COL_FLAG].extend(F::zero_vec(size - 1)); - trace.base[COL_LEN].extend(((1..=size).rev()).map(F::from_usize)); - trace.base[COL_INDEX_A].extend((0..size).map(|i| F::from_usize(ptr_arg_0.to_usize() + i))); - trace.base[COL_INDEX_B].extend((0..size).map(|i| F::from_usize(ptr_arg_1.to_usize() + i * DIMENSION))); - trace.base[COL_INDEX_RES].extend(vec![F::from_usize(ptr_res.to_usize()); size]); - trace.base[COL_VALUE_A].extend(slice_0); - trace.ext[COL_VALUE_B].extend(slice_1); - trace.ext[COL_VALUE_RES].extend(vec![final_result; size]); - - Ok(()) -} diff --git a/crates/lean_vm/src/tables/eq_poly_base_ext/mod.rs b/crates/lean_vm/src/tables/eq_poly_base_ext/mod.rs deleted file mode 100644 index d58ac8f9..00000000 --- a/crates/lean_vm/src/tables/eq_poly_base_ext/mod.rs +++ /dev/null @@ -1,93 +0,0 @@ -use crate::{InstructionContext, tables::eq_poly_base_ext::exec::exec_eq_poly_base_ext, *}; -use multilinear_toolkit::prelude::*; - -mod air; -use air::*; -mod exec; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct EqPolyBaseExtPrecompile; - -impl TableT for EqPolyBaseExtPrecompile { - fn name(&self) -> &'static str { - "eq_poly_base_ext" - } - - fn identifier(&self) -> Table { - Table::eq_poly_base_ext() - } - - fn commited_columns_f(&self) -> Vec { - vec![COL_FLAG, COL_LEN, COL_INDEX_A, COL_INDEX_B, COL_INDEX_RES] - } - - fn commited_columns_ef(&self) -> Vec { - vec![COL_COMPUTATION] - } - - fn normal_lookups_f(&self) -> Vec { - vec![LookupIntoMemory { - index: COL_INDEX_A, - values: COL_VALUE_A, - }] - } - - fn normal_lookups_ef(&self) -> Vec { - vec![ - ExtensionFieldLookupIntoMemory { - index: COL_INDEX_B, - values: COL_VALUE_B, - }, - ExtensionFieldLookupIntoMemory { - index: COL_INDEX_RES, - values: COL_VALUE_RES, - }, - ] - } - - fn vector_lookups(&self) -> Vec { - vec![] - } - - fn buses(&self) -> Vec { - vec![Bus { - table: BusTable::Constant(self.identifier()), - direction: BusDirection::Pull, - selector: BusSelector::Column(COL_FLAG), - data: vec![COL_INDEX_A, COL_INDEX_B, COL_INDEX_RES, COL_LEN], - }] - } - - fn padding_row_f(&self) -> Vec { - [vec![ - F::ONE, // StartFlag - F::ONE, // Len - F::ZERO, // Index A - F::ZERO, // Index B - F::from_usize(ONE_VEC_PTR * VECTOR_LEN), // Index Res - F::ZERO, // Value A - ]] - .concat() - } - - fn padding_row_ef(&self) -> Vec { - vec![ - EF::ZERO, // Value B - EF::ONE, // Value Res - EF::ONE, // Computation - ] - } - - #[inline(always)] - fn execute( - &self, - arg_a: F, - arg_b: F, - arg_c: F, - aux: usize, - ctx: &mut InstructionContext<'_>, - ) -> Result<(), RunnerError> { - let trace = ctx.traces.get_mut(&self.identifier()).unwrap(); - exec_eq_poly_base_ext(arg_a, arg_b, arg_c, aux, ctx.memory, trace) - } -} diff --git a/crates/lean_vm/src/tables/execution/air.rs b/crates/lean_vm/src/tables/execution/air.rs index a4136907..8d843f27 100644 --- a/crates/lean_vm/src/tables/execution/air.rs +++ b/crates/lean_vm/src/tables/execution/air.rs @@ -1,45 +1,44 @@ use multilinear_toolkit::prelude::*; -use p3_air::{Air, AirBuilder}; use crate::{EF, ExecutionTable, ExtraDataForBuses, eval_virtual_bus_column}; -pub const N_INSTRUCTION_COLUMNS: usize = 13; -pub const N_COMMITTED_EXEC_COLUMNS: usize = 5; -pub const N_MEMORY_VALUE_COLUMNS: usize = 3; // virtual (lookup into memory, with logup*) -pub const N_EXEC_AIR_COLUMNS: usize = N_INSTRUCTION_COLUMNS + N_COMMITTED_EXEC_COLUMNS + N_MEMORY_VALUE_COLUMNS; - -// Instruction columns -pub const COL_INDEX_OPERAND_A: usize = 0; -pub const COL_INDEX_OPERAND_B: usize = 1; -pub const COL_INDEX_OPERAND_C: usize = 2; -pub const COL_INDEX_FLAG_A: usize = 3; -pub const COL_INDEX_FLAG_B: usize = 4; -pub const COL_INDEX_FLAG_C: usize = 5; -pub const COL_INDEX_ADD: usize = 6; -pub const COL_INDEX_MUL: usize = 7; -pub const COL_INDEX_DEREF: usize = 8; -pub const COL_INDEX_JUMP: usize = 9; -pub const COL_INDEX_AUX: usize = 10; -pub const COL_INDEX_IS_PRECOMPILE: usize = 11; -pub const COL_INDEX_PRECOMPILE_INDEX: usize = 12; - -// Execution columns -pub const COL_INDEX_MEM_VALUE_A: usize = 13; // virtual with logup* -pub const COL_INDEX_MEM_VALUE_B: usize = 14; // virtual with logup* -pub const COL_INDEX_MEM_VALUE_C: usize = 15; // virtual with logup* -pub const COL_INDEX_PC: usize = 16; -pub const COL_INDEX_FP: usize = 17; -pub const COL_INDEX_MEM_ADDRESS_A: usize = 18; -pub const COL_INDEX_MEM_ADDRESS_B: usize = 19; -pub const COL_INDEX_MEM_ADDRESS_C: usize = 20; +pub const N_COMMITTED_EXEC_COLUMNS: usize = 8; +pub const N_INSTRUCTION_COLUMNS: usize = 14; +pub const N_EXEC_AIR_COLUMNS: usize = N_INSTRUCTION_COLUMNS + N_COMMITTED_EXEC_COLUMNS; + +// Committed columns (IMPORTANT: they must be the first columns) +pub const COL_PC: usize = 0; +pub const COL_FP: usize = 1; +pub const COL_MEM_ADDRESS_A: usize = 2; +pub const COL_MEM_ADDRESS_B: usize = 3; +pub const COL_MEM_ADDRESS_C: usize = 4; +pub const COL_MEM_VALUE_A: usize = 5; +pub const COL_MEM_VALUE_B: usize = 6; +pub const COL_MEM_VALUE_C: usize = 7; + +// Decoded instruction columns (lookup into bytecode with logup*) +pub const COL_OPERAND_A: usize = 8; +pub const COL_OPERAND_B: usize = 9; +pub const COL_OPERAND_C: usize = 10; +pub const COL_FLAG_A: usize = 11; +pub const COL_FLAG_B: usize = 12; +pub const COL_FLAG_C: usize = 13; +pub const COL_ADD: usize = 14; +pub const COL_MUL: usize = 15; +pub const COL_DEREF: usize = 16; +pub const COL_JUMP: usize = 17; +pub const COL_AUX_1: usize = 18; +pub const COL_AUX_2: usize = 19; +pub const COL_IS_PRECOMPILE: usize = 20; +pub const COL_PRECOMPILE_INDEX: usize = 21; // Temporary columns (stored to avoid duplicate computations) pub const N_TEMPORARY_EXEC_COLUMNS: usize = 3; -pub const COL_INDEX_EXEC_NU_A: usize = 21; -pub const COL_INDEX_EXEC_NU_B: usize = 22; -pub const COL_INDEX_EXEC_NU_C: usize = 23; +pub const COL_EXEC_NU_A: usize = 22; +pub const COL_EXEC_NU_B: usize = 23; +pub const COL_EXEC_NU_C: usize = 24; -impl Air for ExecutionTable { +impl Air for ExecutionTable { type ExtraData = ExtraDataForBuses; fn n_columns_f_air(&self) -> usize { @@ -48,11 +47,11 @@ impl Air for ExecutionTable { fn n_columns_ef_air(&self) -> usize { 0 } - fn degree(&self) -> usize { + fn degree_air(&self) -> usize { 5 } fn down_column_indexes_f(&self) -> Vec { - vec![COL_INDEX_PC, COL_INDEX_FP] + vec![COL_PC, COL_FP] } fn down_column_indexes_ef(&self) -> Vec { vec![] @@ -70,34 +69,31 @@ impl Air for ExecutionTable { let next_fp = down[1].clone(); let (operand_a, operand_b, operand_c) = ( - up[COL_INDEX_OPERAND_A].clone(), - up[COL_INDEX_OPERAND_B].clone(), - up[COL_INDEX_OPERAND_C].clone(), + up[COL_OPERAND_A].clone(), + up[COL_OPERAND_B].clone(), + up[COL_OPERAND_C].clone(), ); - let (flag_a, flag_b, flag_c) = ( - up[COL_INDEX_FLAG_A].clone(), - up[COL_INDEX_FLAG_B].clone(), - up[COL_INDEX_FLAG_C].clone(), - ); - let add = up[COL_INDEX_ADD].clone(); - let mul = up[COL_INDEX_MUL].clone(); - let deref = up[COL_INDEX_DEREF].clone(); - let jump = up[COL_INDEX_JUMP].clone(); - let aux = up[COL_INDEX_AUX].clone(); - let is_precompile = up[COL_INDEX_IS_PRECOMPILE].clone(); - let precompile_index = up[COL_INDEX_PRECOMPILE_INDEX].clone(); + let (flag_a, flag_b, flag_c) = (up[COL_FLAG_A].clone(), up[COL_FLAG_B].clone(), up[COL_FLAG_C].clone()); + let add = up[COL_ADD].clone(); + let mul = up[COL_MUL].clone(); + let deref = up[COL_DEREF].clone(); + let jump = up[COL_JUMP].clone(); + let aux_1 = up[COL_AUX_1].clone(); + let aux_2 = up[COL_AUX_2].clone(); + let is_precompile = up[COL_IS_PRECOMPILE].clone(); + let precompile_index = up[COL_PRECOMPILE_INDEX].clone(); let (value_a, value_b, value_c) = ( - up[COL_INDEX_MEM_VALUE_A].clone(), - up[COL_INDEX_MEM_VALUE_B].clone(), - up[COL_INDEX_MEM_VALUE_C].clone(), + up[COL_MEM_VALUE_A].clone(), + up[COL_MEM_VALUE_B].clone(), + up[COL_MEM_VALUE_C].clone(), ); - let pc = up[COL_INDEX_PC].clone(); - let fp = up[COL_INDEX_FP].clone(); + let pc = up[COL_PC].clone(); + let fp = up[COL_FP].clone(); let (addr_a, addr_b, addr_c) = ( - up[COL_INDEX_MEM_ADDRESS_A].clone(), - up[COL_INDEX_MEM_ADDRESS_B].clone(), - up[COL_INDEX_MEM_ADDRESS_C].clone(), + up[COL_MEM_ADDRESS_A].clone(), + up[COL_MEM_ADDRESS_B].clone(), + up[COL_MEM_ADDRESS_C].clone(), ); let flag_a_minus_one = flag_a.clone() - AB::F::ONE; @@ -114,12 +110,16 @@ impl Air for ExecutionTable { let pc_plus_one = pc + AB::F::ONE; let nu_a_minus_one = nu_a.clone() - AB::F::ONE; - builder.eval_virtual_column(eval_virtual_bus_column::( - extra_data, - precompile_index.clone(), - is_precompile.clone(), - &[nu_a.clone(), nu_b.clone(), nu_c.clone(), aux.clone()], - )); + if BUS { + builder.eval_virtual_column(eval_virtual_bus_column::( + extra_data, + precompile_index.clone(), + is_precompile.clone(), + &[nu_a.clone(), nu_b.clone(), nu_c.clone(), aux_1.clone(), aux_2.clone()], + )); + } else { + builder.declare_values(&[nu_a.clone(), nu_b.clone(), nu_c.clone(), aux_1.clone(), aux_2.clone()]); + } builder.assert_zero(flag_a_minus_one * (addr_a.clone() - fp_plus_operand_a)); builder.assert_zero(flag_b_minus_one * (addr_b.clone() - fp_plus_operand_b)); @@ -129,8 +129,8 @@ impl Air for ExecutionTable { builder.assert_zero(mul * (nu_b.clone() - nu_a.clone() * nu_c.clone())); builder.assert_zero(deref.clone() * (addr_c.clone() - (value_a.clone() + operand_c.clone()))); - builder.assert_zero(deref.clone() * aux.clone() * (value_c.clone() - nu_b.clone())); - builder.assert_zero(deref.clone() * (aux.clone() - AB::F::ONE) * (value_c.clone() - fp.clone())); + builder.assert_zero(deref.clone() * aux_1.clone() * (value_c.clone() - nu_b.clone())); + builder.assert_zero(deref.clone() * (aux_1.clone() - AB::F::ONE) * (value_c.clone() - fp.clone())); builder.assert_zero((jump.clone() - AB::F::ONE) * (next_pc.clone() - pc_plus_one.clone())); builder.assert_zero((jump.clone() - AB::F::ONE) * (next_fp.clone() - fp.clone())); @@ -142,3 +142,7 @@ impl Air for ExecutionTable { builder.assert_zero(jump.clone() * nu_a_minus_one.clone() * (next_fp.clone() - fp.clone())); } } + +pub const fn instr_idx(col_index_in_air: usize) -> usize { + col_index_in_air - N_COMMITTED_EXEC_COLUMNS +} diff --git a/crates/lean_vm/src/tables/execution/mod.rs b/crates/lean_vm/src/tables/execution/mod.rs index a6b95703..d0c799e4 100644 --- a/crates/lean_vm/src/tables/execution/mod.rs +++ b/crates/lean_vm/src/tables/execution/mod.rs @@ -5,83 +5,64 @@ mod air; pub use air::*; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct ExecutionTable; +pub struct ExecutionTable; -impl TableT for ExecutionTable { +impl TableT for ExecutionTable { fn name(&self) -> &'static str { "execution" } - fn identifier(&self) -> Table { + fn table(&self) -> Table { Table::execution() } - fn n_columns_f_total(&self) -> usize { - N_EXEC_AIR_COLUMNS + N_TEMPORARY_EXEC_COLUMNS - } - - fn commited_columns_f(&self) -> Vec { - vec![ - COL_INDEX_PC, - COL_INDEX_FP, - COL_INDEX_MEM_ADDRESS_A, - COL_INDEX_MEM_ADDRESS_B, - COL_INDEX_MEM_ADDRESS_C, - ] + fn is_execution_table(&self) -> bool { + true } - fn commited_columns_ef(&self) -> Vec { - vec![] + fn n_columns_f_total(&self) -> usize { + N_EXEC_AIR_COLUMNS + N_TEMPORARY_EXEC_COLUMNS } - fn normal_lookups_f(&self) -> Vec { + fn lookups_f(&self) -> Vec { vec![ LookupIntoMemory { - index: COL_INDEX_MEM_ADDRESS_A, - values: COL_INDEX_MEM_VALUE_A, + index: COL_MEM_ADDRESS_A, + values: vec![COL_MEM_VALUE_A], }, LookupIntoMemory { - index: COL_INDEX_MEM_ADDRESS_B, - values: COL_INDEX_MEM_VALUE_B, + index: COL_MEM_ADDRESS_B, + values: vec![COL_MEM_VALUE_B], }, LookupIntoMemory { - index: COL_INDEX_MEM_ADDRESS_C, - values: COL_INDEX_MEM_VALUE_C, + index: COL_MEM_ADDRESS_C, + values: vec![COL_MEM_VALUE_C], }, ] } - fn normal_lookups_ef(&self) -> Vec { - vec![] - } - - fn vector_lookups(&self) -> Vec { + fn lookups_ef(&self) -> Vec { vec![] } - fn buses(&self) -> Vec { - vec![Bus { - table: BusTable::Variable(COL_INDEX_PRECOMPILE_INDEX), + fn bus(&self) -> Bus { + Bus { + table: BusTable::Variable(COL_PRECOMPILE_INDEX), direction: BusDirection::Push, - selector: BusSelector::Column(COL_INDEX_IS_PRECOMPILE), - data: vec![ - COL_INDEX_EXEC_NU_A, - COL_INDEX_EXEC_NU_B, - COL_INDEX_EXEC_NU_C, - COL_INDEX_AUX, - ], - }] + selector: COL_IS_PRECOMPILE, + data: vec![COL_EXEC_NU_A, COL_EXEC_NU_B, COL_EXEC_NU_C, COL_AUX_1, COL_AUX_2], + } } fn padding_row_f(&self) -> Vec { let mut padding_row = vec![F::ZERO; N_EXEC_AIR_COLUMNS + N_TEMPORARY_EXEC_COLUMNS]; - padding_row[COL_INDEX_PC] = F::from_usize(ENDING_PC); - padding_row[COL_INDEX_JUMP] = F::ONE; - padding_row[COL_INDEX_FLAG_A] = F::ONE; - padding_row[COL_INDEX_OPERAND_A] = F::ONE; - padding_row[COL_INDEX_FLAG_B] = F::ONE; - padding_row[COL_INDEX_FLAG_C] = F::ONE; - padding_row[COL_INDEX_EXEC_NU_A] = F::ONE; // because at the end of program, we always jump (looping at pc=0, so condition = nu_a = 1) + padding_row[COL_PC] = F::from_usize(ENDING_PC); + padding_row[COL_JUMP] = F::ONE; + padding_row[COL_FLAG_A] = F::ONE; + padding_row[COL_OPERAND_A] = F::ONE; + padding_row[COL_FLAG_B] = F::ONE; + padding_row[COL_FLAG_C] = F::ONE; + padding_row[COL_EXEC_NU_A] = F::ONE; // because at the end of program, we always jump (looping at pc=0, so condition = nu_a = 1) padding_row } @@ -90,7 +71,7 @@ impl TableT for ExecutionTable { } #[inline(always)] - fn execute(&self, _: F, _: F, _: F, _: usize, _: &mut InstructionContext<'_>) -> Result<(), RunnerError> { + fn execute(&self, _: F, _: F, _: F, _: usize, _: usize, _: &mut InstructionContext<'_>) -> Result<(), RunnerError> { unreachable!() } } diff --git a/crates/lean_vm/src/tables/merkle/mod.rs b/crates/lean_vm/src/tables/merkle/mod.rs deleted file mode 100644 index 03ba9657..00000000 --- a/crates/lean_vm/src/tables/merkle/mod.rs +++ /dev/null @@ -1,333 +0,0 @@ -use std::array; - -use crate::*; -use multilinear_toolkit::prelude::*; -use p3_air::Air; -use utils::{ToUsize, get_poseidon_16_of_zero, poseidon16_permute, to_big_endian_in_field}; - -// Does not support height = 1 (minimum height is 2) - -// "committed" columns -const COL_FLAG: ColIndex = 0; -const COL_INDEX_LEAF: ColIndex = 1; // vectorized pointer -const COL_LEAF_POSITION: ColIndex = 2; // (between 0 and 2^height - 1) -const COL_INDEX_ROOT: ColIndex = 3; // vectorized pointer -const COL_HEIGHT: ColIndex = 4; // merkle tree height - -const COL_ZERO: ColIndex = 5; // always equal to 0, TODO remove this -const COL_ONE: ColIndex = 6; // always equal to 1, TODO remove this -const COL_IS_LEFT: ColIndex = 7; // boolean, whether the current node is a left child -const COL_LOOKUP_MEM_INDEX: ColIndex = 8; // = COL_INDEX_LEAF if flag = 1, otherwise = COL_INDEX_ROOT - -const INITIAL_COLS_DATA_LEFT: ColIndex = 9; -const INITIAL_COLS_DATA_RIGHT: ColIndex = INITIAL_COLS_DATA_LEFT + VECTOR_LEN; -const INITIAL_COLS_DATA_RES: ColIndex = INITIAL_COLS_DATA_RIGHT + VECTOR_LEN; - -// "virtual" columns (vectorized lookups into memory) -const COL_LOOKUP_MEM_VALUES: ColIndex = INITIAL_COLS_DATA_RES + VECTOR_LEN; - -const TOTAL_N_COLS: usize = COL_LOOKUP_MEM_VALUES + VECTOR_LEN; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct MerklePrecompile; - -impl TableT for MerklePrecompile { - fn name(&self) -> &'static str { - "merkle_verify" - } - - fn identifier(&self) -> Table { - Table::merkle() - } - - fn commited_columns_f(&self) -> Vec { - (0..COL_LOOKUP_MEM_VALUES).collect() - } - - fn commited_columns_ef(&self) -> Vec { - vec![] - } - - fn normal_lookups_f(&self) -> Vec { - vec![] - } - - fn normal_lookups_ef(&self) -> Vec { - vec![] - } - - fn vector_lookups(&self) -> Vec { - vec![VectorLookupIntoMemory { - index: COL_LOOKUP_MEM_INDEX, - values: array::from_fn(|i| COL_LOOKUP_MEM_VALUES + i), - }] - } - - fn buses(&self) -> Vec { - vec![ - Bus { - table: BusTable::Constant(self.identifier()), - direction: BusDirection::Pull, - selector: BusSelector::Column(COL_FLAG), - data: vec![COL_INDEX_LEAF, COL_LEAF_POSITION, COL_INDEX_ROOT, COL_HEIGHT], - }, - Bus { - table: BusTable::Constant(Table::poseidon16_core()), - direction: BusDirection::Push, - selector: BusSelector::ConstantOne, - data: [ - vec![COL_ONE], // Compression - (INITIAL_COLS_DATA_LEFT..INITIAL_COLS_DATA_LEFT + 8).collect::>(), - (INITIAL_COLS_DATA_RIGHT..INITIAL_COLS_DATA_RIGHT + 8).collect::>(), - (INITIAL_COLS_DATA_RES..INITIAL_COLS_DATA_RES + 8).collect::>(), - vec![COL_ZERO; VECTOR_LEN], // Padding - ] - .concat(), - }, - ] - } - - fn padding_row_f(&self) -> Vec { - let default_root = get_poseidon_16_of_zero()[..VECTOR_LEN].to_vec(); - [ - vec![ - F::ONE, // flag - F::ZERO, // index_leaf - F::ZERO, // leaf_position - F::from_usize(POSEIDON_16_NULL_HASH_PTR), // index_root - F::ONE, - F::ZERO, // col_zero - F::ONE, // col_one - F::ZERO, // is_left - F::from_usize(ZERO_VEC_PTR), // lookup_mem_index - ], - vec![F::ZERO; VECTOR_LEN], // data_left - vec![F::ZERO; VECTOR_LEN], // data_right - default_root.clone(), // data_res - vec![F::ZERO; VECTOR_LEN], // lookup_mem_values - ] - .concat() - } - - fn padding_row_ef(&self) -> Vec { - vec![] - } - - #[inline(always)] - fn execute( - &self, - index_leaf: F, - leaf_position: F, - index_root: F, - height: usize, - ctx: &mut InstructionContext<'_>, - ) -> Result<(), RunnerError> { - assert!(height >= 2); - - let leaf_position = leaf_position.to_usize(); - assert!(height > 0); - assert!(leaf_position < (1 << height)); - - let auth_path = ctx.merkle_path_hints.pop_front().unwrap(); - assert_eq!(auth_path.len(), height); - let mut leaf_position_bools = to_big_endian_in_field::(!leaf_position, height); - leaf_position_bools.reverse(); // little-endian - - let leaf = ctx.memory.get_vector(index_leaf.to_usize())?; - - { - let trace = &mut ctx.traces.get_mut(&self.identifier()).unwrap().base; - trace[COL_FLAG].extend([vec![F::ONE], vec![F::ZERO; height - 1]].concat()); - trace[COL_INDEX_LEAF].extend(vec![index_leaf; height]); - trace[COL_LEAF_POSITION].extend((0..height).map(|d| F::from_usize(leaf_position >> d))); - trace[COL_INDEX_ROOT].extend(vec![index_root; height]); - trace[COL_HEIGHT].extend((1..=height).rev().map(F::from_usize)); - trace[COL_ZERO].extend(vec![F::ZERO; height]); - trace[COL_ONE].extend(vec![F::ONE; height]); - trace[COL_IS_LEFT].extend(leaf_position_bools); - trace[COL_LOOKUP_MEM_INDEX].extend([vec![index_leaf], vec![index_root; height - 1]].concat()); - } - - let mut current_hash = leaf; - for (d, neightbour) in auth_path.iter().enumerate() { - let trace = &mut ctx.traces.get_mut(&self.identifier()).unwrap().base; - - let is_left = (leaf_position >> d) & 1 == 0; - - // TODO precompute (in parallel + SIMD) poseidons - - let (data_left, data_right) = if is_left { - (current_hash, *neightbour) - } else { - (*neightbour, current_hash) - }; - for i in 0..VECTOR_LEN { - trace[INITIAL_COLS_DATA_LEFT + i].push(data_left[i]); - trace[INITIAL_COLS_DATA_RIGHT + i].push(data_right[i]); - } - - let mut input = [F::ZERO; VECTOR_LEN * 2]; - input[..VECTOR_LEN].copy_from_slice(&data_left); - input[VECTOR_LEN..].copy_from_slice(&data_right); - - let output = match ctx.poseidon16_precomputed.get(*ctx.n_poseidon16_precomputed_used) { - Some(precomputed) if precomputed.0 == input => { - *ctx.n_poseidon16_precomputed_used += 1; - precomputed.1 - } - _ => poseidon16_permute(input), - }; - - current_hash = output[..VECTOR_LEN].try_into().unwrap(); - for i in 0..VECTOR_LEN { - trace[INITIAL_COLS_DATA_RES + i].push(current_hash[i]); - } - - add_poseidon_16_core_row(ctx.traces, 1, input, current_hash, [F::ZERO; VECTOR_LEN], true); - } - let root = current_hash; - ctx.memory.set_vector(index_root.to_usize(), root)?; - - let trace = &mut ctx.traces.get_mut(&self.identifier()).unwrap().base; - for i in 0..VECTOR_LEN { - trace[COL_LOOKUP_MEM_VALUES + i].extend([vec![leaf[i]], vec![root[i]; height - 1]].concat()); - } - - Ok(()) - } -} - -impl Air for MerklePrecompile { - type ExtraData = ExtraDataForBuses; - fn n_columns_f_air(&self) -> usize { - TOTAL_N_COLS - } - fn n_columns_ef_air(&self) -> usize { - 0 - } - fn degree(&self) -> usize { - 3 - } - fn down_column_indexes_f(&self) -> Vec { - (0..TOTAL_N_COLS - 2 * VECTOR_LEN).collect() - } - fn down_column_indexes_ef(&self) -> Vec { - vec![] - } - fn n_constraints(&self) -> usize { - 12 + 5 * VECTOR_LEN - } - fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { - let up = builder.up_f(); - let flag = up[COL_FLAG].clone(); - let index_leaf = up[COL_INDEX_LEAF].clone(); - let leaf_position = up[COL_LEAF_POSITION].clone(); - let index_root = up[COL_INDEX_ROOT].clone(); - let height = up[COL_HEIGHT].clone(); - let col_zero = up[COL_ZERO].clone(); - let col_one = up[COL_ONE].clone(); - let is_left = up[COL_IS_LEFT].clone(); - let lookup_index = up[COL_LOOKUP_MEM_INDEX].clone(); - let data_left: [_; VECTOR_LEN] = array::from_fn(|i| up[INITIAL_COLS_DATA_LEFT + i].clone()); - let data_right: [_; VECTOR_LEN] = array::from_fn(|i| up[INITIAL_COLS_DATA_RIGHT + i].clone()); - let data_res: [_; VECTOR_LEN] = array::from_fn(|i| up[INITIAL_COLS_DATA_RES + i].clone()); - let lookup_values: [_; VECTOR_LEN] = array::from_fn(|i| up[COL_LOOKUP_MEM_VALUES + i].clone()); - - let down = builder.down_f(); - let flag_down = down[0].clone(); - let index_leaf_down = down[1].clone(); - let leaf_position_down = down[2].clone(); - let index_root_down = down[3].clone(); - let height_down = down[4].clone(); - let _col_zero_down = down[5].clone(); - let _col_one_down = down[6].clone(); - let is_left_down = down[7].clone(); - let _lookup_index_down = down[8].clone(); - let data_left_down: [_; VECTOR_LEN] = array::from_fn(|i| down[9 + i].clone()); - let data_right_down: [_; VECTOR_LEN] = array::from_fn(|i| down[9 + VECTOR_LEN + i].clone()); - - let mut core_bus_data = [AB::F::ZERO; 1 + 2 * 16]; - core_bus_data[0] = col_one.clone(); // Compression - core_bus_data[1..9].clone_from_slice(&data_left); - core_bus_data[9..17].clone_from_slice(&data_right); - core_bus_data[17..25].clone_from_slice(&data_res); - core_bus_data[25..].clone_from_slice(&vec![col_zero.clone(); VECTOR_LEN]); - - builder.eval_virtual_column(eval_virtual_bus_column::( - extra_data, - AB::F::from_usize(self.identifier().index()), - flag.clone(), - &[ - index_leaf.clone(), - leaf_position.clone(), - index_root.clone(), - height.clone(), - ], - )); - - builder.eval_virtual_column(eval_virtual_bus_column::( - extra_data, - AB::F::from_usize(Table::poseidon16_core().index()), - AB::F::ONE, - &core_bus_data, - )); - - // TODO double check constraints - - builder.assert_eq(col_one.clone(), AB::F::ONE); - builder.assert_eq(col_zero.clone(), AB::F::ZERO); - - builder.assert_bool(flag.clone()); - builder.assert_bool(is_left.clone()); - - let not_flag = AB::F::ONE - flag.clone(); - let not_flag_down = AB::F::ONE - flag_down.clone(); - let is_right = AB::F::ONE - is_left.clone(); - let is_right_down = AB::F::ONE - is_left_down.clone(); - - builder.assert_eq( - lookup_index.clone(), - flag.clone() * index_leaf.clone() + not_flag.clone() * index_root.clone(), - ); - - // Parameters should not change as long as the flag has not been switched back to 1: - builder.assert_zero(not_flag_down.clone() * (index_leaf_down.clone() - index_leaf.clone())); - builder.assert_zero(not_flag_down.clone() * (index_root_down.clone() - index_root.clone())); - - // decrease height by 1 each step - builder.assert_zero(not_flag_down.clone() * (height_down.clone() + AB::F::ONE - height.clone())); - - builder.assert_zero( - not_flag_down.clone() - * ((leaf_position_down.clone() * AB::F::TWO + is_right.clone()) - leaf_position.clone()), - ); - - // start (bottom of the tree) - let starts_and_is_left = flag.clone() * is_left.clone(); - for i in 0..VECTOR_LEN { - builder.assert_zero(starts_and_is_left.clone() * (data_left[i].clone() - lookup_values[i].clone())); - } - let starts_and_is_right = flag.clone() * is_right.clone(); - for i in 0..VECTOR_LEN { - builder.assert_zero(starts_and_is_right.clone() * (data_right[i].clone() - lookup_values[i].clone())); - } - - // transition (interior nodes) - let transition_left = not_flag_down.clone() * is_left_down.clone(); - for i in 0..VECTOR_LEN { - builder.assert_zero(transition_left.clone() * (data_left_down[i].clone() - data_res[i].clone())); - } - let transition_right = not_flag_down.clone() * is_right_down.clone(); - for i in 0..VECTOR_LEN { - builder.assert_zero(transition_right.clone() * (data_right_down[i].clone() - data_res[i].clone())); - } - - // end (top of the tree) - builder.assert_zero(flag_down.clone() * (height.clone() - AB::F::ONE)); // at last step, height should be 1 - builder.assert_zero(flag_down.clone() * leaf_position.clone() * (AB::F::ONE - leaf_position.clone())); // at last step, leaf position should be boolean - for i in 0..VECTOR_LEN { - builder - .assert_zero(not_flag.clone() * flag_down.clone() * (data_res[i].clone() - lookup_values[i].clone())); - } - } -} diff --git a/crates/lean_vm/src/tables/mod.rs b/crates/lean_vm/src/tables/mod.rs index 4f08ebde..62fcd766 100644 --- a/crates/lean_vm/src/tables/mod.rs +++ b/crates/lean_vm/src/tables/mod.rs @@ -4,9 +4,6 @@ pub use dot_product::*; mod poseidon_16; pub use poseidon_16::*; -mod poseidon_24; -pub use poseidon_24::*; - mod table_enum; pub use table_enum::*; @@ -16,14 +13,5 @@ pub use table_trait::*; mod execution; pub use execution::*; -mod merkle; -pub use merkle::*; - -mod slice_hash; -pub use slice_hash::*; - -mod eq_poly_base_ext; -pub use eq_poly_base_ext::*; - mod utils; pub(crate) use utils::*; diff --git a/crates/lean_vm/src/tables/poseidon_16/core.rs b/crates/lean_vm/src/tables/poseidon_16/core.rs deleted file mode 100644 index 15619c33..00000000 --- a/crates/lean_vm/src/tables/poseidon_16/core.rs +++ /dev/null @@ -1,151 +0,0 @@ -use std::collections::BTreeMap; - -use crate::*; -use multilinear_toolkit::prelude::*; -use p3_air::Air; -use utils::get_poseidon_16_of_zero; - -const POSEIDON_16_CORE_COL_FLAG: ColIndex = 0; -pub const POSEIDON_16_CORE_COL_COMPRESSION: ColIndex = 1; -pub const POSEIDON_16_CORE_COL_INPUT_START: ColIndex = 2; -// virtual via GKR -pub const POSEIDON_16_CORE_COL_OUTPUT_START: ColIndex = POSEIDON_16_CORE_COL_INPUT_START + 16; -// intermediate columns ("commited cubes") are not handled here - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Poseidon16CorePrecompile; - -impl TableT for Poseidon16CorePrecompile { - fn name(&self) -> &'static str { - "poseidon16_core" - } - - fn identifier(&self) -> Table { - Table::poseidon16_core() - } - - fn n_columns_f_total(&self) -> usize { - 2 + 16 * 2 - } - - fn commited_columns_f(&self) -> Vec { - [ - vec![POSEIDON_16_CORE_COL_FLAG, POSEIDON_16_CORE_COL_COMPRESSION], - (POSEIDON_16_CORE_COL_INPUT_START..POSEIDON_16_CORE_COL_INPUT_START + 16).collect::>(), - ] - .concat() - // (committed cubes are handled elsewhere) - } - - fn commited_columns_ef(&self) -> Vec { - vec![] - } - - fn normal_lookups_f(&self) -> Vec { - vec![] - } - - fn normal_lookups_ef(&self) -> Vec { - vec![] - } - - fn vector_lookups(&self) -> Vec { - vec![] - } - - fn buses(&self) -> Vec { - vec![Bus { - table: BusTable::Constant(self.identifier()), - direction: BusDirection::Pull, - selector: BusSelector::Column(POSEIDON_16_CORE_COL_FLAG), - data: [ - vec![POSEIDON_16_CORE_COL_COMPRESSION], - (POSEIDON_16_CORE_COL_INPUT_START..POSEIDON_16_CORE_COL_INPUT_START + 16).collect::>(), - (POSEIDON_16_CORE_COL_OUTPUT_START..POSEIDON_16_CORE_COL_OUTPUT_START + 16).collect::>(), - ] - .concat(), - }] - } - - fn padding_row_f(&self) -> Vec { - let mut poseidon_of_zero = *get_poseidon_16_of_zero(); - if POSEIDON_16_DEFAULT_COMPRESSION { - poseidon_of_zero[8..].fill(F::ZERO); - } - [ - vec![F::ZERO, F::from_bool(POSEIDON_16_DEFAULT_COMPRESSION)], - vec![F::ZERO; 16], - poseidon_of_zero.to_vec(), - ] - .concat() - } - - fn padding_row_ef(&self) -> Vec { - vec![] - } - - #[inline(always)] - fn execute(&self, _: F, _: F, _: F, _: usize, _: &mut InstructionContext<'_>) -> Result<(), RunnerError> { - unreachable!() - } -} - -impl Air for Poseidon16CorePrecompile { - type ExtraData = ExtraDataForBuses; - fn n_columns_f_air(&self) -> usize { - 2 + 16 * 2 - } - fn n_columns_ef_air(&self) -> usize { - 0 - } - fn degree(&self) -> usize { - 1 - } - fn down_column_indexes_f(&self) -> Vec { - vec![] - } - fn down_column_indexes_ef(&self) -> Vec { - vec![] - } - fn n_constraints(&self) -> usize { - 1 - } - fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { - let up = builder.up_f(); - let flag = up[POSEIDON_16_CORE_COL_FLAG].clone(); - let mut data = [AB::F::ZERO; 1 + 2 * 16]; - data[0] = up[POSEIDON_16_CORE_COL_COMPRESSION].clone(); - data[1..17].clone_from_slice(&up[POSEIDON_16_CORE_COL_INPUT_START..][..16]); - data[17..33].clone_from_slice(&up[POSEIDON_16_CORE_COL_OUTPUT_START..][..16]); - - builder.eval_virtual_column(eval_virtual_bus_column::( - extra_data, - AB::F::from_usize(self.identifier().index()), - flag.clone(), - &data, - )); - } -} - -pub fn add_poseidon_16_core_row( - traces: &mut BTreeMap, - multiplicity: usize, - input: [F; 16], - res_a: [F; 8], - res_b: [F; 8], - is_compression: bool, -) { - let trace = traces.get_mut(&Table::poseidon16_core()).unwrap(); - - trace.base[POSEIDON_16_CORE_COL_FLAG].push(F::from_usize(multiplicity)); - trace.base[POSEIDON_16_CORE_COL_COMPRESSION].push(F::from_bool(is_compression)); - for (i, value) in input.iter().enumerate() { - trace.base[POSEIDON_16_CORE_COL_INPUT_START + i].push(*value); - } - for (i, value) in res_a.iter().enumerate() { - trace.base[POSEIDON_16_CORE_COL_OUTPUT_START + i].push(*value); - } - for (i, value) in res_b.iter().enumerate() { - trace.base[POSEIDON_16_CORE_COL_OUTPUT_START + 8 + i].push(*value); - } -} diff --git a/crates/lean_vm/src/tables/poseidon_16/from_memory.rs b/crates/lean_vm/src/tables/poseidon_16/from_memory.rs deleted file mode 100644 index 445cf04f..00000000 --- a/crates/lean_vm/src/tables/poseidon_16/from_memory.rs +++ /dev/null @@ -1,246 +0,0 @@ -use std::array; - -use crate::*; -use multilinear_toolkit::prelude::*; -use p3_air::Air; -use utils::{ToUsize, get_poseidon_16_of_zero, poseidon16_permute}; - -const POSEIDON_16_MEM_COL_FLAG: ColIndex = 0; -const POSEIDON_16_MEM_COL_INDEX_RES: ColIndex = 1; -const POSEIDON_16_MEM_COL_INDEX_RES_BIS: ColIndex = 2; // = if compressed { 0 } else { POSEIDON_16_COL_INDEX_RES + 1 } -const POSEIDON_16_MEM_COL_COMPRESSION: ColIndex = 3; -const POSEIDON_16_MEM_COL_INDEX_A: ColIndex = 4; -const POSEIDON_16_MEM_COL_INDEX_B: ColIndex = 5; -const POSEIDON_16_MEM_COL_INPUT_START: ColIndex = 6; -const POSEIDON_16_MEM_COL_OUTPUT_START: ColIndex = POSEIDON_16_MEM_COL_INPUT_START + 16; -// intermediate columns ("commited cubes") are not handled here - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Poseidon16MemPrecompile; - -impl TableT for Poseidon16MemPrecompile { - fn name(&self) -> &'static str { - "poseidon16" - } - - fn identifier(&self) -> Table { - Table::poseidon16_mem() - } - - fn n_columns_f_total(&self) -> usize { - 6 + 16 * 2 - } - - fn commited_columns_f(&self) -> Vec { - vec![ - POSEIDON_16_MEM_COL_FLAG, - POSEIDON_16_MEM_COL_INDEX_RES, - POSEIDON_16_MEM_COL_INDEX_RES_BIS, - POSEIDON_16_MEM_COL_COMPRESSION, - POSEIDON_16_MEM_COL_INDEX_A, - POSEIDON_16_MEM_COL_INDEX_B, - ] // (committed cubes are handled elsewhere) - } - - fn commited_columns_ef(&self) -> Vec { - vec![] - } - - fn normal_lookups_f(&self) -> Vec { - vec![] - } - - fn normal_lookups_ef(&self) -> Vec { - vec![] - } - - fn vector_lookups(&self) -> Vec { - vec![ - VectorLookupIntoMemory { - index: POSEIDON_16_MEM_COL_INDEX_A, - values: array::from_fn(|i| POSEIDON_16_MEM_COL_INPUT_START + i), - }, - VectorLookupIntoMemory { - index: POSEIDON_16_MEM_COL_INDEX_B, - values: array::from_fn(|i| POSEIDON_16_MEM_COL_INPUT_START + VECTOR_LEN + i), - }, - VectorLookupIntoMemory { - index: POSEIDON_16_MEM_COL_INDEX_RES, - values: array::from_fn(|i| POSEIDON_16_MEM_COL_OUTPUT_START + i), - }, - VectorLookupIntoMemory { - index: POSEIDON_16_MEM_COL_INDEX_RES_BIS, - values: array::from_fn(|i| POSEIDON_16_MEM_COL_OUTPUT_START + VECTOR_LEN + i), - }, - ] - } - - fn buses(&self) -> Vec { - vec![ - Bus { - table: BusTable::Constant(self.identifier()), - direction: BusDirection::Pull, - selector: BusSelector::Column(POSEIDON_16_MEM_COL_FLAG), - data: vec![ - POSEIDON_16_MEM_COL_INDEX_A, - POSEIDON_16_MEM_COL_INDEX_B, - POSEIDON_16_MEM_COL_INDEX_RES, - POSEIDON_16_MEM_COL_COMPRESSION, - ], - }, - Bus { - table: BusTable::Constant(Table::poseidon16_core()), - direction: BusDirection::Push, - selector: BusSelector::ConstantOne, - data: [ - vec![POSEIDON_16_MEM_COL_COMPRESSION], - (POSEIDON_16_MEM_COL_INPUT_START..POSEIDON_16_MEM_COL_INPUT_START + 16).collect::>(), - (POSEIDON_16_MEM_COL_OUTPUT_START..POSEIDON_16_MEM_COL_OUTPUT_START + 16) - .collect::>(), - ] - .concat(), - }, - ] - } - - fn padding_row_f(&self) -> Vec { - let mut poseidon_of_zero = *get_poseidon_16_of_zero(); - if POSEIDON_16_DEFAULT_COMPRESSION { - poseidon_of_zero[8..].fill(F::ZERO); - } - [ - vec![ - F::ZERO, - F::from_usize(POSEIDON_16_NULL_HASH_PTR), - F::from_usize(if POSEIDON_16_DEFAULT_COMPRESSION { - ZERO_VEC_PTR - } else { - 1 + POSEIDON_16_NULL_HASH_PTR - }), - F::from_bool(POSEIDON_16_DEFAULT_COMPRESSION), - F::from_usize(ZERO_VEC_PTR), - F::from_usize(ZERO_VEC_PTR), - ], - vec![F::ZERO; 16], - poseidon_of_zero.to_vec(), - ] - .concat() - } - - fn padding_row_ef(&self) -> Vec { - vec![] - } - - #[inline(always)] - fn execute( - &self, - arg_a: F, - arg_b: F, - index_res_a: F, - is_compression: usize, - ctx: &mut InstructionContext<'_>, - ) -> Result<(), RunnerError> { - assert!(is_compression == 0 || is_compression == 1); - let is_compression = is_compression == 1; - let trace = ctx.traces.get_mut(&self.identifier()).unwrap(); - - let arg0 = ctx.memory.get_vector(arg_a.to_usize())?; - let arg1 = ctx.memory.get_vector(arg_b.to_usize())?; - - let mut input = [F::ZERO; VECTOR_LEN * 2]; - input[..VECTOR_LEN].copy_from_slice(&arg0); - input[VECTOR_LEN..].copy_from_slice(&arg1); - - let output = match ctx.poseidon16_precomputed.get(*ctx.n_poseidon16_precomputed_used) { - Some(precomputed) if precomputed.0 == input => { - *ctx.n_poseidon16_precomputed_used += 1; - precomputed.1 - } - _ => poseidon16_permute(input), - }; - - let res_a: [F; VECTOR_LEN] = output[..VECTOR_LEN].try_into().unwrap(); - let (index_res_b, res_b): (F, [F; VECTOR_LEN]) = if is_compression { - (F::from_usize(ZERO_VEC_PTR), [F::ZERO; VECTOR_LEN]) - } else { - (index_res_a + F::ONE, output[VECTOR_LEN..].try_into().unwrap()) - }; - - ctx.memory.set_vector(index_res_a.to_usize(), res_a)?; - ctx.memory.set_vector(index_res_b.to_usize(), res_b)?; - - trace.base[POSEIDON_16_MEM_COL_FLAG].push(F::ONE); - trace.base[POSEIDON_16_MEM_COL_INDEX_A].push(arg_a); - trace.base[POSEIDON_16_MEM_COL_INDEX_B].push(arg_b); - trace.base[POSEIDON_16_MEM_COL_INDEX_RES].push(index_res_a); - trace.base[POSEIDON_16_MEM_COL_INDEX_RES_BIS].push(index_res_b); - trace.base[POSEIDON_16_MEM_COL_COMPRESSION].push(F::from_bool(is_compression)); - for (i, value) in input.iter().enumerate() { - trace.base[POSEIDON_16_MEM_COL_INPUT_START + i].push(*value); - } - for (i, value) in res_a.iter().enumerate() { - trace.base[POSEIDON_16_MEM_COL_OUTPUT_START + i].push(*value); - } - for (i, value) in res_b.iter().enumerate() { - trace.base[POSEIDON_16_MEM_COL_OUTPUT_START + 8 + i].push(*value); - } - - add_poseidon_16_core_row(ctx.traces, 1, input, res_a, res_b, is_compression); - - Ok(()) - } -} - -impl Air for Poseidon16MemPrecompile { - type ExtraData = ExtraDataForBuses; - fn n_columns_f_air(&self) -> usize { - 6 + 16 * 2 - } - fn n_columns_ef_air(&self) -> usize { - 0 - } - fn degree(&self) -> usize { - 2 - } - fn down_column_indexes_f(&self) -> Vec { - vec![] - } - fn down_column_indexes_ef(&self) -> Vec { - vec![] - } - fn n_constraints(&self) -> usize { - 5 - } - fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { - let up = builder.up_f(); - let flag = up[POSEIDON_16_MEM_COL_FLAG].clone(); - let index_res = up[POSEIDON_16_MEM_COL_INDEX_RES].clone(); - let index_res_bis = up[POSEIDON_16_MEM_COL_INDEX_RES_BIS].clone(); - let compression = up[POSEIDON_16_MEM_COL_COMPRESSION].clone(); - let index_a = up[POSEIDON_16_MEM_COL_INDEX_A].clone(); - let index_b = up[POSEIDON_16_MEM_COL_INDEX_B].clone(); - - let mut core_bus_data = [AB::F::ZERO; 1 + 2 * 16]; - core_bus_data[0] = up[POSEIDON_16_MEM_COL_COMPRESSION].clone(); - core_bus_data[1..17].clone_from_slice(&up[POSEIDON_16_MEM_COL_INPUT_START..][..16]); - core_bus_data[17..33].clone_from_slice(&up[POSEIDON_16_MEM_COL_OUTPUT_START..][..16]); - - builder.eval_virtual_column(eval_virtual_bus_column::( - extra_data, - AB::F::from_usize(self.identifier().index()), - flag.clone(), - &[index_a.clone(), index_b.clone(), index_res.clone(), compression.clone()], - )); - - builder.eval_virtual_column(eval_virtual_bus_column::( - extra_data, - AB::F::from_usize(Table::poseidon16_core().index()), - AB::F::ONE, - &core_bus_data, - )); - - builder.assert_bool(flag.clone()); - builder.assert_bool(compression.clone()); - builder.assert_eq(index_res_bis, (index_res + AB::F::ONE) * (AB::F::ONE - compression)); - } -} diff --git a/crates/lean_vm/src/tables/poseidon_16/mod.rs b/crates/lean_vm/src/tables/poseidon_16/mod.rs index b35d4130..9b426bc1 100644 --- a/crates/lean_vm/src/tables/poseidon_16/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_16/mod.rs @@ -1,7 +1,356 @@ -mod core; -pub use core::*; +use p3_poseidon2::GenericPoseidon2LinearLayers; +use std::any::TypeId; -mod from_memory; -pub use from_memory::*; +use crate::{tables::poseidon_16::trace_gen::default_poseidon_row, *}; +use multilinear_toolkit::prelude::{symbolic::SymbolicExpression, *}; +use p3_koala_bear::{ + GenericPoseidon2LinearLayersKoalaBear, KOALABEAR_RC16_EXTERNAL_FINAL, KOALABEAR_RC16_EXTERNAL_INITIAL, + KOALABEAR_RC16_INTERNAL, KoalaBear, +}; +use utils::{ToUsize, poseidon16_permute}; + +mod trace_gen; +pub use trace_gen::fill_trace_poseidon_16; pub const POSEIDON_16_DEFAULT_COMPRESSION: bool = true; + +pub(super) const WIDTH: usize = 16; +const HALF_INITIAL_FULL_ROUNDS: usize = KOALABEAR_RC16_EXTERNAL_INITIAL.len() / 2; +const PARTIAL_ROUNDS: usize = KOALABEAR_RC16_INTERNAL.len(); +const HALF_FINAL_FULL_ROUNDS: usize = KOALABEAR_RC16_EXTERNAL_FINAL.len() / 2; + +pub const POSEIDON_16_COL_FLAG: ColIndex = 0; +pub const POSEIDON_16_COL_A: ColIndex = 1; +pub const POSEIDON_16_COL_B: ColIndex = 2; +pub const POSEIDON_16_COL_COMPRESSION: ColIndex = 3; +pub const POSEIDON_16_COL_RES: ColIndex = 4; +pub const POSEIDON_16_COL_RES_BIS: ColIndex = 5; // = if compressed { 0 } else { POSEIDON_16_COL_RES + 1 } +pub const POSEIDON_16_COL_INPUT_START: ColIndex = 6; +const POSEIDON_16_COL_OUTPUT_START: ColIndex = num_cols_poseidon_16() - 16; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Poseidon16Precompile; + +impl TableT for Poseidon16Precompile { + fn name(&self) -> &'static str { + "poseidon16" + } + + fn table(&self) -> Table { + Table::poseidon16() + } + + fn lookups_f(&self) -> Vec { + vec![ + LookupIntoMemory { + index: POSEIDON_16_COL_A, + values: (POSEIDON_16_COL_INPUT_START..POSEIDON_16_COL_INPUT_START + DIGEST_LEN).collect(), + }, + LookupIntoMemory { + index: POSEIDON_16_COL_B, + values: (POSEIDON_16_COL_INPUT_START + DIGEST_LEN..POSEIDON_16_COL_INPUT_START + DIGEST_LEN * 2) + .collect(), + }, + LookupIntoMemory { + index: POSEIDON_16_COL_RES, + values: (POSEIDON_16_COL_OUTPUT_START..POSEIDON_16_COL_OUTPUT_START + DIGEST_LEN).collect(), + }, + LookupIntoMemory { + index: POSEIDON_16_COL_RES_BIS, + values: (POSEIDON_16_COL_OUTPUT_START + DIGEST_LEN..POSEIDON_16_COL_OUTPUT_START + DIGEST_LEN * 2) + .collect(), + }, + ] + } + + fn lookups_ef(&self) -> Vec { + vec![] + } + + fn bus(&self) -> Bus { + Bus { + table: BusTable::Constant(self.table()), + direction: BusDirection::Pull, + selector: POSEIDON_16_COL_FLAG, + data: vec![ + POSEIDON_16_COL_A, + POSEIDON_16_COL_B, + POSEIDON_16_COL_RES, + POSEIDON_16_COL_COMPRESSION, + ], + } + } + + fn padding_row_f(&self) -> Vec { + default_poseidon_row() + } + + fn padding_row_ef(&self) -> Vec { + vec![] + } + + #[inline(always)] + fn execute( + &self, + arg_a: F, + arg_b: F, + index_res_a: F, + is_compression: usize, + _: usize, + ctx: &mut InstructionContext<'_>, + ) -> Result<(), RunnerError> { + assert!(is_compression == 0 || is_compression == 1); + let is_compression = is_compression == 1; + let trace = ctx.traces.get_mut(&self.table()).unwrap(); + + let arg0 = ctx.memory.get_slice(arg_a.to_usize(), DIGEST_LEN)?; + let arg1 = ctx.memory.get_slice(arg_b.to_usize(), DIGEST_LEN)?; + + let mut input = [F::ZERO; DIGEST_LEN * 2]; + input[..DIGEST_LEN].copy_from_slice(&arg0); + input[DIGEST_LEN..].copy_from_slice(&arg1); + + let output = match ctx.poseidon16_precomputed.get(*ctx.n_poseidon16_precomputed_used) { + Some(precomputed) if precomputed.0 == input => { + *ctx.n_poseidon16_precomputed_used += 1; + precomputed.1 + } + _ => poseidon16_permute(input), + }; + + let res_a: [F; DIGEST_LEN] = output[..DIGEST_LEN].try_into().unwrap(); + let (index_res_b, res_b): (F, [F; DIGEST_LEN]) = if is_compression { + (F::from_usize(ZERO_VEC_PTR), [F::ZERO; DIGEST_LEN]) + } else { + ( + index_res_a + F::from_usize(DIGEST_LEN), + output[DIGEST_LEN..].try_into().unwrap(), + ) + }; + + ctx.memory.set_slice(index_res_a.to_usize(), &res_a)?; + ctx.memory.set_slice(index_res_b.to_usize(), &res_b)?; + + trace.base[POSEIDON_16_COL_FLAG].push(F::ONE); + trace.base[POSEIDON_16_COL_A].push(arg_a); + trace.base[POSEIDON_16_COL_B].push(arg_b); + trace.base[POSEIDON_16_COL_RES].push(index_res_a); + trace.base[POSEIDON_16_COL_RES_BIS].push(index_res_b); + trace.base[POSEIDON_16_COL_COMPRESSION].push(F::from_bool(is_compression)); + for (i, value) in input.iter().enumerate() { + trace.base[POSEIDON_16_COL_INPUT_START + i].push(*value); + } + + // the rest of the trace is filled at the end of the execution (to get parallelism + SIMD) + + Ok(()) + } +} + +impl Air for Poseidon16Precompile { + type ExtraData = ExtraDataForBuses; + fn n_columns_f_air(&self) -> usize { + num_cols_poseidon_16() + } + fn n_columns_ef_air(&self) -> usize { + 0 + } + fn degree_air(&self) -> usize { + 10 + } + fn down_column_indexes_f(&self) -> Vec { + vec![] + } + fn down_column_indexes_ef(&self) -> Vec { + vec![] + } + fn n_constraints(&self) -> usize { + BUS as usize + 87 + } + fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { + let cols: Poseidon2Cols = { + let up = builder.up_f(); + let (prefix, shorts, suffix) = unsafe { up.align_to::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + unsafe { std::ptr::read(&shorts[0]) } + }; + + if BUS { + builder.eval_virtual_column(eval_virtual_bus_column::( + extra_data, + AB::F::from_usize(self.table().index()), + cols.flag.clone(), + &[ + cols.index_a.clone(), + cols.index_b.clone(), + cols.index_res.clone(), + cols.compress.clone(), + ], + )); + } else { + builder.declare_values(&[ + cols.index_a.clone(), + cols.index_b.clone(), + cols.index_res.clone(), + cols.compress.clone(), + ]); + } + + builder.assert_bool(cols.flag.clone()); + builder.assert_bool(cols.compress.clone()); + builder.assert_eq( + cols.index_res_bis.clone(), + (cols.index_res.clone() + AB::F::from_usize(DIGEST_LEN)) * (AB::F::ONE - cols.compress.clone()), + ); + + eval(builder, &cols) + } +} + +#[repr(C)] +#[derive(Debug)] +pub(super) struct Poseidon2Cols { + pub flag: T, + pub index_a: T, + pub index_b: T, + pub compress: T, + pub index_res: T, + pub index_res_bis: T, + + pub inputs: [T; WIDTH], + pub beginning_full_rounds: [[T; WIDTH]; HALF_INITIAL_FULL_ROUNDS], + pub partial_rounds: [T; PARTIAL_ROUNDS], + pub ending_full_rounds: [[T; WIDTH]; HALF_FINAL_FULL_ROUNDS], +} + +fn eval(builder: &mut AB, local: &Poseidon2Cols) { + let mut state: [_; WIDTH] = local.inputs.clone().map(|x| x); + + GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(&mut state); + + for round in 0..HALF_INITIAL_FULL_ROUNDS { + eval_2_full_rounds( + &mut state, + &local.beginning_full_rounds[round], + &KOALABEAR_RC16_EXTERNAL_INITIAL[2 * round], + &KOALABEAR_RC16_EXTERNAL_INITIAL[2 * round + 1], + builder, + ); + } + + for (round, cst) in KOALABEAR_RC16_INTERNAL.iter().enumerate().take(PARTIAL_ROUNDS) { + eval_partial_round(&mut state, &local.partial_rounds[round], *cst, builder); + } + + for round in 0..HALF_FINAL_FULL_ROUNDS - 1 { + eval_2_full_rounds( + &mut state, + &local.ending_full_rounds[round], + &KOALABEAR_RC16_EXTERNAL_FINAL[2 * round], + &KOALABEAR_RC16_EXTERNAL_FINAL[2 * round + 1], + builder, + ); + } + + eval_last_2_full_rounds( + &mut state, + &local.ending_full_rounds[HALF_FINAL_FULL_ROUNDS - 1], + &KOALABEAR_RC16_EXTERNAL_FINAL[2 * (HALF_FINAL_FULL_ROUNDS - 1)], + &KOALABEAR_RC16_EXTERNAL_FINAL[2 * (HALF_FINAL_FULL_ROUNDS - 1) + 1], + local.compress.clone(), + builder, + ); +} + +pub const fn num_cols_poseidon_16() -> usize { + size_of::>() +} + +#[inline] +fn eval_2_full_rounds( + state: &mut [AB::F; WIDTH], + post_full_round: &[AB::F; WIDTH], + round_constants_1: &[F; WIDTH], + round_constants_2: &[F; WIDTH], + builder: &mut AB, +) { + for (s, r) in state.iter_mut().zip(round_constants_1.iter()) { + add_koala_bear(s, *r); + *s = s.cube(); + } + GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(state); + for (s, r) in state.iter_mut().zip(round_constants_2.iter()) { + add_koala_bear(s, *r); + *s = s.cube(); + } + GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(state); + for (state_i, post_i) in state.iter_mut().zip(post_full_round) { + builder.assert_eq(state_i.clone(), post_i.clone()); + *state_i = post_i.clone(); + } +} + +#[inline] +fn eval_last_2_full_rounds( + state: &mut [AB::F; WIDTH], + post_full_round: &[AB::F; WIDTH], + round_constants_1: &[F; WIDTH], + round_constants_2: &[F; WIDTH], + compress: AB::F, + builder: &mut AB, +) { + for (s, r) in state.iter_mut().zip(round_constants_1.iter()) { + add_koala_bear(s, *r); + *s = s.cube(); + } + GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(state); + for (s, r) in state.iter_mut().zip(round_constants_2.iter()) { + add_koala_bear(s, *r); + *s = s.cube(); + } + GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(state); + for (state_i, post_i) in state.iter_mut().zip(post_full_round).take(WIDTH / 2) { + builder.assert_eq(state_i.clone(), post_i.clone()); + *state_i = post_i.clone(); + } + for (state_i, post_i) in state.iter_mut().zip(post_full_round).skip(WIDTH / 2) { + builder.assert_eq(state_i.clone() * -(compress.clone() - AB::F::ONE), post_i.clone()); + *state_i = post_i.clone(); + } +} + +#[inline] +fn eval_partial_round( + state: &mut [AB::F; WIDTH], + post_partial_round: &AB::F, + round_constant: F, + builder: &mut AB, +) { + add_koala_bear(&mut state[0], round_constant); + state[0] = state[0].cube(); + + builder.assert_eq(state[0].clone(), post_partial_round.clone()); + state[0] = post_partial_round.clone(); + + GenericPoseidon2LinearLayersKoalaBear::internal_linear_layer(state); +} + +#[inline(always)] +fn add_koala_bear(a: &mut A, value: F) { + if TypeId::of::() == TypeId::of::() { + *unsafe { std::mem::transmute::<&mut A, &mut F>(a) } += value; + } else if TypeId::of::() == TypeId::of::() { + *unsafe { std::mem::transmute::<&mut A, &mut EF>(a) } += value; + } else if TypeId::of::() == TypeId::of::>() { + *unsafe { std::mem::transmute::<&mut A, &mut FPacking>(a) } += value; + } else if TypeId::of::() == TypeId::of::>() { + *unsafe { std::mem::transmute::<&mut A, &mut EFPacking>(a) } += FPacking::::from(value); + } else if TypeId::of::() == TypeId::of::>() { + *unsafe { std::mem::transmute::<&mut A, &mut SymbolicExpression>(a) } += value; + } else { + dbg!(std::any::type_name::()); + unreachable!() + } +} diff --git a/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs b/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs new file mode 100644 index 00000000..4330ccd5 --- /dev/null +++ b/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs @@ -0,0 +1,136 @@ +use p3_poseidon2::GenericPoseidon2LinearLayers; +use tracing::instrument; + +use crate::{ + DIGEST_LEN, F, POSEIDON_16_DEFAULT_COMPRESSION, POSEIDON_16_NULL_HASH_PTR, ZERO_VEC_PTR, + tables::{Poseidon2Cols, WIDTH, num_cols_poseidon_16}, +}; +use multilinear_toolkit::prelude::*; +use p3_koala_bear::{ + GenericPoseidon2LinearLayersKoalaBear, KOALABEAR_RC16_EXTERNAL_FINAL, KOALABEAR_RC16_EXTERNAL_INITIAL, + KOALABEAR_RC16_INTERNAL, KoalaBear, +}; + +#[instrument(name = "generate Poseidon2 trace", skip_all)] +pub fn fill_trace_poseidon_16(trace: &mut [Vec]) { + let n = trace.iter().map(|col| col.len()).max().unwrap(); + for col in trace.iter_mut() { + if col.len() != n { + col.resize(n, F::ZERO); + } + } + + let m = n - (n % packing_width::()); + let trace_packed: Vec<_> = trace.iter().map(|col| FPacking::::pack_slice(&col[..m])).collect(); + + // fill the packed rows + (0..n / packing_width::()).into_par_iter().for_each(|i| { + let ptrs: Vec<*mut FPacking> = trace_packed + .iter() + .map(|col| unsafe { (col.as_ptr() as *mut FPacking).add(i) }) + .collect(); + let perm: &mut Poseidon2Cols<&mut FPacking> = + unsafe { &mut *(ptrs.as_ptr() as *mut Poseidon2Cols<&mut FPacking>) }; + + generate_trace_rows_for_perm(perm); + }); + + // fill the remaining rows (non packed) + for i in m..n { + let ptrs: Vec<*mut F> = trace + .iter() + .map(|col| unsafe { (col.as_ptr() as *mut F).add(i) }) + .collect(); + let perm: &mut Poseidon2Cols<&mut F> = unsafe { &mut *(ptrs.as_ptr() as *mut Poseidon2Cols<&mut F>) }; + generate_trace_rows_for_perm(perm); + } +} + +pub fn default_poseidon_row() -> Vec { + let mut row = vec![F::ZERO; num_cols_poseidon_16()]; + let ptrs: [*mut F; num_cols_poseidon_16()] = std::array::from_fn(|i| unsafe { row.as_mut_ptr().add(i) }); + + let perm: &mut Poseidon2Cols<&mut F> = unsafe { &mut *(ptrs.as_ptr() as *mut Poseidon2Cols<&mut F>) }; + perm.inputs.iter_mut().for_each(|x| **x = F::ZERO); + *perm.flag = F::ZERO; + *perm.index_a = F::from_usize(ZERO_VEC_PTR); + *perm.index_b = F::from_usize(ZERO_VEC_PTR); + *perm.index_res = F::from_usize(POSEIDON_16_NULL_HASH_PTR); + *perm.index_res_bis = if POSEIDON_16_DEFAULT_COMPRESSION { + F::from_usize(ZERO_VEC_PTR) + } else { + F::from_usize(POSEIDON_16_NULL_HASH_PTR + DIGEST_LEN) + }; + *perm.compress = F::from_bool(POSEIDON_16_DEFAULT_COMPRESSION); + + generate_trace_rows_for_perm(perm); + row +} +fn generate_trace_rows_for_perm + Copy>(perm: &mut Poseidon2Cols<&mut F>) { + let mut state: [F; WIDTH] = std::array::from_fn(|i| *perm.inputs[i]); + + GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(&mut state); + + for (full_round, constants) in perm + .beginning_full_rounds + .iter_mut() + .zip(KOALABEAR_RC16_EXTERNAL_INITIAL.chunks_exact(2)) + { + generate_full_round(&mut state, full_round, &constants[0], &constants[1]); + } + + for (partial_round, constant) in perm.partial_rounds.iter_mut().zip(&KOALABEAR_RC16_INTERNAL) { + generate_partial_round(&mut state, partial_round, *constant); + } + + for (full_round, constants) in perm + .ending_full_rounds + .iter_mut() + .zip(KOALABEAR_RC16_EXTERNAL_FINAL.chunks_exact(2)) + { + generate_full_round(&mut state, full_round, &constants[0], &constants[1]); + } + + perm.ending_full_rounds.last_mut().unwrap()[8..16] + .iter_mut() + .for_each(|x| { + **x = (F::ONE - *perm.compress) * **x; + }); +} + +#[inline] +fn generate_full_round + Copy>( + state: &mut [F; WIDTH], + post_full_round: &mut [&mut F; WIDTH], + round_constants_1: &[KoalaBear; WIDTH], + round_constants_2: &[KoalaBear; WIDTH], +) { + // Combine addition of round constants and S-box application in a single loop + for (state_i, const_i) in state.iter_mut().zip(round_constants_1) { + *state_i += *const_i; + *state_i = state_i.cube(); + } + GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(state); + + for (state_i, const_i) in state.iter_mut().zip(round_constants_2.iter()) { + *state_i += *const_i; + *state_i = state_i.cube(); + } + GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(state); + + post_full_round.iter_mut().zip(*state).for_each(|(post, x)| { + **post = x; + }); +} + +#[inline] +fn generate_partial_round + Copy>( + state: &mut [F; WIDTH], + post_partial_round: &mut F, + round_constant: KoalaBear, +) { + state[0] += round_constant; + state[0] = state[0].cube(); + *post_partial_round = state[0]; + GenericPoseidon2LinearLayersKoalaBear::internal_linear_layer(state); +} diff --git a/crates/lean_vm/src/tables/poseidon_24/core.rs b/crates/lean_vm/src/tables/poseidon_24/core.rs deleted file mode 100644 index 8263bc8b..00000000 --- a/crates/lean_vm/src/tables/poseidon_24/core.rs +++ /dev/null @@ -1,138 +0,0 @@ -use std::collections::BTreeMap; - -use crate::*; -use multilinear_toolkit::prelude::*; -use p3_air::Air; -use utils::get_poseidon_24_of_zero; - -const POSEIDON_24_CORE_COL_FLAG: ColIndex = 0; -pub const POSEIDON_24_CORE_COL_INPUT_START: ColIndex = 1; -// virtual via GKR -pub const POSEIDON_24_CORE_COL_OUTPUT_START: ColIndex = POSEIDON_24_CORE_COL_INPUT_START + 24; -// intermediate columns ("commited cubes") are not handled here - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Poseidon24CorePrecompile; - -impl TableT for Poseidon24CorePrecompile { - fn name(&self) -> &'static str { - "poseidon24_core" - } - - fn identifier(&self) -> Table { - Table::poseidon24_core() - } - - fn n_columns_f_total(&self) -> usize { - 1 + 24 + 8 - } - - fn commited_columns_f(&self) -> Vec { - [ - vec![POSEIDON_24_CORE_COL_FLAG], - (POSEIDON_24_CORE_COL_INPUT_START..POSEIDON_24_CORE_COL_INPUT_START + 24).collect::>(), - ] - .concat() - // (committed cubes are handled elsewhere) - } - - fn commited_columns_ef(&self) -> Vec { - vec![] - } - - fn normal_lookups_f(&self) -> Vec { - vec![] - } - - fn normal_lookups_ef(&self) -> Vec { - vec![] - } - - fn vector_lookups(&self) -> Vec { - vec![] - } - - fn buses(&self) -> Vec { - vec![Bus { - table: BusTable::Constant(self.identifier()), - direction: BusDirection::Pull, - selector: BusSelector::Column(POSEIDON_24_CORE_COL_FLAG), - data: [ - (POSEIDON_24_CORE_COL_INPUT_START..POSEIDON_24_CORE_COL_INPUT_START + 24).collect::>(), - (POSEIDON_24_CORE_COL_OUTPUT_START..POSEIDON_24_CORE_COL_OUTPUT_START + 8).collect::>(), - ] - .concat(), - }] - } - - fn padding_row_f(&self) -> Vec { - [ - vec![F::ZERO], - vec![F::ZERO; 24], - get_poseidon_24_of_zero()[16..].to_vec(), - ] - .concat() - } - - fn padding_row_ef(&self) -> Vec { - vec![] - } - - #[inline(always)] - fn execute(&self, _: F, _: F, _: F, _: usize, _: &mut InstructionContext<'_>) -> Result<(), RunnerError> { - unreachable!() - } -} - -impl Air for Poseidon24CorePrecompile { - type ExtraData = ExtraDataForBuses; - fn n_columns_f_air(&self) -> usize { - 1 + 24 + 8 - } - fn n_columns_ef_air(&self) -> usize { - 0 - } - fn degree(&self) -> usize { - 2 - } - fn down_column_indexes_f(&self) -> Vec { - vec![] - } - fn down_column_indexes_ef(&self) -> Vec { - vec![] - } - fn n_constraints(&self) -> usize { - 1 - } - fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { - let up = builder.up_f(); - let flag = up[POSEIDON_24_CORE_COL_FLAG].clone(); - let mut data = [AB::F::ZERO; 24 + 8]; - data[0..24].clone_from_slice(&up[POSEIDON_24_CORE_COL_INPUT_START..][..24]); - data[24..32].clone_from_slice(&up[POSEIDON_24_CORE_COL_OUTPUT_START..][..8]); - - builder.eval_virtual_column(eval_virtual_bus_column::( - extra_data, - AB::F::from_usize(self.identifier().index()), - flag.clone(), - &data, - )); - } -} - -pub fn add_poseidon_24_core_row( - traces: &mut BTreeMap, - multiplicity: usize, - input: [F; 24], - res: [F; 8], -) { - let trace = traces.get_mut(&Table::poseidon24_core()).unwrap(); - - trace.base[POSEIDON_24_CORE_COL_FLAG].push(F::from_usize(multiplicity)); - for (i, value) in input.iter().enumerate() { - trace.base[POSEIDON_24_CORE_COL_INPUT_START + i].push(*value); - } - for (i, value) in res.iter().enumerate() { - trace.base[POSEIDON_24_CORE_COL_OUTPUT_START + i].push(*value); - } -} diff --git a/crates/lean_vm/src/tables/poseidon_24/from_memory.rs b/crates/lean_vm/src/tables/poseidon_24/from_memory.rs deleted file mode 100644 index d8ff548a..00000000 --- a/crates/lean_vm/src/tables/poseidon_24/from_memory.rs +++ /dev/null @@ -1,220 +0,0 @@ -use crate::*; -use multilinear_toolkit::prelude::*; -use p3_air::Air; -use std::array; -use utils::{ToUsize, get_poseidon_24_of_zero, poseidon24_permute}; - -const POSEIDON_24_MEM_COL_FLAG: ColIndex = 0; -const POSEIDON_24_MEM_COL_INDEX_A: ColIndex = 1; -const POSEIDON_24_MEM_COL_INDEX_A_BIS: ColIndex = 2; -const POSEIDON_24_MEM_COL_INDEX_B: ColIndex = 3; -const POSEIDON_24_MEM_COL_INDEX_RES: ColIndex = 4; -const POSEIDON_24_MEM_COL_INPUT_START: ColIndex = 5; -const POSEIDON_24_MEM_COL_OUTPUT_START: ColIndex = POSEIDON_24_MEM_COL_INPUT_START + 24; -// intermediate columns ("commited cubes") are not handled here - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Poseidon24MemPrecompile; - -impl TableT for Poseidon24MemPrecompile { - fn name(&self) -> &'static str { - "poseidon24" - } - - fn identifier(&self) -> Table { - Table::poseidon24_mem() - } - - fn n_columns_f_total(&self) -> usize { - 5 + 24 + 8 - } - - fn commited_columns_f(&self) -> Vec { - vec![ - POSEIDON_24_MEM_COL_FLAG, - POSEIDON_24_MEM_COL_INDEX_A, - POSEIDON_24_MEM_COL_INDEX_A_BIS, - POSEIDON_24_MEM_COL_INDEX_B, - POSEIDON_24_MEM_COL_INDEX_RES, - ] // indexes only here (committed cubes are handled elsewhere) - } - - fn commited_columns_ef(&self) -> Vec { - vec![] - } - - fn normal_lookups_f(&self) -> Vec { - vec![] - } - - fn normal_lookups_ef(&self) -> Vec { - vec![] - } - - fn vector_lookups(&self) -> Vec { - vec![ - VectorLookupIntoMemory { - index: POSEIDON_24_MEM_COL_INDEX_A, - values: array::from_fn(|i| POSEIDON_24_MEM_COL_INPUT_START + i), - }, - VectorLookupIntoMemory { - index: POSEIDON_24_MEM_COL_INDEX_A_BIS, - values: array::from_fn(|i| POSEIDON_24_MEM_COL_INPUT_START + VECTOR_LEN + i), - }, - VectorLookupIntoMemory { - index: POSEIDON_24_MEM_COL_INDEX_B, - values: array::from_fn(|i| POSEIDON_24_MEM_COL_INPUT_START + 2 * VECTOR_LEN + i), - }, - VectorLookupIntoMemory { - index: POSEIDON_24_MEM_COL_INDEX_RES, - values: array::from_fn(|i| POSEIDON_24_MEM_COL_OUTPUT_START + i), - }, - ] - } - - fn buses(&self) -> Vec { - vec![ - Bus { - table: BusTable::Constant(self.identifier()), - direction: BusDirection::Pull, - selector: BusSelector::Column(POSEIDON_24_MEM_COL_FLAG), - data: vec![ - POSEIDON_24_MEM_COL_INDEX_A, - POSEIDON_24_MEM_COL_INDEX_B, - POSEIDON_24_MEM_COL_INDEX_RES, - ], - }, - Bus { - table: BusTable::Constant(Table::poseidon24_core()), - direction: BusDirection::Push, - selector: BusSelector::ConstantOne, - data: [ - (POSEIDON_24_MEM_COL_INPUT_START..POSEIDON_24_MEM_COL_INPUT_START + 24).collect::>(), - (POSEIDON_24_MEM_COL_OUTPUT_START..POSEIDON_24_MEM_COL_OUTPUT_START + 8).collect::>(), - ] - .concat(), - }, - ] - } - - fn padding_row_f(&self) -> Vec { - [ - vec![ - F::ZERO, - F::from_usize(ZERO_VEC_PTR), - F::from_usize(ZERO_VEC_PTR + 1), - F::from_usize(ZERO_VEC_PTR), - F::from_usize(POSEIDON_24_NULL_HASH_PTR), - ], - vec![F::ZERO; 24], - get_poseidon_24_of_zero()[16..].to_vec(), - ] - .concat() - } - - fn padding_row_ef(&self) -> Vec { - vec![] - } - - #[inline(always)] - fn execute( - &self, - arg_a: F, - arg_b: F, - res: F, - aux: usize, - ctx: &mut InstructionContext<'_>, - ) -> Result<(), RunnerError> { - assert_eq!(aux, 0); // no aux for poseidon24 - let trace = ctx.traces.get_mut(&self.identifier()).unwrap(); - - let arg0 = ctx.memory.get_vector(arg_a.to_usize())?; - let arg1 = ctx.memory.get_vector(1 + arg_a.to_usize())?; - let arg2 = ctx.memory.get_vector(arg_b.to_usize())?; - - let mut input = [F::ZERO; VECTOR_LEN * 3]; - input[..VECTOR_LEN].copy_from_slice(&arg0); - input[VECTOR_LEN..2 * VECTOR_LEN].copy_from_slice(&arg1); - input[2 * VECTOR_LEN..].copy_from_slice(&arg2); - - let output = match ctx.poseidon24_precomputed.get(*ctx.n_poseidon24_precomputed_used) { - Some(precomputed) if precomputed.0 == input => { - *ctx.n_poseidon24_precomputed_used += 1; - precomputed.1 - } - _ => { - let output = poseidon24_permute(input); - output[2 * VECTOR_LEN..].try_into().unwrap() - } - }; - - ctx.memory.set_vector(res.to_usize(), output)?; - - trace.base[POSEIDON_24_MEM_COL_FLAG].push(F::ONE); - trace.base[POSEIDON_24_MEM_COL_INDEX_A].push(arg_a); - trace.base[POSEIDON_24_MEM_COL_INDEX_A_BIS].push(arg_a + F::ONE); - trace.base[POSEIDON_24_MEM_COL_INDEX_B].push(arg_b); - trace.base[POSEIDON_24_MEM_COL_INDEX_RES].push(res); - for (i, value) in input.iter().enumerate() { - trace.base[POSEIDON_24_MEM_COL_INPUT_START + i].push(*value); - } - for (i, value) in output.iter().enumerate() { - trace.base[POSEIDON_24_MEM_COL_OUTPUT_START + i].push(*value); - } - - add_poseidon_24_core_row(ctx.traces, 1, input, output); - - Ok(()) - } -} - -impl Air for Poseidon24MemPrecompile { - type ExtraData = ExtraDataForBuses; - fn n_columns_f_air(&self) -> usize { - 5 + 24 + 8 - } - fn n_columns_ef_air(&self) -> usize { - 0 - } - fn degree(&self) -> usize { - 2 - } - fn down_column_indexes_f(&self) -> Vec { - vec![] - } - fn down_column_indexes_ef(&self) -> Vec { - vec![] - } - fn n_constraints(&self) -> usize { - 4 - } - fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { - let up = builder.up_f(); - let flag = up[POSEIDON_24_MEM_COL_FLAG].clone(); - let index_res = up[POSEIDON_24_MEM_COL_INDEX_RES].clone(); - let index_input_a = up[POSEIDON_24_MEM_COL_INDEX_A].clone(); - let index_input_a_bis = up[POSEIDON_24_MEM_COL_INDEX_A_BIS].clone(); - let index_b = up[POSEIDON_24_MEM_COL_INDEX_B].clone(); - - let mut core_bus_data = [AB::F::ZERO; 24 + 8]; - core_bus_data[0..24].clone_from_slice(&up[POSEIDON_24_MEM_COL_INPUT_START..][..24]); - core_bus_data[24..32].clone_from_slice(&up[POSEIDON_24_MEM_COL_OUTPUT_START..][..8]); - - builder.eval_virtual_column(eval_virtual_bus_column::( - extra_data, - AB::F::from_usize(self.identifier().index()), - flag.clone(), - &[index_input_a.clone(), index_b, index_res, AB::F::ZERO], - )); - - builder.eval_virtual_column(eval_virtual_bus_column::( - extra_data, - AB::F::from_usize(Table::poseidon24_core().index()), - AB::F::ONE, - &core_bus_data, - )); - - builder.assert_bool(flag); - builder.assert_eq(index_input_a_bis, index_input_a + AB::F::ONE); - } -} diff --git a/crates/lean_vm/src/tables/poseidon_24/mod.rs b/crates/lean_vm/src/tables/poseidon_24/mod.rs deleted file mode 100644 index 96e662e4..00000000 --- a/crates/lean_vm/src/tables/poseidon_24/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod core; -pub use core::*; - -mod from_memory; -pub use from_memory::*; diff --git a/crates/lean_vm/src/tables/slice_hash/mod.rs b/crates/lean_vm/src/tables/slice_hash/mod.rs deleted file mode 100644 index 16156c68..00000000 --- a/crates/lean_vm/src/tables/slice_hash/mod.rs +++ /dev/null @@ -1,292 +0,0 @@ -use std::array; - -use crate::*; -use multilinear_toolkit::prelude::*; -use p3_air::Air; -use utils::{ToUsize, get_poseidon_24_of_zero, poseidon24_permute}; - -// Does not support len = 1 (minimum len is 2) - -// "committed" columns -const COL_FLAG: ColIndex = 0; -const COL_INDEX_SEED: ColIndex = 1; // vectorized pointer -const COL_INDEX_START: ColIndex = 2; // vectorized pointer -const COL_INDEX_START_BIS: ColIndex = 3; // = COL_INDEX_START + 1 -const COL_INDEX_RES: ColIndex = 4; // vectorized pointer -const COL_LEN: ColIndex = 5; - -const COL_LOOKUP_MEM_INDEX_SEED_OR_RES: ColIndex = 6; // = COL_INDEX_START if flag = 1, otherwise = COL_INDEX_RES -const INITIAL_COLS_DATA_RIGHT: ColIndex = 7; -const INITIAL_COLS_DATA_RES: ColIndex = INITIAL_COLS_DATA_RIGHT + VECTOR_LEN; - -// "virtual" columns (vectorized lookups into memory) -const COL_LOOKUP_MEM_VALUES_SEED_OR_RES: ColIndex = INITIAL_COLS_DATA_RES + VECTOR_LEN; // 8 columns -const COL_LOOKUP_MEM_VALUES_LEFT: ColIndex = COL_LOOKUP_MEM_VALUES_SEED_OR_RES + VECTOR_LEN; // 16 columns - -const TOTAL_N_COLS: usize = COL_LOOKUP_MEM_VALUES_LEFT + 2 * VECTOR_LEN; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct SliceHashPrecompile; - -impl TableT for SliceHashPrecompile { - fn name(&self) -> &'static str { - "slice_hash" - } - - fn identifier(&self) -> Table { - Table::slice_hash() - } - - fn commited_columns_f(&self) -> Vec { - (0..COL_LOOKUP_MEM_VALUES_SEED_OR_RES).collect() - } - - fn commited_columns_ef(&self) -> Vec { - vec![] - } - - fn normal_lookups_f(&self) -> Vec { - vec![] - } - - fn normal_lookups_ef(&self) -> Vec { - vec![] - } - - fn vector_lookups(&self) -> Vec { - vec![ - VectorLookupIntoMemory { - index: COL_LOOKUP_MEM_INDEX_SEED_OR_RES, - values: array::from_fn(|i| COL_LOOKUP_MEM_VALUES_SEED_OR_RES + i), - }, - VectorLookupIntoMemory { - index: COL_INDEX_START, - values: array::from_fn(|i| COL_LOOKUP_MEM_VALUES_LEFT + i), - }, - VectorLookupIntoMemory { - index: COL_INDEX_START_BIS, - values: array::from_fn(|i| COL_LOOKUP_MEM_VALUES_LEFT + VECTOR_LEN + i), - }, - ] - } - - fn buses(&self) -> Vec { - vec![ - Bus { - table: BusTable::Constant(self.identifier()), - direction: BusDirection::Pull, - selector: BusSelector::Column(COL_FLAG), - data: vec![COL_INDEX_SEED, COL_INDEX_START, COL_INDEX_RES, COL_LEN], - }, - Bus { - table: BusTable::Constant(Table::poseidon24_core()), - direction: BusDirection::Push, - selector: BusSelector::ConstantOne, - data: [ - (COL_LOOKUP_MEM_VALUES_LEFT..COL_LOOKUP_MEM_VALUES_LEFT + 16).collect::>(), - (INITIAL_COLS_DATA_RIGHT..INITIAL_COLS_DATA_RIGHT + 8).collect::>(), - (INITIAL_COLS_DATA_RES..INITIAL_COLS_DATA_RES + 8).collect::>(), - ] - .concat(), - }, - ] - } - - fn padding_row_f(&self) -> Vec { - let default_hash = get_poseidon_24_of_zero()[2 * VECTOR_LEN..].to_vec(); - [ - vec![ - F::ONE, // flag - F::from_usize(ZERO_VEC_PTR), // index seed - F::from_usize(ZERO_VEC_PTR), // index_start - F::from_usize(ZERO_VEC_PTR + 1), // index_start_bis - F::from_usize(ZERO_VEC_PTR), // index_res - F::ONE, // len - F::from_usize(ZERO_VEC_PTR), // COL_LOOKUP_MEM_INDEX_SEED_OR_RES - ], - vec![F::ZERO; VECTOR_LEN], // INITIAL_COLS_DATA_RIGHT - default_hash, // INITIAL_COLS_DATA_RES - vec![F::ZERO; VECTOR_LEN], // COL_LOOKUP_MEM_VALUES_SEED_OR_RES - vec![F::ZERO; VECTOR_LEN * 2], // COL_LOOKUP_MEM_VALUES_LEFT - ] - .concat() - } - - fn padding_row_ef(&self) -> Vec { - vec![] - } - - #[inline(always)] - fn execute( - &self, - index_seed: F, - index_start: F, - index_res: F, - len: usize, - ctx: &mut InstructionContext<'_>, - ) -> Result<(), RunnerError> { - assert!(len >= 2); - - let seed = ctx.memory.get_vector(index_seed.to_usize())?; - let mut cap = seed; - for i in 0..len { - let index = index_start.to_usize() + i * 2; - - let mut input = [F::ZERO; VECTOR_LEN * 3]; - input[..VECTOR_LEN].copy_from_slice(&ctx.memory.get_vector(index)?); - input[VECTOR_LEN..VECTOR_LEN * 2].copy_from_slice(&ctx.memory.get_vector(index + 1)?); - input[VECTOR_LEN * 2..].copy_from_slice(&cap); - // let output: [F; VECTOR_LEN] = poseidon24_permute(input)[VECTOR_LEN * 2..].try_into().unwrap(); - - let output = match ctx.poseidon24_precomputed.get(*ctx.n_poseidon24_precomputed_used) { - Some(precomputed) if precomputed.0 == input => { - *ctx.n_poseidon24_precomputed_used += 1; - precomputed.1 - } - _ => poseidon24_permute(input)[VECTOR_LEN * 2..].try_into().unwrap(), - }; - - let trace = &mut ctx.traces.get_mut(&self.identifier()).unwrap().base; - - for j in 0..VECTOR_LEN * 2 { - trace[COL_LOOKUP_MEM_VALUES_LEFT + j].push(input[j]); - } - for j in 0..VECTOR_LEN { - trace[INITIAL_COLS_DATA_RIGHT + j].push(cap[j]); - } - for j in 0..VECTOR_LEN { - trace[INITIAL_COLS_DATA_RES + j].push(output[j]); - } - - add_poseidon_24_core_row(ctx.traces, 1, input, output); - - cap = output; - } - let trace = &mut ctx.traces.get_mut(&self.identifier()).unwrap().base; - - let final_res = cap; - ctx.memory.set_vector(index_res.to_usize(), final_res)?; - - trace[COL_FLAG].extend([vec![F::ONE], vec![F::ZERO; len - 1]].concat()); - trace[COL_INDEX_SEED].extend(vec![index_seed; len]); - trace[COL_INDEX_START].extend((0..len).map(|i| index_start + F::from_usize(i * 2))); - trace[COL_INDEX_START_BIS].extend((0..len).map(|i| index_start + F::from_usize(i * 2 + 1))); - trace[COL_INDEX_RES].extend(vec![index_res; len]); - trace[COL_LEN].extend((1..=len).rev().map(F::from_usize)); - trace[COL_LOOKUP_MEM_INDEX_SEED_OR_RES].extend([vec![index_seed], vec![index_res; len - 1]].concat()); - for i in 0..VECTOR_LEN { - trace[COL_LOOKUP_MEM_VALUES_SEED_OR_RES + i].extend([vec![seed[i]], vec![final_res[i]; len - 1]].concat()); - } - - Ok(()) - } -} - -impl Air for SliceHashPrecompile { - type ExtraData = ExtraDataForBuses; - fn n_columns_f_air(&self) -> usize { - TOTAL_N_COLS - } - fn n_columns_ef_air(&self) -> usize { - 0 - } - fn degree(&self) -> usize { - 3 - } - fn down_column_indexes_f(&self) -> Vec { - (0..INITIAL_COLS_DATA_RES).collect() - } - fn down_column_indexes_ef(&self) -> Vec { - vec![] - } - fn n_constraints(&self) -> usize { - 8 + 5 * VECTOR_LEN - } - fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { - let up = builder.up_f(); - let flag = up[COL_FLAG].clone(); - let index_seed = up[COL_INDEX_SEED].clone(); - let index_start = up[COL_INDEX_START].clone(); - let index_start_bis = up[COL_INDEX_START_BIS].clone(); - let index_res = up[COL_INDEX_RES].clone(); - let len = up[COL_LEN].clone(); - let lookup_index_seed_or_res = up[COL_LOOKUP_MEM_INDEX_SEED_OR_RES].clone(); - let data_right: [_; VECTOR_LEN] = array::from_fn(|i| up[INITIAL_COLS_DATA_RIGHT + i].clone()); - let data_res: [_; VECTOR_LEN] = array::from_fn(|i| up[INITIAL_COLS_DATA_RES + i].clone()); - let data_seed_or_res_lookup_values: [_; VECTOR_LEN] = - array::from_fn(|i| up[COL_LOOKUP_MEM_VALUES_SEED_OR_RES + i].clone()); - - let down = builder.down_f(); - let flag_down = down[0].clone(); - let index_seed_down = down[1].clone(); - let index_start_down = down[2].clone(); - let _index_start_bis_down = down[3].clone(); - let index_res_down = down[4].clone(); - let len_down = down[5].clone(); - let _lookup_index_seed_or_res_down = down[6].clone(); - let data_right_down: [_; VECTOR_LEN] = array::from_fn(|i| down[7 + i].clone()); - - let mut core_bus_data = [AB::F::ZERO; 24 + 8]; - core_bus_data[0..16].clone_from_slice(&up[COL_LOOKUP_MEM_VALUES_LEFT..][..16]); - core_bus_data[16..24].clone_from_slice(&up[INITIAL_COLS_DATA_RIGHT..][..8]); - core_bus_data[24..32].clone_from_slice(&up[INITIAL_COLS_DATA_RES..][..8]); - - builder.eval_virtual_column(eval_virtual_bus_column::( - extra_data, - AB::F::from_usize(self.identifier().index()), - flag.clone(), - &[index_seed.clone(), index_start.clone(), index_res.clone(), len.clone()], - )); - - builder.eval_virtual_column(eval_virtual_bus_column::( - extra_data, - AB::F::from_usize(Table::poseidon24_core().index()), - AB::F::ONE, - &core_bus_data, - )); - - // TODO double check constraints - - builder.assert_bool(flag.clone()); - - let not_flag = AB::F::ONE - flag.clone(); - let not_flag_down = AB::F::ONE - flag_down.clone(); - - builder.assert_eq( - lookup_index_seed_or_res.clone(), - flag.clone() * index_seed.clone() + not_flag.clone() * index_res.clone(), - ); - - // index_start_bis = index_start + 1 - builder.assert_eq(index_start_bis.clone(), index_start.clone() + AB::F::ONE); - - // Parameters should not change as long as the flag has not been switched back to 1: - builder.assert_zero(not_flag_down.clone() * (index_seed_down.clone() - index_seed.clone())); - builder.assert_zero(not_flag_down.clone() * (index_res_down.clone() - index_res.clone())); - - builder.assert_zero(not_flag_down.clone() * (index_start_down.clone() - (index_start.clone() + AB::F::TWO))); - - // decrease len by 1 each step - builder.assert_zero(not_flag_down.clone() * (len_down.clone() + AB::F::ONE - len.clone())); - - // start: ingest the seed - for i in 0..VECTOR_LEN { - builder.assert_zero(flag.clone() * (data_right[i].clone() - data_seed_or_res_lookup_values[i].clone())); - } - - // transition - for i in 0..VECTOR_LEN { - builder.assert_zero(not_flag_down.clone() * (data_res[i].clone() - data_right_down[i].clone())); - } - - // end - builder.assert_zero(flag_down.clone() * (len.clone() - AB::F::ONE)); // at last step, len should be 1 - for i in 0..VECTOR_LEN { - builder.assert_zero( - not_flag.clone() - * flag_down.clone() - * (data_res[i].clone() - data_seed_or_res_lookup_values[i].clone()), - ); - } - } -} diff --git a/crates/lean_vm/src/tables/table_enum.rs b/crates/lean_vm/src/tables/table_enum.rs index 4de9a519..a74acd6b 100644 --- a/crates/lean_vm/src/tables/table_enum.rs +++ b/crates/lean_vm/src/tables/table_enum.rs @@ -1,35 +1,17 @@ -use multilinear_toolkit::prelude::{PF, PrimeCharacteristicRing}; -use p3_air::Air; +use multilinear_toolkit::prelude::*; +use utils::MEMORY_TABLE_INDEX; use crate::*; -pub const N_TABLES: usize = 10; -pub const ALL_TABLES: [Table; N_TABLES] = [ - Table::execution(), - Table::dot_product_be(), - Table::dot_product_ee(), - Table::poseidon16_core(), - Table::poseidon16_mem(), - Table::poseidon24_core(), - Table::poseidon24_mem(), - Table::merkle(), - Table::slice_hash(), - Table::eq_poly_base_ext(), -]; +pub const N_TABLES: usize = 3; +pub const ALL_TABLES: [Table; N_TABLES] = [Table::execution(), Table::dot_product(), Table::poseidon16()]; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(usize)] pub enum Table { - Execution(ExecutionTable), - DotProductBE(DotProductPrecompile), - DotProductEE(DotProductPrecompile), - Poseidon16Core(Poseidon16CorePrecompile), - Poseidon16Mem(Poseidon16MemPrecompile), - Poseidon24Core(Poseidon24CorePrecompile), - Poseidon24Mem(Poseidon24MemPrecompile), - Merkle(MerklePrecompile), - SliceHash(SliceHashPrecompile), - EqPolyBaseExt(EqPolyBaseExtPrecompile), + Execution(ExecutionTable), + DotProduct(DotProductPrecompile), + Poseidon16(Poseidon16Precompile), } #[macro_export] @@ -37,31 +19,17 @@ macro_rules! delegate_to_inner { // Existing pattern for method calls ($self:expr, $method:ident $(, $($arg:expr),*)?) => { match $self { - Self::DotProductBE(p) => p.$method($($($arg),*)?), - Self::DotProductEE(p) => p.$method($($($arg),*)?), - Self::Poseidon16Core(p) => p.$method($($($arg),*)?), - Self::Poseidon16Mem(p) => p.$method($($($arg),*)?), - Self::Poseidon24Core(p) => p.$method($($($arg),*)?), - Self::Poseidon24Mem(p) => p.$method($($($arg),*)?), + Self::DotProduct(p) => p.$method($($($arg),*)?), + Self::Poseidon16(p) => p.$method($($($arg),*)?), Self::Execution(p) => p.$method($($($arg),*)?), - Self::Merkle(p) => p.$method($($($arg),*)?), - Self::SliceHash(p) => p.$method($($($arg),*)?), - Self::EqPolyBaseExt(p) => p.$method($($($arg),*)?), } }; // New pattern for applying a macro to the inner value ($self:expr => $macro_name:ident) => { match $self { - Table::DotProductBE(p) => $macro_name!(p), - Table::DotProductEE(p) => $macro_name!(p), - Table::Poseidon16Core(p) => $macro_name!(p), - Table::Poseidon16Mem(p) => $macro_name!(p), - Table::Poseidon24Core(p) => $macro_name!(p), - Table::Poseidon24Mem(p) => $macro_name!(p), + Table::DotProduct(p) => $macro_name!(p), + Table::Poseidon16(p) => $macro_name!(p), Table::Execution(p) => $macro_name!(p), - Table::Merkle(p) => $macro_name!(p), - Table::SliceHash(p) => $macro_name!(p), - Table::EqPolyBaseExt(p) => $macro_name!(p), } }; } @@ -70,44 +38,17 @@ impl Table { pub const fn execution() -> Self { Self::Execution(ExecutionTable) } - pub const fn dot_product_be() -> Self { - Self::DotProductBE(DotProductPrecompile::) + pub const fn dot_product() -> Self { + Self::DotProduct(DotProductPrecompile) } - pub const fn dot_product_ee() -> Self { - Self::DotProductEE(DotProductPrecompile::) - } - pub const fn poseidon16_core() -> Self { - Self::Poseidon16Core(Poseidon16CorePrecompile) - } - pub const fn poseidon16_mem() -> Self { - Self::Poseidon16Mem(Poseidon16MemPrecompile) - } - pub const fn poseidon24_core() -> Self { - Self::Poseidon24Core(Poseidon24CorePrecompile) - } - pub const fn poseidon24_mem() -> Self { - Self::Poseidon24Mem(Poseidon24MemPrecompile) - } - pub const fn merkle() -> Self { - Self::Merkle(MerklePrecompile) - } - pub const fn slice_hash() -> Self { - Self::SliceHash(SliceHashPrecompile) - } - pub const fn eq_poly_base_ext() -> Self { - Self::EqPolyBaseExt(EqPolyBaseExtPrecompile) + pub const fn poseidon16() -> Self { + Self::Poseidon16(Poseidon16Precompile) } pub fn embed(&self) -> PF { PF::from_usize(self.index()) } pub const fn index(&self) -> usize { - unsafe { *(self as *const Self as *const usize) } - } - pub fn is_poseidon(&self) -> bool { - matches!( - self, - Table::Poseidon16Core(_) | Table::Poseidon16Mem(_) | Table::Poseidon24Core(_) | Table::Poseidon24Mem(_) - ) + unsafe { *(self as *const Self as *const usize) + MEMORY_TABLE_INDEX + 1 } } } @@ -115,26 +56,20 @@ impl TableT for Table { fn name(&self) -> &'static str { delegate_to_inner!(self, name) } - fn identifier(&self) -> Table { - delegate_to_inner!(self, identifier) - } - fn commited_columns_f(&self) -> Vec { - delegate_to_inner!(self, commited_columns_f) + fn table(&self) -> Table { + delegate_to_inner!(self, table) } - fn commited_columns_ef(&self) -> Vec { - delegate_to_inner!(self, commited_columns_ef) + fn lookups_f(&self) -> Vec { + delegate_to_inner!(self, lookups_f) } - fn normal_lookups_f(&self) -> Vec { - delegate_to_inner!(self, normal_lookups_f) + fn lookups_ef(&self) -> Vec { + delegate_to_inner!(self, lookups_ef) } - fn normal_lookups_ef(&self) -> Vec { - delegate_to_inner!(self, normal_lookups_ef) + fn is_execution_table(&self) -> bool { + delegate_to_inner!(self, is_execution_table) } - fn vector_lookups(&self) -> Vec { - delegate_to_inner!(self, vector_lookups) - } - fn buses(&self) -> Vec { - delegate_to_inner!(self, buses) + fn bus(&self) -> Bus { + delegate_to_inner!(self, bus) } fn padding_row_f(&self) -> Vec> { delegate_to_inner!(self, padding_row_f) @@ -147,10 +82,11 @@ impl TableT for Table { arg_a: F, arg_b: F, arg_c: F, - aux: usize, + aux_1: usize, + aux_2: usize, ctx: &mut InstructionContext<'_>, ) -> Result<(), RunnerError> { - delegate_to_inner!(self, execute, arg_a, arg_b, arg_c, aux, ctx) + delegate_to_inner!(self, execute, arg_a, arg_b, arg_c, aux_1, aux_2, ctx) } fn n_columns_f_total(&self) -> usize { delegate_to_inner!(self, n_columns_f_total) @@ -162,8 +98,8 @@ impl TableT for Table { impl Air for Table { type ExtraData = (); - fn degree(&self) -> usize { - delegate_to_inner!(self, degree) + fn degree_air(&self) -> usize { + delegate_to_inner!(self, degree_air) } fn n_columns_f_air(&self) -> usize { delegate_to_inner!(self, n_columns_f_air) @@ -180,17 +116,17 @@ impl Air for Table { fn down_column_indexes_ef(&self) -> Vec { delegate_to_inner!(self, down_column_indexes_ef) } - fn eval(&self, _: &mut AB, _: &Self::ExtraData) { + fn eval(&self, _: &mut AB, _: &Self::ExtraData) { unreachable!() } } pub fn max_bus_width() -> usize { - 1 + ALL_TABLES - .iter() - .map(|table| table.buses().iter().map(|bus| bus.data.len()).max().unwrap()) - .max() - .unwrap() + 1 + ALL_TABLES.iter().map(|table| table.bus().data.len()).max().unwrap() +} + +pub fn max_air_constraints() -> usize { + ALL_TABLES.iter().map(|table| table.n_constraints()).max().unwrap() } #[cfg(test)] @@ -200,7 +136,7 @@ mod tests { #[test] fn test_table_indices() { for (i, table) in ALL_TABLES.iter().enumerate() { - assert_eq!(table.index(), i); + assert_eq!(table.index(), i + MEMORY_TABLE_INDEX + 1); } } } diff --git a/crates/lean_vm/src/tables/table_trait.rs b/crates/lean_vm/src/tables/table_trait.rs index f9068e49..4243090b 100644 --- a/crates/lean_vm/src/tables/table_trait.rs +++ b/crates/lean_vm/src/tables/table_trait.rs @@ -1,25 +1,18 @@ -use crate::{EF, F, InstructionContext, RunnerError, Table, VECTOR_LEN}; +use crate::{DIMENSION, EF, F, InstructionContext, N_COMMITTED_EXEC_COLUMNS, RunnerError, Table}; use multilinear_toolkit::prelude::*; -use p3_air::Air; -use std::{any::TypeId, array, mem::transmute}; -use utils::ToUsize; -use sub_protocols::{ - ColDims, ExtensionCommitmentFromBaseProver, ExtensionCommitmentFromBaseVerifier, committed_dims_extension_from_base, -}; - -// Zero padding will be added to each at least, if this minimum is not reached -// (ensuring AIR / GKR work fine, with SIMD, without too much edge cases) -// Long term, we should find a more elegant solution. -pub const MIN_LOG_N_ROWS_PER_TABLE: usize = 8; -pub const MIN_N_ROWS_PER_TABLE: usize = 1 << MIN_LOG_N_ROWS_PER_TABLE; +use std::{any::TypeId, cmp::Reverse, collections::BTreeMap, mem::transmute}; +use utils::VarCount; pub type ColIndex = usize; +pub type CommittedStatements = BTreeMap, BTreeMap)>>; + #[derive(Debug)] pub struct LookupIntoMemory { pub index: ColIndex, // should be in base field columns - pub values: ColIndex, + /// For (i, col_index) in values.iter().enumerate(), For j in 0..num_rows, columns_f[col_index][j] = memory[index[j] + i] + pub values: Vec, } #[derive(Debug)] @@ -28,12 +21,6 @@ pub struct ExtensionFieldLookupIntoMemory { pub values: ColIndex, } -#[derive(Debug)] -pub struct VectorLookupIntoMemory { - pub index: ColIndex, // should be in base field columns - pub values: [ColIndex; 8], -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum BusDirection { Pull, @@ -49,17 +36,11 @@ impl BusDirection { } } -#[derive(Debug)] -pub enum BusSelector { - Column(ColIndex), - ConstantOne, -} - #[derive(Debug)] pub struct Bus { pub direction: BusDirection, pub table: BusTable, - pub selector: BusSelector, + pub selector: ColIndex, pub data: Vec, // For now, we only supports F (base field) columns as bus data } @@ -69,33 +50,11 @@ pub enum BusTable { Variable(ColIndex), } -#[derive(Debug, Clone, Copy, Default)] -pub struct TableHeight(pub usize); - -impl TableHeight { - pub fn n_rows_non_padded(self) -> usize { - self.0 - } - pub fn n_rows_non_padded_maxed(self) -> usize { - self.0.max(MIN_N_ROWS_PER_TABLE) - } - pub fn n_rows_padded(self) -> usize { - self.0.next_power_of_two().max(MIN_N_ROWS_PER_TABLE) - } - pub fn padding_len(self) -> usize { - self.n_rows_padded() - self.0 - } - pub fn log_padded(self) -> usize { - log2_strict_usize(self.n_rows_padded()) - } -} - -#[derive(Debug, Default, derive_more::Deref)] +#[derive(Debug, Default)] pub struct TableTrace { pub base: Vec>, pub ext: Vec>, - #[deref] - pub height: TableHeight, + pub log_n_rows: VarCount, } impl TableTrace { @@ -103,16 +62,22 @@ impl TableTrace { Self { base: vec![Vec::new(); air.n_columns_f_total()], ext: vec![Vec::new(); air.n_columns_ef_total()], - height: TableHeight::default(), // filled later + log_n_rows: 0, // filled later } } } -#[derive(Debug)] +pub fn sort_tables_by_height(table_heights: &BTreeMap) -> Vec<(Table, usize)> { + let mut tables_heights_sorted = table_heights.clone().into_iter().collect::>(); + tables_heights_sorted.sort_by_key(|&(_, h)| Reverse(h)); + tables_heights_sorted +} + +#[derive(Debug, Default)] pub struct ExtraDataForBuses>> { // GKR quotient challenges - pub fingerprint_challenge_powers: Vec, - pub fingerprint_challenge_powers_packed: Vec>, + pub logup_alpha_powers: Vec, + pub logup_alpha_powers_packed: Vec>, pub bus_beta: EF, pub bus_beta_packed: EFPacking, pub alpha_powers: Vec, @@ -133,17 +98,12 @@ impl AlphaPowers for ExtraDataForBuses { impl>> ExtraDataForBuses { pub fn transmute_bus_data(&self) -> (&Vec, &NewEF) { if TypeId::of::() == TypeId::of::() { - unsafe { - transmute::<(&Vec, &EF), (&Vec, &NewEF)>(( - &self.fingerprint_challenge_powers, - &self.bus_beta, - )) - } + unsafe { transmute::<(&Vec, &EF), (&Vec, &NewEF)>((&self.logup_alpha_powers, &self.bus_beta)) } } else { assert_eq!(TypeId::of::(), TypeId::of::>()); unsafe { transmute::<(&Vec>, &EFPacking), (&Vec, &NewEF)>(( - &self.fingerprint_challenge_powers_packed, + &self.logup_alpha_powers_packed, &self.bus_beta_packed, )) } @@ -155,14 +115,10 @@ impl>> ExtraDataForBuses { /// (Some columns may not appear in the AIR) pub trait TableT: Air { fn name(&self) -> &'static str; - fn identifier(&self) -> Table; - fn commited_columns_f(&self) -> Vec; - /// the first committed column in the extension starts at index 0 - fn commited_columns_ef(&self) -> Vec; - fn normal_lookups_f(&self) -> Vec; - fn normal_lookups_ef(&self) -> Vec; - fn vector_lookups(&self) -> Vec; - fn buses(&self) -> Vec; + fn table(&self) -> Table; + fn lookups_f(&self) -> Vec; + fn lookups_ef(&self) -> Vec; + fn bus(&self) -> Bus; fn padding_row_f(&self) -> Vec; fn padding_row_ef(&self) -> Vec; fn execute( @@ -170,7 +126,8 @@ pub trait TableT: Air { arg_a: F, arg_b: F, arg_c: F, - aux: usize, + aux_1: usize, + aux_2: usize, ctx: &mut InstructionContext<'_>, ) -> Result<(), RunnerError>; @@ -183,272 +140,61 @@ pub trait TableT: Air { self.n_columns_ef_air() } - fn air_padding_row_f(&self) -> Vec { - // only the shited_columns - self.down_column_indexes_f() - .into_iter() - .map(|i| self.padding_row_f()[i]) - .collect() + fn is_execution_table(&self) -> bool { + false } - fn air_padding_row_ef(&self) -> Vec { - // only the shited_columns - self.down_column_indexes_ef() - .into_iter() - .map(|i| self.padding_row_ef()[i]) - .collect() - } - fn committed_dims(&self, n_rows: usize) -> Vec> { - let mut dims = self - .commited_columns_f() - .iter() - .map(|&c| ColDims::padded(n_rows, self.padding_row_f()[c])) - .collect::>(); - dims.extend(committed_dims_extension_from_base( - n_rows, - self.commited_columns_ef() - .iter() - .map(|&c| self.padding_row_ef()[c]) - .collect(), - )); - dims - } - fn num_commited_columns_f(&self) -> usize { - self.commited_columns_f().len() - } - fn n_commited_columns_ef(&self) -> usize { - self.commited_columns_ef().len() - } - fn committed_statements_prover( - &self, - prover_state: &mut FSProver>, - air_point: &MultilinearPoint, - air_values_f: &[EF], - ext_commitment_helper: Option<&ExtensionCommitmentFromBaseProver>, - normal_lookup_statements_on_indexes_f: &mut Vec>>, - normal_lookup_statements_on_indexes_ef: &mut Vec>>, - ) -> Vec>> { - assert_eq!(air_values_f.len(), self.n_columns_f_air()); - let mut statements = self - .commited_columns_f() - .iter() - .map(|&c| vec![Evaluation::new(air_point.clone(), air_values_f[c])]) - .collect::>(); - if let Some(ext_commitment_helper) = ext_commitment_helper { - statements.extend(ext_commitment_helper.after_commitment(prover_state, air_point)); - } - - for lookup in self.normal_lookups_f() { - statements[self.find_committed_column_index_f(lookup.index)] - .extend(normal_lookup_statements_on_indexes_f.remove(0)); - } - for lookup in self.normal_lookups_ef() { - statements[self.find_committed_column_index_f(lookup.index)] - .extend(normal_lookup_statements_on_indexes_ef.remove(0)); + fn n_commited_columns_f(&self) -> usize { + if self.is_execution_table() { + N_COMMITTED_EXEC_COLUMNS + } else { + self.n_columns_f_air() } - - statements } - fn committed_statements_verifier( - &self, - verifier_state: &mut FSVerifier>, - air_point: &MultilinearPoint, - air_values_f: &[EF], - air_values_ef: &[EF], - normal_lookup_statements_on_indexes_f: &mut Vec>>, - normal_lookup_statements_on_indexes_ef: &mut Vec>>, - ) -> ProofResult>>> { - assert_eq!(air_values_f.len(), self.n_columns_f_air()); - assert_eq!(air_values_ef.len(), self.n_columns_ef_air()); - - let mut statements = self - .commited_columns_f() - .iter() - .map(|&c| vec![Evaluation::new(air_point.clone(), air_values_f[c])]) - .collect::>(); - - if self.n_commited_columns_ef() > 0 { - statements.extend(ExtensionCommitmentFromBaseVerifier::after_commitment( - verifier_state, - &MultiEvaluation::new( - air_point.clone(), - self.commited_columns_ef() - .iter() - .map(|&c| air_values_ef[c]) - .collect::>(), - ), - )?); - } - for lookup in self.normal_lookups_f() { - statements[self.find_committed_column_index_f(lookup.index)] - .extend(normal_lookup_statements_on_indexes_f.remove(0)); - } - for lookup in self.normal_lookups_ef() { - statements[self.find_committed_column_index_f(lookup.index)] - .extend(normal_lookup_statements_on_indexes_ef.remove(0)); - } - Ok(statements) - } - fn normal_lookups_statements_f( - &self, - air_point: &MultilinearPoint, - air_values_f: &[EF], - ) -> Vec>> { - assert_eq!(air_values_f.len(), self.n_columns_f_air()); - let mut statements = Vec::new(); - for lookup in self.normal_lookups_f() { - statements.push(vec![Evaluation::new(air_point.clone(), air_values_f[lookup.values])]); - } - statements - } - fn normal_lookups_statements_ef( - &self, - air_point: &MultilinearPoint, - air_values_ef: &[EF], - ) -> Vec>> { - assert_eq!(air_values_ef.len(), self.n_columns_ef_air()); - let mut statements = Vec::new(); - for lookup in self.normal_lookups_ef() { - statements.push(vec![Evaluation::new(air_point.clone(), air_values_ef[lookup.values])]); - } - statements + fn n_commited_columns_ef(&self) -> usize { + self.n_columns_ef_air() } - fn vectorized_lookups_statements( - &self, - air_point: &MultilinearPoint, - air_values_f: &[EF], - ) -> Vec>> { - assert_eq!(air_values_f.len(), self.n_columns_f_air()); - let mut statements = Vec::new(); - for lookup in self.vector_lookups() { - statements.push(vec![MultiEvaluation::new( - air_point.clone(), - lookup.values.map(|col| air_values_f[col]).to_vec(), - )]); - } - statements + + fn n_commited_columns(&self) -> usize { + self.n_commited_columns_ef() * DIMENSION + self.n_commited_columns_f() } - fn committed_columns<'a>( - &self, - trace: &'a TableTrace, - computation_ext_to_base_helper: Option<&'a ExtensionCommitmentFromBaseProver>, - ) -> Vec<&'a [F]> { - // base field committed columns - let mut cols = self - .commited_columns_f() + + fn commited_air_values(&self, air_evals: &[EF]) -> BTreeMap { + // the intermidiate columns are not commited + // (they correspond to decoded instructions, in execution table, obtained via logup* into the bytecode) + air_evals .iter() - .map(|&c| &trace.base[c][..]) - .collect::>(); - // convert extension field committed columns to base field - if let Some(computation_ext_to_base_helper) = computation_ext_to_base_helper { - cols.extend( - computation_ext_to_base_helper - .sub_columns_to_commit - .iter() - .map(Vec::as_slice), - ); - } - cols + .copied() + .enumerate() + .filter(|(i, _)| *i < self.n_commited_columns_f() || *i >= self.n_columns_f_air()) + .collect::>() } - fn normal_lookup_index_columns_f<'a>(&'a self, trace: &'a TableTrace) -> Vec<&'a [F]> { - self.normal_lookups_f() + + fn lookup_index_columns_f<'a>(&'a self, trace: &'a TableTrace) -> Vec<&'a [F]> { + self.lookups_f() .iter() .map(|lookup| &trace.base[lookup.index][..]) .collect() } - fn normal_lookup_index_columns_ef<'a>(&'a self, trace: &'a TableTrace) -> Vec<&'a [F]> { - self.normal_lookups_ef() + fn lookup_index_columns_ef<'a>(&'a self, trace: &'a TableTrace) -> Vec<&'a [F]> { + self.lookups_ef() .iter() .map(|lookup| &trace.base[lookup.index][..]) .collect() } - fn num_normal_lookups_f(&self) -> usize { - self.normal_lookups_f().len() - } - fn num_normal_lookups_ef(&self) -> usize { - self.normal_lookups_ef().len() - } - fn num_vector_lookups(&self) -> usize { - self.vector_lookups().len() - } - fn vector_lookup_index_columns<'a>(&self, trace: &'a TableTrace) -> Vec<&'a [F]> { + fn lookup_f_value_columns<'a>(&self, trace: &'a TableTrace) -> Vec> { let mut cols = Vec::new(); - for lookup in self.vector_lookups() { - cols.push(&trace.base[lookup.index][..]); + for lookup in self.lookups_f() { + cols.push(lookup.values.iter().map(|&c| &trace.base[c][..]).collect()); } cols } - fn normal_lookup_f_value_columns<'a>(&self, trace: &'a TableTrace) -> Vec<&'a [F]> { + fn lookup_ef_value_columns<'a>(&self, trace: &'a TableTrace) -> Vec<&'a [EF]> { let mut cols = Vec::new(); - for lookup in self.normal_lookups_f() { - cols.push(&trace.base[lookup.values][..]); - } - cols - } - fn normal_lookup_ef_value_columns<'a>(&self, trace: &'a TableTrace) -> Vec<&'a [EF]> { - let mut cols = Vec::new(); - for lookup in self.normal_lookups_ef() { + for lookup in self.lookups_ef() { cols.push(&trace.ext[lookup.values][..]); } cols } - fn vector_lookup_values_columns<'a>(&self, trace: &'a TableTrace) -> Vec<[&'a [F]; VECTOR_LEN]> { - let mut cols = Vec::new(); - for lookup in self.vector_lookups() { - cols.push(array::from_fn(|i| &trace.base[lookup.values[i]][..])); - } - cols - } - fn normal_lookup_default_indexes_f(&self) -> Vec { - let mut default_indexes = Vec::new(); - for lookup in self.normal_lookups_f() { - default_indexes.push(self.padding_row_f()[lookup.index].to_usize()); - } - default_indexes - } - fn normal_lookup_default_indexes_ef(&self) -> Vec { - let mut default_indexes = Vec::new(); - for lookup in self.normal_lookups_ef() { - default_indexes.push(self.padding_row_f()[lookup.index].to_usize()); - } - default_indexes - } - fn vector_lookup_default_indexes(&self) -> Vec { - let mut default_indexes = Vec::new(); - for lookup in self.vector_lookups() { - default_indexes.push(self.padding_row_f()[lookup.index].to_usize()); - } - default_indexes - } - fn find_committed_column_index_f(&self, col: ColIndex) -> usize { - self.commited_columns_f().iter().position(|&c| c == col).unwrap() - } -} - -pub fn finger_print>(table: F, data: &[F], challenge: EF) -> EF { - dot_product::(challenge.powers().skip(1), data.iter().copied()) + table -} - -impl Bus { - pub fn padding_contribution( - &self, - table: &T, - padding: usize, - bus_challenge: EF, - fingerprint_challenge: EF, - ) -> EF { - let padding_row_f = table.padding_row_f(); - let default_selector = match &self.selector { - BusSelector::ConstantOne => F::ONE, - BusSelector::Column(col) => padding_row_f[*col], - }; - let default_table = match &self.table { - BusTable::Constant(t) => F::from_usize(t.index()), - BusTable::Variable(col) => padding_row_f[*col], - }; - let default_data = self.data.iter().map(|&col| padding_row_f[col]).collect::>(); - EF::from(default_selector * self.direction.to_field_flag() * F::from_usize(padding)) - / (bus_challenge + finger_print(default_table, &default_data, fingerprint_challenge)) - } } diff --git a/crates/lean_vm/src/tables/utils.rs b/crates/lean_vm/src/tables/utils.rs index 222decf7..96b5002b 100644 --- a/crates/lean_vm/src/tables/utils.rs +++ b/crates/lean_vm/src/tables/utils.rs @@ -1,23 +1,22 @@ -use multilinear_toolkit::prelude::{ExtensionField, PF}; -use p3_air::AirBuilder; +use multilinear_toolkit::prelude::*; use crate::ExtraDataForBuses; pub(crate) fn eval_virtual_bus_column>>( extra_data: &ExtraDataForBuses, - precompile_index: AB::F, + bus_index: AB::F, flag: AB::F, data: &[AB::F], ) -> AB::EF { - let (fingerprint_challenge_powers, bus_beta) = extra_data.transmute_bus_data::(); + let (logup_alpha_powers, bus_beta) = extra_data.transmute_bus_data::(); - assert!(data.len() < fingerprint_challenge_powers.len()); - (fingerprint_challenge_powers[1..] + assert!(data.len() < logup_alpha_powers.len()); + (logup_alpha_powers[1..] .iter() .zip(data) .map(|(c, d)| c.clone() * d.clone()) .sum::() - + precompile_index) + + bus_index) * bus_beta.clone() + flag } diff --git a/crates/lookup/Cargo.toml b/crates/lookup/Cargo.toml deleted file mode 100644 index 12354a7f..00000000 --- a/crates/lookup/Cargo.toml +++ /dev/null @@ -1,17 +0,0 @@ -[package] -name = "lookup" -version.workspace = true -edition.workspace = true - -[lints] -workspace = true - -[dependencies] -utils.workspace = true -p3-koala-bear.workspace = true -rand.workspace = true -whir-p3.workspace = true -p3-challenger.workspace = true -tracing.workspace = true -p3-util.workspace = true -multilinear-toolkit.workspace = true diff --git a/crates/lookup/src/lib.rs b/crates/lookup/src/lib.rs deleted file mode 100644 index 02e9ad03..00000000 --- a/crates/lookup/src/lib.rs +++ /dev/null @@ -1,14 +0,0 @@ -/* -Logup* (Lev Soukhanov) - -https://eprint.iacr.org/2025/946.pdf - -*/ - -mod quotient_gkr; -pub use quotient_gkr::*; - -mod logup_star; -pub use logup_star::*; - -pub(crate) const MIN_VARS_FOR_PACKING: usize = 8; diff --git a/crates/lookup/src/quotient_gkr.rs b/crates/lookup/src/quotient_gkr.rs deleted file mode 100644 index 33a287a9..00000000 --- a/crates/lookup/src/quotient_gkr.rs +++ /dev/null @@ -1,340 +0,0 @@ -use std::array; - -use multilinear_toolkit::prelude::*; -use tracing::instrument; -use utils::{FSProver, FSVerifier}; - -use crate::MIN_VARS_FOR_PACKING; - -/* -GKR to compute sum of fractions. -*/ - -#[instrument(skip_all)] -pub fn prove_gkr_quotient>, const N_GROUPS: usize>( - prover_state: &mut FSProver>, - numerators_and_denominators: &MleGroupRef<'_, EF>, -) -> (EF, MultilinearPoint, EF, EF) { - assert!(N_GROUPS.is_power_of_two() && N_GROUPS >= 2); - assert_eq!(numerators_and_denominators.n_columns(), 2); - let mut layers: Vec> = vec![split_mle_group(numerators_and_denominators, N_GROUPS / 2).into()]; - - loop { - let prev_layer: MleGroup<'_, EF> = layers.last().unwrap().by_ref().into(); - let prev_layer = if prev_layer.is_packed() && prev_layer.n_vars() < MIN_VARS_FOR_PACKING { - prev_layer.by_ref().unpack().as_owned_or_clone().into() - } else { - prev_layer - }; - if prev_layer.n_vars() == 1 { - break; - } - layers.push(sum_quotients(prev_layer.by_ref(), N_GROUPS).into()); - } - - let last_layer = layers.pop().unwrap(); - let last_layer = last_layer.as_owned_or_clone().as_extension().unwrap(); - - assert_eq!(last_layer.len(), N_GROUPS); - let last_nums_and_dens: [[_; 2]; N_GROUPS] = array::from_fn(|i| last_layer[i].to_vec().try_into().unwrap()); - for nd in &last_nums_and_dens { - prover_state.add_extension_scalars(nd); - } - let quotient = (0..N_GROUPS / 2) - .map(|i| last_nums_and_dens[i][0] / last_nums_and_dens[i + N_GROUPS / 2][0]) - .sum::() - + (0..N_GROUPS / 2) - .map(|i| last_nums_and_dens[i][1] / last_nums_and_dens[i + N_GROUPS / 2][1]) - .sum::(); - - let mut point = MultilinearPoint(vec![prover_state.sample()]); - let mut claims = last_nums_and_dens - .iter() - .map(|nd| nd.evaluate(&point)) - .collect::>(); - - for layer in layers[1..].iter().rev() { - (point, claims) = prove_gkr_quotient_step::<_, N_GROUPS>(prover_state, layer.by_ref(), &point, claims, false); - } - (point, claims) = prove_gkr_quotient_step::<_, N_GROUPS>(prover_state, layers[0].by_ref(), &point, claims, true); - assert_eq!(claims.len(), 2); - (quotient, point, claims[0], claims[1]) -} - -fn prove_gkr_quotient_step>, const N_GROUPS: usize>( - prover_state: &mut FSProver>, - numerators_and_denominators: MleGroupRef<'_, EF>, - claim_point: &MultilinearPoint, - claims: Vec, - univariate_skip: bool, -) -> (MultilinearPoint, Vec) { - let prev_numerators_and_denominators_split = match numerators_and_denominators { - MleGroupRef::Extension(numerators_and_denominators) => { - MleGroupRef::Extension(split_chunks(&numerators_and_denominators, 2)) - } - MleGroupRef::ExtensionPacked(numerators_and_denominators) => { - MleGroupRef::ExtensionPacked(split_chunks(&numerators_and_denominators, 2)) - } - _ => unreachable!(), - }; - - let alpha = prover_state.sample(); - - let (mut next_point, inner_evals, _) = sumcheck_prove::( - 1, - prev_numerators_and_denominators_split, - None, - &GKRQuotientComputation:: {}, - &alpha.powers().take(N_GROUPS).collect(), - Some((claim_point.0.clone(), None)), - false, - prover_state, - dot_product(claims.iter().copied(), alpha.powers()), - N_GROUPS != 2, - ); - - prover_state.add_extension_scalars(&inner_evals); - let beta = prover_state.sample(); - - let next_claims = if univariate_skip { - let selectors = univariate_selectors(log2_strict_usize(N_GROUPS)); - vec![ - evaluate_univariate_multilinear::<_, _, _, false>(&inner_evals[..N_GROUPS], &[beta], &selectors, None), - evaluate_univariate_multilinear::<_, _, _, false>(&inner_evals[N_GROUPS..], &[beta], &selectors, None), - ] - } else { - inner_evals - .chunks_exact(2) - .map(|chunk| chunk.evaluate(&MultilinearPoint(vec![beta]))) - .collect::>() - }; - - next_point.0.insert(0, beta); - - (next_point, next_claims) -} - -pub fn verify_gkr_quotient>, const N_GROUPS: usize>( - verifier_state: &mut FSVerifier>, - n_vars: usize, -) -> Result<(EF, MultilinearPoint, EF, EF), ProofError> { - let last_nums_and_dens: [[_; 2]; N_GROUPS] = array::from_fn(|_| { - verifier_state.next_extension_scalars_const().unwrap() // TODO avoid unwrap - }); - - let quotient = (0..N_GROUPS / 2) - .map(|i| last_nums_and_dens[i][0] / last_nums_and_dens[i + N_GROUPS / 2][0]) - .sum::() - + (0..N_GROUPS / 2) - .map(|i| last_nums_and_dens[i][1] / last_nums_and_dens[i + N_GROUPS / 2][1]) - .sum::(); - - let mut point = MultilinearPoint(vec![verifier_state.sample()]); - let mut claims = last_nums_and_dens - .iter() - .map(|nd| nd.evaluate(&point)) - .collect::>(); - - for i in 1..n_vars - log2_strict_usize(N_GROUPS) { - (point, claims) = verify_gkr_quotient_step::<_, N_GROUPS>(verifier_state, i, &point, claims, false)?; - } - (point, claims) = verify_gkr_quotient_step::<_, N_GROUPS>( - verifier_state, - n_vars - log2_strict_usize(N_GROUPS), - &point, - claims, - true, - )?; - assert_eq!(claims.len(), 2); - - Ok((quotient, point, claims[0], claims[1])) -} - -fn verify_gkr_quotient_step>, const N_GROUPS: usize>( - verifier_state: &mut FSVerifier>, - n_vars: usize, - point: &MultilinearPoint, - claims: Vec, - univariate_skip: bool, -) -> Result<(MultilinearPoint, Vec), ProofError> { - assert_eq!(claims.len(), N_GROUPS); - let alpha = verifier_state.sample(); - - let (retrieved_quotient, postponed) = sumcheck_verify(verifier_state, n_vars, 3)?; - - if retrieved_quotient != dot_product(claims.iter().copied(), alpha.powers()) { - return Err(ProofError::InvalidProof); - } - - let inner_evals = verifier_state.next_extension_scalars_vec(N_GROUPS * 2)?; - - if postponed.value - != point.eq_poly_outside(&postponed.point) - * as SumcheckComputation>::eval_extension( - &Default::default(), - &inner_evals, - &[], - &alpha.powers().take(N_GROUPS).collect(), - ) - { - return Err(ProofError::InvalidProof); - } - - let beta = verifier_state.sample(); - let next_claims = if univariate_skip { - let selectors = univariate_selectors(log2_strict_usize(N_GROUPS)); - vec![ - evaluate_univariate_multilinear::<_, _, _, false>(&inner_evals[..N_GROUPS], &[beta], &selectors, None), - evaluate_univariate_multilinear::<_, _, _, false>(&inner_evals[N_GROUPS..], &[beta], &selectors, None), - ] - } else { - inner_evals - .chunks_exact(2) - .map(|chunk| chunk.evaluate(&MultilinearPoint(vec![beta]))) - .collect::>() - }; - let mut next_point = postponed.point.clone(); - next_point.0.insert(0, beta); - - Ok((next_point, next_claims)) -} - -fn sum_quotients>>( - numerators_and_denominators: MleGroupRef<'_, EF>, - n_groups: usize, -) -> MleGroupOwned { - match numerators_and_denominators { - MleGroupRef::ExtensionPacked(numerators_and_denominators) => { - MleGroupOwned::ExtensionPacked(sum_quotients_helper(&numerators_and_denominators, n_groups)) - } - MleGroupRef::Extension(numerators_and_denominators) => { - MleGroupOwned::Extension(sum_quotients_helper(&numerators_and_denominators, n_groups)) - } - _ => unreachable!(), - } -} - -fn sum_quotients_helper( - numerators_and_denominators: &[&[F]], - n_groups: usize, -) -> Vec> { - assert_eq!(numerators_and_denominators.len(), n_groups); - let n = numerators_and_denominators[0].len(); - assert!(n.is_power_of_two() && n >= 2, "n = {n}"); - let mut new_numerators = Vec::new(); - let mut new_denominators = Vec::new(); - let (prev_numerators, prev_denominators) = numerators_and_denominators.split_at(n_groups / 2); - for i in 0..n_groups / 2 { - let (new_num, new_den) = sum_quotients_2_by_2::(prev_numerators[i], prev_denominators[i]); - new_numerators.push(new_num); - new_denominators.push(new_den); - } - new_numerators.extend(new_denominators); - new_numerators -} -fn sum_quotients_2_by_2( - numerators: &[F], - denominators: &[F], -) -> (Vec, Vec) { - let n = numerators.len(); - let new_n = n / 2; - let mut new_numerators = unsafe { uninitialized_vec(new_n) }; - let mut new_denominators = unsafe { uninitialized_vec(new_n) }; - new_numerators - .par_iter_mut() - .zip(new_denominators.par_iter_mut()) - .enumerate() - .for_each(|(i, (num, den))| { - let my_numerators: [_; 2] = [numerators[i], numerators[i + new_n]]; - let my_denominators: [_; 2] = [denominators[i], denominators[i + new_n]]; - *num = my_numerators[0] * my_denominators[1] + my_numerators[1] * my_denominators[0]; - *den = my_denominators[0] * my_denominators[1]; - }); - (new_numerators, new_denominators) -} -fn split_mle_group<'a, EF: ExtensionField>>( - polys: &'a MleGroupRef<'a, EF>, - n_groups: usize, -) -> MleGroupRef<'a, EF> { - match polys { - MleGroupRef::Extension(polys) => MleGroupRef::Extension(split_chunks(polys, n_groups)), - MleGroupRef::ExtensionPacked(polys) => MleGroupRef::ExtensionPacked(split_chunks(polys, n_groups)), - _ => unreachable!(), - } -} - -fn split_chunks<'a, A>(numerators_and_denominators: &[&'a [A]], num_groups: usize) -> Vec<&'a [A]> { - let n = numerators_and_denominators[0].len(); - assert!(n.is_power_of_two() && n >= num_groups); - assert!(num_groups.is_power_of_two()); - - let mut res = Vec::new(); - for slice in numerators_and_denominators { - assert_eq!(slice.len(), n); - res.extend(split_at_many( - slice, - &(1..num_groups).map(|i| i * n / num_groups).collect::>(), - )); - } - res -} - -#[cfg(test)] -mod tests { - use std::time::Instant; - - use super::*; - use p3_koala_bear::QuinticExtensionFieldKB; - use rand::{Rng, SeedableRng, rngs::StdRng}; - use utils::{build_prover_state, build_verifier_state, init_tracing}; - - type EF = QuinticExtensionFieldKB; - - fn sum_all_quotients(nums: &[EF], den: &[EF]) -> EF { - nums.iter().zip(den.iter()).map(|(&n, &d)| n / d).sum() - } - - const N_GROUPS: usize = 8; - - #[test] - fn test_gkr_quotient() { - let log_n = 22; - let n = 1 << log_n; - init_tracing(); - - let mut rng = StdRng::seed_from_u64(0); - - let numerators = (0..n).map(|_| rng.random()).collect::>(); - let c: EF = rng.random(); - let denominators_indexes = (0..n) - .map(|_| PF::::from_usize(rng.random_range(..n))) - .collect::>(); - let denominators = denominators_indexes.iter().map(|&i| c - i).collect::>(); - let real_quotient = sum_all_quotients(&numerators, &denominators); - let mut prover_state = build_prover_state(false); - - let time = Instant::now(); - let prover_statements = prove_gkr_quotient::( - &mut prover_state, - &MleGroupRef::ExtensionPacked(vec![&pack_extension(&numerators), &pack_extension(&denominators)]), - ); - println!("Proving time: {:?}", time.elapsed()); - - let mut verifier_state = build_verifier_state(prover_state); - - let verifier_statements = verify_gkr_quotient::(&mut verifier_state, log_n).unwrap(); - assert_eq!(&verifier_statements, &prover_statements); - let (retrieved_quotient, claim_point, claim_num, claim_den) = verifier_statements; - - assert_eq!(retrieved_quotient, real_quotient); - let selectors = univariate_selectors::>(log2_strict_usize(N_GROUPS)); - assert_eq!( - evaluate_univariate_multilinear::<_, _, _, true>(&numerators, &claim_point, &selectors, None), - claim_num - ); - assert_eq!( - evaluate_univariate_multilinear::<_, _, _, true>(&denominators, &claim_point, &selectors, None), - claim_den - ); - } -} diff --git a/crates/poseidon_circuit/Cargo.toml b/crates/poseidon_circuit/Cargo.toml deleted file mode 100644 index 74625f71..00000000 --- a/crates/poseidon_circuit/Cargo.toml +++ /dev/null @@ -1,20 +0,0 @@ -[package] -name = "poseidon_circuit" -version.workspace = true -edition.workspace = true - -[lints] -workspace = true - -[dependencies] -tracing.workspace = true -utils.workspace = true -# p3-util.workspace = true -multilinear-toolkit.workspace = true -p3-koala-bear.workspace = true -p3-poseidon2.workspace = true -p3-monty-31.workspace = true -rand.workspace = true -whir-p3.workspace = true -sub_protocols.workspace = true - diff --git a/crates/poseidon_circuit/src/gkr_layers/batch_partial_rounds.rs b/crates/poseidon_circuit/src/gkr_layers/batch_partial_rounds.rs deleted file mode 100644 index 4cecb47e..00000000 --- a/crates/poseidon_circuit/src/gkr_layers/batch_partial_rounds.rs +++ /dev/null @@ -1,122 +0,0 @@ -use std::array; - -use multilinear_toolkit::prelude::*; -use p3_koala_bear::{GenericPoseidon2LinearLayersKoalaBear, KoalaBearInternalLayerParameters, KoalaBearParameters}; -use p3_monty_31::InternalLayerBaseParameters; -use p3_poseidon2::GenericPoseidon2LinearLayers; - -use crate::{EF, F}; - -#[derive(Debug)] -pub struct BatchPartialRounds { - pub constants: [F; N_COMMITED_CUBES], - pub last_constant: F, -} - -impl SumcheckComputation - for BatchPartialRounds -where - KoalaBearInternalLayerParameters: InternalLayerBaseParameters, - EF: ExtensionField>, -{ - type ExtraData = Vec; - - fn degree(&self) -> usize { - 3 - } - - #[inline(always)] - fn eval_base(&self, point: &[PF], _: &[EF], alpha_powers: &Self::ExtraData) -> EF { - self.my_eval::>(point, alpha_powers) - } - - #[inline(always)] - fn eval_extension(&self, point: &[EF], _: &[EF], alpha_powers: &Self::ExtraData) -> EF { - self.my_eval::(point, alpha_powers) - } - - #[inline(always)] - fn eval_packed_base( - &self, - point: &[FPacking], - _: &[EFPacking], - alpha_powers: &Self::ExtraData, - ) -> EFPacking { - debug_assert_eq!(point.len(), WIDTH + N_COMMITED_CUBES); - debug_assert_eq!(alpha_powers.len(), WIDTH + N_COMMITED_CUBES); - - let mut res = EFPacking::::ZERO; - let mut buff: [FPacking; WIDTH] = array::from_fn(|j| point[j]); - for (i, &constant) in self.constants.iter().enumerate() { - let computed_cube = (buff[0] + constant).cube(); - res += EFPacking::::from(alpha_powers[WIDTH + i]) * computed_cube; - buff[0] = point[WIDTH + i]; // commited cube - GenericPoseidon2LinearLayersKoalaBear::internal_linear_layer(&mut buff); - } - - buff[0] = (buff[0] + self.last_constant).cube(); - GenericPoseidon2LinearLayersKoalaBear::internal_linear_layer(&mut buff); - for i in 0..WIDTH { - res += EFPacking::::from(alpha_powers[i]) * buff[i]; - } - res - } - - #[inline(always)] - fn eval_packed_extension( - &self, - point: &[EFPacking], - _: &[EFPacking], - alpha_powers: &Self::ExtraData, - ) -> EFPacking { - debug_assert_eq!(point.len(), WIDTH + N_COMMITED_CUBES); - debug_assert_eq!(alpha_powers.len(), WIDTH + N_COMMITED_CUBES); - - let mut res = EFPacking::::ZERO; - let mut buff: [EFPacking; WIDTH] = array::from_fn(|j| point[j]); - for (i, &constant) in self.constants.iter().enumerate() { - let computed_cube = (buff[0] + PFPacking::::from(constant)).cube(); - res += computed_cube * alpha_powers[WIDTH + i]; - buff[0] = point[WIDTH + i]; // commited cube - GenericPoseidon2LinearLayersKoalaBear::internal_linear_layer(&mut buff); - } - - buff[0] = (buff[0] + PFPacking::::from(self.last_constant)).cube(); - GenericPoseidon2LinearLayersKoalaBear::internal_linear_layer(&mut buff); - for i in 0..WIDTH { - res += buff[i] * alpha_powers[i]; - } - res - } -} - -impl BatchPartialRounds -where - KoalaBearInternalLayerParameters: InternalLayerBaseParameters, - EF: ExtensionField>, -{ - #[inline(always)] - fn my_eval>>(&self, point: &[NF], alpha_powers: &[EF]) -> EF - where - EF: ExtensionField, - { - debug_assert_eq!(point.len(), WIDTH + N_COMMITED_CUBES); - debug_assert_eq!(alpha_powers.len(), WIDTH + N_COMMITED_CUBES); - - let mut res = EF::ZERO; - let mut buff: [NF; WIDTH] = array::from_fn(|j| point[j]); - for (i, &constant) in self.constants.iter().enumerate() { - let computed_cube = (buff[0] + constant).cube(); - res += alpha_powers[WIDTH + i] * computed_cube; - buff[0] = point[WIDTH + i]; // commited cube - GenericPoseidon2LinearLayersKoalaBear::internal_linear_layer(&mut buff); - } - - buff[0] = (buff[0] + self.last_constant).cube(); - GenericPoseidon2LinearLayersKoalaBear::internal_linear_layer(&mut buff); - for i in 0..WIDTH { - res += alpha_powers[i] * buff[i]; - } - res - } -} diff --git a/crates/poseidon_circuit/src/gkr_layers/compression.rs b/crates/poseidon_circuit/src/gkr_layers/compression.rs deleted file mode 100644 index d8a55f63..00000000 --- a/crates/poseidon_circuit/src/gkr_layers/compression.rs +++ /dev/null @@ -1,90 +0,0 @@ -use multilinear_toolkit::prelude::*; - -use crate::EF; - -#[derive(Debug)] -pub struct CompressionComputation { - pub compressed_output: usize, -} - -impl SumcheckComputation for CompressionComputation -where - EF: ExtensionField>, -{ - type ExtraData = Vec; - - fn degree(&self) -> usize { - 2 - } - - #[inline(always)] - fn eval_base(&self, point: &[PF], _: &[EF], alpha_powers: &Self::ExtraData) -> EF { - self.my_eval::>(point, alpha_powers) - } - - #[inline(always)] - fn eval_extension(&self, point: &[EF], _: &[EF], alpha_powers: &Self::ExtraData) -> EF { - self.my_eval::(point, alpha_powers) - } - - #[inline(always)] - fn eval_packed_base( - &self, - point: &[PFPacking], - _: &[EFPacking], - alpha_powers: &Self::ExtraData, - ) -> EFPacking { - debug_assert_eq!(point.len(), WIDTH + 1); - let mut res = EFPacking::::ZERO; - let compressed = point[WIDTH]; - for i in 0..self.compressed_output { - res += EFPacking::::from(alpha_powers[i]) * point[i]; - } - for i in self.compressed_output..WIDTH { - res += EFPacking::::from(alpha_powers[i]) * point[i] * (PFPacking::::ONE - compressed); - } - - res - } - - #[inline(always)] - fn eval_packed_extension( - &self, - point: &[EFPacking], - _: &[EFPacking], - alpha_powers: &Self::ExtraData, - ) -> EFPacking { - debug_assert_eq!(point.len(), WIDTH + 1); - let mut res = EFPacking::::ZERO; - let compressed = point[WIDTH]; - for i in 0..self.compressed_output { - res += point[i] * alpha_powers[i]; - } - for i in self.compressed_output..WIDTH { - res += point[i] * (EFPacking::::ONE - compressed) * alpha_powers[i]; - } - - res - } -} - -impl CompressionComputation { - #[inline(always)] - fn my_eval> + ExtensionField, NF: ExtensionField>>( - &self, - point: &[NF], - alpha_powers: &[EF], - ) -> EF { - debug_assert_eq!(point.len(), WIDTH + 1); - let mut res = EF::ZERO; - let compressed = point[WIDTH]; - for i in 0..self.compressed_output { - res += alpha_powers[i] * point[i]; - } - for i in self.compressed_output..WIDTH { - res += alpha_powers[i] * point[i] * (EF::ONE - compressed); - } - - res - } -} diff --git a/crates/poseidon_circuit/src/gkr_layers/full_round.rs b/crates/poseidon_circuit/src/gkr_layers/full_round.rs deleted file mode 100644 index 9b5dda76..00000000 --- a/crates/poseidon_circuit/src/gkr_layers/full_round.rs +++ /dev/null @@ -1,79 +0,0 @@ -use multilinear_toolkit::prelude::*; -use p3_koala_bear::{KoalaBearInternalLayerParameters, KoalaBearParameters}; -use p3_monty_31::InternalLayerBaseParameters; - -use crate::EF; - -#[derive(Debug)] -pub struct FullRoundComputation {} - -impl SumcheckComputation for FullRoundComputation -where - KoalaBearInternalLayerParameters: InternalLayerBaseParameters, - EF: ExtensionField>, -{ - type ExtraData = Vec; - - fn degree(&self) -> usize { - 3 - } - - #[inline(always)] - fn eval_base(&self, point: &[PF], _: &[EF], alpha_powers: &Self::ExtraData) -> EF { - self.my_eval::>(point, alpha_powers) - } - - #[inline(always)] - fn eval_extension(&self, point: &[EF], _: &[EF], alpha_powers: &Self::ExtraData) -> EF { - self.my_eval::(point, alpha_powers) - } - - #[inline(always)] - fn eval_packed_base( - &self, - point: &[PFPacking], - _: &[EFPacking], - alpha_powers: &Self::ExtraData, - ) -> EFPacking { - debug_assert_eq!(point.len(), WIDTH); - let mut res = EFPacking::::ZERO; - for i in 0..WIDTH { - res += EFPacking::::from(alpha_powers[i]) * point[i].cube(); - } - res - } - - #[inline(always)] - fn eval_packed_extension( - &self, - point: &[EFPacking], - _: &[EFPacking], - alpha_powers: &Self::ExtraData, - ) -> EFPacking { - debug_assert_eq!(point.len(), WIDTH); - let mut res = EFPacking::::ZERO; - for i in 0..WIDTH { - res += point[i].cube() * alpha_powers[i]; - } - res - } -} - -impl FullRoundComputation -where - KoalaBearInternalLayerParameters: InternalLayerBaseParameters, - EF: ExtensionField>, -{ - #[inline(always)] - fn my_eval>>(&self, point: &[NF], alpha_powers: &[EF]) -> EF - where - EF: ExtensionField, - { - debug_assert_eq!(point.len(), WIDTH); - let mut res = EF::ZERO; - for i in 0..WIDTH { - res += alpha_powers[i] * point[i].cube(); - } - res - } -} diff --git a/crates/poseidon_circuit/src/gkr_layers/mod.rs b/crates/poseidon_circuit/src/gkr_layers/mod.rs deleted file mode 100644 index b5ddff97..00000000 --- a/crates/poseidon_circuit/src/gkr_layers/mod.rs +++ /dev/null @@ -1,80 +0,0 @@ -mod full_round; -pub use full_round::*; - -mod partial_round; -pub use partial_round::*; - -mod batch_partial_rounds; -pub use batch_partial_rounds::*; - -mod compression; -pub use compression::*; - -use p3_koala_bear::{ - KOALABEAR_RC16_EXTERNAL_FINAL, KOALABEAR_RC16_EXTERNAL_INITIAL, KOALABEAR_RC16_INTERNAL, - KOALABEAR_RC24_EXTERNAL_FINAL, KOALABEAR_RC24_EXTERNAL_INITIAL, KOALABEAR_RC24_INTERNAL, -}; - -use crate::F; - -#[derive(Debug)] -pub struct PoseidonGKRLayers { - pub initial_full_rounds: Vec<[F; WIDTH]>, - pub batch_partial_rounds: Option>, - pub partial_rounds_remaining: Vec, - pub final_full_rounds: Vec<[F; WIDTH]>, - pub compressed_output: Option, -} - -impl PoseidonGKRLayers { - pub fn build(compressed_output: Option) -> Self { - match WIDTH { - 16 => unsafe { - Self::build_generic( - &*(&KOALABEAR_RC16_EXTERNAL_INITIAL as *const [[F; 16]] as *const [[F; WIDTH]]), - &KOALABEAR_RC16_INTERNAL, - &*(&KOALABEAR_RC16_EXTERNAL_FINAL as *const [[F; 16]] as *const [[F; WIDTH]]), - compressed_output, - ) - }, - 24 => unsafe { - Self::build_generic( - &*(&KOALABEAR_RC24_EXTERNAL_INITIAL as *const [[F; 24]] as *const [[F; WIDTH]]), - &KOALABEAR_RC24_INTERNAL, - &*(&KOALABEAR_RC24_EXTERNAL_FINAL as *const [[F; 24]] as *const [[F; WIDTH]]), - compressed_output, - ) - }, - _ => panic!("Only Poseidon 16 and 24 are supported currently"), - } - } - - fn build_generic( - initial_constants: &[[F; WIDTH]], - internal_constants: &[F], - final_constants: &[[F; WIDTH]], - compressed_output: Option, - ) -> Self { - assert!(N_COMMITED_CUBES < internal_constants.len() - 1); // TODO we could go up to internal_constants.len() in theory - let initial_full_rounds = initial_constants.to_vec(); - let (batch_partial_rounds, partial_rounds_remaining) = if N_COMMITED_CUBES == 0 { - (None, internal_constants.to_vec()) - } else { - ( - Some(BatchPartialRounds { - constants: internal_constants[..N_COMMITED_CUBES].try_into().unwrap(), - last_constant: internal_constants[N_COMMITED_CUBES], - }), - internal_constants[N_COMMITED_CUBES + 1..].to_vec(), - ) - }; - let final_full_rounds = final_constants.to_vec(); - Self { - initial_full_rounds, - batch_partial_rounds, - partial_rounds_remaining, - final_full_rounds, - compressed_output, - } - } -} diff --git a/crates/poseidon_circuit/src/gkr_layers/partial_round.rs b/crates/poseidon_circuit/src/gkr_layers/partial_round.rs deleted file mode 100644 index 1430cd0e..00000000 --- a/crates/poseidon_circuit/src/gkr_layers/partial_round.rs +++ /dev/null @@ -1,76 +0,0 @@ -use multilinear_toolkit::prelude::*; - -use crate::{EF, F}; - -#[derive(Debug)] -pub struct PartialRoundComputation; - -impl SumcheckComputation for PartialRoundComputation -where - EF: ExtensionField>, -{ - type ExtraData = Vec; - - fn degree(&self) -> usize { - 3 - } - - #[inline(always)] - fn eval_base(&self, point: &[PF], _: &[EF], alpha_powers: &Self::ExtraData) -> EF { - self.my_eval::>(point, alpha_powers) - } - - #[inline(always)] - fn eval_extension(&self, point: &[EF], _: &[EF], alpha_powers: &Self::ExtraData) -> EF { - self.my_eval::(point, alpha_powers) - } - - #[inline(always)] - fn eval_packed_base( - &self, - point: &[FPacking], - _: &[EFPacking], - alpha_powers: &Self::ExtraData, - ) -> EFPacking { - debug_assert_eq!(point.len(), WIDTH); - let mut res = EFPacking::::from(point[0].cube()); - for i in 1..WIDTH { - res += EFPacking::::from(alpha_powers[i]) * point[i]; - } - res - } - - #[inline(always)] - fn eval_packed_extension( - &self, - point: &[EFPacking], - _: &[EFPacking], - alpha_powers: &Self::ExtraData, - ) -> EFPacking { - debug_assert_eq!(point.len(), WIDTH); - let mut res = point[0].cube(); - for i in 1..WIDTH { - res += point[i] * alpha_powers[i]; - } - res - } -} - -impl PartialRoundComputation -where - EF: ExtensionField>, -{ - #[inline(always)] - fn my_eval> + ExtensionField, NF: ExtensionField>>( - &self, - point: &[NF], - alpha_powers: &[EF], - ) -> EF { - debug_assert_eq!(point.len(), WIDTH); - let mut res = EF::from(point[0].cube()); - for i in 1..WIDTH { - res += alpha_powers[i] * point[i]; - } - res - } -} diff --git a/crates/poseidon_circuit/src/lib.rs b/crates/poseidon_circuit/src/lib.rs deleted file mode 100644 index 965d238f..00000000 --- a/crates/poseidon_circuit/src/lib.rs +++ /dev/null @@ -1,33 +0,0 @@ -#![cfg_attr(not(test), warn(unused_crate_dependencies))] - -use multilinear_toolkit::prelude::{Evaluation, MultiEvaluation}; -use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; - -mod prove; -pub use prove::*; - -mod verify; -pub use verify::*; - -mod utils; -pub use utils::*; - -mod witness_gen; -pub use witness_gen::*; - -pub mod tests; - -pub mod gkr_layers; -pub use gkr_layers::*; - -pub(crate) type F = KoalaBear; -pub(crate) type EF = QuinticExtensionFieldKB; - -/// remain to be proven -#[derive(Debug, Clone)] -pub struct GKRPoseidonResult { - pub output_statements: MultiEvaluation, // of length width - pub input_statements: MultiEvaluation, // of length width - pub cubes_statements: MultiEvaluation, // of length n_committed_cubes - pub on_compression_selector: Option>, // univariate_skips = 1 here (TODO dont do this) -} diff --git a/crates/poseidon_circuit/src/prove.rs b/crates/poseidon_circuit/src/prove.rs deleted file mode 100644 index 1b7f1d09..00000000 --- a/crates/poseidon_circuit/src/prove.rs +++ /dev/null @@ -1,332 +0,0 @@ -use crate::{ - CompressionComputation, EF, F, FullRoundComputation, PartialRoundComputation, PoseidonWitness, - gkr_layers::{BatchPartialRounds, PoseidonGKRLayers}, -}; -use crate::{GKRPoseidonResult, build_poseidon_inv_matrices}; -use multilinear_toolkit::prelude::*; -use p3_koala_bear::{KoalaBearInternalLayerParameters, KoalaBearParameters}; -use p3_monty_31::InternalLayerBaseParameters; -use tracing::{info_span, instrument}; -use utils::fold_multilinear_chunks; - -#[instrument(skip_all)] -pub fn prove_poseidon_gkr( - prover_state: &mut FSProver>, - witness: &PoseidonWitness, WIDTH, N_COMMITED_CUBES>, - output_point: Vec, - univariate_skips: usize, - layers: &PoseidonGKRLayers, -) -> GKRPoseidonResult -where - KoalaBearInternalLayerParameters: InternalLayerBaseParameters, -{ - let mut point = output_point.clone(); - let selectors = univariate_selectors::(univariate_skips); - let (inv_mds_matrix, inv_light_matrix) = build_poseidon_inv_matrices::(); - - assert_eq!(point.len(), log2_strict_usize(witness.n_poseidons())); - - let (output_claims, mut claims) = info_span!("computing output claims").in_scope(|| { - let eq_poly = eval_eq(&point[univariate_skips..]); - let inner_evals = match &witness.compression { - Some((_, compressed_output)) => compressed_output, - None => &witness.output_layer, - } - .par_iter() - .map(|poly| { - FPacking::::unpack_slice(poly) - .chunks_exact(eq_poly.len()) - .map(|chunk| dot_product(eq_poly.iter().copied(), chunk.iter().copied())) - .collect::>() - }) - .collect::>(); - for evals in &inner_evals { - prover_state.add_extension_scalars(evals); - } - let alpha = prover_state.sample(); - let selectors_at_alpha = selectors - .iter() - .map(|selector| selector.evaluate(alpha)) - .collect::>(); - - let mut output_claims = vec![]; - let mut claims = vec![]; - for evals in inner_evals { - output_claims.push(evals.evaluate(&MultilinearPoint(point[..univariate_skips].to_vec()))); - claims.push(dot_product(selectors_at_alpha.iter().copied(), evals.into_iter())) - } - point = [vec![alpha], point[univariate_skips..].to_vec()].concat(); - (output_claims, claims) - }); - - let on_compression_selector = if let Some((compression_indicator, _)) = &witness.compression { - (point, claims) = prove_gkr_round( - prover_state, - &CompressionComputation:: { - compressed_output: layers.compressed_output.unwrap(), - }, - &witness - .output_layer - .iter() - .chain(std::iter::once(compression_indicator)) - .map(Vec::as_slice) - .collect::>(), - &point, - &claims, - univariate_skips, - ); - - let inner_evals = fold_multilinear_chunks( - FPacking::::unpack_slice(compression_indicator), - &MultilinearPoint(point[1..].to_vec()), - ); - prover_state.add_extension_scalars(&inner_evals); - let _ = claims.pop().unwrap(); // remove compression claim - let epsilons = prover_state.sample_vec(univariate_skips); - let new_point = MultilinearPoint([epsilons.clone(), point[1..].to_vec()].concat()); - let new_eval = inner_evals.evaluate(&MultilinearPoint(epsilons)); - - Some(Evaluation::new(new_point, new_eval)) - } else { - None - }; - - for (layer, full_round_constants) in witness.final_full_layers.iter().zip(&layers.final_full_rounds).rev() { - claims = apply_matrix(&inv_mds_matrix, &claims); - - (point, claims) = prove_gkr_round( - prover_state, - &FullRoundComputation {}, - &layer.iter().map(Vec::as_slice).collect::>(), - &point, - &claims, - univariate_skips, - ); - - for (claim, c) in claims.iter_mut().zip(full_round_constants) { - *claim -= *c; - } - } - - for (input_layers, partial_round_constant) in witness - .remaining_partial_round_layers - .iter() - .zip(&layers.partial_rounds_remaining) - .rev() - { - claims = apply_matrix(&inv_light_matrix, &claims); - - (point, claims) = prove_gkr_round( - prover_state, - &PartialRoundComputation:: {}, - input_layers, - &point, - &claims, - univariate_skips, - ); - claims[0] -= *partial_round_constant; - } - - if let Some(batch_partial_round_input) = &witness.batch_partial_round_input { - (point, claims) = prove_batch_internal_round( - prover_state, - batch_partial_round_input, - &witness.committed_cubes, - layers.batch_partial_rounds.as_ref().unwrap(), - &point, - &claims, - &selectors, - ); - } - - let pcs_point_for_cubes = point.clone(); - - claims = claims[..WIDTH].to_vec(); - - for (layer, full_round_constants) in witness - .initial_full_layers - .iter() - .zip(&layers.initial_full_rounds) - .rev() - { - claims = apply_matrix(&inv_mds_matrix, &claims); - - (point, claims) = prove_gkr_round( - prover_state, - &FullRoundComputation {}, - &layer.iter().map(Vec::as_slice).collect::>(), - &point, - &claims, - univariate_skips, - ); - - for (claim, c) in claims.iter_mut().zip(full_round_constants) { - *claim -= *c; - } - } - - claims = apply_matrix(&inv_mds_matrix, &claims); - let _ = claims; - - let pcs_point_for_inputs = point.clone(); - - let input_statements = inner_evals_on_commited_columns( - prover_state, - &pcs_point_for_inputs, - univariate_skips, - &witness.input_layer, - ); - let cubes_statements = if N_COMMITED_CUBES == 0 { - Default::default() - } else { - inner_evals_on_commited_columns( - prover_state, - &pcs_point_for_cubes, - univariate_skips, - &witness.committed_cubes, - ) - }; - - let output_statements = MultiEvaluation::new(output_point, output_claims); - GKRPoseidonResult { - output_statements, - input_statements, - cubes_statements, - on_compression_selector, - } -} - -// #[instrument(skip_all)] -fn prove_gkr_round> + 'static>( - prover_state: &mut FSProver>, - computation: &SC, - input_layers: &[impl AsRef<[PFPacking]>], - claim_point: &[EF], - output_claims: &[EF], - univariate_skips: usize, -) -> (Vec, Vec) { - let batching_scalar = prover_state.sample(); - let batching_scalars_powers = batching_scalar.powers().collect_n(output_claims.len()); - let batched_claim: EF = dot_product(output_claims.iter().copied(), batching_scalars_powers.iter().copied()); - - let (sumcheck_point, sumcheck_inner_evals, sumcheck_final_sum) = sumcheck_prove( - univariate_skips, - MleGroupRef::BasePacked(input_layers.iter().map(|l| l.as_ref()).collect()), - None, - computation, - &batching_scalars_powers, - Some((claim_point.to_vec(), None)), - false, - prover_state, - batched_claim, - true, - ); - - // sanity check - debug_assert_eq!( - computation.eval_extension(&sumcheck_inner_evals, &[], &batching_scalars_powers) - * eq_poly_with_skip(&sumcheck_point, claim_point, univariate_skips), - sumcheck_final_sum - ); - - prover_state.add_extension_scalars(&sumcheck_inner_evals); - - (sumcheck_point.0, sumcheck_inner_evals) -} - -#[instrument(skip_all)] -fn prove_batch_internal_round( - prover_state: &mut FSProver>, - input_layers: &[Vec>], - committed_cubes: &[Vec>], - computation: &BatchPartialRounds, - claim_point: &[EF], - output_claims: &[EF], - selectors: &[DensePolynomial], -) -> (Vec, Vec) -where - KoalaBearInternalLayerParameters: InternalLayerBaseParameters, -{ - assert_eq!(input_layers.len(), WIDTH); - assert_eq!(committed_cubes.len(), N_COMMITED_CUBES); - let univariate_skips = log2_strict_usize(selectors.len()); - - let cubes_evals = info_span!("computing cube evals").in_scope(|| { - batch_evaluate_univariate_multilinear( - &committed_cubes - .iter() - .map(|l| PFPacking::::unpack_slice(l)) - .collect::>(), - claim_point, - selectors, - ) - }); - - prover_state.add_extension_scalars(&cubes_evals); - - let batching_scalar = prover_state.sample(); - let batched_claim: EF = dot_product( - output_claims.iter().chain(&cubes_evals).copied(), - batching_scalar.powers(), - ); - let batching_scalars_powers = batching_scalar.powers().collect_n(WIDTH + N_COMMITED_CUBES); - - let (sumcheck_point, sumcheck_inner_evals, sumcheck_final_sum) = sumcheck_prove( - univariate_skips, - MleGroupRef::BasePacked( - input_layers - .iter() - .chain(committed_cubes.iter()) - .map(Vec::as_slice) - .collect(), - ), - None, - computation, - &batching_scalars_powers, - Some((claim_point.to_vec(), None)), - false, - prover_state, - batched_claim, - true, - ); - - // sanity check - debug_assert_eq!( - computation.eval_extension(&sumcheck_inner_evals, &[], &batching_scalars_powers) - * eq_poly_with_skip(&sumcheck_point, claim_point, univariate_skips), - sumcheck_final_sum - ); - - prover_state.add_extension_scalars(&sumcheck_inner_evals); - - (sumcheck_point.0, sumcheck_inner_evals) -} - -fn inner_evals_on_commited_columns( - prover_state: &mut FSProver>, - point: &[EF], - univariate_skips: usize, - columns: &[Vec>], -) -> MultiEvaluation { - let eq_mle = eval_eq_packed(&point[1..]); - let inner_evals = columns - .par_iter() - .map(|col| { - col.chunks_exact(eq_mle.len()) - .map(|chunk| { - let ef_sum = dot_product::, _, _>(eq_mle.iter().copied(), chunk.iter().copied()); - as PackedFieldExtension>::to_ext_iter([ef_sum]).sum::() - }) - .collect::>() - }) - .flatten() - .collect::>(); - prover_state.add_extension_scalars(&inner_evals); - let mut values_to_prove = vec![]; - let pcs_batching_scalars_inputs = prover_state.sample_vec(univariate_skips); - let point_to_prove = MultilinearPoint([pcs_batching_scalars_inputs.clone(), point[1..].to_vec()].concat()); - for col_inner_evals in inner_evals.chunks_exact(1 << univariate_skips) { - values_to_prove.push(col_inner_evals.evaluate(&MultilinearPoint(pcs_batching_scalars_inputs.clone()))); - } - MultiEvaluation::new(point_to_prove, values_to_prove) -} diff --git a/crates/poseidon_circuit/src/tests.rs b/crates/poseidon_circuit/src/tests.rs deleted file mode 100644 index f55e35bc..00000000 --- a/crates/poseidon_circuit/src/tests.rs +++ /dev/null @@ -1,327 +0,0 @@ -use multilinear_toolkit::prelude::*; -use p3_koala_bear::{KoalaBear, KoalaBearInternalLayerParameters, KoalaBearParameters, QuinticExtensionFieldKB}; -use p3_monty_31::InternalLayerBaseParameters; -use rand::{Rng, SeedableRng, rngs::StdRng}; -use std::{array, time::Instant}; -use sub_protocols::{ - ColDims, packed_pcs_commit, packed_pcs_global_statements_for_prover, packed_pcs_global_statements_for_verifier, - packed_pcs_parse_commitment, -}; -use utils::{ - build_prover_state, build_verifier_state, init_tracing, poseidon16_permute_mut, poseidon24_permute_mut, - transposed_par_iter_mut, -}; -use whir_p3::{FoldingFactor, SecurityAssumption, WhirConfig, WhirConfigBuilder, precompute_dft_twiddles}; - -use crate::{ - GKRPoseidonResult, default_cube_layers, generate_poseidon_witness, gkr_layers::PoseidonGKRLayers, - prove_poseidon_gkr, verify_poseidon_gkr, -}; - -type F = KoalaBear; -type EF = QuinticExtensionFieldKB; - -const COMPRESSION_OUTPUT_WIDTH: usize = 8; - -#[test] -fn test_poseidon_benchmark() { - run_poseidon_benchmark::<16, 0, 3>(12, false, false); - run_poseidon_benchmark::<16, 0, 3>(12, true, false); - run_poseidon_benchmark::<16, 16, 3>(12, false, false); - run_poseidon_benchmark::<16, 16, 3>(12, true, false); -} - -pub fn run_poseidon_benchmark( - log_n_poseidons: usize, - compress: bool, - tracing: bool, -) where - KoalaBearInternalLayerParameters: InternalLayerBaseParameters, -{ - if tracing { - init_tracing(); - } - precompute_dft_twiddles::(1 << 24); - - let whir_config_builder = WhirConfigBuilder { - folding_factor: FoldingFactor::new(7, 4), - soundness_type: SecurityAssumption::CapacityBound, - pow_bits: 16, - max_num_variables_to_send_coeffs: 6, - rs_domain_initial_reduction_factor: 5, - security_level: 128, - starting_log_inv_rate: 1, - }; - let whir_n_vars = log_n_poseidons + log2_ceil_usize(WIDTH + N_COMMITED_CUBES); - let whir_config = WhirConfig::new(whir_config_builder.clone(), whir_n_vars); - - let mut rng = StdRng::seed_from_u64(0); - let n_poseidons = 1 << log_n_poseidons; - let n_compressions = if compress { n_poseidons / 3 } else { 0 }; - - let perm_inputs = (0..n_poseidons).map(|_| rng.random()).collect::>(); - let input: [_; WIDTH] = array::from_fn(|i| perm_inputs.par_iter().map(|x| x[i]).collect::>()); - let input_packed: [_; WIDTH] = array::from_fn(|i| PFPacking::::pack_slice(&input[i]).to_vec()); - - let layers = PoseidonGKRLayers::::build(compress.then_some(COMPRESSION_OUTPUT_WIDTH)); - - let default_cubes = default_cube_layers::(&layers); - - let input_col_dims = vec![ColDims::padded(n_poseidons, F::ZERO); WIDTH]; - let cubes_col_dims = default_cubes - .iter() - .map(|&v| ColDims::padded(n_poseidons, v)) - .collect::>(); - let committed_col_dims = [input_col_dims, cubes_col_dims].concat(); - - let log_smallest_decomposition_chunk = 0; // unused because everything is a power of 2 - - let (mut verifier_state, proof_size_pcs, proof_size_gkr, output_layer, prover_duration, output_statements_prover) = { - // ---------------------------------------------------- PROVER ---------------------------------------------------- - - let prover_time = Instant::now(); - - let witness = generate_poseidon_witness::, WIDTH, N_COMMITED_CUBES>( - input_packed, - &layers, - if compress { - Some( - PFPacking::::pack_slice( - &[ - vec![F::ZERO; n_poseidons - n_compressions], - vec![F::ONE; n_compressions], - ] - .concat(), - ) - .to_vec(), - ) - } else { - None - }, - ); - - let mut prover_state = build_prover_state::(false); - - let committed_polys = witness - .input_layer - .iter() - .chain(&witness.committed_cubes) - .map(|s| PFPacking::::unpack_slice(s)) - .collect::>(); - - let pcs_commitment_witness = packed_pcs_commit( - &whir_config_builder, - &committed_polys, - &committed_col_dims, - &mut prover_state, - log_smallest_decomposition_chunk, - ); - - let claim_point = prover_state.sample_vec(log_n_poseidons); - - let GKRPoseidonResult { - output_statements, - input_statements, - cubes_statements, - on_compression_selector, - } = prove_poseidon_gkr( - &mut prover_state, - &witness, - claim_point.clone(), - UNIVARIATE_SKIPS, - &layers, - ); - assert_eq!(&output_statements.point.0, &claim_point); - if let Some(on_compression_selector) = on_compression_selector { - assert_eq!( - on_compression_selector.value, - mle_of_zeros_then_ones((1 << log_n_poseidons) - n_compressions, &on_compression_selector.point,) - ); - } - - // PCS opening - let mut pcs_statements = vec![]; - for meval in [input_statements, cubes_statements] { - for v in meval.values { - pcs_statements.push(vec![Evaluation { - point: meval.point.clone(), - value: v, - }]); - } - } - - let proof_size_gkr = prover_state.proof_size(); - - let global_statements = packed_pcs_global_statements_for_prover( - &committed_polys, - &committed_col_dims, - log_smallest_decomposition_chunk, - &pcs_statements, - &mut prover_state, - ); - whir_config.prove( - &mut prover_state, - global_statements, - pcs_commitment_witness.inner_witness, - &pcs_commitment_witness.packed_polynomial.by_ref(), - ); - - let prover_duration = prover_time.elapsed(); - - let proof_size_pcs = prover_state.proof_size() - proof_size_gkr; - ( - build_verifier_state(prover_state), - proof_size_pcs, - proof_size_gkr, - match compress { - false => witness.output_layer, - true => witness.compression.unwrap().1, - }, - prover_duration, - output_statements, - ) - }; - - let verifier_time = Instant::now(); - - let output_statements_verifier = { - // ---------------------------------------------------- VERIFIER ---------------------------------------------------- - - let parsed_pcs_commitment = packed_pcs_parse_commitment( - &whir_config_builder, - &mut verifier_state, - &committed_col_dims, - log_smallest_decomposition_chunk, - ) - .unwrap(); - - let output_claim_point = verifier_state.sample_vec(log_n_poseidons); - - let GKRPoseidonResult { - output_statements, - input_statements, - cubes_statements, - on_compression_selector, - } = verify_poseidon_gkr( - &mut verifier_state, - log_n_poseidons, - &output_claim_point, - &layers, - UNIVARIATE_SKIPS, - compress, - ); - assert_eq!(&output_statements.point.0, &output_claim_point); - - if let Some(on_compression_selector) = on_compression_selector { - assert_eq!( - on_compression_selector.value, - mle_of_zeros_then_ones((1 << log_n_poseidons) - n_compressions, &on_compression_selector.point,) - ); - } - - // PCS verification - let mut pcs_statements = vec![]; - for meval in [input_statements, cubes_statements] { - for v in meval.values { - pcs_statements.push(vec![Evaluation { - point: meval.point.clone(), - value: v, - }]); - } - } - - let global_statements = packed_pcs_global_statements_for_verifier( - &committed_col_dims, - log_smallest_decomposition_chunk, - &pcs_statements, - &mut verifier_state, - &Default::default(), - ) - .unwrap(); - - whir_config - .verify::(&mut verifier_state, &parsed_pcs_commitment, global_statements) - .unwrap(); - output_statements - }; - let verifier_duration = verifier_time.elapsed(); - - let mut data_to_hash = input.clone(); - let plaintext_time = Instant::now(); - transposed_par_iter_mut(&mut data_to_hash).for_each(|row| { - if WIDTH == 16 { - let mut buff = array::from_fn(|j| *row[j]); - poseidon16_permute_mut(&mut buff); - for j in 0..WIDTH { - *row[j] = buff[j]; - } - } else if WIDTH == 24 { - let mut buff = array::from_fn(|j| *row[j]); - poseidon24_permute_mut(&mut buff); - for j in 0..WIDTH { - *row[j] = buff[j]; - } - } else { - panic!("Unsupported WIDTH"); - } - }); - let plaintext_duration = plaintext_time.elapsed(); - - // sanity check: ensure the plaintext poseidons matches the last GKR layer: - if compress { - output_layer - .iter() - .enumerate() - .take(COMPRESSION_OUTPUT_WIDTH) - .for_each(|(i, layer)| { - assert_eq!(PFPacking::::unpack_slice(layer), data_to_hash[i]); - }); - output_layer - .iter() - .enumerate() - .skip(COMPRESSION_OUTPUT_WIDTH) - .for_each(|(i, layer)| { - assert_eq!( - &PFPacking::::unpack_slice(layer)[..n_poseidons - n_compressions], - &data_to_hash[i][..n_poseidons - n_compressions] - ); - assert!( - PFPacking::::unpack_slice(layer)[n_poseidons - n_compressions..] - .iter() - .all(|&x| x.is_zero()) - ); - }); - } else { - output_layer.iter().enumerate().for_each(|(i, layer)| { - assert_eq!(PFPacking::::unpack_slice(layer), data_to_hash[i]); - }); - } - assert_eq!(&output_statements_prover, &output_statements_verifier); - assert_eq!( - &output_statements_verifier.values, - &output_layer - .iter() - .map(|layer| PFPacking::::unpack_slice(layer).evaluate(&output_statements_verifier.point)) - .collect::>() - ); - - println!("2^{log_n_poseidons} Poseidon2"); - println!( - "Plaintext (no proof) time: {:.3}s ({:.2}M Poseidons / s)", - plaintext_duration.as_secs_f64(), - n_poseidons as f64 / (plaintext_duration.as_secs_f64() * 1e6) - ); - println!( - "Prover time: {:.3}s ({:.2}M Poseidons / s, {:.1}x slower than plaintext)", - prover_duration.as_secs_f64(), - n_poseidons as f64 / (prover_duration.as_secs_f64() * 1e6), - prover_duration.as_secs_f64() / plaintext_duration.as_secs_f64() - ); - println!( - "Proof size: GKR = {:.1} KiB, PCS = {:.1} KiB . Total = {:.1} KiB (available optimizations: GKR = 40%, PCS = 15%)", - (proof_size_gkr * F::bits()) as f64 / (8.0 * 1024.0), - (proof_size_pcs * F::bits()) as f64 / (8.0 * 1024.0), - ((proof_size_gkr + proof_size_pcs) * F::bits()) as f64 / (8.0 * 1024.0), - ); - println!("Verifier time: {}ms", verifier_duration.as_millis()); -} diff --git a/crates/poseidon_circuit/src/utils.rs b/crates/poseidon_circuit/src/utils.rs deleted file mode 100644 index a610b926..00000000 --- a/crates/poseidon_circuit/src/utils.rs +++ /dev/null @@ -1,33 +0,0 @@ -use std::array; - -use multilinear_toolkit::prelude::*; -use p3_koala_bear::{GenericPoseidon2LinearLayersKoalaBear, KoalaBearInternalLayerParameters, KoalaBearParameters}; -use p3_monty_31::InternalLayerBaseParameters; -use p3_poseidon2::GenericPoseidon2LinearLayers; -use tracing::instrument; - -use crate::F; - -#[instrument(skip_all)] -pub fn build_poseidon_inv_matrices() -> ([[F; WIDTH]; WIDTH], [[F; WIDTH]; WIDTH]) -where - KoalaBearInternalLayerParameters: InternalLayerBaseParameters, -{ - let mut mds_matrix: [[F; WIDTH]; WIDTH] = array::from_fn(|_| array::from_fn(|_| F::ZERO)); - for (i, row) in mds_matrix.iter_mut().enumerate() { - row[i] = F::ONE; - GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(row); - } - mds_matrix = transpose_matrix(&mds_matrix); - let inv_mds_matrix = inverse_matrix(&mds_matrix); - - let mut light_matrix: [[F; WIDTH]; WIDTH] = array::from_fn(|_| array::from_fn(|_| F::ZERO)); - for (i, row) in light_matrix.iter_mut().enumerate() { - row[i] = F::ONE; - GenericPoseidon2LinearLayersKoalaBear::internal_linear_layer(row); - } - light_matrix = transpose_matrix(&light_matrix); - let inv_light_matrix = inverse_matrix(&light_matrix); - - (inv_mds_matrix, inv_light_matrix) -} diff --git a/crates/poseidon_circuit/src/verify.rs b/crates/poseidon_circuit/src/verify.rs deleted file mode 100644 index e3e8eca6..00000000 --- a/crates/poseidon_circuit/src/verify.rs +++ /dev/null @@ -1,232 +0,0 @@ -use multilinear_toolkit::prelude::*; -use p3_koala_bear::{KoalaBearInternalLayerParameters, KoalaBearParameters}; -use p3_monty_31::InternalLayerBaseParameters; - -use crate::{ - CompressionComputation, EF, F, FullRoundComputation, GKRPoseidonResult, PartialRoundComputation, - build_poseidon_inv_matrices, gkr_layers::PoseidonGKRLayers, -}; - -pub fn verify_poseidon_gkr( - verifier_state: &mut FSVerifier>, - log_n_poseidons: usize, - output_claim_point: &[EF], - layers: &PoseidonGKRLayers, - univariate_skips: usize, - compression: bool, -) -> GKRPoseidonResult -where - KoalaBearInternalLayerParameters: InternalLayerBaseParameters, -{ - let selectors = univariate_selectors::(univariate_skips); - let (inv_mds_matrix, inv_light_matrix) = build_poseidon_inv_matrices::(); - - let mut output_claims = vec![]; - let mut claims = vec![]; - - let mut point = { - let inner_evals = (0..WIDTH) - .map(|_| { - verifier_state - .next_extension_scalars_vec(1 << univariate_skips) - .unwrap() - }) - .collect::>(); - let alpha = verifier_state.sample(); - let selectors_at_alpha = selectors - .iter() - .map(|selector| selector.evaluate(alpha)) - .collect::>(); - for evals in inner_evals { - output_claims.push(evals.evaluate(&MultilinearPoint(output_claim_point[..univariate_skips].to_vec()))); - claims.push(dot_product(selectors_at_alpha.iter().copied(), evals.into_iter())) - } - [vec![alpha], output_claim_point[univariate_skips..].to_vec()].concat() - }; - - let on_compression_selector = if compression { - (point, claims) = verify_gkr_round( - verifier_state, - &CompressionComputation:: { - compressed_output: layers.compressed_output.unwrap(), - }, - log_n_poseidons, - &point, - &claims, - univariate_skips, - WIDTH + 1, - ); - - let inner_evals = verifier_state - .next_extension_scalars_vec(1 << univariate_skips) - .unwrap(); - let recomputed_value = - evaluate_univariate_multilinear::<_, _, _, false>(&inner_evals, &[point[0]], &selectors, None); - assert_eq!(claims.pop().unwrap(), recomputed_value); - let epsilons = verifier_state.sample_vec(univariate_skips); - let new_point = MultilinearPoint([epsilons.clone(), point[1..].to_vec()].concat()); - let new_eval = inner_evals.evaluate(&MultilinearPoint(epsilons)); - - Some(Evaluation::new(new_point, new_eval)) - } else { - None - }; - - for full_round_constants in layers.final_full_rounds.iter().rev() { - claims = apply_matrix(&inv_mds_matrix, &claims); - - (point, claims) = verify_gkr_round( - verifier_state, - &FullRoundComputation {}, - log_n_poseidons, - &point, - &claims, - univariate_skips, - WIDTH, - ); - - for (claim, c) in claims.iter_mut().zip(full_round_constants) { - *claim -= *c; - } - } - - for partial_round_constant in layers.partial_rounds_remaining.iter().rev() { - claims = apply_matrix(&inv_light_matrix, &claims); - - (point, claims) = verify_gkr_round( - verifier_state, - &PartialRoundComputation:: {}, - log_n_poseidons, - &point, - &claims, - univariate_skips, - WIDTH, - ); - - claims[0] -= *partial_round_constant; - } - - let mut pcs_point_for_cubes = vec![]; - let mut pcs_evals_for_cubes = vec![]; - if N_COMMITED_CUBES > 0 { - let claimed_cubes_evals = verifier_state.next_extension_scalars_vec(N_COMMITED_CUBES).unwrap(); - - (point, claims) = verify_gkr_round( - verifier_state, - layers.batch_partial_rounds.as_ref().unwrap(), - log_n_poseidons, - &point, - &[claims, claimed_cubes_evals.clone()].concat(), - univariate_skips, - WIDTH + N_COMMITED_CUBES, - ); - - pcs_point_for_cubes = point.clone(); - pcs_evals_for_cubes = claims[WIDTH..].to_vec(); - - claims = claims[..WIDTH].to_vec(); - } - - for full_round_constants in layers.initial_full_rounds.iter().rev() { - claims = apply_matrix(&inv_mds_matrix, &claims); - - (point, claims) = verify_gkr_round( - verifier_state, - &FullRoundComputation {}, - log_n_poseidons, - &point, - &claims, - univariate_skips, - WIDTH, - ); - - for (claim, c) in claims.iter_mut().zip(full_round_constants) { - *claim -= *c; - } - } - - claims = apply_matrix(&inv_mds_matrix, &claims); - - let pcs_point_for_inputs = point.clone(); - let pcs_evals_for_inputs = claims; - - let input_statements = verify_inner_evals_on_commited_columns( - verifier_state, - &pcs_point_for_inputs, - &pcs_evals_for_inputs, - &selectors, - ); - - let cubes_statements = if N_COMMITED_CUBES == 0 { - Default::default() - } else { - verify_inner_evals_on_commited_columns(verifier_state, &pcs_point_for_cubes, &pcs_evals_for_cubes, &selectors) - }; - - let output_statements = MultiEvaluation::new(MultilinearPoint(output_claim_point.to_vec()), output_claims); - GKRPoseidonResult { - output_statements, - input_statements, - cubes_statements, - on_compression_selector, - } -} - -fn verify_gkr_round>>( - verifier_state: &mut FSVerifier>, - computation: &SC, - log_n_poseidons: usize, - claim_point: &[EF], - output_claims: &[EF], - univariate_skips: usize, - n_inputs: usize, -) -> (Vec, Vec) { - let batching_scalar = verifier_state.sample(); - let batching_scalars_powers = batching_scalar.powers().collect_n(output_claims.len()); - let batched_claim: EF = dot_product(output_claims.iter().copied(), batching_scalar.powers()); - - let (retrieved_batched_claim, sumcheck_postponed_claim) = sumcheck_verify_with_univariate_skip( - verifier_state, - computation.degree() + 1, - log_n_poseidons, - univariate_skips, - ) - .unwrap(); - - assert_eq!(retrieved_batched_claim, batched_claim); - - let sumcheck_inner_evals = verifier_state.next_extension_scalars_vec(n_inputs).unwrap(); - assert_eq!( - computation.eval_extension(&sumcheck_inner_evals, &[], &batching_scalars_powers) - * eq_poly_with_skip(&sumcheck_postponed_claim.point, claim_point, univariate_skips), - sumcheck_postponed_claim.value - ); - - (sumcheck_postponed_claim.point.0, sumcheck_inner_evals) -} - -fn verify_inner_evals_on_commited_columns( - verifier_state: &mut FSVerifier>, - point: &[EF], - claimed_evals: &[EF], - selectors: &[DensePolynomial], -) -> MultiEvaluation { - let univariate_skips = log2_strict_usize(selectors.len()); - let inner_evals_inputs = verifier_state - .next_extension_scalars_vec(claimed_evals.len() << univariate_skips) - .unwrap(); - let pcs_batching_scalars_inputs = verifier_state.sample_vec(univariate_skips); - let mut values_to_verif = vec![]; - let point_to_verif = MultilinearPoint([pcs_batching_scalars_inputs.clone(), point[1..].to_vec()].concat()); - for (&eval, col_inner_evals) in claimed_evals - .iter() - .zip(inner_evals_inputs.chunks_exact(1 << univariate_skips)) - { - assert_eq!( - eval, - evaluate_univariate_multilinear::<_, _, _, false>(col_inner_evals, &point[..1], selectors, None) - ); - values_to_verif.push(col_inner_evals.evaluate(&MultilinearPoint(pcs_batching_scalars_inputs.clone()))); - } - MultiEvaluation::new(point_to_verif, values_to_verif) -} diff --git a/crates/poseidon_circuit/src/witness_gen.rs b/crates/poseidon_circuit/src/witness_gen.rs deleted file mode 100644 index 98eba6b8..00000000 --- a/crates/poseidon_circuit/src/witness_gen.rs +++ /dev/null @@ -1,246 +0,0 @@ -use std::array; - -use multilinear_toolkit::prelude::*; -use p3_koala_bear::GenericPoseidon2LinearLayersKoalaBear; -use p3_koala_bear::KoalaBearInternalLayerParameters; -use p3_koala_bear::KoalaBearParameters; -use p3_monty_31::InternalLayerBaseParameters; -use p3_poseidon2::GenericPoseidon2LinearLayers; -use utils::transposed_par_iter_mut; - -use crate::F; -use crate::gkr_layers::BatchPartialRounds; -use crate::gkr_layers::PoseidonGKRLayers; - -#[derive(Debug, Hash)] -pub struct PoseidonWitness { - pub input_layer: [Vec; WIDTH], // input of the permutation - pub initial_full_layers: Vec<[Vec; WIDTH]>, // just before cubing - pub batch_partial_round_input: Option<[Vec; WIDTH]>, // again, the input of the batch (partial) round - pub committed_cubes: [Vec; N_COMMITED_CUBES], // the cubes commited in the batch (partial) rounds - pub remaining_partial_round_layers: Vec<[Vec; WIDTH]>, // the input of each remaining partial round, just before cubing the first element - pub final_full_layers: Vec<[Vec; WIDTH]>, // just before cubing - pub output_layer: [Vec; WIDTH], // output of the permutation - pub compression: Option<(Vec, [Vec; WIDTH])>, // compression indicator column, compressed output -} - -impl PoseidonWitness, WIDTH, N_COMMITED_CUBES> { - pub fn n_poseidons(&self) -> usize { - self.input_layer[0].len() * packing_width::() - } -} - -pub fn generate_poseidon_witness( - input: [Vec; WIDTH], - layers: &PoseidonGKRLayers, - compression: Option>, -) -> PoseidonWitness -where - A: Algebra + Copy + Send + Sync, - KoalaBearInternalLayerParameters: InternalLayerBaseParameters, -{ - let mut initial_full_layers = vec![apply_full_round::<_, _, false, true, true>( - &input, - &layers.initial_full_rounds[0], - )]; - for constants in &layers.initial_full_rounds[1..] { - initial_full_layers.push(apply_full_round::<_, _, true, true, true>( - initial_full_layers.last().unwrap(), - constants, - )); - } - - let layer = apply_full_round::<_, _, true, true, false>( - initial_full_layers.last().unwrap(), - &[F::ZERO; WIDTH], // unused - ); - - let (batch_partial_round_layer, mut next_layer, committed_cubes) = if N_COMMITED_CUBES == 0 { - (None, layer, vec![].try_into().unwrap()) - } else { - let (next_layer, committed_cubes) = - apply_batch_partial_rounds(&layer, layers.batch_partial_rounds.as_ref().unwrap()); - (Some(layer), next_layer, committed_cubes) - }; - - next_layer[0] = next_layer[0] - .par_iter() - .map(|&val| val + layers.partial_rounds_remaining[0]) - .collect(); - let mut remaining_partial_round_layers = vec![next_layer]; - for &constant in &layers.partial_rounds_remaining[1..] { - remaining_partial_round_layers.push(apply_partial_round( - remaining_partial_round_layers.last().unwrap(), - Some(constant), - )); - } - - let mut final_full_layers = vec![apply_full_round::<_, _, false, false, true>( - &apply_partial_round(remaining_partial_round_layers.last().unwrap(), None), - &layers.final_full_rounds[0], - )]; - for constants in &layers.final_full_rounds[1..] { - final_full_layers.push(apply_full_round::<_, _, true, true, true>( - final_full_layers.last().unwrap(), - constants, - )); - } - - let output_layer = apply_full_round::<_, _, true, true, false>( - final_full_layers.last().unwrap(), - &[F::ZERO; WIDTH], // unused - ); - - let compression = compression.map(|indicator| { - let compressed_output = (0..WIDTH) - .into_par_iter() - .map(|col_idx| { - if col_idx < layers.compressed_output.unwrap() { - output_layer[col_idx].clone() - } else { - output_layer[col_idx] - .iter() - .zip(&indicator) - .map(|(out, ind)| *out * (A::ONE - *ind)) - .collect::>() - } - }) - .collect::>() - .try_into() - .unwrap(); - (indicator, compressed_output) - }); - - PoseidonWitness { - input_layer: input, - initial_full_layers, - batch_partial_round_input: batch_partial_round_layer, - committed_cubes, - remaining_partial_round_layers, - final_full_layers, - output_layer, - compression, - } -} - -// #[instrument(skip_all)] -fn apply_full_round( - input_layers: &[Vec; WIDTH], - constants: &[F; WIDTH], -) -> [Vec; WIDTH] -where - A: Algebra + Copy + Send + Sync, - KoalaBearInternalLayerParameters: InternalLayerBaseParameters, -{ - let mut output_layers: [_; WIDTH] = array::from_fn(|_| A::zero_vec(input_layers[0].len())); - transposed_par_iter_mut(&mut output_layers) - .enumerate() - .for_each(|(row_index, output_row)| { - let mut buff: [A; WIDTH] = array::from_fn(|j| input_layers[j][row_index]); - if CUBE { - for v in &mut buff { - *v = v.cube(); - } - } - if MDS { - GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(&mut buff); - } - if ADD_CONSTANTS { - buff.iter_mut().enumerate().for_each(|(j, val)| { - *val += constants[j]; - }); - } - for j in 0..WIDTH { - *output_row[j] = buff[j]; - } - }); - output_layers -} - -// #[instrument(skip_all)] -fn apply_partial_round( - input_layers: &[Vec], - partial_round_constant: Option, -) -> [Vec; WIDTH] -where - A: Algebra + Copy + Send + Sync, - KoalaBearInternalLayerParameters: InternalLayerBaseParameters, -{ - // cube single, light matrix mul, add single constant - let mut output_layers: [_; WIDTH] = array::from_fn(|_| A::zero_vec(input_layers[0].len())); - transposed_par_iter_mut(&mut output_layers) - .enumerate() - .for_each(|(row_index, output_row)| { - let mut buff = [A::ZERO; WIDTH]; - buff[0] = input_layers[0][row_index].cube(); - for j in 1..WIDTH { - buff[j] = input_layers[j][row_index]; - } - GenericPoseidon2LinearLayersKoalaBear::internal_linear_layer(&mut buff); - if let Some(constant) = partial_round_constant { - buff[0] += constant; - } - for j in 0..WIDTH { - *output_row[j] = buff[j]; - } - }); - output_layers -} - -// #[instrument(skip_all)] -fn apply_batch_partial_rounds( - input_layers: &[Vec], - rounds: &BatchPartialRounds, -) -> ([Vec; WIDTH], [Vec; N_COMMITED_CUBES]) -where - A: Algebra + Copy + Send + Sync, - KoalaBearInternalLayerParameters: InternalLayerBaseParameters, -{ - let mut output_layers: [_; WIDTH] = array::from_fn(|_| A::zero_vec(input_layers[0].len())); - let mut cubes: [_; N_COMMITED_CUBES] = array::from_fn(|_| A::zero_vec(input_layers[0].len())); - transposed_par_iter_mut(&mut output_layers) - .zip(transposed_par_iter_mut(&mut cubes)) - .enumerate() - .for_each(|(row_index, (output_row, cubes))| { - let mut buff: [A; WIDTH] = array::from_fn(|j| input_layers[j][row_index]); - for (i, &constant) in rounds.constants.iter().enumerate() { - *cubes[i] = (buff[0] + constant).cube(); - buff[0] = *cubes[i]; - GenericPoseidon2LinearLayersKoalaBear::internal_linear_layer(&mut buff); - } - buff[0] = (buff[0] + rounds.last_constant).cube(); - GenericPoseidon2LinearLayersKoalaBear::internal_linear_layer(&mut buff); - for j in 0..WIDTH { - *output_row[j] = buff[j]; - } - }); - (output_layers, cubes) -} - -pub fn default_cube_layers( - layers: &PoseidonGKRLayers, -) -> [A; N_COMMITED_CUBES] -where - A: Algebra + Copy + Send + Sync, - KoalaBearInternalLayerParameters: InternalLayerBaseParameters, -{ - if N_COMMITED_CUBES == 0 { - return vec![].try_into().unwrap(); - } - - generate_poseidon_witness::( - array::from_fn(|_| vec![A::ZERO]), - layers, - if layers.compressed_output.is_some() { - Some(vec![A::ZERO]) - } else { - None - }, - ) - .committed_cubes - .iter() - .map(|v| v[0]) - .collect::>() - .try_into() - .unwrap() -} diff --git a/crates/rec_aggregation/Cargo.toml b/crates/rec_aggregation/Cargo.toml index cf99d34a..07c2dbd4 100644 --- a/crates/rec_aggregation/Cargo.toml +++ b/crates/rec_aggregation/Cargo.toml @@ -12,15 +12,12 @@ xmss.workspace = true rand.workspace = true p3-poseidon2.workspace = true p3-koala-bear.workspace = true -p3-challenger.workspace = true -p3-air.workspace = true p3-symmetric.workspace = true p3-util.workspace = true whir-p3.workspace = true tracing.workspace = true air.workspace = true sub_protocols.workspace = true -lookup.workspace = true lean_vm.workspace = true serde_json.workspace = true lean_compiler.workspace = true diff --git a/crates/rec_aggregation/__init__.py b/crates/rec_aggregation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/crates/rec_aggregation/fiat_shamir.py b/crates/rec_aggregation/fiat_shamir.py new file mode 100644 index 00000000..fe2614b8 --- /dev/null +++ b/crates/rec_aggregation/fiat_shamir.py @@ -0,0 +1,116 @@ +from snark_lib import * +# FIAT SHAMIR layout: 17 field elements +# 0..8 -> first half of sponge state +# 8..16 -> second half of sponge state +# 16 -> transcript pointer + +from utils import * + + +def fs_new(transcript_ptr): + fs_state = Array(17) + set_to_16_zeros(fs_state) + fs_state[16] = transcript_ptr + duplexed = duplexing(fs_state) + return duplexed + + +def duplexing(fs): + new_fs = Array(17) + poseidon16(fs, fs + 8, new_fs, PERMUTATION) + new_fs[16] = fs[16] + return new_fs + + +def fs_grinding(fs, bits): + if bits == 0: + return fs # no grinding + left = Array(8) + grinding_witness = read_memory(fs[16]) + left[0] = grinding_witness + set_to_7_zeros(left + 1) + + fs_after_poseidon = Array(17) + poseidon16(left, fs + 8, fs_after_poseidon, PERMUTATION) + fs_after_poseidon[16] = fs[16] + 1 # one element read from transcript + + sampled = fs_after_poseidon[0] + _, sampled_low_bits_value = checked_decompose_bits(sampled, bits) + assert sampled_low_bits_value == 0 + + fs_duplexed = duplexing(fs_after_poseidon) + + return fs_duplexed + + +def fs_sample_ef(fs): + return fs + + +def fs_hint(fs, n): + # return the updated fiat-shamir, and a pointer to n field elements from the transcript + + transcript_ptr = fs[16] + new_fs = Array(17) + copy_16(fs, new_fs) + new_fs[16] = fs[16] + n # advance transcript pointer + return new_fs, transcript_ptr + + +def fs_receive_chunks(fs, n_chunks: Const): + # each chunk = 8 field elements + new_fs = Array(1 + 16 * n_chunks) + transcript_ptr = fs[16] + new_fs[16 * n_chunks] = transcript_ptr + 8 * n_chunks # advance transcript pointer + + poseidon16(transcript_ptr, fs + 8, new_fs, PERMUTATION) + for i in unroll(1, n_chunks): + poseidon16( + transcript_ptr + i * 8, + new_fs + ((i - 1) * 16 + 8), + new_fs + i * 16, + PERMUTATION, + ) + return new_fs + 16 * (n_chunks - 1), transcript_ptr + + +def fs_receive_ef(fs, n: Const): + new_fs, ef_ptr = fs_receive_chunks(fs, next_multiple_of(n * DIM, 8) / 8) + for i in unroll(n * DIM, next_multiple_of(n * DIM, 8)): + assert ef_ptr[i] == 0 + return new_fs, ef_ptr + + +def fs_print_state(fs_state): + for i in unroll(0, 17): + print(i, fs_state[i]) + return + + +def sample_bits_const(fs: Mut, n_samples: Const, K): + # return the updated fiat-shamir, and a pointer to n pointers, each pointing to 31 (boolean) field elements, + sampled_bits = Array(n_samples) + for i in unroll(0, (next_multiple_of(n_samples, 8) / 8) - 1): + for j in unroll(0, 8): + bits, _ = checked_decompose_bits(fs[j], K) + sampled_bits[i * 8 + j] = bits + fs = duplexing(fs) + # Last batch (may be partial) + for j in unroll(0, 8 - ((8 - (n_samples % 8)) % 8)): + bits, _ = checked_decompose_bits(fs[j], K) + sampled_bits[((next_multiple_of(n_samples, 8) / 8) - 1) * 8 + j] = bits + return duplexing(fs), sampled_bits + + +def sample_bits_dynamic(fs_state, n_samples, K): + new_fs_state: Imu + sampled_bits: Imu + for r in unroll(0, N_ROUNDS_BASE + 1): + if n_samples == NUM_QUERIES_BASE[r]: + new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_BASE[r], K) + return new_fs_state, sampled_bits + for r in unroll(0, N_ROUNDS_EXT + 1): + if n_samples == NUM_QUERIES_EXT[r]: + new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_EXT[r], K) + return new_fs_state, sampled_bits + assert False, "sample_bits_dynamic called with unsupported n_samples" diff --git a/crates/rec_aggregation/fiat_shamir.snark.py b/crates/rec_aggregation/fiat_shamir.snark.py new file mode 100644 index 00000000..7973b05b --- /dev/null +++ b/crates/rec_aggregation/fiat_shamir.snark.py @@ -0,0 +1,115 @@ +# FIAT SHAMIR layout: 17 field elements +# 0..8 -> first half of sponge state +# 8..16 -> second half of sponge state +# 16 -> transcript pointer + +from utils import * + + +def fs_new(transcript_ptr): + fs_state = Array(17) + set_to_16_zeros(fs_state) + fs_state[16] = transcript_ptr + duplexed = duplexing(fs_state) + return duplexed + + +def duplexing(fs): + new_fs = Array(17) + poseidon16(fs, fs + 8, new_fs, PERMUTATION) + new_fs[16] = fs[16] + return new_fs + + +def fs_grinding(fs, bits): + if bits == 0: + return fs # no grinding + left = Array(8) + grinding_witness = read_memory(fs[16]) + left[0] = grinding_witness + set_to_7_zeros(left + 1) + + fs_after_poseidon = Array(17) + poseidon16(left, fs + 8, fs_after_poseidon, PERMUTATION) + fs_after_poseidon[16] = fs[16] + 1 # one element read from transcript + + sampled = fs_after_poseidon[0] + _, sampled_low_bits_value = checked_decompose_bits(sampled, bits) + assert sampled_low_bits_value == 0 + + fs_duplexed = duplexing(fs_after_poseidon) + + return fs_duplexed + + +def fs_sample_ef(fs): + return fs + + +def fs_hint(fs, n): + # return the updated fiat-shamir, and a pointer to n field elements from the transcript + + transcript_ptr = fs[16] + new_fs = Array(17) + copy_16(fs, new_fs) + new_fs[16] = fs[16] + n # advance transcript pointer + return new_fs, transcript_ptr + + +def fs_receive_chunks(fs, n_chunks: Const): + # each chunk = 8 field elements + new_fs = Array(1 + 16 * n_chunks) + transcript_ptr = fs[16] + new_fs[16 * n_chunks] = transcript_ptr + 8 * n_chunks # advance transcript pointer + + poseidon16(transcript_ptr, fs + 8, new_fs, PERMUTATION) + for i in unroll(1, n_chunks): + poseidon16( + transcript_ptr + i * 8, + new_fs + ((i - 1) * 16 + 8), + new_fs + i * 16, + PERMUTATION, + ) + return new_fs + 16 * (n_chunks - 1), transcript_ptr + + +def fs_receive_ef(fs, n: Const): + new_fs, ef_ptr = fs_receive_chunks(fs, next_multiple_of(n * DIM, 8) / 8) + for i in unroll(n * DIM, next_multiple_of(n * DIM, 8)): + assert ef_ptr[i] == 0 + return new_fs, ef_ptr + + +def fs_print_state(fs_state): + for i in unroll(0, 17): + print(i, fs_state[i]) + return + + +def sample_bits_const(fs: Mut, n_samples: Const, K): + # return the updated fiat-shamir, and a pointer to n pointers, each pointing to 31 (boolean) field elements, + sampled_bits = Array(n_samples) + for i in unroll(0, (next_multiple_of(n_samples, 8) / 8) - 1): + for j in unroll(0, 8): + bits, _ = checked_decompose_bits(fs[j], K) + sampled_bits[i * 8 + j] = bits + fs = duplexing(fs) + # Last batch (may be partial) + for j in unroll(0, 8 - ((8 - (n_samples % 8)) % 8)): + bits, _ = checked_decompose_bits(fs[j], K) + sampled_bits[((next_multiple_of(n_samples, 8) / 8) - 1) * 8 + j] = bits + return duplexing(fs), sampled_bits + + +def sample_bits_dynamic(fs_state, n_samples, K): + new_fs_state: Imu + sampled_bits: Imu + for r in unroll(0, N_ROUNDS_BASE + 1): + if n_samples == NUM_QUERIES_BASE[r]: + new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_BASE[r], K) + return new_fs_state, sampled_bits + for r in unroll(0, N_ROUNDS_EXT + 1): + if n_samples == NUM_QUERIES_EXT[r]: + new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_EXT[r], K) + return new_fs_state, sampled_bits + assert False, "sample_bits_dynamic called with unsupported n_samples" diff --git a/crates/rec_aggregation/hashing.py b/crates/rec_aggregation/hashing.py new file mode 100644 index 00000000..bf67a9ef --- /dev/null +++ b/crates/rec_aggregation/hashing.py @@ -0,0 +1,128 @@ +from snark_lib import * + +COMPRESSION = 1 +PERMUTATION = 0 + +DIM = 5 # extension degree +VECTOR_LEN = 8 + +MERKLE_HEIGHTS_BASE = MERKLE_HEIGHTS_BASE_PLACEHOLDER +MERKLE_HEIGHTS_EXT = MERKLE_HEIGHTS_EXT_PLACEHOLDER +NUM_QUERIES_BASE = NUM_QUERIES_BASE_PLACEHOLDER +NUM_QUERIES_EXT = NUM_QUERIES_EXT_PLACEHOLDER +N_ROUNDS_BASE = len(NUM_QUERIES_BASE) - 1 +N_ROUNDS_EXT = len(NUM_QUERIES_EXT) - 1 + + +def batch_hash_slice(num_queries, all_data_to_hash, all_resulting_hashes, len): + if len == DIM * 2: + batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, DIM * 2) + return + if len == 16: + batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, 16) + return + if len == 1: + batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, 1) + return + assert False, "batch_hash_slice called with unsupported len" + + +def batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, len: Const): + for i in range(0, num_queries): + data = all_data_to_hash[i] + res = slice_hash(ZERO_VEC_PTR, data, len) + all_resulting_hashes[i] = res + return + + +def slice_hash(seed, data, len: Const): + states = Array(len * VECTOR_LEN) + poseidon16(ZERO_VEC_PTR, data, states, COMPRESSION) + state_indexes = Array(len) + state_indexes[0] = states + for j in unroll(1, len): + state_indexes[j] = state_indexes[j - 1] + VECTOR_LEN + poseidon16(state_indexes[j - 1], data + j * VECTOR_LEN, state_indexes[j], COMPRESSION) + return state_indexes[len - 1] + + +def merkle_verif_batch(n_paths, merkle_paths, leaves_digests, leave_positions, root, height, num_queries): + for i in unroll(0, N_ROUNDS_BASE + 1): + if height + num_queries * 1000 == MERKLE_HEIGHTS_BASE[i] + NUM_QUERIES_BASE[i] * 1000: + merkle_verif_batch_const( + NUM_QUERIES_BASE[i], + merkle_paths, + leaves_digests, + leave_positions, + root, + MERKLE_HEIGHTS_BASE[i], + ) + return + for i in unroll(0, N_ROUNDS_EXT + 1): + if height + num_queries * 1000 == MERKLE_HEIGHTS_EXT[i] + NUM_QUERIES_EXT[i] * 1000: + merkle_verif_batch_const( + NUM_QUERIES_EXT[i], + merkle_paths, + leaves_digests, + leave_positions, + root, + MERKLE_HEIGHTS_EXT[i], + ) + return + print(12345555) + print(height) + assert False + + +def merkle_verif_batch_const(n_paths: Const, merkle_paths, leaves_digests, leave_positions, root, height: Const): + # n_paths: F + # leaves_digests: pointer to a slice of n_paths vectorized pointers, each pointing to 1 chunk of 8 field elements + # leave_positions: pointer to a slice of n_paths field elements (each < 2^height) + # root: vectorized pointer to 1 chunk of 8 field elements + # height: F + + for i in unroll(0, n_paths): + merkle_verify( + leaves_digests[i], + merkle_paths + (i * height) * VECTOR_LEN, + leave_positions[i], + root, + height, + ) + + return + + +def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, root, height: Const): + states = Array(height * VECTOR_LEN) + + # First merkle round + match leaf_position_bits[0]: + case 0: + poseidon16(leaf_digest, merkle_path, states, COMPRESSION) + case 1: + poseidon16(merkle_path, leaf_digest, states, COMPRESSION) + + # Remaining merkle rounds + state_indexes = Array(height) + state_indexes[0] = states + for j in unroll(1, height): + state_indexes[j] = state_indexes[j - 1] + VECTOR_LEN + # Warning: this works only if leaf_position_bits[i] is known to be boolean: + match leaf_position_bits[j]: + case 0: + poseidon16( + state_indexes[j - 1], + merkle_path + j * VECTOR_LEN, + state_indexes[j], + COMPRESSION, + ) + case 1: + poseidon16( + merkle_path + j * VECTOR_LEN, + state_indexes[j - 1], + state_indexes[j], + COMPRESSION, + ) + copy_8(state_indexes[height - 1], root) + return diff --git a/crates/rec_aggregation/recursion.py b/crates/rec_aggregation/recursion.py new file mode 100644 index 00000000..26578128 --- /dev/null +++ b/crates/rec_aggregation/recursion.py @@ -0,0 +1,685 @@ +from snark_lib import * +from whir import * + +N_TABLES = N_TABLES_PLACEHOLDER +MIN_LOG_N_ROWS_PER_TABLE = MIN_LOG_N_ROWS_PER_TABLE_PLACEHOLDER +MAX_LOG_N_ROWS_PER_TABLE = MAX_LOG_N_ROWS_PER_TABLE_PLACEHOLDER +MIN_LOG_MEMORY_SIZE = MIN_LOG_MEMORY_SIZE_PLACEHOLDER +MAX_LOG_MEMORY_SIZE = MAX_LOG_MEMORY_SIZE_PLACEHOLDER +N_VARS_FIRST_GKR = N_VARS_FIRST_GKR_PLACEHOLDER +MAX_BUS_WIDTH = MAX_BUS_WIDTH_PLACEHOLDER +MAX_NUM_AIR_CONSTRAINTS = MAX_NUM_AIR_CONSTRAINTS_PLACEHOLDER +MEMORY_TABLE_INDEX = MEMORY_TABLE_INDEX_PLACEHOLDER + +LOOKUPS_F_INDEXES = LOOKUPS_F_INDEXES_PLACEHOLDER # [[_; ?]; N_TABLES] +LOOKUPS_F_VALUES = LOOKUPS_F_VALUES_PLACEHOLDER # [[[_; ?]; ?]; N_TABLES] + +LOOKUPS_EF_INDEXES = LOOKUPS_EF_INDEXES_PLACEHOLDER # [[_; ?]; N_TABLES] +LOOKUPS_EF_VALUES = LOOKUPS_EF_VALUES_PLACEHOLDER # [[_; ?]; N_TABLES] + +NUM_COLS_F_AIR = NUM_COLS_F_AIR_PLACEHOLDER +NUM_COLS_EF_AIR = NUM_COLS_EF_AIR_PLACEHOLDER + +NUM_COLS_F_COMMITED = NUM_COLS_F_COMMITED_PLACEHOLDER + +EXECUTION_TABLE_INDEX = EXECUTION_TABLE_INDEX_PLACEHOLDER +AIR_DEGREES = AIR_DEGREES_PLACEHOLDER # [_; N_TABLES] +N_AIR_COLUMNS_F = N_AIR_COLUMNS_F_PLACEHOLDER # [_; N_TABLES] +N_AIR_COLUMNS_EF = N_AIR_COLUMNS_EF_PLACEHOLDER # [_; N_TABLES] +AIR_DOWN_COLUMNS_F = AIR_DOWN_COLUMNS_F_PLACEHOLDER # [[_; ?]; N_TABLES] +AIR_DOWN_COLUMNS_EF = AIR_DOWN_COLUMNS_EF_PLACEHOLDER # [[_; _]; N_TABLES] + +NUM_BYTECODE_INSTRUCTIONS = NUM_BYTECODE_INSTRUCTIONS_PLACEHOLDER +N_COMMITTED_EXEC_COLUMNS = N_COMMITTED_EXEC_COLUMNS_PLACEHOLDER + +GUEST_BYTECODE_LEN = GUEST_BYTECODE_LEN_PLACEHOLDER +COL_PC = COL_PC_PLACEHOLDER +TOTAL_WHIR_STATEMENTS_BASE = TOTAL_WHIR_STATEMENTS_BASE_PLACEHOLDER +STARTING_PC = STARTING_PC_PLACEHOLDER +ENDING_PC = ENDING_PC_PLACEHOLDER +NONRESERVED_PROGRAM_INPUT_START = NONRESERVED_PROGRAM_INPUT_START_PLACEHOLDER + + +def main(): + mem = 0 + priv_start = mem[PRIVATE_INPUT_START_PTR] + proof_size = priv_start[0] + outer_public_memory_log_size = priv_start[1] + outer_public_memory_size = powers_of_two(outer_public_memory_log_size) + n_recursions = priv_start[2] + outer_public_memory = priv_start + 3 + proofs_start = outer_public_memory + outer_public_memory_size + for i in range(0, n_recursions): + proof_transcript = proofs_start + i * proof_size + recursion(outer_public_memory_log_size, outer_public_memory, proof_transcript) + return + + +def recursion(outer_public_memory_log_size, outer_public_memory, proof_transcript): + fs: Mut = fs_new(proof_transcript) + + # table dims + debug_assert(N_TABLES + 1 < VECTOR_LEN) # (because duplex only once bellow) + fs, mem_and_table_dims = fs_receive_chunks(fs, 1) + for i in unroll(N_TABLES + 1, 8): + assert mem_and_table_dims[i] == 0 + log_memory = mem_and_table_dims[0] + table_dims = mem_and_table_dims + 1 + + for i in unroll(0, N_TABLES): + n_vars_for_table = table_dims[i] + assert n_vars_for_table <= log_memory + assert MIN_LOG_N_ROWS_PER_TABLE <= n_vars_for_table + assert n_vars_for_table <= MAX_LOG_N_ROWS_PER_TABLE[i] + assert MIN_LOG_MEMORY_SIZE <= log_memory + assert log_memory <= MAX_LOG_MEMORY_SIZE + + # parse 1st whir commitment + fs, whir_base_root, whir_base_ood_points, whir_base_ood_evals = parse_whir_commitment_const(fs, NUM_OOD_COMMIT_BASE) + + logup_c = fs_sample_ef(fs) + fs = duplexing(fs) + logup_alpha = fs_sample_ef(fs) + fs = duplexing(fs) + + # GENRIC LOGUP + + fs, quotient_gkr, point_gkr, numerators_value, denominators_value = verify_gkr_quotient(fs, N_VARS_FIRST_GKR) + set_to_5_zeros(quotient_gkr) + + logup_alpha_powers = powers(logup_alpha, MAX_BUS_WIDTH) + + memory_and_acc_prefix = multilinear_location_prefix(0, N_VARS_FIRST_GKR - log_memory, point_gkr) + + fs, value_acc = fs_receive_ef(fs, 1) + fs, value_memory = fs_receive_ef(fs, 1) + + retrieved_numerators_value: Mut = opposite_extension_ret(mul_extension_ret(memory_and_acc_prefix, value_acc)) + + value_index = mle_of_01234567_etc(point_gkr + (N_VARS_FIRST_GKR - log_memory) * DIM, log_memory) + fingerprint_memory = fingerprint_2(MEMORY_TABLE_INDEX, value_index, value_memory, logup_alpha_powers) + retrieved_denominators_value: Mut = mul_extension_ret( + memory_and_acc_prefix, sub_extension_ret(logup_c, fingerprint_memory) + ) + + offset: Mut = powers_of_two(log_memory) + bus_numerators_values = DynArray([]) + bus_denominators_values = DynArray([]) + pcs_points = DynArray([]) # [[_; N]; N_TABLES] + for i in unroll(0, N_TABLES): + pcs_points.push(DynArray([])) + pcs_values = DynArray([]) # [[[[] or [_]; num cols]; N]; N_TABLES] + for i in unroll(0, N_TABLES): + pcs_values.push(DynArray([])) + for table_index in unroll(0, N_TABLES): + # I] Bus (data flow between tables) + + log_n_rows = table_dims[table_index] + n_rows = powers_of_two(log_n_rows) + inner_point = point_gkr + (N_VARS_FIRST_GKR - log_n_rows) * DIM + pcs_points[table_index].push(inner_point) + + prefix = multilinear_location_prefix(offset / n_rows, N_VARS_FIRST_GKR - log_n_rows, point_gkr) + + fs, eval_on_selector = fs_receive_ef(fs, 1) + retrieved_numerators_value = add_extension_ret( + retrieved_numerators_value, mul_extension_ret(prefix, eval_on_selector) + ) + + fs, eval_on_data = fs_receive_ef(fs, 1) + retrieved_denominators_value = add_extension_ret( + retrieved_denominators_value, mul_extension_ret(prefix, eval_on_data) + ) + + bus_numerators_values.push(eval_on_selector) + + bus_denominators_values.push(eval_on_data) + + offset += n_rows + + # II] Lookup into memory + + pcs_values[table_index].push(DynArray([])) + total_num_cols = NUM_COLS_F_AIR[table_index] + DIM * NUM_COLS_EF_AIR[table_index] + for col in unroll(0, total_num_cols): + pcs_values[table_index][0].push(DynArray([])) + + for lookup_f_index in unroll(0, len(LOOKUPS_F_INDEXES[table_index])): + col_index = LOOKUPS_F_INDEXES[table_index][lookup_f_index] + fs, index_eval = fs_receive_ef(fs, 1) + debug_assert(len(pcs_values[table_index][0][col_index]) == 0) + pcs_values[table_index][0][col_index].push(index_eval) + for i in unroll(0, len(LOOKUPS_F_VALUES[table_index][lookup_f_index])): + fs, value_eval = fs_receive_ef(fs, 1) + col_index = LOOKUPS_F_VALUES[table_index][lookup_f_index][i] + debug_assert(len(pcs_values[table_index][0][col_index]) == 0) + pcs_values[table_index][0][col_index].push(value_eval) + + pref = multilinear_location_prefix( + offset / n_rows, N_VARS_FIRST_GKR - log_n_rows, point_gkr + ) # TODO there is some duplication here + retrieved_numerators_value = add_extension_ret(retrieved_numerators_value, pref) + fingerp = fingerprint_2( + MEMORY_TABLE_INDEX, + add_base_extension_ret(i, index_eval), + value_eval, + logup_alpha_powers, + ) + retrieved_denominators_value = add_extension_ret( + retrieved_denominators_value, + mul_extension_ret(pref, sub_extension_ret(logup_c, fingerp)), + ) + + offset += n_rows + + for lookup_ef_index in unroll(0, len(LOOKUPS_EF_INDEXES[table_index])): + col_index = LOOKUPS_EF_INDEXES[table_index][lookup_ef_index] + fs, index_eval = fs_receive_ef(fs, 1) + if len(pcs_values[table_index][0][col_index]) == 0: + pcs_values[table_index][0][col_index].push(index_eval) + else: + # assert equal + copy_5(index_eval, pcs_values[table_index][0][col_index][0]) + + for i in unroll(0, DIM): + fs, value_eval = fs_receive_ef(fs, 1) + pref = multilinear_location_prefix( + offset / n_rows, N_VARS_FIRST_GKR - log_n_rows, point_gkr + ) # TODO there is some duplication here + retrieved_numerators_value = add_extension_ret(retrieved_numerators_value, pref) + fingerp = fingerprint_2( + MEMORY_TABLE_INDEX, + add_base_extension_ret(i, index_eval), + value_eval, + logup_alpha_powers, + ) + retrieved_denominators_value = add_extension_ret( + retrieved_denominators_value, + mul_extension_ret(pref, sub_extension_ret(logup_c, fingerp)), + ) + + global_index = ( + NUM_COLS_F_COMMITED[table_index] + LOOKUPS_EF_VALUES[table_index][lookup_ef_index] * DIM + i + ) + debug_assert(len(pcs_values[table_index][0][global_index]) == 0) + pcs_values[table_index][0][global_index].push(value_eval) + + offset += n_rows + + retrieved_denominators_value = add_extension_ret( + retrieved_denominators_value, + mle_of_zeros_then_ones(point_gkr, offset, N_VARS_FIRST_GKR), + ) + + copy_5(retrieved_numerators_value, numerators_value) + copy_5(retrieved_denominators_value, denominators_value) + + memory_acc_point = point_gkr + (N_VARS_FIRST_GKR - log_memory) * DIM + + # END OF GENERIC LOGUP + + # VERIFY BUS AND AIR + + bus_beta = fs_sample_ef(fs) + fs = duplexing(fs) + + air_alpha = fs_sample_ef(fs) + air_alpha_powers = powers_const(air_alpha, MAX_NUM_AIR_CONSTRAINTS + 1) + + for table_index in unroll(0, N_TABLES): + log_n_rows = table_dims[table_index] + bus_numerator_value = bus_numerators_values[table_index] + bus_denominator_value = bus_denominators_values[table_index] + total_num_cols = NUM_COLS_F_AIR[table_index] + DIM * NUM_COLS_EF_AIR[table_index] + + bus_final_value: Mut = bus_numerator_value + if table_index != EXECUTION_TABLE_INDEX - 1: # -1 because shift due to memory + bus_final_value = opposite_extension_ret(bus_final_value) + bus_final_value = add_extension_ret( + bus_final_value, + mul_extension_ret(bus_beta, sub_extension_ret(bus_denominator_value, logup_c)), + ) + + zerocheck_challenges = pcs_points[table_index][0] + + fs, outer_point, outer_eval = sumcheck_verify(fs, log_n_rows, bus_final_value, AIR_DEGREES[table_index] + 1) + + n_up_columns_f = N_AIR_COLUMNS_F[table_index] + n_up_columns_ef = N_AIR_COLUMNS_EF[table_index] + n_down_columns_f = len(AIR_DOWN_COLUMNS_F[table_index]) + n_down_columns_ef = len(AIR_DOWN_COLUMNS_EF[table_index]) + n_up_columns = n_up_columns_f + n_up_columns_ef + n_down_columns = n_down_columns_f + n_down_columns_ef + fs, inner_evals = fs_receive_ef(fs, n_up_columns + n_down_columns) + + air_constraints_eval = evaluate_air_constraints( + table_index, inner_evals, air_alpha_powers, bus_beta, logup_alpha_powers + ) + expected_outer_eval = mul_extension_ret( + air_constraints_eval, + eq_mle_extension(zerocheck_challenges, outer_point, log_n_rows), + ) + copy_5(expected_outer_eval, outer_eval) + + if len(AIR_DOWN_COLUMNS_F[table_index]) != 0: + batching_scalar = fs_sample_ef(fs) + batching_scalar_powers = powers_const(batching_scalar, n_down_columns) + evals_down_f = inner_evals + n_up_columns_f * DIM + evals_down_ef = inner_evals + (n_up_columns_f + n_down_columns_f + n_up_columns_ef) * DIM + inner_sum: Mut = dot_product_ret(evals_down_f, batching_scalar_powers, n_down_columns_f, EE) + if n_down_columns_ef != 0: + inner_sum = add_extension_ret( + inner_sum, + dot_product_ret( + evals_down_ef, + batching_scalar_powers + n_down_columns_f, + n_down_columns_ef, + EE, + ), + ) + + fs, inner_point, inner_value = sumcheck_verify(fs, log_n_rows, inner_sum, 2) + + matrix_down_sc_eval = next_mle(outer_point, inner_point, log_n_rows) + + fs, evals_f_on_down_columns = fs_receive_ef(fs, n_down_columns_f) + batched_col_down_sc_eval: Mut = dot_product_ret( + evals_f_on_down_columns, batching_scalar_powers, n_down_columns_f, EE + ) + + evals_ef_on_down_columns: Imu + if n_down_columns_ef != 0: + fs, evals_ef_on_down_columns = fs_receive_ef(fs, n_down_columns_ef) + batched_col_down_sc_eval = add_extension_ret( + batched_col_down_sc_eval, + dot_product_ret( + evals_ef_on_down_columns, + batching_scalar_powers + n_down_columns_f, + n_down_columns_ef, + EE, + ), + ) + + copy_5( + inner_value, + mul_extension_ret(batched_col_down_sc_eval, matrix_down_sc_eval), + ) + + pcs_points[table_index].push(inner_point) + pcs_values[table_index].push(DynArray([])) + last_index = len(pcs_values[table_index]) - 1 + for _ in unroll(0, total_num_cols): + pcs_values[table_index][last_index].push(DynArray([])) + for i in unroll(0, n_down_columns_f): + pcs_values[table_index][last_index][AIR_DOWN_COLUMNS_F[table_index][i]].push( + evals_f_on_down_columns + i * DIM + ) + for i in unroll(0, n_down_columns_ef): + fs, transposed = fs_receive_ef(fs, DIM) + copy_5( + evals_ef_on_down_columns + i * DIM, + dot_product_with_the_base_vectors(transposed), + ) + for j in unroll(0, DIM): + virtual_col_index = n_up_columns_f + AIR_DOWN_COLUMNS_EF[table_index][i] * DIM + j + pcs_values[table_index][last_index][virtual_col_index].push(transposed + j * DIM) + + pcs_points[table_index].push(outer_point) + pcs_values[table_index].push(DynArray([])) + last_index_2 = len(pcs_values[table_index]) - 1 + for _ in unroll(0, total_num_cols): + pcs_values[table_index][last_index_2].push(DynArray([])) + for i in unroll(0, n_up_columns_f): + pcs_values[table_index][last_index_2][i].push(inner_evals + i * DIM) + + for i in unroll(0, n_up_columns_ef): + fs, transposed = fs_receive_ef(fs, DIM) + copy_5( + inner_evals + (n_up_columns_f + n_down_columns_f + i) * DIM, + dot_product_with_the_base_vectors(transposed), + ) + for j in unroll(0, DIM): + virtual_col_index = n_up_columns_f + i * DIM + j + pcs_values[table_index][last_index_2][virtual_col_index].push(transposed + j * DIM) + + log_num_instrs = log2_ceil(NUM_BYTECODE_INSTRUCTIONS) + bytecode_compression_challenges = Array(DIM * log_num_instrs) + for i in unroll(0, log_num_instrs): + copy_5(fs_sample_ef(fs), bytecode_compression_challenges + i * DIM) # TODO avoid duplication + if i != log_num_instrs - 1: + fs = duplexing(fs) + + bytecode_air_values = Array(DIM * 2**log_num_instrs) + for i in unroll(0, NUM_BYTECODE_INSTRUCTIONS): + col = N_COMMITTED_EXEC_COLUMNS + i + copy_5( + pcs_values[EXECUTION_TABLE_INDEX - 1][2][col][0], + bytecode_air_values + i * DIM, + ) + pcs_values[EXECUTION_TABLE_INDEX - 1][2][col].pop() + for i in unroll(NUM_BYTECODE_INSTRUCTIONS, 2**log_num_instrs): + set_to_5_zeros(bytecode_air_values + i * DIM) + + bytecode_air_point = pcs_points[EXECUTION_TABLE_INDEX - 1][2] + bytecode_lookup_claim = dot_product_ret( + bytecode_air_values, + poly_eq_extension(bytecode_compression_challenges, log_num_instrs), + 2**log_num_instrs, + EE, + ) + + fs, whir_ext_root, whir_ext_ood_points, whir_ext_ood_evals = parse_whir_commitment_const(fs, NUM_OOD_COMMIT_EXT) + + # VERIFY LOGUP* + + log_table_len = log2_ceil(GUEST_BYTECODE_LEN) + log_n_cycles = table_dims[EXECUTION_TABLE_INDEX - 1] + fs, ls_sumcheck_point, ls_sumcheck_value = sumcheck_verify(fs, log_table_len, bytecode_lookup_claim, 2) + fs, table_eval = fs_receive_ef(fs, 1) + fs, pushforward_eval = fs_receive_ef(fs, 1) + mul_extension(table_eval, pushforward_eval, ls_sumcheck_value) + + ls_c = fs_sample_ef(fs) + + fs, quotient_left, claim_point_left, claim_num_left, eval_c_minus_indexes = verify_gkr_quotient(fs, log_n_cycles) + fs, quotient_right, claim_point_right, pushforward_final_eval, claim_den_right = verify_gkr_quotient( + fs, log_table_len + ) + + copy_5(quotient_left, quotient_right) + + copy_5( + eq_mle_extension(claim_point_left, bytecode_air_point, log_n_cycles), + claim_num_left, + ) + copy_5( + sub_extension_ret(ls_c, mle_of_01234567_etc(claim_point_right, log_table_len)), + claim_den_right, + ) + + # logupstar statements: + ls_on_indexes_point = claim_point_left + ls_on_indexes_eval = sub_extension_ret(ls_c, eval_c_minus_indexes) + ls_on_table_point = ls_sumcheck_point + ls_on_table_eval = table_eval + ls_on_pushforward_point_1 = ls_sumcheck_point + ls_on_pushforward_eval_1 = pushforward_eval + ls_on_pushforward_point_2 = claim_point_right + ls_on_pushforward_eval_2 = pushforward_final_eval + + # TODO evaluate the folded bytecode + + pcs_points[EXECUTION_TABLE_INDEX - 1].push(ls_on_indexes_point) + pcs_values[EXECUTION_TABLE_INDEX - 1].push(DynArray([])) + last_len = len(pcs_values[EXECUTION_TABLE_INDEX - 1]) - 1 + total_exec_cols = NUM_COLS_F_AIR[EXECUTION_TABLE_INDEX - 1] + DIM * NUM_COLS_EF_AIR[EXECUTION_TABLE_INDEX - 1] + for _ in unroll(0, total_exec_cols): + pcs_values[EXECUTION_TABLE_INDEX - 1][last_len].push(DynArray([])) + pcs_values[EXECUTION_TABLE_INDEX - 1][last_len][COL_PC].push(ls_on_indexes_eval) + + # verify the outer public memory is well constructed (with the conventions) + for i in unroll(0, next_multiple_of(NONRESERVED_PROGRAM_INPUT_START, DIM) / DIM): + copy_5(i * DIM, outer_public_memory + i * DIM) + + public_memory_random_point = Array(outer_public_memory_log_size * DIM) + for i in range(0, outer_public_memory_log_size): + copy_5(fs_sample_ef(fs), public_memory_random_point + i * DIM) + fs = duplexing(fs) + poly_eq_public_mem = poly_eq_extension_dynamic(public_memory_random_point, outer_public_memory_log_size) + public_memory_eval = Array(DIM) + dot_product_be_dynamic( + outer_public_memory, + poly_eq_public_mem, + public_memory_eval, + powers_of_two(outer_public_memory_log_size), + ) + + # WHIR BASE + combination_randomness_gen: Mut = fs_sample_ef(fs) + combination_randomness_powers: Mut = powers_const( + combination_randomness_gen, NUM_OOD_COMMIT_BASE + TOTAL_WHIR_STATEMENTS_BASE + ) + whir_sum: Mut = dot_product_ret(whir_base_ood_evals, combination_randomness_powers, NUM_OOD_COMMIT_BASE, EE) + curr_randomness: Mut = combination_randomness_powers + NUM_OOD_COMMIT_BASE * DIM + + whir_sum = add_extension_ret(mul_extension_ret(value_memory, curr_randomness), whir_sum) + curr_randomness += DIM + whir_sum = add_extension_ret(mul_extension_ret(value_acc, curr_randomness), whir_sum) + curr_randomness += DIM + whir_sum = add_extension_ret(mul_extension_ret(public_memory_eval, curr_randomness), whir_sum) + curr_randomness += DIM + + whir_sum = add_extension_ret(mul_extension_ret(embed_in_ef(STARTING_PC), curr_randomness), whir_sum) + curr_randomness += DIM + whir_sum = add_extension_ret(mul_extension_ret(embed_in_ef(ENDING_PC), curr_randomness), whir_sum) + curr_randomness += DIM + + for table_index in unroll(0, N_TABLES): + debug_assert(len(pcs_points[table_index]) == len(pcs_values[table_index])) + for i in unroll(0, len(pcs_values[table_index])): + for j in unroll(0, len(pcs_values[table_index][i])): + debug_assert(len(pcs_values[table_index][i][j]) < 2) + if len(pcs_values[table_index][i][j]) == 1: + whir_sum = add_extension_ret( + mul_extension_ret(pcs_values[table_index][i][j][0], curr_randomness), + whir_sum, + ) + curr_randomness += DIM + + folding_randomness_global: Mut + s: Mut + final_value: Mut + end_sum: Mut + fs, folding_randomness_global, s, final_value, end_sum = whir_open_base( + fs, + whir_base_root, + whir_base_ood_points, + combination_randomness_powers, + whir_sum, + ) + + curr_randomness = combination_randomness_powers + NUM_OOD_COMMIT_BASE * DIM + + eq_memory_acc_point = eq_mle_extension( + folding_randomness_global + (N_VARS_BASE - log_memory) * DIM, + memory_acc_point, + log_memory, + ) + prefix_mem = multilinear_location_prefix(0, N_VARS_BASE - log_memory, folding_randomness_global) + s = add_extension_ret( + s, + mul_extension_ret(mul_extension_ret(curr_randomness, prefix_mem), eq_memory_acc_point), + ) + curr_randomness += DIM + + prefix_acc = multilinear_location_prefix(1, N_VARS_BASE - log_memory, folding_randomness_global) + s = add_extension_ret( + s, + mul_extension_ret(mul_extension_ret(curr_randomness, prefix_acc), eq_memory_acc_point), + ) + curr_randomness += DIM + + eq_pub_mem = eq_mle_extension( + folding_randomness_global + (N_VARS_BASE - outer_public_memory_log_size) * DIM, + public_memory_random_point, + outer_public_memory_log_size, + ) + prefix_pub_mem = multilinear_location_prefix( + 0, N_VARS_BASE - outer_public_memory_log_size, folding_randomness_global + ) + s = add_extension_ret( + s, + mul_extension_ret(mul_extension_ret(curr_randomness, prefix_pub_mem), eq_pub_mem), + ) + curr_randomness += DIM + + offset = powers_of_two(log_memory) * 2 # memory and acc + + prefix_pc_start = multilinear_location_prefix( + offset + COL_PC * powers_of_two(log_n_cycles), + N_VARS_BASE, + folding_randomness_global, + ) + s = add_extension_ret(s, mul_extension_ret(curr_randomness, prefix_pc_start)) + curr_randomness += DIM + + prefix_pc_end = multilinear_location_prefix( + offset + (COL_PC + 1) * powers_of_two(log_n_cycles) - 1, + N_VARS_BASE, + folding_randomness_global, + ) + s = add_extension_ret(s, mul_extension_ret(curr_randomness, prefix_pc_end)) + curr_randomness += DIM + + for table_index in unroll(0, N_TABLES): + log_n_rows = table_dims[table_index] + n_rows = powers_of_two(log_n_rows) + total_num_cols = NUM_COLS_F_AIR[table_index] + DIM * NUM_COLS_EF_AIR[table_index] + for i in unroll(0, len(pcs_points[table_index])): + point = pcs_points[table_index][i] + eq_factor = eq_mle_extension( + point, + folding_randomness_global + (N_VARS_BASE - log_n_rows) * DIM, + log_n_rows, + ) + for j in unroll(0, total_num_cols): + if len(pcs_values[table_index][i][j]) == 1: + prefix = multilinear_location_prefix( + offset / n_rows + j, + N_VARS_BASE - log_n_rows, + folding_randomness_global, + ) + s = add_extension_ret( + s, + mul_extension_ret(mul_extension_ret(curr_randomness, prefix), eq_factor), + ) + curr_randomness += DIM + num_commited_cols: Imu + if table_index == EXECUTION_TABLE_INDEX - 1: + num_commited_cols = N_COMMITTED_EXEC_COLUMNS + else: + num_commited_cols = total_num_cols + offset += n_rows * num_commited_cols + + copy_5(mul_extension_ret(s, final_value), end_sum) + + # WHIR EXT (Pushforward) + combination_randomness_gen = fs_sample_ef(fs) + combination_randomness_powers = powers_const(combination_randomness_gen, NUM_OOD_COMMIT_EXT + 2) + whir_sum = dot_product_ret(whir_ext_ood_evals, combination_randomness_powers, NUM_OOD_COMMIT_EXT, EE) + whir_sum = add_extension_ret( + whir_sum, + mul_extension_ret( + combination_randomness_powers + NUM_OOD_COMMIT_EXT * DIM, + ls_on_pushforward_eval_1, + ), + ) + whir_sum = add_extension_ret( + whir_sum, + mul_extension_ret( + combination_randomness_powers + (NUM_OOD_COMMIT_EXT + 1) * DIM, + ls_on_pushforward_eval_2, + ), + ) + fs, folding_randomness_global, s, final_value, end_sum = whir_open_ext( + fs, whir_ext_root, whir_ext_ood_points, combination_randomness_powers, whir_sum + ) + + # Last TODO = Opening on the guest bytecode, but there are multiple ways to handle this + + return + + +def multilinear_location_prefix(offset, n_vars, point): + bits = checked_decompose_bits_small_value(offset, n_vars) + res = eq_mle_base_extension(bits, point, n_vars) + return res + + +def fingerprint_2(table_index, data_1, data_2, alpha_powers): + buff = Array(DIM * 2) + copy_5(data_1, buff) + copy_5(data_2, buff + DIM) + res: Mut = dot_product_ret(buff, alpha_powers + DIM, 2, EE) + res = add_base_extension_ret(table_index, res) + return res + + +def verify_gkr_quotient(fs: Mut, n_vars): + fs, nums = fs_receive_ef(fs, 2) + fs, denoms = fs_receive_ef(fs, 2) + + q1 = div_extension_ret(nums, denoms) + q2 = div_extension_ret(nums + DIM, denoms + DIM) + quotient = add_extension_ret(q1, q2) + + points = Array(n_vars) + claims_num = Array(n_vars) + claims_den = Array(n_vars) + + points[0] = fs_sample_ef(fs) + fs = duplexing(fs) + + point_poly_eq = poly_eq_extension(points[0], 1) + + first_claim_num = dot_product_ret(nums, point_poly_eq, 2, EE) + first_claim_den = dot_product_ret(denoms, point_poly_eq, 2, EE) + claims_num[0] = first_claim_num + claims_den[0] = first_claim_den + + for i in range(1, n_vars): + fs, points[i], claims_num[i], claims_den[i] = verify_gkr_quotient_step( + fs, i, points[i - 1], claims_num[i - 1], claims_den[i - 1] + ) + + return ( + fs, + quotient, + points[n_vars - 1], + claims_num[n_vars - 1], + claims_den[n_vars - 1], + ) + + +def verify_gkr_quotient_step(fs: Mut, n_vars, point, claim_num, claim_den): + alpha = fs_sample_ef(fs) + alpha_mul_claim_den = mul_extension_ret(alpha, claim_den) + num_plus_alpha_mul_claim_den = add_extension_ret(claim_num, alpha_mul_claim_den) + postponed_point = Array((n_vars + 1) * DIM) + fs, postponed_value = sumcheck_verify_helper(fs, n_vars, num_plus_alpha_mul_claim_den, 3, postponed_point + DIM) + fs, inner_evals = fs_receive_ef(fs, 4) + a_num = inner_evals + b_num = inner_evals + DIM + a_den = inner_evals + 2 * DIM + b_den = inner_evals + 3 * DIM + sum_num, sum_den = sum_2_ef_fractions(a_num, a_den, b_num, b_den) + sum_den_mul_alpha = mul_extension_ret(sum_den, alpha) + sum_num_plus_sum_den_mul_alpha = add_extension_ret(sum_num, sum_den_mul_alpha) + eq_factor = eq_mle_extension(point, postponed_point + DIM, n_vars) + mul_extension(sum_num_plus_sum_den_mul_alpha, eq_factor, postponed_value) + + beta = fs_sample_ef(fs) + fs = duplexing(fs) + point_poly_eq = poly_eq_extension(beta, 1) + new_claim_num = dot_product_ret(inner_evals, point_poly_eq, 2, EE) + new_claim_den = dot_product_ret(inner_evals + 2 * DIM, point_poly_eq, 2, EE) + + copy_5(beta, postponed_point) + + return fs, postponed_point, new_claim_num, new_claim_den + + +def evaluate_air_constraints(table_index, inner_evals, air_alpha_powers, bus_beta, bus_alpha_powers): + res: Imu + debug_assert(table_index < 3) + match table_index: + case 0: + res = evaluate_air_constraints_table_0(inner_evals, air_alpha_powers, bus_beta, bus_alpha_powers) + case 1: + res = evaluate_air_constraints_table_1(inner_evals, air_alpha_powers, bus_beta, bus_alpha_powers) + case 2: + res = evaluate_air_constraints_table_2(inner_evals, air_alpha_powers, bus_beta, bus_alpha_powers) + return res + + +EVALUATE_AIR_FUNCTIONS_PLACEHOLDER diff --git a/crates/rec_aggregation/src/lib.rs b/crates/rec_aggregation/src/lib.rs index a4cdd421..1f3d901e 100644 --- a/crates/rec_aggregation/src/lib.rs +++ b/crates/rec_aggregation/src/lib.rs @@ -1,4 +1,4 @@ #![cfg_attr(not(test), allow(unused_crate_dependencies))] -pub mod whir_recursion; +pub mod recursion; pub mod xmss_aggregate; diff --git a/crates/rec_aggregation/src/recursion.rs b/crates/rec_aggregation/src/recursion.rs new file mode 100644 index 00000000..0f4aabdb --- /dev/null +++ b/crates/rec_aggregation/src/recursion.rs @@ -0,0 +1,577 @@ +use std::collections::{BTreeMap, HashMap}; +use std::path::Path; +use std::rc::Rc; +use std::time::Instant; + +use lean_compiler::{CompilationFlags, ProgramSource, compile_program, compile_program_with_flags}; +use lean_prover::prove_execution::prove_execution; +use lean_prover::verify_execution::verify_execution; +use lean_prover::{STARTING_LOG_INV_RATE_BASE, STARTING_LOG_INV_RATE_EXTENSION, SnarkParams, whir_config_builder}; +use lean_vm::*; +use multilinear_toolkit::prelude::symbolic::{ + SymbolicExpression, SymbolicOperation, get_symbolic_constraints_and_bus_data_values, +}; +use multilinear_toolkit::prelude::*; +use utils::{Counter, MEMORY_TABLE_INDEX}; +use whir_p3::{WhirConfig, precompute_dft_twiddles}; + +pub fn run_recursion_benchmark(count: usize, tracing: bool) { + if tracing { + utils::init_tracing(); + } + let filepath = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("recursion.py") + .to_str() + .unwrap() + .to_string(); + + let snark_params = SnarkParams { + first_whir: whir_config_builder(STARTING_LOG_INV_RATE_BASE, 3, 1), + second_whir: whir_config_builder(STARTING_LOG_INV_RATE_EXTENSION, 4, 1), + }; + let program_to_prove = r#" +DIM = 5 +COMPRESSION = 1 +PERMUTATION = 0 +POSEIDON_OF_ZERO = POSEIDON_OF_ZERO_PLACEHOLDER +# Dot product precompile: +BE = 1 # base-extension +EE = 0 # extension-extension + +def main(): + for i in range(0, 1000): + null_ptr = ZERO_VEC_PTR # pointer to zero vector + poseidon_of_zero = POSEIDON_OF_ZERO + poseidon16(null_ptr, null_ptr, poseidon_of_zero, PERMUTATION) + poseidon16(null_ptr, null_ptr, poseidon_of_zero, COMPRESSION) + dot_product(null_ptr, null_ptr, null_ptr, 2, BE) + dot_product(null_ptr, null_ptr, null_ptr, 2, EE) + x: Mut = 0 + n = 10 + for j in range(0, n): + x += j + assert x == n * (n - 1) / 2 + n = 100000 + x = 0 + sum: Mut = x[0] + for i in unroll(0, n): + sum += i + assert sum == n * (n - 1) / 2 + return +"# + .replace("POSEIDON_OF_ZERO_PLACEHOLDER", &POSEIDON_16_NULL_HASH_PTR.to_string()); + let bytecode_to_prove = compile_program(&ProgramSource::Raw(program_to_prove.to_string())); + precompute_dft_twiddles::(1 << 24); + let outer_public_input = vec![]; + let outer_private_input = vec![]; + let proof_to_prove = prove_execution( + &bytecode_to_prove, + (&outer_public_input, &outer_private_input), + &vec![], + &snark_params, + false, + ); + let verif_details = verify_execution(&bytecode_to_prove, &[], proof_to_prove.proof.clone(), &snark_params).unwrap(); + + let base_whir = WhirConfig::::new(&snark_params.first_whir, proof_to_prove.first_whir_n_vars); + let ext_whir = WhirConfig::::new( + &snark_params.second_whir, + log2_ceil_usize(bytecode_to_prove.instructions.len()), + ); + + // let guest_program_commitment = { + // let mut prover_state = build_prover_state(); + // let polynomial = MleOwned::Base(bytecode_to_multilinear_polynomial(&bytecode_to_prove.instructions)); + // let witness = ext_whir.commit(&mut prover_state, &polynomial); + // let commitment_transcript = prover_state.proof().to_vec(); + // assert_eq!(commitment_transcript.len(), ext_whir.committment_ood_samples * DIMENSION + VECTOR_LEN); + // }; + + let mut replacements = whir_recursion_placeholder_replacements(&base_whir, true); + replacements.extend(whir_recursion_placeholder_replacements(&ext_whir, false)); + + assert!( + verif_details.log_memory >= verif_details.table_n_vars[&Table::execution()] + && verif_details + .table_n_vars + .values() + .collect::>() + .windows(2) + .all(|w| w[0] >= w[1]), + "TODO a more general recursion program", + ); + assert_eq!( + verif_details.table_n_vars.keys().copied().collect::>(), + vec![Table::execution(), Table::dot_product(), Table::poseidon16()] + ); + + // VM recursion parameters (different from WHIR) + replacements.insert( + "N_VARS_FIRST_GKR_PLACEHOLDER".to_string(), + verif_details.first_quotient_gkr_n_vars.to_string(), + ); + replacements.insert("N_TABLES_PLACEHOLDER".to_string(), N_TABLES.to_string()); + replacements.insert( + "MIN_LOG_N_ROWS_PER_TABLE_PLACEHOLDER".to_string(), + MIN_LOG_N_ROWS_PER_TABLE.to_string(), + ); + let mut max_log_n_rows_per_table = MAX_LOG_N_ROWS_PER_TABLE.to_vec(); + max_log_n_rows_per_table.sort_by_key(|(table, _)| table.index()); + max_log_n_rows_per_table.dedup(); + assert_eq!(max_log_n_rows_per_table.len(), N_TABLES); + replacements.insert( + "MAX_LOG_N_ROWS_PER_TABLE_PLACEHOLDER".to_string(), + format!( + "[{}]", + max_log_n_rows_per_table + .iter() + .map(|(_, v)| v.to_string()) + .collect::>() + .join(", ") + ), + ); + replacements.insert( + "MIN_LOG_MEMORY_SIZE_PLACEHOLDER".to_string(), + MIN_LOG_MEMORY_SIZE.to_string(), + ); + replacements.insert( + "MAX_LOG_MEMORY_SIZE_PLACEHOLDER".to_string(), + MAX_LOG_MEMORY_SIZE.to_string(), + ); + replacements.insert("MAX_BUS_WIDTH_PLACEHOLDER".to_string(), max_bus_width().to_string()); + replacements.insert( + "MEMORY_TABLE_INDEX_PLACEHOLDER".to_string(), + MEMORY_TABLE_INDEX.to_string(), + ); + replacements.insert( + "GUEST_BYTECODE_LEN_PLACEHOLDER".to_string(), + bytecode_to_prove.instructions.len().to_string(), + ); + replacements.insert("COL_PC_PLACEHOLDER".to_string(), COL_PC.to_string()); + replacements.insert( + "NONRESERVED_PROGRAM_INPUT_START_PLACEHOLDER".to_string(), + NONRESERVED_PROGRAM_INPUT_START.to_string(), + ); + + let mut lookup_f_indexes_str = vec![]; + let mut lookup_f_values_str = vec![]; + let mut lookup_ef_indexes_str = vec![]; + let mut lookup_ef_values_str = vec![]; + let mut num_cols_f_air = vec![]; + let mut num_cols_ef_air = vec![]; + let mut num_cols_f_committed = vec![]; + let mut air_degrees = vec![]; + let mut n_air_columns_f = vec![]; + let mut n_air_columns_ef = vec![]; + let mut air_down_columns_f = vec![]; + let mut air_down_columns_ef = vec![]; + for table in ALL_TABLES { + let this_look_f_indexes_str = table + .lookups_f() + .iter() + .map(|lookup_f| lookup_f.index.to_string()) + .collect::>(); + let this_look_ef_indexes_str = table + .lookups_ef() + .iter() + .map(|lookup_ef| lookup_ef.index.to_string()) + .collect::>(); + lookup_f_indexes_str.push(format!("[{}]", this_look_f_indexes_str.join(", "))); + lookup_ef_indexes_str.push(format!("[{}]", this_look_ef_indexes_str.join(", "))); + num_cols_f_air.push(table.n_columns_f_air().to_string()); + num_cols_ef_air.push(table.n_columns_ef_air().to_string()); + num_cols_f_committed.push(table.n_commited_columns_f().to_string()); + let this_lookup_f_values_str = table + .lookups_f() + .iter() + .map(|lookup_f| { + format!( + "[{}]", + lookup_f + .values + .iter() + .map(|v| v.to_string()) + .collect::>() + .join(", ") + ) + }) + .collect::>(); + let this_lookup_ef_values_str = table + .lookups_ef() + .iter() + .map(|lookup_ef| lookup_ef.values.to_string()) + .collect::>(); + lookup_f_values_str.push(format!("[{}]", this_lookup_f_values_str.join(", "))); + lookup_ef_values_str.push(format!("[{}]", this_lookup_ef_values_str.join(", "))); + air_degrees.push(table.degree_air().to_string()); + n_air_columns_f.push(table.n_columns_f_air().to_string()); + n_air_columns_ef.push(table.n_columns_ef_air().to_string()); + air_down_columns_f.push(format!( + "[{}]", + table + .down_column_indexes_f() + .iter() + .map(|v| v.to_string()) + .collect::>() + .join(", ") + )); + air_down_columns_ef.push(format!( + "[{}]", + table + .down_column_indexes_ef() + .iter() + .map(|v| v.to_string()) + .collect::>() + .join(", ") + )); + } + replacements.insert( + "LOOKUPS_F_INDEXES_PLACEHOLDER".to_string(), + format!("[{}]", lookup_f_indexes_str.join(", ")), + ); + replacements.insert( + "LOOKUPS_F_VALUES_PLACEHOLDER".to_string(), + format!("[{}]", lookup_f_values_str.join(", ")), + ); + replacements.insert( + "NUM_COLS_F_AIR_PLACEHOLDER".to_string(), + format!("[{}]", num_cols_f_air.join(", ")), + ); + replacements.insert( + "NUM_COLS_EF_AIR_PLACEHOLDER".to_string(), + format!("[{}]", num_cols_ef_air.join(", ")), + ); + replacements.insert( + "NUM_COLS_F_COMMITED_PLACEHOLDER".to_string(), + format!("[{}]", num_cols_f_committed.join(", ")), + ); + replacements.insert( + "LOOKUPS_EF_INDEXES_PLACEHOLDER".to_string(), + format!("[{}]", lookup_ef_indexes_str.join(", ")), + ); + replacements.insert( + "LOOKUPS_EF_VALUES_PLACEHOLDER".to_string(), + format!("[{}]", lookup_ef_values_str.join(", ")), + ); + replacements.insert( + "EXECUTION_TABLE_INDEX_PLACEHOLDER".to_string(), + Table::execution().index().to_string(), + ); + replacements.insert( + "MAX_NUM_AIR_CONSTRAINTS_PLACEHOLDER".to_string(), + max_air_constraints().to_string(), + ); + replacements.insert( + "AIR_DEGREES_PLACEHOLDER".to_string(), + format!("[{}]", air_degrees.join(", ")), + ); + replacements.insert( + "N_AIR_COLUMNS_F_PLACEHOLDER".to_string(), + format!("[{}]", n_air_columns_f.join(", ")), + ); + replacements.insert( + "N_AIR_COLUMNS_EF_PLACEHOLDER".to_string(), + format!("[{}]", n_air_columns_ef.join(", ")), + ); + replacements.insert( + "AIR_DOWN_COLUMNS_F_PLACEHOLDER".to_string(), + format!("[{}]", air_down_columns_f.join(", ")), + ); + replacements.insert( + "AIR_DOWN_COLUMNS_EF_PLACEHOLDER".to_string(), + format!("[{}]", air_down_columns_ef.join(", ")), + ); + replacements.insert( + "EVALUATE_AIR_FUNCTIONS_PLACEHOLDER".to_string(), + all_air_evals_in_zk_dsl(), + ); + replacements.insert( + "NUM_BYTECODE_INSTRUCTIONS_PLACEHOLDER".to_string(), + N_INSTRUCTION_COLUMNS.to_string(), + ); + replacements.insert( + "N_COMMITTED_EXEC_COLUMNS_PLACEHOLDER".to_string(), + N_COMMITTED_EXEC_COLUMNS.to_string(), + ); + replacements.insert( + "TOTAL_WHIR_STATEMENTS_BASE_PLACEHOLDER".to_string(), + verif_details.total_whir_statements_base.to_string(), + ); + replacements.insert("STARTING_PC_PLACEHOLDER".to_string(), STARTING_PC.to_string()); + replacements.insert("ENDING_PC_PLACEHOLDER".to_string(), ENDING_PC.to_string()); + + let inner_public_input = vec![]; + let outer_public_memory = build_public_memory(&outer_public_input); + let mut inner_private_input = vec![ + F::from_usize(proof_to_prove.proof.len()), + F::from_usize(log2_strict_usize(outer_public_memory.len())), + F::from_usize(count), + ]; + inner_private_input.extend(outer_public_memory); + for _ in 0..count { + inner_private_input.extend(proof_to_prove.proof.to_vec()); + } + + let recursion_bytecode = + compile_program_with_flags(&ProgramSource::Filepath(filepath), CompilationFlags { replacements }); + + let time = Instant::now(); + + let recursion_proof = prove_execution( + &recursion_bytecode, + (&inner_public_input, &inner_private_input), + &vec![], // TODO precompute poseidons + &Default::default(), + false, + ); + let proving_time = time.elapsed(); + verify_execution( + &recursion_bytecode, + &inner_public_input, + recursion_proof.proof, + &Default::default(), + ) + .unwrap(); + println!( + "(Outer proof: 2**{} memory, 2**{} cycles, 2**{} dot_product_rows, 2**{} poseidons)", + verif_details.log_memory, + verif_details.table_n_vars[&Table::execution()], + verif_details.table_n_vars[&Table::dot_product()], + verif_details.table_n_vars[&Table::poseidon16()], + ); + println!("{}", recursion_proof.exec_summary); + println!( + "{}->1 recursion proving time: {} ms (1->1: {} ms), proof size: {} KiB (not optimized)", + count, + proving_time.as_millis(), + proving_time.as_millis() / count as u128, + recursion_proof.proof_size_fe * F::bits() / (8 * 1024) + ); +} + +pub(crate) fn whir_recursion_placeholder_replacements( + whir_config: &WhirConfig, + base: bool, +) -> BTreeMap { + let mut num_queries = vec![]; + let mut ood_samples = vec![]; + let mut grinding_bits = vec![]; + let merkle_heights = (0..=whir_config.n_rounds()) + .map(|r| whir_config.merkle_tree_height(r).to_string()) + .collect::>(); + let mut folding_factors = vec![]; + for round in &whir_config.round_parameters { + num_queries.push(round.num_queries.to_string()); + ood_samples.push(round.ood_samples.to_string()); + grinding_bits.push(round.pow_bits.to_string()); + folding_factors.push(round.folding_factor.to_string()); + } + folding_factors.push(whir_config.final_round_config().folding_factor.to_string()); + grinding_bits.push(whir_config.final_pow_bits.to_string()); + num_queries.push(whir_config.final_queries.to_string()); + + let end = if base { "_BASE_PLACEHOLDER" } else { "_EXT_PLACEHOLDER" }; + let mut replacements = BTreeMap::new(); + replacements.insert( + format!("MERKLE_HEIGHTS{}", end), + format!("[{}]", merkle_heights.join(", ")), + ); + replacements.insert(format!("NUM_QUERIES{}", end), format!("[{}]", num_queries.join(", "))); + replacements.insert( + format!("NUM_OOD_COMMIT{}", end), + whir_config.committment_ood_samples.to_string(), + ); + replacements.insert(format!("NUM_OODS{}", end), format!("[{}]", ood_samples.join(", "))); + replacements.insert( + format!("GRINDING_BITS{}", end), + format!("[{}]", grinding_bits.join(", ")), + ); + replacements.insert( + format!("FOLDING_FACTORS{}", end), + format!("[{}]", folding_factors.join(", ")), + ); + replacements.insert(format!("N_VARS{}", end), whir_config.num_variables.to_string()); + replacements.insert( + format!("LOG_INV_RATE{}", end), + whir_config.starting_log_inv_rate.to_string(), + ); + replacements.insert( + format!("FINAL_VARS{}", end), + whir_config.n_vars_of_final_polynomial().to_string(), + ); + replacements.insert( + format!("FIRST_RS_REDUCTION_FACTOR{}", end), + whir_config.rs_domain_initial_reduction_factor.to_string(), + ); + replacements +} + +fn all_air_evals_in_zk_dsl() -> String { + let mut res = String::new(); + res += &air_eval_in_zk_dsl(ExecutionTable:: {}); + res += &air_eval_in_zk_dsl(DotProductPrecompile:: {}); + res += &air_eval_in_zk_dsl(Poseidon16Precompile:: {}); + res +} + +const AIR_INNER_VALUES_VAR: &str = "inner_evals"; + +fn air_eval_in_zk_dsl(table: T) -> String +where + T::ExtraData: Default, +{ + let (constraints, bus_data) = get_symbolic_constraints_and_bus_data_values::(&table); + let mut vars_counter = Counter::new(); + let mut cache: HashMap<*const (), String> = HashMap::new(); + + let mut res = format!( + "def evaluate_air_constraints_table_{}({}, air_alpha_powers, bus_beta, bus_alpha_powers):\n", + table.table().index() - 1, + AIR_INNER_VALUES_VAR + ); + + let mut constraints_evals = vec![]; + for constraint in &constraints { + constraints_evals.push(write_down_air_constraint_eval( + constraint, + &mut cache, + &mut res, + &mut vars_counter, + )); + } + + // first: bus data + let table_index = match table.bus().table { + BusTable::Constant(c) => format!("embed_in_ef({})", c.index()), + BusTable::Variable(col) => format!("{} + DIM * {}", AIR_INNER_VALUES_VAR, col), + }; + let flag = format!("{} + DIM * {}", AIR_INNER_VALUES_VAR, table.bus().selector); + res += &format!("\n buff = Array(DIM * {})", bus_data.len()); + for (i, data) in bus_data.iter().enumerate() { + let data_str = write_down_air_constraint_eval(data, &mut cache, &mut res, &mut vars_counter); + res += &format!("\n copy_5({}, buff + DIM * {})", data_str, i); + } + res += &format!( + "\n bus_res: Mut = dot_product_ret(buff, bus_alpha_powers + DIM, {}, EE)", + bus_data.len() + ); + res += &format!("\n bus_res = add_extension_ret({}, bus_res)", table_index); + res += "\n bus_res = mul_extension_ret(bus_res, bus_beta)"; + res += &format!("\n sum: Mut = add_extension_ret(bus_res, {})", flag); + + for (index, constraint_eval) in constraints_evals.iter().enumerate() { + res += format!( + "\n sum = add_extension_ret(sum, mul_extension_ret(air_alpha_powers + {} * DIM, {}))", + index + 1, + constraint_eval + ) + .as_str(); + } + + res += "\n return sum"; + res += "\n"; + res +} + +fn write_down_air_constraint_eval( + constraint: &SymbolicExpression, + cache: &mut HashMap<*const (), String>, + res: &mut String, + vars_counter: &mut Counter, +) -> String { + match constraint { + SymbolicExpression::Constant(_) => { + unreachable!() + } + SymbolicExpression::Variable(v) => { + format!("{} + DIM * {}", AIR_INNER_VALUES_VAR, v.index) + } + SymbolicExpression::Operation(operation) => { + let key = Rc::as_ptr(operation) as *const (); + if let Some(var_name) = cache.get(&key) { + return var_name.clone(); + } + let (op, args) = &**operation; + + let new_var = match *op { + SymbolicOperation::Neg => { + assert_eq!(args.len(), 1); + let arg_str = write_down_air_constraint_eval(&args[0], cache, res, vars_counter); + let aux_var = format!("aux_{}", vars_counter.get_next()); + res.push_str(&format!("\n {} = opposite_extension_ret({})", aux_var, arg_str)); + return aux_var; + } + SymbolicOperation::Add => handle_operation_on_two( + args, + cache, + res, + vars_counter, + ("add_base_extension_ret", "add_base_extension_ret", "add_extension_ret"), + true, + ), + SymbolicOperation::Sub => handle_operation_on_two( + args, + cache, + res, + vars_counter, + ("sub_base_extension_ret", "sub_extension_base_ret", "sub_extension_ret"), + false, + ), + SymbolicOperation::Mul => handle_operation_on_two( + args, + cache, + res, + vars_counter, + ("mul_base_extension_ret", "mul_base_extension_ret", "mul_extension_ret"), + true, + ), + }; + assert!(!cache.contains_key(&key)); + cache.insert(key, new_var.clone()); + new_var + } + } +} + +fn handle_operation_on_two( + args: &[SymbolicExpression], + cache: &mut HashMap<*const (), String>, + res: &mut String, + vars_counter: &mut Counter, + (be_func, eb_func, ee_func): (&str, &str, &str), + switch_args: bool, +) -> String { + assert_eq!(args.len(), 2); + if let SymbolicExpression::Constant(c1) = args[0] { + let arg2_str = write_down_air_constraint_eval(&args[1], cache, res, vars_counter); + let aux_var = format!("aux_{}", vars_counter.get_next()); + res.push_str(&format!("\n {} = {}({}, {})", aux_var, be_func, c1, arg2_str)); + return aux_var; + } + if let SymbolicExpression::Constant(c2) = args[1] { + let arg1_str = write_down_air_constraint_eval(&args[0], cache, res, vars_counter); + let aux_var = format!("aux_{}", vars_counter.get_next()); + let (term0, term1) = if switch_args { + (c2.to_string(), arg1_str) + } else { + (arg1_str, c2.to_string()) + }; + res.push_str(&format!("\n {} = {}({}, {})", aux_var, eb_func, term0, term1)); + return aux_var; + } + let arg1_str = write_down_air_constraint_eval(&args[0], cache, res, vars_counter); + let arg2_str = write_down_air_constraint_eval(&args[1], cache, res, vars_counter); + let aux_var = format!("aux_{}", vars_counter.get_next()); + res.push_str(&format!("\n {} = {}({}, {})", aux_var, ee_func, arg1_str, arg2_str)); + aux_var +} + +#[test] +fn display_all_air_evals_in_zk_dsl() { + println!("{}", all_air_evals_in_zk_dsl()); +} + +#[test] +fn test_end2end_recursion() { + run_recursion_benchmark(1, false); +} diff --git a/crates/rec_aggregation/src/whir_recursion.rs b/crates/rec_aggregation/src/whir_recursion.rs deleted file mode 100644 index 3fbcbfee..00000000 --- a/crates/rec_aggregation/src/whir_recursion.rs +++ /dev/null @@ -1,180 +0,0 @@ -use std::collections::BTreeMap; -use std::collections::VecDeque; -use std::path::Path; -use std::time::Instant; - -use lean_compiler::CompilationFlags; -use lean_compiler::ProgramSource; -use lean_compiler::compile_program_with_flags; -use lean_prover::prove_execution::prove_execution; -use lean_prover::verify_execution::verify_execution; -use lean_prover::whir_config_builder; -use lean_vm::*; -use multilinear_toolkit::prelude::*; -use rand::Rng; -use rand::SeedableRng; -use rand::rngs::StdRng; -use utils::build_challenger; -use utils::{build_prover_state, padd_with_zero_to_next_multiple_of}; -use whir_p3::{FoldingFactor, SecurityAssumption, WhirConfig, WhirConfigBuilder, precompute_dft_twiddles}; - -const NUM_VARIABLES: usize = 25; - -pub fn run_whir_recursion_benchmark(tracing: bool, n_recursions: usize) { - let src_file = Path::new(env!("CARGO_MANIFEST_DIR")).join("whir_recursion.snark"); - let filepath = src_file.to_str().unwrap().to_string(); - let recursion_config_builder = WhirConfigBuilder { - max_num_variables_to_send_coeffs: 6, - security_level: 128, - pow_bits: 17, - folding_factor: FoldingFactor::new(7, 4), - soundness_type: SecurityAssumption::CapacityBound, - starting_log_inv_rate: 2, - rs_domain_initial_reduction_factor: 3, - }; - - let mut replacements = BTreeMap::new(); - replacements.insert("N_RECURSIONS_PLACEHOLDER".to_string(), n_recursions.to_string()); - - let mut recursion_config = WhirConfig::::new(recursion_config_builder.clone(), NUM_VARIABLES); - - // TODO remove overriding this - { - recursion_config.committment_ood_samples = 1; - for round in &mut recursion_config.round_parameters { - round.ood_samples = 1; - } - } - - assert_eq!(recursion_config.committment_ood_samples, 1); - // println!("Whir parameters: {}", params.to_string()); - for (i, round) in recursion_config.round_parameters.iter().enumerate() { - replacements.insert(format!("NUM_QUERIES_{i}_PLACEHOLDER"), round.num_queries.to_string()); - replacements.insert(format!("GRINDING_BITS_{i}_PLACEHOLDER"), round.pow_bits.to_string()); - } - - replacements.insert( - format!("NUM_QUERIES_{}_PLACEHOLDER", recursion_config.n_rounds()), - recursion_config.final_queries.to_string(), - ); - replacements.insert( - format!("GRINDING_BITS_{}_PLACEHOLDER", recursion_config.n_rounds()), - recursion_config.final_pow_bits.to_string(), - ); - replacements.insert("N_VARS_PLACEHOLDER".to_string(), NUM_VARIABLES.to_string()); - replacements.insert( - "LOG_INV_RATE_PLACEHOLDER".to_string(), - recursion_config_builder.starting_log_inv_rate.to_string(), - ); - assert_eq!(recursion_config.n_rounds(), 3); // this is hardcoded in the program above - for round in 0..=recursion_config.n_rounds() { - replacements.insert( - format!("FOLDING_FACTOR_{round}_PLACEHOLDER"), - recursion_config_builder.folding_factor.at_round(round).to_string(), - ); - } - replacements.insert( - "RS_REDUCTION_FACTOR_0_PLACEHOLDER".to_string(), - recursion_config_builder.rs_domain_initial_reduction_factor.to_string(), - ); - - let mut rng = StdRng::seed_from_u64(0); - let polynomial = MleOwned::Base((0..1 << NUM_VARIABLES).map(|_| rng.random()).collect::>()); - - let point = MultilinearPoint::((0..NUM_VARIABLES).map(|_| rng.random()).collect()); - - let mut statement = Vec::new(); - let eval = polynomial.evaluate(&point); - statement.push(Evaluation::new(point.clone(), eval)); - - let mut prover_state = build_prover_state(true); - - precompute_dft_twiddles::(1 << 24); - - let witness = recursion_config.commit(&mut prover_state, &polynomial); - recursion_config.prove(&mut prover_state, statement.clone(), witness, &polynomial.by_ref()); - let whir_proof = prover_state.into_proof(); - - { - let mut verifier_state = VerifierState::new(whir_proof.clone(), build_challenger()); - let parsed_commitment = recursion_config.parse_commitment::(&mut verifier_state).unwrap(); - recursion_config - .verify(&mut verifier_state, &parsed_commitment, statement) - .unwrap(); - } - - let commitment_size = 16; - let mut public_input = whir_proof.proof_data[..commitment_size].to_vec(); - public_input.extend(padd_with_zero_to_next_multiple_of( - &point - .iter() - .flat_map(|x| >::as_basis_coefficients_slice(x).to_vec()) - .collect::>(), - VECTOR_LEN, - )); - public_input.extend(padd_with_zero_to_next_multiple_of( - >::as_basis_coefficients_slice(&eval), - VECTOR_LEN, - )); - - public_input.extend(whir_proof.proof_data[commitment_size..].to_vec()); - - assert!(public_input.len().is_multiple_of(VECTOR_LEN)); - replacements.insert( - "WHIR_PROOF_SIZE_PLACEHOLDER".to_string(), - (public_input.len() / VECTOR_LEN).to_string(), - ); - - public_input = std::iter::repeat_n(public_input, n_recursions).flatten().collect(); - - if tracing { - utils::init_tracing(); - } - - let bytecode = compile_program_with_flags(&ProgramSource::Filepath(filepath), CompilationFlags { replacements }); - - let mut merkle_path_hints = VecDeque::new(); - for _ in 0..n_recursions { - merkle_path_hints.extend(whir_proof.merkle_hints.clone()); - } - - // in practice we will precompute all the possible values - // (depending on the number of recursions + the number of xmss signatures) - // (or even better: find a linear relation) - let no_vec_runtime_memory = execute_bytecode( - &bytecode, - (&public_input, &[]), - 1 << 20, - false, - (&vec![], &vec![]), // TODO - merkle_path_hints.clone(), - ) - .no_vec_runtime_memory; - - let time = Instant::now(); - - let (proof, summary) = prove_execution( - &bytecode, - (&public_input, &[]), - whir_config_builder(), - no_vec_runtime_memory, - false, - (&vec![], &vec![]), // TODO precompute poseidons - merkle_path_hints, - ); - let proof_size = proof.proof_size; - let proving_time = time.elapsed(); - verify_execution(&bytecode, &public_input, proof, whir_config_builder()).unwrap(); - - println!("{summary}"); - println!( - "Proving time: {} ms / WHIR recursion, proof size: {} KiB (not optimized)", - proving_time.as_millis() / n_recursions as u128, - proof_size * F::bits() / (8 * 1024) - ); -} - -#[test] -fn test_whir_recursion() { - run_whir_recursion_benchmark(false, 1); -} diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index 71e19c7a..7dbacea9 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -1,23 +1,22 @@ use lean_compiler::*; -use lean_prover::{LOG_SMALLEST_DECOMPOSITION_CHUNK, whir_config_builder}; -use lean_prover::{prove_execution::prove_execution, verify_execution::verify_execution}; +use lean_prover::{SnarkParams, prove_execution::prove_execution, verify_execution::verify_execution}; use lean_vm::*; use multilinear_toolkit::prelude::*; use rand::{Rng, SeedableRng, rngs::StdRng}; -use std::collections::VecDeque; use std::path::Path; use std::sync::OnceLock; use std::time::Instant; use tracing::{info_span, instrument}; +use utils::to_little_endian_in_field; use whir_p3::precompute_dft_twiddles; use xmss::{ - Poseidon16History, Poseidon24History, V, XMSS_MAX_LOG_LIFETIME, XMSS_MIN_LOG_LIFETIME, XmssPublicKey, - XmssSignature, xmss_generate_phony_signatures, xmss_verify_with_poseidon_trace, + Poseidon16History, V, XMSS_MAX_LOG_LIFETIME, XmssPublicKey, XmssSignature, xmss_generate_phony_signatures, + xmss_verify_with_poseidon_trace, }; -static XMSS_AGGREGATION_PROGRAM: OnceLock = OnceLock::new(); +static XMSS_AGGREGATION_PROGRAM: OnceLock = OnceLock::new(); -fn get_xmss_aggregation_program() -> &'static XmssAggregationProgram { +fn get_xmss_aggregation_program() -> &'static Bytecode { XMSS_AGGREGATION_PROGRAM.get_or_init(compile_xmss_aggregation_program) } @@ -25,40 +24,28 @@ pub fn xmss_setup_aggregation_program() { let _ = get_xmss_aggregation_program(); } -// vectorized -fn xmss_sig_size_in_memory() -> usize { - 1 + V -} - fn build_public_input(xmss_pub_keys: &[XmssPublicKey], message_hash: [F; 8], slot: u64) -> Vec { - let mut public_input = message_hash.to_vec(); + let mut public_input = vec![F::from_usize(xmss_pub_keys.len())]; + public_input.extend(message_hash.to_vec()); public_input.extend(xmss_pub_keys.iter().flat_map(|pk| pk.merkle_root)); public_input.extend(xmss_pub_keys.iter().map(|pk| F::from_usize(pk.log_lifetime))); - public_input.extend( - xmss_pub_keys - .iter() - .map(|pk| F::from_u64(slot.checked_sub(pk.first_slot).unwrap())), // index in merkle tree - ); - - let min_public_input_size = (1 << LOG_SMALLEST_DECOMPOSITION_CHUNK) - NONRESERVED_PROGRAM_INPUT_START; - public_input.extend(F::zero_vec(min_public_input_size.saturating_sub(public_input.len()))); - public_input.splice( - 0..0, - [ - vec![ - F::from_usize(xmss_pub_keys.len()), - F::from_usize(xmss_sig_size_in_memory()), - ], - vec![F::ZERO; 6], - ] - .concat(), - ); + for pk in xmss_pub_keys { + let index_in_merkle_tree = slot.checked_sub(pk.first_slot).unwrap() as usize; + public_input.extend(to_little_endian_in_field::( + index_in_merkle_tree, + XMSS_MAX_LOG_LIFETIME, + )); + } + let mut acc = F::ZERO; + for pk in xmss_pub_keys { + public_input.push(acc); + acc += F::from_usize((1 + V + pk.log_lifetime) * DIGEST_LEN); // size of the signature + } public_input } -fn build_private_input(all_signatures: &[XmssSignature]) -> (Vec, VecDeque>) { +fn build_private_input(all_signatures: &[XmssSignature]) -> Vec { let mut private_input = vec![]; - let mut merkle_path_hints = VecDeque::>::new(); for signature in all_signatures { let initial_private_input_len = private_input.len(); private_input.extend(signature.wots_signature.randomness.to_vec()); @@ -69,82 +56,24 @@ fn build_private_input(all_signatures: &[XmssSignature]) -> (Vec, VecDeque, -} - -impl XmssAggregationProgram { - pub fn compute_non_vec_memory(&self, log_lifetimes: &[usize]) -> usize { - log_lifetimes - .iter() - .map(|&ll| self.no_vec_mem_per_log_lifetime[ll - XMSS_MIN_LOG_LIFETIME]) - .sum::() - + self.default_no_vec_mem + assert!(sig_size.is_multiple_of(DIGEST_LEN)); } + private_input } #[instrument(skip_all)] -fn compile_xmss_aggregation_program() -> XmssAggregationProgram { - let src_file = Path::new(env!("CARGO_MANIFEST_DIR")).join("xmss_aggregate.snark"); - let filepath = src_file.to_str().unwrap().to_string(); - let bytecode = compile_program(&ProgramSource::Filepath(filepath.clone())); - let default_no_vec_mem = exec_phony_xmss(&bytecode, &[]).no_vec_runtime_memory; - let mut no_vec_mem_per_log_lifetime = vec![]; - for log_lifetime in XMSS_MIN_LOG_LIFETIME..=XMSS_MAX_LOG_LIFETIME { - let no_vec_mem = exec_phony_xmss(&bytecode, &[log_lifetime]).no_vec_runtime_memory; - no_vec_mem_per_log_lifetime.push(no_vec_mem.checked_sub(default_no_vec_mem).unwrap()); - } - - let res = XmssAggregationProgram { - bytecode, - default_no_vec_mem, - no_vec_mem_per_log_lifetime, - }; - - let n_sanity_checks = 50; - let mut rng = rand::rng(); - for _ in 0..n_sanity_checks { - let n_sigs = rng.random_range(1..=25); - let log_lifetimes = (0..n_sigs) - .map(|_| rng.random_range(XMSS_MIN_LOG_LIFETIME..=XMSS_MAX_LOG_LIFETIME)) - .collect::>(); - let result = exec_phony_xmss(&res.bytecode, &log_lifetimes); - assert_eq!( - result.no_vec_runtime_memory, - res.compute_non_vec_memory(&log_lifetimes), - "inconsistent no-vec memory for log_lifetimes : {log_lifetimes:?}: non linear formula, TODO", - ); - } - res -} - -fn exec_phony_xmss(bytecode: &Bytecode, log_lifetimes: &[usize]) -> ExecutionResult { - let mut rng = StdRng::seed_from_u64(0); - let message_hash: [F; 8] = rng.random(); - let slot = 1111; - let (xmss_pub_keys, all_signatures) = xmss_generate_phony_signatures(log_lifetimes, message_hash, slot); - let public_input = build_public_input(&xmss_pub_keys, message_hash, slot); - let (private_input, merkle_path_hints) = build_private_input(&all_signatures); - execute_bytecode( - bytecode, - (&public_input, &private_input), - 1 << 21, - false, - (&vec![], &vec![]), - merkle_path_hints, - ) +fn compile_xmss_aggregation_program() -> Bytecode { + let filepath = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("xmss_aggregate.py") + .to_str() + .unwrap() + .to_string(); + compile_program(&ProgramSource::Filepath(filepath)) } pub fn run_xmss_benchmark(log_lifetimes: &[usize], tracing: bool) { @@ -159,7 +88,6 @@ pub fn run_xmss_benchmark(log_lifetimes: &[usize], tracing: bool) { let slot = 1111; let (xmss_pub_keys, all_signatures) = xmss_generate_phony_signatures(log_lifetimes, message_hash, slot); - let time = Instant::now(); let (proof_data, n_field_elements_in_proof, summary) = xmss_aggregate_signatures_helper(&xmss_pub_keys, &all_signatures, message_hash, slot).unwrap(); @@ -203,26 +131,23 @@ fn xmss_aggregate_signatures_helper( let program = get_xmss_aggregation_program(); - let (poseidons_16_precomputed, poseidons_24_precomputed) = - precompute_poseidons(xmss_pub_keys, all_signatures, &message_hash) - .ok_or(XmssAggregateError::InvalidSigature)?; + let poseidons_16_precomputed = precompute_poseidons(xmss_pub_keys, all_signatures, &message_hash) + .ok_or(XmssAggregateError::InvalidSigature)?; let public_input = build_public_input(xmss_pub_keys, message_hash, slot); - let (private_input, merkle_path_hints) = build_private_input(all_signatures); + let private_input = build_private_input(all_signatures); - let (proof, summary) = prove_execution( - &program.bytecode, + let proof = prove_execution( + program, (&public_input, &private_input), - whir_config_builder(), - program.compute_non_vec_memory(&xmss_pub_keys.iter().map(|pk| pk.log_lifetime).collect::>()), + &poseidons_16_precomputed, + &SnarkParams::default(), false, - (&poseidons_16_precomputed, &poseidons_24_precomputed), - merkle_path_hints, ); - let proof_bytes = info_span!("Proof serialization").in_scope(|| bincode::serialize(&proof).unwrap()); + let proof_bytes = info_span!("Proof serialization").in_scope(|| bincode::serialize(&proof.proof).unwrap()); - Ok((proof_bytes, proof.proof_size, summary)) + Ok((proof_bytes, proof.proof_size_fe, proof.exec_summary)) } pub fn xmss_verify_aggregated_signatures( @@ -240,7 +165,7 @@ pub fn xmss_verify_aggregated_signatures( let public_input = build_public_input(xmss_pub_keys, message_hash, slot); - verify_execution(&program.bytecode, &public_input, proof, whir_config_builder()) + verify_execution(program, &public_input, proof, &SnarkParams::default()).map(|_| ()) } #[instrument(skip_all)] @@ -248,7 +173,7 @@ fn precompute_poseidons( xmss_pub_keys: &[XmssPublicKey], all_signatures: &[XmssSignature], message_hash: &[F; 8], -) -> Option<(Poseidon16History, Poseidon24History)> { +) -> Option { assert_eq!(xmss_pub_keys.len(), all_signatures.len()); let traces = xmss_pub_keys .par_iter() @@ -256,16 +181,7 @@ fn precompute_poseidons( .map(|(pub_key, sig)| xmss_verify_with_poseidon_trace(pub_key, message_hash, sig)) .collect::, _>>() .ok()?; - Some(( - traces - .par_iter() - .flat_map(|(poseidon_16_trace, _)| poseidon_16_trace.to_vec()) - .collect(), - traces - .par_iter() - .flat_map(|(_, poseidon_24_trace)| poseidon_24_trace.to_vec()) - .collect(), - )) + Some(traces.into_par_iter().flatten().collect()) } #[test] @@ -273,7 +189,7 @@ fn test_xmss_aggregate() { let n_xmss = 10; let mut rng = StdRng::seed_from_u64(0); let log_lifetimes = (0..n_xmss) - .map(|_| rng.random_range(XMSS_MIN_LOG_LIFETIME..=XMSS_MAX_LOG_LIFETIME)) + .map(|_| rng.random_range(xmss::XMSS_MIN_LOG_LIFETIME..=XMSS_MAX_LOG_LIFETIME)) .collect::>(); run_xmss_benchmark(&log_lifetimes, false); } diff --git a/crates/rec_aggregation/utils.py b/crates/rec_aggregation/utils.py new file mode 100644 index 00000000..27af5238 --- /dev/null +++ b/crates/rec_aggregation/utils.py @@ -0,0 +1,838 @@ +from snark_lib import * +from hashing import * + +F_BITS = 31 # koala-bear = 31 bits + +TWO_ADICITY = 24 +ROOT = 1791270792 # of order 2^TWO_ADICITY + +# Dot product precompile: +BE = 1 # base-extension +EE = 0 # extension-extension + +# bit decomposition hint +BIG_ENDIAN = 0 +LITTLE_ENDIAN = 1 + + +def powers(alpha, n): + # alpha: EF + # n: F + + res = Array(n * DIM) + set_to_one(res) + for i in range(0, n - 1): + mul_extension(res + i * DIM, alpha, res + (i + 1) * DIM) + return res + + +def powers_const(alpha, n: Const): + # alpha: EF + # n: F + + res = Array(n * DIM) + set_to_one(res) + for i in unroll(0, n - 1): + mul_extension(res + i * DIM, alpha, res + (i + 1) * DIM) + return res + + +def unit_root_pow_dynamic(domain_size, index_bits): + # index_bits is a pointer to domain_size bits + res: Imu + debug_assert(domain_size < 26) + match domain_size: + case 0: + _ = 0 # unreachable + case 1: + res = unit_root_pow_const(1, index_bits) + case 2: + res = unit_root_pow_const(2, index_bits) + case 3: + res = unit_root_pow_const(3, index_bits) + case 4: + res = unit_root_pow_const(4, index_bits) + case 5: + res = unit_root_pow_const(5, index_bits) + case 6: + res = unit_root_pow_const(6, index_bits) + case 7: + res = unit_root_pow_const(7, index_bits) + case 8: + res = unit_root_pow_const(8, index_bits) + case 9: + res = unit_root_pow_const(9, index_bits) + case 10: + res = unit_root_pow_const(10, index_bits) + case 11: + res = unit_root_pow_const(11, index_bits) + case 12: + res = unit_root_pow_const(12, index_bits) + case 13: + res = unit_root_pow_const(13, index_bits) + case 14: + res = unit_root_pow_const(14, index_bits) + case 15: + res = unit_root_pow_const(15, index_bits) + case 16: + res = unit_root_pow_const(16, index_bits) + case 17: + res = unit_root_pow_const(17, index_bits) + case 18: + res = unit_root_pow_const(18, index_bits) + case 19: + res = unit_root_pow_const(19, index_bits) + case 20: + res = unit_root_pow_const(20, index_bits) + case 21: + res = unit_root_pow_const(21, index_bits) + case 22: + res = unit_root_pow_const(22, index_bits) + case 23: + res = unit_root_pow_const(23, index_bits) + case 24: + res = unit_root_pow_const(24, index_bits) + case 25: + res = unit_root_pow_const(25, index_bits) + return res + + +def unit_root_pow_const(domain_size: Const, index_bits): + prod: Mut = (index_bits[0] * ROOT ** (2 ** (TWO_ADICITY - domain_size))) + (1 - index_bits[0]) + for i in unroll(1, domain_size): + prod *= (index_bits[i] * ROOT ** (2 ** (TWO_ADICITY - domain_size + i))) + (1 - index_bits[i]) + return prod + + +def poly_eq_extension_dynamic(point, n): + debug_assert(n < 8) + res: Imu + match n: + case 0: + res = ONE_VEC_PTR + case 1: + res = poly_eq_extension(point, 1) + case 2: + res = poly_eq_extension(point, 2) + case 3: + res = poly_eq_extension(point, 3) + case 4: + res = poly_eq_extension(point, 4) + case 5: + res = poly_eq_extension(point, 5) + case 6: + res = poly_eq_extension(point, 6) + case 7: + res = poly_eq_extension(point, 7) + return res + + +def poly_eq_extension(point, n: Const): + # Example: for n = 2: eq(x, y) = [(1 - x)(1 - y), (1 - x)y, x(1 - y), xy] + + res = Array((2 ** (n + 1) - 1) * DIM) + set_to_one(res) + + for s in unroll(0, n): + p = point + (n - 1 - s) * DIM + for i in unroll(0, 2**s): + mul_extension(p, res + (2**s - 1 + i) * DIM, res + (2 ** (s + 1) - 1 + 2**s + i) * DIM) + sub_extension( + res + (2**s - 1 + i) * DIM, + res + (2 ** (s + 1) - 1 + 2**s + i) * DIM, + res + (2 ** (s + 1) - 1 + i) * DIM, + ) + return res + (2**n - 1) * DIM + + +def poly_eq_base(point, n: Const): + # Example: for n = 2: eq(x, y) = [(1 - x)(1 - y), (1 - x)y, x(1 - y), xy] + + res = Array((2 ** (n + 1) - 1)) + res[0] = 1 + for s in unroll(0, n): + p = point[n - 1 - s] + for i in unroll(0, 2**s): + res[2 ** (s + 1) - 1 + 2**s + i] = p * res[2**s - 1 + i] + res[2 ** (s + 1) - 1 + i] = res[2**s - 1 + i] - res[2 ** (s + 1) - 1 + 2**s + i] + return res + (2**n - 1) + + +def pow(a, b): + if b == 0: + return 1 # a^0 = 1 + else: + p = pow(a, b - 1) + return a * p + + +def eq_mle_extension(a, b, n): + buff = Array(n * DIM) + + for i in range(0, n): + shift = i * DIM + ai = a + shift + bi = b + shift + buffi = buff + shift + ab = mul_extension_ret(ai, bi) + buffi[0] = 1 + 2 * ab[0] - ai[0] - bi[0] + for j in unroll(1, DIM): + buffi[j] = 2 * ab[j] - ai[j] - bi[j] + + current_prod: Mut = buff + for i in range(0, n - 1): + next_prod = Array(DIM) + mul_extension(current_prod, buff + (i + 1) * DIM, next_prod) + current_prod = next_prod + + return current_prod + + +def eq_mle_base_extension(a, b, n): + res: Imu + debug_assert(n < 26) + match n: + case 0: + _ = 0 # unreachable + case 1: + res = eq_mle_extension_base_const(a, b, 1) + case 2: + res = eq_mle_extension_base_const(a, b, 2) + case 3: + res = eq_mle_extension_base_const(a, b, 3) + case 4: + res = eq_mle_extension_base_const(a, b, 4) + case 5: + res = eq_mle_extension_base_const(a, b, 5) + case 6: + res = eq_mle_extension_base_const(a, b, 6) + case 7: + res = eq_mle_extension_base_const(a, b, 7) + case 8: + res = eq_mle_extension_base_const(a, b, 8) + case 9: + res = eq_mle_extension_base_const(a, b, 9) + case 10: + res = eq_mle_extension_base_const(a, b, 10) + case 11: + res = eq_mle_extension_base_const(a, b, 11) + case 12: + res = eq_mle_extension_base_const(a, b, 12) + case 13: + res = eq_mle_extension_base_const(a, b, 13) + case 14: + res = eq_mle_extension_base_const(a, b, 14) + case 15: + res = eq_mle_extension_base_const(a, b, 15) + case 16: + res = eq_mle_extension_base_const(a, b, 16) + case 17: + res = eq_mle_extension_base_const(a, b, 17) + case 18: + res = eq_mle_extension_base_const(a, b, 18) + case 19: + res = eq_mle_extension_base_const(a, b, 19) + case 20: + res = eq_mle_extension_base_const(a, b, 20) + case 21: + res = eq_mle_extension_base_const(a, b, 21) + case 22: + res = eq_mle_extension_base_const(a, b, 22) + case 23: + res = eq_mle_extension_base_const(a, b, 23) + case 24: + res = eq_mle_extension_base_const(a, b, 24) + case 25: + res = eq_mle_extension_base_const(a, b, 25) + return res + + +def eq_mle_extension_base_const(a, b, n: Const): + # a: base + # b: extension + + buff = Array(n * DIM) + + for i in unroll(0, n): + ai = a[i] + bi = b + i * DIM + buffi = buff + i * DIM + ai_double = ai * 2 + buffi[0] = 1 + ai_double * bi[0] - ai - bi[0] + for j in unroll(1, DIM): + buffi[j] = ai_double * bi[j] - bi[j] + + prods = Array(n * DIM) + copy_5(buff, prods) + for i in unroll(0, n - 1): + mul_extension(prods + i * DIM, buff + (i + 1) * DIM, prods + (i + 1) * DIM) + return prods + (n - 1) * DIM + + +def expand_from_univariate_base(alpha, n): + res: Imu + debug_assert(n < 23) + match n: + case 0: + _ = 0 # unreachable + case 1: + res = expand_from_univariate_base_const(alpha, 1) + case 2: + res = expand_from_univariate_base_const(alpha, 2) + case 3: + res = expand_from_univariate_base_const(alpha, 3) + case 4: + res = expand_from_univariate_base_const(alpha, 4) + case 5: + res = expand_from_univariate_base_const(alpha, 5) + case 6: + res = expand_from_univariate_base_const(alpha, 6) + case 7: + res = expand_from_univariate_base_const(alpha, 7) + case 8: + res = expand_from_univariate_base_const(alpha, 8) + case 9: + res = expand_from_univariate_base_const(alpha, 9) + case 10: + res = expand_from_univariate_base_const(alpha, 10) + case 11: + res = expand_from_univariate_base_const(alpha, 11) + case 12: + res = expand_from_univariate_base_const(alpha, 12) + case 13: + res = expand_from_univariate_base_const(alpha, 13) + case 14: + res = expand_from_univariate_base_const(alpha, 14) + case 15: + res = expand_from_univariate_base_const(alpha, 15) + case 16: + res = expand_from_univariate_base_const(alpha, 16) + case 17: + res = expand_from_univariate_base_const(alpha, 17) + case 18: + res = expand_from_univariate_base_const(alpha, 18) + case 19: + res = expand_from_univariate_base_const(alpha, 19) + case 20: + res = expand_from_univariate_base_const(alpha, 20) + case 21: + res = expand_from_univariate_base_const(alpha, 21) + case 22: + res = expand_from_univariate_base_const(alpha, 22) + return res + + +def expand_from_univariate_base_const(alpha, n: Const): + # "expand_from_univariate" + # alpha: F + + res = Array(n) + current: Mut = alpha + for i in unroll(0, n): + res[i] = current + current *= current + return res + + +def expand_from_univariate_ext(alpha, n): + res = Array(n * DIM) + copy_5(alpha, res) + for i in range(0, n - 1): + mul_extension(res + i * DIM, res + i * DIM, res + (i + 1) * DIM) + return res + + +def dot_product_be_dynamic(a, b, res, n): + for i in unroll(6, 10): + if n == 2**i: + dot_product(a, b, res, 2**i, BE) + return + assert False, "dot_product_be_dynamic called with unsupported n" + + +def dot_product_ee_dynamic(a, b, res, n): + if n == 16: + dot_product(a, b, res, 16, EE) + return + if n == 1: + dot_product(a, b, res, 1, EE) + return + if n == 2: + dot_product(a, b, res, 2, EE) + return + + for i in unroll(0, N_ROUNDS_BASE + 1): + if n == NUM_QUERIES_BASE[i]: + dot_product(a, b, res, NUM_QUERIES_BASE[i], EE) + return + if n == NUM_QUERIES_BASE[i] + 1: + dot_product(a, b, res, NUM_QUERIES_BASE[i] + 1, EE) + return + for i in unroll(0, N_ROUNDS_EXT + 1): + if n == NUM_QUERIES_EXT[i]: + dot_product(a, b, res, NUM_QUERIES_EXT[i], EE) + return + if n == NUM_QUERIES_EXT[i] + 1: + dot_product(a, b, res, NUM_QUERIES_EXT[i] + 1, EE) + return + + assert False, "dot_product_ee_dynamic called with unsupported n" + + +def mle_of_01234567_etc(point, n): + if n == 0: + return ZERO_VEC_PTR + else: + e = mle_of_01234567_etc(point + DIM, n - 1) + a = one_minus_self_extension_ret(point) + b = mul_extension_ret(a, e) + power_of_2 = powers_of_two(n - 1) + c = add_base_extension_ret(power_of_2, e) + d = mul_extension_ret(point, c) + res = add_extension_ret(b, d) + return res + + +def powers_of_two(n): + debug_assert(n < 32) + res: Imu + match n: + case 0: + res = 0 + 2**0 + case 1: + res = 0 + 2**1 + case 2: + res = 0 + 2**2 + case 3: + res = 0 + 2**3 + case 4: + res = 0 + 2**4 + case 5: + res = 0 + 2**5 + case 6: + res = 0 + 2**6 + case 7: + res = 0 + 2**7 + case 8: + res = 0 + 2**8 + case 9: + res = 0 + 2**9 + case 10: + res = 0 + 2**10 + case 11: + res = 0 + 2**11 + case 12: + res = 0 + 2**12 + case 13: + res = 0 + 2**13 + case 14: + res = 0 + 2**14 + case 15: + res = 0 + 2**15 + case 16: + res = 0 + 2**16 + case 17: + res = 0 + 2**17 + case 18: + res = 0 + 2**18 + case 19: + res = 0 + 2**19 + case 20: + res = 0 + 2**20 + case 21: + res = 0 + 2**21 + case 22: + res = 0 + 2**22 + case 23: + res = 0 + 2**23 + case 24: + res = 0 + 2**24 + case 25: + res = 0 + 2**25 + case 26: + res = 0 + 2**26 + case 27: + res = 0 + 2**27 + case 28: + res = 0 + 2**28 + case 29: + res = 0 + 2**29 + case 30: + res = 0 + 2**30 + case 31: + res = 0 + 2**31 + return res + + +@inline +def mul_extension_ret(a, b): + return dot_product_ret(a, b, 1, EE) + + +@inline +def mul_extension(a, b, c): + dot_product(a, b, c, 1, EE) + return + + +@inline +def add_extension_ret(a, b): + # TODO if a and b are adjacent we can do it in one cycle using the dot_product precompile + c = Array(DIM) + for i in unroll(0, DIM): + c[i] = a[i] + b[i] + return c + + +@inline +def add_extension(a, b, c): + # TODO if a and b are adjacent we can do it in one cycle using the dot_product precompile + for i in unroll(0, DIM): + c[i] = a[i] + b[i] + return + + +@inline +def one_minus_self_extension_ret(a): + res = Array(DIM) + res[0] = 1 - a[0] + for i in unroll(1, DIM): + res[i] = 0 - a[i] + return res + + +@inline +def opposite_extension_ret(a): + # todo use dot_product precompile + res = Array(DIM) + for i in unroll(0, DIM): + res[i] = 0 - a[i] + return res + + +@inline +def add_base_extension_ret(a, b): + # a: base + # b: extension + res = Array(DIM) + res[0] = a + b[0] + for i in unroll(1, DIM): + res[i] = b[i] + return res + + +@inline +def mul_base_extension_ret(a, b): + # a: base + # b: extension + + # TODO: use dot_product_be + + res = Array(DIM) + for i in unroll(0, DIM): + res[i] = a * b[i] + return res + + +def div_extension_ret(n, d): + quotient = Array(DIM) + dot_product(d, quotient, n, 1, EE) + return quotient + + +@inline +def sub_extension(a, b, c): + # TODO if a and b are adjacent we can do it in one cycle using the dot_product precompile + for i in unroll(0, DIM): + c[i] = a[i] - b[i] + return + + +@inline +def sub_base_extension_ret(a, b): + # a: base + # b: extension + # return a - b + res = Array(DIM) + res[0] = a - b[0] + for i in unroll(1, DIM): + res[i] = 0 - b[i] + return res + + +@inline +def sub_extension_base_ret(a, b): + # a: extension + # b: base + # return a - b + res = Array(DIM) + res[0] = a[0] - b + for i in unroll(1, DIM): + res[i] = a[i] + return res + + +@inline +def sub_extension_ret(a, b): + # TODO if a and b are adjacent we can do it in one cycle using the dot_product precompile + c = Array(DIM) + for i in unroll(0, DIM): + c[i] = a[i] - b[i] + return c + + +@inline +def copy_5(a, b): + dot_product(a, ONE_VEC_PTR, b, 1, EE) + return + + +@inline +def set_to_5_zeros(a): + zero_ptr = ZERO_VEC_PTR + dot_product(a, ONE_VEC_PTR, zero_ptr, 1, EE) + return + + +@inline +def set_to_7_zeros(a): + zero_ptr = ZERO_VEC_PTR + dot_product(a, ONE_VEC_PTR, zero_ptr, 1, EE) + a[5] = 0 + a[6] = 0 + return + + +@inline +def set_to_16_zeros(a): + zero_ptr = ZERO_VEC_PTR + dot_product(a, ONE_VEC_PTR, zero_ptr, 1, EE) + dot_product(a + 5, ONE_VEC_PTR, zero_ptr, 1, EE) + dot_product(a + 10, ONE_VEC_PTR, zero_ptr, 1, EE) + a[15] = 0 + return + + +@inline +def copy_8(a, b): + dot_product(a, ONE_VEC_PTR, b, 1, EE) + assert a[5] == b[5] + assert a[6] == b[6] + assert a[7] == b[7] + return + + +@inline +def copy_16(a, b): + dot_product(a, ONE_VEC_PTR, b, 1, EE) + dot_product(a + 5, ONE_VEC_PTR, b + 5, 1, EE) + dot_product(a + 10, ONE_VEC_PTR, b + 10, 1, EE) + a[15] = b[15] + return + + +def copy_many_ef(a, b, n): + for i in range(0, n): + dot_product(a + i * DIM, ONE_VEC_PTR, b + i * DIM, 1, EE) + return + + +@inline +def set_to_one(a): + a[0] = 1 + for i in unroll(1, DIM): + a[i] = 0 + return + + +def print_ef(a): + for i in unroll(0, DIM): + print(a[i]) + return + + +def print_vec(a): + for i in unroll(0, VECTOR_LEN): + print(a[i]) + return + + +def print_many(a, n): + for i in range(0, n): + print(a[i]) + return + + +def next_multiple_of_8(a: Const): + return a + (8 - (a % 8)) % 8 + + +@inline +def read_memory(ptr): + mem = 0 + return mem[ptr] + + +def univariate_polynomial_eval(coeffs, point, degree: Const): + powers = powers(point, degree + 1) # TODO use a parameter: Const version + res = Array(DIM) + dot_product(coeffs, powers, res, degree + 1, EE) + return res + + +def sum_2_ef_fractions(a_num, a_den, b_num, b_den): + common_den = mul_extension_ret(a_den, b_den) + a_num_mul_b_den = mul_extension_ret(a_num, b_den) + b_num_mul_a_den = mul_extension_ret(b_num, a_den) + sum_num = add_extension_ret(a_num_mul_b_den, b_num_mul_a_den) + return sum_num, common_den + + +# p = 2^31 - 2^24 + 1 +# in binary: p = 1111111000000000000000000000001 +# p - 1 = 1111111000000000000000000000000 +# p - 2 = 1111110111111111111111111111111 +# p - 3 = 1111110111111111111111111111110 +# ... +# Any field element (< p) is either: +# - 1111111 | 00...00 +# - not(1111111) | xx...xx +def checked_decompose_bits(a, k): + # return a pointer to the 31 bits of a + # .. and the partial value, reading the first K bits (with k <= 24) + bits = Array(F_BITS) + hint_decompose_bits(a, bits, F_BITS, LITTLE_ENDIAN) + + for i in unroll(0, F_BITS): + assert bits[i] * (1 - bits[i]) == 0 + partial_sums_24 = Array(24) + sum_24: Mut = bits[0] + partial_sums_24[0] = sum_24 + for i in unroll(1, 24): + sum_24 += bits[i] * 2**i + partial_sums_24[i] = sum_24 + sum_7: Mut = bits[24] + for i in unroll(1, 7): + sum_7 += bits[24 + i] * 2**i + if sum_7 == 127: + assert sum_24 == 0 + + assert a == sum_24 + sum_7 * 2**24 + partial_sum = partial_sums_24[k - 1] + return bits, partial_sum + + +def checked_decompose_bits_small_value(to_decompose, n_bits): + bits = Array(n_bits) + hint_decompose_bits(to_decompose, bits, n_bits, BIG_ENDIAN) + sum: Mut = bits[n_bits - 1] + power_of_2: Mut = 1 + for i in range(1, n_bits): + power_of_2 *= 2 + sum += bits[n_bits - 1 - i] * power_of_2 + assert to_decompose == sum + return bits + + +@inline +def dot_product_ret(a, b, n, mode): + res = Array(DIM) + dot_product(a, b, res, n, mode) + return res + + +def mle_of_zeros_then_ones(point, n_zeros, n_vars): + if n_vars == 0: + res = Array(DIM) + res[0] = 1 - n_zeros + for i in unroll(1, DIM): + res[i] = 0 + return res + + n_values = powers_of_two(n_vars) + debug_assert(n_zeros <= n_values) + + if n_zeros == n_values: + return ZERO_VEC_PTR + + bits, _ = checked_decompose_bits(n_zeros, 0) + + res: Mut = Array(DIM) + set_to_one(res) + + for i in range(0, n_vars): + p = point + (n_vars - 1 - i) * DIM + if bits[i] == 0: + one_minus_p = one_minus_self_extension_ret(p) + tmp = mul_extension_ret(one_minus_p, res) + res = add_extension_ret(tmp, p) + else: + res = mul_extension_ret(p, res) + return res + + +@inline +def embed_in_ef(f): + res = Array(DIM) + res[0] = f + for i in unroll(1, DIM): + res[i] = 0 + return res + + +def next_mle(x, y, n): + # x and y are pointers to n elements of extension field + + # Build eq_prefix[0..n+1] where eq_prefix[i] = prod_{j=i} (x[j] * (1-y[j])) + low_suffix = Array((n + 1) * DIM) + set_to_one(low_suffix + n * DIM) + for i in range(0, n): + idx = n - 1 - i + xi = x + idx * DIM + yi = y + idx * DIM + one_minus_y = one_minus_self_extension_ret(yi) + x_one_minus_y = mul_extension_ret(xi, one_minus_y) + mul_extension(low_suffix + (idx + 1) * DIM, x_one_minus_y, low_suffix + idx * DIM) + + # Compute sum = Σ_{arr=0..n} (eq_prefix[arr] * (1-x[arr]) * y[arr] * low_suffix[arr+1]) + sum: Mut = ZERO_VEC_PTR + for arr in range(0, n): + x_arr = x + arr * DIM + y_arr = y + arr * DIM + one_minus_x = one_minus_self_extension_ret(x_arr) + carry = mul_extension_ret(one_minus_x, y_arr) + eq_carry = mul_extension_ret(eq_prefix + arr * DIM, carry) + term = mul_extension_ret(eq_carry, low_suffix + (arr + 1) * DIM) + sum = add_extension_ret(sum, term) + + # Compute prod = product of all x[i] * product of all y[i] + prod: Mut = Array(DIM) + set_to_one(prod) + for i in range(0, n): + prod = mul_extension_ret(prod, x + i * DIM) + for i in range(0, n): + prod = mul_extension_ret(prod, y + i * DIM) + + result = add_extension_ret(sum, prod) + return result + + +@inline +def dot_product_with_the_base_vectors(slice): + # slice: pointer to DIM extension field elements + # cf constants.rs: by convention, [10000] [01000] [00100] [00010] [00001] is harcoded in memory, starting at ONE_VEC_PTR + return dot_product_ret(slice, ONE_VEC_PTR, 1, EE) diff --git a/crates/rec_aggregation/utils.snark.py b/crates/rec_aggregation/utils.snark.py new file mode 100644 index 00000000..297daa7b --- /dev/null +++ b/crates/rec_aggregation/utils.snark.py @@ -0,0 +1,837 @@ +from hashing import * + +F_BITS = 31 # koala-bear = 31 bits + +TWO_ADICITY = 24 +ROOT = 1791270792 # of order 2^TWO_ADICITY + +# Dot product precompile: +BE = 1 # base-extension +EE = 0 # extension-extension + +# bit decomposition hint +BIG_ENDIAN = 0 +LITTLE_ENDIAN = 1 + + +def powers(alpha, n): + # alpha: EF + # n: F + + res = Array(n * DIM) + set_to_one(res) + for i in range(0, n - 1): + mul_extension(res + i * DIM, alpha, res + (i + 1) * DIM) + return res + + +def powers_const(alpha, n: Const): + # alpha: EF + # n: F + + res = Array(n * DIM) + set_to_one(res) + for i in unroll(0, n - 1): + mul_extension(res + i * DIM, alpha, res + (i + 1) * DIM) + return res + + +def unit_root_pow_dynamic(domain_size, index_bits): + # index_bits is a pointer to domain_size bits + res: Imu + debug_assert(domain_size < 26) + match domain_size: + case 0: + _ = 0 # unreachable + case 1: + res = unit_root_pow_const(1, index_bits) + case 2: + res = unit_root_pow_const(2, index_bits) + case 3: + res = unit_root_pow_const(3, index_bits) + case 4: + res = unit_root_pow_const(4, index_bits) + case 5: + res = unit_root_pow_const(5, index_bits) + case 6: + res = unit_root_pow_const(6, index_bits) + case 7: + res = unit_root_pow_const(7, index_bits) + case 8: + res = unit_root_pow_const(8, index_bits) + case 9: + res = unit_root_pow_const(9, index_bits) + case 10: + res = unit_root_pow_const(10, index_bits) + case 11: + res = unit_root_pow_const(11, index_bits) + case 12: + res = unit_root_pow_const(12, index_bits) + case 13: + res = unit_root_pow_const(13, index_bits) + case 14: + res = unit_root_pow_const(14, index_bits) + case 15: + res = unit_root_pow_const(15, index_bits) + case 16: + res = unit_root_pow_const(16, index_bits) + case 17: + res = unit_root_pow_const(17, index_bits) + case 18: + res = unit_root_pow_const(18, index_bits) + case 19: + res = unit_root_pow_const(19, index_bits) + case 20: + res = unit_root_pow_const(20, index_bits) + case 21: + res = unit_root_pow_const(21, index_bits) + case 22: + res = unit_root_pow_const(22, index_bits) + case 23: + res = unit_root_pow_const(23, index_bits) + case 24: + res = unit_root_pow_const(24, index_bits) + case 25: + res = unit_root_pow_const(25, index_bits) + return res + + +def unit_root_pow_const(domain_size: Const, index_bits): + prod: Mut = (index_bits[0] * ROOT ** (2 ** (TWO_ADICITY - domain_size))) + (1 - index_bits[0]) + for i in unroll(1, domain_size): + prod *= (index_bits[i] * ROOT ** (2 ** (TWO_ADICITY - domain_size + i))) + (1 - index_bits[i]) + return prod + + +def poly_eq_extension_dynamic(point, n): + debug_assert(n < 8) + res: Imu + match n: + case 0: + res = ONE_VEC_PTR + case 1: + res = poly_eq_extension(point, 1) + case 2: + res = poly_eq_extension(point, 2) + case 3: + res = poly_eq_extension(point, 3) + case 4: + res = poly_eq_extension(point, 4) + case 5: + res = poly_eq_extension(point, 5) + case 6: + res = poly_eq_extension(point, 6) + case 7: + res = poly_eq_extension(point, 7) + return res + + +def poly_eq_extension(point, n: Const): + # Example: for n = 2: eq(x, y) = [(1 - x)(1 - y), (1 - x)y, x(1 - y), xy] + + res = Array((2 ** (n + 1) - 1) * DIM) + set_to_one(res) + + for s in unroll(0, n): + p = point + (n - 1 - s) * DIM + for i in unroll(0, 2**s): + mul_extension(p, res + (2**s - 1 + i) * DIM, res + (2 ** (s + 1) - 1 + 2**s + i) * DIM) + sub_extension( + res + (2**s - 1 + i) * DIM, + res + (2 ** (s + 1) - 1 + 2**s + i) * DIM, + res + (2 ** (s + 1) - 1 + i) * DIM, + ) + return res + (2**n - 1) * DIM + + +def poly_eq_base(point, n: Const): + # Example: for n = 2: eq(x, y) = [(1 - x)(1 - y), (1 - x)y, x(1 - y), xy] + + res = Array((2 ** (n + 1) - 1)) + res[0] = 1 + for s in unroll(0, n): + p = point[n - 1 - s] + for i in unroll(0, 2**s): + res[2 ** (s + 1) - 1 + 2**s + i] = p * res[2**s - 1 + i] + res[2 ** (s + 1) - 1 + i] = res[2**s - 1 + i] - res[2 ** (s + 1) - 1 + 2**s + i] + return res + (2**n - 1) + + +def pow(a, b): + if b == 0: + return 1 # a^0 = 1 + else: + p = pow(a, b - 1) + return a * p + + +def eq_mle_extension(a, b, n): + buff = Array(n * DIM) + + for i in range(0, n): + shift = i * DIM + ai = a + shift + bi = b + shift + buffi = buff + shift + ab = mul_extension_ret(ai, bi) + buffi[0] = 1 + 2 * ab[0] - ai[0] - bi[0] + for j in unroll(1, DIM): + buffi[j] = 2 * ab[j] - ai[j] - bi[j] + + current_prod: Mut = buff + for i in range(0, n - 1): + next_prod = Array(DIM) + mul_extension(current_prod, buff + (i + 1) * DIM, next_prod) + current_prod = next_prod + + return current_prod + + +def eq_mle_base_extension(a, b, n): + res: Imu + debug_assert(n < 26) + match n: + case 0: + _ = 0 # unreachable + case 1: + res = eq_mle_extension_base_const(a, b, 1) + case 2: + res = eq_mle_extension_base_const(a, b, 2) + case 3: + res = eq_mle_extension_base_const(a, b, 3) + case 4: + res = eq_mle_extension_base_const(a, b, 4) + case 5: + res = eq_mle_extension_base_const(a, b, 5) + case 6: + res = eq_mle_extension_base_const(a, b, 6) + case 7: + res = eq_mle_extension_base_const(a, b, 7) + case 8: + res = eq_mle_extension_base_const(a, b, 8) + case 9: + res = eq_mle_extension_base_const(a, b, 9) + case 10: + res = eq_mle_extension_base_const(a, b, 10) + case 11: + res = eq_mle_extension_base_const(a, b, 11) + case 12: + res = eq_mle_extension_base_const(a, b, 12) + case 13: + res = eq_mle_extension_base_const(a, b, 13) + case 14: + res = eq_mle_extension_base_const(a, b, 14) + case 15: + res = eq_mle_extension_base_const(a, b, 15) + case 16: + res = eq_mle_extension_base_const(a, b, 16) + case 17: + res = eq_mle_extension_base_const(a, b, 17) + case 18: + res = eq_mle_extension_base_const(a, b, 18) + case 19: + res = eq_mle_extension_base_const(a, b, 19) + case 20: + res = eq_mle_extension_base_const(a, b, 20) + case 21: + res = eq_mle_extension_base_const(a, b, 21) + case 22: + res = eq_mle_extension_base_const(a, b, 22) + case 23: + res = eq_mle_extension_base_const(a, b, 23) + case 24: + res = eq_mle_extension_base_const(a, b, 24) + case 25: + res = eq_mle_extension_base_const(a, b, 25) + return res + + +def eq_mle_extension_base_const(a, b, n: Const): + # a: base + # b: extension + + buff = Array(n * DIM) + + for i in unroll(0, n): + ai = a[i] + bi = b + i * DIM + buffi = buff + i * DIM + ai_double = ai * 2 + buffi[0] = 1 + ai_double * bi[0] - ai - bi[0] + for j in unroll(1, DIM): + buffi[j] = ai_double * bi[j] - bi[j] + + prods = Array(n * DIM) + copy_5(buff, prods) + for i in unroll(0, n - 1): + mul_extension(prods + i * DIM, buff + (i + 1) * DIM, prods + (i + 1) * DIM) + return prods + (n - 1) * DIM + + +def expand_from_univariate_base(alpha, n): + res: Imu + debug_assert(n < 23) + match n: + case 0: + _ = 0 # unreachable + case 1: + res = expand_from_univariate_base_const(alpha, 1) + case 2: + res = expand_from_univariate_base_const(alpha, 2) + case 3: + res = expand_from_univariate_base_const(alpha, 3) + case 4: + res = expand_from_univariate_base_const(alpha, 4) + case 5: + res = expand_from_univariate_base_const(alpha, 5) + case 6: + res = expand_from_univariate_base_const(alpha, 6) + case 7: + res = expand_from_univariate_base_const(alpha, 7) + case 8: + res = expand_from_univariate_base_const(alpha, 8) + case 9: + res = expand_from_univariate_base_const(alpha, 9) + case 10: + res = expand_from_univariate_base_const(alpha, 10) + case 11: + res = expand_from_univariate_base_const(alpha, 11) + case 12: + res = expand_from_univariate_base_const(alpha, 12) + case 13: + res = expand_from_univariate_base_const(alpha, 13) + case 14: + res = expand_from_univariate_base_const(alpha, 14) + case 15: + res = expand_from_univariate_base_const(alpha, 15) + case 16: + res = expand_from_univariate_base_const(alpha, 16) + case 17: + res = expand_from_univariate_base_const(alpha, 17) + case 18: + res = expand_from_univariate_base_const(alpha, 18) + case 19: + res = expand_from_univariate_base_const(alpha, 19) + case 20: + res = expand_from_univariate_base_const(alpha, 20) + case 21: + res = expand_from_univariate_base_const(alpha, 21) + case 22: + res = expand_from_univariate_base_const(alpha, 22) + return res + + +def expand_from_univariate_base_const(alpha, n: Const): + # "expand_from_univariate" + # alpha: F + + res = Array(n) + current: Mut = alpha + for i in unroll(0, n): + res[i] = current + current *= current + return res + + +def expand_from_univariate_ext(alpha, n): + res = Array(n * DIM) + copy_5(alpha, res) + for i in range(0, n - 1): + mul_extension(res + i * DIM, res + i * DIM, res + (i + 1) * DIM) + return res + + +def dot_product_be_dynamic(a, b, res, n): + for i in unroll(6, 10): + if n == 2**i: + dot_product(a, b, res, 2**i, BE) + return + assert False, "dot_product_be_dynamic called with unsupported n" + + +def dot_product_ee_dynamic(a, b, res, n): + if n == 16: + dot_product(a, b, res, 16, EE) + return + if n == 1: + dot_product(a, b, res, 1, EE) + return + if n == 2: + dot_product(a, b, res, 2, EE) + return + + for i in unroll(0, N_ROUNDS_BASE + 1): + if n == NUM_QUERIES_BASE[i]: + dot_product(a, b, res, NUM_QUERIES_BASE[i], EE) + return + if n == NUM_QUERIES_BASE[i] + 1: + dot_product(a, b, res, NUM_QUERIES_BASE[i] + 1, EE) + return + for i in unroll(0, N_ROUNDS_EXT + 1): + if n == NUM_QUERIES_EXT[i]: + dot_product(a, b, res, NUM_QUERIES_EXT[i], EE) + return + if n == NUM_QUERIES_EXT[i] + 1: + dot_product(a, b, res, NUM_QUERIES_EXT[i] + 1, EE) + return + + assert False, "dot_product_ee_dynamic called with unsupported n" + + +def mle_of_01234567_etc(point, n): + if n == 0: + return ZERO_VEC_PTR + else: + e = mle_of_01234567_etc(point + DIM, n - 1) + a = one_minus_self_extension_ret(point) + b = mul_extension_ret(a, e) + power_of_2 = powers_of_two(n - 1) + c = add_base_extension_ret(power_of_2, e) + d = mul_extension_ret(point, c) + res = add_extension_ret(b, d) + return res + + +def powers_of_two(n): + debug_assert(n < 32) + res: Imu + match n: + case 0: + res = 0 + 2**0 + case 1: + res = 0 + 2**1 + case 2: + res = 0 + 2**2 + case 3: + res = 0 + 2**3 + case 4: + res = 0 + 2**4 + case 5: + res = 0 + 2**5 + case 6: + res = 0 + 2**6 + case 7: + res = 0 + 2**7 + case 8: + res = 0 + 2**8 + case 9: + res = 0 + 2**9 + case 10: + res = 0 + 2**10 + case 11: + res = 0 + 2**11 + case 12: + res = 0 + 2**12 + case 13: + res = 0 + 2**13 + case 14: + res = 0 + 2**14 + case 15: + res = 0 + 2**15 + case 16: + res = 0 + 2**16 + case 17: + res = 0 + 2**17 + case 18: + res = 0 + 2**18 + case 19: + res = 0 + 2**19 + case 20: + res = 0 + 2**20 + case 21: + res = 0 + 2**21 + case 22: + res = 0 + 2**22 + case 23: + res = 0 + 2**23 + case 24: + res = 0 + 2**24 + case 25: + res = 0 + 2**25 + case 26: + res = 0 + 2**26 + case 27: + res = 0 + 2**27 + case 28: + res = 0 + 2**28 + case 29: + res = 0 + 2**29 + case 30: + res = 0 + 2**30 + case 31: + res = 0 + 2**31 + return res + + +@inline +def mul_extension_ret(a, b): + return dot_product_ret(a, b, 1, EE) + + +@inline +def mul_extension(a, b, c): + dot_product(a, b, c, 1, EE) + return + + +@inline +def add_extension_ret(a, b): + # TODO if a and b are adjacent we can do it in one cycle using the dot_product precompile + c = Array(DIM) + for i in unroll(0, DIM): + c[i] = a[i] + b[i] + return c + + +@inline +def add_extension(a, b, c): + # TODO if a and b are adjacent we can do it in one cycle using the dot_product precompile + for i in unroll(0, DIM): + c[i] = a[i] + b[i] + return + + +@inline +def one_minus_self_extension_ret(a): + res = Array(DIM) + res[0] = 1 - a[0] + for i in unroll(1, DIM): + res[i] = 0 - a[i] + return res + + +@inline +def opposite_extension_ret(a): + # todo use dot_product precompile + res = Array(DIM) + for i in unroll(0, DIM): + res[i] = 0 - a[i] + return res + + +@inline +def add_base_extension_ret(a, b): + # a: base + # b: extension + res = Array(DIM) + res[0] = a + b[0] + for i in unroll(1, DIM): + res[i] = b[i] + return res + + +@inline +def mul_base_extension_ret(a, b): + # a: base + # b: extension + + # TODO: use dot_product_be + + res = Array(DIM) + for i in unroll(0, DIM): + res[i] = a * b[i] + return res + + +def div_extension_ret(n, d): + quotient = Array(DIM) + dot_product(d, quotient, n, 1, EE) + return quotient + + +@inline +def sub_extension(a, b, c): + # TODO if a and b are adjacent we can do it in one cycle using the dot_product precompile + for i in unroll(0, DIM): + c[i] = a[i] - b[i] + return + + +@inline +def sub_base_extension_ret(a, b): + # a: base + # b: extension + # return a - b + res = Array(DIM) + res[0] = a - b[0] + for i in unroll(1, DIM): + res[i] = 0 - b[i] + return res + + +@inline +def sub_extension_base_ret(a, b): + # a: extension + # b: base + # return a - b + res = Array(DIM) + res[0] = a[0] - b + for i in unroll(1, DIM): + res[i] = a[i] + return res + + +@inline +def sub_extension_ret(a, b): + # TODO if a and b are adjacent we can do it in one cycle using the dot_product precompile + c = Array(DIM) + for i in unroll(0, DIM): + c[i] = a[i] - b[i] + return c + + +@inline +def copy_5(a, b): + dot_product(a, ONE_VEC_PTR, b, 1, EE) + return + + +@inline +def set_to_5_zeros(a): + zero_ptr = ZERO_VEC_PTR + dot_product(a, ONE_VEC_PTR, zero_ptr, 1, EE) + return + + +@inline +def set_to_7_zeros(a): + zero_ptr = ZERO_VEC_PTR + dot_product(a, ONE_VEC_PTR, zero_ptr, 1, EE) + a[5] = 0 + a[6] = 0 + return + + +@inline +def set_to_16_zeros(a): + zero_ptr = ZERO_VEC_PTR + dot_product(a, ONE_VEC_PTR, zero_ptr, 1, EE) + dot_product(a + 5, ONE_VEC_PTR, zero_ptr, 1, EE) + dot_product(a + 10, ONE_VEC_PTR, zero_ptr, 1, EE) + a[15] = 0 + return + + +@inline +def copy_8(a, b): + dot_product(a, ONE_VEC_PTR, b, 1, EE) + assert a[5] == b[5] + assert a[6] == b[6] + assert a[7] == b[7] + return + + +@inline +def copy_16(a, b): + dot_product(a, ONE_VEC_PTR, b, 1, EE) + dot_product(a + 5, ONE_VEC_PTR, b + 5, 1, EE) + dot_product(a + 10, ONE_VEC_PTR, b + 10, 1, EE) + a[15] = b[15] + return + + +def copy_many_ef(a, b, n): + for i in range(0, n): + dot_product(a + i * DIM, ONE_VEC_PTR, b + i * DIM, 1, EE) + return + + +@inline +def set_to_one(a): + a[0] = 1 + for i in unroll(1, DIM): + a[i] = 0 + return + + +def print_ef(a): + for i in unroll(0, DIM): + print(a[i]) + return + + +def print_vec(a): + for i in unroll(0, VECTOR_LEN): + print(a[i]) + return + + +def print_many(a, n): + for i in range(0, n): + print(a[i]) + return + + +def next_multiple_of_8(a: Const): + return a + (8 - (a % 8)) % 8 + + +@inline +def read_memory(ptr): + mem = 0 + return mem[ptr] + + +def univariate_polynomial_eval(coeffs, point, degree: Const): + powers = powers(point, degree + 1) # TODO use a parameter: Const version + res = Array(DIM) + dot_product(coeffs, powers, res, degree + 1, EE) + return res + + +def sum_2_ef_fractions(a_num, a_den, b_num, b_den): + common_den = mul_extension_ret(a_den, b_den) + a_num_mul_b_den = mul_extension_ret(a_num, b_den) + b_num_mul_a_den = mul_extension_ret(b_num, a_den) + sum_num = add_extension_ret(a_num_mul_b_den, b_num_mul_a_den) + return sum_num, common_den + + +# p = 2^31 - 2^24 + 1 +# in binary: p = 1111111000000000000000000000001 +# p - 1 = 1111111000000000000000000000000 +# p - 2 = 1111110111111111111111111111111 +# p - 3 = 1111110111111111111111111111110 +# ... +# Any field element (< p) is either: +# - 1111111 | 00...00 +# - not(1111111) | xx...xx +def checked_decompose_bits(a, k): + # return a pointer to the 31 bits of a + # .. and the partial value, reading the first K bits (with k <= 24) + bits = Array(F_BITS) + hint_decompose_bits(a, bits, F_BITS, LITTLE_ENDIAN) + + for i in unroll(0, F_BITS): + assert bits[i] * (1 - bits[i]) == 0 + partial_sums_24 = Array(24) + sum_24: Mut = bits[0] + partial_sums_24[0] = sum_24 + for i in unroll(1, 24): + sum_24 += bits[i] * 2**i + partial_sums_24[i] = sum_24 + sum_7: Mut = bits[24] + for i in unroll(1, 7): + sum_7 += bits[24 + i] * 2**i + if sum_7 == 127: + assert sum_24 == 0 + + assert a == sum_24 + sum_7 * 2**24 + partial_sum = partial_sums_24[k - 1] + return bits, partial_sum + + +def checked_decompose_bits_small_value(to_decompose, n_bits): + bits = Array(n_bits) + hint_decompose_bits(to_decompose, bits, n_bits, BIG_ENDIAN) + sum: Mut = bits[n_bits - 1] + power_of_2: Mut = 1 + for i in range(1, n_bits): + power_of_2 *= 2 + sum += bits[n_bits - 1 - i] * power_of_2 + assert to_decompose == sum + return bits + + +@inline +def dot_product_ret(a, b, n, mode): + res = Array(DIM) + dot_product(a, b, res, n, mode) + return res + + +def mle_of_zeros_then_ones(point, n_zeros, n_vars): + if n_vars == 0: + res = Array(DIM) + res[0] = 1 - n_zeros + for i in unroll(1, DIM): + res[i] = 0 + return res + + n_values = powers_of_two(n_vars) + debug_assert(n_zeros <= n_values) + + if n_zeros == n_values: + return ZERO_VEC_PTR + + bits, _ = checked_decompose_bits(n_zeros, 0) + + res: Mut = Array(DIM) + set_to_one(res) + + for i in range(0, n_vars): + p = point + (n_vars - 1 - i) * DIM + if bits[i] == 0: + one_minus_p = one_minus_self_extension_ret(p) + tmp = mul_extension_ret(one_minus_p, res) + res = add_extension_ret(tmp, p) + else: + res = mul_extension_ret(p, res) + return res + + +@inline +def embed_in_ef(f): + res = Array(DIM) + res[0] = f + for i in unroll(1, DIM): + res[i] = 0 + return res + + +def next_mle(x, y, n): + # x and y are pointers to n elements of extension field + + # Build eq_prefix[0..n+1] where eq_prefix[i] = prod_{j=i} (x[j] * (1-y[j])) + low_suffix = Array((n + 1) * DIM) + set_to_one(low_suffix + n * DIM) + for i in range(0, n): + idx = n - 1 - i + xi = x + idx * DIM + yi = y + idx * DIM + one_minus_y = one_minus_self_extension_ret(yi) + x_one_minus_y = mul_extension_ret(xi, one_minus_y) + mul_extension(low_suffix + (idx + 1) * DIM, x_one_minus_y, low_suffix + idx * DIM) + + # Compute sum = Σ_{arr=0..n} (eq_prefix[arr] * (1-x[arr]) * y[arr] * low_suffix[arr+1]) + sum: Mut = ZERO_VEC_PTR + for arr in range(0, n): + x_arr = x + arr * DIM + y_arr = y + arr * DIM + one_minus_x = one_minus_self_extension_ret(x_arr) + carry = mul_extension_ret(one_minus_x, y_arr) + eq_carry = mul_extension_ret(eq_prefix + arr * DIM, carry) + term = mul_extension_ret(eq_carry, low_suffix + (arr + 1) * DIM) + sum = add_extension_ret(sum, term) + + # Compute prod = product of all x[i] * product of all y[i] + prod: Mut = Array(DIM) + set_to_one(prod) + for i in range(0, n): + prod = mul_extension_ret(prod, x + i * DIM) + for i in range(0, n): + prod = mul_extension_ret(prod, y + i * DIM) + + result = add_extension_ret(sum, prod) + return result + + +@inline +def dot_product_with_the_base_vectors(slice): + # slice: pointer to DIM extension field elements + # cf constants.rs: by convention, [10000] [01000] [00100] [00010] [00001] is harcoded in memory, starting at ONE_VEC_PTR + return dot_product_ret(slice, ONE_VEC_PTR, 1, EE) diff --git a/crates/rec_aggregation/whir.py b/crates/rec_aggregation/whir.py new file mode 100644 index 00000000..818299d0 --- /dev/null +++ b/crates/rec_aggregation/whir.py @@ -0,0 +1,473 @@ +from snark_lib import * +from fiat_shamir import * + +N_VARS_BASE = N_VARS_BASE_PLACEHOLDER +LOG_INV_RATE_BASE = LOG_INV_RATE_BASE_PLACEHOLDER +FOLDING_FACTORS_BASE = FOLDING_FACTORS_BASE_PLACEHOLDER +FINAL_VARS_BASE = FINAL_VARS_BASE_PLACEHOLDER +FIRST_RS_REDUCTION_FACTOR_BASE = FIRST_RS_REDUCTION_FACTOR_BASE_PLACEHOLDER +NUM_OOD_COMMIT_BASE = NUM_OOD_COMMIT_BASE_PLACEHOLDER +NUM_OODS_BASE = NUM_OODS_BASE_PLACEHOLDER +GRINDING_BITS_BASE = GRINDING_BITS_BASE_PLACEHOLDER + +N_VARS_EXT = N_VARS_EXT_PLACEHOLDER +LOG_INV_RATE_EXT = LOG_INV_RATE_EXT_PLACEHOLDER +FOLDING_FACTORS_EXT = FOLDING_FACTORS_EXT_PLACEHOLDER +FINAL_VARS_EXT = FINAL_VARS_EXT_PLACEHOLDER +FIRST_RS_REDUCTION_FACTOR_EXT = FIRST_RS_REDUCTION_FACTOR_EXT_PLACEHOLDER +NUM_OOD_COMMIT_EXT = NUM_OOD_COMMIT_EXT_PLACEHOLDER +NUM_OODS_EXT = NUM_OODS_EXT_PLACEHOLDER +GRINDING_BITS_EXT = GRINDING_BITS_EXT_PLACEHOLDER + + +def whir_open_base( + fs: Mut, + root: Mut, + ood_points_commit, + combination_randomness_powers_0, + claimed_sum: Mut, +): + all_folding_randomness = Array(N_ROUNDS_BASE + 2) + all_ood_points = Array(N_ROUNDS_BASE) + all_circle_values = Array(N_ROUNDS_BASE + 1) + all_combination_randomness_powers = Array(N_ROUNDS_BASE) + + domain_sz: Mut = N_VARS_BASE + LOG_INV_RATE_BASE + for r in unroll(0, N_ROUNDS_BASE): + is_first_round: Imu + if r == 0: + is_first_round = 1 + else: + is_first_round = 0 + ( + fs, + all_folding_randomness[r], + all_ood_points[r], + root, + all_circle_values[r], + all_combination_randomness_powers[r], + claimed_sum, + ) = whir_round( + fs, + root, + FOLDING_FACTORS_BASE[r], + 2 ** FOLDING_FACTORS_BASE[r], + is_first_round, + NUM_QUERIES_BASE[r], + domain_sz, + claimed_sum, + GRINDING_BITS_BASE[r], + NUM_OODS_BASE[r], + ) + if r == 0: + domain_sz -= FIRST_RS_REDUCTION_FACTOR_BASE + else: + domain_sz -= 1 + + fs, all_folding_randomness[N_ROUNDS_BASE], claimed_sum = sumcheck_verify( + fs, FOLDING_FACTORS_BASE[N_ROUNDS_BASE], claimed_sum, 2 + ) + + fs, final_coeffcients = fs_receive_ef(fs, 2**FINAL_VARS_BASE) + + fs, all_circle_values[N_ROUNDS_BASE], final_folds = sample_stir_indexes_and_fold( + fs, + NUM_QUERIES_BASE[N_ROUNDS_BASE], + 0, + FOLDING_FACTORS_BASE[N_ROUNDS_BASE], + 2 ** FOLDING_FACTORS_BASE[N_ROUNDS_BASE], + domain_sz, + root, + all_folding_randomness[N_ROUNDS_BASE], + GRINDING_BITS_BASE[N_ROUNDS_BASE], + ) + + final_circle_values = all_circle_values[N_ROUNDS_BASE] + for i in range(0, NUM_QUERIES_BASE[N_ROUNDS_BASE]): + powers_of_2_rev = expand_from_univariate_base_const(final_circle_values[i], FINAL_VARS_BASE) + poly_eq = poly_eq_base(powers_of_2_rev, FINAL_VARS_BASE) + final_pol_evaluated_on_circle = Array(DIM) + dot_product( + poly_eq, + final_coeffcients, + final_pol_evaluated_on_circle, + 2**FINAL_VARS_BASE, + BE, + ) + copy_5(final_pol_evaluated_on_circle, final_folds + i * DIM) + + fs, all_folding_randomness[N_ROUNDS_BASE + 1], end_sum = sumcheck_verify(fs, FINAL_VARS_BASE, claimed_sum, 2) + + folding_randomness_global = Array(N_VARS_BASE * DIM) + + start: Mut = folding_randomness_global + for i in unroll(0, N_ROUNDS_BASE + 1): + for j in unroll(0, FOLDING_FACTORS_BASE[i]): + copy_5(all_folding_randomness[i] + j * DIM, start + j * DIM) + start += FOLDING_FACTORS_BASE[i] * DIM + for j in unroll(0, FINAL_VARS_BASE): + copy_5(all_folding_randomness[N_ROUNDS_BASE + 1] + j * DIM, start + j * DIM) + + all_ood_recovered_evals = Array(NUM_OOD_COMMIT_BASE * DIM) + for i in unroll(0, NUM_OOD_COMMIT_BASE): + expanded_from_univariate = expand_from_univariate_ext(ood_points_commit + i * DIM, N_VARS_BASE) + ood_rec = eq_mle_extension(expanded_from_univariate, folding_randomness_global, N_VARS_BASE) + copy_5(ood_rec, all_ood_recovered_evals + i * DIM) + s: Mut = dot_product_ret( + all_ood_recovered_evals, + combination_randomness_powers_0, + NUM_OOD_COMMIT_BASE, + EE, + ) + + n_vars: Mut = N_VARS_BASE + my_folding_randomness: Mut = folding_randomness_global + for i in unroll(0, N_ROUNDS_BASE): + n_vars -= FOLDING_FACTORS_BASE[i] + my_ood_recovered_evals = Array(NUM_OODS_BASE[i] * DIM) + combination_randomness_powers = all_combination_randomness_powers[i] + my_folding_randomness += FOLDING_FACTORS_BASE[i] * DIM + for j in unroll(0, NUM_OODS_BASE[i]): + expanded_from_univariate = expand_from_univariate_ext(all_ood_points[i] + j * DIM, n_vars) + ood_rec = eq_mle_extension(expanded_from_univariate, my_folding_randomness, n_vars) + copy_5(ood_rec, my_ood_recovered_evals + j * DIM) + summed_ood = Array(DIM) + dot_product_ee_dynamic( + my_ood_recovered_evals, + combination_randomness_powers, + summed_ood, + NUM_OODS_BASE[i], + ) + + s6s = Array((NUM_QUERIES_BASE[i]) * DIM) + circle_value_i = all_circle_values[i] + for j in range(0, NUM_QUERIES_BASE[i]): # unroll ? + expanded_from_univariate = expand_from_univariate_base(circle_value_i[j], n_vars) + temp = eq_mle_base_extension(expanded_from_univariate, my_folding_randomness, n_vars) + copy_5(temp, s6s + j * DIM) + s7 = dot_product_ret( + s6s, + combination_randomness_powers + NUM_OODS_BASE[i] * DIM, + NUM_QUERIES_BASE[i], + EE, + ) + s = add_extension_ret(s, s7) + s = add_extension_ret(summed_ood, s) + poly_eq_final = poly_eq_extension(all_folding_randomness[N_ROUNDS_BASE + 1], FINAL_VARS_BASE) + final_value = dot_product_ret(poly_eq_final, final_coeffcients, 2**FINAL_VARS_BASE, EE) + # copy_5(mul_extension_ret(s, final_value), end_sum); + + fs = duplexing(fs) + + return fs, folding_randomness_global, s, final_value, end_sum + + +def whir_open_ext( + fs: Mut, + root: Mut, + ood_points_commit, + combination_randomness_powers_0, + claimed_sum: Mut, +): + all_folding_randomness = Array(N_ROUNDS_EXT + 2) + all_ood_points = Array(N_ROUNDS_EXT) + all_circle_values = Array(N_ROUNDS_EXT + 1) + all_combination_randomness_powers = Array(N_ROUNDS_EXT) + + domain_sz: Mut = N_VARS_EXT + LOG_INV_RATE_EXT + for r in unroll(0, N_ROUNDS_EXT): + ( + fs, + all_folding_randomness[r], + all_ood_points[r], + root, + all_circle_values[r], + all_combination_randomness_powers[r], + claimed_sum, + ) = whir_round( + fs, + root, + FOLDING_FACTORS_EXT[r], + 2 ** FOLDING_FACTORS_EXT[r], + 0, + NUM_QUERIES_EXT[r], + domain_sz, + claimed_sum, + GRINDING_BITS_EXT[r], + NUM_OODS_EXT[r], + ) + if r == 0: + domain_sz -= FIRST_RS_REDUCTION_FACTOR_EXT + else: + domain_sz -= 1 + + fs, all_folding_randomness[N_ROUNDS_EXT], claimed_sum = sumcheck_verify( + fs, FOLDING_FACTORS_EXT[N_ROUNDS_EXT], claimed_sum, 2 + ) + + fs, final_coeffcients = fs_receive_ef(fs, 2**FINAL_VARS_EXT) + + fs, all_circle_values[N_ROUNDS_EXT], final_folds = sample_stir_indexes_and_fold( + fs, + NUM_QUERIES_EXT[N_ROUNDS_EXT], + 0, + FOLDING_FACTORS_EXT[N_ROUNDS_EXT], + 2 ** FOLDING_FACTORS_EXT[N_ROUNDS_EXT], + domain_sz, + root, + all_folding_randomness[N_ROUNDS_EXT], + GRINDING_BITS_EXT[N_ROUNDS_EXT], + ) + + final_circle_values = all_circle_values[N_ROUNDS_EXT] + for i in range(0, NUM_QUERIES_EXT[N_ROUNDS_EXT]): + powers_of_2_rev = expand_from_univariate_base_const(final_circle_values[i], FINAL_VARS_EXT) + poly_eq = poly_eq_base(powers_of_2_rev, FINAL_VARS_EXT) + final_pol_evaluated_on_circle = Array(DIM) + dot_product( + poly_eq, + final_coeffcients, + final_pol_evaluated_on_circle, + 2**FINAL_VARS_EXT, + BE, + ) + copy_5(final_pol_evaluated_on_circle, final_folds + i * DIM) + + fs, all_folding_randomness[N_ROUNDS_EXT + 1], end_sum = sumcheck_verify(fs, FINAL_VARS_EXT, claimed_sum, 2) + + folding_randomness_global = Array(N_VARS_EXT * DIM) + + start: Mut = folding_randomness_global + for i in unroll(0, N_ROUNDS_EXT + 1): + for j in unroll(0, FOLDING_FACTORS_EXT[i]): + copy_5(all_folding_randomness[i] + j * DIM, start + j * DIM) + start += FOLDING_FACTORS_EXT[i] * DIM + for j in unroll(0, FINAL_VARS_EXT): + copy_5(all_folding_randomness[N_ROUNDS_EXT + 1] + j * DIM, start + j * DIM) + + all_ood_recovered_evals = Array(NUM_OOD_COMMIT_EXT * DIM) + for i in unroll(0, NUM_OOD_COMMIT_EXT): + expanded_from_univariate = expand_from_univariate_ext(ood_points_commit + i * DIM, N_VARS_EXT) + ood_rec = eq_mle_extension(expanded_from_univariate, folding_randomness_global, N_VARS_EXT) + copy_5(ood_rec, all_ood_recovered_evals + i * DIM) + s: Mut = dot_product_ret(all_ood_recovered_evals, combination_randomness_powers_0, NUM_OOD_COMMIT_EXT, EE) + + n_vars: Mut = N_VARS_EXT + my_folding_randomness: Mut = folding_randomness_global + for i in unroll(0, N_ROUNDS_EXT): + n_vars -= FOLDING_FACTORS_EXT[i] + my_ood_recovered_evals = Array(NUM_OODS_EXT[i] * DIM) + combination_randomness_powers = all_combination_randomness_powers[i] + my_folding_randomness += FOLDING_FACTORS_EXT[i] * DIM + for j in unroll(0, NUM_OODS_EXT[i]): + expanded_from_univariate = expand_from_univariate_ext(all_ood_points[i] + j * DIM, n_vars) + ood_rec = eq_mle_extension(expanded_from_univariate, my_folding_randomness, n_vars) + copy_5(ood_rec, my_ood_recovered_evals + j * DIM) + summed_ood = Array(DIM) + dot_product_ee_dynamic( + my_ood_recovered_evals, + combination_randomness_powers, + summed_ood, + NUM_OODS_EXT[i], + ) + + s6s = Array((NUM_QUERIES_EXT[i]) * DIM) + circle_value_i = all_circle_values[i] + for j in range(0, NUM_QUERIES_EXT[i]): # unroll ? + expanded_from_univariate = expand_from_univariate_base(circle_value_i[j], n_vars) + temp = eq_mle_base_extension(expanded_from_univariate, my_folding_randomness, n_vars) + copy_5(temp, s6s + j * DIM) + s7 = dot_product_ret( + s6s, + combination_randomness_powers + NUM_OODS_EXT[i] * DIM, + NUM_QUERIES_EXT[i], + EE, + ) + s = add_extension_ret(s, s7) + s = add_extension_ret(summed_ood, s) + poly_eq_final = poly_eq_extension(all_folding_randomness[N_ROUNDS_EXT + 1], FINAL_VARS_EXT) + final_value = dot_product_ret(poly_eq_final, final_coeffcients, 2**FINAL_VARS_EXT, EE) + # copy_5(mul_extension_ret(s, final_value), end_sum); + + fs = duplexing(fs) + + return fs, folding_randomness_global, s, final_value, end_sum + + +def sumcheck_verify(fs: Mut, n_steps, claimed_sum, degree: Const): + challenges = Array(n_steps * DIM) + fs, new_claimed_sum = sumcheck_verify_helper(fs, n_steps, claimed_sum, degree, challenges) + return fs, challenges, new_claimed_sum + + +def sumcheck_verify_helper(fs: Mut, n_steps, claimed_sum: Mut, degree: Const, challenges): + for sc_round in range(0, n_steps): + fs, poly = fs_receive_ef(fs, degree + 1) + sum_over_boolean_hypercube = polynomial_sum_at_0_and_1(poly, degree) + copy_5(sum_over_boolean_hypercube, claimed_sum) + rand = fs_sample_ef(fs) + claimed_sum = univariate_polynomial_eval(poly, rand, degree) + copy_5(rand, challenges + sc_round * DIM) + + return fs, claimed_sum + + +def sample_stir_indexes_and_fold( + fs: Mut, + num_queries, + merkle_leaves_in_basefield, + folding_factor, + two_pow_folding_factor, + domain_size, + prev_root, + folding_randomness, + grinding_bits, +): + folded_domain_size = domain_size - folding_factor + + fs = fs_grinding(fs, grinding_bits) + fs, stir_challenges_indexes = sample_bits_dynamic(fs, num_queries, folded_domain_size) + + answers = Array( + num_queries + ) # a vector of pointers, each pointing to `two_pow_folding_factor` field elements (base if first rounds, extension otherwise) + + n_chunks_per_answer: Imu + # the number of chunk of 8 field elements per merkle leaf opened + if merkle_leaves_in_basefield == 1: + n_chunks_per_answer = two_pow_folding_factor + else: + n_chunks_per_answer = two_pow_folding_factor * DIM + + for i in range(0, num_queries): + fs, answer = fs_hint(fs, n_chunks_per_answer) + answers[i] = answer + + leaf_hashes = Array(num_queries) # a vector of vectorized pointers, each pointing to 1 chunk of 8 field elements + batch_hash_slice(num_queries, answers, leaf_hashes, n_chunks_per_answer / VECTOR_LEN) + + fs, merkle_paths = fs_hint(fs, folded_domain_size * num_queries * VECTOR_LEN) + + # Merkle verification + merkle_verif_batch( + num_queries, + merkle_paths, + leaf_hashes, + stir_challenges_indexes, + prev_root, + folded_domain_size, + num_queries, + ) + + folds = Array(num_queries * DIM) + + poly_eq = poly_eq_extension_dynamic(folding_randomness, folding_factor) + + if merkle_leaves_in_basefield == 1: + for i in range(0, num_queries): + dot_product(answers[i], poly_eq, folds + i * DIM, 2 ** FOLDING_FACTORS_BASE[0], BE) + else: + for i in range(0, num_queries): + dot_product_ee_dynamic(answers[i], poly_eq, folds + i * DIM, two_pow_folding_factor) + + circle_values = Array(num_queries) # ROOT^each_stir_index + for i in range(0, num_queries): + stir_index_bits = stir_challenges_indexes[i] + circle_value = unit_root_pow_dynamic(folded_domain_size, stir_index_bits) + circle_values[i] = circle_value + + return fs, circle_values, folds + + +def whir_round( + fs: Mut, + prev_root, + folding_factor, + two_pow_folding_factor, + merkle_leaves_in_basefield, + num_queries, + domain_size, + claimed_sum, + grinding_bits, + num_ood, +): + fs, folding_randomness, new_claimed_sum_a = sumcheck_verify(fs, folding_factor, claimed_sum, 2) + + fs, root, ood_points, ood_evals = parse_commitment(fs, num_ood) + + fs, circle_values, folds = sample_stir_indexes_and_fold( + fs, + num_queries, + merkle_leaves_in_basefield, + folding_factor, + two_pow_folding_factor, + domain_size, + prev_root, + folding_randomness, + grinding_bits, + ) + + combination_randomness_gen = fs_sample_ef(fs) + + combination_randomness_powers = powers(combination_randomness_gen, num_queries + num_ood) + + claimed_sum_0 = Array(DIM) + dot_product_ee_dynamic(ood_evals, combination_randomness_powers, claimed_sum_0, num_ood) + + claimed_sum_1 = Array(DIM) + dot_product_ee_dynamic(folds, combination_randomness_powers + num_ood * DIM, claimed_sum_1, num_queries) + + new_claimed_sum_b = add_extension_ret(claimed_sum_0, claimed_sum_1) + + final_sum = add_extension_ret(new_claimed_sum_a, new_claimed_sum_b) + + return ( + fs, + folding_randomness, + ood_points, + root, + circle_values, + combination_randomness_powers, + final_sum, + ) + + +def polynomial_sum_at_0_and_1(coeffs, degree: Const): + debug_assert(1 < degree) + + res = Array(DIM * (1 + degree)) + add_extension(coeffs, coeffs, res) # constant coefficient is doubled + for i in unroll(0, degree): + add_extension(res + i * DIM, coeffs + (i + 1) * DIM, res + (i + 1) * DIM) # TODO use the dot_product precompile + return res + degree * DIM + + +def parse_commitment(fs: Mut, num_ood): + root: Imu + ood_points: Imu + ood_evals: Imu + debug_assert(num_ood < 4) + debug_assert(num_ood != 0) + match num_ood: + case 0: + _ = 0 # unreachable + case 1: + fs, root, ood_points, ood_evals = parse_whir_commitment_const(fs, 1) + case 2: + fs, root, ood_points, ood_evals = parse_whir_commitment_const(fs, 2) + case 3: + fs, root, ood_points, ood_evals = parse_whir_commitment_const(fs, 3) + case 4: + fs, root, ood_points, ood_evals = parse_whir_commitment_const(fs, 4) + return fs, root, ood_points, ood_evals + + +def parse_whir_commitment_const(fs: Mut, num_ood: Const): + fs, root = fs_receive_chunks(fs, 1) + ood_points = Array(num_ood * DIM) + for i in unroll(0, num_ood): + ood_point = fs_sample_ef(fs) + copy_5(ood_point, ood_points + i * DIM) + fs = duplexing(fs) + fs, ood_evals = fs_receive_ef(fs, num_ood) + return fs, root, ood_points, ood_evals diff --git a/crates/rec_aggregation/whir_recursion.snark b/crates/rec_aggregation/whir_recursion.snark deleted file mode 100644 index 84bf963d..00000000 --- a/crates/rec_aggregation/whir_recursion.snark +++ /dev/null @@ -1,1095 +0,0 @@ - -// 1 OOD QUERY PER ROUND -> TODO -// 0 GRINDING IN SUMCHECK -> TODO - -const COMPRESSION = 1; -const PERMUTATION = 0; - -const F_BITS = 31; // koala-bear = 31 bits -const DIM = 5; // extension degree - -const N_VARS = N_VARS_PLACEHOLDER; -const LOG_INV_RATE = LOG_INV_RATE_PLACEHOLDER; -const N_ROUNDS = 3; // TODO make it a parameter - -const FOLDING_FACTOR_0 = FOLDING_FACTOR_0_PLACEHOLDER; -const FOLDING_FACTOR_1 = FOLDING_FACTOR_1_PLACEHOLDER; -const FOLDING_FACTOR_2 = FOLDING_FACTOR_2_PLACEHOLDER; -const FOLDING_FACTOR_3 = FOLDING_FACTOR_3_PLACEHOLDER; - -const FINAL_VARS = N_VARS - (FOLDING_FACTOR_0 + FOLDING_FACTOR_1 + FOLDING_FACTOR_2 + FOLDING_FACTOR_3); - -const RS_REDUCTION_FACTOR_0 = RS_REDUCTION_FACTOR_0_PLACEHOLDER; -const RS_REDUCTION_FACTOR_1 = 1; -const RS_REDUCTION_FACTOR_2 = 1; -const RS_REDUCTION_FACTOR_3 = 1; - -const NUM_QUERIES_0 = NUM_QUERIES_0_PLACEHOLDER; -const NUM_QUERIES_1 = NUM_QUERIES_1_PLACEHOLDER; -const NUM_QUERIES_2 = NUM_QUERIES_2_PLACEHOLDER; -const NUM_QUERIES_3 = NUM_QUERIES_3_PLACEHOLDER; - -const GRINDING_BITS_0 = GRINDING_BITS_0_PLACEHOLDER; -const GRINDING_BITS_1 = GRINDING_BITS_1_PLACEHOLDER; -const GRINDING_BITS_2 = GRINDING_BITS_2_PLACEHOLDER; -const GRINDING_BITS_3 = GRINDING_BITS_3_PLACEHOLDER; - -const MERKLE_HEIGHT_0 = N_VARS + LOG_INV_RATE - FOLDING_FACTOR_0; -const MERKLE_HEIGHT_1 = N_VARS + LOG_INV_RATE - FOLDING_FACTOR_1 - RS_REDUCTION_FACTOR_0; -const MERKLE_HEIGHT_2 = MERKLE_HEIGHT_1 - RS_REDUCTION_FACTOR_1; -const MERKLE_HEIGHT_3 = MERKLE_HEIGHT_2 - RS_REDUCTION_FACTOR_2; - -const TWO_ADICITY = 24; -const ROOT = 1791270792; // of order 2^TWO_ADICITY - -const N_RECURSIONS = N_RECURSIONS_PLACEHOLDER; -const WHIR_PROOF_SIZE = WHIR_PROOF_SIZE_PLACEHOLDER; // vectorized - -fn main() { - for i in 0..N_RECURSIONS unroll { - whir_recursion((public_input_start / 8) + i * WHIR_PROOF_SIZE); - } - return; -} - -fn whir_recursion(transcript_start) { - fs_state = fs_new(transcript_start); - - fs_state_0, root_0, ood_point_0, ood_eval_0 = parse_commitment(fs_state); - - // In the future point / eval will come from the PIOP - point_vector_len = next_multiple_of_8(N_VARS * DIM); - fs_state_1, pcs_point_vec = fs_hint(fs_state_0, point_vector_len / 8); - pcs_point = pcs_point_vec * 8; - fs_state_3, pcs_eval_vec = fs_hint(fs_state_1, 1); - pcs_eval = pcs_eval_vec * 8; - - fs_state_4, combination_randomness_gen_0 = fs_sample_ef(fs_state_3); - - claimed_sum_side = mul_extension_ret(combination_randomness_gen_0, pcs_eval); - claimed_sum_0 = add_extension_ret(ood_eval_0, claimed_sum_side); - domain_size_0 = N_VARS + LOG_INV_RATE; - fs_state_5, folding_randomness_1, ood_point_1, root_1, circle_values_1, combination_randomness_powers_1, claimed_sum_1 = - whir_round(fs_state_4, root_0, FOLDING_FACTOR_0, 2**FOLDING_FACTOR_0, 1, NUM_QUERIES_0, domain_size_0, claimed_sum_0, GRINDING_BITS_0); - - domain_size_1 = domain_size_0 - RS_REDUCTION_FACTOR_0; - fs_state_6, folding_randomness_2, ood_point_2, root_2, circle_values_2, combination_randomness_powers_2, claimed_sum_2 = - whir_round(fs_state_5, root_1, FOLDING_FACTOR_1, 2**FOLDING_FACTOR_1, 0, NUM_QUERIES_1, domain_size_1, claimed_sum_1, GRINDING_BITS_1); - - domain_size_2 = domain_size_1 - RS_REDUCTION_FACTOR_1; - fs_state_7, folding_randomness_3, ood_point_3, root_3, circle_values_3, combination_randomness_powers_3, claimed_sum_3 = - whir_round(fs_state_6, root_2, FOLDING_FACTOR_2, 2**FOLDING_FACTOR_2, 0, NUM_QUERIES_2, domain_size_2, claimed_sum_2, GRINDING_BITS_2); - - domain_size_3 = domain_size_2 - RS_REDUCTION_FACTOR_2; - fs_state_8, folding_randomness_4, final_claimed_sum = sumcheck(fs_state_7, FOLDING_FACTOR_3, claimed_sum_3); - - fs_state_9, final_coeffcients_unpacked = fs_receive(fs_state_8, 2**FINAL_VARS); - final_coeffcients = malloc(2**FINAL_VARS * DIM); - for i in 0..2**FINAL_VARS { - assert_eq_extension((final_coeffcients_unpacked + i) * 8, final_coeffcients + i*DIM); - } - - fs_state_10, final_circle_values, final_folds = - sample_stir_indexes_and_fold(fs_state_9, NUM_QUERIES_3, 0, FOLDING_FACTOR_3, 2**FOLDING_FACTOR_3, domain_size_3, root_3, folding_randomness_4, GRINDING_BITS_3); - - for i in 0..NUM_QUERIES_3 { - powers_of_2_rev = expand_from_univariate_const(final_circle_values[i], FINAL_VARS); - poly_eq = poly_eq_base(powers_of_2_rev, FINAL_VARS); - final_pol_evaluated_on_circle = malloc(DIM); - dot_product_be(poly_eq, final_coeffcients, final_pol_evaluated_on_circle, 2**FINAL_VARS); - assert_eq_extension(final_pol_evaluated_on_circle, final_folds + i*DIM); - } - - fs_state_11, folding_randomness_5, end_sum = sumcheck(fs_state_10, FINAL_VARS, final_claimed_sum); - - folding_randomness_global = malloc(N_VARS * DIM); - - ffs = malloc(N_ROUNDS + 2); - ffs[0] = FOLDING_FACTOR_0; ffs[1] = FOLDING_FACTOR_1; ffs[2] = FOLDING_FACTOR_2; ffs[3] = FOLDING_FACTOR_3; ffs[4] = FINAL_VARS; - frs = malloc(N_ROUNDS + 2); - frs[0] = folding_randomness_1; frs[1] = folding_randomness_2; frs[2] = folding_randomness_3; frs[3] = folding_randomness_4; frs[4] = folding_randomness_5; - ffs_sums = malloc(N_ROUNDS + 3); - ffs_sums[0] = 0; - for i in 0..N_ROUNDS + 2 { - ffs_sums[i + 1] = ffs_sums[i] + ffs[i]; - } - for i in 0..N_ROUNDS + 2 { - start = folding_randomness_global + ffs_sums[i] * DIM; - for j in 0..ffs[i] { - assert_eq_extension(frs[i] + j*DIM, start + j*DIM); - } - } - - ood_0_expanded_from_univariate = expand_from_univariate(ood_point_0, N_VARS); - s0 = eq_mle_extension(ood_0_expanded_from_univariate, folding_randomness_global, N_VARS); - s1 = eq_mle_extension(pcs_point, folding_randomness_global, N_VARS); - s3 = mul_extension_ret(s1, combination_randomness_gen_0); - s4 = add_extension_ret(s0, s3); - - weight_sums = malloc(N_ROUNDS + 1); - weight_sums[0] = s4; - - ood_points = malloc(N_ROUNDS + 1); ood_points[0] = ood_point_0; ood_points[1] = ood_point_1; ood_points[2] = ood_point_2; ood_points[3] = ood_point_3; - num_queries = malloc(N_ROUNDS + 1); num_queries[0] = NUM_QUERIES_0; num_queries[1] = NUM_QUERIES_1; num_queries[2] = NUM_QUERIES_2; num_queries[3] = NUM_QUERIES_3; - circle_values = malloc(N_ROUNDS + 1); circle_values[0] = circle_values_1; circle_values[1] = circle_values_2; circle_values[2] = circle_values_3; circle_values[3] = final_circle_values; - combination_randomness_powers = malloc(N_ROUNDS); combination_randomness_powers[0] = combination_randomness_powers_1; combination_randomness_powers[1] = combination_randomness_powers_2; combination_randomness_powers[2] = combination_randomness_powers_3; - - for i in 0..N_ROUNDS { - ood_expanded_from_univariate = expand_from_univariate(ood_points[i + 1], N_VARS - ffs_sums[i+1]); - s5 = eq_mle_extension(ood_expanded_from_univariate, folding_randomness_global + ffs_sums[i+1]*DIM, N_VARS - ffs_sums[i+1]); - s6s = malloc((num_queries[i] + 1) * DIM); - assert_eq_extension(s5, s6s); - circle_value_i = circle_values[i]; - for j in 0..num_queries[i] { - expanded_from_univariate = expand_from_univariate_dynamic(circle_value_i[j], N_VARS - ffs_sums[i+1]); - temp = eq_mle_extension_base_dynamic(expanded_from_univariate, folding_randomness_global + ffs_sums[i+1]*DIM, N_VARS - ffs_sums[i+1]); - assert_eq_extension(temp, s6s + (j + 1) * DIM); - } - s7 = malloc(DIM); - dot_product_ee_dynamic(s6s, combination_randomness_powers[i], s7, num_queries[i] + 1); - wsum = add_extension_ret(weight_sums[i], s7); - weight_sums[i+1] = wsum; - } - evaluation_of_weights = weight_sums[N_ROUNDS]; // not good - poly_eq_final = poly_eq_extension(folding_randomness_5, FINAL_VARS, 2**FINAL_VARS); - final_value = malloc(DIM); - dot_product_ee(poly_eq_final, final_coeffcients, final_value, 2**FINAL_VARS); - evaluation_of_weights_times_final_value = mul_extension_ret(evaluation_of_weights, final_value); - assert_eq_extension(evaluation_of_weights_times_final_value, end_sum); - return; -} - -fn eq_mle_extension(a, b, n) -> 1 { - - buff = malloc(n*DIM); - - for i in 0..n { - shift = i * DIM; - ai = a + shift; - bi = b + shift; - buffi = buff + shift; - ab = mul_extension_ret(ai, bi); - buffi[0] = 1 + 2 * ab[0] - ai[0] - bi[0]; - for j in 1..DIM unroll { - buffi[j] = 2 * ab[j] - ai[j] - bi[j]; - } - } - - prods = malloc(n*DIM); - assert_eq_extension(buff, prods); - for i in 0..n - 1 { - mul_extension(prods + i*DIM, buff + (i + 1)*DIM, prods + (i + 1)*DIM); - } - - return prods + (n - 1) * DIM; -} - -fn eq_mle_extension_base_dynamic(a, b, n) -> 1 { - res = malloc(DIM); - match n { - 0 => { } // unreachable - 1 => { eq_poly_base_ext(a, b, res, 1); } - 2 => { eq_poly_base_ext(a, b, res, 2); } - 3 => { eq_poly_base_ext(a, b, res, 3); } - 4 => { eq_poly_base_ext(a, b, res, 4); } - 5 => { eq_poly_base_ext(a, b, res, 5); } - 6 => { eq_poly_base_ext(a, b, res, 6); } - 7 => { eq_poly_base_ext(a, b, res, 7); } - 8 => { eq_poly_base_ext(a, b, res, 8); } - 9 => { eq_poly_base_ext(a, b, res, 9); } - 10 => { eq_poly_base_ext(a, b, res, 10); } - 11 => { eq_poly_base_ext(a, b, res, 11); } - 12 => { eq_poly_base_ext(a, b, res, 12); } - 13 => { eq_poly_base_ext(a, b, res, 13); } - 14 => { eq_poly_base_ext(a, b, res, 14); } - 15 => { eq_poly_base_ext(a, b, res, 15); } - 16 => { eq_poly_base_ext(a, b, res, 16); } - 17 => { eq_poly_base_ext(a, b, res, 17); } - 18 => { eq_poly_base_ext(a, b, res, 18); } - 19 => { eq_poly_base_ext(a, b, res, 19); } - 20 => { eq_poly_base_ext(a, b, res, 20); } - 21 => { eq_poly_base_ext(a, b, res, 21); } - 22 => { eq_poly_base_ext(a, b, res, 22); } - } - return res; -} - -fn expand_from_univariate_dynamic(alpha, n) -> 1 { - var res; - match n { - 0 => { } // unreachable - 1 => { res = expand_from_univariate_const(alpha, 1); } - 2 => { res = expand_from_univariate_const(alpha, 2); } - 3 => { res = expand_from_univariate_const(alpha, 3); } - 4 => { res = expand_from_univariate_const(alpha, 4); } - 5 => { res = expand_from_univariate_const(alpha, 5); } - 6 => { res = expand_from_univariate_const(alpha, 6); } - 7 => { res = expand_from_univariate_const(alpha, 7); } - 8 => { res = expand_from_univariate_const(alpha, 8); } - 9 => { res = expand_from_univariate_const(alpha, 9); } - 10 => { res = expand_from_univariate_const(alpha, 10); } - 11 => { res = expand_from_univariate_const(alpha, 11); } - 12 => { res = expand_from_univariate_const(alpha, 12); } - 13 => { res = expand_from_univariate_const(alpha, 13); } - 14 => { res = expand_from_univariate_const(alpha, 14); } - 15 => { res = expand_from_univariate_const(alpha, 15); } - 16 => { res = expand_from_univariate_const(alpha, 16); } - 17 => { res = expand_from_univariate_const(alpha, 17); } - 18 => { res = expand_from_univariate_const(alpha, 18); } - 19 => { res = expand_from_univariate_const(alpha, 19); } - 20 => { res = expand_from_univariate_const(alpha, 20); } - 21 => { res = expand_from_univariate_const(alpha, 21); } - 22 => { res = expand_from_univariate_const(alpha, 22); } - } - return res; -} - -fn expand_from_univariate_const(alpha, const n) -> 1 { - // "expand_from_univariate" - // alpha: F - - res = malloc(n); - res[0] = alpha; - for i in 0..n-1 unroll { - res[i+1] = res[i] * res[i]; - } - return res; -} - -fn expand_from_univariate(alpha, n) -> 1 { - res = malloc(n*DIM); - assert_eq_extension(alpha, res); - for i in 0..n-1 { - mul_extension(res + i*DIM, res + i*DIM, res + (i + 1)*DIM); - } - return res; -} - -fn sumcheck(fs_state, n_steps, claimed_sum) -> 3 { - - fs_states_a = malloc(n_steps + 1); - fs_states_a[0] = fs_state; - - claimed_sums = malloc(n_steps + 1); - claimed_sums[0] = claimed_sum; - - folding_randomness = malloc(n_steps * DIM); - - for sc_round in 0..n_steps { - fs_state_5, poly = fs_receive_ef(fs_states_a[sc_round], 3); - sum_over_boolean_hypercube = degree_two_polynomial_sum_at_0_and_1(poly); - assert_eq_extension(sum_over_boolean_hypercube, claimed_sums[sc_round]); - fs_state_6, rand = fs_sample_ef(fs_state_5); - fs_states_a[sc_round + 1] = fs_state_6; - new_claimed_sum = degree_two_polynomial_eval(poly, rand); - claimed_sums[sc_round + 1] = new_claimed_sum; - assert_eq_extension(rand, folding_randomness + sc_round * DIM); - } - - new_state = fs_states_a[n_steps]; - new_claimed_sum = claimed_sums[n_steps]; - - return new_state, folding_randomness, new_claimed_sum; -} - -fn sample_stir_indexes_and_fold(fs_state, num_queries, merkle_leaves_in_basefield, folding_factor, two_pow_folding_factor, domain_size, prev_root, folding_randomness, grinding_bits) -> 3 { - - folded_domain_size = domain_size - folding_factor; - - fs_state_8 = fs_grinding(fs_state, grinding_bits); - fs_state_9, stir_challenges_indexes = sample_bits_dynamic(fs_state_8, num_queries, folded_domain_size); - - answers = malloc(num_queries); // a vector of vectorized pointers, each pointing to `two_pow_folding_factor` field elements (base if first rounds, extension otherwise) - fs_states_b = malloc(num_queries + 1); - fs_states_b[0] = fs_state_9; - - var n_chunks_per_answer; - // the number of chunk of 8 field elements per merkle leaf opened - if merkle_leaves_in_basefield == 1 { - n_chunks_per_answer = two_pow_folding_factor / 8; // "/ 8" because initial merkle leaves are in the basefield - } else { - n_chunks_per_answer = two_pow_folding_factor * DIM / 8; - } - - for i in 0..num_queries { - new_fs_state, answer = fs_hint(fs_states_b[i], n_chunks_per_answer); - fs_states_b[i + 1] = new_fs_state; - answers[i] = answer; - } - fs_state_10 = fs_states_b[num_queries]; - - leaf_hashes = malloc(num_queries); // a vector of vectorized pointers, each pointing to 1 chunk of 8 field elements - batch_hash_slice_dynamic(num_queries, answers, leaf_hashes, n_chunks_per_answer); - - // Merkle verification - merkle_verif_batch_dynamic(num_queries, leaf_hashes, stir_challenges_indexes + num_queries, prev_root, folded_domain_size); - - folds = malloc(num_queries * DIM); - - poly_eq = poly_eq_extension(folding_randomness, folding_factor, two_pow_folding_factor); - - if merkle_leaves_in_basefield == 1 { - for i in 0..num_queries { - dot_product_be(answers[i] * 8, poly_eq, folds + i*DIM, 2**FOLDING_FACTOR_0); - } - } else { - for i in 0..num_queries { - dot_product_ee_dynamic(answers[i] * 8, poly_eq, folds + i*DIM, two_pow_folding_factor); - } - } - - circle_values = malloc(num_queries); // ROOT^each_stir_index - for i in 0..num_queries { - stir_index_bits = stir_challenges_indexes[i]; - circle_value = unit_root_pow_dynamic(folded_domain_size, stir_index_bits); - circle_values[i] = circle_value; - } - - return fs_state_10, circle_values, folds; -} - -fn batch_hash_slice_dynamic(num_queries, all_data_to_hash, all_resulting_hashes, len) { - if len == DIM * 2 { - batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, DIM); - return; - } - if len == 16 { - batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, 8); - return; - } - TODO_batch_hash_slice_dynamic = len; - print(77777123); - print(len); - panic(); -} - -fn batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, const half_len) { - for i in 0..num_queries { - data = all_data_to_hash[i]; - res = malloc_vec(1); - slice_hash(pointer_to_zero_vector, data, res, half_len); - all_resulting_hashes[i] = res; - } - return; -} - -fn merkle_verif_batch_dynamic(n_paths, leaves_digests, leave_positions, root, height) { - if height == MERKLE_HEIGHT_0 { - merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_0); - return; - } else if height == MERKLE_HEIGHT_1 { - merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_1); - return; - } else if height == MERKLE_HEIGHT_2 { - merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_2); - return; - } else if height == MERKLE_HEIGHT_3 { - merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_3); - return; - } - - print(12345555); - print(height); - panic(); -} - -fn merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, const height) { - // n_paths: F - // leaves_digests: pointer to a slice of n_paths vectorized pointers, each pointing to 1 chunk of 8 field elements - // leave_positions: pointer to a slice of n_paths field elements (each < 2^height) - // root: vectorized pointer to 1 chunk of 8 field elements - // height: F - - for i in 0..n_paths { - merkle_verify(leaves_digests[i], leave_positions[i], root, height); - } - - return; -} - - -fn whir_round(fs_state, prev_root, folding_factor, two_pow_folding_factor, merkle_leaves_in_basefield, num_queries, domain_size, claimed_sum, grinding_bits) -> 7 { - fs_state_7, folding_randomness, new_claimed_sum_a = sumcheck(fs_state, folding_factor, claimed_sum); - - fs_state_8, root, ood_point, ood_eval = parse_commitment(fs_state_7); - - fs_state_11, circle_values, folds = - sample_stir_indexes_and_fold(fs_state_8, num_queries, merkle_leaves_in_basefield, folding_factor, two_pow_folding_factor, domain_size, prev_root, folding_randomness, grinding_bits); - - fs_state_12, combination_randomness_gen = fs_sample_ef(fs_state_11); - - combination_randomness_powers = powers(combination_randomness_gen, num_queries + 1); // "+ 1" because of one OOD sample - - claimed_sum_supplement_side = malloc(5); - dot_product_ee_dynamic(folds, combination_randomness_powers + DIM, claimed_sum_supplement_side, num_queries); - - claimed_sum_supplement = add_extension_ret(claimed_sum_supplement_side, ood_eval); - new_claimed_sum_b = add_extension_ret(claimed_sum_supplement, new_claimed_sum_a); - - return fs_state_12, folding_randomness, ood_point, root, circle_values, combination_randomness_powers, new_claimed_sum_b; -} - -fn copy_chunk(src, dst) { - // src: pointer to 8 F - // dst: pointer to 8 F - for i in 0..8 unroll { dst[i] = src[i]; } - return; -} - -fn copy_chunk_vec(src, dst) { - zero = 0; // TODO - add_extension(src, zero, dst); - return; -} - -fn powers(alpha, n) -> 1 { - // alpha: EF - // n: F - - res = malloc(n * DIM); - set_to_one(res); - for i in 0..n - 1 { - mul_extension(res + i*DIM, alpha, res + (i + 1)*DIM); - } - return res; -} - -fn unit_root_pow_dynamic(domain_size, index_bits) -> 1 { - // index_bits is a pointer to domain_size bits - var res; - match domain_size { - 0 => { } // unreachable - 1 => { res = unit_root_pow_const(1, index_bits); } - 2 => { res = unit_root_pow_const(2, index_bits); } - 3 => { res = unit_root_pow_const(3, index_bits); } - 4 => { res = unit_root_pow_const(4, index_bits); } - 5 => { res = unit_root_pow_const(5, index_bits); } - 6 => { res = unit_root_pow_const(6, index_bits); } - 7 => { res = unit_root_pow_const(7, index_bits); } - 8 => { res = unit_root_pow_const(8, index_bits); } - 9 => { res = unit_root_pow_const(9, index_bits); } - 10 => { res = unit_root_pow_const(10, index_bits); } - 11 => { res = unit_root_pow_const(11, index_bits); } - 12 => { res = unit_root_pow_const(12, index_bits); } - 13 => { res = unit_root_pow_const(13, index_bits); } - 14 => { res = unit_root_pow_const(14, index_bits); } - 15 => { res = unit_root_pow_const(15, index_bits); } - 16 => { res = unit_root_pow_const(16, index_bits); } - 17 => { res = unit_root_pow_const(17, index_bits); } - 18 => { res = unit_root_pow_const(18, index_bits); } - 19 => { res = unit_root_pow_const(19, index_bits); } - 20 => { res = unit_root_pow_const(20, index_bits); } - 21 => { res = unit_root_pow_const(21, index_bits); } - 22 => { res = unit_root_pow_const(22, index_bits); } - } - return res; -} - -fn unit_root_pow_const(const domain_size, index_bits) -> 1 { - prods = malloc(domain_size); - prods[0] = ((index_bits[0] * ROOT**(2**(TWO_ADICITY - domain_size))) + (1 - index_bits[0])); - for i in 1..domain_size unroll { - prods[i] = prods[i - 1] * ((index_bits[i] * ROOT**(2**(TWO_ADICITY - domain_size + i))) + (1 - index_bits[i])); - } - return prods[domain_size - 1]; -} - -fn dot_product_ee_dynamic(a, b, res, n) { - if n == 16 { - dot_product_ee(a, b, res, 16); - return; - } else { - if n == NUM_QUERIES_0 { - dot_product_ee(a, b, res, NUM_QUERIES_0); - return; - } else { - dot_product_ee_dynamic_helper_1(a, b, res, n); - return; - } - } - -} - -fn dot_product_ee_dynamic_helper_1(a, b, res, n) { - if n == NUM_QUERIES_1 { - dot_product_ee(a, b, res, NUM_QUERIES_1); - return; - } else { - if n == NUM_QUERIES_2 { - dot_product_ee(a, b, res, NUM_QUERIES_2); - return; - } else { - dot_product_ee_dynamic_helper_3(a, b, res, n); - return; - } - } -} - -fn dot_product_ee_dynamic_helper_3(a, b, res, n) { - if n == NUM_QUERIES_3 { - dot_product_ee(a, b, res, NUM_QUERIES_3); - return; - } else { - if n == NUM_QUERIES_0 + 1 { - dot_product_ee(a, b, res, NUM_QUERIES_0 + 1); - return; - } else { - dot_product_ee_dynamic_helper_4(a, b, res, n); - return; - } - } -} - -fn dot_product_ee_dynamic_helper_4(a, b, res, n) { - if n == NUM_QUERIES_1 + 1 { - dot_product_ee(a, b, res, NUM_QUERIES_1 + 1); - return; - } else { - if n == NUM_QUERIES_2 + 1 { - dot_product_ee(a, b, res, NUM_QUERIES_2 + 1); - return; - } else { - if n == NUM_QUERIES_3 + 1 { - dot_product_ee(a, b, res, NUM_QUERIES_3 + 1); - return; - } - - TODO_dot_product_ee_dynamic = 0; - print(TODO_dot_product_ee_dynamic, n); - panic(); - } - } -} - -fn poly_eq_extension(point, n, two_pow_n) -> 1 { - // Example: for n = 2: eq(x, y) = [(1 - x)(1 - y), (1 - x)y, x(1 - y), xy] - - if n == 0 { - res = malloc(DIM); - set_to_one(res); - return res; - } else { - res = malloc(two_pow_n * DIM); - - inner_res = poly_eq_extension(point + DIM, n - 1, two_pow_n / 2); - - two_pow_n_minus_1 = two_pow_n / 2; - - for i in 0..two_pow_n_minus_1 { - mul_extension(point, inner_res + i*DIM, res + (two_pow_n_minus_1 + i) * DIM); - sub_extension(inner_res + i*DIM, res + (two_pow_n_minus_1 + i) * DIM, res + i*DIM); - } - - return res; - } -} - -fn poly_eq_base(point, n) -> 1 { - var res; - match n { - 0 => { } // unreachable - 1 => { res = poly_eq_base_1(point); } - 2 => { res = poly_eq_base_2(point); } - 3 => { res = poly_eq_base_3(point); } - 4 => { res = poly_eq_base_4(point); } - 5 => { res = poly_eq_base_5(point); } - 6 => { res = poly_eq_base_6(point); } - 7 => { res = poly_eq_base_7(point); } - } - return res; -} - -fn poly_eq_base_7(point) -> 1 { - // n = 7 - // return a (normal) pointer to 2^n base field elements, corresponding to the "equality polynomial" at point - // Example: for n = 2: eq(x, y) = [(1 - x)(1 - y), (1 - x)y, x(1 - y), xy] - - res = malloc(128); - - inner_res = poly_eq_base_6(point + 1); - - for i in 0..64 unroll { - res[64 + i] = inner_res[i] * point[0]; - res[i] = inner_res[i] - res[64 + i]; - } - - return res; -} - -fn poly_eq_base_6(point) -> 1 { - // n = 6 - res = malloc(64); - - inner_res = poly_eq_base_5(point + 1); - - for i in 0..32 unroll { - res[32 + i] = inner_res[i] * point[0]; - res[i] = inner_res[i] - res[32 + i]; - } - - return res; -} - -fn poly_eq_base_5(point) -> 1 { - // n = 5 - res = malloc(32); - - inner_res = poly_eq_base_4(point + 1); - - for i in 0..16 unroll { - res[16 + i] = inner_res[i] * point[0]; - res[i] = inner_res[i] - res[16 + i]; - } - - return res; -} - -fn poly_eq_base_4(point) -> 1 { - // n = 4 - res = malloc(16); - - inner_res = poly_eq_base_3(point + 1); - - for i in 0..8 unroll { - res[8 + i] = inner_res[i] * point[0]; - res[i] = inner_res[i] - res[8 + i]; - } - - return res; -} - -fn poly_eq_base_3(point) -> 1 { - // n = 3 - res = malloc(8); - - inner_res = poly_eq_base_2(point + 1); - - for i in 0..4 unroll { - res[4 + i] = inner_res[i] * point[0]; - res[i] = inner_res[i] - res[4 + i]; - } - - return res; -} - -fn poly_eq_base_2(point) -> 1 { - // n = 2 - res = malloc(4); - - inner_res = poly_eq_base_1(point + 1); - - for i in 0..2 unroll { - res[2 + i] = inner_res[i] * point[0]; - res[i] = inner_res[i] - res[2 + i]; - } - - return res; -} - -fn poly_eq_base_1(point) -> 1 { - // n = 1 - // Base case: eq(x) = [1 - x, x] - res = malloc(2); - - res[1] = point[0]; - res[0] = 1 - res[1]; - - return res; -} - - -fn pow(a, b) -> 1 { - if b == 0 { - return 1; // a^0 = 1 - } else { - p = pow(a, b - 1); - return a * p; - } -} - -fn sample_bits_dynamic(fs_state, n_samples, K) -> 2 { - var new_fs_state; - var sampled_bits; - if n_samples == NUM_QUERIES_0 { - new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_0, K); - return new_fs_state, sampled_bits; - } else if n_samples == NUM_QUERIES_1 { - new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_1, K); - return new_fs_state, sampled_bits; - } else if n_samples == NUM_QUERIES_2 { - new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_2, K); - return new_fs_state, sampled_bits; - } else if n_samples == NUM_QUERIES_3 { - new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_3, K); - return new_fs_state, sampled_bits; - } - print(n_samples); - print(999333); - panic(); -} - -fn sample_bits_const(fs_state, const n_samples, K) -> 2 { - // return the updated fs_state, and a pointer to n pointers, each pointing to 31 (boolean) field elements, - // ... followed by the n corresponding sampled field elements (where we only look at the first K bits) - samples = malloc(n_samples); - new_fs_state = fs_sample_helper(fs_state, n_samples, samples); - sampled_bits = malloc(n_samples * 2); - for i in 0..n_samples unroll { - bits, partial_sum = checked_decompose_bits(samples[i], K); - sampled_bits[i] = bits; - sampled_bits[n_samples + i] = partial_sum; - } - - return new_fs_state, sampled_bits; -} - -// p = 2^31 - 2^24 + 1 -// in binary: p = 1111111000000000000000000000001 -// p - 1 = 1111111000000000000000000000000 -// p - 2 = 1111110111111111111111111111111 -// p - 3 = 1111110111111111111111111111110 -// ... -// Any field element (< p) is either: -// - 1111111 | 00...00 -// - not(1111111) | xx...xx -fn checked_decompose_bits(a, k) -> 2 { - // return a pointer to the 31 bits of a - // .. and the partial value, reading the first K bits (with k <= 24) - bits = malloc(F_BITS); - hint_decompose_bits(a, bits); - - for i in 0..F_BITS unroll { - assert bits[i] * (1 - bits[i]) == 0; - } - sums_24_first_bits = malloc(24); - sums_24_first_bits[0] = bits[0]; - for i in 1..24 unroll { - sums_24_first_bits[i] = sums_24_first_bits[i - 1] + bits[i] * 2**i; - } - sums_7_last_bits = malloc(7); - sums_7_last_bits[0] = bits[24]; - for i in 1..7 unroll { - sums_7_last_bits[i] = sums_7_last_bits[i - 1] + bits[24 + i] * 2**i; - } - if sums_7_last_bits[6] == 127 { - assert sums_24_first_bits[23] == 0; - } - - assert a == sums_24_first_bits[23] + sums_7_last_bits[6] * 2**24; - partial_sum = sums_24_first_bits[k - 1]; - return bits, partial_sum; -} - -fn degree_two_polynomial_sum_at_0_and_1(coeffs) -> 1 { - // coeffs is a normal pointer to 3 consecutive EF element - // return a normal pointer to 1 ef element - a = add_extension_ret(coeffs, coeffs); - b = add_extension_ret(a, coeffs + DIM); - c = add_extension_ret(b, coeffs + (DIM * 2)); - return c; -} - -fn degree_two_polynomial_eval(coeffs, point) -> 1 { - // coefs: normal pointer to 3 consecutive EF element - // point: normal pointer to 1 EF element - // return a normal pointer to 1 EF element - point_squared = mul_extension_ret(point, point); - a_xx = mul_extension_ret(coeffs + DIM * 2, point_squared); - b_x = mul_extension_ret(coeffs + DIM, point); - c = coeffs; - res_0 = add_extension_ret(a_xx, b_x); - res_1 = add_extension_ret(res_0, c); - return res_1; -} - -fn parse_commitment(fs_state) -> 4 { - fs_state_1, root = fs_receive(fs_state, 1); // vectorized pointer of len 1 - fs_state_2, ood_point = fs_sample_ef(fs_state_1); - fs_state_3, ood_eval = fs_receive_ef(fs_state_2, 1); - return fs_state_3, root, ood_point, ood_eval; -} - -// FIAT SHAMIR layout: -// 0 -> transcript (vectorized pointer) -// 1 -> vectorized pointer to first half of sponge state -// 2 -> vectorized pointer to second half of sponge state -// 3 -> output_buffer_size - -fn fs_new(transcript) -> 1 { - // transcript is a (vectorized) pointer - // TODO domain separator - fs_state = malloc(4); - fs_state[0] = transcript; - fs_state[1] = pointer_to_zero_vector; // first half of sponge state - fs_state[2] = pointer_to_zero_vector; // second half of sponge state - fs_state[3] = 0; // output buffer size - - return fs_state; -} - -fn fs_grinding(fs_state, bits) -> 1 { - // WARNING: should not be called 2 times in a row without duplexing in between - - if bits == 0 { - return fs_state; // no grinding - } - - transcript_ptr = fs_state[0] * 8; - l_ptr = fs_state[1] * 8; - - new_l = malloc_vec(1); - new_l_ptr = new_l * 8; - new_l_ptr[0] = transcript_ptr[0]; - for i in 1..8 unroll { - new_l_ptr[i] = l_ptr[i]; - } - - l_r_updated = malloc_vec(2); - poseidon16(new_l, fs_state[2], l_r_updated, PERMUTATION); - new_fs_state = malloc(4); - new_fs_state[0] = fs_state[0] + 1; // read one 1 chunk of 8 field elements (7 are useless) - new_fs_state[1] = l_r_updated; - new_fs_state[2] = l_r_updated + 1; - new_fs_state[3] = 7; // output_buffer_size - - l_updated_ptr = l_r_updated* 8; - sampled = l_updated_ptr[7]; - _, sampled_low_bits_value = checked_decompose_bits(sampled, bits); - assert sampled_low_bits_value == 0; - return new_fs_state; -} - -fn less_than_8(a) inline -> 1 { - // TODO range check - if a * (a - 1) * (a - 2) * (a - 3) * (a - 4) * (a - 5) * (a - 6) * (a - 7) == 0 { - return 1; // a < 8 - } else { - return 0; // a >= 8 - } -} - -fn fs_sample_ef(fs_state) -> 2 { - // return the updated fs_state, and a normal pointer to 1 EF element - res = malloc(DIM); - new_fs_state = fs_sample_helper(fs_state, DIM, res); - return new_fs_state, res; -} - -fn fs_sample_helper(fs_state, n, res) -> 1 { - // return the updated fs_state - // fill res with n field elements - - output_buffer_size = fs_state[3]; - output_buffer_ptr = fs_state[1] * 8; - - for i in 0..n { - if output_buffer_size - i == 0 { - break; - } - res[i] = output_buffer_ptr[output_buffer_size - 1 - i]; - } - - finished = less_than_8(output_buffer_size - n); - if finished == 1 { - // no duplexing - new_fs_state = malloc(4); - new_fs_state[0] = fs_state[0]; - new_fs_state[1] = fs_state[1]; - new_fs_state[2] = fs_state[2]; - new_fs_state[3] = output_buffer_size - n; - return new_fs_state; - } else { - // duplexing - l_r = malloc_vec(2); - poseidon16(fs_state[1], fs_state[2], l_r, PERMUTATION); - new_fs_state = malloc(4); - new_fs_state[0] = fs_state[0]; - new_fs_state[1] = l_r; - new_fs_state[2] = l_r + 1; - new_fs_state[3] = 8; // output_buffer_size - - remaining = n - output_buffer_size; - if remaining == 0 { - return new_fs_state; - } - - shifted_res = res + output_buffer_size; - final_res = fs_sample_helper(new_fs_state, remaining, shifted_res); - return final_res; - } -} - -fn fs_hint(fs_state, n) -> 2 { - // return the updated fs_state, and a vectorized pointer to n chunk of 8 field elements - - res = fs_state[0]; - new_fs_state = malloc(4); - new_fs_state[0] = res + n; - new_fs_state[1] = fs_state[1]; - new_fs_state[2] = fs_state[2]; - new_fs_state[3] = fs_state[3]; - return new_fs_state, res; -} - -fn fs_receive_ef(fs_state, n) -> 2 { - var final_fs_state; - var res; - match n { - 0 => { } // unreachable - 1 => { final_fs_state, res = fs_receive_ef_const(fs_state, 1); } - 2 => { final_fs_state, res = fs_receive_ef_const(fs_state, 2); } - 3 => { final_fs_state, res = fs_receive_ef_const(fs_state, 3); } - 4 => { final_fs_state, res = fs_receive_ef_const(fs_state, 4); } - 5 => { final_fs_state, res = fs_receive_ef_const(fs_state, 5); } - 6 => { final_fs_state, res = fs_receive_ef_const(fs_state, 6); } - 7 => { final_fs_state, res = fs_receive_ef_const(fs_state, 7); } - 8 => { final_fs_state, res = fs_receive_ef_const(fs_state, 8); } - 9 => { final_fs_state, res = fs_receive_ef_const(fs_state, 9); } - 10 => { final_fs_state, res = fs_receive_ef_const(fs_state, 10); } - } - return final_fs_state, res; -} - - -fn fs_receive_ef_const(fs_state, const n) -> 2 { - // return the updated fs_state, and a (normal) pointer to n consecutive EF elements - final_fs_state = fs_observe(fs_state, n); - res = malloc(n * DIM); - // TODO optimize with dot_product - for i in 0..n unroll { - ptr = (fs_state[0] + i) * 8; - for j in 0..DIM unroll { - res[i * DIM + j] = ptr[j]; - } - for j in DIM..8 unroll { - assert ptr[j] == 0; - } - } - - return final_fs_state, res; -} - -fn fs_receive(fs_state, n) -> 2 { - // return the updated fs_state, and a vectorized pointer to n chunk of 8 field elements - - res = fs_state[0]; - final_fs_state = fs_observe(fs_state, n); - return final_fs_state, res; -} - -fn fs_observe(fs_state, n) -> 1 { - // observe n chunk of 8 field elements from the transcript - // and return the updated fs_state - // duplexing - l_r = malloc_vec(2); - poseidon16(fs_state[0], fs_state[2], l_r, PERMUTATION); - new_fs_state = malloc(4); - new_fs_state[0] = fs_state[0] + 1; - new_fs_state[1] = l_r; - new_fs_state[2] = l_r + 1; - new_fs_state[3] = 8; // output_buffer_size - - if n == 1 { - return new_fs_state; - } else { - final_fs_state = fs_observe(new_fs_state, n - 1); - return final_fs_state; - } -} - -fn fs_print_state(fs_state) { - left = fs_state[1] * 8; - for i in 0..8 { - print(left[i]); - } - right = fs_state[2] * 8; - for i in 0..8 { - print(right[i]); - } - return; -} - -fn mul_extension_ret(a, b) inline -> 1 { - c = malloc(DIM); - dot_product_ee(a, b, c, 1); - return c; -} - -fn mul_extension(a, b, c) inline { - dot_product_ee(a, b, c, 1); - return; -} - -fn add_extension_ret(a, b) inline -> 1 { - // TODO if a and b are adjacent we can do it in one cycle using the dot_product precompile - c = malloc(DIM); - for i in 0..DIM unroll { - c[i] = a[i] + b[i]; - } - return c; -} - -fn add_extension(a, b, c) inline { - // TODO if a and b are adjacent we can do it in one cycle using the dot_product precompile - for i in 0..DIM unroll { - c[i] = a[i] + b[i]; - } - return; -} - -fn sub_extension(a, b, c) inline { - // TODO if a and b are adjacent we can do it in one cycle using the dot_product precompile - for i in 0..DIM unroll { - c[i] = a[i] - b[i]; - } - return; -} - -fn assert_eq_extension(a, b) inline { - dot_product_ee(a, pointer_to_one_vector * 8, b, 1); - return; -} - -// TODO improve -fn assert_eq_vec(a, b) inline { - a_ptr = a * 8; - b_ptr = b * 8; - dot_product_ee(a_ptr, pointer_to_one_vector * 8, b_ptr, 1); - dot_product_ee(a_ptr + (8 - DIM), pointer_to_one_vector * 8, b_ptr + (8 - DIM), 1); - return; -} - -fn set_to_one(a) inline { - a[0] = 1; - for i in 1..DIM unroll { a[i] = 0; } - return; -} - -fn print_vec(a) { - print_many_vec(a, 1); - return; -} - -fn print_ef(a) { - for i in 0..DIM unroll { - print(a[i]); - } - return; -} - -fn print_many_vec(a, n) { - print_many(a * 8, n * 8); - return; -} - -fn print_many(a, n) { - for i in 0..n { - print(a[i]); - } - return; -} - -fn next_multiple_of_8(const a) -> 1 { - return a + (8 - (a % 8)) % 8; -} - diff --git a/crates/rec_aggregation/xmss_aggregate.py b/crates/rec_aggregation/xmss_aggregate.py new file mode 100644 index 00000000..ea6101c5 --- /dev/null +++ b/crates/rec_aggregation/xmss_aggregate.py @@ -0,0 +1,248 @@ +from snark_lib import * + +COMPRESSION = 1 +PERMUTATION = 0 + +V = 66 +W = 4 +TARGET_SUM = 118 +MAX_LOG_LIFETIME = 30 + +V_HALF = V / 2 # V should be even + +VECTOR_LEN = 8 + +# Dot product precompile: +BE = 1 # base-extension +EE = 0 # extension-extension + + +def main(): + NONRESERVED_PROGRAM_INPUT_START_ = NONRESERVED_PROGRAM_INPUT_START + n_signatures = NONRESERVED_PROGRAM_INPUT_START_[0] + message_hash = NONRESERVED_PROGRAM_INPUT_START + 1 + all_public_keys = message_hash + VECTOR_LEN + all_log_lifetimes = all_public_keys + n_signatures * VECTOR_LEN + all_merkle_indexes = all_log_lifetimes + n_signatures + sig_sizes = all_merkle_indexes + n_signatures * MAX_LOG_LIFETIME + + mem = 0 + signatures_start = mem[PRIVATE_INPUT_START_PTR] + + for i in range(0, n_signatures): + xmss_public_key = all_public_keys + i * VECTOR_LEN + signature = signatures_start + sig_sizes[i] + log_lifetime = all_log_lifetimes[i] + merkle_index = all_merkle_indexes + i * MAX_LOG_LIFETIME + xmss_public_key_recovered = xmss_recover_pub_key(message_hash, signature, log_lifetime, merkle_index) + assert_eq_vec(xmss_public_key, xmss_public_key_recovered) + return + + +def xmss_recover_pub_key(message_hash, signature, log_lifetime, merkle_index): + # signature: randomness | chain_tips + # return the hashed xmss public key + randomness = signature + chain_tips = signature + VECTOR_LEN + merkle_path = chain_tips + V * VECTOR_LEN + + # 1) We encode message_hash + randomness into the d-th layer of the hypercube + + compressed = Array(VECTOR_LEN) + poseidon16(message_hash, randomness, compressed, COMPRESSION) + compressed_vals = Array(6) + dot_product(compressed, ONE_VEC_PTR, compressed_vals, 1, EE) + compressed_vals[5] = compressed[5] + + encoding = Array(12 * 6) + remaining = Array(6) + + hint_decompose_bits_xmss( + encoding, + remaining, + compressed_vals[0], + compressed_vals[1], + compressed_vals[2], + compressed_vals[3], + compressed_vals[4], + compressed_vals[5], + ) + + # check that the decomposition is correct + for i in unroll(0, 6): + for j in unroll(0, 12): + assert encoding[i * 12 + j] <= 3 + + assert remaining[i] <= 2**7 - 2 + + partial_sum: Mut = remaining[i] * 2**24 + for j in unroll(1, 13): + partial_sum += encoding[i * 12 + (j - 1)] * 4 ** (j - 1) + assert partial_sum == compressed_vals[i] + + # we need to check the target sum + target_sum: Mut = encoding[0] + for i in unroll(1, V): + target_sum += encoding[i] + assert target_sum == TARGET_SUM + + public_key = Array(V * VECTOR_LEN) + + # This is a trick to avoid the compiler to allocate memory "on stack". + # (Heap allocation is better here, to keep the memmory use of the different "match arms" balanced) + vector_len = VECTOR_LEN + + for i in unroll(0, V): + match encoding[i]: + case 0: + var_1 = chain_tips + i * VECTOR_LEN + var_2 = public_key + i * VECTOR_LEN + var_3 = Array(vector_len) + var_4 = Array(vector_len) + poseidon16(var_1, ZERO_VEC_PTR, var_3, COMPRESSION) + poseidon16(var_3, ZERO_VEC_PTR, var_4, COMPRESSION) + poseidon16(var_4, ZERO_VEC_PTR, var_2, COMPRESSION) + case 1: + var_3 = Array(vector_len) + var_1 = chain_tips + i * VECTOR_LEN + var_2 = public_key + i * VECTOR_LEN + poseidon16(var_1, ZERO_VEC_PTR, var_3, COMPRESSION) + poseidon16(var_3, ZERO_VEC_PTR, var_2, COMPRESSION) + case 2: + var_1 = chain_tips + i * VECTOR_LEN + var_2 = public_key + i * VECTOR_LEN + poseidon16(var_1, ZERO_VEC_PTR, var_2, COMPRESSION) + case 3: + var_1 = chain_tips + (i * VECTOR_LEN) + var_2 = public_key + (i * VECTOR_LEN) + var_3 = var_1 + 3 + var_4 = var_2 + 3 + dot_product(var_1, ONE_VEC_PTR, var_2, 1, EE) + dot_product(var_3, ONE_VEC_PTR, var_4, 1, EE) + + wots_pubkey_hashed = slice_hash(ZERO_VEC_PTR, public_key, V_HALF) + + debug_assert(log_lifetime < MAX_LOG_LIFETIME + 1) + + merkle_root: Imu + match log_lifetime: + case 0: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 0) + case 1: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 1) + case 2: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 2) + case 3: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 3) + case 4: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 4) + case 5: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 5) + case 6: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 6) + case 7: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 7) + case 8: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 8) + case 9: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 9) + case 10: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 10) + case 11: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 11) + case 12: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 12) + case 13: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 13) + case 14: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 14) + case 15: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 15) + case 16: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 16) + case 17: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 17) + case 18: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 18) + case 19: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 19) + case 20: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 20) + case 21: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 21) + case 22: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 22) + case 23: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 23) + case 24: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 24) + case 25: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 25) + case 26: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 26) + case 27: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 27) + case 28: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 28) + case 29: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 29) + case 30: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 30) + case 31: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 31) + case 32: + merkle_root = merkle_verify(wots_pubkey_hashed, merkle_path, merkle_index, 32) + + return merkle_root + + +def merkle_verify(leaf_digest, merkle_path, leaf_position_bits, height: Const): + states = Array(height * VECTOR_LEN) + + # First merkle round + match leaf_position_bits[0]: + case 0: + poseidon16(leaf_digest, merkle_path, states, COMPRESSION) + case 1: + poseidon16(merkle_path, leaf_digest, states, COMPRESSION) + + # Remaining merkle rounds + state_indexes = Array(height) + state_indexes[0] = states + for j in unroll(1, height): + state_indexes[j] = state_indexes[j - 1] + VECTOR_LEN + # Warning: this works only if leaf_position_bits[i] is known to be boolean: + match leaf_position_bits[j]: + case 0: + poseidon16( + state_indexes[j - 1], + merkle_path + j * VECTOR_LEN, + state_indexes[j], + COMPRESSION, + ) + case 1: + poseidon16( + merkle_path + j * VECTOR_LEN, + state_indexes[j - 1], + state_indexes[j], + COMPRESSION, + ) + return state_indexes[height - 1] + + +def slice_hash(seed, data, half_len: Const): + states = Array(half_len * 2 * VECTOR_LEN) + poseidon16(ZERO_VEC_PTR, data, states, COMPRESSION) + state_indexes = Array(half_len * 2) + state_indexes[0] = states + for j in unroll(1, (half_len * 2)): + state_indexes[j] = state_indexes[j - 1] + VECTOR_LEN + poseidon16(state_indexes[j - 1], data + j * VECTOR_LEN, state_indexes[j], COMPRESSION) + return state_indexes[half_len * 2 - 1] + + +@inline +def assert_eq_vec(x, y): + dot_product(x, ONE_VEC_PTR, y, 1, EE) + dot_product(x + 3, ONE_VEC_PTR, y + 3, 1, EE) + return diff --git a/crates/rec_aggregation/xmss_aggregate.snark b/crates/rec_aggregation/xmss_aggregate.snark deleted file mode 100644 index 0eb8f47d..00000000 --- a/crates/rec_aggregation/xmss_aggregate.snark +++ /dev/null @@ -1,182 +0,0 @@ -const COMPRESSION = 1; -const PERMUTATION = 0; - -const V = 66; -const W = 4; -const TARGET_SUM = 118; - -const V_HALF = V / 2; // V should be even - -fn main() { - public_input_start_ = public_input_start; - n_signatures = public_input_start_[0]; - sig_size = public_input_start_[1]; // vectorized - message_hash = public_input_start / 8 + 1; - all_public_keys = message_hash + 1; - all_log_lifetimes = (all_public_keys + n_signatures) * 8; - all_merkle_indexes = all_log_lifetimes + n_signatures; - - signatures_start_no_vec = private_input_start(); - signatures_start = signatures_start_no_vec / 8; - for i in 0..n_signatures { - xmss_public_key = all_public_keys + i; - signature = signatures_start + i * sig_size; - log_lifetime = all_log_lifetimes[i]; - merkle_index = all_merkle_indexes[i]; - xmss_public_key_recovered = xmss_recover_pub_key(message_hash, signature, log_lifetime, merkle_index); - assert_eq_vec(xmss_public_key, xmss_public_key_recovered); - } - return; -} - -fn xmss_recover_pub_key(message_hash, signature, log_lifetime, merkle_index) -> 1 { - // message_hash: vectorized pointers (of length 1) - // signature: vectorized pointer = randomness | chain_tips - // return a vectorized pointer (of length 1), the hashed xmss public key - randomness = signature; // vectorized - chain_tips = signature + 1; // vectorized - - // 1) We encode message_hash + randomness into the d-th layer of the hypercube - - compressed = malloc_vec(1); - poseidon16(message_hash, randomness, compressed, COMPRESSION); - compressed_ptr = compressed * 8; - compressed_vals = malloc(6); - dot_product_ee(compressed_ptr, pointer_to_one_vector * 8, compressed_vals, 1); - compressed_vals[5] = compressed_ptr[5]; - - encoding = malloc(12 * 6); - remaining = malloc(6); - - hint_decompose_bits_xmss(encoding, remaining, compressed_vals[0], compressed_vals[1], compressed_vals[2], compressed_vals[3], compressed_vals[4], compressed_vals[5]); - - // check that the decomposition is correct - for i in 0..6 unroll { - for j in 0..12 unroll { - // TODO Implem range check (https://github.com/leanEthereum/leanMultisig/issues/52) - // For now we use dummy instructions to replicate exactly the cost - - // assert encoding[i * 12 + j] < 4; - dummy_0 = 88888888; - assert dummy_0 == 88888888; - assert dummy_0 == 88888888; - assert dummy_0 == 88888888; - } - - // assert remaining[i] < 2^7 - 1; - dummy_1 = 88888888; - dummy_2 = 88888888; - dummy_3 = 88888888; - assert dummy_1 == 88888888; - assert dummy_2 == 88888888; - assert dummy_3 == 88888888; - - partial_sums = malloc(13); - partial_sums[0] = remaining[i] * 2**24; - for j in 1..13 unroll { - partial_sums[j] = partial_sums[j - 1] + encoding[i * 12 + (j-1)] * 4**(j-1); - } - assert partial_sums[12] == compressed_vals[i]; - } - - // we need to check the target sum - sums = malloc(V); - sums[0] = encoding[0]; - for i in 1..V unroll { - sums[i] = sums[i - 1] + encoding[i]; - } - assert sums[V - 1] == TARGET_SUM; - - public_key = malloc_vec(V); - - chain_tips_ptr = 8 * chain_tips; - public_key_ptr = 8 * public_key; - - for i in 0..V unroll { - match encoding[i] { - 0 => { - var_1 = chain_tips + i; - var_2 = public_key + i; - var_3 = malloc_vec(1); - var_4 = malloc_vec(1); - poseidon16(var_1, pointer_to_zero_vector, var_3, COMPRESSION); - poseidon16(var_3, pointer_to_zero_vector, var_4, COMPRESSION); - poseidon16(var_4, pointer_to_zero_vector, var_2, COMPRESSION); - } - 1 => { - var_3 = malloc_vec(1); - var_1 = chain_tips + i; - var_2 = public_key + i; - poseidon16(var_1, pointer_to_zero_vector, var_3, COMPRESSION); - poseidon16(var_3, pointer_to_zero_vector, var_2, COMPRESSION); - } - 2 => { - var_1 = chain_tips + i; - var_2 = public_key + i; - poseidon16(var_1, pointer_to_zero_vector, var_2, COMPRESSION); - } - 3 => { - var_1 = chain_tips_ptr + (i * 8); - var_2 = public_key_ptr + (i * 8); - var_3 = var_1 + 3; - var_4 = var_2 + 3; - dot_product_ee(var_1, pointer_to_one_vector * 8, var_2, 1); - dot_product_ee(var_3, pointer_to_one_vector * 8, var_4, 1); - } - } - } - - wots_pubkey_hashed = malloc_vec(1); - slice_hash(pointer_to_zero_vector, public_key, wots_pubkey_hashed, V_HALF); - - debug_assert log_lifetime < 33; - - merkle_root = malloc_vec(1); - match log_lifetime { - 0 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 0); } - 1 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 1); } - 2 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 2); } - 3 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 3); } - 4 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 4); } - 5 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 5); } - 6 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 6); } - 7 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 7); } - 8 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 8); } - 9 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 9); } - 10 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 10); } - 11 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 11); } - 12 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 12); } - 13 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 13); } - 14 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 14); } - 15 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 15); } - 16 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 16); } - 17 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 17); } - 18 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 18); } - 19 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 19); } - 20 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 20); } - 21 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 21); } - 22 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 22); } - 23 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 23); } - 24 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 24); } - 25 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 25); } - 26 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 26); } - 27 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 27); } - 28 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 28); } - 29 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 29); } - 30 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 30); } - 31 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 31); } - 32 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 32); } - } - - - return merkle_root; -} - -fn assert_eq_vec(x, y) inline { - // x and y are vectorized pointer of len 1 each - ptr_x = x * 8; - ptr_y = y * 8; - dot_product_ee(ptr_x, pointer_to_one_vector * 8, ptr_y, 1); - dot_product_ee(ptr_x + 3, pointer_to_one_vector * 8, ptr_y + 3, 1); - return; -} diff --git a/crates/sub_protocols/Cargo.toml b/crates/sub_protocols/Cargo.toml index 2f51c4ce..51d8314d 100644 --- a/crates/sub_protocols/Cargo.toml +++ b/crates/sub_protocols/Cargo.toml @@ -10,10 +10,9 @@ workspace = true tracing.workspace = true utils.workspace = true whir-p3.workspace = true -derive_more.workspace = true p3-util.workspace = true +lean_vm.workspace = true multilinear-toolkit.workspace = true -lookup.workspace = true [dev-dependencies] p3-koala-bear.workspace = true diff --git a/crates/sub_protocols/src/commit_extension_from_base.rs b/crates/sub_protocols/src/commit_extension_from_base.rs deleted file mode 100644 index cc677fc8..00000000 --- a/crates/sub_protocols/src/commit_extension_from_base.rs +++ /dev/null @@ -1,81 +0,0 @@ -use crate::ColDims; -use multilinear_toolkit::prelude::*; -use utils::dot_product_with_base; -use utils::transpose_slice_to_basis_coefficients; - -/// Commit extension field columns with a PCS allowing to commit in the base field - -#[derive(Debug)] -pub struct ExtensionCommitmentFromBaseProver>> { - pub sub_columns_to_commit: Vec>>, -} - -pub fn committed_dims_extension_from_base>>( - non_zero_height: usize, - default_values: Vec, -) -> Vec>> { - default_values - .into_iter() - .flat_map(|default_value| { - EF::as_basis_coefficients_slice(&default_value) - .iter() - .map(|&d| ColDims::padded(non_zero_height, d)) - .collect::>() - }) - .collect() -} - -impl>> ExtensionCommitmentFromBaseProver { - pub fn before_commitment(extension_columns: Vec<&[EF]>) -> Self { - let mut sub_columns_to_commit = Vec::new(); - for extension_column in extension_columns { - sub_columns_to_commit.extend(transpose_slice_to_basis_coefficients::, EF>(extension_column)); - } - Self { sub_columns_to_commit } - } - - pub fn after_commitment( - &self, - prover_state: &mut FSProver>, - evaluation_point: &MultilinearPoint, - ) -> Vec>> { - let sub_evals = self - .sub_columns_to_commit - .par_iter() - .map(|sub_column| sub_column.evaluate(evaluation_point)) - .collect::>(); - - prover_state.add_extension_scalars(&sub_evals); - - sub_evals - .iter() - .map(|sub_value| vec![Evaluation::new(evaluation_point.clone(), *sub_value)]) - .collect::>() - } -} - -#[derive(Debug)] -pub struct ExtensionCommitmentFromBaseVerifier {} - -impl ExtensionCommitmentFromBaseVerifier { - pub fn after_commitment>>( - verifier_state: &mut FSVerifier>, - claim: &MultiEvaluation, - ) -> ProofResult>>> { - let sub_evals = verifier_state.next_extension_scalars_vec(EF::DIMENSION * claim.num_values())?; - - let mut statements_remaning_to_verify = Vec::new(); - for (chunk, claim_value) in sub_evals.chunks_exact(EF::DIMENSION).zip(&claim.values) { - if dot_product_with_base(chunk) != *claim_value { - return Err(ProofError::InvalidProof); - } - statements_remaning_to_verify.extend( - chunk - .iter() - .map(|&sub_value| vec![Evaluation::new(claim.point.clone(), sub_value)]), - ); - } - - Ok(statements_remaning_to_verify) - } -} diff --git a/crates/sub_protocols/src/generic_logup.rs b/crates/sub_protocols/src/generic_logup.rs new file mode 100644 index 00000000..28d471e4 --- /dev/null +++ b/crates/sub_protocols/src/generic_logup.rs @@ -0,0 +1,420 @@ +use crate::{prove_gkr_quotient, verify_gkr_quotient}; +use lean_vm::BusDirection; +use lean_vm::BusTable; +use lean_vm::ColIndex; +use lean_vm::DIMENSION; +use lean_vm::EF; +use lean_vm::F; +use lean_vm::Table; +use lean_vm::TableT; +use lean_vm::TableTrace; +use lean_vm::max_bus_width; +use lean_vm::sort_tables_by_height; +use multilinear_toolkit::prelude::*; +use std::collections::BTreeMap; +use utils::MEMORY_TABLE_INDEX; +use utils::VarCount; +use utils::VecOrSlice; +use utils::finger_print; +use utils::from_end; +use utils::mle_of_01234567_etc; +use utils::to_big_endian_in_field; +use utils::transpose_slice_to_basis_coefficients; + +#[derive(Debug, PartialEq, Hash, Clone)] +pub struct GenericLogupStatements { + pub memory_acc_point: MultilinearPoint, + pub value_memory: EF, + pub value_acc: EF, + pub bus_numerators_values: BTreeMap, + pub bus_denominators_values: BTreeMap, + pub points: BTreeMap>, + pub columns_values: BTreeMap>, + // Used in recursion + pub total_n_vars: usize, +} + +#[allow(clippy::too_many_arguments)] +pub fn prove_generic_logup( + prover_state: &mut impl FSProver, + c: EF, + alpha: EF, + memory: &[F], + acc: &[F], + traces: &BTreeMap, +) -> GenericLogupStatements { + assert!(memory[0].is_zero()); + assert!(memory.len().is_power_of_two()); + assert_eq!(memory.len(), acc.len()); + assert!(memory.len() >= traces.values().map(|t| 1 << t.log_n_rows).max().unwrap()); + + let tables_heights = traces.iter().map(|(table, trace)| (*table, trace.log_n_rows)).collect(); + let tables_heights_sorted = sort_tables_by_height(&tables_heights); + + let total_n_vars = compute_total_n_vars( + log2_strict_usize(memory.len()), + &tables_heights_sorted.iter().cloned().collect(), + ); + let mut numerators = EF::zero_vec(1 << total_n_vars); + let mut denominators = EF::zero_vec(1 << total_n_vars); + + let alpha_powers = alpha.powers().collect_n(max_bus_width()); + + // Memory: ... + numerators[..memory.len()] + .par_iter_mut() + .zip(acc) // TODO embedding overhead + .for_each(|(num, a)| *num = EF::from(-*a)); // Note the negative sign here + denominators[..memory.len()] + .par_iter_mut() + .zip(memory.par_iter().enumerate()) + .for_each(|(denom, (i, &mem_value))| { + *denom = c - finger_print( + F::from_usize(MEMORY_TABLE_INDEX), + &[F::from_usize(i), mem_value], + &alpha_powers, + ) + }); + + // ... Rest of the tables: + let mut offset = memory.len(); + for (table, _) in &tables_heights_sorted { + let trace = &traces[table]; + let log_n_rows = trace.log_n_rows; + + // I] Bus (data flow between tables) + + let bus = table.bus(); + numerators[offset..][..1 << log_n_rows] + .par_iter_mut() + .zip(&trace.base[bus.selector]) + .for_each(|(num, selector)| { + *num = EF::from(match bus.direction { + BusDirection::Pull => -*selector, + BusDirection::Push => *selector, + }) + }); // TODO embedding overhead + denominators[offset..][..1 << log_n_rows] + .par_iter_mut() + .enumerate() + .for_each(|(i, denom)| { + *denom = { + c + finger_print( + match &bus.table { + BusTable::Constant(table) => table.embed(), + BusTable::Variable(col) => trace.base[*col][i], + }, + bus.data + .iter() + .map(|col| trace.base[*col][i]) + .collect::>() + .as_slice(), + &alpha_powers, + ) + } + }); + + offset += 1 << log_n_rows; + + // II] Lookup into memory + + let mut value_columns_f = Vec::>::new(); + for cols_f in table.lookup_f_value_columns(trace) { + value_columns_f.push(cols_f.iter().map(|s| VecOrSlice::Slice(s)).collect()); + } + let mut value_columns_ef = Vec::>::new(); + for col_ef in table.lookup_ef_value_columns(trace) { + value_columns_ef.push( + transpose_slice_to_basis_coefficients::, EF>(col_ef) + .into_iter() + .map(VecOrSlice::Vec) + .collect(), + ); + } + for (index_columns, value_columns) in [ + (table.lookup_index_columns_f(trace), &value_columns_f), + (table.lookup_index_columns_ef(trace), &value_columns_ef), + ] { + for (col_index, col_values) in index_columns.iter().zip(value_columns) { + numerators[offset..][..col_values.len() << log_n_rows] + .par_iter_mut() + .for_each(|num| { + *num = EF::ONE; + }); // TODO embedding overhead + denominators[offset..][..col_values.len() << log_n_rows] + .par_chunks_exact_mut(1 << log_n_rows) + .enumerate() + .for_each(|(i, denom_chunk)| { + let i_field = F::from_usize(i); + denom_chunk.par_iter_mut().enumerate().for_each(|(j, denom)| { + let index = col_index[j] + i_field; + let mem_value = col_values[i].as_slice()[j]; + *denom = + c - finger_print(F::from_usize(MEMORY_TABLE_INDEX), &[index, mem_value], &alpha_powers) + }); + }); + offset += col_values.len() << log_n_rows; + } + } + } + + assert_eq!(log2_ceil_usize(offset), total_n_vars); + tracing::info!("Logup data: {} = 2^{:.2}", offset, (offset as f64).log2()); + + denominators[offset..].par_iter_mut().for_each(|d| *d = EF::ONE); // padding + + // TODO pack directly + let numerators_packed = MleRef::Extension(&numerators).pack(); + let denominators_packed = MleRef::Extension(&denominators).pack(); + + let (sum, claim_point_gkr, numerators_value, denominators_value) = + prove_gkr_quotient(prover_state, &numerators_packed.by_ref(), &denominators_packed.by_ref()); + + let _ = (numerators_value, denominators_value); // TODO use it to avoid some computation below + + // sanity check + assert_eq!(sum, EF::ZERO); + + // Memory: ... + let memory_acc_point = MultilinearPoint(from_end(&claim_point_gkr, log2_strict_usize(memory.len())).to_vec()); + let value_acc = acc.evaluate(&memory_acc_point); + prover_state.add_extension_scalar(value_acc); + + let value_memory = memory.evaluate(&memory_acc_point); + prover_state.add_extension_scalar(value_memory); + + // ... Rest of the tables: + let mut points = BTreeMap::new(); + let mut bus_numerators_values = BTreeMap::new(); + let mut bus_denominators_values = BTreeMap::new(); + let mut columns_values = BTreeMap::new(); + let mut offset = memory.len(); + for (table, _) in &tables_heights_sorted { + let trace = &traces[table]; + let log_n_rows = trace.log_n_rows; + + let inner_point = MultilinearPoint(from_end(&claim_point_gkr, log_n_rows).to_vec()); + points.insert(*table, inner_point.clone()); + + // I] Bus (data flow between tables) + + let eval_on_selector = + trace.base[table.bus().selector].evaluate(&inner_point) * table.bus().direction.to_field_flag(); + prover_state.add_extension_scalar(eval_on_selector); + + let eval_on_data = (&denominators[offset..][..1 << log_n_rows]).evaluate(&inner_point); + prover_state.add_extension_scalar(eval_on_data); + + bus_numerators_values.insert(*table, eval_on_selector); + bus_denominators_values.insert(*table, eval_on_data); + + // II] Lookup into memory + + let mut table_values = BTreeMap::::new(); + for lookup_f in table.lookups_f() { + let index_eval = trace.base[lookup_f.index].evaluate(&inner_point); + prover_state.add_extension_scalar(index_eval); + assert!(!table_values.contains_key(&lookup_f.index)); + table_values.insert(lookup_f.index, index_eval); + + for col_index in &lookup_f.values { + let value_eval = trace.base[*col_index].evaluate(&inner_point); + prover_state.add_extension_scalar(value_eval); + assert!(!table_values.contains_key(col_index)); + table_values.insert(*col_index, value_eval); + } + } + + for lookup_ef in table.lookups_ef() { + let index_eval = trace.base[lookup_ef.index].evaluate(&inner_point); + prover_state.add_extension_scalar(index_eval); + assert_eq!(table_values.get(&lookup_ef.index).unwrap_or(&index_eval), &index_eval); + table_values.insert(lookup_ef.index, index_eval); + + let col_ef = &trace.ext[lookup_ef.values]; + + for (i, col) in transpose_slice_to_basis_coefficients::, EF>(col_ef) + .iter() + .enumerate() + { + let value_eval = col.evaluate(&inner_point); + prover_state.add_extension_scalar(value_eval); + let global_index = table.n_commited_columns_f() + lookup_ef.values * DIMENSION + i; + assert!(!table_values.contains_key(&global_index)); + table_values.insert(global_index, value_eval); + } + } + points.insert(*table, inner_point); + columns_values.insert(*table, table_values); + + offset += offset_for_table(table, log_n_rows); + } + + GenericLogupStatements { + memory_acc_point, + value_memory, + value_acc, + bus_numerators_values, + bus_denominators_values, + points, + columns_values, + total_n_vars, + } +} + +#[allow(clippy::too_many_arguments)] +pub fn verify_generic_logup( + verifier_state: &mut impl FSVerifier, + c: EF, + alpha: EF, + log_memory: usize, + table_log_n_rows: &BTreeMap, +) -> ProofResult { + let tables_heights_sorted = sort_tables_by_height(table_log_n_rows); + + let total_n_vars = compute_total_n_vars(log_memory, &tables_heights_sorted.iter().cloned().collect()); + + let (sum, point_gkr, numerators_value, denominators_value) = verify_gkr_quotient(verifier_state, total_n_vars)?; + + if sum != EF::ZERO { + return Err(ProofError::InvalidProof); + } + + let alpha_powers = alpha.powers().collect_n(max_bus_width()); + + let mut retrieved_numerators_value = EF::ZERO; + let mut retrieved_denominators_value = EF::ZERO; + + // Memory ... + let memory_acc_point = MultilinearPoint(from_end(&point_gkr, log_memory).to_vec()); + let bits = to_big_endian_in_field::(0, total_n_vars - log_memory); + let pref = + MultilinearPoint(bits).eq_poly_outside(&MultilinearPoint(point_gkr[..total_n_vars - log_memory].to_vec())); + + let value_acc = verifier_state.next_extension_scalar()?; + retrieved_numerators_value -= pref * value_acc; + + let value_memory = verifier_state.next_extension_scalar()?; + let value_index = mle_of_01234567_etc(&memory_acc_point); + retrieved_denominators_value += pref + * (c - finger_print( + F::from_usize(MEMORY_TABLE_INDEX), + &[value_index, value_memory], + &alpha_powers, + )); + + // ... Rest of the tables: + let mut points = BTreeMap::new(); + let mut bus_numerators_values = BTreeMap::new(); + let mut bus_denominators_values = BTreeMap::new(); + let mut columns_values = BTreeMap::new(); + let mut offset = 1 << log_memory; + for &(table, log_n_rows) in &tables_heights_sorted { + let n_missing_vars = total_n_vars - log_n_rows; + let inner_point = MultilinearPoint(from_end(&point_gkr, log_n_rows).to_vec()); + let missing_point = MultilinearPoint(point_gkr[..n_missing_vars].to_vec()); + + points.insert(table, inner_point.clone()); + + // I] Bus (data flow between tables) + + let eval_on_selector = verifier_state.next_extension_scalar()?; + + let bits = to_big_endian_in_field::(offset >> log_n_rows, n_missing_vars); + let pref = MultilinearPoint(bits).eq_poly_outside(&missing_point); + retrieved_numerators_value += pref * eval_on_selector; + + let eval_on_data = verifier_state.next_extension_scalar()?; + retrieved_denominators_value += pref * eval_on_data; + + bus_numerators_values.insert(table, eval_on_selector); + bus_denominators_values.insert(table, eval_on_data); + + offset += 1 << log_n_rows; + + // II] Lookup into memory + + let mut table_values = BTreeMap::::new(); + for lookup_f in table.lookups_f() { + let index_eval = verifier_state.next_extension_scalar()?; + assert!(!table_values.contains_key(&lookup_f.index)); + table_values.insert(lookup_f.index, index_eval); + + for (i, col_index) in lookup_f.values.iter().enumerate() { + let value_eval = verifier_state.next_extension_scalar()?; + assert!(!table_values.contains_key(col_index)); + table_values.insert(*col_index, value_eval); + + let bits = to_big_endian_in_field::(offset >> log_n_rows, n_missing_vars); + let pref = MultilinearPoint(bits).eq_poly_outside(&missing_point); + retrieved_numerators_value += pref; + retrieved_denominators_value += pref + * (c - finger_print( + F::from_usize(MEMORY_TABLE_INDEX), + &[index_eval + F::from_usize(i), value_eval], + &alpha_powers, + )); + offset += 1 << log_n_rows; + } + } + + for lookup_ef in table.lookups_ef() { + let index_eval = verifier_state.next_extension_scalar()?; + assert_eq!(table_values.get(&lookup_ef.index).unwrap_or(&index_eval), &index_eval); + table_values.insert(lookup_ef.index, index_eval); + + for i in 0..DIMENSION { + let value_eval = verifier_state.next_extension_scalar()?; + + let bits = to_big_endian_in_field::(offset >> log_n_rows, n_missing_vars); + let pref = MultilinearPoint(bits).eq_poly_outside(&missing_point); + retrieved_numerators_value += pref; + retrieved_denominators_value += pref + * (c - finger_print( + F::from_usize(MEMORY_TABLE_INDEX), + &[index_eval + F::from_usize(i), value_eval], + &alpha_powers, + )); + let global_index = table.n_commited_columns_f() + lookup_ef.values * DIMENSION + i; + assert!(!table_values.contains_key(&global_index)); + table_values.insert(global_index, value_eval); + offset += 1 << log_n_rows; + } + } + columns_values.insert(table, table_values); + } + + retrieved_denominators_value += mle_of_zeros_then_ones(offset, &point_gkr); // to compensate for the final padding: XYZ111111...1 + if retrieved_numerators_value != numerators_value { + return Err(ProofError::InvalidProof); + } + if retrieved_denominators_value != denominators_value { + return Err(ProofError::InvalidProof); + } + + Ok(GenericLogupStatements { + memory_acc_point, + value_memory, + value_acc, + bus_numerators_values, + bus_denominators_values, + points, + columns_values, + total_n_vars, + }) +} + +fn offset_for_table(table: &Table, log_n_rows: usize) -> usize { + let num_cols = + table.lookups_f().iter().map(|l| l.values.len()).sum::() + table.lookups_ef().len() * DIMENSION + 1; // +1 for the bus + num_cols << log_n_rows +} + +fn compute_total_n_vars(log_memory: usize, tables_heights: &BTreeMap) -> usize { + let total_len = (1 << log_memory) + + tables_heights + .iter() + .map(|(table, log_n_rows)| offset_for_table(table, *log_n_rows)) + .sum::(); + log2_ceil_usize(total_len) +} diff --git a/crates/sub_protocols/src/generic_packed_lookup.rs b/crates/sub_protocols/src/generic_packed_lookup.rs deleted file mode 100644 index 4e815573..00000000 --- a/crates/sub_protocols/src/generic_packed_lookup.rs +++ /dev/null @@ -1,352 +0,0 @@ -use lookup::compute_pushforward; -use lookup::prove_logup_star; -use lookup::verify_logup_star; -use multilinear_toolkit::prelude::*; -use std::any::TypeId; -use utils::VecOrSlice; -use utils::{FSProver, assert_eq_many}; - -use crate::{ColDims, MultilinearChunks, packed_pcs_global_statements_for_prover}; - -#[derive(Debug)] -pub struct GenericPackedLookupProver<'a, TF: Field, EF: ExtensionField + ExtensionField>> { - // inputs - pub(crate) table: VecOrSlice<'a, TF>, - pub(crate) index_columns: Vec<&'a [PF]>, - - // outputs - pub(crate) n_cols_per_group: Vec, - pub(crate) chunks: MultilinearChunks, - pub(crate) packed_lookup_indexes: Vec>, - pub(crate) poly_eq_point: Vec, - pub(crate) pushforward: Vec, // to be committed - pub(crate) batched_value: EF, -} - -#[derive(Debug, PartialEq)] -pub struct PackedLookupStatements { - pub on_table: Evaluation, - pub on_pushforward: Vec>, - pub on_indexes: Vec>>, // contain sparse points (TODO take advantage of it) -} - -impl<'a, TF: Field, EF: ExtensionField + ExtensionField>> GenericPackedLookupProver<'a, TF, EF> -where - PF: PrimeField64, -{ - pub fn pushforward_to_commit(&self) -> &[EF] { - &self.pushforward - } - - // before committing to the pushforward - #[allow(clippy::too_many_arguments)] - pub fn step_1( - prover_state: &mut FSProver>, - table: VecOrSlice<'a, TF>, // table[0] is assumed to be zero - index_columns: Vec<&'a [PF]>, - heights: Vec, - default_indexes: Vec, - value_columns: Vec>>, // value_columns[i][j] = (index_columns[i] + j)*table (using the notation of https://eprint.iacr.org/2025/946) - statements: Vec>>, - log_smallest_decomposition_chunk: usize, - ) -> Self { - let table_ref = table.as_slice(); - assert!(table_ref[0].is_zero()); - assert!(table_ref.len().is_power_of_two()); - assert_eq_many!( - index_columns.len(), - heights.len(), - default_indexes.len(), - value_columns.len(), - statements.len() - ); - value_columns.iter().zip(&statements).for_each(|(cols, evals)| { - assert_eq!(cols.len(), evals[0].num_values()); - }); - let n_groups = value_columns.len(); - let n_cols_per_group = value_columns.iter().map(|cols| cols.len()).collect::>(); - - let flatened_value_columns = value_columns - .iter() - .flat_map(|cols| cols.iter().map(|col| col.as_slice())) - .collect::>(); - - let mut all_dims = vec![]; - for (i, (default_index, height)) in default_indexes.iter().zip(heights.iter()).enumerate() { - for col_index in 0..n_cols_per_group[i] { - all_dims.push(ColDims::padded(*height, table_ref[col_index + default_index])); - } - } - - let (_packed_lookup_values, chunks) = crate::compute_multilinear_chunks_and_apply( - &flatened_value_columns, - &all_dims, - log_smallest_decomposition_chunk, - ); - - let packed_statements = packed_pcs_global_statements_for_prover( - &flatened_value_columns, - &all_dims, - log_smallest_decomposition_chunk, - &expand_multi_evals(&statements), - prover_state, - ); - - let mut missing_shifted_index_cols = vec![vec![]; n_groups]; - for (i, index_col) in index_columns.iter().enumerate() { - for j in 1..n_cols_per_group[i] { - let shifted_col = index_col - .par_iter() - .map(|&x| x + PF::::from_usize(j)) - .collect::>>(); - missing_shifted_index_cols[i].push(shifted_col); - } - } - let mut all_index_cols_ref = vec![]; - for (i, index_col) in index_columns.iter().enumerate() { - all_index_cols_ref.push(*index_col); - for shifted_col in &missing_shifted_index_cols[i] { - all_index_cols_ref.push(shifted_col.as_slice()); - } - } - - let packed_lookup_indexes = chunks.apply(&all_index_cols_ref); - - let batching_scalar = prover_state.sample(); - - let mut poly_eq_point = EF::zero_vec(1 << chunks.packed_n_vars); - for (alpha_power, statement) in batching_scalar.powers().zip(&packed_statements) { - compute_sparse_eval_eq(&statement.point, &mut poly_eq_point, alpha_power); - } - let pushforward = compute_pushforward(&packed_lookup_indexes, table_ref.len(), &poly_eq_point); - - let batched_value: EF = batching_scalar - .powers() - .zip(&packed_statements) - .map(|(alpha_power, statement)| alpha_power * statement.value) - .sum(); - - Self { - table, - index_columns, - n_cols_per_group, - batched_value, - packed_lookup_indexes, - poly_eq_point, - pushforward, - chunks, - } - } - - // after committing to the pushforward - pub fn step_2( - &self, - prover_state: &mut FSProver>, - non_zero_memory_size: usize, - ) -> PackedLookupStatements { - let table = if TypeId::of::() == TypeId::of::>() { - MleRef::Base(unsafe { std::mem::transmute::<&[TF], &[PF]>(self.table.as_slice()) }) - } else if TypeId::of::() == TypeId::of::() { - MleRef::Extension(unsafe { std::mem::transmute::<&[TF], &[EF]>(self.table.as_slice()) }) - } else { - panic!(); - }; - let logup_star_statements = prove_logup_star( - prover_state, - &table, - &self.packed_lookup_indexes, - self.batched_value, - &self.poly_eq_point, - &self.pushforward, - Some(non_zero_memory_size), - ); - - let mut value_on_packed_indexes = EF::ZERO; - let mut offset = 0; - let mut index_statements_to_prove = vec![]; - for (i, n_cols) in self.n_cols_per_group.iter().enumerate() { - let my_chunks = &self.chunks[offset..offset + n_cols]; - offset += n_cols; - - assert!(my_chunks.iter().all(|col_chunks| { - col_chunks - .iter() - .zip(my_chunks[0].iter()) - .all(|(c1, c2)| c1.offset_in_original == c2.offset_in_original && c1.n_vars == c2.n_vars) - })); - let mut inner_statements = vec![]; - let mut inner_evals = vec![]; - for chunk in &my_chunks[0] { - let sparse_point = MultilinearPoint( - [ - chunk.bits_offset_in_original(), - logup_star_statements.on_indexes.point[self.chunks.packed_n_vars - chunk.n_vars..].to_vec(), - ] - .concat(), - ); - let eval = self.index_columns[i].evaluate_sparse(&sparse_point); - inner_evals.push(eval); - inner_statements.push(Evaluation::new(sparse_point, eval)); - } - prover_state.add_extension_scalars(&inner_evals); - index_statements_to_prove.push(inner_statements); - - for (col_index, chunks_for_col) in my_chunks.iter().enumerate() { - for (&inner_eval, chunk) in inner_evals.iter().zip(chunks_for_col) { - let missing_vars = self.chunks.packed_n_vars - chunk.n_vars; - value_on_packed_indexes += (inner_eval + PF::::from_usize(col_index)) - * MultilinearPoint(logup_star_statements.on_indexes.point[..missing_vars].to_vec()) - .eq_poly_outside(&MultilinearPoint( - chunk.bits_offset_in_packed(self.chunks.packed_n_vars), - )); - } - } - } - // sanity check - assert_eq!(value_on_packed_indexes, logup_star_statements.on_indexes.value); - - PackedLookupStatements { - on_table: logup_star_statements.on_table, - on_pushforward: logup_star_statements.on_pushforward, - on_indexes: index_statements_to_prove, - } - } -} - -#[derive(Debug)] -pub struct GenericPackedLookupVerifier>> { - n_cols_per_group: Vec, - chunks: MultilinearChunks, - batching_scalar: EF, - packed_statements: Vec>, -} - -impl>> GenericPackedLookupVerifier -where - PF: PrimeField64, -{ - // before receiving the commitment to the pushforward - pub fn step_1>>( - verifier_state: &mut FSVerifier>, - heights: Vec, - default_indexes: Vec, - statements: Vec>>, - log_smallest_decomposition_chunk: usize, - table_initial_values: &[TF], - ) -> ProofResult - where - EF: ExtensionField, - { - let n_cols_per_group = statements - .iter() - .map(|evals| evals[0].num_values()) - .collect::>(); - let mut all_dims = vec![]; - for (i, (default_index, height)) in default_indexes.iter().zip(heights.iter()).enumerate() { - for col_index in 0..n_cols_per_group[i] { - all_dims.push(ColDims::padded( - *height, - table_initial_values[col_index + default_index], - )); - } - } - - let packed_statements = crate::packed_pcs_global_statements_for_verifier( - &all_dims, - log_smallest_decomposition_chunk, - &expand_multi_evals(&statements), - verifier_state, - &Default::default(), - )?; - let chunks = MultilinearChunks::compute(&all_dims, log_smallest_decomposition_chunk); - - let batching_scalar = verifier_state.sample(); - - Ok(Self { - n_cols_per_group, - chunks, - batching_scalar, - packed_statements, - }) - } - - // after receiving the commitment to the pushforward - pub fn step_2( - &self, - verifier_state: &mut FSVerifier>, - log_memory_size: usize, - ) -> ProofResult> { - let logup_star_statements = verify_logup_star( - verifier_state, - log_memory_size, - self.chunks.packed_n_vars, - &self.packed_statements, - self.batching_scalar, - ) - .unwrap(); - - let mut value_on_packed_indexes = EF::ZERO; - let mut offset = 0; - let mut index_statements_to_verify = vec![]; - for n_cols in &self.n_cols_per_group { - let my_chunks = &self.chunks[offset..offset + n_cols]; - offset += n_cols; - - // sanity check - assert!(my_chunks.iter().all(|col_chunks| { - col_chunks - .iter() - .zip(my_chunks[0].iter()) - .all(|(c1, c2)| c1.offset_in_original == c2.offset_in_original && c1.n_vars == c2.n_vars) - })); - let mut inner_statements = vec![]; - let inner_evals = verifier_state.next_extension_scalars_vec(my_chunks[0].len())?; - for (chunk, &eval) in my_chunks[0].iter().zip(&inner_evals) { - let sparse_point = MultilinearPoint( - [ - chunk.bits_offset_in_original(), - logup_star_statements.on_indexes.point[self.chunks.packed_n_vars - chunk.n_vars..].to_vec(), - ] - .concat(), - ); - inner_statements.push(Evaluation::new(sparse_point, eval)); - } - index_statements_to_verify.push(inner_statements); - - for (col_index, chunks_for_col) in my_chunks.iter().enumerate() { - for (&inner_eval, chunk) in inner_evals.iter().zip(chunks_for_col) { - let missing_vars = self.chunks.packed_n_vars - chunk.n_vars; - value_on_packed_indexes += (inner_eval + PF::::from_usize(col_index)) - * MultilinearPoint(logup_star_statements.on_indexes.point[..missing_vars].to_vec()) - .eq_poly_outside(&MultilinearPoint( - chunk.bits_offset_in_packed(self.chunks.packed_n_vars), - )); - } - } - } - if value_on_packed_indexes != logup_star_statements.on_indexes.value { - return Err(ProofError::InvalidProof); - } - - Ok(PackedLookupStatements { - on_table: logup_star_statements.on_table, - on_pushforward: logup_star_statements.on_pushforward, - on_indexes: index_statements_to_verify, - }) - } -} - -fn expand_multi_evals(statements: &[Vec>]) -> Vec>> { - statements - .iter() - .flat_map(|multi_evals| { - let mut evals = vec![vec![]; multi_evals[0].num_values()]; - for meval in multi_evals { - for (i, &v) in meval.values.iter().enumerate() { - evals[i].push(Evaluation::new(meval.point.clone(), v)); - } - } - evals - }) - .collect::>() -} diff --git a/crates/sub_protocols/src/lib.rs b/crates/sub_protocols/src/lib.rs index 809f23af..32ffb821 100644 --- a/crates/sub_protocols/src/lib.rs +++ b/crates/sub_protocols/src/lib.rs @@ -1,14 +1,13 @@ -mod generic_packed_lookup; -pub use generic_packed_lookup::*; +mod generic_logup; +pub use generic_logup::*; mod packed_pcs; pub use packed_pcs::*; -mod commit_extension_from_base; -pub use commit_extension_from_base::*; +mod quotient_gkr; +pub use quotient_gkr::*; -mod normal_packed_lookup; -pub use normal_packed_lookup::*; +mod logup_star; +pub use logup_star::*; -mod vectorized_packed_lookup; -pub use vectorized_packed_lookup::*; +pub(crate) const MIN_VARS_FOR_PACKING: usize = 8; diff --git a/crates/lookup/src/logup_star.rs b/crates/sub_protocols/src/logup_star.rs similarity index 79% rename from crates/lookup/src/logup_star.rs rename to crates/sub_protocols/src/logup_star.rs index 64bdcf0d..cbf34e19 100644 --- a/crates/lookup/src/logup_star.rs +++ b/crates/sub_protocols/src/logup_star.rs @@ -6,10 +6,9 @@ https://eprint.iacr.org/2025/946.pdf */ use multilinear_toolkit::prelude::*; -use utils::ToUsize; +use utils::{ToUsize, mle_of_01234567_etc}; use tracing::{info_span, instrument}; -use utils::{FSProver, FSVerifier}; use crate::{ MIN_VARS_FOR_PACKING, @@ -25,12 +24,12 @@ pub struct LogupStarStatements { #[instrument(skip_all)] pub fn prove_logup_star( - prover_state: &mut FSProver>, + prover_state: &mut impl FSProver, table: &MleRef<'_, EF>, indexes: &[PF], claimed_value: EF, poly_eq_point: &[EF], - pushforward: &[EF], // already commited + pushforward: &MleRef<'_, EF>, // already commited max_index: Option, ) -> LogupStarStatements where @@ -51,7 +50,7 @@ where let (poly_eq_point_packed, pushforward_packed, table_packed) = info_span!("packing").in_scope(|| { ( MleRef::Extension(poly_eq_point).pack_if(packing), - MleRef::Extension(pushforward).pack_if(packing), + pushforward.pack_if(packing), table.pack_if(packing), ) }); @@ -93,9 +92,10 @@ where .collect::>(); let c_minus_indexes_packed = MleRef::Extension(&c_minus_indexes).pack_if(packing); - let (_, claim_point_left, _, eval_c_minus_indexes) = prove_gkr_quotient::<_, 2>( + let (_, claim_point_left, _, eval_c_minus_indexes) = prove_gkr_quotient( prover_state, - &MleGroupRef::merge(&[&poly_eq_point_packed.by_ref(), &c_minus_indexes_packed.by_ref()]), + &poly_eq_point_packed.by_ref(), + &c_minus_indexes_packed.by_ref(), ); let c_minus_increments = MleRef::Extension( @@ -105,9 +105,10 @@ where .collect::>(), ); let c_minus_increments_packed = c_minus_increments.pack_if(packing); - let (_, claim_point_right, pushforward_final_eval, _) = prove_gkr_quotient::<_, 2>( + let (_, claim_point_right, pushforward_final_eval, _) = prove_gkr_quotient( prover_state, - &MleGroupRef::merge(&[&pushforward_packed.by_ref(), &c_minus_increments_packed.by_ref()]), + &pushforward_packed.by_ref(), + &c_minus_increments_packed.by_ref(), ); let on_indexes = Evaluation::new(claim_point_left, c - eval_c_minus_indexes); @@ -123,11 +124,10 @@ where } pub fn verify_logup_star( - verifier_state: &mut FSVerifier>, + verifier_state: &mut impl FSVerifier, log_table_len: usize, log_indexes_len: usize, - claims: &[Evaluation], - alpha: EF, // batching challenge + claim: Evaluation, ) -> Result, ProofError> where EF: ExtensionField>, @@ -135,7 +135,7 @@ where { let (sum, postponed) = sumcheck_verify(verifier_state, log_table_len, 2).map_err(|_| ProofError::InvalidProof)?; - if sum != claims.iter().zip(alpha.powers()).map(|(c, a)| c.value * a).sum::() { + if sum != claim.value { return Err(ProofError::InvalidProof); } @@ -152,35 +152,22 @@ where let c = verifier_state.sample(); let (quotient_left, claim_point_left, claim_num_left, eval_c_minus_indexes) = - verify_gkr_quotient::<_, 2>(verifier_state, log_indexes_len)?; + verify_gkr_quotient(verifier_state, log_indexes_len)?; let (quotient_right, claim_point_right, pushforward_final_eval, claim_den_right) = - verify_gkr_quotient::<_, 2>(verifier_state, log_table_len)?; + verify_gkr_quotient(verifier_state, log_table_len)?; if quotient_left != quotient_right { return Err(ProofError::InvalidProof); } let on_indexes = Evaluation::new(claim_point_left.clone(), c - eval_c_minus_indexes); - if claim_num_left - != claims - .iter() - .zip(alpha.powers()) - .map(|(claim, a)| claim_point_left.eq_poly_outside(&claim.point) * a) - .sum::() - { + if claim_num_left != claim_point_left.eq_poly_outside(&claim.point) { return Err(ProofError::InvalidProof); } on_pushforward.push(Evaluation::new(claim_point_right.clone(), pushforward_final_eval)); - let big_endian_mle = claim_point_right - .iter() - .rev() - .enumerate() - .map(|(i, &p)| p * EF::TWO.exp_u64(i as u64)) - .sum::(); - - if claim_den_right != c - big_endian_mle { + if claim_den_right != c - mle_of_01234567_etc(&claim_point_right) { return Err(ProofError::InvalidProof); } @@ -210,8 +197,6 @@ pub fn compute_pushforward>( #[cfg(test)] mod tests { - use std::time::Instant; - use super::*; use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; use rand::{Rng, SeedableRng, rngs::StdRng}; @@ -258,7 +243,7 @@ mod tests { let point = MultilinearPoint((0..log_indexes_len).map(|_| rng.random()).collect::>()); - let mut prover_state = build_prover_state(false); + let mut prover_state = build_prover_state(); let eval = values.evaluate(&point); let time = std::time::Instant::now(); @@ -272,18 +257,18 @@ mod tests { &commited_indexes, claim.value, &poly_eq_point, - &pushforward, + &MleRef::Extension(&pushforward), Some(max_index), ); println!("Proving logup_star took {} ms", time.elapsed().as_millis()); - let last_prover_state = prover_state.challenger().state(); + let last_prover_state = prover_state.state(); let mut verifier_state = build_verifier_state(prover_state); let verifier_statements = - verify_logup_star(&mut verifier_state, log_table_len, log_indexes_len, &[claim], EF::ONE).unwrap(); + verify_logup_star(&mut verifier_state, log_table_len, log_indexes_len, claim).unwrap(); assert_eq!(&verifier_statements, &prover_statements); - assert_eq!(last_prover_state, verifier_state.challenger().state()); + assert_eq!(last_prover_state, verifier_state.state()); assert_eq!( indexes.evaluate(&verifier_statements.on_indexes.point), @@ -296,19 +281,5 @@ mod tests { for eval in &verifier_statements.on_pushforward { assert_eq!(pushforward.evaluate(&eval.point), eval.value); } - - { - let n_muls = 16; - let slice = (0..(table_length + indexes_len) / packing_width::()) - .map(|_| rng.random()) - .collect::>>(); - let time = Instant::now(); - let sum = slice - .par_iter() - .map(|x| (0..n_muls).map(|_| *x).product::>()) - .sum::>(); - assert!(sum != EFPacking::::ONE); - println!("Optimal time we can hope for: {} ms", time.elapsed().as_millis()); - } } } diff --git a/crates/sub_protocols/src/normal_packed_lookup.rs b/crates/sub_protocols/src/normal_packed_lookup.rs deleted file mode 100644 index 71770317..00000000 --- a/crates/sub_protocols/src/normal_packed_lookup.rs +++ /dev/null @@ -1,220 +0,0 @@ -use multilinear_toolkit::prelude::*; -use utils::FSProver; -use utils::VecOrSlice; -use utils::assert_eq_many; -use utils::dot_product_with_base; -use utils::transpose_slice_to_basis_coefficients; - -use crate::GenericPackedLookupProver; -use crate::GenericPackedLookupVerifier; - -#[derive(Debug)] -pub struct NormalPackedLookupProver<'a, EF: ExtensionField>> { - generic: GenericPackedLookupProver<'a, PF, EF>, - n_cols_f: usize, -} - -#[derive(Debug, PartialEq)] -pub struct NormalPackedLookupStatements { - pub on_table: Evaluation, - pub on_pushforward: Vec>, - pub on_indexes_f: Vec>>, // contain sparse points (TODO take advantage of it) - pub on_indexes_ef: Vec>>, // contain sparse points (TODO take advantage of it) -} - -impl<'a, EF: ExtensionField>> NormalPackedLookupProver<'a, EF> -where - PF: PrimeField64, -{ - pub fn pushforward_to_commit(&self) -> &[EF] { - self.generic.pushforward_to_commit() - } - - // before committing to the pushforward - #[allow(clippy::too_many_arguments)] - pub fn step_1( - prover_state: &mut FSProver>, - table: &'a [PF], // table[0] is assumed to be zero - index_columns_f: Vec<&'a [PF]>, - index_columns_ef: Vec<&'a [PF]>, - heights_f: Vec, - heights_ef: Vec, - default_indexes_f: Vec, - default_indexes_ef: Vec, - value_columns_f: Vec<&'a [PF]>, - value_columns_ef: Vec<&'a [EF]>, - statements_f: Vec>>, - statements_ef: Vec>>, - log_smallest_decomposition_chunk: usize, - ) -> Self { - assert_eq_many!( - index_columns_f.len(), - heights_f.len(), - default_indexes_f.len(), - value_columns_f.len(), - statements_f.len() - ); - assert_eq_many!( - index_columns_ef.len(), - heights_ef.len(), - default_indexes_ef.len(), - value_columns_ef.len(), - statements_ef.len() - ); - let n_cols_f = value_columns_f.len(); - - let mut all_value_columns = vec![]; - for col_f in value_columns_f { - all_value_columns.push(vec![VecOrSlice::Slice(col_f)]); - } - for col_ef in &value_columns_ef { - all_value_columns.push( - transpose_slice_to_basis_coefficients::, EF>(col_ef) - .into_iter() - .map(VecOrSlice::Vec) - .collect(), - ); - } - - let mut multi_eval_statements = vec![]; - for eval_group in &statements_f { - multi_eval_statements.push( - eval_group - .iter() - .map(|e| MultiEvaluation::new(e.point.clone(), vec![e.value])) - .collect(), - ); - } - - for (eval_group, extension_column_split) in statements_ef.iter().zip(&all_value_columns[n_cols_f..]) { - let mut multi_evals = vec![]; - for eval in eval_group { - let sub_evals = extension_column_split - .par_iter() - .map(|slice| slice.as_slice().evaluate(&eval.point)) - .collect::>(); - // sanity check: - assert_eq!(dot_product_with_base(&sub_evals), eval.value); - - prover_state.add_extension_scalars(&sub_evals); - multi_evals.push(MultiEvaluation::new(eval.point.clone(), sub_evals)); - } - - multi_eval_statements.push(multi_evals); - } - - let index_columns = [index_columns_f, index_columns_ef].concat(); - let heights = [heights_f, heights_ef].concat(); - let default_indexes = [default_indexes_f, default_indexes_ef].concat(); - - let generic = GenericPackedLookupProver::step_1( - prover_state, - VecOrSlice::Slice(table), - index_columns, - heights, - default_indexes, - all_value_columns, - multi_eval_statements, - log_smallest_decomposition_chunk, - ); - - Self { generic, n_cols_f } - } - - // after committing to the pushforward - pub fn step_2( - &self, - prover_state: &mut FSProver>, - non_zero_memory_size: usize, - ) -> NormalPackedLookupStatements { - let res = self.generic.step_2(prover_state, non_zero_memory_size); - NormalPackedLookupStatements { - on_table: res.on_table, - on_pushforward: res.on_pushforward, - on_indexes_f: res.on_indexes[..self.n_cols_f].to_vec(), - on_indexes_ef: res.on_indexes[self.n_cols_f..].to_vec(), - } - } -} - -#[derive(Debug)] -pub struct NormalPackedLookupVerifier>> { - generic: GenericPackedLookupVerifier, - n_cols_f: usize, -} - -impl>> NormalPackedLookupVerifier -where - PF: PrimeField64, -{ - // before receiving the commitment to the pushforward - #[allow(clippy::too_many_arguments)] - pub fn step_1>>( - verifier_state: &mut FSVerifier>, - heights_f: Vec, - heights_ef: Vec, - default_indexes_f: Vec, - default_indexes_ef: Vec, - statements_f: Vec>>, - statements_ef: Vec>>, - log_smallest_decomposition_chunk: usize, - table_initial_values: &[TF], - ) -> ProofResult - where - EF: ExtensionField, - { - assert_eq_many!(heights_f.len(), default_indexes_f.len(), statements_f.len()); - assert_eq_many!(heights_ef.len(), default_indexes_ef.len(), statements_ef.len()); - let n_cols_f = statements_f.len(); - - let mut multi_eval_statements = vec![]; - for eval_group in &statements_f { - multi_eval_statements.push( - eval_group - .iter() - .map(|e| MultiEvaluation::new(e.point.clone(), vec![e.value])) - .collect(), - ); - } - for eval_group in &statements_ef { - let mut multi_evals = vec![]; - for eval in eval_group { - let sub_evals = - verifier_state.next_extension_scalars_vec(>>::DIMENSION)?; - if dot_product_with_base(&sub_evals) != eval.value { - return Err(ProofError::InvalidProof); - } - multi_evals.push(MultiEvaluation::new(eval.point.clone(), sub_evals)); - } - - multi_eval_statements.push(multi_evals); - } - let heights = [heights_f, heights_ef].concat(); - let default_indexes = [default_indexes_f, default_indexes_ef].concat(); - let generic = GenericPackedLookupVerifier::step_1( - verifier_state, - heights, - default_indexes, - multi_eval_statements, - log_smallest_decomposition_chunk, - table_initial_values, - )?; - - Ok(Self { generic, n_cols_f }) - } - - // after receiving the commitment to the pushforward - pub fn step_2( - &self, - verifier_state: &mut FSVerifier>, - log_memory_size: usize, - ) -> ProofResult> { - let res = self.generic.step_2(verifier_state, log_memory_size)?; - Ok(NormalPackedLookupStatements { - on_table: res.on_table, - on_pushforward: res.on_pushforward, - on_indexes_f: res.on_indexes[..self.n_cols_f].to_vec(), - on_indexes_ef: res.on_indexes[self.n_cols_f..].to_vec(), - }) - } -} diff --git a/crates/sub_protocols/src/packed_pcs.rs b/crates/sub_protocols/src/packed_pcs.rs index 1362e6f3..b8919ee0 100644 --- a/crates/sub_protocols/src/packed_pcs.rs +++ b/crates/sub_protocols/src/packed_pcs.rs @@ -1,694 +1,128 @@ -use std::{any::TypeId, cmp::Reverse, collections::BTreeMap}; - +use lean_vm::{COL_PC, CommittedStatements, ENDING_PC, STARTING_PC, sort_tables_by_height}; +use lean_vm::{EF, F, Table, TableT, TableTrace}; use multilinear_toolkit::prelude::*; -use p3_util::{log2_ceil_usize, log2_strict_usize}; +use p3_util::log2_ceil_usize; +use std::collections::BTreeMap; use tracing::instrument; -use utils::{ - FSProver, FSVerifier, from_end, multilinear_eval_constants_at_right, to_big_endian_bits, to_big_endian_in_field, -}; +use utils::{VarCount, transpose_slice_to_basis_coefficients}; use whir_p3::*; -#[derive(Debug, Clone)] -pub struct Chunk { - pub original_poly_index: usize, - pub original_n_vars: usize, - pub n_vars: usize, - pub offset_in_original: usize, - pub public_data: bool, - pub offset_in_packed: Option, -} - -impl Chunk { - pub fn bits_offset_in_original(&self) -> Vec { - to_big_endian_in_field( - self.offset_in_original >> self.n_vars, - self.original_n_vars - self.n_vars, - ) - } - pub fn bits_offset_in_packed(&self, packed_n_vars: usize) -> Vec { - to_big_endian_in_field( - self.offset_in_packed.unwrap() >> self.n_vars, - packed_n_vars - self.n_vars, - ) - } - fn global_point_for_statement(&self, point: &[F], packed_n_vars: usize) -> MultilinearPoint { - MultilinearPoint([self.bits_offset_in_packed(packed_n_vars), point.to_vec()].concat()) - } -} - -/* -General layout: [public data][committed data][repeated value] (the thing has length 2^n_vars, public data is a power of two) -*/ -#[derive(Debug, Clone, Copy)] -pub struct ColDims { - pub n_vars: usize, - pub log_public_data_size: Option, - pub committed_size: usize, - pub default_value: F, -} - -impl ColDims { - pub fn full(n_vars: usize) -> Self { - Self { - n_vars, - log_public_data_size: None, - committed_size: 1 << n_vars, - default_value: F::ZERO, - } - } - - pub fn padded(committed_size: usize, default_value: F) -> Self { - Self::padded_with_public_data(None, committed_size, default_value) - } - - pub fn padded_with_public_data( - log_public_data_size: Option, - committed_size: usize, - default_value: F, - ) -> Self { - let public_data_size = log_public_data_size.map_or(0, |l| 1 << l); - let n_vars = log2_ceil_usize(public_data_size + committed_size); - Self { - n_vars, - log_public_data_size, - committed_size, - default_value, - } - } -} - -fn split_in_chunks( - poly_index: usize, - dims: &ColDims, - log_smallest_decomposition_chunk: usize, -) -> Vec { - let mut offset_in_original = 0; - let mut res = Vec::new(); - if let Some(log_public) = dims.log_public_data_size { - assert!( - log_public >= log_smallest_decomposition_chunk, - "poly {poly_index}: {log_public} < {log_smallest_decomposition_chunk}" - ); - res.push(Chunk { - original_poly_index: poly_index, - original_n_vars: dims.n_vars, - n_vars: log_public, - offset_in_original, - public_data: true, - offset_in_packed: None, - }); - offset_in_original += 1 << log_public; - } - let mut remaining = dims.committed_size; - - loop { - let mut chunk_size = if remaining.next_power_of_two() - remaining <= 1 << log_smallest_decomposition_chunk { - log2_ceil_usize(remaining) - } else { - remaining.ilog2() as usize - }; - if let Some(log_public) = dims.log_public_data_size { - chunk_size = chunk_size.min(log_public); - } - - res.push(Chunk { - original_poly_index: poly_index, - original_n_vars: dims.n_vars, - n_vars: chunk_size, - offset_in_original, - public_data: false, - offset_in_packed: None, - }); - offset_in_original += 1 << chunk_size; - if remaining <= 1 << chunk_size { - return res; - } - remaining -= 1 << chunk_size; - } -} - -pub fn num_packed_vars_for_dims(dims: &[ColDims], log_smallest_decomposition_chunk: usize) -> usize { - MultilinearChunks::compute(dims, log_smallest_decomposition_chunk).packed_n_vars -} - #[derive(Debug)] -pub struct MultiCommitmentWitness>> { +pub struct MultiCommitmentWitness { + pub packed_n_vars: usize, pub inner_witness: Witness, pub packed_polynomial: MleOwned, } -#[derive(Debug, derive_more::Deref)] -pub struct MultilinearChunks { - #[deref] - pub chunks_decomposition: Vec>, - pub packed_n_vars: usize, -} - -impl MultilinearChunks { - pub fn compute(dims: &[ColDims], log_smallest_decomposition_chunk: usize) -> Self { - let mut all_chunks = Vec::new(); - for (i, dim) in dims.iter().enumerate() { - all_chunks.extend(split_in_chunks(i, dim, log_smallest_decomposition_chunk)); - } - all_chunks.sort_by_key(|c| (Reverse(c.public_data), Reverse(c.n_vars))); - - let mut offset_in_packed = 0; - let mut chunks_decomposition: BTreeMap<_, Vec<_>> = BTreeMap::new(); - for chunk in &mut all_chunks { - if !chunk.public_data { - chunk.offset_in_packed = Some(offset_in_packed); - offset_in_packed += 1 << chunk.n_vars; - } - chunks_decomposition - .entry(chunk.original_poly_index) - .or_default() - .push(chunk.clone()); +#[instrument(skip_all)] +pub fn packed_pcs_global_statements( + packed_n_vars: usize, + memory_n_vars: usize, + memory_acc_statements: Vec>, + tables_heights: &BTreeMap, + committed_statements: &CommittedStatements, +) -> Vec> { + assert_eq!(tables_heights.len(), committed_statements.len()); + + let tables_heights_sorted = sort_tables_by_height(tables_heights); + + let mut global_statements = memory_acc_statements; + let mut offset = 2 << memory_n_vars; + + for (table, n_vars) in tables_heights_sorted { + if table.is_execution_table() { + // Important: ensure both initial and final PC conditions are correct + global_statements.push(SparseStatement::unique_value( + packed_n_vars, + offset + (COL_PC << n_vars), + EF::from_usize(STARTING_PC), + )); + global_statements.push(SparseStatement::unique_value( + packed_n_vars, + offset + ((COL_PC + 1) << n_vars) - 1, + EF::from_usize(ENDING_PC), + )); } - let packed_n_vars = log2_ceil_usize( - all_chunks - .iter() - .filter(|c| !c.public_data) - .map(|c| 1 << c.n_vars) - .sum::(), - ); - let chunks_decomposition = chunks_decomposition.values().cloned().collect::>(); - assert_eq!(chunks_decomposition.len(), dims.len()); - Self { - chunks_decomposition, - packed_n_vars, + for (point, col_statements) in &committed_statements[&table] { + global_statements.push(SparseStatement::new( + packed_n_vars, + point.clone(), + col_statements + .iter() + .map(|(&col_index, &value)| SparseValue::new((offset >> n_vars) + col_index, value)) + .collect(), + )); } + offset += table.n_commited_columns() << n_vars; } - - pub fn apply(&self, polynomials: &[&[F]]) -> Vec - where - F: Field, - { - let packed_polynomial = F::zero_vec(1 << self.packed_n_vars); // TODO avoid this huge cloning of all witness data - self.iter() - .flatten() - .collect::>() - .par_iter() - .filter(|chunk| !chunk.public_data) - .for_each(|chunk| { - let start = chunk.offset_in_packed.unwrap(); - let end = start + (1 << chunk.n_vars); - let original_poly = &polynomials[chunk.original_poly_index]; - unsafe { - let slice = - std::slice::from_raw_parts_mut((packed_polynomial.as_ptr() as *mut F).add(start), end - start); - slice.copy_from_slice( - &original_poly[chunk.offset_in_original..chunk.offset_in_original + (1 << chunk.n_vars)], - ); - } - }); - - packed_polynomial - } -} - -#[instrument(skip_all)] -pub fn compute_multilinear_chunks_and_apply( - polynomials: &[&[F]], - dims: &[ColDims], - log_smallest_decomposition_chunk: usize, -) -> (Vec, MultilinearChunks) -where - F: Field, -{ - assert_eq!(polynomials.len(), dims.len()); - for (i, (poly, dim)) in polynomials.iter().zip(dims.iter()).enumerate() { - assert_eq!( - poly.len(), - 1 << dim.n_vars, - "poly {} has {} vars, but dim should be {}", - i, - log2_strict_usize(poly.len()), - dim.n_vars - ); - } - let chunks = MultilinearChunks::compute(dims, log_smallest_decomposition_chunk); - - { - // logging - let total_commited_data: usize = dims.iter().map(|d| d.committed_size).sum(); - let packed_commited_data: usize = chunks - .iter() - .flatten() - .filter(|c| !c.public_data) - .map(|c| 1 << c.n_vars) - .sum(); - - tracing::info!( - "Total committed data (full granularity): {} = 2^{:.3} | packed to 2^{:.3} -> 2^{}", - total_commited_data, - (total_commited_data as f64).log2(), - (packed_commited_data as f64).log2(), - chunks.packed_n_vars - ); - } - - let packed_polynomial = chunks.apply(polynomials); - - (packed_polynomial, chunks) + global_statements } #[instrument(skip_all)] -pub fn packed_pcs_commit( +pub fn packed_pcs_commit( + prover_state: &mut impl FSProver, whir_config_builder: &WhirConfigBuilder, - polynomials: &[&[F]], - dims: &[ColDims], - prover_state: &mut FSProver>, - log_smallest_decomposition_chunk: usize, -) -> MultiCommitmentWitness -where - F: Field + TwoAdicField + ExtensionField>, - PF: TwoAdicField, - EF: ExtensionField + TwoAdicField + ExtensionField>, -{ - let (packed_polynomial, _chunks_decomposition) = - compute_multilinear_chunks_and_apply::(polynomials, dims, log_smallest_decomposition_chunk); - let packed_n_vars = log2_strict_usize(packed_polynomial.len()); + memory: &[F], + acc: &[F], + traces: &BTreeMap, +) -> MultiCommitmentWitness { + assert_eq!(memory.len(), acc.len()); + let tables_heights = traces.iter().map(|(table, trace)| (*table, trace.log_n_rows)).collect(); + let tables_heights_sorted = sort_tables_by_height(&tables_heights); + + let packed_n_vars = compute_total_n_vars( + log2_strict_usize(memory.len()), + &tables_heights_sorted.iter().cloned().collect(), + ); + let mut packed_polynomial = F::zero_vec(1 << packed_n_vars); // TODO avoid cloning all witness data + packed_polynomial[..memory.len()].copy_from_slice(memory); + let mut offset = memory.len(); + packed_polynomial[offset..offset + acc.len()].copy_from_slice(acc); + offset += acc.len(); + for (table, log_n_rows) in &tables_heights_sorted { + let n_rows = 1 << *log_n_rows; + for col_index_f in 0..table.n_commited_columns_f() { + let col = &traces[table].base[col_index_f]; + packed_polynomial[offset..offset + n_rows].copy_from_slice(&col[..n_rows]); + offset += n_rows; + } + for col_index_ef in 0..table.n_commited_columns_ef() { + let col = &traces[table].ext[col_index_ef]; + let transposed = transpose_slice_to_basis_coefficients(col); + for basis_col in transposed { + packed_polynomial[offset..offset + n_rows].copy_from_slice(&basis_col); + offset += n_rows; + } + } + } + assert_eq!(log2_ceil_usize(offset), packed_n_vars); + tracing::info!("packed PCS data: {} = 2^{:.2}", offset, (offset as f64).log2()); - let mle = if TypeId::of::() == TypeId::of::>() { - MleOwned::Base(unsafe { std::mem::transmute::, Vec>>(packed_polynomial) }) - } else if TypeId::of::() == TypeId::of::() { - MleOwned::ExtensionPacked(pack_extension(&unsafe { - std::mem::transmute::, Vec>(packed_polynomial) - })) // TODO this is innefficient (this transposes everything...) - } else { - panic!("Unsupported field type for packed PCS: {}", std::any::type_name::()); - }; + let packed_polynomial = MleOwned::Base(packed_polynomial); - let inner_witness = WhirConfig::new(whir_config_builder.clone(), packed_n_vars).commit(prover_state, &mle); + let inner_witness = WhirConfig::new(whir_config_builder, packed_n_vars).commit(prover_state, &packed_polynomial); MultiCommitmentWitness { + packed_n_vars, inner_witness, - packed_polynomial: mle, - } -} - -#[instrument(skip_all)] -pub fn packed_pcs_global_statements_for_prover + ExtensionField>>( - polynomials: &[&[F]], - dims: &[ColDims], - log_smallest_decomposition_chunk: usize, - statements_per_polynomial: &[Vec>], - prover_state: &mut FSProver>, -) -> Vec> { - // TODO: - // - cache the "eq" poly, and then use dot product - // - current packing is not optimal in the end: can lead to [16][4][2][2] (instead of [16][8]) - - let all_chunks = MultilinearChunks::compute(dims, log_smallest_decomposition_chunk); - - let statements_flattened = statements_per_polynomial - .iter() - .enumerate() - .flat_map(|(poly_index, poly_statements)| poly_statements.iter().map(move |statement| (poly_index, statement))) - .collect::>(); - - let sub_packed_statements_and_evals_to_send = statements_flattened - .par_iter() - .map(|(poly_index, statement)| { - let dim = &dims[*poly_index]; - let pol = polynomials[*poly_index]; - - let chunks = &all_chunks[*poly_index]; - assert!(!chunks.is_empty()); - let mut sub_packed_statements = Vec::new(); - let mut evals_to_send = Vec::new(); - if chunks.len() == 1 { - assert!(!chunks[0].public_data, "TODO"); - assert_eq!(chunks[0].n_vars, statement.point.0.len(), "poly: {poly_index}"); - assert!( - chunks[0] - .offset_in_packed - .unwrap() - .is_multiple_of(1 << chunks[0].n_vars) - ); - - sub_packed_statements.push(Evaluation::new( - chunks[0].global_point_for_statement(&statement.point, all_chunks.packed_n_vars), - statement.value, - )); - } else { - let initial_booleans = statement - .point - .iter() - .take_while(|&&x| x == EF::ZERO || x == EF::ONE) - .map(|&x| x == EF::ONE) - .collect::>(); - let mut all_chunk_evals = Vec::new(); - - // skip the first one, we will deduce it (if it's not public) - // TODO do we really need to parallelize this? - chunks[1..] - .par_iter() - .map(|chunk| { - let missing_vars = statement.point.0.len() - chunk.n_vars; - - let offset_in_original_booleans = - to_big_endian_bits(chunk.offset_in_original >> chunk.n_vars, missing_vars); - - if !initial_booleans.is_empty() - && initial_booleans.len() < offset_in_original_booleans.len() - && initial_booleans == offset_in_original_booleans[..initial_booleans.len()] - { - tracing::warn!("TODO: sparse statement accroos mutiple chunks"); - } - - if initial_booleans.len() >= offset_in_original_booleans.len() { - if initial_booleans[..missing_vars] != offset_in_original_booleans { - // this chunk is not concerned by this sparse evaluation - return (None, EF::ZERO); - } else { - // the evaluation only depends on this chunk, no need to recompute and = statement.value - return (None, statement.value); - } - } - - let sub_point = MultilinearPoint(statement.point.0[missing_vars..].to_vec()); - let sub_value = (&pol - [chunk.offset_in_original..chunk.offset_in_original + (1 << chunk.n_vars)]) - .evaluate_sparse(&sub_point); // `evaluate_sparse` because sometime (typically due to packed lookup protocol, the original statement is already sparse) - ( - Some(Evaluation::new( - chunk.global_point_for_statement(&sub_point, all_chunks.packed_n_vars), - sub_value, - )), - sub_value, - ) - }) - .collect::>() - .into_iter() - .for_each(|(statement, sub_value)| { - if let Some(statement) = statement { - evals_to_send.push(statement.value); - sub_packed_statements.push(statement); - } - all_chunk_evals.push(sub_value); - }); - - let initial_missing_vars = statement.point.0.len() - chunks[0].n_vars; - let initial_offset_in_original_booleans = - to_big_endian_bits(chunks[0].offset_in_original >> chunks[0].n_vars, initial_missing_vars); - if initial_booleans.len() < initial_offset_in_original_booleans.len() // if the statement only concern the first chunk, no need to send more data - && dim.log_public_data_size.is_none() - // if the first value is public, no need to recompute it - { - let retrieved_eval = compute_multilinear_value_from_chunks( - &chunks[1..], - &all_chunk_evals, - &statement.point, - 1 << chunks[0].n_vars, - dim.default_value, - ); - - let initial_missing_vars = statement.point.0.len() - chunks[0].n_vars; - let initial_sub_value = (statement.value - retrieved_eval) - / MultilinearPoint(statement.point.0[..initial_missing_vars].to_vec()) - .eq_poly_outside(&MultilinearPoint(chunks[0].bits_offset_in_original())); - let initial_sub_point = MultilinearPoint(statement.point.0[initial_missing_vars..].to_vec()); - - let initial_packed_point = - chunks[0].global_point_for_statement(&initial_sub_point, all_chunks.packed_n_vars); - sub_packed_statements.insert(0, Evaluation::new(initial_packed_point, initial_sub_value)); - evals_to_send.insert(0, initial_sub_value); - } - } - (sub_packed_statements, evals_to_send) - }) - .collect::>(); - - let mut packed_statements = Vec::new(); - for (sub_packed_statements, evals_to_send) in sub_packed_statements_and_evals_to_send { - packed_statements.extend(sub_packed_statements); - prover_state.add_extension_scalars(&evals_to_send); + packed_polynomial, } - packed_statements } -pub fn packed_pcs_parse_commitment< - F: Field + TwoAdicField, - EF: ExtensionField + TwoAdicField + ExtensionField>, ->( +pub fn packed_pcs_parse_commitment( whir_config_builder: &WhirConfigBuilder, - verifier_state: &mut FSVerifier>, - dims: &[ColDims], - log_smallest_decomposition_chunk: usize, -) -> Result, ProofError> -where - PF: TwoAdicField, -{ - let all_chunks = MultilinearChunks::compute(dims, log_smallest_decomposition_chunk); - WhirConfig::new(whir_config_builder.clone(), all_chunks.packed_n_vars).parse_commitment(verifier_state) + verifier_state: &mut impl FSVerifier, + log_memory: usize, + tables_heights: &BTreeMap, +) -> Result, ProofError> { + let packed_n_vars = compute_total_n_vars(log_memory, tables_heights); + WhirConfig::new(whir_config_builder, packed_n_vars).parse_commitment(verifier_state) } -pub fn packed_pcs_global_statements_for_verifier + ExtensionField>>( - dims: &[ColDims], - log_smallest_decomposition_chunk: usize, - statements_per_polynomial: &[Vec>], - verifier_state: &mut FSVerifier>, - public_data: &BTreeMap>, // poly_index -> public data slice (power of 2) -) -> Result>, ProofError> { - assert_eq!(dims.len(), statements_per_polynomial.len()); - let all_chunks = MultilinearChunks::compute(dims, log_smallest_decomposition_chunk); - let mut packed_statements = Vec::new(); - for (poly_index, statements) in statements_per_polynomial.iter().enumerate() { - let dim = &dims[poly_index]; - let has_public_data = dim.log_public_data_size.is_some(); - let chunks = &all_chunks[poly_index]; - assert!(!chunks.is_empty()); - for statement in statements { - if chunks.len() == 1 { - assert!(!chunks[0].public_data, "TODO"); - assert_eq!(chunks[0].n_vars, statement.point.0.len()); - assert!( - chunks[0] - .offset_in_packed - .unwrap() - .is_multiple_of(1 << chunks[0].n_vars) - ); - packed_statements.push(Evaluation::new( - chunks[0].global_point_for_statement(&statement.point, all_chunks.packed_n_vars), - statement.value, - )); - } else { - let initial_booleans = statement - .point - .iter() - .take_while(|&&x| x == EF::ZERO || x == EF::ONE) - .map(|&x| x == EF::ONE) - .collect::>(); - let mut sub_values = vec![]; - if has_public_data { - sub_values.push( - public_data[&poly_index] - .evaluate(&MultilinearPoint(from_end(&statement.point, chunks[0].n_vars).to_vec())), - ); - } - for chunk in chunks { - if chunk.public_data { - continue; - } - let missing_vars = statement.point.0.len() - chunk.n_vars; - let offset_in_original_booleans = - to_big_endian_bits(chunk.offset_in_original >> chunk.n_vars, missing_vars); - - if initial_booleans.len() >= offset_in_original_booleans.len() { - if initial_booleans[..missing_vars] != offset_in_original_booleans { - // this chunk is not concerned by this sparse evaluation - sub_values.push(EF::ZERO); - } else { - // the evaluation only depends on this chunk, no need to recompute and = statement.value - sub_values.push(statement.value); - } - } else { - let sub_value = verifier_state.next_extension_scalar()?; - sub_values.push(sub_value); - let sub_point = MultilinearPoint(statement.point.0[missing_vars..].to_vec()); - packed_statements.push(Evaluation::new( - chunk.global_point_for_statement(&sub_point, all_chunks.packed_n_vars), - sub_value, - )); - } - } - // consistency check - if statement.value - != compute_multilinear_value_from_chunks( - chunks, - &sub_values, - &statement.point, - 0, - dim.default_value, - ) - { - return Err(ProofError::InvalidProof); - } - } - } - } - Ok(packed_statements) -} - -fn compute_multilinear_value_from_chunks>( - chunks: &[Chunk], - evals_per_chunk: &[EF], - point: &[EF], - size_of_first_chunk_mising: usize, - default_value: F, -) -> EF { - assert_eq!(chunks.len(), evals_per_chunk.len()); - let mut eval = EF::ZERO; - - let mut chunk_offset_sums = size_of_first_chunk_mising; - for (chunk, &sub_value) in chunks.iter().zip(evals_per_chunk) { - let missing_vars = point.len() - chunk.n_vars; - eval += sub_value - * MultilinearPoint(point[..missing_vars].to_vec()) - .eq_poly_outside(&MultilinearPoint(chunk.bits_offset_in_original())); - chunk_offset_sums += 1 << chunk.n_vars; - } - eval += multilinear_eval_constants_at_right(chunk_offset_sums, point) * default_value; - eval -} - -#[cfg(test)] -mod tests { - use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; - use p3_util::log2_strict_usize; - use rand::{Rng, SeedableRng, rngs::StdRng}; - use utils::{build_prover_state, build_verifier_state}; - - use super::*; - - type F = KoalaBear; - type EF = QuinticExtensionFieldKB; - - #[test] - fn test_packed_pcs() { - let whir_config_builder = WhirConfigBuilder { - folding_factor: FoldingFactor::new(4, 4), - soundness_type: SecurityAssumption::CapacityBound, - pow_bits: 13, - max_num_variables_to_send_coeffs: 6, - rs_domain_initial_reduction_factor: 1, - security_level: 75, - starting_log_inv_rate: 1, - }; - - let mut rng = StdRng::seed_from_u64(0); - let log_smallest_decomposition_chunk = 4; - let committed_length_lengths_and_default_value_and_log_public_data: [(usize, F, Option); _] = [ - (916, F::from_usize(8), Some(5)), - (854, F::from_usize(0), Some(7)), - (854, F::from_usize(1), Some(5)), - (16, F::from_usize(0), Some(5)), - (1127, F::from_usize(0), Some(6)), - (595, F::from_usize(3), Some(6)), - (17, F::from_usize(0), None), - (95, F::from_usize(3), None), - (256, F::from_usize(8), None), - (1088, F::from_usize(9), None), - (512, F::from_usize(0), None), - (256, F::from_usize(8), Some(6)), - (1088, F::from_usize(9), Some(5)), - (512, F::from_usize(0), Some(5)), - (754, F::from_usize(4), Some(5)), - (1023, F::from_usize(7), Some(5)), - (2025, F::from_usize(11), Some(8)), - (16, F::from_usize(8), None), - (854, F::from_usize(0), None), - (854, F::from_usize(1), None), - (16, F::from_usize(0), None), - (754, F::from_usize(4), None), - (1023, F::from_usize(7), None), - (2025, F::from_usize(15), None), - (600, F::from_usize(100), None), - ]; - let mut public_data = BTreeMap::new(); - let mut polynomials = Vec::new(); - let mut dims = Vec::new(); - let mut statements_per_polynomial = Vec::new(); - for (pol_index, &(committed_length, default_value, log_public_data)) in - committed_length_lengths_and_default_value_and_log_public_data - .iter() - .enumerate() - { - let mut poly = (0..committed_length + log_public_data.map_or(0, |l| 1 << l)) - .map(|_| rng.random()) - .collect::>(); - poly.resize(poly.len().next_power_of_two(), default_value); - if let Some(log_public) = log_public_data { - public_data.insert(pol_index, poly[..1 << log_public].to_vec()); - } - let n_vars = log2_strict_usize(poly.len()); - let n_points = rng.random_range(1..5); - let mut statements = Vec::new(); - for _ in 0..n_points { - let point = MultilinearPoint((0..n_vars).map(|_| rng.random()).collect::>()); - let value = poly.evaluate(&point); - statements.push(Evaluation { point, value }); - } - polynomials.push(poly); - dims.push(ColDims { - n_vars, - log_public_data_size: log_public_data, - committed_size: committed_length, - default_value, - }); - statements_per_polynomial.push(statements); - } - statements_per_polynomial - .last_mut() - .unwrap() - .push(Evaluation::new(vec![EF::ONE; 10], EF::from_usize(100))); - - let mut prover_state = build_prover_state(false); - precompute_dft_twiddles::(1 << 24); - - let polynomials_ref = polynomials.iter().map(|p| p.as_slice()).collect::>(); - let witness = packed_pcs_commit( - &whir_config_builder, - &polynomials_ref, - &dims, - &mut prover_state, - log_smallest_decomposition_chunk, - ); - - let packed_statements = packed_pcs_global_statements_for_prover( - &polynomials_ref, - &dims, - log_smallest_decomposition_chunk, - &statements_per_polynomial, - &mut prover_state, - ); - let num_variables = witness.packed_polynomial.by_ref().n_vars(); - WhirConfig::new(whir_config_builder.clone(), num_variables).prove( - &mut prover_state, - packed_statements, - witness.inner_witness, - &witness.packed_polynomial.by_ref(), - ); - - let mut verifier_state = build_verifier_state(prover_state); - - let parsed_commitment = packed_pcs_parse_commitment( - &whir_config_builder, - &mut verifier_state, - &dims, - log_smallest_decomposition_chunk, - ) - .unwrap(); - let packed_statements = packed_pcs_global_statements_for_verifier( - &dims, - log_smallest_decomposition_chunk, - &statements_per_polynomial, - &mut verifier_state, - &public_data, - ) - .unwrap(); - WhirConfig::new(whir_config_builder, num_variables) - .verify(&mut verifier_state, &parsed_commitment, packed_statements) - .unwrap(); - } +fn compute_total_n_vars(log_memory: usize, tables_heights: &BTreeMap) -> usize { + let total_len = (2 << log_memory) + + tables_heights + .iter() + .map(|(table, log_n_rows)| table.n_commited_columns() << log_n_rows) + .sum::(); + log2_ceil_usize(total_len) } diff --git a/crates/sub_protocols/src/quotient_gkr.rs b/crates/sub_protocols/src/quotient_gkr.rs new file mode 100644 index 00000000..b7c2b18a --- /dev/null +++ b/crates/sub_protocols/src/quotient_gkr.rs @@ -0,0 +1,260 @@ +use multilinear_toolkit::prelude::*; +use tracing::instrument; + +use crate::MIN_VARS_FOR_PACKING; + +/* +GKR to compute sum of fractions. +*/ + +#[instrument(skip_all)] +pub fn prove_gkr_quotient>>( + prover_state: &mut impl FSProver, + numerators: &MleRef<'_, EF>, + denominators: &MleRef<'_, EF>, +) -> (EF, MultilinearPoint, EF, EF) { + assert!(numerators.is_packed() == denominators.is_packed()); + let mut layers: Vec<(Mle<'_, EF>, Mle<'_, EF>)> = + vec![(numerators.soft_clone().into(), denominators.soft_clone().into())]; + + loop { + let (mut prev_numerators, mut prev_denominators) = layers.last().cloned().unwrap(); + if prev_numerators.is_packed() && prev_numerators.n_vars() < MIN_VARS_FOR_PACKING { + (prev_numerators, prev_denominators) = ( + prev_numerators.unpack().as_owned_or_clone().into(), + prev_denominators.unpack().as_owned_or_clone().into(), + ) + } + if prev_numerators.n_vars() == 1 { + break; + } + let (new_numerators, new_denominators) = sum_quotients(prev_numerators.by_ref(), prev_denominators.by_ref()); + layers.push((new_numerators.into(), new_denominators.into())); + } + + let (last_numerators, last_denominators) = layers.pop().unwrap(); + let last_numerators = last_numerators.as_owned().unwrap(); + let last_numerators = last_numerators.as_extension().unwrap(); + let last_denominators = last_denominators.as_owned().unwrap(); + let last_denominators = last_denominators.as_extension().unwrap(); + prover_state.add_extension_scalars(last_numerators); + prover_state.add_extension_scalars(last_denominators); + let quotient = last_numerators[0] / last_denominators[0] + last_numerators[1] / last_denominators[1]; + + let mut point = MultilinearPoint(vec![prover_state.sample()]); + prover_state.duplexing(); + let mut claims = vec![last_numerators.evaluate(&point), last_denominators.evaluate(&point)]; + + for (nums, denoms) in layers.iter().rev() { + (point, claims) = prove_gkr_quotient_step(prover_state, nums, denoms, &point, claims); + } + assert_eq!(claims.len(), 2); + (quotient, point, claims[0], claims[1]) +} + +fn prove_gkr_quotient_step>>( + prover_state: &mut impl FSProver, + numerators: &Mle<'_, EF>, + denominators: &Mle<'_, EF>, + claim_point: &MultilinearPoint, + claims: Vec, +) -> (MultilinearPoint, Vec) { + let prev_numerators_and_denominators_split = match (numerators.by_ref(), denominators.by_ref()) { + (MleRef::ExtensionPacked(numerators), MleRef::ExtensionPacked(denominators)) => { + let (left_nums, right_nums) = numerators.split_at(numerators.len() / 2); + let (left_dens, right_dens) = denominators.split_at(denominators.len() / 2); + MleGroupRef::ExtensionPacked(vec![left_nums, right_nums, left_dens, right_dens]) + } + (MleRef::Extension(numerators), MleRef::Extension(denominators)) => { + let (left_nums, right_nums) = numerators.split_at(numerators.len() / 2); + let (left_dens, right_dens) = denominators.split_at(denominators.len() / 2); + MleGroupRef::Extension(vec![left_nums, right_nums, left_dens, right_dens]) + } + _ => unreachable!(), + }; + + let alpha = prover_state.sample(); + + let (mut next_point, inner_evals, _) = sumcheck_prove::( + 1, + prev_numerators_and_denominators_split, + None, + &GKRQuotientComputation::<2> {}, + &alpha.powers().take(2).collect(), + Some((claim_point.0.clone(), None)), + false, + prover_state, + dot_product(claims.iter().copied(), alpha.powers()), + false, + ); + + prover_state.add_extension_scalars(&inner_evals); + let beta = prover_state.sample(); + prover_state.duplexing(); + + let next_claims = inner_evals + .chunks_exact(2) + .map(|chunk| chunk.evaluate(&MultilinearPoint(vec![beta]))) + .collect::>(); + + next_point.0.insert(0, beta); + + (next_point, next_claims) +} + +pub fn verify_gkr_quotient>>( + verifier_state: &mut impl FSVerifier, + n_vars: usize, +) -> Result<(EF, MultilinearPoint, EF, EF), ProofError> { + let last_nums = verifier_state.next_extension_scalars_vec(2)?; + let last_dens = verifier_state.next_extension_scalars_vec(2)?; + + let quotient = last_nums[0] / last_dens[0] + last_nums[1] / last_dens[1]; + + let mut point = MultilinearPoint(vec![verifier_state.sample()]); + verifier_state.duplexing(); + let mut claims_num = last_nums.evaluate(&point); + let mut claims_den = last_dens.evaluate(&point); + for i in 1..n_vars { + (point, claims_num, claims_den) = verify_gkr_quotient_step(verifier_state, i, &point, claims_num, claims_den)?; + } + + Ok((quotient, point, claims_num, claims_den)) +} + +fn verify_gkr_quotient_step>>( + verifier_state: &mut impl FSVerifier, + n_vars: usize, + point: &MultilinearPoint, + claims_num: EF, + claims_den: EF, +) -> Result<(MultilinearPoint, EF, EF), ProofError> { + let alpha = verifier_state.sample(); + + let (retrieved_quotient, postponed) = sumcheck_verify(verifier_state, n_vars, 3)?; + + if retrieved_quotient != claims_num + alpha * claims_den { + return Err(ProofError::InvalidProof); + } + + let inner_evals = verifier_state.next_extension_scalars_vec(4)?; + + if postponed.value + != point.eq_poly_outside(&postponed.point) + * as SumcheckComputation>::eval_extension( + &Default::default(), + &inner_evals, + &[], + &alpha.powers().take(2).collect(), + ) + { + return Err(ProofError::InvalidProof); + } + + let beta = verifier_state.sample(); + verifier_state.duplexing(); + + let next_claims_numerators = (&inner_evals[..2]).evaluate(&MultilinearPoint(vec![beta])); + let next_claims_denominators = (&inner_evals[2..]).evaluate(&MultilinearPoint(vec![beta])); + let mut next_point = postponed.point.clone(); + next_point.0.insert(0, beta); + + Ok((next_point, next_claims_numerators, next_claims_denominators)) +} + +fn sum_quotients>>( + numerators: MleRef<'_, EF>, + denominators: MleRef<'_, EF>, +) -> (MleOwned, MleOwned) { + match (numerators, denominators) { + (MleRef::ExtensionPacked(numerators), MleRef::ExtensionPacked(denominators)) => { + let (new_numerators, new_denominators) = sum_quotients_2_by_2(numerators, denominators); + ( + MleOwned::ExtensionPacked(new_numerators), + MleOwned::ExtensionPacked(new_denominators), + ) + } + (MleRef::Extension(numerators), MleRef::Extension(denominators)) => { + let (new_numerators, new_denominators) = sum_quotients_2_by_2(numerators, denominators); + ( + MleOwned::Extension(new_numerators), + MleOwned::Extension(new_denominators), + ) + } + _ => unreachable!(), + } +} +fn sum_quotients_2_by_2( + numerators: &[F], + denominators: &[F], +) -> (Vec, Vec) { + let n = numerators.len(); + let new_n = n / 2; + let mut new_numerators = unsafe { uninitialized_vec(new_n) }; + let mut new_denominators = unsafe { uninitialized_vec(new_n) }; + new_numerators + .par_iter_mut() + .zip(new_denominators.par_iter_mut()) + .enumerate() + .for_each(|(i, (num, den))| { + let my_numerators: [_; 2] = [numerators[i], numerators[i + new_n]]; + let my_denominators: [_; 2] = [denominators[i], denominators[i + new_n]]; + *num = my_numerators[0] * my_denominators[1] + my_numerators[1] * my_denominators[0]; + *den = my_denominators[0] * my_denominators[1]; + }); + (new_numerators, new_denominators) +} + +#[cfg(test)] +mod tests { + use std::time::Instant; + + use super::*; + use p3_koala_bear::QuinticExtensionFieldKB; + use rand::{Rng, SeedableRng, rngs::StdRng}; + use utils::{build_prover_state, build_verifier_state, init_tracing}; + + type EF = QuinticExtensionFieldKB; + + fn sum_all_quotients(nums: &[EF], den: &[EF]) -> EF { + nums.iter().zip(den.iter()).map(|(&n, &d)| n / d).sum() + } + + const N_GROUPS: usize = 2; + + #[test] + fn test_gkr_quotient() { + let log_n = 13; + let n = 1 << log_n; + init_tracing(); + + let mut rng = StdRng::seed_from_u64(0); + + let numerators = (0..n).map(|_| rng.random()).collect::>(); + let c: EF = rng.random(); + let denominators_indexes = (0..n) + .map(|_| PF::::from_usize(rng.random_range(..n))) + .collect::>(); + let denominators = denominators_indexes.iter().map(|&i| c - i).collect::>(); + let real_quotient = sum_all_quotients(&numerators, &denominators); + let mut prover_state = build_prover_state(); + + let time = Instant::now(); + let prover_statements = prove_gkr_quotient::( + &mut prover_state, + &MleRef::ExtensionPacked(&pack_extension(&numerators)), + &MleRef::ExtensionPacked(&pack_extension(&denominators)), + ); + println!("Proving time: {:?}", time.elapsed()); + + let mut verifier_state = build_verifier_state(prover_state); + + let verifier_statements = verify_gkr_quotient::(&mut verifier_state, log_n).unwrap(); + assert_eq!(&verifier_statements, &prover_statements); + let (retrieved_quotient, claim_point, claim_num, claim_den) = verifier_statements; + + assert_eq!(retrieved_quotient, real_quotient); + assert_eq!(numerators.evaluate(&claim_point), claim_num); + assert_eq!(denominators.evaluate(&claim_point), claim_den); + } +} diff --git a/crates/sub_protocols/src/vectorized_packed_lookup.rs b/crates/sub_protocols/src/vectorized_packed_lookup.rs deleted file mode 100644 index 8fbb10ba..00000000 --- a/crates/sub_protocols/src/vectorized_packed_lookup.rs +++ /dev/null @@ -1,159 +0,0 @@ -use multilinear_toolkit::prelude::*; -use utils::FSProver; -use utils::VecOrSlice; -use utils::fold_multilinear_chunks; - -use crate::GenericPackedLookupProver; -use crate::GenericPackedLookupVerifier; -use crate::PackedLookupStatements; - -#[derive(Debug)] -pub struct VectorizedPackedLookupProver<'a, EF: ExtensionField>, const VECTOR_LEN: usize> { - generic: GenericPackedLookupProver<'a, EF, EF>, - folding_scalars: MultilinearPoint, -} - -impl<'a, EF: ExtensionField>, const VECTOR_LEN: usize> VectorizedPackedLookupProver<'a, EF, VECTOR_LEN> -where - PF: PrimeField64, -{ - pub fn pushforward_to_commit(&self) -> &[EF] { - self.generic.pushforward_to_commit() - } - - // before committing to the pushforward - #[allow(clippy::too_many_arguments)] - pub fn step_1( - prover_state: &mut FSProver>, - table: &'a [PF], // table[0] is assumed to be zero - index_columns: Vec<&'a [PF]>, - heights: Vec, - default_indexes: Vec, - value_columns: Vec<[&'a [PF]; VECTOR_LEN]>, - statements: Vec>>, - log_smallest_decomposition_chunk: usize, - ) -> Self { - let folding_scalars = MultilinearPoint(prover_state.sample_vec(log2_strict_usize(VECTOR_LEN))); - let folded_table = fold_multilinear_chunks(table, &folding_scalars); - - let folding_poly_eq = eval_eq(&folding_scalars); - let folded_value_columns = value_columns - .par_iter() - .map(|cols| { - let n = cols[0].len(); - assert!(cols.iter().all(|c| c.len() == n)); - assert!(n.is_power_of_two()); - vec![VecOrSlice::Vec( - (0..n) - .into_par_iter() - .map(|i| { - folding_poly_eq - .iter() - .enumerate() - .map(|(j, &coeff)| coeff * cols[j][i]) - .sum::() - }) - .collect::>(), - )] - }) - .collect::>(); - - let generic = GenericPackedLookupProver::<'_, EF, EF>::step_1( - prover_state, - VecOrSlice::Vec(folded_table), - index_columns, - heights, - default_indexes, - folded_value_columns, - get_folded_statements(statements, &folding_scalars), - log_smallest_decomposition_chunk, - ); - - Self { - generic, - folding_scalars, - } - } - - // after committing to the pushforward - pub fn step_2( - &self, - prover_state: &mut FSProver>, - non_zero_memory_size: usize, - ) -> PackedLookupStatements { - let mut statements = self - .generic - .step_2(prover_state, non_zero_memory_size.div_ceil(VECTOR_LEN)); - statements.on_table.point.extend(self.folding_scalars.0.clone()); - statements - } -} - -#[derive(Debug)] -pub struct VectorizedPackedLookupVerifier>, const VECTOR_LEN: usize> { - generic: GenericPackedLookupVerifier, - folding_scalars: MultilinearPoint, -} - -impl>, const VECTOR_LEN: usize> VectorizedPackedLookupVerifier -where - PF: PrimeField64, -{ - // before receiving the commitment to the pushforward - pub fn step_1( - verifier_state: &mut FSVerifier>, - heights: Vec, - default_indexes: Vec, - statements: Vec>>, - log_smallest_decomposition_chunk: usize, - table_initial_values: &[PF], - ) -> ProofResult { - let folding_scalars = MultilinearPoint(verifier_state.sample_vec(log2_strict_usize(VECTOR_LEN))); - let folded_table_initial_values = fold_multilinear_chunks( - &table_initial_values[..(table_initial_values.len() / VECTOR_LEN) * VECTOR_LEN], - &folding_scalars, - ); - - let generic = GenericPackedLookupVerifier::step_1::( - verifier_state, - heights, - default_indexes, - get_folded_statements(statements, &folding_scalars), - log_smallest_decomposition_chunk, - &folded_table_initial_values, - )?; - - Ok(Self { - generic, - folding_scalars, - }) - } - - // after receiving the commitment to the pushforward - pub fn step_2( - &self, - verifier_state: &mut FSVerifier>, - log_memory_size: usize, - ) -> ProofResult> { - let mut statements = self - .generic - .step_2(verifier_state, log_memory_size - log2_strict_usize(VECTOR_LEN))?; - statements.on_table.point.extend(self.folding_scalars.0.clone()); - Ok(statements) - } -} - -fn get_folded_statements( - statements: Vec>>, - folding_scalars: &MultilinearPoint, -) -> Vec>> { - statements - .iter() - .map(|sub_statements| { - sub_statements - .iter() - .map(|meval| MultiEvaluation::new(meval.point.clone(), vec![meval.values.evaluate(folding_scalars)])) - .collect::>() - }) - .collect::>() -} diff --git a/crates/sub_protocols/tests/test_generic_packed_lookup.rs b/crates/sub_protocols/tests/test_generic_packed_lookup.rs deleted file mode 100644 index 2c4e166f..00000000 --- a/crates/sub_protocols/tests/test_generic_packed_lookup.rs +++ /dev/null @@ -1,117 +0,0 @@ -use multilinear_toolkit::prelude::*; -use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; -use p3_util::log2_ceil_usize; -use rand::{Rng, SeedableRng, rngs::StdRng}; -use sub_protocols::{GenericPackedLookupProver, GenericPackedLookupVerifier}; -use utils::{ToUsize, VecOrSlice, assert_eq_many, build_prover_state, build_verifier_state}; - -type F = KoalaBear; -type EF = QuinticExtensionFieldKB; -const LOG_SMALLEST_DECOMPOSITION_CHUNK: usize = 5; - -#[test] -fn test_generic_packed_lookup() { - let non_zero_memory_size: usize = 37412; - let lookups_height_and_cols: Vec<(usize, usize)> = vec![(4587, 1), (1234, 3), (9411, 1), (7890, 2)]; - let default_indexes = vec![7, 11, 0, 2]; - let n_statements = [1, 5, 2, 1]; - assert_eq_many!(lookups_height_and_cols.len(), default_indexes.len(), n_statements.len()); - - let mut rng = StdRng::seed_from_u64(0); - let mut memory = F::zero_vec(non_zero_memory_size.next_power_of_two()); - for mem in memory.iter_mut().take(non_zero_memory_size).skip(1) { - *mem = rng.random(); - } - - let mut all_indexe_columns = vec![]; - let mut all_value_columns = vec![]; - let mut all_statements = vec![]; - for (i, (n_lines, n_cols)) in lookups_height_and_cols.iter().enumerate() { - let mut indexes = vec![F::from_usize(default_indexes[i]); n_lines.next_power_of_two()]; - for idx in indexes.iter_mut().take(*n_lines) { - *idx = F::from_usize(rng.random_range(0..non_zero_memory_size)); - } - all_indexe_columns.push(indexes); - let indexes = all_indexe_columns.last().unwrap(); - - let mut columns = vec![]; - for col_index in 0..*n_cols { - let mut col = F::zero_vec(n_lines.next_power_of_two()); - for i in 0..n_lines.next_power_of_two() { - col[i] = memory[indexes[i].to_usize() + col_index]; - } - columns.push(col); - } - let mut statements = vec![]; - for _ in 0..n_statements[i] { - let point = MultilinearPoint::::random(&mut rng, log2_ceil_usize(*n_lines)); - let values = columns.iter().map(|col| col.evaluate(&point)).collect::>(); - statements.push(MultiEvaluation::new(point, values)); - } - all_statements.push(statements); - all_value_columns.push(columns); - } - - let mut prover_state = build_prover_state(false); - - let packed_lookup_prover = GenericPackedLookupProver::step_1( - &mut prover_state, - VecOrSlice::Slice(&memory), - all_indexe_columns.iter().map(Vec::as_slice).collect(), - lookups_height_and_cols.iter().map(|(h, _)| *h).collect(), - default_indexes.clone(), - all_value_columns - .iter() - .map(|cols| cols.iter().map(|s| VecOrSlice::Slice(s)).collect()) - .collect(), - all_statements.clone(), - LOG_SMALLEST_DECOMPOSITION_CHUNK, - ); - - // phony commitment to pushforward - prover_state.hint_extension_scalars(packed_lookup_prover.pushforward_to_commit()); - - let remaining_claims_to_prove = packed_lookup_prover.step_2(&mut prover_state, non_zero_memory_size); - - let mut verifier_state = build_verifier_state(prover_state); - - let packed_lookup_verifier = GenericPackedLookupVerifier::step_1( - &mut verifier_state, - lookups_height_and_cols.iter().map(|(h, _)| *h).collect(), - default_indexes, - all_statements, - LOG_SMALLEST_DECOMPOSITION_CHUNK, - &memory[..100], - ) - .unwrap(); - - // receive commitment to pushforward - let pushforward = verifier_state - .receive_hint_extension_scalars(non_zero_memory_size.next_power_of_two()) - .unwrap(); - - let remaining_claims_to_verify = packed_lookup_verifier - .step_2(&mut verifier_state, log2_ceil_usize(non_zero_memory_size)) - .unwrap(); - - assert_eq!(&remaining_claims_to_prove, &remaining_claims_to_verify); - - assert_eq!( - memory.evaluate(&remaining_claims_to_verify.on_table.point), - remaining_claims_to_verify.on_table.value - ); - for pusforward_statement in &remaining_claims_to_verify.on_pushforward { - assert_eq!( - pushforward.evaluate(&pusforward_statement.point), - pusforward_statement.value - ); - } - for (index_col, index_statements) in all_indexe_columns - .iter() - .zip(remaining_claims_to_verify.on_indexes.iter()) - { - for statement in index_statements { - assert_eq!(index_col.evaluate(&statement.point), statement.value); - } - } -} diff --git a/crates/sub_protocols/tests/test_normal_packed_lookup.rs b/crates/sub_protocols/tests/test_normal_packed_lookup.rs deleted file mode 100644 index 458509a1..00000000 --- a/crates/sub_protocols/tests/test_normal_packed_lookup.rs +++ /dev/null @@ -1,161 +0,0 @@ -use multilinear_toolkit::prelude::*; -use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; -use p3_util::log2_ceil_usize; -use rand::{Rng, SeedableRng, rngs::StdRng}; -use sub_protocols::{NormalPackedLookupProver, NormalPackedLookupVerifier}; -use utils::{ToUsize, build_prover_state, build_verifier_state, collect_refs}; - -type F = KoalaBear; -type EF = QuinticExtensionFieldKB; -const LOG_SMALLEST_DECOMPOSITION_CHUNK: usize = 5; - -#[test] -fn test_normal_packed_lookup() { - let non_zero_memory_size: usize = 37412; - let cols_heights_f: Vec = vec![785, 1022, 4751]; - let cols_heights_ef: Vec = vec![2088, 110]; - let default_indexes_f = vec![7, 11, 0]; - let default_indexes_ef = vec![2, 3]; - let n_statements_f = vec![1, 5, 2]; - let n_statements_ef = vec![3, 4]; - - let mut rng = StdRng::seed_from_u64(0); - let mut memory = F::zero_vec(non_zero_memory_size.next_power_of_two()); - for mem in memory.iter_mut().take(non_zero_memory_size).skip(1) { - *mem = rng.random(); - } - - let mut all_indexe_columns_f = vec![]; - let mut all_indexe_columns_ef = vec![]; - for (i, height) in cols_heights_f.iter().enumerate() { - let mut indexes = vec![F::from_usize(default_indexes_f[i]); height.next_power_of_two()]; - for idx in indexes.iter_mut().take(*height) { - *idx = F::from_usize(rng.random_range(0..non_zero_memory_size)); - } - all_indexe_columns_f.push(indexes); - } - for (i, height) in cols_heights_ef.iter().enumerate() { - let mut indexes = vec![F::from_usize(default_indexes_ef[i]); height.next_power_of_two()]; - for idx in indexes.iter_mut().take(*height) { - *idx = - F::from_usize(rng.random_range(0..non_zero_memory_size - >>::DIMENSION)); - } - all_indexe_columns_ef.push(indexes); - } - - let mut value_columns_f = vec![]; - for base_col in &all_indexe_columns_f { - let mut values = vec![]; - for index in base_col { - values.push(memory[index.to_usize()]); - } - value_columns_f.push(values); - } - let mut value_columns_ef = vec![]; - for ext_col in &all_indexe_columns_ef { - let mut values = vec![]; - for index in ext_col { - values.push(QuinticExtensionFieldKB::from_basis_coefficients_fn(|i| { - memory[index.to_usize() + i] - })); - } - value_columns_ef.push(values); - } - - let mut all_statements_f = vec![]; - for (value_col_f, n_statements) in value_columns_f.iter().zip(&n_statements_f) { - let mut statements = vec![]; - for _ in 0..*n_statements { - let point = MultilinearPoint::::random(&mut rng, log2_strict_usize(value_col_f.len())); - let value = value_col_f.evaluate(&point); - statements.push(Evaluation::new(point, value)); - } - all_statements_f.push(statements); - } - let mut all_statements_ef = vec![]; - for (value_col_ef, n_statements) in value_columns_ef.iter().zip(&n_statements_ef) { - let mut statements = vec![]; - for _ in 0..*n_statements { - let point = MultilinearPoint::::random(&mut rng, log2_strict_usize(value_col_ef.len())); - let value = value_col_ef.evaluate(&point); - statements.push(Evaluation::new(point, value)); - } - all_statements_ef.push(statements); - } - - let mut prover_state = build_prover_state(false); - - let packed_lookup_prover = NormalPackedLookupProver::step_1( - &mut prover_state, - &memory, - collect_refs(&all_indexe_columns_f), - collect_refs(&all_indexe_columns_ef), - cols_heights_f.clone(), - cols_heights_ef.clone(), - default_indexes_f.clone(), - default_indexes_ef.clone(), - collect_refs(&value_columns_f), - collect_refs(&value_columns_ef), - all_statements_f.clone(), - all_statements_ef.clone(), - LOG_SMALLEST_DECOMPOSITION_CHUNK, - ); - - // phony commitment to pushforward - prover_state.hint_extension_scalars(packed_lookup_prover.pushforward_to_commit()); - - let remaining_claims_to_prove = packed_lookup_prover.step_2(&mut prover_state, non_zero_memory_size); - - let mut verifier_state = build_verifier_state(prover_state); - - let packed_lookup_verifier = NormalPackedLookupVerifier::step_1( - &mut verifier_state, - cols_heights_f, - cols_heights_ef, - default_indexes_f, - default_indexes_ef, - all_statements_f, - all_statements_ef, - LOG_SMALLEST_DECOMPOSITION_CHUNK, - &memory[..100], - ) - .unwrap(); - - // receive commitment to pushforward - let pushforward = verifier_state - .receive_hint_extension_scalars(non_zero_memory_size.next_power_of_two()) - .unwrap(); - - let remaining_claims_to_verify = packed_lookup_verifier - .step_2(&mut verifier_state, log2_ceil_usize(non_zero_memory_size)) - .unwrap(); - - assert_eq!(&remaining_claims_to_prove, &remaining_claims_to_verify); - - assert_eq!( - memory.evaluate(&remaining_claims_to_verify.on_table.point), - remaining_claims_to_verify.on_table.value - ); - for pusforward_statement in &remaining_claims_to_verify.on_pushforward { - assert_eq!( - pushforward.evaluate(&pusforward_statement.point), - pusforward_statement.value - ); - } - for (index_col, index_statements) in all_indexe_columns_f - .iter() - .zip(remaining_claims_to_verify.on_indexes_f.iter()) - { - for statement in index_statements { - assert_eq!(index_col.evaluate(&statement.point), statement.value); - } - } - for (index_col, index_statements) in all_indexe_columns_ef - .iter() - .zip(remaining_claims_to_verify.on_indexes_ef.iter()) - { - for statement in index_statements { - assert_eq!(index_col.evaluate(&statement.point), statement.value); - } - } -} diff --git a/crates/sub_protocols/tests/test_vectorized_packed_lookup.rs b/crates/sub_protocols/tests/test_vectorized_packed_lookup.rs deleted file mode 100644 index 8721962e..00000000 --- a/crates/sub_protocols/tests/test_vectorized_packed_lookup.rs +++ /dev/null @@ -1,123 +0,0 @@ -use std::array; - -use multilinear_toolkit::prelude::*; -use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; -use p3_util::log2_ceil_usize; -use rand::{Rng, SeedableRng, rngs::StdRng}; -use sub_protocols::{VectorizedPackedLookupProver, VectorizedPackedLookupVerifier}; -use utils::{ToUsize, assert_eq_many, build_prover_state, build_verifier_state}; - -type F = KoalaBear; -type EF = QuinticExtensionFieldKB; -const LOG_SMALLEST_DECOMPOSITION_CHUNK: usize = 5; - -const VECTOR_LEN: usize = 8; - -#[test] -fn test_vectorized_packed_lookup() { - let non_zero_memory_size: usize = 37412; - let cols_heights: Vec = vec![785, 1022, 4751]; - let default_indexes = vec![7, 11, 0]; - let n_statements = vec![1, 5, 2]; - assert_eq_many!(cols_heights.len(), default_indexes.len(), n_statements.len()); - - let mut rng = StdRng::seed_from_u64(0); - let mut memory = F::zero_vec(non_zero_memory_size.next_power_of_two()); - for mem in memory.iter_mut().take(non_zero_memory_size).skip(VECTOR_LEN) { - *mem = rng.random(); - } - - let mut all_indexe_columns = vec![]; - for (i, height) in cols_heights.iter().enumerate() { - let mut indexes = vec![F::from_usize(default_indexes[i]); height.next_power_of_two()]; - for idx in indexes.iter_mut().take(*height) { - *idx = F::from_usize(rng.random_range(0..non_zero_memory_size / VECTOR_LEN)); - } - all_indexe_columns.push(indexes); - } - - let mut all_value_columns = vec![]; - for index_col in &all_indexe_columns { - let mut values: [Vec; VECTOR_LEN] = Default::default(); - for index in index_col { - for i in 0..VECTOR_LEN { - values[i].push(memory[index.to_usize() * VECTOR_LEN + i]); - } - } - all_value_columns.push(values); - } - - let mut all_statements = vec![]; - for (value_cols, n_statements) in all_value_columns.iter().zip(&n_statements) { - let mut statements = vec![]; - for _ in 0..*n_statements { - let point = MultilinearPoint::::random(&mut rng, log2_strict_usize(value_cols[0].len())); - let values = value_cols.iter().map(|col| col.evaluate(&point)).collect::>(); - statements.push(MultiEvaluation::new(point, values)); - } - all_statements.push(statements); - } - - let mut prover_state = build_prover_state(false); - - let packed_lookup_prover = VectorizedPackedLookupProver::step_1( - &mut prover_state, - &memory, - all_indexe_columns.iter().map(Vec::as_slice).collect(), - cols_heights.clone(), - default_indexes.clone(), - all_value_columns - .iter() - .map(|v| array::from_fn::<_, VECTOR_LEN, _>(|i| v[i].as_slice())) - .collect(), - all_statements.clone(), - LOG_SMALLEST_DECOMPOSITION_CHUNK, - ); - - // phony commitment to pushforward - prover_state.hint_extension_scalars(packed_lookup_prover.pushforward_to_commit()); - - let remaining_claims_to_prove = packed_lookup_prover.step_2(&mut prover_state, non_zero_memory_size); - - let mut verifier_state = build_verifier_state(prover_state); - - let packed_lookup_verifier = VectorizedPackedLookupVerifier::<_, VECTOR_LEN>::step_1( - &mut verifier_state, - cols_heights, - default_indexes, - all_statements, - LOG_SMALLEST_DECOMPOSITION_CHUNK, - &memory[..100], - ) - .unwrap(); - - // receive commitment to pushforward - let pushforward = verifier_state - .receive_hint_extension_scalars(non_zero_memory_size.next_power_of_two() / VECTOR_LEN) - .unwrap(); - - let remaining_claims_to_verify = packed_lookup_verifier - .step_2(&mut verifier_state, log2_ceil_usize(non_zero_memory_size)) - .unwrap(); - - assert_eq!(&remaining_claims_to_prove, &remaining_claims_to_verify); - - assert_eq!( - memory.evaluate(&remaining_claims_to_verify.on_table.point), - remaining_claims_to_verify.on_table.value - ); - for pusforward_statement in &remaining_claims_to_verify.on_pushforward { - assert_eq!( - pushforward.evaluate(&pusforward_statement.point), - pusforward_statement.value - ); - } - for (index_col, index_statements) in all_indexe_columns - .iter() - .zip(remaining_claims_to_verify.on_indexes.iter()) - { - for statement in index_statements { - assert_eq!(index_col.evaluate(&statement.point), statement.value); - } - } -} diff --git a/crates/utils/Cargo.toml b/crates/utils/Cargo.toml index fa686f7a..27e4ecd6 100644 --- a/crates/utils/Cargo.toml +++ b/crates/utils/Cargo.toml @@ -7,8 +7,6 @@ edition.workspace = true workspace = true [dependencies] -p3-air.workspace = true -p3-challenger.workspace = true p3-koala-bear.workspace = true tracing-forest.workspace = true p3-symmetric.workspace = true diff --git a/crates/utils/src/constraints_checker.rs b/crates/utils/src/constraints_checker.rs index 0750f672..6f03fd9c 100644 --- a/crates/utils/src/constraints_checker.rs +++ b/crates/utils/src/constraints_checker.rs @@ -1,5 +1,3 @@ -use p3_air::AirBuilder; - use multilinear_toolkit::prelude::*; /* diff --git a/crates/utils/src/misc.rs b/crates/utils/src/misc.rs index 480b7d4c..12d33dc8 100644 --- a/crates/utils/src/misc.rs +++ b/crates/utils/src/misc.rs @@ -53,6 +53,12 @@ pub fn to_little_endian_bits(value: usize, bit_count: usize) -> Vec { res } +pub fn to_little_endian_in_field(value: usize, bit_count: usize) -> Vec { + let mut res = to_big_endian_in_field::(value, bit_count); + res.reverse(); + res +} + #[macro_export] macro_rules! assert_eq_many { ($first:expr, $($rest:expr),+ $(,)?) => { @@ -126,3 +132,22 @@ pub fn encapsulate_vec(v: Vec) -> Vec> { pub fn collect_refs(vecs: &[Vec]) -> Vec<&[T]> { vecs.iter().map(Vec::as_slice).collect() } + +pub fn collect_inner_refs(vecs: &[Vec>]) -> Vec> { + vecs.iter().map(|v| collect_refs(v)).collect() +} + +#[derive(Debug, Clone, Default)] +pub struct Counter(usize); + +impl Counter { + pub fn get_next(&mut self) -> usize { + let val = self.0; + self.0 += 1; + val + } + + pub fn new() -> Self { + Self(0) + } +} diff --git a/crates/utils/src/multilinear.rs b/crates/utils/src/multilinear.rs index b4d50991..fd516733 100644 --- a/crates/utils/src/multilinear.rs +++ b/crates/utils/src/multilinear.rs @@ -107,6 +107,27 @@ pub fn fold_multilinear_chunks>( .collect() } +pub fn mle_of_01234567_etc(point: &[F]) -> F { + if point.is_empty() { + F::ZERO + } else { + let e = mle_of_01234567_etc(&point[1..]); + (F::ONE - point[0]) * e + point[0] * (e + F::from_usize(1 << (point.len() - 1))) + } +} + +/// table = 0 is reversed for memory +pub const MEMORY_TABLE_INDEX: usize = 0; + +pub fn finger_print>, EF: ExtensionField + ExtensionField>( + table: F, + data: &[IF], + alpha_powers: &[EF], +) -> EF { + assert!(alpha_powers.len() > data.len()); + dot_product::(alpha_powers[1..].iter().copied(), data.iter().copied()) + table +} + #[cfg(test)] mod tests { use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; @@ -147,4 +168,17 @@ mod tests { assert_eq!(eval, pol.evaluate(&MultilinearPoint(point.clone()))); } } + + #[test] + fn test_mle_of_01234567_etc() { + let n_vars = 10; + let mut rng = StdRng::seed_from_u64(0); + let point = (0..n_vars).map(|_| rng.random()).collect::>(); + let eval = mle_of_01234567_etc(&point); + let mut pol = F::zero_vec(1 << n_vars); + for (i, p) in pol.iter_mut().enumerate().take(1 << n_vars) { + *p = F::from_usize(i); + } + assert_eq!(eval, pol.evaluate(&MultilinearPoint(point))); + } } diff --git a/crates/utils/src/poseidon2.rs b/crates/utils/src/poseidon2.rs index e940ad5f..59beed64 100644 --- a/crates/utils/src/poseidon2.rs +++ b/crates/utils/src/poseidon2.rs @@ -3,9 +3,6 @@ use std::sync::OnceLock; use p3_koala_bear::KOALABEAR_RC16_EXTERNAL_FINAL; use p3_koala_bear::KOALABEAR_RC16_EXTERNAL_INITIAL; use p3_koala_bear::KOALABEAR_RC16_INTERNAL; -use p3_koala_bear::KOALABEAR_RC24_EXTERNAL_FINAL; -use p3_koala_bear::KOALABEAR_RC24_EXTERNAL_INITIAL; -use p3_koala_bear::KOALABEAR_RC24_INTERNAL; use p3_koala_bear::KoalaBear; use p3_koala_bear::Poseidon2KoalaBear; use p3_poseidon2::ExternalLayerConstants; @@ -18,15 +15,11 @@ pub const QUARTER_FULL_ROUNDS_16: usize = 2; pub const HALF_FULL_ROUNDS_16: usize = 4; pub const PARTIAL_ROUNDS_16: usize = 20; -pub const QUARTER_FULL_ROUNDS_24: usize = 2; -pub const HALF_FULL_ROUNDS_24: usize = 4; -pub const PARTIAL_ROUNDS_24: usize = 23; - static POSEIDON_16_INSTANCE: OnceLock = OnceLock::new(); static POSEIDON_16_OF_ZERO: OnceLock<[KoalaBear; 16]> = OnceLock::new(); #[inline(always)] -pub(crate) fn get_poseidon16() -> &'static Poseidon16 { +pub fn get_poseidon16() -> &'static Poseidon16 { POSEIDON_16_INSTANCE.get_or_init(|| { let external_constants = ExternalLayerConstants::new( KOALABEAR_RC16_EXTERNAL_INITIAL.to_vec(), @@ -50,32 +43,3 @@ pub fn poseidon16_permute(input: [KoalaBear; 16]) -> [KoalaBear; 16] { pub fn poseidon16_permute_mut(input: &mut [KoalaBear; 16]) { get_poseidon16().permute_mut(input); } - -#[inline(always)] -pub fn poseidon24_permute(input: [KoalaBear; 24]) -> [KoalaBear; 24] { - get_poseidon24().permute(input) -} - -#[inline(always)] -pub fn poseidon24_permute_mut(input: &mut [KoalaBear; 24]) { - get_poseidon24().permute_mut(input); -} - -static POSEIDON_24_INSTANCE: OnceLock = OnceLock::new(); -static POSEIDON_24_OF_ZERO: OnceLock<[KoalaBear; 24]> = OnceLock::new(); - -#[inline(always)] -pub(crate) fn get_poseidon24() -> &'static Poseidon24 { - POSEIDON_24_INSTANCE.get_or_init(|| { - let external_constants = ExternalLayerConstants::new( - KOALABEAR_RC24_EXTERNAL_INITIAL.to_vec(), - KOALABEAR_RC24_EXTERNAL_FINAL.to_vec(), - ); - Poseidon24::new(external_constants, KOALABEAR_RC24_INTERNAL.to_vec()) - }) -} - -#[inline(always)] -pub fn get_poseidon_24_of_zero() -> &'static [KoalaBear; 24] { - POSEIDON_24_OF_ZERO.get_or_init(|| poseidon24_permute([KoalaBear::default(); 24])) -} diff --git a/crates/utils/src/wrappers.rs b/crates/utils/src/wrappers.rs index 29030d98..0c2aca6d 100644 --- a/crates/utils/src/wrappers.rs +++ b/crates/utils/src/wrappers.rs @@ -1,27 +1,23 @@ use multilinear_toolkit::prelude::*; -use p3_challenger::DuplexChallenger; -use p3_koala_bear::KoalaBear; +use p3_koala_bear::QuinticExtensionFieldKB; use crate::Poseidon16; use crate::get_poseidon16; -pub type FSProver = ProverState, EF, Challenger>; -pub type FSVerifier = VerifierState, EF, Challenger>; +pub type VarCount = usize; -pub type MyChallenger = DuplexChallenger; - -pub fn build_challenger() -> MyChallenger { - MyChallenger::new(get_poseidon16().clone()) -} - -pub fn build_prover_state>(padding: bool) -> ProverState { - ProverState::new(build_challenger(), padding) +pub fn build_prover_state() -> ProverState { + let mut prover_state = ProverState::new(get_poseidon16().clone()); + prover_state.duplexing(); + prover_state } -pub fn build_verifier_state>( - prover_state: ProverState, -) -> VerifierState { - VerifierState::new(prover_state.into_proof(), build_challenger()) +pub fn build_verifier_state( + prover_state: ProverState, +) -> VerifierState { + let mut verifier_state = VerifierState::new(prover_state.into_proof(), get_poseidon16().clone()); + verifier_state.duplexing(); + verifier_state } pub trait ToUsize { diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index 3c2e8ee2..3dbae626 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -1,8 +1,12 @@ +/* +Toy (unsecure) XMSS, intended for benchmark only. +Production-grade XMSS SOON. +*/ + #![cfg_attr(not(test), warn(unused_crate_dependencies))] use p3_koala_bear::KoalaBear; - mod wots; -use utils::{poseidon16_permute, poseidon24_permute}; +use utils::poseidon16_permute; pub use wots::*; mod xmss; pub use xmss::*; @@ -22,7 +26,6 @@ pub const XMSS_MIN_LOG_LIFETIME: usize = 2; pub const XMSS_MAX_LOG_LIFETIME: usize = 30; pub type Poseidon16History = Vec<([F; 16], [F; 16])>; -pub type Poseidon24History = Vec<([F; 24], [F; 8])>; fn poseidon16_compress(a: &Digest, b: &Digest) -> Digest { poseidon16_permute([*a, *b].concat().try_into().unwrap())[0..8] @@ -36,21 +39,3 @@ fn poseidon16_compress_with_trace(a: &Digest, b: &Digest, poseidon_16_trace: &mu poseidon_16_trace.push((input, output)); output[0..8].try_into().unwrap() } - -fn poseidon24_compress(a: &Digest, b: &Digest, c: &Digest) -> Digest { - poseidon24_permute([*a, *b, *c].concat().try_into().unwrap())[16..24] - .try_into() - .unwrap() -} - -fn poseidon24_compress_with_trace( - a: &Digest, - b: &Digest, - c: &Digest, - poseidon_24_trace: &mut Vec<([F; 24], [F; 8])>, -) -> Digest { - let input: [F; 24] = [*a, *b, *c].concat().try_into().unwrap(); - let output = poseidon24_permute(input)[16..24].try_into().unwrap(); - poseidon_24_trace.push((input, output)); - output -} diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index 90931e4c..7e55f575 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -68,10 +68,9 @@ impl WotsPublicKey { self.hash_with_poseidon_trace(&mut Vec::new()) } - pub fn hash_with_poseidon_trace(&self, poseidon_24_trace: &mut Poseidon24History) -> Digest { - assert!(V.is_multiple_of(2), "V must be even for hashing pairs."); - self.0.chunks_exact(2).fold(Digest::default(), |digest, chunk| { - poseidon24_compress_with_trace(&chunk[0], &chunk[1], &digest, poseidon_24_trace) + pub fn hash_with_poseidon_trace(&self, poseidon_16_trace: &mut Poseidon16History) -> Digest { + self.0.iter().fold(Digest::default(), |digest, chunk| { + poseidon16_compress_with_trace(&digest, chunk, poseidon_16_trace) }) } } diff --git a/crates/xmss/src/xmss.rs b/crates/xmss/src/xmss.rs index 927513a0..bcce0743 100644 --- a/crates/xmss/src/xmss.rs +++ b/crates/xmss/src/xmss.rs @@ -164,7 +164,7 @@ pub fn xmss_verify_with_poseidon_trace( pub_key: &XmssPublicKey, message_hash: &Digest, signature: &XmssSignature, -) -> Result<(Poseidon16History, Poseidon24History), XmssVerifyError> { +) -> Result { if signature.slot < pub_key.first_slot { return Err(XmssVerifyError::SlotTooEarly); } @@ -173,13 +173,12 @@ pub fn xmss_verify_with_poseidon_trace( return Err(XmssVerifyError::SlotTooLate); } let mut poseidon_16_trace = Vec::new(); - let mut poseidon_24_trace = Vec::new(); let wots_public_key = signature .wots_signature .recover_public_key_with_poseidon_trace(message_hash, &signature.wots_signature, &mut poseidon_16_trace) .ok_or(XmssVerifyError::InvalidWots)?; // merkle root verification - let mut current_hash = wots_public_key.hash_with_poseidon_trace(&mut poseidon_24_trace); + let mut current_hash = wots_public_key.hash_with_poseidon_trace(&mut poseidon_16_trace); if signature.merkle_proof.len() != pub_key.log_lifetime { return Err(XmssVerifyError::InvalidMerklePath); } @@ -192,7 +191,7 @@ pub fn xmss_verify_with_poseidon_trace( } } if current_hash == pub_key.merkle_root { - Ok((poseidon_16_trace, poseidon_24_trace)) + Ok(poseidon_16_trace) } else { Err(XmssVerifyError::InvalidMerklePath) } diff --git a/docs/Whirlaway.pdf b/docs/Whirlaway.pdf deleted file mode 100644 index 87f909cb..00000000 Binary files a/docs/Whirlaway.pdf and /dev/null differ diff --git a/docs/XMSS_trivial_encoding.pdf b/docs/XMSS_trivial_encoding.pdf deleted file mode 100644 index 3d2b4369..00000000 Binary files a/docs/XMSS_trivial_encoding.pdf and /dev/null differ diff --git a/docs/benchmark_graphs/graphs/raw_poseidons.svg b/docs/benchmark_graphs/graphs/raw_poseidons.svg deleted file mode 100644 index 7e569814..00000000 --- a/docs/benchmark_graphs/graphs/raw_poseidons.svg +++ /dev/null @@ -1,1636 +0,0 @@ - - - - - - - - 2025-12-04T22:06:18.732680 - image/svg+xml - - - Matplotlib v3.10.5, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/benchmark_graphs/graphs/recursive_whir_opening.svg b/docs/benchmark_graphs/graphs/recursive_whir_opening.svg deleted file mode 100644 index df50e22d..00000000 --- a/docs/benchmark_graphs/graphs/recursive_whir_opening.svg +++ /dev/null @@ -1,1957 +0,0 @@ - - - - - - - - 2025-12-05T09:04:06.823546 - image/svg+xml - - - Matplotlib v3.10.5, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/benchmark_graphs/graphs/xmss_aggregated.svg b/docs/benchmark_graphs/graphs/xmss_aggregated.svg deleted file mode 100644 index 811e5210..00000000 --- a/docs/benchmark_graphs/graphs/xmss_aggregated.svg +++ /dev/null @@ -1,1676 +0,0 @@ - - - - - - - - 2025-12-05T09:04:06.893738 - image/svg+xml - - - Matplotlib v3.10.5, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/benchmark_graphs/graphs/xmss_aggregated_overhead.svg b/docs/benchmark_graphs/graphs/xmss_aggregated_overhead.svg deleted file mode 100644 index 581edb58..00000000 --- a/docs/benchmark_graphs/graphs/xmss_aggregated_overhead.svg +++ /dev/null @@ -1,1780 +0,0 @@ - - - - - - - - 2025-12-05T09:04:06.953527 - image/svg+xml - - - Matplotlib v3.10.5, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/benchmark_graphs/main.py b/docs/benchmark_graphs/main.py deleted file mode 100644 index 3c2eaaf3..00000000 --- a/docs/benchmark_graphs/main.py +++ /dev/null @@ -1,212 +0,0 @@ -import matplotlib.pyplot as plt -import matplotlib.dates as mdates -from matplotlib.ticker import ScalarFormatter, LogLocator -from datetime import datetime, timedelta - -# uv run python main.py - -N_DAYS_SHOWN = 130 - -plt.rcParams.update({ - 'font.size': 12, # Base font size - 'axes.titlesize': 14, # Title font size - 'axes.labelsize': 12, # X and Y label font size - 'xtick.labelsize': 12, # X tick label font size - 'ytick.labelsize': 12, # Y tick label font size - 'legend.fontsize': 12, # Legend font size -}) - - -def create_duration_graph(data, target=None, target_label=None, title="", y_legend="", file="", labels=None, log_scale=False): - if labels is None: - labels = ["Series 1"] - - # Number of curves based on tuple length - num_curves = len(data[0]) - 1 if data else 1 # -1 for the date - - dates = [] - values = [[] for _ in range(num_curves)] - - for item in data: - dates.append(datetime.strptime(item[0], '%Y-%m-%d')) - for i in range(num_curves): - values[i].append(item[i + 1]) - - colors = ['#2E86AB', "#FF0000", '#28A745', "#FF7B00", "#9E01FF"] - markers = ['o', 's', 's', '^', '^'] - - _, ax = plt.subplots(figsize=(10, 6)) - - all_values = [] - - for i in range(num_curves): - if i >= len(labels): - break # No label provided for this curve - - # Filter out None values - dates_filtered = [d for d, v in zip(dates, values[i]) if v is not None] - values_filtered = [v for v in values[i] if v is not None] - - if values_filtered: # Only plot if there's data - ax.plot(dates_filtered, values_filtered, - marker=markers[i % len(markers)], - linewidth=2, - markersize=8, - color=colors[i % len(colors)], - label=labels[i]) - all_values.extend(values_filtered) - - min_date = min(dates) - max_date = max(dates) - date_range = max_date - min_date - if date_range < timedelta(days=N_DAYS_SHOWN): - max_date = min_date + timedelta(days=N_DAYS_SHOWN) - ax.set_xlim(min_date - timedelta(days=1), max_date + timedelta(days=1)) - - ax.xaxis.set_major_locator(mdates.WeekdayLocator(interval=1)) - ax.xaxis.set_major_formatter(mdates.DateFormatter('%m/%d')) - - plt.setp(ax.xaxis.get_majorticklabels(), rotation=50, ha='right') - - if target is not None and target_label is not None: - ax.axhline(y=target, color='#555555', linestyle='--', - linewidth=2, label=target_label) - - ax.set_ylabel(y_legend, fontsize=12) - ax.set_title(title, fontsize=16, pad=15) - ax.grid(True, alpha=0.3, which='both') - ax.legend() - - if log_scale: - ax.set_yscale('log') - - ax.yaxis.set_major_locator(LogLocator(base=10.0, numticks=15)) - ax.yaxis.set_minor_locator(LogLocator(base=10.0, subs=range(1, 10), numticks=100)) - - ax.yaxis.set_major_formatter(ScalarFormatter()) - ax.yaxis.set_minor_formatter(ScalarFormatter()) - ax.yaxis.get_major_formatter().set_scientific(False) - ax.yaxis.get_minor_formatter().set_scientific(False) - - ax.tick_params(axis='y', which='minor', labelsize=10) - else: - if all_values: - max_value = max(all_values) - if target is not None: - max_value = max(max_value, target) - ax.set_ylim(0, max_value * 1.1) - - plt.tight_layout() - plt.savefig(f'graphs/{file}.svg', format='svg', bbox_inches='tight') - - -if __name__ == "__main__": - - create_duration_graph( - data=[ - ('2025-08-27', 85000, None,), - ('2025-08-30', 95000, None), - ('2025-09-09', 108000, None), - ('2025-09-14', 108000, None), - ('2025-09-28', 125000, None), - ('2025-10-01', 185000, None), - ('2025-10-12', 195000, None), - ('2025-10-13', 205000, None), - ('2025-10-18', 210000, 620_000), - ('2025-10-27', 610_000, 1_250_000), - ('2025-10-29', 650_000, 1_300_000), - ], - target=1_500_000, - target_label="Target (1.5M Poseidon2 / s)", - title="Raw Poseidon2", - y_legend="Poseidons proven / s", - file="raw_poseidons", - labels=["i9-12900H", "m4-max"], - log_scale=False - ) - - create_duration_graph( - data=[ - ('2025-08-27', 2.7, None, None, None, None), - ('2025-09-07', 1.4, None, None, None, None), - ('2025-09-09', 1.32, None, None, None, None), - ('2025-09-10', 0.970, None, None, None, None), - ('2025-09-14', 0.825, None, None, None, None), - ('2025-09-28', 0.725, None, None, None, None), - ('2025-10-01', 0.685, None, None, None, None), - ('2025-10-03', 0.647, None, None, None, None), - ('2025-10-12', 0.569, None, None, None, None), - ('2025-10-13', 0.521, None, None, None, None), - ('2025-10-18', 0.411, 0.320, None, None, None), - ('2025-10-27', 0.425, 0.330, None, None, None), - ('2025-11-15', 0.417, 0.330, None, None, None), - ('2025-12-04', None, 0.220, 0.097, 0.320, 0.130), - ('2025-12-05', None, 0.217, 0.092, 0.305, 0.125), - ], - target=0.1, - target_label="Target (0.1 s)", - title="Recursive WHIR opening (log scale)", - y_legend="Proving time (s)", - file="recursive_whir_opening", - labels=["i9-12900H | 1-to-1", "m4-max | 1-to-1", "m4-max | n-to-1", "m4-max | lean-vm-simple | 1-to-1", "m4-max | lean-vm-simple | n-to-1"], - log_scale=True - ) - - create_duration_graph( - data=[ - ('2025-08-27', 35, None, None), - ('2025-09-02', 37, None, None), - ('2025-09-03', 53, None, None), - ('2025-09-09', 62, None, None), - ('2025-09-10', 76, None, None), - ('2025-09-14', 107, None, None), - ('2025-09-28', 137, None, None), - ('2025-10-01', 172, None, None), - ('2025-10-03', 177, None, None), - ('2025-10-07', 193, None, None), - ('2025-10-12', 214, None, None), - ('2025-10-13', 234, None, None), - ('2025-10-18', 255, 465, None), - ('2025-10-27', 314, 555, None), - ('2025-11-02', 350, 660, None), - ('2025-11-15', 380, 720, None), - ('2025-12-04', None, 940, 755), - ('2025-12-05', None, 971, 813), - ], - target=1000, - target_label="Target (1000 XMSS/s)", - title="number of XMSS aggregated / s", - y_legend="", - file="xmss_aggregated", - labels=["i9-12900H", "m4-max", "m4-max | lean-vm-simple"] - ) - - create_duration_graph( - data=[ - ('2025-08-27', 14.2 / 0.92, None, None), - ('2025-09-02', 13.5 / 0.82, None, None), - ('2025-09-03', 9.4 / 0.82, None, None), - ('2025-09-09', 8.02 / 0.72, None, None), - ('2025-09-10', 6.53 / 0.72, None, None), - ('2025-09-14', 4.65 / 0.72, None, None), - ('2025-09-28', 3.63 / 0.63, None, None), - ('2025-10-01', 2.9 / 0.42, None, None), - ('2025-10-03', 2.81 / 0.42, None, None), - ('2025-10-07', 2.59 / 0.42, None, None), - ('2025-10-12', 2.33 / 0.40, None, None), - ('2025-10-13', 2.13 / 0.38, None, None), - ('2025-10-18', 1.96 / 0.37, 1.07 / 0.12, None), - ('2025-10-27', (610_000 / 157) / 314, (1_250_000 / 157) / 555, None), - ('2025-10-29', (650_000 / 157) / 314, (1_300_000 / 157) / 555, None), - ('2025-11-02', (650_000 / 157) / 350, (1_300_000 / 157) / 660, None), - ('2025-11-15', (650_000 / 157) / 380, (1_300_000 / 157) / 720, None), - ('2025-12-04', None, (1_300_000 / 157) / 940, (1_300_000 / 157) / 755), - ('2025-12-05', None, (1_300_000 / 157) / 971, (1_300_000 / 157) / 813), - ], - target=2, - target_label="Target (2x)", - title="XMSS aggregated: zkVM overhead vs raw Poseidons", - y_legend="", - file="xmss_aggregated_overhead", - labels=["i9-12900H", "m4-max", "m4-max | lean-vm-simple"] - ) \ No newline at end of file diff --git a/docs/benchmark_graphs/pyproject.toml b/docs/benchmark_graphs/pyproject.toml deleted file mode 100644 index 41fdcd46..00000000 --- a/docs/benchmark_graphs/pyproject.toml +++ /dev/null @@ -1,7 +0,0 @@ -[project] -name = "benchmark_graphs" -version = "0.1.0" -requires-python = ">=3.12" -dependencies = [ - "matplotlib>=3.10.5", -] diff --git a/docs/benchmark_graphs/uv.lock b/docs/benchmark_graphs/uv.lock deleted file mode 100644 index d03b2c1f..00000000 --- a/docs/benchmark_graphs/uv.lock +++ /dev/null @@ -1,424 +0,0 @@ -version = 1 -revision = 3 -requires-python = ">=3.12" - -[[package]] -name = "benchmark-graphs" -version = "0.1.0" -source = { virtual = "." } -dependencies = [ - { name = "matplotlib" }, -] - -[package.metadata] -requires-dist = [{ name = "matplotlib", specifier = ">=3.10.5" }] - -[[package]] -name = "contourpy" -version = "1.3.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/58/01/1253e6698a07380cd31a736d248a3f2a50a7c88779a1813da27503cadc2a/contourpy-1.3.3.tar.gz", hash = "sha256:083e12155b210502d0bca491432bb04d56dc3432f95a979b429f2848c3dbe880", size = 13466174, upload-time = "2025-07-26T12:03:12.549Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/be/45/adfee365d9ea3d853550b2e735f9d66366701c65db7855cd07621732ccfc/contourpy-1.3.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b08a32ea2f8e42cf1d4be3169a98dd4be32bafe4f22b6c4cb4ba810fa9e5d2cb", size = 293419, upload-time = "2025-07-26T12:01:21.16Z" }, - { url = "https://files.pythonhosted.org/packages/53/3e/405b59cfa13021a56bba395a6b3aca8cec012b45bf177b0eaf7a202cde2c/contourpy-1.3.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:556dba8fb6f5d8742f2923fe9457dbdd51e1049c4a43fd3986a0b14a1d815fc6", size = 273979, upload-time = "2025-07-26T12:01:22.448Z" }, - { url = "https://files.pythonhosted.org/packages/d4/1c/a12359b9b2ca3a845e8f7f9ac08bdf776114eb931392fcad91743e2ea17b/contourpy-1.3.3-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:92d9abc807cf7d0e047b95ca5d957cf4792fcd04e920ca70d48add15c1a90ea7", size = 332653, upload-time = "2025-07-26T12:01:24.155Z" }, - { url = "https://files.pythonhosted.org/packages/63/12/897aeebfb475b7748ea67b61e045accdfcf0d971f8a588b67108ed7f5512/contourpy-1.3.3-cp312-cp312-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b2e8faa0ed68cb29af51edd8e24798bb661eac3bd9f65420c1887b6ca89987c8", size = 379536, upload-time = "2025-07-26T12:01:25.91Z" }, - { url = "https://files.pythonhosted.org/packages/43/8a/a8c584b82deb248930ce069e71576fc09bd7174bbd35183b7943fb1064fd/contourpy-1.3.3-cp312-cp312-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:626d60935cf668e70a5ce6ff184fd713e9683fb458898e4249b63be9e28286ea", size = 384397, upload-time = "2025-07-26T12:01:27.152Z" }, - { url = "https://files.pythonhosted.org/packages/cc/8f/ec6289987824b29529d0dfda0d74a07cec60e54b9c92f3c9da4c0ac732de/contourpy-1.3.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d00e655fcef08aba35ec9610536bfe90267d7ab5ba944f7032549c55a146da1", size = 362601, upload-time = "2025-07-26T12:01:28.808Z" }, - { url = "https://files.pythonhosted.org/packages/05/0a/a3fe3be3ee2dceb3e615ebb4df97ae6f3828aa915d3e10549ce016302bd1/contourpy-1.3.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:451e71b5a7d597379ef572de31eeb909a87246974d960049a9848c3bc6c41bf7", size = 1331288, upload-time = "2025-07-26T12:01:31.198Z" }, - { url = "https://files.pythonhosted.org/packages/33/1d/acad9bd4e97f13f3e2b18a3977fe1b4a37ecf3d38d815333980c6c72e963/contourpy-1.3.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:459c1f020cd59fcfe6650180678a9993932d80d44ccde1fa1868977438f0b411", size = 1403386, upload-time = "2025-07-26T12:01:33.947Z" }, - { url = "https://files.pythonhosted.org/packages/cf/8f/5847f44a7fddf859704217a99a23a4f6417b10e5ab1256a179264561540e/contourpy-1.3.3-cp312-cp312-win32.whl", hash = "sha256:023b44101dfe49d7d53932be418477dba359649246075c996866106da069af69", size = 185018, upload-time = "2025-07-26T12:01:35.64Z" }, - { url = "https://files.pythonhosted.org/packages/19/e8/6026ed58a64563186a9ee3f29f41261fd1828f527dd93d33b60feca63352/contourpy-1.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:8153b8bfc11e1e4d75bcb0bff1db232f9e10b274e0929de9d608027e0d34ff8b", size = 226567, upload-time = "2025-07-26T12:01:36.804Z" }, - { url = "https://files.pythonhosted.org/packages/d1/e2/f05240d2c39a1ed228d8328a78b6f44cd695f7ef47beb3e684cf93604f86/contourpy-1.3.3-cp312-cp312-win_arm64.whl", hash = "sha256:07ce5ed73ecdc4a03ffe3e1b3e3c1166db35ae7584be76f65dbbe28a7791b0cc", size = 193655, upload-time = "2025-07-26T12:01:37.999Z" }, - { url = "https://files.pythonhosted.org/packages/68/35/0167aad910bbdb9599272bd96d01a9ec6852f36b9455cf2ca67bd4cc2d23/contourpy-1.3.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:177fb367556747a686509d6fef71d221a4b198a3905fe824430e5ea0fda54eb5", size = 293257, upload-time = "2025-07-26T12:01:39.367Z" }, - { url = "https://files.pythonhosted.org/packages/96/e4/7adcd9c8362745b2210728f209bfbcf7d91ba868a2c5f40d8b58f54c509b/contourpy-1.3.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d002b6f00d73d69333dac9d0b8d5e84d9724ff9ef044fd63c5986e62b7c9e1b1", size = 274034, upload-time = "2025-07-26T12:01:40.645Z" }, - { url = "https://files.pythonhosted.org/packages/73/23/90e31ceeed1de63058a02cb04b12f2de4b40e3bef5e082a7c18d9c8ae281/contourpy-1.3.3-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:348ac1f5d4f1d66d3322420f01d42e43122f43616e0f194fc1c9f5d830c5b286", size = 334672, upload-time = "2025-07-26T12:01:41.942Z" }, - { url = "https://files.pythonhosted.org/packages/ed/93/b43d8acbe67392e659e1d984700e79eb67e2acb2bd7f62012b583a7f1b55/contourpy-1.3.3-cp313-cp313-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:655456777ff65c2c548b7c454af9c6f33f16c8884f11083244b5819cc214f1b5", size = 381234, upload-time = "2025-07-26T12:01:43.499Z" }, - { url = "https://files.pythonhosted.org/packages/46/3b/bec82a3ea06f66711520f75a40c8fc0b113b2a75edb36aa633eb11c4f50f/contourpy-1.3.3-cp313-cp313-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:644a6853d15b2512d67881586bd03f462c7ab755db95f16f14d7e238f2852c67", size = 385169, upload-time = "2025-07-26T12:01:45.219Z" }, - { url = "https://files.pythonhosted.org/packages/4b/32/e0f13a1c5b0f8572d0ec6ae2f6c677b7991fafd95da523159c19eff0696a/contourpy-1.3.3-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4debd64f124ca62069f313a9cb86656ff087786016d76927ae2cf37846b006c9", size = 362859, upload-time = "2025-07-26T12:01:46.519Z" }, - { url = "https://files.pythonhosted.org/packages/33/71/e2a7945b7de4e58af42d708a219f3b2f4cff7386e6b6ab0a0fa0033c49a9/contourpy-1.3.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a15459b0f4615b00bbd1e91f1b9e19b7e63aea7483d03d804186f278c0af2659", size = 1332062, upload-time = "2025-07-26T12:01:48.964Z" }, - { url = "https://files.pythonhosted.org/packages/12/fc/4e87ac754220ccc0e807284f88e943d6d43b43843614f0a8afa469801db0/contourpy-1.3.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ca0fdcd73925568ca027e0b17ab07aad764be4706d0a925b89227e447d9737b7", size = 1403932, upload-time = "2025-07-26T12:01:51.979Z" }, - { url = "https://files.pythonhosted.org/packages/a6/2e/adc197a37443f934594112222ac1aa7dc9a98faf9c3842884df9a9d8751d/contourpy-1.3.3-cp313-cp313-win32.whl", hash = "sha256:b20c7c9a3bf701366556e1b1984ed2d0cedf999903c51311417cf5f591d8c78d", size = 185024, upload-time = "2025-07-26T12:01:53.245Z" }, - { url = "https://files.pythonhosted.org/packages/18/0b/0098c214843213759692cc638fce7de5c289200a830e5035d1791d7a2338/contourpy-1.3.3-cp313-cp313-win_amd64.whl", hash = "sha256:1cadd8b8969f060ba45ed7c1b714fe69185812ab43bd6b86a9123fe8f99c3263", size = 226578, upload-time = "2025-07-26T12:01:54.422Z" }, - { url = "https://files.pythonhosted.org/packages/8a/9a/2f6024a0c5995243cd63afdeb3651c984f0d2bc727fd98066d40e141ad73/contourpy-1.3.3-cp313-cp313-win_arm64.whl", hash = "sha256:fd914713266421b7536de2bfa8181aa8c699432b6763a0ea64195ebe28bff6a9", size = 193524, upload-time = "2025-07-26T12:01:55.73Z" }, - { url = "https://files.pythonhosted.org/packages/c0/b3/f8a1a86bd3298513f500e5b1f5fd92b69896449f6cab6a146a5d52715479/contourpy-1.3.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:88df9880d507169449d434c293467418b9f6cbe82edd19284aa0409e7fdb933d", size = 306730, upload-time = "2025-07-26T12:01:57.051Z" }, - { url = "https://files.pythonhosted.org/packages/3f/11/4780db94ae62fc0c2053909b65dc3246bd7cecfc4f8a20d957ad43aa4ad8/contourpy-1.3.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:d06bb1f751ba5d417047db62bca3c8fde202b8c11fb50742ab3ab962c81e8216", size = 287897, upload-time = "2025-07-26T12:01:58.663Z" }, - { url = "https://files.pythonhosted.org/packages/ae/15/e59f5f3ffdd6f3d4daa3e47114c53daabcb18574a26c21f03dc9e4e42ff0/contourpy-1.3.3-cp313-cp313t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e4e6b05a45525357e382909a4c1600444e2a45b4795163d3b22669285591c1ae", size = 326751, upload-time = "2025-07-26T12:02:00.343Z" }, - { url = "https://files.pythonhosted.org/packages/0f/81/03b45cfad088e4770b1dcf72ea78d3802d04200009fb364d18a493857210/contourpy-1.3.3-cp313-cp313t-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:ab3074b48c4e2cf1a960e6bbeb7f04566bf36b1861d5c9d4d8ac04b82e38ba20", size = 375486, upload-time = "2025-07-26T12:02:02.128Z" }, - { url = "https://files.pythonhosted.org/packages/0c/ba/49923366492ffbdd4486e970d421b289a670ae8cf539c1ea9a09822b371a/contourpy-1.3.3-cp313-cp313t-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:6c3d53c796f8647d6deb1abe867daeb66dcc8a97e8455efa729516b997b8ed99", size = 388106, upload-time = "2025-07-26T12:02:03.615Z" }, - { url = "https://files.pythonhosted.org/packages/9f/52/5b00ea89525f8f143651f9f03a0df371d3cbd2fccd21ca9b768c7a6500c2/contourpy-1.3.3-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:50ed930df7289ff2a8d7afeb9603f8289e5704755c7e5c3bbd929c90c817164b", size = 352548, upload-time = "2025-07-26T12:02:05.165Z" }, - { url = "https://files.pythonhosted.org/packages/32/1d/a209ec1a3a3452d490f6b14dd92e72280c99ae3d1e73da74f8277d4ee08f/contourpy-1.3.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:4feffb6537d64b84877da813a5c30f1422ea5739566abf0bd18065ac040e120a", size = 1322297, upload-time = "2025-07-26T12:02:07.379Z" }, - { url = "https://files.pythonhosted.org/packages/bc/9e/46f0e8ebdd884ca0e8877e46a3f4e633f6c9c8c4f3f6e72be3fe075994aa/contourpy-1.3.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2b7e9480ffe2b0cd2e787e4df64270e3a0440d9db8dc823312e2c940c167df7e", size = 1391023, upload-time = "2025-07-26T12:02:10.171Z" }, - { url = "https://files.pythonhosted.org/packages/b9/70/f308384a3ae9cd2209e0849f33c913f658d3326900d0ff5d378d6a1422d2/contourpy-1.3.3-cp313-cp313t-win32.whl", hash = "sha256:283edd842a01e3dcd435b1c5116798d661378d83d36d337b8dde1d16a5fc9ba3", size = 196157, upload-time = "2025-07-26T12:02:11.488Z" }, - { url = "https://files.pythonhosted.org/packages/b2/dd/880f890a6663b84d9e34a6f88cded89d78f0091e0045a284427cb6b18521/contourpy-1.3.3-cp313-cp313t-win_amd64.whl", hash = "sha256:87acf5963fc2b34825e5b6b048f40e3635dd547f590b04d2ab317c2619ef7ae8", size = 240570, upload-time = "2025-07-26T12:02:12.754Z" }, - { url = "https://files.pythonhosted.org/packages/80/99/2adc7d8ffead633234817ef8e9a87115c8a11927a94478f6bb3d3f4d4f7d/contourpy-1.3.3-cp313-cp313t-win_arm64.whl", hash = "sha256:3c30273eb2a55024ff31ba7d052dde990d7d8e5450f4bbb6e913558b3d6c2301", size = 199713, upload-time = "2025-07-26T12:02:14.4Z" }, - { url = "https://files.pythonhosted.org/packages/72/8b/4546f3ab60f78c514ffb7d01a0bd743f90de36f0019d1be84d0a708a580a/contourpy-1.3.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:fde6c716d51c04b1c25d0b90364d0be954624a0ee9d60e23e850e8d48353d07a", size = 292189, upload-time = "2025-07-26T12:02:16.095Z" }, - { url = "https://files.pythonhosted.org/packages/fd/e1/3542a9cb596cadd76fcef413f19c79216e002623158befe6daa03dbfa88c/contourpy-1.3.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:cbedb772ed74ff5be440fa8eee9bd49f64f6e3fc09436d9c7d8f1c287b121d77", size = 273251, upload-time = "2025-07-26T12:02:17.524Z" }, - { url = "https://files.pythonhosted.org/packages/b1/71/f93e1e9471d189f79d0ce2497007731c1e6bf9ef6d1d61b911430c3db4e5/contourpy-1.3.3-cp314-cp314-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:22e9b1bd7a9b1d652cd77388465dc358dafcd2e217d35552424aa4f996f524f5", size = 335810, upload-time = "2025-07-26T12:02:18.9Z" }, - { url = "https://files.pythonhosted.org/packages/91/f9/e35f4c1c93f9275d4e38681a80506b5510e9327350c51f8d4a5a724d178c/contourpy-1.3.3-cp314-cp314-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a22738912262aa3e254e4f3cb079a95a67132fc5a063890e224393596902f5a4", size = 382871, upload-time = "2025-07-26T12:02:20.418Z" }, - { url = "https://files.pythonhosted.org/packages/b5/71/47b512f936f66a0a900d81c396a7e60d73419868fba959c61efed7a8ab46/contourpy-1.3.3-cp314-cp314-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:afe5a512f31ee6bd7d0dda52ec9864c984ca3d66664444f2d72e0dc4eb832e36", size = 386264, upload-time = "2025-07-26T12:02:21.916Z" }, - { url = "https://files.pythonhosted.org/packages/04/5f/9ff93450ba96b09c7c2b3f81c94de31c89f92292f1380261bd7195bea4ea/contourpy-1.3.3-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f64836de09927cba6f79dcd00fdd7d5329f3fccc633468507079c829ca4db4e3", size = 363819, upload-time = "2025-07-26T12:02:23.759Z" }, - { url = "https://files.pythonhosted.org/packages/3e/a6/0b185d4cc480ee494945cde102cb0149ae830b5fa17bf855b95f2e70ad13/contourpy-1.3.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:1fd43c3be4c8e5fd6e4f2baeae35ae18176cf2e5cced681cca908addf1cdd53b", size = 1333650, upload-time = "2025-07-26T12:02:26.181Z" }, - { url = "https://files.pythonhosted.org/packages/43/d7/afdc95580ca56f30fbcd3060250f66cedbde69b4547028863abd8aa3b47e/contourpy-1.3.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:6afc576f7b33cf00996e5c1102dc2a8f7cc89e39c0b55df93a0b78c1bd992b36", size = 1404833, upload-time = "2025-07-26T12:02:28.782Z" }, - { url = "https://files.pythonhosted.org/packages/e2/e2/366af18a6d386f41132a48f033cbd2102e9b0cf6345d35ff0826cd984566/contourpy-1.3.3-cp314-cp314-win32.whl", hash = "sha256:66c8a43a4f7b8df8b71ee1840e4211a3c8d93b214b213f590e18a1beca458f7d", size = 189692, upload-time = "2025-07-26T12:02:30.128Z" }, - { url = "https://files.pythonhosted.org/packages/7d/c2/57f54b03d0f22d4044b8afb9ca0e184f8b1afd57b4f735c2fa70883dc601/contourpy-1.3.3-cp314-cp314-win_amd64.whl", hash = "sha256:cf9022ef053f2694e31d630feaacb21ea24224be1c3ad0520b13d844274614fd", size = 232424, upload-time = "2025-07-26T12:02:31.395Z" }, - { url = "https://files.pythonhosted.org/packages/18/79/a9416650df9b525737ab521aa181ccc42d56016d2123ddcb7b58e926a42c/contourpy-1.3.3-cp314-cp314-win_arm64.whl", hash = "sha256:95b181891b4c71de4bb404c6621e7e2390745f887f2a026b2d99e92c17892339", size = 198300, upload-time = "2025-07-26T12:02:32.956Z" }, - { url = "https://files.pythonhosted.org/packages/1f/42/38c159a7d0f2b7b9c04c64ab317042bb6952b713ba875c1681529a2932fe/contourpy-1.3.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:33c82d0138c0a062380332c861387650c82e4cf1747aaa6938b9b6516762e772", size = 306769, upload-time = "2025-07-26T12:02:34.2Z" }, - { url = "https://files.pythonhosted.org/packages/c3/6c/26a8205f24bca10974e77460de68d3d7c63e282e23782f1239f226fcae6f/contourpy-1.3.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:ea37e7b45949df430fe649e5de8351c423430046a2af20b1c1961cae3afcda77", size = 287892, upload-time = "2025-07-26T12:02:35.807Z" }, - { url = "https://files.pythonhosted.org/packages/66/06/8a475c8ab718ebfd7925661747dbb3c3ee9c82ac834ccb3570be49d129f4/contourpy-1.3.3-cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d304906ecc71672e9c89e87c4675dc5c2645e1f4269a5063b99b0bb29f232d13", size = 326748, upload-time = "2025-07-26T12:02:37.193Z" }, - { url = "https://files.pythonhosted.org/packages/b4/a3/c5ca9f010a44c223f098fccd8b158bb1cb287378a31ac141f04730dc49be/contourpy-1.3.3-cp314-cp314t-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:ca658cd1a680a5c9ea96dc61cdbae1e85c8f25849843aa799dfd3cb370ad4fbe", size = 375554, upload-time = "2025-07-26T12:02:38.894Z" }, - { url = "https://files.pythonhosted.org/packages/80/5b/68bd33ae63fac658a4145088c1e894405e07584a316738710b636c6d0333/contourpy-1.3.3-cp314-cp314t-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ab2fd90904c503739a75b7c8c5c01160130ba67944a7b77bbf36ef8054576e7f", size = 388118, upload-time = "2025-07-26T12:02:40.642Z" }, - { url = "https://files.pythonhosted.org/packages/40/52/4c285a6435940ae25d7410a6c36bda5145839bc3f0beb20c707cda18b9d2/contourpy-1.3.3-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b7301b89040075c30e5768810bc96a8e8d78085b47d8be6e4c3f5a0b4ed478a0", size = 352555, upload-time = "2025-07-26T12:02:42.25Z" }, - { url = "https://files.pythonhosted.org/packages/24/ee/3e81e1dd174f5c7fefe50e85d0892de05ca4e26ef1c9a59c2a57e43b865a/contourpy-1.3.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:2a2a8b627d5cc6b7c41a4beff6c5ad5eb848c88255fda4a8745f7e901b32d8e4", size = 1322295, upload-time = "2025-07-26T12:02:44.668Z" }, - { url = "https://files.pythonhosted.org/packages/3c/b2/6d913d4d04e14379de429057cd169e5e00f6c2af3bb13e1710bcbdb5da12/contourpy-1.3.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:fd6ec6be509c787f1caf6b247f0b1ca598bef13f4ddeaa126b7658215529ba0f", size = 1391027, upload-time = "2025-07-26T12:02:47.09Z" }, - { url = "https://files.pythonhosted.org/packages/93/8a/68a4ec5c55a2971213d29a9374913f7e9f18581945a7a31d1a39b5d2dfe5/contourpy-1.3.3-cp314-cp314t-win32.whl", hash = "sha256:e74a9a0f5e3fff48fb5a7f2fd2b9b70a3fe014a67522f79b7cca4c0c7e43c9ae", size = 202428, upload-time = "2025-07-26T12:02:48.691Z" }, - { url = "https://files.pythonhosted.org/packages/fa/96/fd9f641ffedc4fa3ace923af73b9d07e869496c9cc7a459103e6e978992f/contourpy-1.3.3-cp314-cp314t-win_amd64.whl", hash = "sha256:13b68d6a62db8eafaebb8039218921399baf6e47bf85006fd8529f2a08ef33fc", size = 250331, upload-time = "2025-07-26T12:02:50.137Z" }, - { url = "https://files.pythonhosted.org/packages/ae/8c/469afb6465b853afff216f9528ffda78a915ff880ed58813ba4faf4ba0b6/contourpy-1.3.3-cp314-cp314t-win_arm64.whl", hash = "sha256:b7448cb5a725bb1e35ce88771b86fba35ef418952474492cf7c764059933ff8b", size = 203831, upload-time = "2025-07-26T12:02:51.449Z" }, -] - -[[package]] -name = "cycler" -version = "0.12.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a9/95/a3dbbb5028f35eafb79008e7522a75244477d2838f38cbb722248dabc2a8/cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c", size = 7615, upload-time = "2023-10-07T05:32:18.335Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30", size = 8321, upload-time = "2023-10-07T05:32:16.783Z" }, -] - -[[package]] -name = "fonttools" -version = "4.59.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/11/7f/29c9c3fe4246f6ad96fee52b88d0dc3a863c7563b0afc959e36d78b965dc/fonttools-4.59.1.tar.gz", hash = "sha256:74995b402ad09822a4c8002438e54940d9f1ecda898d2bb057729d7da983e4cb", size = 3534394, upload-time = "2025-08-14T16:28:14.266Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ac/fe/6e069cc4cb8881d164a9bd956e9df555bc62d3eb36f6282e43440200009c/fonttools-4.59.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:43ab814bbba5f02a93a152ee61a04182bb5809bd2bc3609f7822e12c53ae2c91", size = 2769172, upload-time = "2025-08-14T16:26:45.729Z" }, - { url = "https://files.pythonhosted.org/packages/b9/98/ec4e03f748fefa0dd72d9d95235aff6fef16601267f4a2340f0e16b9330f/fonttools-4.59.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4f04c3ffbfa0baafcbc550657cf83657034eb63304d27b05cff1653b448ccff6", size = 2337281, upload-time = "2025-08-14T16:26:47.921Z" }, - { url = "https://files.pythonhosted.org/packages/8b/b1/890360a7e3d04a30ba50b267aca2783f4c1364363797e892e78a4f036076/fonttools-4.59.1-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d601b153e51a5a6221f0d4ec077b6bfc6ac35bfe6c19aeaa233d8990b2b71726", size = 4909215, upload-time = "2025-08-14T16:26:49.682Z" }, - { url = "https://files.pythonhosted.org/packages/8a/ec/2490599550d6c9c97a44c1e36ef4de52d6acf742359eaa385735e30c05c4/fonttools-4.59.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c735e385e30278c54f43a0d056736942023c9043f84ee1021eff9fd616d17693", size = 4951958, upload-time = "2025-08-14T16:26:51.616Z" }, - { url = "https://files.pythonhosted.org/packages/d1/40/bd053f6f7634234a9b9805ff8ae4f32df4f2168bee23cafd1271ba9915a9/fonttools-4.59.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1017413cdc8555dce7ee23720da490282ab7ec1cf022af90a241f33f9a49afc4", size = 4894738, upload-time = "2025-08-14T16:26:53.836Z" }, - { url = "https://files.pythonhosted.org/packages/ac/a1/3cd12a010d288325a7cfcf298a84825f0f9c29b01dee1baba64edfe89257/fonttools-4.59.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5c6d8d773470a5107052874341ed3c487c16ecd179976d81afed89dea5cd7406", size = 5045983, upload-time = "2025-08-14T16:26:56.153Z" }, - { url = "https://files.pythonhosted.org/packages/a2/af/8a2c3f6619cc43cf87951405337cc8460d08a4e717bb05eaa94b335d11dc/fonttools-4.59.1-cp312-cp312-win32.whl", hash = "sha256:2a2d0d33307f6ad3a2086a95dd607c202ea8852fa9fb52af9b48811154d1428a", size = 2203407, upload-time = "2025-08-14T16:26:58.165Z" }, - { url = "https://files.pythonhosted.org/packages/8e/f2/a19b874ddbd3ebcf11d7e25188ef9ac3f68b9219c62263acb34aca8cde05/fonttools-4.59.1-cp312-cp312-win_amd64.whl", hash = "sha256:0b9e4fa7eaf046ed6ac470f6033d52c052481ff7a6e0a92373d14f556f298dc0", size = 2251561, upload-time = "2025-08-14T16:27:00.646Z" }, - { url = "https://files.pythonhosted.org/packages/19/5e/94a4d7f36c36e82f6a81e0064d148542e0ad3e6cf51fc5461ca128f3658d/fonttools-4.59.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:89d9957b54246c6251345297dddf77a84d2c19df96af30d2de24093bbdf0528b", size = 2760192, upload-time = "2025-08-14T16:27:03.024Z" }, - { url = "https://files.pythonhosted.org/packages/ee/a5/f50712fc33ef9d06953c660cefaf8c8fe4b8bc74fa21f44ee5e4f9739439/fonttools-4.59.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8156b11c0d5405810d216f53907bd0f8b982aa5f1e7e3127ab3be1a4062154ff", size = 2332694, upload-time = "2025-08-14T16:27:04.883Z" }, - { url = "https://files.pythonhosted.org/packages/e9/a2/5a9fc21c354bf8613215ce233ab0d933bd17d5ff4c29693636551adbc7b3/fonttools-4.59.1-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8387876a8011caec52d327d5e5bca705d9399ec4b17afb8b431ec50d47c17d23", size = 4889254, upload-time = "2025-08-14T16:27:07.02Z" }, - { url = "https://files.pythonhosted.org/packages/2d/e5/54a6dc811eba018d022ca2e8bd6f2969291f9586ccf9a22a05fc55f91250/fonttools-4.59.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fb13823a74b3a9204a8ed76d3d6d5ec12e64cc5bc44914eb9ff1cdac04facd43", size = 4949109, upload-time = "2025-08-14T16:27:09.3Z" }, - { url = "https://files.pythonhosted.org/packages/db/15/b05c72a248a95bea0fd05fbd95acdf0742945942143fcf961343b7a3663a/fonttools-4.59.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e1ca10da138c300f768bb68e40e5b20b6ecfbd95f91aac4cc15010b6b9d65455", size = 4888428, upload-time = "2025-08-14T16:27:11.514Z" }, - { url = "https://files.pythonhosted.org/packages/63/71/c7d6840f858d695adc0c4371ec45e3fb1c8e060b276ba944e2800495aca4/fonttools-4.59.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2beb5bfc4887a3130f8625349605a3a45fe345655ce6031d1bac11017454b943", size = 5032668, upload-time = "2025-08-14T16:27:13.872Z" }, - { url = "https://files.pythonhosted.org/packages/90/54/57be4aca6f1312e2bc4d811200dd822325794e05bdb26eeff0976edca651/fonttools-4.59.1-cp313-cp313-win32.whl", hash = "sha256:419f16d750d78e6d704bfe97b48bba2f73b15c9418f817d0cb8a9ca87a5b94bf", size = 2201832, upload-time = "2025-08-14T16:27:16.126Z" }, - { url = "https://files.pythonhosted.org/packages/fc/1f/1899a6175a5f900ed8730a0d64f53ca1b596ed7609bfda033cf659114258/fonttools-4.59.1-cp313-cp313-win_amd64.whl", hash = "sha256:c536f8a852e8d3fa71dde1ec03892aee50be59f7154b533f0bf3c1174cfd5126", size = 2250673, upload-time = "2025-08-14T16:27:18.033Z" }, - { url = "https://files.pythonhosted.org/packages/15/07/f6ba82c22f118d9985c37fea65d8d715ca71300d78b6c6e90874dc59f11d/fonttools-4.59.1-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:d5c3bfdc9663f3d4b565f9cb3b8c1efb3e178186435b45105bde7328cfddd7fe", size = 2758606, upload-time = "2025-08-14T16:27:20.064Z" }, - { url = "https://files.pythonhosted.org/packages/3a/81/84aa3d0ce27b0112c28b67b637ff7a47cf401cf5fbfee6476e4bc9777580/fonttools-4.59.1-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:ea03f1da0d722fe3c2278a05957e6550175571a4894fbf9d178ceef4a3783d2b", size = 2330187, upload-time = "2025-08-14T16:27:22.42Z" }, - { url = "https://files.pythonhosted.org/packages/17/41/b3ba43f78afb321e2e50232c87304c8d0f5ab39b64389b8286cc39cdb824/fonttools-4.59.1-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:57a3708ca6bfccb790f585fa6d8f29432ec329618a09ff94c16bcb3c55994643", size = 4832020, upload-time = "2025-08-14T16:27:24.214Z" }, - { url = "https://files.pythonhosted.org/packages/67/b1/3af871c7fb325a68938e7ce544ca48bfd2c6bb7b357f3c8252933b29100a/fonttools-4.59.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:729367c91eb1ee84e61a733acc485065a00590618ca31c438e7dd4d600c01486", size = 4930687, upload-time = "2025-08-14T16:27:26.484Z" }, - { url = "https://files.pythonhosted.org/packages/c5/4f/299fc44646b30d9ef03ffaa78b109c7bd32121f0d8f10009ee73ac4514bc/fonttools-4.59.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:8f8ef66ac6db450193ed150e10b3b45dde7aded10c5d279968bc63368027f62b", size = 4875794, upload-time = "2025-08-14T16:27:28.887Z" }, - { url = "https://files.pythonhosted.org/packages/90/cf/a0a3d763ab58f5f81ceff104ddb662fd9da94248694862b9c6cbd509fdd5/fonttools-4.59.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:075f745d539a998cd92cb84c339a82e53e49114ec62aaea8307c80d3ad3aef3a", size = 4985780, upload-time = "2025-08-14T16:27:30.858Z" }, - { url = "https://files.pythonhosted.org/packages/72/c5/ba76511aaae143d89c29cd32ce30bafb61c477e8759a1590b8483f8065f8/fonttools-4.59.1-cp314-cp314-win32.whl", hash = "sha256:c2b0597522d4c5bb18aa5cf258746a2d4a90f25878cbe865e4d35526abd1b9fc", size = 2205610, upload-time = "2025-08-14T16:27:32.578Z" }, - { url = "https://files.pythonhosted.org/packages/a9/65/b250e69d6caf35bc65cddbf608be0662d741c248f2e7503ab01081fc267e/fonttools-4.59.1-cp314-cp314-win_amd64.whl", hash = "sha256:e9ad4ce044e3236f0814c906ccce8647046cc557539661e35211faadf76f283b", size = 2255376, upload-time = "2025-08-14T16:27:34.653Z" }, - { url = "https://files.pythonhosted.org/packages/11/f3/0bc63a23ac0f8175e23d82f85d6ee693fbd849de7ad739f0a3622182ad29/fonttools-4.59.1-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:652159e8214eb4856e8387ebcd6b6bd336ee258cbeb639c8be52005b122b9609", size = 2826546, upload-time = "2025-08-14T16:27:36.783Z" }, - { url = "https://files.pythonhosted.org/packages/e9/46/a3968205590e068fdf60e926be329a207782576cb584d3b7dcd2d2844957/fonttools-4.59.1-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:43d177cd0e847ea026fedd9f099dc917da136ed8792d142298a252836390c478", size = 2359771, upload-time = "2025-08-14T16:27:39.678Z" }, - { url = "https://files.pythonhosted.org/packages/b8/ff/d14b4c283879e8cb57862d9624a34fe6522b6fcdd46ccbfc58900958794a/fonttools-4.59.1-cp314-cp314t-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:e54437651e1440ee53a95e6ceb6ee440b67a3d348c76f45f4f48de1a5ecab019", size = 4831575, upload-time = "2025-08-14T16:27:41.885Z" }, - { url = "https://files.pythonhosted.org/packages/9c/04/a277d9a584a49d98ca12d3b2c6663bdf333ae97aaa83bd0cdabf7c5a6c84/fonttools-4.59.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6065fdec8ff44c32a483fd44abe5bcdb40dd5e2571a5034b555348f2b3a52cea", size = 5069962, upload-time = "2025-08-14T16:27:44.284Z" }, - { url = "https://files.pythonhosted.org/packages/16/6f/3d2ae69d96c4cdee6dfe7598ca5519a1514487700ca3d7c49c5a1ad65308/fonttools-4.59.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:42052b56d176f8b315fbc09259439c013c0cb2109df72447148aeda677599612", size = 4942926, upload-time = "2025-08-14T16:27:46.523Z" }, - { url = "https://files.pythonhosted.org/packages/0c/d3/c17379e0048d03ce26b38e4ab0e9a98280395b00529e093fe2d663ac0658/fonttools-4.59.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:bcd52eaa5c4c593ae9f447c1d13e7e4a00ca21d755645efa660b6999425b3c88", size = 4958678, upload-time = "2025-08-14T16:27:48.555Z" }, - { url = "https://files.pythonhosted.org/packages/8c/3f/c5543a1540abdfb4d375e3ebeb84de365ab9b153ec14cb7db05f537dd1e7/fonttools-4.59.1-cp314-cp314t-win32.whl", hash = "sha256:02e4fdf27c550dded10fe038a5981c29f81cb9bc649ff2eaa48e80dab8998f97", size = 2266706, upload-time = "2025-08-14T16:27:50.556Z" }, - { url = "https://files.pythonhosted.org/packages/3e/99/85bff6e674226bc8402f983e365f07e76d990e7220ba72bcc738fef52391/fonttools-4.59.1-cp314-cp314t-win_amd64.whl", hash = "sha256:412a5fd6345872a7c249dac5bcce380393f40c1c316ac07f447bc17d51900922", size = 2329994, upload-time = "2025-08-14T16:27:52.36Z" }, - { url = "https://files.pythonhosted.org/packages/0f/64/9d606e66d498917cd7a2ff24f558010d42d6fd4576d9dd57f0bd98333f5a/fonttools-4.59.1-py3-none-any.whl", hash = "sha256:647db657073672a8330608970a984d51573557f328030566521bc03415535042", size = 1130094, upload-time = "2025-08-14T16:28:12.048Z" }, -] - -[[package]] -name = "kiwisolver" -version = "1.4.9" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/5c/3c/85844f1b0feb11ee581ac23fe5fce65cd049a200c1446708cc1b7f922875/kiwisolver-1.4.9.tar.gz", hash = "sha256:c3b22c26c6fd6811b0ae8363b95ca8ce4ea3c202d3d0975b2914310ceb1bcc4d", size = 97564, upload-time = "2025-08-10T21:27:49.279Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/86/c9/13573a747838aeb1c76e3267620daa054f4152444d1f3d1a2324b78255b5/kiwisolver-1.4.9-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ac5a486ac389dddcc5bef4f365b6ae3ffff2c433324fb38dd35e3fab7c957999", size = 123686, upload-time = "2025-08-10T21:26:10.034Z" }, - { url = "https://files.pythonhosted.org/packages/51/ea/2ecf727927f103ffd1739271ca19c424d0e65ea473fbaeea1c014aea93f6/kiwisolver-1.4.9-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f2ba92255faa7309d06fe44c3a4a97efe1c8d640c2a79a5ef728b685762a6fd2", size = 66460, upload-time = "2025-08-10T21:26:11.083Z" }, - { url = "https://files.pythonhosted.org/packages/5b/5a/51f5464373ce2aeb5194508298a508b6f21d3867f499556263c64c621914/kiwisolver-1.4.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4a2899935e724dd1074cb568ce7ac0dce28b2cd6ab539c8e001a8578eb106d14", size = 64952, upload-time = "2025-08-10T21:26:12.058Z" }, - { url = "https://files.pythonhosted.org/packages/70/90/6d240beb0f24b74371762873e9b7f499f1e02166a2d9c5801f4dbf8fa12e/kiwisolver-1.4.9-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f6008a4919fdbc0b0097089f67a1eb55d950ed7e90ce2cc3e640abadd2757a04", size = 1474756, upload-time = "2025-08-10T21:26:13.096Z" }, - { url = "https://files.pythonhosted.org/packages/12/42/f36816eaf465220f683fb711efdd1bbf7a7005a2473d0e4ed421389bd26c/kiwisolver-1.4.9-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:67bb8b474b4181770f926f7b7d2f8c0248cbcb78b660fdd41a47054b28d2a752", size = 1276404, upload-time = "2025-08-10T21:26:14.457Z" }, - { url = "https://files.pythonhosted.org/packages/2e/64/bc2de94800adc830c476dce44e9b40fd0809cddeef1fde9fcf0f73da301f/kiwisolver-1.4.9-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2327a4a30d3ee07d2fbe2e7933e8a37c591663b96ce42a00bc67461a87d7df77", size = 1294410, upload-time = "2025-08-10T21:26:15.73Z" }, - { url = "https://files.pythonhosted.org/packages/5f/42/2dc82330a70aa8e55b6d395b11018045e58d0bb00834502bf11509f79091/kiwisolver-1.4.9-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:7a08b491ec91b1d5053ac177afe5290adacf1f0f6307d771ccac5de30592d198", size = 1343631, upload-time = "2025-08-10T21:26:17.045Z" }, - { url = "https://files.pythonhosted.org/packages/22/fd/f4c67a6ed1aab149ec5a8a401c323cee7a1cbe364381bb6c9c0d564e0e20/kiwisolver-1.4.9-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d8fc5c867c22b828001b6a38d2eaeb88160bf5783c6cb4a5e440efc981ce286d", size = 2224963, upload-time = "2025-08-10T21:26:18.737Z" }, - { url = "https://files.pythonhosted.org/packages/45/aa/76720bd4cb3713314677d9ec94dcc21ced3f1baf4830adde5bb9b2430a5f/kiwisolver-1.4.9-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:3b3115b2581ea35bb6d1f24a4c90af37e5d9b49dcff267eeed14c3893c5b86ab", size = 2321295, upload-time = "2025-08-10T21:26:20.11Z" }, - { url = "https://files.pythonhosted.org/packages/80/19/d3ec0d9ab711242f56ae0dc2fc5d70e298bb4a1f9dfab44c027668c673a1/kiwisolver-1.4.9-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:858e4c22fb075920b96a291928cb7dea5644e94c0ee4fcd5af7e865655e4ccf2", size = 2487987, upload-time = "2025-08-10T21:26:21.49Z" }, - { url = "https://files.pythonhosted.org/packages/39/e9/61e4813b2c97e86b6fdbd4dd824bf72d28bcd8d4849b8084a357bc0dd64d/kiwisolver-1.4.9-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ed0fecd28cc62c54b262e3736f8bb2512d8dcfdc2bcf08be5f47f96bf405b145", size = 2291817, upload-time = "2025-08-10T21:26:22.812Z" }, - { url = "https://files.pythonhosted.org/packages/a0/41/85d82b0291db7504da3c2defe35c9a8a5c9803a730f297bd823d11d5fb77/kiwisolver-1.4.9-cp312-cp312-win_amd64.whl", hash = "sha256:f68208a520c3d86ea51acf688a3e3002615a7f0238002cccc17affecc86a8a54", size = 73895, upload-time = "2025-08-10T21:26:24.37Z" }, - { url = "https://files.pythonhosted.org/packages/e2/92/5f3068cf15ee5cb624a0c7596e67e2a0bb2adee33f71c379054a491d07da/kiwisolver-1.4.9-cp312-cp312-win_arm64.whl", hash = "sha256:2c1a4f57df73965f3f14df20b80ee29e6a7930a57d2d9e8491a25f676e197c60", size = 64992, upload-time = "2025-08-10T21:26:25.732Z" }, - { url = "https://files.pythonhosted.org/packages/31/c1/c2686cda909742ab66c7388e9a1a8521a59eb89f8bcfbee28fc980d07e24/kiwisolver-1.4.9-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a5d0432ccf1c7ab14f9949eec60c5d1f924f17c037e9f8b33352fa05799359b8", size = 123681, upload-time = "2025-08-10T21:26:26.725Z" }, - { url = "https://files.pythonhosted.org/packages/ca/f0/f44f50c9f5b1a1860261092e3bc91ecdc9acda848a8b8c6abfda4a24dd5c/kiwisolver-1.4.9-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efb3a45b35622bb6c16dbfab491a8f5a391fe0e9d45ef32f4df85658232ca0e2", size = 66464, upload-time = "2025-08-10T21:26:27.733Z" }, - { url = "https://files.pythonhosted.org/packages/2d/7a/9d90a151f558e29c3936b8a47ac770235f436f2120aca41a6d5f3d62ae8d/kiwisolver-1.4.9-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1a12cf6398e8a0a001a059747a1cbf24705e18fe413bc22de7b3d15c67cffe3f", size = 64961, upload-time = "2025-08-10T21:26:28.729Z" }, - { url = "https://files.pythonhosted.org/packages/e9/e9/f218a2cb3a9ffbe324ca29a9e399fa2d2866d7f348ec3a88df87fc248fc5/kiwisolver-1.4.9-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b67e6efbf68e077dd71d1a6b37e43e1a99d0bff1a3d51867d45ee8908b931098", size = 1474607, upload-time = "2025-08-10T21:26:29.798Z" }, - { url = "https://files.pythonhosted.org/packages/d9/28/aac26d4c882f14de59041636292bc838db8961373825df23b8eeb807e198/kiwisolver-1.4.9-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5656aa670507437af0207645273ccdfee4f14bacd7f7c67a4306d0dcaeaf6eed", size = 1276546, upload-time = "2025-08-10T21:26:31.401Z" }, - { url = "https://files.pythonhosted.org/packages/8b/ad/8bfc1c93d4cc565e5069162f610ba2f48ff39b7de4b5b8d93f69f30c4bed/kiwisolver-1.4.9-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:bfc08add558155345129c7803b3671cf195e6a56e7a12f3dde7c57d9b417f525", size = 1294482, upload-time = "2025-08-10T21:26:32.721Z" }, - { url = "https://files.pythonhosted.org/packages/da/f1/6aca55ff798901d8ce403206d00e033191f63d82dd708a186e0ed2067e9c/kiwisolver-1.4.9-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:40092754720b174e6ccf9e845d0d8c7d8e12c3d71e7fc35f55f3813e96376f78", size = 1343720, upload-time = "2025-08-10T21:26:34.032Z" }, - { url = "https://files.pythonhosted.org/packages/d1/91/eed031876c595c81d90d0f6fc681ece250e14bf6998c3d7c419466b523b7/kiwisolver-1.4.9-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:497d05f29a1300d14e02e6441cf0f5ee81c1ff5a304b0d9fb77423974684e08b", size = 2224907, upload-time = "2025-08-10T21:26:35.824Z" }, - { url = "https://files.pythonhosted.org/packages/e9/ec/4d1925f2e49617b9cca9c34bfa11adefad49d00db038e692a559454dfb2e/kiwisolver-1.4.9-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:bdd1a81a1860476eb41ac4bc1e07b3f07259e6d55bbf739b79c8aaedcf512799", size = 2321334, upload-time = "2025-08-10T21:26:37.534Z" }, - { url = "https://files.pythonhosted.org/packages/43/cb/450cd4499356f68802750c6ddc18647b8ea01ffa28f50d20598e0befe6e9/kiwisolver-1.4.9-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:e6b93f13371d341afee3be9f7c5964e3fe61d5fa30f6a30eb49856935dfe4fc3", size = 2488313, upload-time = "2025-08-10T21:26:39.191Z" }, - { url = "https://files.pythonhosted.org/packages/71/67/fc76242bd99f885651128a5d4fa6083e5524694b7c88b489b1b55fdc491d/kiwisolver-1.4.9-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d75aa530ccfaa593da12834b86a0724f58bff12706659baa9227c2ccaa06264c", size = 2291970, upload-time = "2025-08-10T21:26:40.828Z" }, - { url = "https://files.pythonhosted.org/packages/75/bd/f1a5d894000941739f2ae1b65a32892349423ad49c2e6d0771d0bad3fae4/kiwisolver-1.4.9-cp313-cp313-win_amd64.whl", hash = "sha256:dd0a578400839256df88c16abddf9ba14813ec5f21362e1fe65022e00c883d4d", size = 73894, upload-time = "2025-08-10T21:26:42.33Z" }, - { url = "https://files.pythonhosted.org/packages/95/38/dce480814d25b99a391abbddadc78f7c117c6da34be68ca8b02d5848b424/kiwisolver-1.4.9-cp313-cp313-win_arm64.whl", hash = "sha256:d4188e73af84ca82468f09cadc5ac4db578109e52acb4518d8154698d3a87ca2", size = 64995, upload-time = "2025-08-10T21:26:43.889Z" }, - { url = "https://files.pythonhosted.org/packages/e2/37/7d218ce5d92dadc5ebdd9070d903e0c7cf7edfe03f179433ac4d13ce659c/kiwisolver-1.4.9-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:5a0f2724dfd4e3b3ac5a82436a8e6fd16baa7d507117e4279b660fe8ca38a3a1", size = 126510, upload-time = "2025-08-10T21:26:44.915Z" }, - { url = "https://files.pythonhosted.org/packages/23/b0/e85a2b48233daef4b648fb657ebbb6f8367696a2d9548a00b4ee0eb67803/kiwisolver-1.4.9-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:1b11d6a633e4ed84fc0ddafd4ebfd8ea49b3f25082c04ad12b8315c11d504dc1", size = 67903, upload-time = "2025-08-10T21:26:45.934Z" }, - { url = "https://files.pythonhosted.org/packages/44/98/f2425bc0113ad7de24da6bb4dae1343476e95e1d738be7c04d31a5d037fd/kiwisolver-1.4.9-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:61874cdb0a36016354853593cffc38e56fc9ca5aa97d2c05d3dcf6922cd55a11", size = 66402, upload-time = "2025-08-10T21:26:47.101Z" }, - { url = "https://files.pythonhosted.org/packages/98/d8/594657886df9f34c4177cc353cc28ca7e6e5eb562d37ccc233bff43bbe2a/kiwisolver-1.4.9-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:60c439763a969a6af93b4881db0eed8fadf93ee98e18cbc35bc8da868d0c4f0c", size = 1582135, upload-time = "2025-08-10T21:26:48.665Z" }, - { url = "https://files.pythonhosted.org/packages/5c/c6/38a115b7170f8b306fc929e166340c24958347308ea3012c2b44e7e295db/kiwisolver-1.4.9-cp313-cp313t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:92a2f997387a1b79a75e7803aa7ded2cfbe2823852ccf1ba3bcf613b62ae3197", size = 1389409, upload-time = "2025-08-10T21:26:50.335Z" }, - { url = "https://files.pythonhosted.org/packages/bf/3b/e04883dace81f24a568bcee6eb3001da4ba05114afa622ec9b6fafdc1f5e/kiwisolver-1.4.9-cp313-cp313t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a31d512c812daea6d8b3be3b2bfcbeb091dbb09177706569bcfc6240dcf8b41c", size = 1401763, upload-time = "2025-08-10T21:26:51.867Z" }, - { url = "https://files.pythonhosted.org/packages/9f/80/20ace48e33408947af49d7d15c341eaee69e4e0304aab4b7660e234d6288/kiwisolver-1.4.9-cp313-cp313t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:52a15b0f35dad39862d376df10c5230155243a2c1a436e39eb55623ccbd68185", size = 1453643, upload-time = "2025-08-10T21:26:53.592Z" }, - { url = "https://files.pythonhosted.org/packages/64/31/6ce4380a4cd1f515bdda976a1e90e547ccd47b67a1546d63884463c92ca9/kiwisolver-1.4.9-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:a30fd6fdef1430fd9e1ba7b3398b5ee4e2887783917a687d86ba69985fb08748", size = 2330818, upload-time = "2025-08-10T21:26:55.051Z" }, - { url = "https://files.pythonhosted.org/packages/fa/e9/3f3fcba3bcc7432c795b82646306e822f3fd74df0ee81f0fa067a1f95668/kiwisolver-1.4.9-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:cc9617b46837c6468197b5945e196ee9ca43057bb7d9d1ae688101e4e1dddf64", size = 2419963, upload-time = "2025-08-10T21:26:56.421Z" }, - { url = "https://files.pythonhosted.org/packages/99/43/7320c50e4133575c66e9f7dadead35ab22d7c012a3b09bb35647792b2a6d/kiwisolver-1.4.9-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:0ab74e19f6a2b027ea4f845a78827969af45ce790e6cb3e1ebab71bdf9f215ff", size = 2594639, upload-time = "2025-08-10T21:26:57.882Z" }, - { url = "https://files.pythonhosted.org/packages/65/d6/17ae4a270d4a987ef8a385b906d2bdfc9fce502d6dc0d3aea865b47f548c/kiwisolver-1.4.9-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:dba5ee5d3981160c28d5490f0d1b7ed730c22470ff7f6cc26cfcfaacb9896a07", size = 2391741, upload-time = "2025-08-10T21:26:59.237Z" }, - { url = "https://files.pythonhosted.org/packages/2a/8f/8f6f491d595a9e5912971f3f863d81baddccc8a4d0c3749d6a0dd9ffc9df/kiwisolver-1.4.9-cp313-cp313t-win_arm64.whl", hash = "sha256:0749fd8f4218ad2e851e11cc4dc05c7cbc0cbc4267bdfdb31782e65aace4ee9c", size = 68646, upload-time = "2025-08-10T21:27:00.52Z" }, - { url = "https://files.pythonhosted.org/packages/6b/32/6cc0fbc9c54d06c2969faa9c1d29f5751a2e51809dd55c69055e62d9b426/kiwisolver-1.4.9-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:9928fe1eb816d11ae170885a74d074f57af3a0d65777ca47e9aeb854a1fba386", size = 123806, upload-time = "2025-08-10T21:27:01.537Z" }, - { url = "https://files.pythonhosted.org/packages/b2/dd/2bfb1d4a4823d92e8cbb420fe024b8d2167f72079b3bb941207c42570bdf/kiwisolver-1.4.9-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:d0005b053977e7b43388ddec89fa567f43d4f6d5c2c0affe57de5ebf290dc552", size = 66605, upload-time = "2025-08-10T21:27:03.335Z" }, - { url = "https://files.pythonhosted.org/packages/f7/69/00aafdb4e4509c2ca6064646cba9cd4b37933898f426756adb2cb92ebbed/kiwisolver-1.4.9-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:2635d352d67458b66fd0667c14cb1d4145e9560d503219034a18a87e971ce4f3", size = 64925, upload-time = "2025-08-10T21:27:04.339Z" }, - { url = "https://files.pythonhosted.org/packages/43/dc/51acc6791aa14e5cb6d8a2e28cefb0dc2886d8862795449d021334c0df20/kiwisolver-1.4.9-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:767c23ad1c58c9e827b649a9ab7809fd5fd9db266a9cf02b0e926ddc2c680d58", size = 1472414, upload-time = "2025-08-10T21:27:05.437Z" }, - { url = "https://files.pythonhosted.org/packages/3d/bb/93fa64a81db304ac8a246f834d5094fae4b13baf53c839d6bb6e81177129/kiwisolver-1.4.9-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:72d0eb9fba308b8311685c2268cf7d0a0639a6cd027d8128659f72bdd8a024b4", size = 1281272, upload-time = "2025-08-10T21:27:07.063Z" }, - { url = "https://files.pythonhosted.org/packages/70/e6/6df102916960fb8d05069d4bd92d6d9a8202d5a3e2444494e7cd50f65b7a/kiwisolver-1.4.9-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f68e4f3eeca8fb22cc3d731f9715a13b652795ef657a13df1ad0c7dc0e9731df", size = 1298578, upload-time = "2025-08-10T21:27:08.452Z" }, - { url = "https://files.pythonhosted.org/packages/7c/47/e142aaa612f5343736b087864dbaebc53ea8831453fb47e7521fa8658f30/kiwisolver-1.4.9-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d84cd4061ae292d8ac367b2c3fa3aad11cb8625a95d135fe93f286f914f3f5a6", size = 1345607, upload-time = "2025-08-10T21:27:10.125Z" }, - { url = "https://files.pythonhosted.org/packages/54/89/d641a746194a0f4d1a3670fb900d0dbaa786fb98341056814bc3f058fa52/kiwisolver-1.4.9-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:a60ea74330b91bd22a29638940d115df9dc00af5035a9a2a6ad9399ffb4ceca5", size = 2230150, upload-time = "2025-08-10T21:27:11.484Z" }, - { url = "https://files.pythonhosted.org/packages/aa/6b/5ee1207198febdf16ac11f78c5ae40861b809cbe0e6d2a8d5b0b3044b199/kiwisolver-1.4.9-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:ce6a3a4e106cf35c2d9c4fa17c05ce0b180db622736845d4315519397a77beaf", size = 2325979, upload-time = "2025-08-10T21:27:12.917Z" }, - { url = "https://files.pythonhosted.org/packages/fc/ff/b269eefd90f4ae14dcc74973d5a0f6d28d3b9bb1afd8c0340513afe6b39a/kiwisolver-1.4.9-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:77937e5e2a38a7b48eef0585114fe7930346993a88060d0bf886086d2aa49ef5", size = 2491456, upload-time = "2025-08-10T21:27:14.353Z" }, - { url = "https://files.pythonhosted.org/packages/fc/d4/10303190bd4d30de547534601e259a4fbf014eed94aae3e5521129215086/kiwisolver-1.4.9-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:24c175051354f4a28c5d6a31c93906dc653e2bf234e8a4bbfb964892078898ce", size = 2294621, upload-time = "2025-08-10T21:27:15.808Z" }, - { url = "https://files.pythonhosted.org/packages/28/e0/a9a90416fce5c0be25742729c2ea52105d62eda6c4be4d803c2a7be1fa50/kiwisolver-1.4.9-cp314-cp314-win_amd64.whl", hash = "sha256:0763515d4df10edf6d06a3c19734e2566368980d21ebec439f33f9eb936c07b7", size = 75417, upload-time = "2025-08-10T21:27:17.436Z" }, - { url = "https://files.pythonhosted.org/packages/1f/10/6949958215b7a9a264299a7db195564e87900f709db9245e4ebdd3c70779/kiwisolver-1.4.9-cp314-cp314-win_arm64.whl", hash = "sha256:0e4e2bf29574a6a7b7f6cb5fa69293b9f96c928949ac4a53ba3f525dffb87f9c", size = 66582, upload-time = "2025-08-10T21:27:18.436Z" }, - { url = "https://files.pythonhosted.org/packages/ec/79/60e53067903d3bc5469b369fe0dfc6b3482e2133e85dae9daa9527535991/kiwisolver-1.4.9-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:d976bbb382b202f71c67f77b0ac11244021cfa3f7dfd9e562eefcea2df711548", size = 126514, upload-time = "2025-08-10T21:27:19.465Z" }, - { url = "https://files.pythonhosted.org/packages/25/d1/4843d3e8d46b072c12a38c97c57fab4608d36e13fe47d47ee96b4d61ba6f/kiwisolver-1.4.9-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:2489e4e5d7ef9a1c300a5e0196e43d9c739f066ef23270607d45aba368b91f2d", size = 67905, upload-time = "2025-08-10T21:27:20.51Z" }, - { url = "https://files.pythonhosted.org/packages/8c/ae/29ffcbd239aea8b93108de1278271ae764dfc0d803a5693914975f200596/kiwisolver-1.4.9-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:e2ea9f7ab7fbf18fffb1b5434ce7c69a07582f7acc7717720f1d69f3e806f90c", size = 66399, upload-time = "2025-08-10T21:27:21.496Z" }, - { url = "https://files.pythonhosted.org/packages/a1/ae/d7ba902aa604152c2ceba5d352d7b62106bedbccc8e95c3934d94472bfa3/kiwisolver-1.4.9-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b34e51affded8faee0dfdb705416153819d8ea9250bbbf7ea1b249bdeb5f1122", size = 1582197, upload-time = "2025-08-10T21:27:22.604Z" }, - { url = "https://files.pythonhosted.org/packages/f2/41/27c70d427eddb8bc7e4f16420a20fefc6f480312122a59a959fdfe0445ad/kiwisolver-1.4.9-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d8aacd3d4b33b772542b2e01beb50187536967b514b00003bdda7589722d2a64", size = 1390125, upload-time = "2025-08-10T21:27:24.036Z" }, - { url = "https://files.pythonhosted.org/packages/41/42/b3799a12bafc76d962ad69083f8b43b12bf4fe78b097b12e105d75c9b8f1/kiwisolver-1.4.9-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7cf974dd4e35fa315563ac99d6287a1024e4dc2077b8a7d7cd3d2fb65d283134", size = 1402612, upload-time = "2025-08-10T21:27:25.773Z" }, - { url = "https://files.pythonhosted.org/packages/d2/b5/a210ea073ea1cfaca1bb5c55a62307d8252f531beb364e18aa1e0888b5a0/kiwisolver-1.4.9-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:85bd218b5ecfbee8c8a82e121802dcb519a86044c9c3b2e4aef02fa05c6da370", size = 1453990, upload-time = "2025-08-10T21:27:27.089Z" }, - { url = "https://files.pythonhosted.org/packages/5f/ce/a829eb8c033e977d7ea03ed32fb3c1781b4fa0433fbadfff29e39c676f32/kiwisolver-1.4.9-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:0856e241c2d3df4efef7c04a1e46b1936b6120c9bcf36dd216e3acd84bc4fb21", size = 2331601, upload-time = "2025-08-10T21:27:29.343Z" }, - { url = "https://files.pythonhosted.org/packages/e0/4b/b5e97eb142eb9cd0072dacfcdcd31b1c66dc7352b0f7c7255d339c0edf00/kiwisolver-1.4.9-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:9af39d6551f97d31a4deebeac6f45b156f9755ddc59c07b402c148f5dbb6482a", size = 2422041, upload-time = "2025-08-10T21:27:30.754Z" }, - { url = "https://files.pythonhosted.org/packages/40/be/8eb4cd53e1b85ba4edc3a9321666f12b83113a178845593307a3e7891f44/kiwisolver-1.4.9-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:bb4ae2b57fc1d8cbd1cf7b1d9913803681ffa903e7488012be5b76dedf49297f", size = 2594897, upload-time = "2025-08-10T21:27:32.803Z" }, - { url = "https://files.pythonhosted.org/packages/99/dd/841e9a66c4715477ea0abc78da039832fbb09dac5c35c58dc4c41a407b8a/kiwisolver-1.4.9-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:aedff62918805fb62d43a4aa2ecd4482c380dc76cd31bd7c8878588a61bd0369", size = 2391835, upload-time = "2025-08-10T21:27:34.23Z" }, - { url = "https://files.pythonhosted.org/packages/0c/28/4b2e5c47a0da96896fdfdb006340ade064afa1e63675d01ea5ac222b6d52/kiwisolver-1.4.9-cp314-cp314t-win_amd64.whl", hash = "sha256:1fa333e8b2ce4d9660f2cda9c0e1b6bafcfb2457a9d259faa82289e73ec24891", size = 79988, upload-time = "2025-08-10T21:27:35.587Z" }, - { url = "https://files.pythonhosted.org/packages/80/be/3578e8afd18c88cdf9cb4cffde75a96d2be38c5a903f1ed0ceec061bd09e/kiwisolver-1.4.9-cp314-cp314t-win_arm64.whl", hash = "sha256:4a48a2ce79d65d363597ef7b567ce3d14d68783d2b2263d98db3d9477805ba32", size = 70260, upload-time = "2025-08-10T21:27:36.606Z" }, -] - -[[package]] -name = "matplotlib" -version = "3.10.5" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "contourpy" }, - { name = "cycler" }, - { name = "fonttools" }, - { name = "kiwisolver" }, - { name = "numpy" }, - { name = "packaging" }, - { name = "pillow" }, - { name = "pyparsing" }, - { name = "python-dateutil" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/43/91/f2939bb60b7ebf12478b030e0d7f340247390f402b3b189616aad790c366/matplotlib-3.10.5.tar.gz", hash = "sha256:352ed6ccfb7998a00881692f38b4ca083c691d3e275b4145423704c34c909076", size = 34804044, upload-time = "2025-07-31T18:09:33.805Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/66/1e/c6f6bcd882d589410b475ca1fc22e34e34c82adff519caf18f3e6dd9d682/matplotlib-3.10.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:00b6feadc28a08bd3c65b2894f56cf3c94fc8f7adcbc6ab4516ae1e8ed8f62e2", size = 8253056, upload-time = "2025-07-31T18:08:05.385Z" }, - { url = "https://files.pythonhosted.org/packages/53/e6/d6f7d1b59413f233793dda14419776f5f443bcccb2dfc84b09f09fe05dbe/matplotlib-3.10.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ee98a5c5344dc7f48dc261b6ba5d9900c008fc12beb3fa6ebda81273602cc389", size = 8110131, upload-time = "2025-07-31T18:08:07.293Z" }, - { url = "https://files.pythonhosted.org/packages/66/2b/bed8a45e74957549197a2ac2e1259671cd80b55ed9e1fe2b5c94d88a9202/matplotlib-3.10.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a17e57e33de901d221a07af32c08870ed4528db0b6059dce7d7e65c1122d4bea", size = 8669603, upload-time = "2025-07-31T18:08:09.064Z" }, - { url = "https://files.pythonhosted.org/packages/7e/a7/315e9435b10d057f5e52dfc603cd353167ae28bb1a4e033d41540c0067a4/matplotlib-3.10.5-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:97b9d6443419085950ee4a5b1ee08c363e5c43d7176e55513479e53669e88468", size = 9508127, upload-time = "2025-07-31T18:08:10.845Z" }, - { url = "https://files.pythonhosted.org/packages/7f/d9/edcbb1f02ca99165365d2768d517898c22c6040187e2ae2ce7294437c413/matplotlib-3.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ceefe5d40807d29a66ae916c6a3915d60ef9f028ce1927b84e727be91d884369", size = 9566926, upload-time = "2025-07-31T18:08:13.186Z" }, - { url = "https://files.pythonhosted.org/packages/3b/d9/6dd924ad5616c97b7308e6320cf392c466237a82a2040381163b7500510a/matplotlib-3.10.5-cp312-cp312-win_amd64.whl", hash = "sha256:c04cba0f93d40e45b3c187c6c52c17f24535b27d545f757a2fffebc06c12b98b", size = 8107599, upload-time = "2025-07-31T18:08:15.116Z" }, - { url = "https://files.pythonhosted.org/packages/0e/f3/522dc319a50f7b0279fbe74f86f7a3506ce414bc23172098e8d2bdf21894/matplotlib-3.10.5-cp312-cp312-win_arm64.whl", hash = "sha256:a41bcb6e2c8e79dc99c5511ae6f7787d2fb52efd3d805fff06d5d4f667db16b2", size = 7978173, upload-time = "2025-07-31T18:08:21.518Z" }, - { url = "https://files.pythonhosted.org/packages/8d/05/4f3c1f396075f108515e45cb8d334aff011a922350e502a7472e24c52d77/matplotlib-3.10.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:354204db3f7d5caaa10e5de74549ef6a05a4550fdd1c8f831ab9bca81efd39ed", size = 8253586, upload-time = "2025-07-31T18:08:23.107Z" }, - { url = "https://files.pythonhosted.org/packages/2f/2c/e084415775aac7016c3719fe7006cdb462582c6c99ac142f27303c56e243/matplotlib-3.10.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b072aac0c3ad563a2b3318124756cb6112157017f7431626600ecbe890df57a1", size = 8110715, upload-time = "2025-07-31T18:08:24.675Z" }, - { url = "https://files.pythonhosted.org/packages/52/1b/233e3094b749df16e3e6cd5a44849fd33852e692ad009cf7de00cf58ddf6/matplotlib-3.10.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d52fd5b684d541b5a51fb276b2b97b010c75bee9aa392f96b4a07aeb491e33c7", size = 8669397, upload-time = "2025-07-31T18:08:26.778Z" }, - { url = "https://files.pythonhosted.org/packages/e8/ec/03f9e003a798f907d9f772eed9b7c6a9775d5bd00648b643ebfb88e25414/matplotlib-3.10.5-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee7a09ae2f4676276f5a65bd9f2bd91b4f9fbaedf49f40267ce3f9b448de501f", size = 9508646, upload-time = "2025-07-31T18:08:28.848Z" }, - { url = "https://files.pythonhosted.org/packages/91/e7/c051a7a386680c28487bca27d23b02d84f63e3d2a9b4d2fc478e6a42e37e/matplotlib-3.10.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ba6c3c9c067b83481d647af88b4e441d532acdb5ef22178a14935b0b881188f4", size = 9567424, upload-time = "2025-07-31T18:08:30.726Z" }, - { url = "https://files.pythonhosted.org/packages/36/c2/24302e93ff431b8f4173ee1dd88976c8d80483cadbc5d3d777cef47b3a1c/matplotlib-3.10.5-cp313-cp313-win_amd64.whl", hash = "sha256:07442d2692c9bd1cceaa4afb4bbe5b57b98a7599de4dabfcca92d3eea70f9ebe", size = 8107809, upload-time = "2025-07-31T18:08:33.928Z" }, - { url = "https://files.pythonhosted.org/packages/0b/33/423ec6a668d375dad825197557ed8fbdb74d62b432c1ed8235465945475f/matplotlib-3.10.5-cp313-cp313-win_arm64.whl", hash = "sha256:48fe6d47380b68a37ccfcc94f009530e84d41f71f5dae7eda7c4a5a84aa0a674", size = 7978078, upload-time = "2025-07-31T18:08:36.764Z" }, - { url = "https://files.pythonhosted.org/packages/51/17/521fc16ec766455c7bb52cc046550cf7652f6765ca8650ff120aa2d197b6/matplotlib-3.10.5-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:3b80eb8621331449fc519541a7461987f10afa4f9cfd91afcd2276ebe19bd56c", size = 8295590, upload-time = "2025-07-31T18:08:38.521Z" }, - { url = "https://files.pythonhosted.org/packages/f8/12/23c28b2c21114c63999bae129fce7fd34515641c517ae48ce7b7dcd33458/matplotlib-3.10.5-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:47a388908e469d6ca2a6015858fa924e0e8a2345a37125948d8e93a91c47933e", size = 8158518, upload-time = "2025-07-31T18:08:40.195Z" }, - { url = "https://files.pythonhosted.org/packages/81/f8/aae4eb25e8e7190759f3cb91cbeaa344128159ac92bb6b409e24f8711f78/matplotlib-3.10.5-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8b6b49167d208358983ce26e43aa4196073b4702858670f2eb111f9a10652b4b", size = 8691815, upload-time = "2025-07-31T18:08:42.238Z" }, - { url = "https://files.pythonhosted.org/packages/d0/ba/450c39ebdd486bd33a359fc17365ade46c6a96bf637bbb0df7824de2886c/matplotlib-3.10.5-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8a8da0453a7fd8e3da114234ba70c5ba9ef0e98f190309ddfde0f089accd46ea", size = 9522814, upload-time = "2025-07-31T18:08:44.914Z" }, - { url = "https://files.pythonhosted.org/packages/89/11/9c66f6a990e27bb9aa023f7988d2d5809cb98aa39c09cbf20fba75a542ef/matplotlib-3.10.5-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:52c6573dfcb7726a9907b482cd5b92e6b5499b284ffacb04ffbfe06b3e568124", size = 9573917, upload-time = "2025-07-31T18:08:47.038Z" }, - { url = "https://files.pythonhosted.org/packages/b3/69/8b49394de92569419e5e05e82e83df9b749a0ff550d07631ea96ed2eb35a/matplotlib-3.10.5-cp313-cp313t-win_amd64.whl", hash = "sha256:a23193db2e9d64ece69cac0c8231849db7dd77ce59c7b89948cf9d0ce655a3ce", size = 8181034, upload-time = "2025-07-31T18:08:48.943Z" }, - { url = "https://files.pythonhosted.org/packages/47/23/82dc435bb98a2fc5c20dffcac8f0b083935ac28286413ed8835df40d0baa/matplotlib-3.10.5-cp313-cp313t-win_arm64.whl", hash = "sha256:56da3b102cf6da2776fef3e71cd96fcf22103a13594a18ac9a9b31314e0be154", size = 8023337, upload-time = "2025-07-31T18:08:50.791Z" }, - { url = "https://files.pythonhosted.org/packages/ac/e0/26b6cfde31f5383503ee45dcb7e691d45dadf0b3f54639332b59316a97f8/matplotlib-3.10.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:96ef8f5a3696f20f55597ffa91c28e2e73088df25c555f8d4754931515512715", size = 8253591, upload-time = "2025-07-31T18:08:53.254Z" }, - { url = "https://files.pythonhosted.org/packages/c1/89/98488c7ef7ea20ea659af7499628c240a608b337af4be2066d644cfd0a0f/matplotlib-3.10.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:77fab633e94b9da60512d4fa0213daeb76d5a7b05156840c4fd0399b4b818837", size = 8112566, upload-time = "2025-07-31T18:08:55.116Z" }, - { url = "https://files.pythonhosted.org/packages/52/67/42294dfedc82aea55e1a767daf3263aacfb5a125f44ba189e685bab41b6f/matplotlib-3.10.5-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:27f52634315e96b1debbfdc5c416592edcd9c4221bc2f520fd39c33db5d9f202", size = 9513281, upload-time = "2025-07-31T18:08:56.885Z" }, - { url = "https://files.pythonhosted.org/packages/e7/68/f258239e0cf34c2cbc816781c7ab6fca768452e6bf1119aedd2bd4a882a3/matplotlib-3.10.5-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:525f6e28c485c769d1f07935b660c864de41c37fd716bfa64158ea646f7084bb", size = 9780873, upload-time = "2025-07-31T18:08:59.241Z" }, - { url = "https://files.pythonhosted.org/packages/89/64/f4881554006bd12e4558bd66778bdd15d47b00a1f6c6e8b50f6208eda4b3/matplotlib-3.10.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:1f5f3ec4c191253c5f2b7c07096a142c6a1c024d9f738247bfc8e3f9643fc975", size = 9568954, upload-time = "2025-07-31T18:09:01.244Z" }, - { url = "https://files.pythonhosted.org/packages/06/f8/42779d39c3f757e1f012f2dda3319a89fb602bd2ef98ce8faf0281f4febd/matplotlib-3.10.5-cp314-cp314-win_amd64.whl", hash = "sha256:707f9c292c4cd4716f19ab8a1f93f26598222cd931e0cd98fbbb1c5994bf7667", size = 8237465, upload-time = "2025-07-31T18:09:03.206Z" }, - { url = "https://files.pythonhosted.org/packages/cf/f8/153fd06b5160f0cd27c8b9dd797fcc9fb56ac6a0ebf3c1f765b6b68d3c8a/matplotlib-3.10.5-cp314-cp314-win_arm64.whl", hash = "sha256:21a95b9bf408178d372814de7baacd61c712a62cae560b5e6f35d791776f6516", size = 8108898, upload-time = "2025-07-31T18:09:05.231Z" }, - { url = "https://files.pythonhosted.org/packages/9a/ee/c4b082a382a225fe0d2a73f1f57cf6f6f132308805b493a54c8641006238/matplotlib-3.10.5-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:a6b310f95e1102a8c7c817ef17b60ee5d1851b8c71b63d9286b66b177963039e", size = 8295636, upload-time = "2025-07-31T18:09:07.306Z" }, - { url = "https://files.pythonhosted.org/packages/30/73/2195fa2099718b21a20da82dfc753bf2af58d596b51aefe93e359dd5915a/matplotlib-3.10.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:94986a242747a0605cb3ff1cb98691c736f28a59f8ffe5175acaeb7397c49a5a", size = 8158575, upload-time = "2025-07-31T18:09:09.083Z" }, - { url = "https://files.pythonhosted.org/packages/f6/e9/a08cdb34618a91fa08f75e6738541da5cacde7c307cea18ff10f0d03fcff/matplotlib-3.10.5-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1ff10ea43288f0c8bab608a305dc6c918cc729d429c31dcbbecde3b9f4d5b569", size = 9522815, upload-time = "2025-07-31T18:09:11.191Z" }, - { url = "https://files.pythonhosted.org/packages/4e/bb/34d8b7e0d1bb6d06ef45db01dfa560d5a67b1c40c0b998ce9ccde934bb09/matplotlib-3.10.5-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f6adb644c9d040ffb0d3434e440490a66cf73dbfa118a6f79cd7568431f7a012", size = 9783514, upload-time = "2025-07-31T18:09:13.307Z" }, - { url = "https://files.pythonhosted.org/packages/12/09/d330d1e55dcca2e11b4d304cc5227f52e2512e46828d6249b88e0694176e/matplotlib-3.10.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:4fa40a8f98428f789a9dcacd625f59b7bc4e3ef6c8c7c80187a7a709475cf592", size = 9573932, upload-time = "2025-07-31T18:09:15.335Z" }, - { url = "https://files.pythonhosted.org/packages/eb/3b/f70258ac729aa004aca673800a53a2b0a26d49ca1df2eaa03289a1c40f81/matplotlib-3.10.5-cp314-cp314t-win_amd64.whl", hash = "sha256:95672a5d628b44207aab91ec20bf59c26da99de12b88f7e0b1fb0a84a86ff959", size = 8322003, upload-time = "2025-07-31T18:09:17.416Z" }, - { url = "https://files.pythonhosted.org/packages/5b/60/3601f8ce6d76a7c81c7f25a0e15fde0d6b66226dd187aa6d2838e6374161/matplotlib-3.10.5-cp314-cp314t-win_arm64.whl", hash = "sha256:2efaf97d72629e74252e0b5e3c46813e9eeaa94e011ecf8084a971a31a97f40b", size = 8153849, upload-time = "2025-07-31T18:09:19.673Z" }, -] - -[[package]] -name = "numpy" -version = "2.3.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/37/7d/3fec4199c5ffb892bed55cff901e4f39a58c81df9c44c280499e92cad264/numpy-2.3.2.tar.gz", hash = "sha256:e0486a11ec30cdecb53f184d496d1c6a20786c81e55e41640270130056f8ee48", size = 20489306, upload-time = "2025-07-24T21:32:07.553Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/00/6d/745dd1c1c5c284d17725e5c802ca4d45cfc6803519d777f087b71c9f4069/numpy-2.3.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:bc3186bea41fae9d8e90c2b4fb5f0a1f5a690682da79b92574d63f56b529080b", size = 20956420, upload-time = "2025-07-24T20:28:18.002Z" }, - { url = "https://files.pythonhosted.org/packages/bc/96/e7b533ea5740641dd62b07a790af5d9d8fec36000b8e2d0472bd7574105f/numpy-2.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2f4f0215edb189048a3c03bd5b19345bdfa7b45a7a6f72ae5945d2a28272727f", size = 14184660, upload-time = "2025-07-24T20:28:39.522Z" }, - { url = "https://files.pythonhosted.org/packages/2b/53/102c6122db45a62aa20d1b18c9986f67e6b97e0d6fbc1ae13e3e4c84430c/numpy-2.3.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:8b1224a734cd509f70816455c3cffe13a4f599b1bf7130f913ba0e2c0b2006c0", size = 5113382, upload-time = "2025-07-24T20:28:48.544Z" }, - { url = "https://files.pythonhosted.org/packages/2b/21/376257efcbf63e624250717e82b4fae93d60178f09eb03ed766dbb48ec9c/numpy-2.3.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:3dcf02866b977a38ba3ec10215220609ab9667378a9e2150615673f3ffd6c73b", size = 6647258, upload-time = "2025-07-24T20:28:59.104Z" }, - { url = "https://files.pythonhosted.org/packages/91/ba/f4ebf257f08affa464fe6036e13f2bf9d4642a40228781dc1235da81be9f/numpy-2.3.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:572d5512df5470f50ada8d1972c5f1082d9a0b7aa5944db8084077570cf98370", size = 14281409, upload-time = "2025-07-24T20:40:30.298Z" }, - { url = "https://files.pythonhosted.org/packages/59/ef/f96536f1df42c668cbacb727a8c6da7afc9c05ece6d558927fb1722693e1/numpy-2.3.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8145dd6d10df13c559d1e4314df29695613575183fa2e2d11fac4c208c8a1f73", size = 16641317, upload-time = "2025-07-24T20:40:56.625Z" }, - { url = "https://files.pythonhosted.org/packages/f6/a7/af813a7b4f9a42f498dde8a4c6fcbff8100eed00182cc91dbaf095645f38/numpy-2.3.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:103ea7063fa624af04a791c39f97070bf93b96d7af7eb23530cd087dc8dbe9dc", size = 16056262, upload-time = "2025-07-24T20:41:20.797Z" }, - { url = "https://files.pythonhosted.org/packages/8b/5d/41c4ef8404caaa7f05ed1cfb06afe16a25895260eacbd29b4d84dff2920b/numpy-2.3.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fc927d7f289d14f5e037be917539620603294454130b6de200091e23d27dc9be", size = 18579342, upload-time = "2025-07-24T20:41:50.753Z" }, - { url = "https://files.pythonhosted.org/packages/a1/4f/9950e44c5a11636f4a3af6e825ec23003475cc9a466edb7a759ed3ea63bd/numpy-2.3.2-cp312-cp312-win32.whl", hash = "sha256:d95f59afe7f808c103be692175008bab926b59309ade3e6d25009e9a171f7036", size = 6320610, upload-time = "2025-07-24T20:42:01.551Z" }, - { url = "https://files.pythonhosted.org/packages/7c/2f/244643a5ce54a94f0a9a2ab578189c061e4a87c002e037b0829dd77293b6/numpy-2.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:9e196ade2400c0c737d93465327d1ae7c06c7cb8a1756121ebf54b06ca183c7f", size = 12786292, upload-time = "2025-07-24T20:42:20.738Z" }, - { url = "https://files.pythonhosted.org/packages/54/cd/7b5f49d5d78db7badab22d8323c1b6ae458fbf86c4fdfa194ab3cd4eb39b/numpy-2.3.2-cp312-cp312-win_arm64.whl", hash = "sha256:ee807923782faaf60d0d7331f5e86da7d5e3079e28b291973c545476c2b00d07", size = 10194071, upload-time = "2025-07-24T20:42:36.657Z" }, - { url = "https://files.pythonhosted.org/packages/1c/c0/c6bb172c916b00700ed3bf71cb56175fd1f7dbecebf8353545d0b5519f6c/numpy-2.3.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c8d9727f5316a256425892b043736d63e89ed15bbfe6556c5ff4d9d4448ff3b3", size = 20949074, upload-time = "2025-07-24T20:43:07.813Z" }, - { url = "https://files.pythonhosted.org/packages/20/4e/c116466d22acaf4573e58421c956c6076dc526e24a6be0903219775d862e/numpy-2.3.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:efc81393f25f14d11c9d161e46e6ee348637c0a1e8a54bf9dedc472a3fae993b", size = 14177311, upload-time = "2025-07-24T20:43:29.335Z" }, - { url = "https://files.pythonhosted.org/packages/78/45/d4698c182895af189c463fc91d70805d455a227261d950e4e0f1310c2550/numpy-2.3.2-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:dd937f088a2df683cbb79dda9a772b62a3e5a8a7e76690612c2737f38c6ef1b6", size = 5106022, upload-time = "2025-07-24T20:43:37.999Z" }, - { url = "https://files.pythonhosted.org/packages/9f/76/3e6880fef4420179309dba72a8c11f6166c431cf6dee54c577af8906f914/numpy-2.3.2-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:11e58218c0c46c80509186e460d79fbdc9ca1eb8d8aee39d8f2dc768eb781089", size = 6640135, upload-time = "2025-07-24T20:43:49.28Z" }, - { url = "https://files.pythonhosted.org/packages/34/fa/87ff7f25b3c4ce9085a62554460b7db686fef1e0207e8977795c7b7d7ba1/numpy-2.3.2-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5ad4ebcb683a1f99f4f392cc522ee20a18b2bb12a2c1c42c3d48d5a1adc9d3d2", size = 14278147, upload-time = "2025-07-24T20:44:10.328Z" }, - { url = "https://files.pythonhosted.org/packages/1d/0f/571b2c7a3833ae419fe69ff7b479a78d313581785203cc70a8db90121b9a/numpy-2.3.2-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:938065908d1d869c7d75d8ec45f735a034771c6ea07088867f713d1cd3bbbe4f", size = 16635989, upload-time = "2025-07-24T20:44:34.88Z" }, - { url = "https://files.pythonhosted.org/packages/24/5a/84ae8dca9c9a4c592fe11340b36a86ffa9fd3e40513198daf8a97839345c/numpy-2.3.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:66459dccc65d8ec98cc7df61307b64bf9e08101f9598755d42d8ae65d9a7a6ee", size = 16053052, upload-time = "2025-07-24T20:44:58.872Z" }, - { url = "https://files.pythonhosted.org/packages/57/7c/e5725d99a9133b9813fcf148d3f858df98511686e853169dbaf63aec6097/numpy-2.3.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a7af9ed2aa9ec5950daf05bb11abc4076a108bd3c7db9aa7251d5f107079b6a6", size = 18577955, upload-time = "2025-07-24T20:45:26.714Z" }, - { url = "https://files.pythonhosted.org/packages/ae/11/7c546fcf42145f29b71e4d6f429e96d8d68e5a7ba1830b2e68d7418f0bbd/numpy-2.3.2-cp313-cp313-win32.whl", hash = "sha256:906a30249315f9c8e17b085cc5f87d3f369b35fedd0051d4a84686967bdbbd0b", size = 6311843, upload-time = "2025-07-24T20:49:24.444Z" }, - { url = "https://files.pythonhosted.org/packages/aa/6f/a428fd1cb7ed39b4280d057720fed5121b0d7754fd2a9768640160f5517b/numpy-2.3.2-cp313-cp313-win_amd64.whl", hash = "sha256:c63d95dc9d67b676e9108fe0d2182987ccb0f11933c1e8959f42fa0da8d4fa56", size = 12782876, upload-time = "2025-07-24T20:49:43.227Z" }, - { url = "https://files.pythonhosted.org/packages/65/85/4ea455c9040a12595fb6c43f2c217257c7b52dd0ba332c6a6c1d28b289fe/numpy-2.3.2-cp313-cp313-win_arm64.whl", hash = "sha256:b05a89f2fb84d21235f93de47129dd4f11c16f64c87c33f5e284e6a3a54e43f2", size = 10192786, upload-time = "2025-07-24T20:49:59.443Z" }, - { url = "https://files.pythonhosted.org/packages/80/23/8278f40282d10c3f258ec3ff1b103d4994bcad78b0cba9208317f6bb73da/numpy-2.3.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:4e6ecfeddfa83b02318f4d84acf15fbdbf9ded18e46989a15a8b6995dfbf85ab", size = 21047395, upload-time = "2025-07-24T20:45:58.821Z" }, - { url = "https://files.pythonhosted.org/packages/1f/2d/624f2ce4a5df52628b4ccd16a4f9437b37c35f4f8a50d00e962aae6efd7a/numpy-2.3.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:508b0eada3eded10a3b55725b40806a4b855961040180028f52580c4729916a2", size = 14300374, upload-time = "2025-07-24T20:46:20.207Z" }, - { url = "https://files.pythonhosted.org/packages/f6/62/ff1e512cdbb829b80a6bd08318a58698867bca0ca2499d101b4af063ee97/numpy-2.3.2-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:754d6755d9a7588bdc6ac47dc4ee97867271b17cee39cb87aef079574366db0a", size = 5228864, upload-time = "2025-07-24T20:46:30.58Z" }, - { url = "https://files.pythonhosted.org/packages/7d/8e/74bc18078fff03192d4032cfa99d5a5ca937807136d6f5790ce07ca53515/numpy-2.3.2-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:a9f66e7d2b2d7712410d3bc5684149040ef5f19856f20277cd17ea83e5006286", size = 6737533, upload-time = "2025-07-24T20:46:46.111Z" }, - { url = "https://files.pythonhosted.org/packages/19/ea/0731efe2c9073ccca5698ef6a8c3667c4cf4eea53fcdcd0b50140aba03bc/numpy-2.3.2-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:de6ea4e5a65d5a90c7d286ddff2b87f3f4ad61faa3db8dabe936b34c2275b6f8", size = 14352007, upload-time = "2025-07-24T20:47:07.1Z" }, - { url = "https://files.pythonhosted.org/packages/cf/90/36be0865f16dfed20f4bc7f75235b963d5939707d4b591f086777412ff7b/numpy-2.3.2-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3ef07ec8cbc8fc9e369c8dcd52019510c12da4de81367d8b20bc692aa07573a", size = 16701914, upload-time = "2025-07-24T20:47:32.459Z" }, - { url = "https://files.pythonhosted.org/packages/94/30/06cd055e24cb6c38e5989a9e747042b4e723535758e6153f11afea88c01b/numpy-2.3.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:27c9f90e7481275c7800dc9c24b7cc40ace3fdb970ae4d21eaff983a32f70c91", size = 16132708, upload-time = "2025-07-24T20:47:58.129Z" }, - { url = "https://files.pythonhosted.org/packages/9a/14/ecede608ea73e58267fd7cb78f42341b3b37ba576e778a1a06baffbe585c/numpy-2.3.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:07b62978075b67eee4065b166d000d457c82a1efe726cce608b9db9dd66a73a5", size = 18651678, upload-time = "2025-07-24T20:48:25.402Z" }, - { url = "https://files.pythonhosted.org/packages/40/f3/2fe6066b8d07c3685509bc24d56386534c008b462a488b7f503ba82b8923/numpy-2.3.2-cp313-cp313t-win32.whl", hash = "sha256:c771cfac34a4f2c0de8e8c97312d07d64fd8f8ed45bc9f5726a7e947270152b5", size = 6441832, upload-time = "2025-07-24T20:48:37.181Z" }, - { url = "https://files.pythonhosted.org/packages/0b/ba/0937d66d05204d8f28630c9c60bc3eda68824abde4cf756c4d6aad03b0c6/numpy-2.3.2-cp313-cp313t-win_amd64.whl", hash = "sha256:72dbebb2dcc8305c431b2836bcc66af967df91be793d63a24e3d9b741374c450", size = 12927049, upload-time = "2025-07-24T20:48:56.24Z" }, - { url = "https://files.pythonhosted.org/packages/e9/ed/13542dd59c104d5e654dfa2ac282c199ba64846a74c2c4bcdbc3a0f75df1/numpy-2.3.2-cp313-cp313t-win_arm64.whl", hash = "sha256:72c6df2267e926a6d5286b0a6d556ebe49eae261062059317837fda12ddf0c1a", size = 10262935, upload-time = "2025-07-24T20:49:13.136Z" }, - { url = "https://files.pythonhosted.org/packages/c9/7c/7659048aaf498f7611b783e000c7268fcc4dcf0ce21cd10aad7b2e8f9591/numpy-2.3.2-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:448a66d052d0cf14ce9865d159bfc403282c9bc7bb2a31b03cc18b651eca8b1a", size = 20950906, upload-time = "2025-07-24T20:50:30.346Z" }, - { url = "https://files.pythonhosted.org/packages/80/db/984bea9d4ddf7112a04cfdfb22b1050af5757864cfffe8e09e44b7f11a10/numpy-2.3.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:546aaf78e81b4081b2eba1d105c3b34064783027a06b3ab20b6eba21fb64132b", size = 14185607, upload-time = "2025-07-24T20:50:51.923Z" }, - { url = "https://files.pythonhosted.org/packages/e4/76/b3d6f414f4eca568f469ac112a3b510938d892bc5a6c190cb883af080b77/numpy-2.3.2-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:87c930d52f45df092f7578889711a0768094debf73cfcde105e2d66954358125", size = 5114110, upload-time = "2025-07-24T20:51:01.041Z" }, - { url = "https://files.pythonhosted.org/packages/9e/d2/6f5e6826abd6bca52392ed88fe44a4b52aacb60567ac3bc86c67834c3a56/numpy-2.3.2-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:8dc082ea901a62edb8f59713c6a7e28a85daddcb67454c839de57656478f5b19", size = 6642050, upload-time = "2025-07-24T20:51:11.64Z" }, - { url = "https://files.pythonhosted.org/packages/c4/43/f12b2ade99199e39c73ad182f103f9d9791f48d885c600c8e05927865baf/numpy-2.3.2-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:af58de8745f7fa9ca1c0c7c943616c6fe28e75d0c81f5c295810e3c83b5be92f", size = 14296292, upload-time = "2025-07-24T20:51:33.488Z" }, - { url = "https://files.pythonhosted.org/packages/5d/f9/77c07d94bf110a916b17210fac38680ed8734c236bfed9982fd8524a7b47/numpy-2.3.2-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fed5527c4cf10f16c6d0b6bee1f89958bccb0ad2522c8cadc2efd318bcd545f5", size = 16638913, upload-time = "2025-07-24T20:51:58.517Z" }, - { url = "https://files.pythonhosted.org/packages/9b/d1/9d9f2c8ea399cc05cfff8a7437453bd4e7d894373a93cdc46361bbb49a7d/numpy-2.3.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:095737ed986e00393ec18ec0b21b47c22889ae4b0cd2d5e88342e08b01141f58", size = 16071180, upload-time = "2025-07-24T20:52:22.827Z" }, - { url = "https://files.pythonhosted.org/packages/4c/41/82e2c68aff2a0c9bf315e47d61951099fed65d8cb2c8d9dc388cb87e947e/numpy-2.3.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b5e40e80299607f597e1a8a247ff8d71d79c5b52baa11cc1cce30aa92d2da6e0", size = 18576809, upload-time = "2025-07-24T20:52:51.015Z" }, - { url = "https://files.pythonhosted.org/packages/14/14/4b4fd3efb0837ed252d0f583c5c35a75121038a8c4e065f2c259be06d2d8/numpy-2.3.2-cp314-cp314-win32.whl", hash = "sha256:7d6e390423cc1f76e1b8108c9b6889d20a7a1f59d9a60cac4a050fa734d6c1e2", size = 6366410, upload-time = "2025-07-24T20:56:44.949Z" }, - { url = "https://files.pythonhosted.org/packages/11/9e/b4c24a6b8467b61aced5c8dc7dcfce23621baa2e17f661edb2444a418040/numpy-2.3.2-cp314-cp314-win_amd64.whl", hash = "sha256:b9d0878b21e3918d76d2209c924ebb272340da1fb51abc00f986c258cd5e957b", size = 12918821, upload-time = "2025-07-24T20:57:06.479Z" }, - { url = "https://files.pythonhosted.org/packages/0e/0f/0dc44007c70b1007c1cef86b06986a3812dd7106d8f946c09cfa75782556/numpy-2.3.2-cp314-cp314-win_arm64.whl", hash = "sha256:2738534837c6a1d0c39340a190177d7d66fdf432894f469728da901f8f6dc910", size = 10477303, upload-time = "2025-07-24T20:57:22.879Z" }, - { url = "https://files.pythonhosted.org/packages/8b/3e/075752b79140b78ddfc9c0a1634d234cfdbc6f9bbbfa6b7504e445ad7d19/numpy-2.3.2-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:4d002ecf7c9b53240be3bb69d80f86ddbd34078bae04d87be81c1f58466f264e", size = 21047524, upload-time = "2025-07-24T20:53:22.086Z" }, - { url = "https://files.pythonhosted.org/packages/fe/6d/60e8247564a72426570d0e0ea1151b95ce5bd2f1597bb878a18d32aec855/numpy-2.3.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:293b2192c6bcce487dbc6326de5853787f870aeb6c43f8f9c6496db5b1781e45", size = 14300519, upload-time = "2025-07-24T20:53:44.053Z" }, - { url = "https://files.pythonhosted.org/packages/4d/73/d8326c442cd428d47a067070c3ac6cc3b651a6e53613a1668342a12d4479/numpy-2.3.2-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:0a4f2021a6da53a0d580d6ef5db29947025ae8b35b3250141805ea9a32bbe86b", size = 5228972, upload-time = "2025-07-24T20:53:53.81Z" }, - { url = "https://files.pythonhosted.org/packages/34/2e/e71b2d6dad075271e7079db776196829019b90ce3ece5c69639e4f6fdc44/numpy-2.3.2-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:9c144440db4bf3bb6372d2c3e49834cc0ff7bb4c24975ab33e01199e645416f2", size = 6737439, upload-time = "2025-07-24T20:54:04.742Z" }, - { url = "https://files.pythonhosted.org/packages/15/b0/d004bcd56c2c5e0500ffc65385eb6d569ffd3363cb5e593ae742749b2daa/numpy-2.3.2-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f92d6c2a8535dc4fe4419562294ff957f83a16ebdec66df0805e473ffaad8bd0", size = 14352479, upload-time = "2025-07-24T20:54:25.819Z" }, - { url = "https://files.pythonhosted.org/packages/11/e3/285142fcff8721e0c99b51686426165059874c150ea9ab898e12a492e291/numpy-2.3.2-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cefc2219baa48e468e3db7e706305fcd0c095534a192a08f31e98d83a7d45fb0", size = 16702805, upload-time = "2025-07-24T20:54:50.814Z" }, - { url = "https://files.pythonhosted.org/packages/33/c3/33b56b0e47e604af2c7cd065edca892d180f5899599b76830652875249a3/numpy-2.3.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:76c3e9501ceb50b2ff3824c3589d5d1ab4ac857b0ee3f8f49629d0de55ecf7c2", size = 16133830, upload-time = "2025-07-24T20:55:17.306Z" }, - { url = "https://files.pythonhosted.org/packages/6e/ae/7b1476a1f4d6a48bc669b8deb09939c56dd2a439db1ab03017844374fb67/numpy-2.3.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:122bf5ed9a0221b3419672493878ba4967121514b1d7d4656a7580cd11dddcbf", size = 18652665, upload-time = "2025-07-24T20:55:46.665Z" }, - { url = "https://files.pythonhosted.org/packages/14/ba/5b5c9978c4bb161034148ade2de9db44ec316fab89ce8c400db0e0c81f86/numpy-2.3.2-cp314-cp314t-win32.whl", hash = "sha256:6f1ae3dcb840edccc45af496f312528c15b1f79ac318169d094e85e4bb35fdf1", size = 6514777, upload-time = "2025-07-24T20:55:57.66Z" }, - { url = "https://files.pythonhosted.org/packages/eb/46/3dbaf0ae7c17cdc46b9f662c56da2054887b8d9e737c1476f335c83d33db/numpy-2.3.2-cp314-cp314t-win_amd64.whl", hash = "sha256:087ffc25890d89a43536f75c5fe8770922008758e8eeeef61733957041ed2f9b", size = 13111856, upload-time = "2025-07-24T20:56:17.318Z" }, - { url = "https://files.pythonhosted.org/packages/c1/9e/1652778bce745a67b5fe05adde60ed362d38eb17d919a540e813d30f6874/numpy-2.3.2-cp314-cp314t-win_arm64.whl", hash = "sha256:092aeb3449833ea9c0bf0089d70c29ae480685dd2377ec9cdbbb620257f84631", size = 10544226, upload-time = "2025-07-24T20:56:34.509Z" }, -] - -[[package]] -name = "packaging" -version = "25.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, -] - -[[package]] -name = "pillow" -version = "11.3.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f3/0d/d0d6dea55cd152ce3d6767bb38a8fc10e33796ba4ba210cbab9354b6d238/pillow-11.3.0.tar.gz", hash = "sha256:3828ee7586cd0b2091b6209e5ad53e20d0649bbe87164a459d0676e035e8f523", size = 47113069, upload-time = "2025-07-01T09:16:30.666Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/40/fe/1bc9b3ee13f68487a99ac9529968035cca2f0a51ec36892060edcc51d06a/pillow-11.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fdae223722da47b024b867c1ea0be64e0df702c5e0a60e27daad39bf960dd1e4", size = 5278800, upload-time = "2025-07-01T09:14:17.648Z" }, - { url = "https://files.pythonhosted.org/packages/2c/32/7e2ac19b5713657384cec55f89065fb306b06af008cfd87e572035b27119/pillow-11.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:921bd305b10e82b4d1f5e802b6850677f965d8394203d182f078873851dada69", size = 4686296, upload-time = "2025-07-01T09:14:19.828Z" }, - { url = "https://files.pythonhosted.org/packages/8e/1e/b9e12bbe6e4c2220effebc09ea0923a07a6da1e1f1bfbc8d7d29a01ce32b/pillow-11.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:eb76541cba2f958032d79d143b98a3a6b3ea87f0959bbe256c0b5e416599fd5d", size = 5871726, upload-time = "2025-07-03T13:10:04.448Z" }, - { url = "https://files.pythonhosted.org/packages/8d/33/e9200d2bd7ba00dc3ddb78df1198a6e80d7669cce6c2bdbeb2530a74ec58/pillow-11.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:67172f2944ebba3d4a7b54f2e95c786a3a50c21b88456329314caaa28cda70f6", size = 7644652, upload-time = "2025-07-03T13:10:10.391Z" }, - { url = "https://files.pythonhosted.org/packages/41/f1/6f2427a26fc683e00d985bc391bdd76d8dd4e92fac33d841127eb8fb2313/pillow-11.3.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:97f07ed9f56a3b9b5f49d3661dc9607484e85c67e27f3e8be2c7d28ca032fec7", size = 5977787, upload-time = "2025-07-01T09:14:21.63Z" }, - { url = "https://files.pythonhosted.org/packages/e4/c9/06dd4a38974e24f932ff5f98ea3c546ce3f8c995d3f0985f8e5ba48bba19/pillow-11.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:676b2815362456b5b3216b4fd5bd89d362100dc6f4945154ff172e206a22c024", size = 6645236, upload-time = "2025-07-01T09:14:23.321Z" }, - { url = "https://files.pythonhosted.org/packages/40/e7/848f69fb79843b3d91241bad658e9c14f39a32f71a301bcd1d139416d1be/pillow-11.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3e184b2f26ff146363dd07bde8b711833d7b0202e27d13540bfe2e35a323a809", size = 6086950, upload-time = "2025-07-01T09:14:25.237Z" }, - { url = "https://files.pythonhosted.org/packages/0b/1a/7cff92e695a2a29ac1958c2a0fe4c0b2393b60aac13b04a4fe2735cad52d/pillow-11.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6be31e3fc9a621e071bc17bb7de63b85cbe0bfae91bb0363c893cbe67247780d", size = 6723358, upload-time = "2025-07-01T09:14:27.053Z" }, - { url = "https://files.pythonhosted.org/packages/26/7d/73699ad77895f69edff76b0f332acc3d497f22f5d75e5360f78cbcaff248/pillow-11.3.0-cp312-cp312-win32.whl", hash = "sha256:7b161756381f0918e05e7cb8a371fff367e807770f8fe92ecb20d905d0e1c149", size = 6275079, upload-time = "2025-07-01T09:14:30.104Z" }, - { url = "https://files.pythonhosted.org/packages/8c/ce/e7dfc873bdd9828f3b6e5c2bbb74e47a98ec23cc5c74fc4e54462f0d9204/pillow-11.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:a6444696fce635783440b7f7a9fc24b3ad10a9ea3f0ab66c5905be1c19ccf17d", size = 6986324, upload-time = "2025-07-01T09:14:31.899Z" }, - { url = "https://files.pythonhosted.org/packages/16/8f/b13447d1bf0b1f7467ce7d86f6e6edf66c0ad7cf44cf5c87a37f9bed9936/pillow-11.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:2aceea54f957dd4448264f9bf40875da0415c83eb85f55069d89c0ed436e3542", size = 2423067, upload-time = "2025-07-01T09:14:33.709Z" }, - { url = "https://files.pythonhosted.org/packages/1e/93/0952f2ed8db3a5a4c7a11f91965d6184ebc8cd7cbb7941a260d5f018cd2d/pillow-11.3.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:1c627742b539bba4309df89171356fcb3cc5a9178355b2727d1b74a6cf155fbd", size = 2128328, upload-time = "2025-07-01T09:14:35.276Z" }, - { url = "https://files.pythonhosted.org/packages/4b/e8/100c3d114b1a0bf4042f27e0f87d2f25e857e838034e98ca98fe7b8c0a9c/pillow-11.3.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:30b7c02f3899d10f13d7a48163c8969e4e653f8b43416d23d13d1bbfdc93b9f8", size = 2170652, upload-time = "2025-07-01T09:14:37.203Z" }, - { url = "https://files.pythonhosted.org/packages/aa/86/3f758a28a6e381758545f7cdb4942e1cb79abd271bea932998fc0db93cb6/pillow-11.3.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:7859a4cc7c9295f5838015d8cc0a9c215b77e43d07a25e460f35cf516df8626f", size = 2227443, upload-time = "2025-07-01T09:14:39.344Z" }, - { url = "https://files.pythonhosted.org/packages/01/f4/91d5b3ffa718df2f53b0dc109877993e511f4fd055d7e9508682e8aba092/pillow-11.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ec1ee50470b0d050984394423d96325b744d55c701a439d2bd66089bff963d3c", size = 5278474, upload-time = "2025-07-01T09:14:41.843Z" }, - { url = "https://files.pythonhosted.org/packages/f9/0e/37d7d3eca6c879fbd9dba21268427dffda1ab00d4eb05b32923d4fbe3b12/pillow-11.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7db51d222548ccfd274e4572fdbf3e810a5e66b00608862f947b163e613b67dd", size = 4686038, upload-time = "2025-07-01T09:14:44.008Z" }, - { url = "https://files.pythonhosted.org/packages/ff/b0/3426e5c7f6565e752d81221af9d3676fdbb4f352317ceafd42899aaf5d8a/pillow-11.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2d6fcc902a24ac74495df63faad1884282239265c6839a0a6416d33faedfae7e", size = 5864407, upload-time = "2025-07-03T13:10:15.628Z" }, - { url = "https://files.pythonhosted.org/packages/fc/c1/c6c423134229f2a221ee53f838d4be9d82bab86f7e2f8e75e47b6bf6cd77/pillow-11.3.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f0f5d8f4a08090c6d6d578351a2b91acf519a54986c055af27e7a93feae6d3f1", size = 7639094, upload-time = "2025-07-03T13:10:21.857Z" }, - { url = "https://files.pythonhosted.org/packages/ba/c9/09e6746630fe6372c67c648ff9deae52a2bc20897d51fa293571977ceb5d/pillow-11.3.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c37d8ba9411d6003bba9e518db0db0c58a680ab9fe5179f040b0463644bc9805", size = 5973503, upload-time = "2025-07-01T09:14:45.698Z" }, - { url = "https://files.pythonhosted.org/packages/d5/1c/a2a29649c0b1983d3ef57ee87a66487fdeb45132df66ab30dd37f7dbe162/pillow-11.3.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:13f87d581e71d9189ab21fe0efb5a23e9f28552d5be6979e84001d3b8505abe8", size = 6642574, upload-time = "2025-07-01T09:14:47.415Z" }, - { url = "https://files.pythonhosted.org/packages/36/de/d5cc31cc4b055b6c6fd990e3e7f0f8aaf36229a2698501bcb0cdf67c7146/pillow-11.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:023f6d2d11784a465f09fd09a34b150ea4672e85fb3d05931d89f373ab14abb2", size = 6084060, upload-time = "2025-07-01T09:14:49.636Z" }, - { url = "https://files.pythonhosted.org/packages/d5/ea/502d938cbaeec836ac28a9b730193716f0114c41325db428e6b280513f09/pillow-11.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:45dfc51ac5975b938e9809451c51734124e73b04d0f0ac621649821a63852e7b", size = 6721407, upload-time = "2025-07-01T09:14:51.962Z" }, - { url = "https://files.pythonhosted.org/packages/45/9c/9c5e2a73f125f6cbc59cc7087c8f2d649a7ae453f83bd0362ff7c9e2aee2/pillow-11.3.0-cp313-cp313-win32.whl", hash = "sha256:a4d336baed65d50d37b88ca5b60c0fa9d81e3a87d4a7930d3880d1624d5b31f3", size = 6273841, upload-time = "2025-07-01T09:14:54.142Z" }, - { url = "https://files.pythonhosted.org/packages/23/85/397c73524e0cd212067e0c969aa245b01d50183439550d24d9f55781b776/pillow-11.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:0bce5c4fd0921f99d2e858dc4d4d64193407e1b99478bc5cacecba2311abde51", size = 6978450, upload-time = "2025-07-01T09:14:56.436Z" }, - { url = "https://files.pythonhosted.org/packages/17/d2/622f4547f69cd173955194b78e4d19ca4935a1b0f03a302d655c9f6aae65/pillow-11.3.0-cp313-cp313-win_arm64.whl", hash = "sha256:1904e1264881f682f02b7f8167935cce37bc97db457f8e7849dc3a6a52b99580", size = 2423055, upload-time = "2025-07-01T09:14:58.072Z" }, - { url = "https://files.pythonhosted.org/packages/dd/80/a8a2ac21dda2e82480852978416cfacd439a4b490a501a288ecf4fe2532d/pillow-11.3.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:4c834a3921375c48ee6b9624061076bc0a32a60b5532b322cc0ea64e639dd50e", size = 5281110, upload-time = "2025-07-01T09:14:59.79Z" }, - { url = "https://files.pythonhosted.org/packages/44/d6/b79754ca790f315918732e18f82a8146d33bcd7f4494380457ea89eb883d/pillow-11.3.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5e05688ccef30ea69b9317a9ead994b93975104a677a36a8ed8106be9260aa6d", size = 4689547, upload-time = "2025-07-01T09:15:01.648Z" }, - { url = "https://files.pythonhosted.org/packages/49/20/716b8717d331150cb00f7fdd78169c01e8e0c219732a78b0e59b6bdb2fd6/pillow-11.3.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1019b04af07fc0163e2810167918cb5add8d74674b6267616021ab558dc98ced", size = 5901554, upload-time = "2025-07-03T13:10:27.018Z" }, - { url = "https://files.pythonhosted.org/packages/74/cf/a9f3a2514a65bb071075063a96f0a5cf949c2f2fce683c15ccc83b1c1cab/pillow-11.3.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f944255db153ebb2b19c51fe85dd99ef0ce494123f21b9db4877ffdfc5590c7c", size = 7669132, upload-time = "2025-07-03T13:10:33.01Z" }, - { url = "https://files.pythonhosted.org/packages/98/3c/da78805cbdbee9cb43efe8261dd7cc0b4b93f2ac79b676c03159e9db2187/pillow-11.3.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1f85acb69adf2aaee8b7da124efebbdb959a104db34d3a2cb0f3793dbae422a8", size = 6005001, upload-time = "2025-07-01T09:15:03.365Z" }, - { url = "https://files.pythonhosted.org/packages/6c/fa/ce044b91faecf30e635321351bba32bab5a7e034c60187fe9698191aef4f/pillow-11.3.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:05f6ecbeff5005399bb48d198f098a9b4b6bdf27b8487c7f38ca16eeb070cd59", size = 6668814, upload-time = "2025-07-01T09:15:05.655Z" }, - { url = "https://files.pythonhosted.org/packages/7b/51/90f9291406d09bf93686434f9183aba27b831c10c87746ff49f127ee80cb/pillow-11.3.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:a7bc6e6fd0395bc052f16b1a8670859964dbd7003bd0af2ff08342eb6e442cfe", size = 6113124, upload-time = "2025-07-01T09:15:07.358Z" }, - { url = "https://files.pythonhosted.org/packages/cd/5a/6fec59b1dfb619234f7636d4157d11fb4e196caeee220232a8d2ec48488d/pillow-11.3.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:83e1b0161c9d148125083a35c1c5a89db5b7054834fd4387499e06552035236c", size = 6747186, upload-time = "2025-07-01T09:15:09.317Z" }, - { url = "https://files.pythonhosted.org/packages/49/6b/00187a044f98255225f172de653941e61da37104a9ea60e4f6887717e2b5/pillow-11.3.0-cp313-cp313t-win32.whl", hash = "sha256:2a3117c06b8fb646639dce83694f2f9eac405472713fcb1ae887469c0d4f6788", size = 6277546, upload-time = "2025-07-01T09:15:11.311Z" }, - { url = "https://files.pythonhosted.org/packages/e8/5c/6caaba7e261c0d75bab23be79f1d06b5ad2a2ae49f028ccec801b0e853d6/pillow-11.3.0-cp313-cp313t-win_amd64.whl", hash = "sha256:857844335c95bea93fb39e0fa2726b4d9d758850b34075a7e3ff4f4fa3aa3b31", size = 6985102, upload-time = "2025-07-01T09:15:13.164Z" }, - { url = "https://files.pythonhosted.org/packages/f3/7e/b623008460c09a0cb38263c93b828c666493caee2eb34ff67f778b87e58c/pillow-11.3.0-cp313-cp313t-win_arm64.whl", hash = "sha256:8797edc41f3e8536ae4b10897ee2f637235c94f27404cac7297f7b607dd0716e", size = 2424803, upload-time = "2025-07-01T09:15:15.695Z" }, - { url = "https://files.pythonhosted.org/packages/73/f4/04905af42837292ed86cb1b1dabe03dce1edc008ef14c473c5c7e1443c5d/pillow-11.3.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:d9da3df5f9ea2a89b81bb6087177fb1f4d1c7146d583a3fe5c672c0d94e55e12", size = 5278520, upload-time = "2025-07-01T09:15:17.429Z" }, - { url = "https://files.pythonhosted.org/packages/41/b0/33d79e377a336247df6348a54e6d2a2b85d644ca202555e3faa0cf811ecc/pillow-11.3.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:0b275ff9b04df7b640c59ec5a3cb113eefd3795a8df80bac69646ef699c6981a", size = 4686116, upload-time = "2025-07-01T09:15:19.423Z" }, - { url = "https://files.pythonhosted.org/packages/49/2d/ed8bc0ab219ae8768f529597d9509d184fe8a6c4741a6864fea334d25f3f/pillow-11.3.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0743841cabd3dba6a83f38a92672cccbd69af56e3e91777b0ee7f4dba4385632", size = 5864597, upload-time = "2025-07-03T13:10:38.404Z" }, - { url = "https://files.pythonhosted.org/packages/b5/3d/b932bb4225c80b58dfadaca9d42d08d0b7064d2d1791b6a237f87f661834/pillow-11.3.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2465a69cf967b8b49ee1b96d76718cd98c4e925414ead59fdf75cf0fd07df673", size = 7638246, upload-time = "2025-07-03T13:10:44.987Z" }, - { url = "https://files.pythonhosted.org/packages/09/b5/0487044b7c096f1b48f0d7ad416472c02e0e4bf6919541b111efd3cae690/pillow-11.3.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:41742638139424703b4d01665b807c6468e23e699e8e90cffefe291c5832b027", size = 5973336, upload-time = "2025-07-01T09:15:21.237Z" }, - { url = "https://files.pythonhosted.org/packages/a8/2d/524f9318f6cbfcc79fbc004801ea6b607ec3f843977652fdee4857a7568b/pillow-11.3.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:93efb0b4de7e340d99057415c749175e24c8864302369e05914682ba642e5d77", size = 6642699, upload-time = "2025-07-01T09:15:23.186Z" }, - { url = "https://files.pythonhosted.org/packages/6f/d2/a9a4f280c6aefedce1e8f615baaa5474e0701d86dd6f1dede66726462bbd/pillow-11.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:7966e38dcd0fa11ca390aed7c6f20454443581d758242023cf36fcb319b1a874", size = 6083789, upload-time = "2025-07-01T09:15:25.1Z" }, - { url = "https://files.pythonhosted.org/packages/fe/54/86b0cd9dbb683a9d5e960b66c7379e821a19be4ac5810e2e5a715c09a0c0/pillow-11.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:98a9afa7b9007c67ed84c57c9e0ad86a6000da96eaa638e4f8abe5b65ff83f0a", size = 6720386, upload-time = "2025-07-01T09:15:27.378Z" }, - { url = "https://files.pythonhosted.org/packages/e7/95/88efcaf384c3588e24259c4203b909cbe3e3c2d887af9e938c2022c9dd48/pillow-11.3.0-cp314-cp314-win32.whl", hash = "sha256:02a723e6bf909e7cea0dac1b0e0310be9d7650cd66222a5f1c571455c0a45214", size = 6370911, upload-time = "2025-07-01T09:15:29.294Z" }, - { url = "https://files.pythonhosted.org/packages/2e/cc/934e5820850ec5eb107e7b1a72dd278140731c669f396110ebc326f2a503/pillow-11.3.0-cp314-cp314-win_amd64.whl", hash = "sha256:a418486160228f64dd9e9efcd132679b7a02a5f22c982c78b6fc7dab3fefb635", size = 7117383, upload-time = "2025-07-01T09:15:31.128Z" }, - { url = "https://files.pythonhosted.org/packages/d6/e9/9c0a616a71da2a5d163aa37405e8aced9a906d574b4a214bede134e731bc/pillow-11.3.0-cp314-cp314-win_arm64.whl", hash = "sha256:155658efb5e044669c08896c0c44231c5e9abcaadbc5cd3648df2f7c0b96b9a6", size = 2511385, upload-time = "2025-07-01T09:15:33.328Z" }, - { url = "https://files.pythonhosted.org/packages/1a/33/c88376898aff369658b225262cd4f2659b13e8178e7534df9e6e1fa289f6/pillow-11.3.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:59a03cdf019efbfeeed910bf79c7c93255c3d54bc45898ac2a4140071b02b4ae", size = 5281129, upload-time = "2025-07-01T09:15:35.194Z" }, - { url = "https://files.pythonhosted.org/packages/1f/70/d376247fb36f1844b42910911c83a02d5544ebd2a8bad9efcc0f707ea774/pillow-11.3.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:f8a5827f84d973d8636e9dc5764af4f0cf2318d26744b3d902931701b0d46653", size = 4689580, upload-time = "2025-07-01T09:15:37.114Z" }, - { url = "https://files.pythonhosted.org/packages/eb/1c/537e930496149fbac69efd2fc4329035bbe2e5475b4165439e3be9cb183b/pillow-11.3.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ee92f2fd10f4adc4b43d07ec5e779932b4eb3dbfbc34790ada5a6669bc095aa6", size = 5902860, upload-time = "2025-07-03T13:10:50.248Z" }, - { url = "https://files.pythonhosted.org/packages/bd/57/80f53264954dcefeebcf9dae6e3eb1daea1b488f0be8b8fef12f79a3eb10/pillow-11.3.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c96d333dcf42d01f47b37e0979b6bd73ec91eae18614864622d9b87bbd5bbf36", size = 7670694, upload-time = "2025-07-03T13:10:56.432Z" }, - { url = "https://files.pythonhosted.org/packages/70/ff/4727d3b71a8578b4587d9c276e90efad2d6fe0335fd76742a6da08132e8c/pillow-11.3.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4c96f993ab8c98460cd0c001447bff6194403e8b1d7e149ade5f00594918128b", size = 6005888, upload-time = "2025-07-01T09:15:39.436Z" }, - { url = "https://files.pythonhosted.org/packages/05/ae/716592277934f85d3be51d7256f3636672d7b1abfafdc42cf3f8cbd4b4c8/pillow-11.3.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:41342b64afeba938edb034d122b2dda5db2139b9a4af999729ba8818e0056477", size = 6670330, upload-time = "2025-07-01T09:15:41.269Z" }, - { url = "https://files.pythonhosted.org/packages/e7/bb/7fe6cddcc8827b01b1a9766f5fdeb7418680744f9082035bdbabecf1d57f/pillow-11.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:068d9c39a2d1b358eb9f245ce7ab1b5c3246c7c8c7d9ba58cfa5b43146c06e50", size = 6114089, upload-time = "2025-07-01T09:15:43.13Z" }, - { url = "https://files.pythonhosted.org/packages/8b/f5/06bfaa444c8e80f1a8e4bff98da9c83b37b5be3b1deaa43d27a0db37ef84/pillow-11.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a1bc6ba083b145187f648b667e05a2534ecc4b9f2784c2cbe3089e44868f2b9b", size = 6748206, upload-time = "2025-07-01T09:15:44.937Z" }, - { url = "https://files.pythonhosted.org/packages/f0/77/bc6f92a3e8e6e46c0ca78abfffec0037845800ea38c73483760362804c41/pillow-11.3.0-cp314-cp314t-win32.whl", hash = "sha256:118ca10c0d60b06d006be10a501fd6bbdfef559251ed31b794668ed569c87e12", size = 6377370, upload-time = "2025-07-01T09:15:46.673Z" }, - { url = "https://files.pythonhosted.org/packages/4a/82/3a721f7d69dca802befb8af08b7c79ebcab461007ce1c18bd91a5d5896f9/pillow-11.3.0-cp314-cp314t-win_amd64.whl", hash = "sha256:8924748b688aa210d79883357d102cd64690e56b923a186f35a82cbc10f997db", size = 7121500, upload-time = "2025-07-01T09:15:48.512Z" }, - { url = "https://files.pythonhosted.org/packages/89/c7/5572fa4a3f45740eaab6ae86fcdf7195b55beac1371ac8c619d880cfe948/pillow-11.3.0-cp314-cp314t-win_arm64.whl", hash = "sha256:79ea0d14d3ebad43ec77ad5272e6ff9bba5b679ef73375ea760261207fa8e0aa", size = 2512835, upload-time = "2025-07-01T09:15:50.399Z" }, -] - -[[package]] -name = "pyparsing" -version = "3.2.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bb/22/f1129e69d94ffff626bdb5c835506b3a5b4f3d070f17ea295e12c2c6f60f/pyparsing-3.2.3.tar.gz", hash = "sha256:b9c13f1ab8b3b542f72e28f634bad4de758ab3ce4546e4301970ad6fa77c38be", size = 1088608, upload-time = "2025-03-25T05:01:28.114Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl", hash = "sha256:a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf", size = 111120, upload-time = "2025-03-25T05:01:24.908Z" }, -] - -[[package]] -name = "python-dateutil" -version = "2.9.0.post0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "six" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432, upload-time = "2024-03-01T18:36:20.211Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, -] - -[[package]] -name = "six" -version = "1.17.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031, upload-time = "2024-12-04T17:35:28.174Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, -] diff --git a/minimal_zkVM.pdf b/minimal_zkVM.pdf index 5478c89d..60d3102f 100644 Binary files a/minimal_zkVM.pdf and b/minimal_zkVM.pdf differ diff --git a/misc/bibliography.bib b/misc/bibliography.bib new file mode 100644 index 00000000..1a9baac3 --- /dev/null +++ b/misc/bibliography.bib @@ -0,0 +1,153 @@ +@article{whir, + author = {Gal Arnon and Alessandro Chiesa and Giacomo Fenzi and Eylon Yogev}, + title = {{WHIR}: Reed–Solomon Proximity Testing with Super-Fast Verification}, + howpublished = {Cryptology {ePrint} Archive, Paper 2024/1586}, + year = {2024}, + url = {https://eprint.iacr.org/2024/1586} +} +@article{fri_binius, + author = {Benjamin E. Diamond and Jim Posen}, + title = {Polylogarithmic Proofs for Multilinears over Binary Towers}, + howpublished = {Cryptology {ePrint} Archive, Paper 2024/504}, + year = {2024}, + url = {https://eprint.iacr.org/2024/504} +} +@article{ccs, + author = {Srinath Setty and Justin Thaler and Riad Wahby}, + title = {Customizable constraint systems for succinct arguments}, + howpublished = {Cryptology {ePrint} Archive, Paper 2023/552}, + year = {2023}, + url = {https://eprint.iacr.org/2023/552} +} +@article{simple_multivariate_AIR, + author = {William Borgeaud}, + title = {A simple multivariate AIR argument inspired by SuperSpartan}, + year = {2023}, + url = {https://solvable.group/posts/super-air/} +} +@article{hyperplonk, + author = {Binyi Chen and Benedikt Bünz and Dan Boneh and Zhenfei Zhang}, + title = {{HyperPlonk}: Plonk with Linear-Time Prover and High-Degree Custom Gates}, + howpublished = {Cryptology {ePrint} Archive, Paper 2022/1355}, + year = {2022}, + url = {https://eprint.iacr.org/2022/1355} +} +@article{univariate_skip, + author = {Angus Gruen}, + title = {Some Improvements for the {PIOP} for {ZeroCheck}}, + howpublished = {Cryptology {ePrint} Archive, Paper 2024/108}, + year = {2024}, + url = {https://eprint.iacr.org/2024/108} +} +@misc{eth_stark, + author = {StarkWare}, + title = {{ethSTARK} Documentation}, + howpublished = {Cryptology {ePrint} Archive, Paper 2021/582}, + year = {2021}, + url = {https://eprint.iacr.org/2021/582} +} + + +@misc{ethereum_signatures, + author = {Justin Drake and Dmitry Khovratovich and Mikhail Kudinov and Benedikt Wagner}, + title = {Hash-Based Multi-Signatures for Post-Quantum Ethereum}, + howpublished = {Cryptology {ePrint} Archive, Paper 2025/055}, + year = {2025}, + doi = {10.62056/aey7qjp10}, + url = {https://eprint.iacr.org/2025/055} +} + +@misc{proximity_gaps_rs_codes, + author = {Eli Ben-Sasson and Dan Carmon and Yuval Ishai and Swastik Kopparty and Shubhangi Saraf}, + title = {Proximity Gaps for Reed-Solomon Codes}, + howpublished = {Cryptology {ePrint} Archive, Paper 2020/654}, + year = {2020}, + url = {https://eprint.iacr.org/2020/654} +} + +@misc{cairo, + author = {Lior Goldberg and Shahar Papini and Michael Riabzev}, + title = {Cairo – a Turing-complete {STARK}-friendly {CPU} architecture}, + howpublished = {Cryptology {ePrint} Archive, Paper 2021/1063}, + year = {2021}, + url = {https://eprint.iacr.org/2021/1063} +} + +@misc{spice, + author = {Srinath Setty and Sebastian Angel and Trinabh Gupta and Jonathan Lee}, + title = {Proving the correct execution of concurrent services in zero-knowledge}, + howpublished = {Cryptology {ePrint} Archive, Paper 2018/907}, + year = {2018}, + url = {https://eprint.iacr.org/2018/907} +} +@misc{top_hypercube, + author = {Dmitry Khovratovich and Mikhail Kudinov and Benedikt Wagner}, + title = {At the Top of the Hypercube -- Better Size-Time Tradeoffs for Hash-Based Signatures}, + howpublished = {Cryptology {ePrint} Archive, Paper 2025/889}, + year = {2025}, + url = {https://eprint.iacr.org/2025/889} +} + +@misc{poseidon2, + author = {Lorenzo Grassi and Dmitry Khovratovich and Markus Schofnegger}, + title = {Poseidon2: A Faster Version of the Poseidon Hash Function}, + howpublished = {Cryptology {ePrint} Archive, Paper 2023/323}, + year = {2023}, + url = {https://eprint.iacr.org/2023/323} +} + +@misc{logup_star, + author = {Lev Soukhanov}, + title = {Logup*: faster, cheaper logup argument for small-table indexed lookups}, + howpublished = {Cryptology {ePrint} Archive, Paper 2025/946}, + year = {2025}, + url = {https://eprint.iacr.org/2025/946} +} + +@misc{LeanSig, + author = {Justin Drake and Dmitry Khovratovich and Mikhail Kudinov and Benedikt Wagner}, + title = {Technical Note: {LeanSig} for Post-Quantum Ethereum}, + howpublished = {Cryptology {ePrint} Archive, Paper 2025/1332}, + year = {2025}, + url = {https://eprint.iacr.org/2025/1332} +} + +@misc{jagged_pcs, + author = {Tamir Hemo and Kevin Jue and Eugene Rabinovich and Gyumin Roh and Ron D. Rothblum}, + title = {Jagged Polynomial Commitments (or: How to Stack Multilinears)}, + howpublished = {Cryptology {ePrint} Archive, Paper 2025/917}, + year = {2025}, + url = {https://eprint.iacr.org/2025/917} +} + +@misc{proximity_gaps_rs_codes_2, + author = {Eli Ben-Sasson and Dan Carmon and Ulrich Haböck and Swastik Kopparty and Shubhangi Saraf}, + title = {On Proximity Gaps for Reed–Solomon Codes}, + howpublished = {Cryptology {ePrint} Archive, Paper 2025/2055}, + year = {2025}, + url = {https://eprint.iacr.org/2025/2055} +} + +@misc{logup, + author = {Ulrich Haböck}, + title = {Multivariate lookups based on logarithmic derivatives}, + howpublished = {Cryptology {ePrint} Archive, Paper 2022/1530}, + year = {2022}, + url = {https://eprint.iacr.org/2022/1530} +} + +@misc{logup_gkr, + author = {Shahar Papini and Ulrich Haböck}, + title = {Improving logarithmic derivative lookups using {GKR}}, + howpublished = {Cryptology {ePrint} Archive, Paper 2023/1284}, + year = {2023}, + url = {https://eprint.iacr.org/2023/1284} +} + +@misc{openvm, + author = {OPENVM CONTRIBUTORS}, + title = {OPENVM WHITEPAPER}, + year = {2025}, + url = {https://openvm.dev/whitepaper.pdf} +} + diff --git a/misc/images/AggMerge.png b/misc/images/AggMerge.png new file mode 100644 index 00000000..c1fdade5 Binary files /dev/null and b/misc/images/AggMerge.png differ diff --git a/misc/images/banner.svg b/misc/images/banner.svg new file mode 100644 index 00000000..8f186144 --- /dev/null +++ b/misc/images/banner.svg @@ -0,0 +1,172 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/misc/images/memory.png b/misc/images/memory.png new file mode 100644 index 00000000..5f661199 Binary files /dev/null and b/misc/images/memory.png differ diff --git a/misc/images/memory_layout.png b/misc/images/memory_layout.png new file mode 100644 index 00000000..732da7a9 Binary files /dev/null and b/misc/images/memory_layout.png differ diff --git a/misc/minimal_zkVM.tex b/misc/minimal_zkVM.tex new file mode 100644 index 00000000..ac80d6d9 --- /dev/null +++ b/misc/minimal_zkVM.tex @@ -0,0 +1,616 @@ +\documentclass{article} + +\usepackage[english]{babel} + +\usepackage[letterpaper,top=2cm,bottom=2cm,left=3cm,right=3cm,marginparwidth=1.75cm]{geometry} + +% Useful packages +\usepackage{amsmath} +\usepackage{amssymb} +\usepackage{xcolor} +\usepackage{graphicx} +\usepackage{mathtools} +\usepackage{framed} +\usepackage{tikz} +\usepackage[colorlinks=true, allcolors=blue]{hyperref} +\usepackage{xcolor} +\usepackage{colortbl} +\usepackage{booktabs} +\usepackage{tikz} +\usepackage{tcolorbox} +\usepackage{algorithm} +\usepackage{algpseudocode} +\usepackage{amsthm} +\usetikzlibrary{positioning, arrows.meta} + +\theoremstyle{definition} +\newtheorem{exmp}{Example}[section] + +\definecolor{sigblue}{RGB}{214,228,247} +\definecolor{siggreen}{RGB}{208,235,203} +\definecolor{sigyellow}{RGB}{255,245,200} + +\newcommand{\Fp}{\mathbb F_p} +\newcommand{\Fq}{\mathbb F_q} +\newcommand{\offdest}{\text{off}_{\text{dest}}} +\newcommand{\offopzero}{\text{off}_{\text{op0}}} +\newcommand{\offopone}{\text{off}_{\text{op1}}} +\newcommand*{\logeq}{\ratio\Leftrightarrow} + +\algnewcommand\algorithmicpublicinput{\textbf{Public input:}} +\algnewcommand\publicinput{\item[\algorithmicpublicinput]} + +\algnewcommand\algorithmicprivateinput{\textbf{Private input:}} +\algnewcommand\privateinput{\item[\algorithmicprivateinput]} + +\newtheorem{lemma}{Lemma} + +\title{Minimal zkVM for Lean Ethereum (draft 0.5.0)} +\date{} +\begin{document} + +\maketitle + +\section{What is the goal of this zkVM?} + +Replacing the BLS signature scheme with a Post-Quantum alternative. One approach is to use stateful hash-based signatures, XMSS, as explained in \cite{ethereum_signatures}, \cite{top_hypercube} and \cite{LeanSig}, and to use a hash-based SNARK to handle aggregation. +A candidate hash function is Poseidon2 \cite{poseidon2}. + +We want to be able to: + +\begin{itemize} + \item \textbf{Aggregate} XMSS signatures + \item \textbf{Merge} those aggregate signatures +\end{itemize} + +The latter involves recursively verifying a SNARK. Both tasks mainly require to prove a lot of hashes. A minimal zkVM (inspired by Cairo \cite{cairo}) is useful as glue to handle all the logic. + +Aggregate / Merge can be unified in a single program, which is the only one the zkVM has to prove (see \ref{fig1} for a visual interpretation): + +\begin{algorithm} +\caption{AggregateMerge} +\begin{algorithmic}[1] +\publicinput \textbf{pub\_keys} (of size $n$), \textbf{bitfield} ($k$ ones, $n-k$ zeros), \textbf{msg} (the encoding of the signed message) +\privateinput $s > 0$, \textbf{sub\_bitfields} (of size $s$), \textbf{aggregate\_proofs} (of size $s - 1$), \textbf{signatures} + + +\Comment{Bitfield consistency} + +\State Check: \textbf{bitfield} $= \bigcup_{i=0}^{s-1}$ +\textbf{sub\_bitfields}[i] + + +\State \Comment{Verify the first $s-1$ sub\_bitfields using aggregate\_proofs:} + +\For{$i \gets 0$ to $s-2$} + \State inner\_public\_input $\gets$ (\textbf{pub\_keys}, \textbf{sub\_bitfields}[i], \textbf{msg}) + \State \textit{snark\_verify}("AggregateMerge", inner\_public\_input, \textbf{aggregate\_proofs}[i]) + % \State \Comment{\textbf{snark\_verify}(program, public input, proof)} +\EndFor + +\State \Comment{Verify the last sub\_bitfields using signatures} + +\State $k \gets 0$ +\For{$i \gets 0$ to $n-1$} + \If{\textbf{sub\_bitfields}[s-1][i] = 1} + \State \textit{signature\_verify}(\textbf{msg}, \textbf{pub\_keys}[i], \textbf{signatures}[k]) + \State $k \gets k + 1$ + \EndIf +\EndFor + +\end{algorithmic} +\end{algorithm} + + + +\begin{figure}[h] +\label{fig1} +\caption{AggregateMerge visualized.} +\centering +\includegraphics[scale=0.6]{images/AggMerge.png} +\end{figure} + + +\section{VM specification} +\subsection{Field} + +\fbox{KoalaBear prime: $p = 2^{31} - 2^{24} + 1$} + +\vspace{3mm} + +Advantages: +\begin{itemize} + \item small field $\xrightarrow{}$ less Poseidon rounds + \item $x \xrightarrow{} x^3$ is an automorphism of $\Fp^*$, meaning efficient S-box for Poseidon2 (in BabyBear, it's degree $7$) + \item $< 2^{31}$ $\xrightarrow{}$ the sum of 2 field elements can be stored in an u32 +\end{itemize} + +The small 2-addicity (24) is not a limiting factor in WHIR, thanks to the use of an interleaved Reed Solomon code. + +% For instance, in a degree 8 extension, with rate = $1/2$, FRI/WHIR can commit to a polynomial of degree $2^{26}$ (FFT on a domain of size $2^{27}$), which is to equivalent to $2^{29}$ KoalaBear elements after RingSwitching \cite{fri_binius} $\approx 2 $ GiB of data (enough to prove the validity of $2$ million Poseidon2). + + +\vspace{3mm} + +\fbox{Extension field: of degree 5} : enables 128 bits of security in WHIR, with the Johnson bound, thanks to the latest result of \cite{proximity_gaps_rs_codes_2}. + +\subsection{Memory} + +\begin{itemize} + \item Read-Only Memory + \item Word = KoalaBear field element + \item Memory size: $M = 2^m$ with $16 \leq m \leq 29$ ($m$ depends on the execution and is communicated at the beginning of the proof). + \item The first $M' = 2^{m'}$ memory cells hold the "public input", on which both prover and verifier must agree. + This enables to pass the arguments that the leanISA program receives as input (in our case: message to sign and XMSS public keys). +\end{itemize} + +% \begin{figure}[h] +% \caption{Memory structure} +% \centering +% \label{memory} +% \includegraphics[scale=0.7]{images/memory.png} +% \end{figure} + +\subsection{Registers} + +% As in Cairo: +\begin{itemize} + \item \fbox{pc: program counter} + % \item \textbf{ap}: allocation pointer = Points to the first untouched memory cell + \item \fbox{fp: frame pointer} : points to the start of the current stack +\end{itemize} + + +\textbf{Difference with Cairo: no "ap" register} (allocation pointer). + + +\subsection{Instruction Set Architecture} + +Notations: +\begin{itemize} + \item $\alpha$, $\beta$ and $\gamma$ represent parameters of the instructions (immediate value operands) + \item $\textbf{m}[i]$ represents the value of the memory at index $i \in \Fp$, with $i < M$ ($M$: memory size). Any out-of-bound memory access ($i \geq M$) is invalid. + \item $\begin{cases} A \\ B \end{cases}$ When using the instruction, either $A$ or $B$ can be used, but not both simultaneously. +\end{itemize} + + +\subsubsection{ADD / MUL} + +$a + c = b$ or $a \cdot c = b$ with: + +\begin{align*} +a &= +\begin{cases} +\alpha \\ +\textbf{m}[\text{fp} + \alpha] +\end{cases} +& +b &= +\begin{cases} +\beta \\ +\textbf{m}[\text{fp} + \beta] +\end{cases} +& +c &= +\begin{cases} +\text{fp} \\ +\textbf{m}[\text{fp} + \gamma] +\end{cases} +\end{align*} + + +\subsubsection{DEREF} + +$$ +\textbf{m}[\textbf{m}[\text{fp} + \alpha] + \beta] = +\begin{cases} +& \gamma \\ +& \textbf{m}[\text{fp} + \gamma] \\ +& \text{fp} + +\end{cases} +$$ + +\subsubsection{JUMP (Conditional)} + +\begin{align*} +\text{condition} &= +\begin{cases} +\alpha \\ +\textbf{m}[\text{fp} + \alpha] +\end{cases} \in \{0, 1\} +& +\text{dest} &= +\begin{cases} +\beta \\ +\textbf{m}[\text{fp} + \beta] +\end{cases} +& +\text{next(fp)} &= +\begin{cases} +\text{fp} \\ +\textbf{m}[\text{fp} + \gamma] +\end{cases} +\end{align*} + +$$ +\text{next(pc)} = +\begin{cases} +\text{dest} & \text{if condition} = 1 \\ +\text{pc} + 1 & \text{if condition} = 0 +\end{cases} +$$ + +\subsubsection{2 Precompiles} + + + +\begin{enumerate} + \item \textbf{POSEIDON}: Poseidon2 permutation over 16 field elements. + \item \textbf{DOT\_PRODUCT}: Computes a dot product between either: + \begin{itemize} + \item two slices of extension field elements + \item one slice of base field elements, and one slice of extension field elements + \end{itemize} +\end{enumerate} + + +\subsection{ISA programming} + +\subsubsection{Functions} + +\begin{enumerate} + \item Each function has a deterministic memory footprint: the length of the continuous frame in memory that is allocated for each of the its calls. + \item At runtime, each time we call our function, we receive via a memory cell a hint pointing to a free range of memory. We then store the current values of pc / fp at the start of this newly allocated frame, alongside the function's arguments, we can then jump, to enter the function bytecode, and modify fp with the hinted value. The intuition here is that the verifier does not care where the new memory frame will be placed (we use a read-only memory, so we cannot overwrite previous frames). In practice, the prover that runs the program would need to keep the value of the allocation pointer "ap", in order to adequately allocate new memory frames, but there is no need to keep track of it from the versifier's perspective. +\end{enumerate} + + +\begin{figure}[h] +\caption{Memory layout of a function call} +\centering +\includegraphics[scale=0.4]{images/memory_layout.png} +\end{figure} + +\subsubsection{Loops}\label{loops} + +We suggest to unroll loops when the number of iterations is low, and known at compile time. +The remaining loops are transformed into recursive functions (by the leanISA compiler). + +\subsubsection{Range checks} + +\fbox{It's possible to check that a given memory cell is smaller than some value $t$ (for $t \leq 2^{16})$ in 3 cycles.} + +We denote by \textbf{m}[fp + $x$] the memory cell for which we want to ensure \textbf{m}[fp + $x$] $< t$. +We also denote by \textbf{m}[fp + $i$], \textbf{m}[fp + $j$] and \textbf{m}[fp + $k$] 3 auxiliary memory cells (that have not been used yet). +\begin{enumerate} + \item \textbf{m}[\textbf{m}[fp + $x$]] = \textbf{m}[fp + $i$] (using DEREF, this ensures \textbf{m}[fp + $x$] $ < M$, the memory size) + \item \textbf{m}[\textbf{m}[fp + $x$]] + \textbf{m}[fp + $j$] = (t-1) (using ADD) + \item \textbf{m}[\textbf{m}[fp + $j$]] = \textbf{m}[fp + $k$] (using DEREF, this ensures $t - 1 - $ \textbf{m}[fp + $x$] $ < M$) +\end{enumerate} + +Given $t \leq 2^{16} \leq M$, \textbf{m}[fp + $x$] $ < M$, $t - 1 - $ \textbf{m}[fp + $x$] $< M$, and $M \leq 2^{29} < p / 2$, we have: \textbf{m}[fp + $x$] $ < t$. + +Note: From the point of view of the prover running the program, some hints are necessary (filling the values of \textbf{m}[fp + $i$] and \textbf{m}[fp + $k$] must be done at end of execution). + +This idea was pointed out by D. Khovratovich, and is an unplanned use of the DEREF instruction. + +\begin{exmp} + Let's say we want to write a function with 2 arguments $x = \textbf{m}$[fp + 2] and $y = \textbf{m}$[fp + 3] ($\textbf{m}$[fp + 0] and $\textbf{m}$[fp + 1] are used, by convention, to store the caller's pc and fp, to return to the previous context at the end of the function), which perform the following: + + \begin{enumerate} + \item assert(x $<$ 10) + \item z := x*y + 100 + \item assert(z $<$ 1000) + \end{enumerate} + + Which can be compiled to: + + \begin{enumerate} + \item \textbf{m}[fp + 4] = \textbf{m}[$\big[$fp + 2$\big]$] // check that $x$ is "small" + \item \textbf{m}[fp + 2] + \textbf{m}[fp + 5] = 9 // compute $9 - x$ + \item \textbf{m}[fp + 6] = \textbf{m}[$\big[$fp + 5$\big]$] // check that $9 - x$ is "small" + \item \textbf{m}[fp + 7] = \textbf{m}[fp + 2] * \textbf{m}[fp + 3] // compute $x.y$ + \item \textbf{m}[fp + 8] = \textbf{m}[fp + 7] + 100 // compute $z = x.y + 100$ + \item \textbf{m}[fp + 9] = \textbf{m}[$\big[$fp + 8$\big]$] // check that $z$ is "small" + \item \textbf{m}[fp + 8] + \textbf{m}[fp + 10] = 999 // compute $999 - z$ + \item \textbf{m}[fp + 11] = \textbf{m}[$\big[$fp + 10$\big]$] // check that $999 - z$ is "small" + \item JUMP with next(pc) = \textbf{m}[fp + 0], next(fp) = \textbf{m}[fp + 1], condition = 1 // return + \end{enumerate} + + +\end{exmp} + + +\subsubsection{Switch statements} + +Suppose we want a different logic depending on the value $x$ of a given memory cell, where $x$ is known to be $< k$ (if the value $x$ comes from a "hint", don't forget to range-check it). + +Each of the $k$ different value leads to a different branch at runtime, represented by a block of code. We want to jump to the correct block of code depending on $x$. +One efficient implementation consists in placing our blocks of code at regular intervals, and to jump to a $a+ b.x$, where $a$ is the offset of the first block of code (in case $x = 0$), and $b$ is the distance between two consecutive blocks. +\newline +\newline +Example: During XMSS verification, for each of the $v$ chains, we need to hash a pre-image, a number of times depending on the encoding, but known to be $< w$. Here $k = w$, and the $i-th$ block of code we could jump to will execute $i$ times the hash function (unrolled loop). + +\section{Proving system} + +\subsection{Execution table} + +\subsubsection{Reduced commitment via logup*} + +In Cairo each instruction is encoded with 15 boolean flags, and 3 offsets. In the execution trace, this leads to committing to 18 field elements at each instruction. + +We can significantly reduce the commitments cost using logup*\cite{logup_star}. In the the execution table, we only need to commit to the pc column, and all the flags / offsets describing the current instruction can be fetched by an indexed lookup argument (for which logup* drastically reduces commitment costs). + +\subsubsection{Commitment} + +\fbox{At each cycle, we commit to 8 (base) field elements:} + +\begin{itemize} + \item pc (program counter) + \item fp (frame pointer) + % \item jump (non zero when a jump occurs) + \item $\text{addr}_A$, $\text{addr}_B$, $\text{addr}_C$ + \item $\text{value}_A = \textbf{m}[\text{addr}_A]$, $\text{value}_B = \textbf{m}[\text{addr}_B]$, $\text{value}_C = \textbf{m}[\text{addr}_C]$ +\end{itemize} + + +\subsubsection{Instruction Encoding} + +Each instruction is described by 14 field elements: + +\begin{itemize} + \item 3 operands ($\in \Fp$): $\text{operand}_A$, $\text{operand}_B$, $\text{operand}_C$ + \item 3 associated flags ($\in \{0, 1\}$): $\text{flag}_A$, $\text{flag}_B$, $\text{flag}_C$ + \item 6 opcode flags ($\in \{0, 1\}$): ADD, MUL, DEREF, JUMP, IS\_PRECOMPILE, PRECOMPILE\_INDEX + \item 2 multi-purpose operands: AUX\_1, AUX\_2 +\end{itemize} + + +\subsubsection{AIR transition constraints} + +We use \fbox{transition constraints of degree 5}, but it's always possible to make them quadratic with additional columns in the execution table. + +\vspace{5mm} + +We define the following quantities: +\begin{itemize} + \item $\nu_A = \text{flag}_A \cdot \text{operand}_A + (1 - \text{flag}_A) \cdot \text{value}_A$ + \item $\nu_B = \text{flag}_B \cdot \text{operand}_B + (1 - \text{flag}_B) \cdot \text{value}_B$ + \item $\nu_C = \text{flag}_C \cdot \text{fp} + (1 - \text{flag}_C) \cdot \text{value}_C$ +\end{itemize} + +With the associated constraints: $\forall X \in \{A, B, C\}: (1 -\text{flag}_X) \cdot (\text{address}_X - (\text{fp} + \text{operand}_X)) = 0$ + +\vspace{3mm}\centerline{\rule{10cm}{0.4pt}}\vspace{5mm} + +For addition and multiplication: +\begin{itemize} + \item $\text{ADD} \cdot(\nu_B - (\nu_A + \nu_C)) = 0$ + \item $\text{MUL} \cdot(\nu_B - \nu_A \cdot \nu_C) = 0$ +\end{itemize} + +\vspace{3mm}\centerline{\rule{10cm}{0.4pt}}\vspace{5mm} + +When DEREF $= 1$, set $\text{flag}_A = 0$, $\text{flag}_C = 1$ and: +$$ +\textbf{m}[\textbf{m}[\text{fp} + \alpha] + \beta] = +\left\{ +\begin{array}{lcl} +\gamma & \xrightarrow{} & \text{AUX} = 1, \text{ flag}_B = 1 \\ +\textbf{m}[\text{fp} + \gamma] & \xrightarrow{} & \text{AUX} = 1, \text{ flag}_B = 0 \\ +\text{fp} & \xrightarrow{} & \text{AUX} = 0 \text{ (flag}_B = 1) +\end{array} +\right. +$$ + +\begin{itemize} + \item $\text{DEREF} \cdot (\text{addr}_C - (\text{value}_A + \text{operand}_C)) = 0$ + \item $\text{DEREF} \cdot \text{AUX} \cdot (\text{value}_C - \nu_B) = 0$ + \item $\text{DEREF} \cdot (1 - \text{AUX}) \cdot (\text{value}_C - \text{fp}) = 0$ +\end{itemize} + +\vspace{3mm}\centerline{\rule{10cm}{0.4pt}}\vspace{5mm} + +When there is no jump: +\begin{itemize} + \item $(1 - \text{JUMP}) \cdot ( \text{next(pc)} - (\text{pc} + 1)) = 0$ + \item $(1 - \text{JUMP}) \cdot (\text{next(fp)} - \text{fp}) = 0$ +\end{itemize} + +\vspace{3mm} + +When JUMP $= 1$, the condition is represented by $\nu_A$: +\begin{itemize} + \item $\text{JUMP} \cdot \nu_A \cdot (1 - \nu_A) = 0$ + \item $\text{JUMP} \cdot \nu_A \cdot ( \text{next(pc)} - \nu_C) = 0$ + \item $\text{JUMP} \cdot \nu_A \cdot ( \text{next(fp)} - \nu_B) = 0$ + \item $\text{JUMP} \cdot (1 - \nu_A) \cdot ( \text{next(pc)} - (\text{pc} + 1)) = 0$ + \item $\text{JUMP} \cdot (1 - \nu_A) \cdot (\text{next(fp)} - \text{fp}) = 0$ +\end{itemize} + +Note: the constraint $\text{JUMP} \cdot \nu_A \cdot (1 - \nu_A) = 0$ could be removed, as long as it's correctly enforced in the bytecode. + +\subsection{Data flow between tables / memory} + + +\begin{lemma} + Let $a_0, a_1, \ldots, a_{n-1}$ be pairwise distinct poles in $\Fq$, and let $m_0, m_1, \ldots, m_{n-1}$ be an associated list of multiplicities in $\{0, 1, \dots, p - 1\}$. Consider the rational function: +$$P(X) = \sum_{i=0}^{n-1} \frac{m_i}{X - a_i}$$ + +Except with probability $n/q$, if $P(\alpha) = 0$ for a random $\alpha \in \Fq$, then all multiplicities $m_i = 0$. + +\end{lemma} + +\subsubsection{Indexed Lookup into Memory} + +We use logup \cite{logup}, in its indexed form, to allow tables to perform lookups into the read-only memory. + + +\vspace{3mm} + +Let $\mathcal{T}$ denote the set of all tables in the system. For each table $T \in \mathcal{T}$ with $H_T$ rows, let $n_T$ denote the number of memory lookups. Each lookup $i < n_T$ consists of an \textbf{index column} $\text{col}_{\text{index},T,i}$ and a \textbf{value column} $\text{col}_{\text{val},T,i}$. + +\vspace{3mm} + +The rule to enforce is the following: +$$\forall T \in \mathcal{T}, \forall i < n_T, \forall j < H_T: \quad \text{col}_{\text{val},T,i}(j) = \textbf{m}[\text{col}_{\text{index},T,i}(j)]$$ + +Implicitly, we must also have $\text{col}_{\text{index},T,i}(j) < M$ (the memory size). + +\vspace{3mm} + +The prover initially commits to a multilinear polynomial $\textit{acc}$, having the same size as the memory, such that (in the honest case) for every $k < M$: +$$\textit{acc}[k] = \sum_{T \in \mathcal{T}} \sum_{i < n_T} \left| \{ j < H_T : \text{col}_{\text{index},T,i}(j) = k \} \right|$$ +i.e., $\textit{acc}[k]$ represents the total number of times address $k$ is accessed by the lookups across all tables. + +\vspace{3mm} + +The verifier sends a random challenge $\alpha \in \Fq$ (TODO quantify soundness error). Let $N = \sum_{T \in \mathcal{T}} n_T \cdot H_T$ be the total number of memory lookups. Assuming $N < p$ (to avoid overflow), the indexed lookup into memory is valid, except with probability $(N + M)/q$, if for a randomly sampled $X \in \Fq$: + +$$\sum_{T \in \mathcal{T}} \sum_{i>(); run_xmss_benchmark(&log_lifetimes, tracing); } - Cli::Recursion { tracing, count } => { - run_whir_recursion_benchmark(tracing, count); + Cli::Recursion { n, tracing } => { + run_recursion_benchmark(n, tracing); } Cli::Poseidon { log_n_perms: log_count, tracing, } => { - run_poseidon_benchmark::<16, 16, 3>(log_count, false, tracing); + benchmark_prove_poseidon_16(log_count, tracing); } } } diff --git a/src/prove_poseidons.rs b/src/prove_poseidons.rs new file mode 100644 index 00000000..87a6aafc --- /dev/null +++ b/src/prove_poseidons.rs @@ -0,0 +1,154 @@ +use air::{check_air_validity, prove_air, verify_air}; +use lean_vm::{ + EF, ExtraDataForBuses, F, POSEIDON_16_COL_A, POSEIDON_16_COL_B, POSEIDON_16_COL_COMPRESSION, POSEIDON_16_COL_FLAG, + POSEIDON_16_COL_INPUT_START, POSEIDON_16_COL_RES, POSEIDON_16_COL_RES_BIS, POSEIDON_16_DEFAULT_COMPRESSION, + POSEIDON_16_NULL_HASH_PTR, Poseidon16Precompile, ZERO_VEC_PTR, fill_trace_poseidon_16, num_cols_poseidon_16, +}; +use multilinear_toolkit::prelude::*; +use rand::{Rng, SeedableRng, rngs::StdRng}; +use utils::{ + build_prover_state, build_verifier_state, collect_refs, init_tracing, padd_with_zero_to_next_power_of_two, +}; +use whir_p3::{FoldingFactor, SecurityAssumption, SparseStatement, WhirConfig, WhirConfigBuilder}; + +const WIDTH: usize = 16; +const UNIVARIATE_SKIPS: usize = 3; + +#[test] +fn test_benchmark_air_poseidon_16() { + benchmark_prove_poseidon_16(11, false); +} + +#[allow(clippy::too_many_lines)] +pub fn benchmark_prove_poseidon_16(log_n_rows: usize, tracing: bool) { + if tracing { + init_tracing(); + } + let n_rows = 1 << log_n_rows; + let mut rng = StdRng::seed_from_u64(0); + let mut trace = vec![vec![F::ZERO; n_rows]; num_cols_poseidon_16()]; + for t in trace.iter_mut().skip(POSEIDON_16_COL_INPUT_START).take(WIDTH) { + *t = (0..n_rows).map(|_| rng.random()).collect(); + } + trace[POSEIDON_16_COL_FLAG] = (0..n_rows).map(|_| F::ONE).collect(); + trace[POSEIDON_16_COL_RES] = (0..n_rows).map(|_| F::from_usize(POSEIDON_16_NULL_HASH_PTR)).collect(); + trace[POSEIDON_16_COL_RES_BIS] = (0..n_rows).map(|_| F::from_usize(ZERO_VEC_PTR)).collect(); + trace[POSEIDON_16_COL_COMPRESSION] = (0..n_rows) + .map(|_| F::from_bool(POSEIDON_16_DEFAULT_COMPRESSION)) + .collect(); + trace[POSEIDON_16_COL_A] = (0..n_rows).map(|_| F::from_usize(ZERO_VEC_PTR)).collect(); + trace[POSEIDON_16_COL_B] = (0..n_rows).map(|_| F::from_usize(ZERO_VEC_PTR)).collect(); + fill_trace_poseidon_16(&mut trace); + + let whir_config = WhirConfigBuilder { + folding_factor: FoldingFactor::new(7, 4), + soundness_type: SecurityAssumption::JohnsonBound, + pow_bits: 16, + max_num_variables_to_send_coeffs: 6, + rs_domain_initial_reduction_factor: 5, + security_level: 123, + starting_log_inv_rate: 1, + }; + + let air = Poseidon16Precompile::; + + check_air_validity( + &air, + &ExtraDataForBuses::default(), + &collect_refs(&trace), + &[] as &[&[EF]], + &[], + &[], + ) + .unwrap(); + + let mut prover_state = build_prover_state(); + + let packed_n_vars = log2_ceil_usize(num_cols_poseidon_16() << log_n_rows); + let whir_config = WhirConfig::new(&whir_config, packed_n_vars); + + let time = std::time::Instant::now(); + + { + let mut commitmed_pol = F::zero_vec((num_cols_poseidon_16() << log_n_rows).next_power_of_two()); + for (i, col) in trace.iter().enumerate() { + commitmed_pol[i << log_n_rows..(i + 1) << log_n_rows].copy_from_slice(col); + } + let committed_pol = MleOwned::Base(commitmed_pol); + let witness = whir_config.commit(&mut prover_state, &committed_pol); + + let alpha = prover_state.sample(); + prover_state.duplexing(); + let air_alpha_powers: Vec = alpha.powers().collect_n(air.n_constraints() + 1); + let extra_data = ExtraDataForBuses { + alpha_powers: air_alpha_powers, + ..Default::default() + }; + + let air_claims = prove_air::( + &mut prover_state, + &air, + extra_data, + UNIVARIATE_SKIPS, + &collect_refs(&trace), + &[] as &[&[EF]], + None, + true, + ); + assert!(air_claims.down_point.is_none()); + assert_eq!(air_claims.evals_f.len(), air.n_columns_air()); + + let betas = prover_state.sample_vec(log2_ceil_usize(num_cols_poseidon_16())); + prover_state.duplexing(); + let packed_point = MultilinearPoint([betas.clone(), air_claims.point.0].concat()); + let packed_eval = padd_with_zero_to_next_power_of_two(&air_claims.evals_f).evaluate(&MultilinearPoint(betas)); + + whir_config.prove( + &mut prover_state, + vec![SparseStatement::dense(packed_point, packed_eval)], + witness, + &committed_pol.by_ref(), + ); + } + + println!( + "{} Poseidons / s", + (n_rows as f64 / time.elapsed().as_secs_f64()) as usize + ); + + { + let mut verifier_state = build_verifier_state(prover_state); + + let parsed_commitment = whir_config.parse_commitment::(&mut verifier_state).unwrap(); + + let alpha = verifier_state.sample(); + verifier_state.duplexing(); + let air_alpha_powers: Vec = alpha.powers().collect_n(air.n_constraints() + 1); + let extra_data = ExtraDataForBuses { + alpha_powers: air_alpha_powers, + ..Default::default() + }; + let air_claims = verify_air( + &mut verifier_state, + &air, + extra_data, + UNIVARIATE_SKIPS, + log2_ceil_usize(n_rows), + None, + ) + .unwrap(); + + let betas = verifier_state.sample_vec(log2_ceil_usize(num_cols_poseidon_16())); + verifier_state.duplexing(); + let packed_point = MultilinearPoint([betas.clone(), air_claims.point.0].concat()); + let packed_eval = padd_with_zero_to_next_power_of_two(&air_claims.evals_f).evaluate(&MultilinearPoint(betas)); + + whir_config + .verify( + &mut verifier_state, + &parsed_commitment, + vec![SparseStatement::dense(packed_point, packed_eval)], + ) + .unwrap(); + } +}