From 43a421d7185948648e49b62a20daf59aa5168e66 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 13 Aug 2025 05:19:56 -0700 Subject: [PATCH 01/12] Extend `Poseidon2` Chip for `MULTI_OBSERVE` (#5) --- crates/vm/src/metrics/cycle_tracker/mod.rs | 47 ++- crates/vm/src/metrics/mod.rs | 2 +- .../native/circuit/src/poseidon2/air.rs | 290 +++++++++++++++++- .../native/circuit/src/poseidon2/chip.rs | 6 +- .../native/circuit/src/poseidon2/columns.rs | 50 ++- .../native/circuit/src/poseidon2/tests.rs | 4 +- extensions/native/circuit/tests/ext.rs | 7 +- .../native/compiler/src/asm/compiler.rs | 15 + .../native/compiler/src/asm/instruction.rs | 8 + .../src/constraints/halo2/compiler.rs | 4 +- .../native/compiler/src/conversion/mod.rs | 12 + extensions/native/compiler/src/ir/builder.rs | 4 + .../native/compiler/src/ir/instructions.rs | 10 + extensions/native/compiler/src/ir/poseidon.rs | 43 +++ extensions/native/compiler/src/lib.rs | 1 + .../native/recursion/src/challenger/duplex.rs | 4 +- .../native/recursion/tests/recursion.rs | 154 ++++++++++ 17 files changed, 633 insertions(+), 28 deletions(-) diff --git a/crates/vm/src/metrics/cycle_tracker/mod.rs b/crates/vm/src/metrics/cycle_tracker/mod.rs index 3d989bc44b..451eb59ba3 100644 --- a/crates/vm/src/metrics/cycle_tracker/mod.rs +++ b/crates/vm/src/metrics/cycle_tracker/mod.rs @@ -1,7 +1,18 @@ +/// Stats for a nested span in the execution segment that is tracked by the [`CycleTracker`]. +#[derive(Clone, Debug, Default)] +pub struct SpanInfo { + /// The name of the span. + pub tag: String, + /// The cycle count at which the span starts. + pub start: usize, +} + #[derive(Clone, Debug, Default)] pub struct CycleTracker { /// Stack of span names, with most recent at the end - stack: Vec, + stack: Vec, + /// Depth of the stack. + depth: usize, } impl CycleTracker { @@ -10,29 +21,41 @@ impl CycleTracker { } pub fn top(&self) -> Option<&String> { - self.stack.last() + match self.stack.last() { + Some(span) => Some(&span.tag), + _ => None + } } /// Starts a new cycle tracker span for the given name. - /// If a span already exists for the given name, it ends the existing span and pushes a new one - /// to the vec. - pub fn start(&mut self, mut name: String) { + /// If a span already exists for the given name, it ends the existing span and pushes a new one to the vec. + pub fn start(&mut self, mut name: String, cycles_count: usize) { // hack to remove "CT-" prefix if name.starts_with("CT-") { name = name.split_off(3); } - self.stack.push(name); + self.stack.push(SpanInfo { + tag: name.clone(), + start: cycles_count, + }); + let padding = "│ ".repeat(self.depth); + tracing::info!("{}┌╴{}", padding, name); + self.depth += 1; } /// Ends the cycle tracker span for the given name. /// If no span exists for the given name, it panics. - pub fn end(&mut self, mut name: String) { + pub fn end(&mut self, mut name: String, cycles_count: usize) { // hack to remove "CT-" prefix if name.starts_with("CT-") { name = name.split_off(3); } - let stack_top = self.stack.pop(); - assert_eq!(stack_top.unwrap(), name, "Stack top does not match name"); + let SpanInfo { tag, start } = self.stack.pop().unwrap(); + assert_eq!(tag, name, "Stack top does not match name"); + self.depth -= 1; + let padding = "│ ".repeat(self.depth); + let span_cycles = cycles_count - start; + tracing::info!("{}└╴{} cycles", padding, span_cycles); } /// Ends the current cycle tracker span. @@ -42,7 +65,11 @@ impl CycleTracker { /// Get full name of span with all parent names separated by ";" in flamegraph format pub fn get_full_name(&self) -> String { - self.stack.join(";") + self.stack + .iter() + .map(|span_info| span_info.tag.clone()) + .collect::>() + .join(";") } } diff --git a/crates/vm/src/metrics/mod.rs b/crates/vm/src/metrics/mod.rs index 698ba1f1e8..7af679b3cd 100644 --- a/crates/vm/src/metrics/mod.rs +++ b/crates/vm/src/metrics/mod.rs @@ -224,7 +224,7 @@ impl VmMetrics { .map(|(_, func)| (*func).clone()) .unwrap(); if pc == self.current_fn.start { - self.cycle_tracker.start(self.current_fn.name.clone()); + self.cycle_tracker.start(self.current_fn.name.clone(), 0); } else { while let Some(name) = self.cycle_tracker.top() { if name == &self.current_fn.name { diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index adf2c09a62..373995bce9 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -2,13 +2,13 @@ use std::{array::from_fn, borrow::Borrow, sync::Arc}; use openvm_circuit::{ arch::{ExecutionBridge, ExecutionState}, - system::memory::{offline_checker::MemoryBridge, MemoryAddress}, + system::memory::{offline_checker::MemoryBridge, MemoryAddress, CHUNK}, }; use openvm_circuit_primitives::utils::not; use openvm_instructions::LocalOpcode; use openvm_native_compiler::{ conversion::AS, - Poseidon2Opcode::{COMP_POS2, PERM_POS2}, + Poseidon2Opcode::{COMP_POS2, PERM_POS2, MULTI_OBSERVE}, VerifyBatchOpcode::VERIFY_BATCH, }; use openvm_poseidon2_air::{ @@ -27,9 +27,8 @@ use crate::poseidon2::{ chip::{NUM_INITIAL_READS, NUM_SIMPLE_ACCESSES}, columns::{ InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, - TopLevelSpecificCols, + TopLevelSpecificCols, MultiObserveCols, }, - CHUNK, }; #[derive(Clone, Debug)] @@ -90,6 +89,7 @@ impl Air incorporate_sibling, inside_row, simple, + multi_observe_row, end_inside_row, end_top_level, start_top_level, @@ -117,7 +117,8 @@ impl Air builder.assert_bool(incorporate_sibling); builder.assert_bool(inside_row); builder.assert_bool(simple); - let enabled = incorporate_row + incorporate_sibling + inside_row + simple; + builder.assert_bool(multi_observe_row); + let enabled = incorporate_row + incorporate_sibling + inside_row + simple + multi_observe_row; builder.assert_bool(enabled.clone()); builder.assert_bool(end_inside_row); builder.when(end_inside_row).assert_one(inside_row); @@ -698,6 +699,285 @@ impl Air &write_data_2, ) .eval(builder, simple * is_permute); + + //// multi_observe contraints + let multi_observe_specific: &MultiObserveCols = + specific[..MultiObserveCols::::width()].borrow(); + let next_multi_observe_specific: &MultiObserveCols = + next.specific[..MultiObserveCols::::width()].borrow(); + let &MultiObserveCols { + pc, + final_timestamp_increment, + state_ptr, + input_ptr, + init_pos, + len, + is_first, + is_last, + curr_len, + start_idx, + end_idx, + aux_after_start, + aux_before_end, + read_data, + write_data, + data, + should_permute, + read_sponge_state, + write_sponge_state, + write_final_idx, + final_idx, + input_register_1, + input_register_2, + input_register_3, + output_register + } = multi_observe_specific; + + builder + .when(multi_observe_row) + .assert_bool(is_first); + builder + .when(multi_observe_row) + .assert_bool(is_last); + builder + .when(multi_observe_row) + .assert_bool(should_permute); + + self.execution_bridge + .execute_and_increment_pc( + AB::F::from_canonical_usize(MULTI_OBSERVE.global_opcode().as_usize()), + [ + output_register.into(), + input_register_1.into(), + input_register_2.into(), + self.address_space.into(), + self.address_space.into(), + input_register_3.into(), + ], + ExecutionState::new(pc, very_first_timestamp), + final_timestamp_increment, + ) + .eval(builder, multi_observe_row * is_first); + + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, output_register), + [state_ptr], + very_first_timestamp, + &read_data[0], + ) + .eval(builder, multi_observe_row * is_first); + + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, input_register_1), + [init_pos], + very_first_timestamp + AB::F::ONE, + &read_data[1], + ) + .eval(builder, multi_observe_row * is_first); + + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, input_register_2), + [input_ptr], + very_first_timestamp + AB::F::TWO, + &read_data[2], + ) + .eval(builder, multi_observe_row * is_first); + + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, input_register_3), + [len], + very_first_timestamp + AB::F::from_canonical_usize(3), + &read_data[3], + ) + .eval(builder, multi_observe_row * is_first); + + for i in 0..CHUNK { + let i_var = AB::F::from_canonical_usize(i); + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, input_ptr + curr_len + i_var - start_idx), + [data[i]], + start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO, + &read_data[i] + ) + .eval(builder, aux_after_start[i] * aux_before_end[i]); + + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + state_ptr + i_var, + ), + [data[i]], + start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO + AB::F::ONE, + &write_data[i], + ) + .eval(builder, aux_after_start[i] * aux_before_end[i]); + } + + for i in 0..(CHUNK - 1) { + builder + .when(aux_after_start[i]) + .assert_one(aux_after_start[i + 1]); + } + + for i in 1..CHUNK { + builder + .when(aux_before_end[i]) + .assert_one(aux_before_end[i - 1]); + } + + builder + .when(multi_observe_row) + .when(not(is_first)) + .assert_eq( + aux_after_start[0] + + aux_after_start[1] + + aux_after_start[2] + + aux_after_start[3] + + aux_after_start[4] + + aux_after_start[5] + + aux_after_start[6] + + aux_after_start[7], + AB::Expr::from_canonical_usize(CHUNK) - start_idx.into() + ); + + builder + .when(multi_observe_row) + .when(not(is_first)) + .assert_eq( + aux_before_end[0] + + aux_before_end[1] + + aux_before_end[2] + + aux_before_end[3] + + aux_before_end[4] + + aux_before_end[5] + + aux_before_end[6] + + aux_before_end[7], + end_idx + ); + + let full_sponge_input = from_fn::<_, {CHUNK * 2}, _>(|i| local.inner.inputs[i]); + let full_sponge_output = from_fn::<_, {CHUNK * 2}, _>(|i| local.inner.ending_full_rounds[BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS - 1].post[i]); + + self.memory_bridge + .read( + MemoryAddress::new( + self.address_space, + state_ptr, + ), + full_sponge_input, + start_timestamp + end_idx * AB::F::TWO - start_idx * AB::F::TWO, + &read_sponge_state, + ) + .eval(builder, multi_observe_row * should_permute); + + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + state_ptr + ), + full_sponge_output, + start_timestamp + end_idx * AB::F::TWO - start_idx * AB::F::TWO + AB::F::ONE, + &write_sponge_state, + ) + .eval(builder, multi_observe_row * should_permute); + + self.memory_bridge + .write( + MemoryAddress::new( + self.address_space, + input_register_1, + ), + [final_idx], + start_timestamp + is_first * AB::F::from_canonical_usize(4) + (end_idx - start_idx) * AB::F::TWO + should_permute * AB::F::TWO, + &write_final_idx + ) + .eval(builder, multi_observe_row * is_last); + + // Field transitions + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(next_multi_observe_specific.curr_len, multi_observe_specific.curr_len + end_idx - start_idx); + + // Boundary conditions + builder + .when(multi_observe_row) + .when(is_first) + .assert_zero(curr_len); + + builder + .when(multi_observe_row) + .when(is_last) + .assert_eq(curr_len + (end_idx - start_idx), len); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_one(multi_observe_row); + + builder + .when(multi_observe_row) + .when(not(is_last)) + .assert_one(next.multi_observe_row); + + // Field consistency + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(state_ptr, next_multi_observe_specific.state_ptr); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(input_ptr, next_multi_observe_specific.input_ptr); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(init_pos, next_multi_observe_specific.init_pos); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(len, next_multi_observe_specific.len); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(input_register_1, next_multi_observe_specific.input_register_1); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(input_register_2, next_multi_observe_specific.input_register_2); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(input_register_3, next_multi_observe_specific.input_register_3); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(output_register, next_multi_observe_specific.output_register); + + // Timestamp constraints + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(very_first_timestamp, next.very_first_timestamp); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(next.start_timestamp, start_timestamp + is_first * AB::F::from_canonical_usize(4) + (end_idx - start_idx) * AB::F::TWO + should_permute * AB::F::TWO); } } diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 331cb9dbd0..01a77d0c65 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -12,7 +12,7 @@ use openvm_circuit::{ use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{ conversion::AS, - Poseidon2Opcode::{COMP_POS2, PERM_POS2}, + Poseidon2Opcode::{COMP_POS2, PERM_POS2, MULTI_OBSERVE}, VerifyBatchOpcode::VERIFY_BATCH, }; use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubChip, Poseidon2SubCols}; @@ -659,7 +659,9 @@ where String::from("PERM_POS2") } else if opcode == COMP_POS2.global_opcode().as_usize() { String::from("COMP_POS2") - } else { + } else if opcode == MULTI_OBSERVE.global_opcode().as_usize() { + String::from("MULTI_OBSERVE") + }else { unreachable!("unsupported opcode: {}", opcode) } } diff --git a/extensions/native/circuit/src/poseidon2/columns.rs b/extensions/native/circuit/src/poseidon2/columns.rs index 6c47c23245..fe0fce881a 100644 --- a/extensions/native/circuit/src/poseidon2/columns.rs +++ b/extensions/native/circuit/src/poseidon2/columns.rs @@ -28,6 +28,8 @@ pub struct NativePoseidon2Cols { pub inside_row: T, /// Indicates that this row is a simple row. pub simple: T, + /// Indicates that this row is a multi_observe row. + pub multi_observe_row: T, /// Indicates the last row in an inside-row block. pub end_inside_row: T, @@ -60,15 +62,16 @@ pub struct NativePoseidon2Cols { /// indicates that cell `i + 1` inside a chunk is exhausted. pub is_exhausted: [T; CHUNK - 1], - pub specific: [T; max3( + pub specific: [T; max4( TopLevelSpecificCols::::width(), InsideRowSpecificCols::::width(), SimplePoseidonSpecificCols::::width(), + MultiObserveCols::::width(), )], } -const fn max3(a: usize, b: usize, c: usize) -> usize { - const_max(a, const_max(b, c)) +const fn max4(a: usize, b: usize, c: usize, d: usize) -> usize { + const_max(a, const_max(b, const_max(c, d))) } #[repr(C)] #[derive(AlignedBorrow)] @@ -200,3 +203,44 @@ pub struct SimplePoseidonSpecificCols { pub write_data_1: MemoryWriteAuxCols, pub write_data_2: MemoryWriteAuxCols, } + +#[repr(C)] +#[derive(AlignedBorrow, Copy, Clone)] +pub struct MultiObserveCols { + // Program states + pub pc: T, + pub final_timestamp_increment: T, + + // Initial reads from registers + pub state_ptr: T, + pub input_ptr: T, + pub init_pos: T, + pub len: T, + + pub is_first: T, + pub is_last: T, + pub curr_len: T, + pub start_idx: T, + pub end_idx: T, + pub aux_after_start: [T; CHUNK], + pub aux_before_end: [T; CHUNK], + + // Transcript observation + pub read_data: [MemoryReadAuxCols; CHUNK], + pub write_data: [MemoryWriteAuxCols; CHUNK], + pub data: [T; CHUNK], + + // Permutation + pub should_permute: T, + pub read_sponge_state: MemoryReadAuxCols, + pub write_sponge_state: MemoryWriteAuxCols, + + // Final write back and registers + pub write_final_idx: MemoryWriteAuxCols, + pub final_idx: T, + + pub input_register_1: T, + pub input_register_2: T, + pub input_register_3: T, + pub output_register: T, +} \ No newline at end of file diff --git a/extensions/native/circuit/src/poseidon2/tests.rs b/extensions/native/circuit/src/poseidon2/tests.rs index 197def47b8..bce4358451 100644 --- a/extensions/native/circuit/src/poseidon2/tests.rs +++ b/extensions/native/circuit/src/poseidon2/tests.rs @@ -467,7 +467,8 @@ fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester { tester.write(e, lhs, data_left); tester.write(e, lhs + CHUNK, data_right); - } + }, + MULTI_OBSERVE => {} } tester.execute(&mut harness.executor, &mut harness.arena, &instruction); @@ -484,6 +485,7 @@ fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester {} } } tester.build().load(harness).finalize() diff --git a/extensions/native/circuit/tests/ext.rs b/extensions/native/circuit/tests/ext.rs index 5da70cb53b..70584fd926 100644 --- a/extensions/native/circuit/tests/ext.rs +++ b/extensions/native/circuit/tests/ext.rs @@ -35,7 +35,7 @@ fn test_ext2felt() { } #[test] -fn test_ext_from_base_slice() { +fn test_ext_from_base_vec() { const D: usize = 4; type F = BabyBear; type EF = BinomialExtensionField; @@ -52,8 +52,9 @@ fn test_ext_from_base_slice() { let val = EF::from_base_slice(base_slice); let expected: Ext<_, _> = builder.constant(val); - let felts = base_slice.map(|e| builder.constant::>(e)); - let actual = builder.ext_from_base_slice(&felts); + let felts = base_slice.map(|e| builder.constant::>(e)).to_vec(); + let actual = builder.uninit(); + builder.ext_from_base_vec(actual, felts); builder.assert_ext_eq(actual, expected); builder.halt(); diff --git a/extensions/native/compiler/src/asm/compiler.rs b/extensions/native/compiler/src/asm/compiler.rs index a09c8a217e..0e0db9cff9 100644 --- a/extensions/native/compiler/src/asm/compiler.rs +++ b/extensions/native/compiler/src/asm/compiler.rs @@ -489,6 +489,12 @@ impl + TwoAdicField> AsmCo DslIr::HintBitsF(var, len) => { self.push(AsmInstruction::HintBits(var.fp(), len), debug_info); } + DslIr::Poseidon2MultiObserve(dst, init_pos, arr_ptr, len) => { + self.push( + AsmInstruction::Poseidon2MultiObserve(dst.fp(), init_pos.fp(), arr_ptr.fp(), len.get_var().fp()), + debug_info, + ); + }, DslIr::Poseidon2PermuteBabyBear(dst, src) => match (dst, src) { (Array::Dyn(dst, _), Array::Dyn(src, _)) => self.push( AsmInstruction::Poseidon2Permute(dst.fp(), src.fp()), @@ -617,6 +623,15 @@ impl + TwoAdicField> AsmCo debug_info, ); } + DslIr::ExtFromBaseVec(ext, base_vec) => { + assert_eq!(base_vec.len(), EF::D); + for (i, base) in base_vec.into_iter().enumerate() { + self.push( + AsmInstruction::CopyF(ext.fp() + (i as i32), base.fp()), + debug_info.clone(), + ); + } + } _ => unimplemented!(), } } diff --git a/extensions/native/compiler/src/asm/instruction.rs b/extensions/native/compiler/src/asm/instruction.rs index 1aa5ea8527..cd4990b08b 100644 --- a/extensions/native/compiler/src/asm/instruction.rs +++ b/extensions/native/compiler/src/asm/instruction.rs @@ -110,6 +110,11 @@ pub enum AsmInstruction { /// Halt. Halt, + /// Absorbs multiple base elements into a duplex transcript with Poseidon2 permutation + /// (sponge_state, init_pos, arr_ptr, len) + /// Returns the final index position of hash sponge + Poseidon2MultiObserve(i32, i32, i32, i32), + /// Perform a Poseidon2 permutation on state starting at address `lhs` /// and store new state at `rhs`. /// (a, b) are pointers to (lhs, rhs). @@ -334,6 +339,9 @@ impl> AsmInstruction { AsmInstruction::Trap => write!(f, "trap"), AsmInstruction::Halt => write!(f, "halt"), AsmInstruction::HintBits(src, len) => write!(f, "hint_bits ({})fp, {}", src, len), + AsmInstruction::Poseidon2MultiObserve(dst, init_pos, arr, len) => { + write!(f, "poseidon2_multi_observe ({})fp, ({})fp ({})fp ({})fp", dst, init_pos, arr, len) + } AsmInstruction::Poseidon2Permute(dst, lhs) => { write!(f, "poseidon2_permute ({})fp, ({})fp", dst, lhs) } diff --git a/extensions/native/compiler/src/constraints/halo2/compiler.rs b/extensions/native/compiler/src/constraints/halo2/compiler.rs index fd75d526d2..3bc7efcf04 100644 --- a/extensions/native/compiler/src/constraints/halo2/compiler.rs +++ b/extensions/native/compiler/src/constraints/halo2/compiler.rs @@ -493,11 +493,11 @@ impl Halo2ConstraintCompiler { } DslIr::CycleTrackerStart(_name) => { #[cfg(feature = "metrics")] - cell_tracker.start(_name); + cell_tracker.start(_name, 0); } DslIr::CycleTrackerEnd(_name) => { #[cfg(feature = "metrics")] - cell_tracker.end(_name); + cell_tracker.end(_name, 0); } DslIr::CircuitPublish(val, index) => { public_values[index] = vars[&val.0]; diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index af4e5080fb..82ee912703 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -441,6 +441,18 @@ fn convert_instruction>( AS::Native, AS::Native, )], + AsmInstruction::Poseidon2MultiObserve(dst, init, arr, len) => vec![ + Instruction { + opcode: options.opcode_with_offset(Poseidon2Opcode::MULTI_OBSERVE), + a: i32_f(dst), + b: i32_f(init), + c: i32_f(arr), + d: AS::Native.to_field(), + e: AS::Native.to_field(), + f: i32_f(len), + g: F::ZERO, + } + ], AsmInstruction::Poseidon2Compress(dst, src1, src2) => vec![inst( options.opcode_with_offset(Poseidon2Opcode::COMP_POS2), i32_f(dst), diff --git a/extensions/native/compiler/src/ir/builder.rs b/extensions/native/compiler/src/ir/builder.rs index 966c0db21c..64c62e64fa 100644 --- a/extensions/native/compiler/src/ir/builder.rs +++ b/extensions/native/compiler/src/ir/builder.rs @@ -620,6 +620,10 @@ impl Builder { self.witness_space.get(id.value()).unwrap() } + pub fn ext_from_base_vec(&mut self, ext: Ext, base_vec: Vec>) { + self.push(DslIr::ExtFromBaseVec(ext, base_vec)); + } + /// Throws an error. pub fn error(&mut self) { self.operations.trace_push(DslIr::Error()); diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index fa4d8b9931..13f5c4a653 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -208,6 +208,13 @@ pub enum DslIr { /// Permutes an array of Bn254 elements using Poseidon2 (output = p2_permute(array)). Should /// only be used when target is a circuit. CircuitPoseidon2Permute([Var; 3]), + /// Absorbs an array of baby bear elements into a duplex transcript with Poseidon2 permutations (output = p2_multi_observe(array, els)). + Poseidon2MultiObserve( + Ptr, // sponge_state + Var, // initial input_ptr position + Ptr, // input array (base elements) + Usize, // len of els + ), // Miscellaneous instructions. /// Prints a variable. @@ -241,6 +248,9 @@ pub enum DslIr { /// Operation to halt the program. Should be the last instruction in the program. Halt, + /// Packs a vector of felts into an ext. + ExtFromBaseVec(Ext, Vec>), + // Public inputs for circuits. /// Publish a field element as the ith public value. Should only be used when target is a /// circuit. diff --git a/extensions/native/compiler/src/ir/poseidon.rs b/extensions/native/compiler/src/ir/poseidon.rs index 12ec526c89..ee9bc0d87d 100644 --- a/extensions/native/compiler/src/ir/poseidon.rs +++ b/extensions/native/compiler/src/ir/poseidon.rs @@ -1,6 +1,8 @@ use openvm_native_compiler_derive::iter_zip; use openvm_stark_backend::p3_field::FieldAlgebra; +use crate::ir::Variable; + use super::{Array, ArrayLike, Builder, Config, DslIr, Ext, Felt, MemIndex, Ptr, Usize, Var}; pub const DIGEST_SIZE: usize = 8; @@ -8,6 +10,47 @@ pub const HASH_RATE: usize = 8; pub const PERMUTATION_WIDTH: usize = 16; impl Builder { + /// Extends native VM ability to observe multiple base elements in one opcode operation + /// Absorbs elements sequentially at the RATE portion of sponge state and performs as many permutations as necessary. + /// Returns the index position of the next input_ptr. + /// + /// [Reference](https://docs.rs/p3-poseidon2/latest/p3_poseidon2/struct.Poseidon2.html) + pub fn poseidon2_multi_observe( + &mut self, + sponge_state: &Array>, + input_ptr: Ptr, + arr: &Array>, + ) -> Usize { + let buffer_size: Var = Var::uninit(self); + self.assign(&buffer_size, C::N::from_canonical_usize(HASH_RATE)); + + match sponge_state { + Array::Fixed(_) => { + panic!("Poseidon2 permutation is not allowed on fixed arrays"); + } + Array::Dyn(sponge_ptr, _) => { + match arr { + Array::Fixed(_) => { + panic!("Base elements input must be dynamic"); + } + Array::Dyn(ptr, len) => { + let init_pos: Var = Var::uninit(self); + self.assign(&init_pos, input_ptr.address - sponge_ptr.address); + + self.operations.push(DslIr::Poseidon2MultiObserve( + *sponge_ptr, + init_pos, + *ptr, + len.clone(), + )); + + Usize::Var(init_pos) + } + } + } + } + } + /// Applies the Poseidon2 permutation to the given array. /// /// [Reference](https://docs.rs/p3-poseidon2/latest/p3_poseidon2/struct.Poseidon2.html) diff --git a/extensions/native/compiler/src/lib.rs b/extensions/native/compiler/src/lib.rs index ef28b37139..66c786fbd9 100644 --- a/extensions/native/compiler/src/lib.rs +++ b/extensions/native/compiler/src/lib.rs @@ -184,6 +184,7 @@ pub enum NativePhantom { pub enum Poseidon2Opcode { PERM_POS2, COMP_POS2, + MULTI_OBSERVE, } /// Opcodes for FRI opening proofs. diff --git a/extensions/native/recursion/src/challenger/duplex.rs b/extensions/native/recursion/src/challenger/duplex.rs index 7c0cd4dd88..2d45d896be 100644 --- a/extensions/native/recursion/src/challenger/duplex.rs +++ b/extensions/native/recursion/src/challenger/duplex.rs @@ -101,7 +101,9 @@ impl DuplexChallengerVariable { let b = self.sample(builder); let c = self.sample(builder); let d = self.sample(builder); - builder.ext_from_base_slice(&[a, b, c, d]) + let ext = builder.uninit(); + builder.ext_from_base_vec(ext, vec![a, b, c, d]); + ext } fn sample_bits(&self, builder: &mut Builder, nb_bits: RVar) -> Array> diff --git a/extensions/native/recursion/tests/recursion.rs b/extensions/native/recursion/tests/recursion.rs index 6322614784..6bcb913f57 100644 --- a/extensions/native/recursion/tests/recursion.rs +++ b/extensions/native/recursion/tests/recursion.rs @@ -166,3 +166,157 @@ fn test_fibonacci_program_halo2_verify() { FriParameters::new_for_testing(LOG_BLOWUP), ); } + +#[test] +fn test_multi_observe() { + let mut builder = AsmBuilder::>::default(); + + build_test_program(&mut builder); + + // Fill in test program logic + builder.halt(); + + let compilation_options = CompilerOptions::default().with_cycle_tracker(); + let mut compiler = AsmCompiler::new(compilation_options.word_size); + compiler.build(builder.operations); + let asm_code = compiler.code(); + + // let program = Program::from_instructions(&instructions); + let program: Program<_> = convert_program(asm_code, compilation_options); + + let poseidon2_max_constraint_degree = 3; + + let fri_params = if matches!(std::env::var("OPENVM_FAST_TEST"), Ok(x) if &x == "1") { + FriParameters { + // max constraint degree = 2^log_blowup + 1 + log_blowup: 1, + log_final_poly_len: 0, + num_queries: 2, + proof_of_work_bits: 0, + } + } else { + standard_fri_params_with_100_bits_conjectured_security(1) + }; + + let engine = BabyBearPoseidon2Engine::new(fri_params); + let mut config = NativeConfig::aggregation(0, poseidon2_max_constraint_degree); + config.system.memory_config.max_access_adapter_n = 16; + + let vm = VirtualMachine::new(engine, config); + + let pk = vm.keygen(); + let result = vm.execute_and_generate(program, vec![]).unwrap(); + let proofs = vm.prove(&pk, result); + for proof in proofs { + verify_single(&vm.engine, &pk.get_vk(), &proof).expect("Verification failed"); + } +} + +fn build_test_program( + builder: &mut Builder, +) { + let sample_lens: Vec = vec![10, 2, 0, 3, 20]; + + let mut rng = create_seeded_rng(); + let challenger = DuplexChallengerVariable::new(builder); + + for l in sample_lens { + let sample_input: Array> = builder.dyn_array(l); + builder.range(0, l).for_each(|idx_vec, builder| { + let f_u32: u32 = rng.gen_range(1..1 << 30); + builder.set(&sample_input, idx_vec[0], C::F::from_canonical_u32(f_u32)); + }); + + let next_input_ptr = builder.poseidon2_multi_observe(&challenger.sponge_state, challenger.input_ptr, &sample_input); + + builder.assign( + &challenger.input_ptr, + challenger.io_empty_ptr + next_input_ptr.clone(), + ); + builder.if_ne(next_input_ptr, Usize::from(0)).then_or_else( + |builder| { + builder.assign(&challenger.output_ptr, challenger.io_empty_ptr); + }, + |builder| { + builder.assign(&challenger.output_ptr, challenger.io_full_ptr); + }, + ); + } +} + +#[test] +fn test_multi_observe() { + let mut builder = AsmBuilder::>::default(); + + build_test_program(&mut builder); + + // Fill in test program logic + builder.halt(); + + let compilation_options = CompilerOptions::default().with_cycle_tracker(); + let mut compiler = AsmCompiler::new(compilation_options.word_size); + compiler.build(builder.operations); + let asm_code = compiler.code(); + + // let program = Program::from_instructions(&instructions); + let program: Program<_> = convert_program(asm_code, compilation_options); + + let poseidon2_max_constraint_degree = 3; + + let fri_params = if matches!(std::env::var("OPENVM_FAST_TEST"), Ok(x) if &x == "1") { + FriParameters { + // max constraint degree = 2^log_blowup + 1 + log_blowup: 1, + log_final_poly_len: 0, + num_queries: 2, + proof_of_work_bits: 0, + } + } else { + standard_fri_params_with_100_bits_conjectured_security(1) + }; + + let engine = BabyBearPoseidon2Engine::new(fri_params); + let mut config = NativeConfig::aggregation(0, poseidon2_max_constraint_degree); + config.system.memory_config.max_access_adapter_n = 16; + + let vm = VirtualMachine::new(engine, config); + + let pk = vm.keygen(); + let result = vm.execute_and_generate(program, vec![]).unwrap(); + let proofs = vm.prove(&pk, result); + for proof in proofs { + verify_single(&vm.engine, &pk.get_vk(), &proof).expect("Verification failed"); + } +} + +fn build_test_program( + builder: &mut Builder, +) { + let sample_lens: Vec = vec![10, 2, 0, 3, 20]; + + let mut rng = create_seeded_rng(); + let challenger = DuplexChallengerVariable::new(builder); + + for l in sample_lens { + let sample_input: Array> = builder.dyn_array(l); + builder.range(0, l).for_each(|idx_vec, builder| { + let f_u32: u32 = rng.gen_range(1..1 << 30); + builder.set(&sample_input, idx_vec[0], C::F::from_canonical_u32(f_u32)); + }); + + let next_input_ptr = builder.poseidon2_multi_observe(&challenger.sponge_state, challenger.input_ptr, &sample_input); + + builder.assign( + &challenger.input_ptr, + challenger.io_empty_ptr + next_input_ptr.clone(), + ); + builder.if_ne(next_input_ptr, Usize::from(0)).then_or_else( + |builder| { + builder.assign(&challenger.output_ptr, challenger.io_empty_ptr); + }, + |builder| { + builder.assign(&challenger.output_ptr, challenger.io_full_ptr); + }, + ); + } +} From 5cf662556f7f40ebdb54ed1d700950c6b4f978af Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Wed, 19 Nov 2025 15:22:55 +0800 Subject: [PATCH 02/12] fix --- .../native/recursion/tests/recursion.rs | 127 ++++-------------- 1 file changed, 27 insertions(+), 100 deletions(-) diff --git a/extensions/native/recursion/tests/recursion.rs b/extensions/native/recursion/tests/recursion.rs index 6bcb913f57..e3a5fd0f3b 100644 --- a/extensions/native/recursion/tests/recursion.rs +++ b/extensions/native/recursion/tests/recursion.rs @@ -1,16 +1,22 @@ use itertools::Itertools; use openvm_circuit::{ arch::{ - instructions::program::Program, PreflightExecutionOutput, PreflightExecutor, VmBuilder, - VmCircuitConfig, VmExecutionConfig, + PreflightExecutionOutput, PreflightExecutor, VmBuilder, VmCircuitConfig, VmExecutionConfig, instructions::program::Program }, - utils::TestStarkEngine, + utils::{TestStarkEngine, air_test_impl}, }; use openvm_native_circuit::{ execute_program_with_config, test_native_config, NativeBuilder, NativeConfig, }; -use openvm_native_compiler::{asm::AsmBuilder, ir::Felt}; -use openvm_native_recursion::testing_utils::inner::run_recursive_test; +use openvm_native_compiler::{ + asm::{AsmBuilder, AsmCompiler}, + conversion::{convert_program, CompilerOptions}, + ir::{Array, Builder, Config, Felt}, + prelude::Usize, +}; +use openvm_native_recursion::{ + challenger::duplex::DuplexChallengerVariable, testing_utils::inner::run_recursive_test, +}; use openvm_stark_backend::{ config::{Domain, StarkGenericConfig, Val}, p3_commit::PolynomialSpace, @@ -24,12 +30,14 @@ use openvm_stark_backend::{ use openvm_stark_sdk::{ config::{ baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine}, + fri_params::standard_fri_params_with_100_bits_conjectured_security, FriParameters, }, engine::StarkFriEngine, p3_baby_bear::BabyBear, - utils::ProofInputForTest, + utils::{create_seeded_rng, ProofInputForTest}, }; +use rand::Rng; fn fibonacci_program(a: u32, b: u32, n: u32) -> Program { type F = BabyBear; @@ -169,13 +177,12 @@ fn test_fibonacci_program_halo2_verify() { #[test] fn test_multi_observe() { - let mut builder = AsmBuilder::>::default(); + type F = BabyBear; + type EF = BinomialExtensionField; + let mut builder = AsmBuilder::::default(); build_test_program(&mut builder); - // Fill in test program logic - builder.halt(); - let compilation_options = CompilerOptions::default().with_cycle_tracker(); let mut compiler = AsmCompiler::new(compilation_options.word_size); compiler.build(builder.operations); @@ -185,88 +192,11 @@ fn test_multi_observe() { let program: Program<_> = convert_program(asm_code, compilation_options); let poseidon2_max_constraint_degree = 3; - - let fri_params = if matches!(std::env::var("OPENVM_FAST_TEST"), Ok(x) if &x == "1") { - FriParameters { - // max constraint degree = 2^log_blowup + 1 - log_blowup: 1, - log_final_poly_len: 0, - num_queries: 2, - proof_of_work_bits: 0, - } - } else { - standard_fri_params_with_100_bits_conjectured_security(1) - }; - - let engine = BabyBearPoseidon2Engine::new(fri_params); - let mut config = NativeConfig::aggregation(0, poseidon2_max_constraint_degree); - config.system.memory_config.max_access_adapter_n = 16; - - let vm = VirtualMachine::new(engine, config); - - let pk = vm.keygen(); - let result = vm.execute_and_generate(program, vec![]).unwrap(); - let proofs = vm.prove(&pk, result); - for proof in proofs { - verify_single(&vm.engine, &pk.get_vk(), &proof).expect("Verification failed"); - } -} - -fn build_test_program( - builder: &mut Builder, -) { - let sample_lens: Vec = vec![10, 2, 0, 3, 20]; - - let mut rng = create_seeded_rng(); - let challenger = DuplexChallengerVariable::new(builder); - - for l in sample_lens { - let sample_input: Array> = builder.dyn_array(l); - builder.range(0, l).for_each(|idx_vec, builder| { - let f_u32: u32 = rng.gen_range(1..1 << 30); - builder.set(&sample_input, idx_vec[0], C::F::from_canonical_u32(f_u32)); - }); - let next_input_ptr = builder.poseidon2_multi_observe(&challenger.sponge_state, challenger.input_ptr, &sample_input); - - builder.assign( - &challenger.input_ptr, - challenger.io_empty_ptr + next_input_ptr.clone(), - ); - builder.if_ne(next_input_ptr, Usize::from(0)).then_or_else( - |builder| { - builder.assign(&challenger.output_ptr, challenger.io_empty_ptr); - }, - |builder| { - builder.assign(&challenger.output_ptr, challenger.io_full_ptr); - }, - ); - } -} - -#[test] -fn test_multi_observe() { - let mut builder = AsmBuilder::>::default(); - - build_test_program(&mut builder); - - // Fill in test program logic - builder.halt(); - - let compilation_options = CompilerOptions::default().with_cycle_tracker(); - let mut compiler = AsmCompiler::new(compilation_options.word_size); - compiler.build(builder.operations); - let asm_code = compiler.code(); - - // let program = Program::from_instructions(&instructions); - let program: Program<_> = convert_program(asm_code, compilation_options); - - let poseidon2_max_constraint_degree = 3; - let fri_params = if matches!(std::env::var("OPENVM_FAST_TEST"), Ok(x) if &x == "1") { FriParameters { // max constraint degree = 2^log_blowup + 1 - log_blowup: 1, + log_blowup: 1, log_final_poly_len: 0, num_queries: 2, proof_of_work_bits: 0, @@ -275,23 +205,15 @@ fn test_multi_observe() { standard_fri_params_with_100_bits_conjectured_security(1) }; - let engine = BabyBearPoseidon2Engine::new(fri_params); let mut config = NativeConfig::aggregation(0, poseidon2_max_constraint_degree); config.system.memory_config.max_access_adapter_n = 16; - let vm = VirtualMachine::new(engine, config); + let vb = NativeBuilder::default(); + air_test_impl::(fri_params, vb, config, program, vec![], 1, true).unwrap(); - let pk = vm.keygen(); - let result = vm.execute_and_generate(program, vec![]).unwrap(); - let proofs = vm.prove(&pk, result); - for proof in proofs { - verify_single(&vm.engine, &pk.get_vk(), &proof).expect("Verification failed"); - } } -fn build_test_program( - builder: &mut Builder, -) { +fn build_test_program(builder: &mut Builder) { let sample_lens: Vec = vec![10, 2, 0, 3, 20]; let mut rng = create_seeded_rng(); @@ -304,7 +226,11 @@ fn build_test_program( builder.set(&sample_input, idx_vec[0], C::F::from_canonical_u32(f_u32)); }); - let next_input_ptr = builder.poseidon2_multi_observe(&challenger.sponge_state, challenger.input_ptr, &sample_input); + let next_input_ptr = builder.poseidon2_multi_observe( + &challenger.sponge_state, + challenger.input_ptr, + &sample_input, + ); builder.assign( &challenger.input_ptr, @@ -319,4 +245,5 @@ fn build_test_program( }, ); } + builder.halt(); } From d994e784926565136bc94164bd338a76e0383f4d Mon Sep 17 00:00:00 2001 From: xkx Date: Mon, 24 Nov 2025 22:05:18 +0800 Subject: [PATCH 03/12] Feat: support `multi_observe` (#12) * fix1 * fix2 * execution wip * fix 3 * fix 4 --- .../native/circuit/src/extension/mod.rs | 1 + .../native/circuit/src/poseidon2/air.rs | 116 ++++---- .../native/circuit/src/poseidon2/chip.rs | 275 +++++++++++++++++- .../native/circuit/src/poseidon2/columns.rs | 12 +- .../native/circuit/src/poseidon2/execution.rs | 191 +++++++++++- .../native/recursion/tests/recursion.rs | 11 +- 6 files changed, 533 insertions(+), 73 deletions(-) diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index 98b2fe774d..a86cdb1bd2 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -165,6 +165,7 @@ impl VmExecutionExtension for Native { VerifyBatchOpcode::VERIFY_BATCH.global_opcode(), Poseidon2Opcode::PERM_POS2.global_opcode(), Poseidon2Opcode::COMP_POS2.global_opcode(), + Poseidon2Opcode::MULTI_OBSERVE.global_opcode(), ], )?; diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index 373995bce9..adf01695e4 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -8,7 +8,7 @@ use openvm_circuit_primitives::utils::not; use openvm_instructions::LocalOpcode; use openvm_native_compiler::{ conversion::AS, - Poseidon2Opcode::{COMP_POS2, PERM_POS2, MULTI_OBSERVE}, + Poseidon2Opcode::{COMP_POS2, MULTI_OBSERVE, PERM_POS2}, VerifyBatchOpcode::VERIFY_BATCH, }; use openvm_poseidon2_air::{ @@ -26,8 +26,8 @@ use openvm_stark_backend::{ use crate::poseidon2::{ chip::{NUM_INITIAL_READS, NUM_SIMPLE_ACCESSES}, columns::{ - InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, - TopLevelSpecificCols, MultiObserveCols, + InsideRowSpecificCols, MultiObserveCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, + TopLevelSpecificCols, }, }; @@ -118,7 +118,8 @@ impl Air builder.assert_bool(inside_row); builder.assert_bool(simple); builder.assert_bool(multi_observe_row); - let enabled = incorporate_row + incorporate_sibling + inside_row + simple + multi_observe_row; + let enabled = + incorporate_row + incorporate_sibling + inside_row + simple + multi_observe_row; builder.assert_bool(enabled.clone()); builder.assert_bool(end_inside_row); builder.when(end_inside_row).assert_one(inside_row); @@ -730,18 +731,12 @@ impl Air input_register_1, input_register_2, input_register_3, - output_register + output_register, } = multi_observe_specific; - builder - .when(multi_observe_row) - .assert_bool(is_first); - builder - .when(multi_observe_row) - .assert_bool(is_last); - builder - .when(multi_observe_row) - .assert_bool(should_permute); + builder.when(multi_observe_row).assert_bool(is_first); + builder.when(multi_observe_row).assert_bool(is_last); + builder.when(multi_observe_row).assert_bool(should_permute); self.execution_bridge .execute_and_increment_pc( @@ -799,19 +794,19 @@ impl Air let i_var = AB::F::from_canonical_usize(i); self.memory_bridge .read( - MemoryAddress::new(self.address_space, input_ptr + curr_len + i_var - start_idx), + MemoryAddress::new( + self.address_space, + input_ptr + curr_len + i_var - start_idx, + ), [data[i]], start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO, - &read_data[i] + &read_data[i], ) .eval(builder, aux_after_start[i] * aux_before_end[i]); - + self.memory_bridge .write( - MemoryAddress::new( - self.address_space, - state_ptr + i_var, - ), + MemoryAddress::new(self.address_space, state_ptr + i_var), [data[i]], start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO + AB::F::ONE, &write_data[i], @@ -835,15 +830,15 @@ impl Air .when(multi_observe_row) .when(not(is_first)) .assert_eq( - aux_after_start[0] - + aux_after_start[1] - + aux_after_start[2] - + aux_after_start[3] - + aux_after_start[4] - + aux_after_start[5] - + aux_after_start[6] - + aux_after_start[7], - AB::Expr::from_canonical_usize(CHUNK) - start_idx.into() + aux_after_start[0] + + aux_after_start[1] + + aux_after_start[2] + + aux_after_start[3] + + aux_after_start[4] + + aux_after_start[5] + + aux_after_start[6] + + aux_after_start[7], + AB::Expr::from_canonical_usize(CHUNK) - start_idx.into(), ); builder @@ -851,43 +846,40 @@ impl Air .when(not(is_first)) .assert_eq( aux_before_end[0] - + aux_before_end[1] - + aux_before_end[2] - + aux_before_end[3] - + aux_before_end[4] - + aux_before_end[5] - + aux_before_end[6] - + aux_before_end[7], - end_idx + + aux_before_end[1] + + aux_before_end[2] + + aux_before_end[3] + + aux_before_end[4] + + aux_before_end[5] + + aux_before_end[6] + + aux_before_end[7], + end_idx, ); - let full_sponge_input = from_fn::<_, {CHUNK * 2}, _>(|i| local.inner.inputs[i]); - let full_sponge_output = from_fn::<_, {CHUNK * 2}, _>(|i| local.inner.ending_full_rounds[BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS - 1].post[i]); + let full_sponge_input = from_fn::<_, { CHUNK * 2 }, _>(|i| local.inner.inputs[i]); + let full_sponge_output = from_fn::<_, { CHUNK * 2 }, _>(|i| { + local.inner.ending_full_rounds[BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS - 1].post[i] + }); self.memory_bridge .read( - MemoryAddress::new( - self.address_space, - state_ptr, - ), + MemoryAddress::new(self.address_space, state_ptr), full_sponge_input, - start_timestamp + end_idx * AB::F::TWO - start_idx * AB::F::TWO, + start_timestamp + (end_idx - start_idx) * AB::F::TWO, &read_sponge_state, ) .eval(builder, multi_observe_row * should_permute); - + self.memory_bridge .write( - MemoryAddress::new( - self.address_space, - state_ptr - ), + MemoryAddress::new(self.address_space, state_ptr), full_sponge_output, - start_timestamp + end_idx * AB::F::TWO - start_idx * AB::F::TWO + AB::F::ONE, + start_timestamp + (end_idx - start_idx) * AB::F::TWO + AB::F::ONE, &write_sponge_state, ) .eval(builder, multi_observe_row * should_permute); + /* self.memory_bridge .write( MemoryAddress::new( @@ -899,13 +891,14 @@ impl Air &write_final_idx ) .eval(builder, multi_observe_row * is_last); + */ // Field transitions builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) .assert_eq(next_multi_observe_specific.curr_len, multi_observe_specific.curr_len + end_idx - start_idx); - + // Boundary conditions builder .when(multi_observe_row) @@ -927,7 +920,7 @@ impl Air .when(not(is_last)) .assert_one(next.multi_observe_row); - // Field consistency + // Fields remain same across same instance builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) @@ -951,17 +944,26 @@ impl Air builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) - .assert_eq(input_register_1, next_multi_observe_specific.input_register_1); + .assert_eq( + input_register_1, + next_multi_observe_specific.input_register_1, + ); builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) - .assert_eq(input_register_2, next_multi_observe_specific.input_register_2); + .assert_eq( + input_register_2, + next_multi_observe_specific.input_register_2, + ); builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) - .assert_eq(input_register_3, next_multi_observe_specific.input_register_3); + .assert_eq( + input_register_3, + next_multi_observe_specific.input_register_3, + ); builder .when(next.multi_observe_row) @@ -974,10 +976,12 @@ impl Air .when(not(next_multi_observe_specific.is_first)) .assert_eq(very_first_timestamp, next.very_first_timestamp); + /* builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) .assert_eq(next.start_timestamp, start_timestamp + is_first * AB::F::from_canonical_usize(4) + (end_idx - start_idx) * AB::F::TWO + should_permute * AB::F::TWO); + */ } } diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 01a77d0c65..2ccfbe93a3 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -12,7 +12,7 @@ use openvm_circuit::{ use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{ conversion::AS, - Poseidon2Opcode::{COMP_POS2, PERM_POS2, MULTI_OBSERVE}, + Poseidon2Opcode::{COMP_POS2, MULTI_OBSERVE, PERM_POS2}, VerifyBatchOpcode::VERIFY_BATCH, }; use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubChip, Poseidon2SubCols}; @@ -25,7 +25,7 @@ use openvm_stark_backend::{ use crate::poseidon2::{ columns::{ - InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, + InsideRowSpecificCols, MultiObserveCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, TopLevelSpecificCols, }, CHUNK, @@ -644,6 +644,181 @@ where if !self.optimistic { assert_eq!(commit, root); } + } else if instruction.opcode == MULTI_OBSERVE.global_opcode() { + let &Instruction { + a: state_ptr_register, + b: init_pos_register, + c: input_ptr_register, + d: register_address_space, + e: data_address_space, + f: len_register, + .. + } = instruction; + + assert_eq!( + register_address_space, + F::from_canonical_u32(AS::Native as u32) + ); + assert_eq!(data_address_space, F::from_canonical_u32(AS::Native as u32)); + + let [init_pos]: [F; 1] = + memory_read_native(state.memory.data(), init_pos_register.as_canonical_u32()); + let [input_len]: [F; 1] = + memory_read_native(state.memory.data(), len_register.as_canonical_u32()); + + let mut len = input_len.as_canonical_u32() as usize; + let mut pos = init_pos.as_canonical_u32() as usize; + let mut chunks: Vec<(usize, usize)> = vec![]; + + const NUM_HEAD_ACCESSES: usize = 4; + let mut final_timestamp_inc = NUM_HEAD_ACCESSES; + while len > 0 { + if len >= (CHUNK - pos) { + chunks.push((pos.clone(), CHUNK.clone())); + len -= CHUNK - pos; + final_timestamp_inc += 2 * (CHUNK - pos + 1); + pos = 0; + } else { + chunks.push((pos.clone(), pos + len)); + final_timestamp_inc += 2 * len; + len = 0; + pos = pos + len; + } + } + + let allocated_rows = arena + .alloc(MultiRowLayout::new(NativePoseidon2Metadata { + num_rows: 1 + chunks.len(), + })) + .0; + let head_cols = &mut allocated_rows[0]; + let head_multi_observe_cols: &mut MultiObserveCols = + head_cols.specific[..MultiObserveCols::::width()].borrow_mut(); + + let [state_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + state_ptr_register.as_canonical_u32(), + head_multi_observe_cols.read_data[0].as_mut(), + ); + let [init_pos]: [F; 1] = tracing_read_native_helper( + state.memory, + init_pos_register.as_canonical_u32(), + head_multi_observe_cols.read_data[1].as_mut(), + ); + let [input_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + input_ptr_register.as_canonical_u32(), + head_multi_observe_cols.read_data[2].as_mut(), + ); + let [input_len]: [F; 1] = tracing_read_native_helper( + state.memory, + len_register.as_canonical_u32(), + head_multi_observe_cols.read_data[3].as_mut(), + ); + + let input_ptr_u32 = input_ptr.as_canonical_u32(); + let state_ptr_u32 = state_ptr.as_canonical_u32(); + + let init_timestamp = F::from_canonical_u32(init_timestamp_u32); + + for (i, cols) in allocated_rows.iter_mut().enumerate() { + let multi_observe_cols: &mut MultiObserveCols = + cols.specific[..MultiObserveCols::::width()].borrow_mut(); + multi_observe_cols.input_register_1 = init_pos_register; + multi_observe_cols.input_register_2 = input_ptr_register; + multi_observe_cols.input_register_3 = len_register; + multi_observe_cols.output_register = state_ptr_register; + multi_observe_cols.init_pos = init_pos; + multi_observe_cols.input_ptr = input_ptr; + multi_observe_cols.state_ptr = state_ptr; + multi_observe_cols.len = input_len; + + cols.multi_observe_row = F::ONE; + cols.very_first_timestamp = init_timestamp; + + if i == 0 { + // head row + cols.inner.export = F::from_canonical_u32(1 + chunks.len() as u32); + multi_observe_cols.pc = F::from_canonical_u32(*state.pc); + multi_observe_cols.final_timestamp_increment = + F::from_canonical_usize(final_timestamp_inc); + multi_observe_cols.is_first = F::ONE; + multi_observe_cols.is_last = F::ZERO; + multi_observe_cols.curr_len = F::ZERO; + multi_observe_cols.should_permute = F::ZERO; + } + } + + let mut input_idx: usize = 0; + let mut cur_timestamp = init_timestamp_u32 + NUM_HEAD_ACCESSES as u32; + let num_chunks = chunks.len(); + for (i, ((chunk_start, chunk_end), cols)) in chunks + .into_iter() + .zip(allocated_rows.iter_mut().skip(1)) + .enumerate() + { + let multi_observe_cols: &mut MultiObserveCols = + cols.specific[..MultiObserveCols::::width()].borrow_mut(); + + cols.start_timestamp = F::from_canonical_u32(cur_timestamp); + + multi_observe_cols.start_idx = F::from_canonical_usize(chunk_start); + multi_observe_cols.end_idx = F::from_canonical_usize(chunk_end); + + multi_observe_cols.is_first = F::ZERO; + multi_observe_cols.is_last = if i == num_chunks - 1 { + F::ONE + } else { + F::ZERO + }; + multi_observe_cols.curr_len = F::from_canonical_usize(input_idx); + + for j in chunk_start..CHUNK { + multi_observe_cols.aux_after_start[j] = F::ONE; + } + for j in 0..chunk_end { + multi_observe_cols.aux_before_end[j] = F::ONE; + } + for j in chunk_start..chunk_end { + let n_f: [F; 1] = tracing_read_native_helper( + state.memory, + input_ptr_u32 + input_idx as u32, + multi_observe_cols.read_data[j].as_mut(), + ); + tracing_write_native_inplace( + state.memory, + state_ptr_u32 + j as u32, + n_f, + &mut multi_observe_cols.write_data[j], + ); + multi_observe_cols.data[j] = n_f[0]; + input_idx += 1; + cur_timestamp += 2; + } + + if chunk_end >= CHUNK { + multi_observe_cols.should_permute = F::ONE; + let permutation_input: [F; 16] = tracing_read_native_helper( + state.memory, + state_ptr_u32, + multi_observe_cols.read_sponge_state.as_mut(), + ); + cols.inner.inputs.clone_from_slice(&permutation_input); + let output = self.subchip.permute(permutation_input); + tracing_write_native_inplace( + state.memory, + state_ptr_u32, + std::array::from_fn(|i| output[i]), + &mut multi_observe_cols.write_sponge_state, + ); + cur_timestamp += 2; + } else { + multi_observe_cols.should_permute = F::ZERO; + let sponge_state: [F; 16] = + memory_read_native(state.memory.data(), state_ptr_u32); + cols.inner.inputs.clone_from_slice(&sponge_state); + } + } } else { unreachable!() } @@ -661,7 +836,7 @@ where String::from("COMP_POS2") } else if opcode == MULTI_OBSERVE.global_opcode().as_usize() { String::from("MULTI_OBSERVE") - }else { + } else { unreachable!("unsupported opcode: {}", opcode) } } @@ -688,6 +863,10 @@ impl TraceFiller let (curr, rest) = if cols.simple.is_one() { row_idx += 1; row_slice.split_at_mut(width) + } else if cols.multi_observe_row.is_one() { + let total_num_row = cols.inner.export.as_canonical_u32() as usize; + row_idx += total_num_row; + row_slice.split_at_mut(total_num_row * width) } else { let num_non_inside_row = cols.inner.export.as_canonical_u32() as usize; let start = (num_non_inside_row - 1) * width; @@ -704,6 +883,8 @@ impl TraceFiller let cols: &NativePoseidon2Cols = chunk_slice[..width].borrow(); if cols.simple.is_one() { self.fill_simple_chunk(mem_helper, chunk_slice); + } else if cols.multi_observe_row.is_one() { + self.fill_multi_observe_chunk(mem_helper, chunk_slice); } else { self.fill_verify_batch_chunk(mem_helper, chunk_slice); } @@ -961,6 +1142,94 @@ impl NativePoseidon2Filler, + chunk_slice: &mut [F], + ) { + let inner_width = self.subchip.air.width(); + let width = NativePoseidon2Cols::::width(); + let head_cols: &mut NativePoseidon2Cols = + chunk_slice[..width].borrow_mut(); + let num_rows = head_cols.inner.export.as_canonical_u32() as usize; + + let head_multi_observe_cols: &mut MultiObserveCols = + head_cols.specific[..MultiObserveCols::::width()].borrow_mut(); + let start_timestamp_u32 = head_cols.very_first_timestamp.as_canonical_u32(); + + // state_ptr, init_pos, input_ptr, len + mem_fill_helper( + mem_helper, + start_timestamp_u32, + head_multi_observe_cols.read_data[0].as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 1, + head_multi_observe_cols.read_data[1].as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 2, + head_multi_observe_cols.read_data[2].as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 3, + head_multi_observe_cols.read_data[3].as_mut(), + ); + + // generate permutation traces for each row + for row_idx in 0..num_rows { + let cols: &NativePoseidon2Cols = chunk_slice + [row_idx * width..(row_idx + 1) * width] + .as_ref() + .borrow(); + let inner_cols = &self.subchip.generate_trace(vec![cols.inner.inputs]).values; + chunk_slice[row_idx * width..(row_idx + 1) * width][..inner_width] + .copy_from_slice(inner_cols); + } + + for row_idx in 1..num_rows { + let cols: &mut NativePoseidon2Cols = + chunk_slice[row_idx * width..(row_idx + 1) * width].borrow_mut(); + let multi_observe_cols: &mut MultiObserveCols = + cols.specific[..MultiObserveCols::::width()].borrow_mut(); + + let mut start_timestamp_u32 = cols.start_timestamp.as_canonical_u32(); + let chunk_start = multi_observe_cols.start_idx.as_canonical_u32(); + let chunk_end = multi_observe_cols.end_idx.as_canonical_u32(); + + for j in chunk_start..chunk_end { + mem_fill_helper( + mem_helper, + start_timestamp_u32, + multi_observe_cols.read_data[j as usize].as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 1, + multi_observe_cols.write_data[j as usize].as_mut(), + ); + + start_timestamp_u32 += 2; + } + + if chunk_end >= CHUNK as u32 { + mem_fill_helper( + mem_helper, + start_timestamp_u32, + multi_observe_cols.read_sponge_state.as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 1, + multi_observe_cols.write_sponge_state.as_mut(), + ); + } + } + } + #[inline(always)] fn poseidon2_output_from_trace(inner: &Poseidon2SubCols) -> &[F; 2 * CHUNK] { &inner.ending_full_rounds.last().unwrap().post diff --git a/extensions/native/circuit/src/poseidon2/columns.rs b/extensions/native/circuit/src/poseidon2/columns.rs index fe0fce881a..934378dfbe 100644 --- a/extensions/native/circuit/src/poseidon2/columns.rs +++ b/extensions/native/circuit/src/poseidon2/columns.rs @@ -212,10 +212,15 @@ pub struct MultiObserveCols { pub final_timestamp_increment: T, // Initial reads from registers + // They are same across same instance of multi_observe pub state_ptr: T, pub input_ptr: T, pub init_pos: T, pub len: T, + pub input_register_1: T, + pub input_register_2: T, + pub input_register_3: T, + pub output_register: T, pub is_first: T, pub is_last: T, @@ -238,9 +243,4 @@ pub struct MultiObserveCols { // Final write back and registers pub write_final_idx: MemoryWriteAuxCols, pub final_idx: T, - - pub input_register_1: T, - pub input_register_2: T, - pub input_register_3: T, - pub output_register: T, -} \ No newline at end of file +} diff --git a/extensions/native/circuit/src/poseidon2/execution.rs b/extensions/native/circuit/src/poseidon2/execution.rs index 20889e4186..41f4827a67 100644 --- a/extensions/native/circuit/src/poseidon2/execution.rs +++ b/extensions/native/circuit/src/poseidon2/execution.rs @@ -5,10 +5,12 @@ use std::{ use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives::AlignedBytesBorrow; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode, NATIVE_AS, +}; use openvm_native_compiler::{ conversion::AS, - Poseidon2Opcode::{COMP_POS2, PERM_POS2}, + Poseidon2Opcode::{COMP_POS2, MULTI_OBSERVE, PERM_POS2}, VerifyBatchOpcode::VERIFY_BATCH, }; use openvm_poseidon2_air::Poseidon2SubChip; @@ -29,6 +31,16 @@ struct Pos2PreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { input_register_2: u32, } +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct MultiObservePreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { + subchip: &'a Poseidon2SubChip, + pub init_pos_register: u32, + pub input_ptr_register: u32, + pub len_register: u32, + pub state_ptr_register: u32, +} + #[derive(AlignedBytesBorrow, Clone)] #[repr(C)] struct VerifyBatchPreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { @@ -87,6 +99,51 @@ impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Executor, + multi_observe_data: &mut MultiObservePreCompute<'a, F, SBOX_REGISTERS>, + ) -> Result<(), StaticProgramError> { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + f, + .. + } = inst; + + if opcode != MULTI_OBSERVE.global_opcode() { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + let f = f.as_canonical_u32(); + + if d != AS::Native as u32 { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + if e != AS::Native as u32 { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + multi_observe_data.subchip = &self.subchip; + multi_observe_data.state_ptr_register = a; + multi_observe_data.init_pos_register = b; + multi_observe_data.input_ptr_register = c; + multi_observe_data.len_register = f; + + Ok(()) + } + #[inline(always)] fn pre_compute_verify_batch_impl( &'a self, @@ -142,6 +199,7 @@ impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Executor) } + } else if $opcode == MULTI_OBSERVE.global_opcode() { + let multi_observe_data: &mut MultiObservePreCompute = + $data.borrow_mut(); + $executor.pre_compute_multi_observe_impl($pc, $inst, multi_observe_data)?; + Ok($execute_multi_observe_impl::<_, _, SBOX_REGISTERS>) } else { let verify_batch_data: &mut VerifyBatchPreCompute = $data.borrow_mut(); @@ -166,13 +229,18 @@ macro_rules! dispatch1 { }; } +fn max3(a: usize, b: usize, c: usize) -> usize { + std::cmp::max(a, std::cmp::max(b, c)) +} + impl Executor for NativePoseidon2Executor { #[inline(always)] fn pre_compute_size(&self) -> usize { - std::cmp::max( + max3( size_of::>(), + size_of::>(), size_of::>(), ) } @@ -187,6 +255,7 @@ impl Executor ) -> Result, StaticProgramError> { dispatch1!( execute_pos2_e1_impl, + execute_multi_observe_e1_impl, execute_verify_batch_e1_impl, self, inst.opcode, @@ -205,6 +274,7 @@ impl Executor ) -> Result, StaticProgramError> { dispatch1!( execute_pos2_e1_handler, + execute_multi_observe_e1_handler, execute_verify_batch_e1_handler, self, inst.opcode, @@ -218,6 +288,7 @@ impl Executor macro_rules! dispatch2 { ( $execute_pos2_impl:ident, + $execute_multi_observe_impl:ident, $execute_verify_batch_impl:ident, $executor:ident, $opcode:expr, @@ -237,6 +308,13 @@ macro_rules! dispatch2 { } else { Ok($execute_pos2_impl::<_, _, SBOX_REGISTERS, false>) } + } else if $opcode == MULTI_OBSERVE.global_opcode() { + let pre_compute: &mut E2PreCompute> = + $data.borrow_mut(); + pre_compute.chip_idx = $chip_idx as u32; + + $executor.pre_compute_multi_observe_impl($pc, $inst, &mut pre_compute.data)?; + Ok($execute_multi_observe_impl::<_, _, SBOX_REGISTERS>) } else { let pre_compute: &mut E2PreCompute> = $data.borrow_mut(); @@ -270,6 +348,7 @@ impl MeteredExecutor ) -> Result, StaticProgramError> { dispatch2!( execute_pos2_e2_impl, + execute_multi_observe_e2_impl, execute_verify_batch_e2_impl, self, inst.opcode, @@ -290,6 +369,7 @@ impl MeteredExecutor ) -> Result, StaticProgramError> { dispatch2!( execute_pos2_e2_handler, + execute_multi_observe_e2_handler, execute_verify_batch_e2_handler, self, inst.opcode, @@ -345,6 +425,49 @@ unsafe fn execute_pos2_e2_impl< .on_height_change(pre_compute.chip_idx as usize, height); } +#[create_handler] +#[inline(always)] +unsafe fn execute_multi_observe_e1_impl< + F: PrimeField32, + CTX: ExecutionCtxTrait, + const SBOX_REGISTERS: usize, +>( + pre_compute: &[u8], + instret: &mut u64, + pc: &mut u32, + _arg: u64, + exec_state: &mut VmExecState, +) { + let pre_compute: &MultiObservePreCompute = pre_compute.borrow(); + execute_multi_observe_e12_impl::<_, _, SBOX_REGISTERS>(pre_compute, instret, pc, exec_state); +} + +#[create_handler] +#[inline(always)] +unsafe fn execute_multi_observe_e2_impl< + F: PrimeField32, + CTX: MeteredExecutionCtxTrait, + const SBOX_REGISTERS: usize, +>( + pre_compute: &[u8], + instret: &mut u64, + pc: &mut u32, + _arg: u64, + exec_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute> = + pre_compute.borrow(); + let height = execute_multi_observe_e12_impl::<_, _, SBOX_REGISTERS>( + &pre_compute.data, + instret, + pc, + exec_state, + ); + exec_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + #[create_handler] #[inline(always)] unsafe fn execute_verify_batch_e1_impl< @@ -452,6 +575,68 @@ unsafe fn execute_pos2_e12_impl< 1 } +#[inline(always)] +unsafe fn execute_multi_observe_e12_impl< + F: PrimeField32, + CTX: ExecutionCtxTrait, + const SBOX_REGISTERS: usize, +>( + pre_compute: &MultiObservePreCompute, + instret: &mut u64, + pc: &mut u32, + exec_state: &mut VmExecState, +) -> u32 { + let subchip = pre_compute.subchip; + + let [sponge_ptr]: [F; 1] = + exec_state.vm_read(AS::Native as u32, pre_compute.state_ptr_register); + let [init_pos]: [F; 1] = exec_state.vm_read(AS::Native as u32, pre_compute.init_pos_register); + let [input_ptr]: [F; 1] = exec_state.vm_read(AS::Native as u32, pre_compute.input_ptr_register); + let [len]: [F; 1] = exec_state.vm_read(AS::Native as u32, pre_compute.len_register); + + let mut len = len.as_canonical_u32() as usize; + let mut pos = init_pos.as_canonical_u32() as usize; + let input_ptr_u32 = input_ptr.as_canonical_u32(); + let sponge_ptr_u32 = sponge_ptr.as_canonical_u32(); + let mut height = 0; + + // split input into chunks s.t. each chunk fills the RATE portion of sponge state + let mut observation_chunks: Vec<(usize, usize)> = vec![]; + while len > 0 { + if len >= (CHUNK - pos) { + observation_chunks.push((pos, CHUNK)); + len -= CHUNK - pos; + pos = 0; + } else { + observation_chunks.push((pos, pos + len)); + len = 0; + pos = pos + len; + } + } + + height += 1; + let mut input_idx = 0; + + for (chunk_start, chunk_end) in observation_chunks { + for j in chunk_start..chunk_end { + let [n_f]: [F; 1] = exec_state.vm_read(NATIVE_AS as u32, input_ptr_u32 + input_idx); + exec_state.vm_write(NATIVE_AS as u32, sponge_ptr_u32 + (j as u32), &[n_f]); + input_idx += 1; + } + if chunk_end == CHUNK { + let mut p2_input: [F; CHUNK * 2] = exec_state.vm_read(NATIVE_AS as u32, sponge_ptr_u32); + subchip.permute_mut(&mut p2_input); + exec_state.vm_write(NATIVE_AS as u32, sponge_ptr_u32, &p2_input); + } + + height += 1; + } + *pc = pc.wrapping_add(DEFAULT_PC_STEP); + *instret += 1; + + height +} + #[inline(always)] unsafe fn execute_verify_batch_e12_impl< F: PrimeField32, diff --git a/extensions/native/recursion/tests/recursion.rs b/extensions/native/recursion/tests/recursion.rs index e3a5fd0f3b..a147d161f0 100644 --- a/extensions/native/recursion/tests/recursion.rs +++ b/extensions/native/recursion/tests/recursion.rs @@ -1,9 +1,10 @@ use itertools::Itertools; use openvm_circuit::{ arch::{ - PreflightExecutionOutput, PreflightExecutor, VmBuilder, VmCircuitConfig, VmExecutionConfig, instructions::program::Program + instructions::program::Program, PreflightExecutionOutput, PreflightExecutor, VmBuilder, + VmCircuitConfig, VmExecutionConfig, }, - utils::{TestStarkEngine, air_test_impl}, + utils::{air_test_impl, TestStarkEngine}, }; use openvm_native_circuit::{ execute_program_with_config, test_native_config, NativeBuilder, NativeConfig, @@ -209,12 +210,12 @@ fn test_multi_observe() { config.system.memory_config.max_access_adapter_n = 16; let vb = NativeBuilder::default(); - air_test_impl::(fri_params, vb, config, program, vec![], 1, true).unwrap(); - + air_test_impl::(fri_params, vb, config, program, vec![], 1, true) + .unwrap(); } fn build_test_program(builder: &mut Builder) { - let sample_lens: Vec = vec![10, 2, 0, 3, 20]; + let sample_lens: Vec = vec![10, 2, 1, 3, 20]; let mut rng = create_seeded_rng(); let challenger = DuplexChallengerVariable::new(builder); From 065d43c0efbad8b368f540b6c546de1e7315e936 Mon Sep 17 00:00:00 2001 From: xkx Date: Tue, 25 Nov 2025 11:21:10 +0800 Subject: [PATCH 04/12] remove read_sponge_state columns (#13) --- .../native/circuit/src/poseidon2/air.rs | 30 +++++++++++-------- .../native/circuit/src/poseidon2/chip.rs | 26 ++++------------ .../native/circuit/src/poseidon2/columns.rs | 1 - 3 files changed, 23 insertions(+), 34 deletions(-) diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index adf01695e4..5f40d41d32 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -1,5 +1,6 @@ use std::{array::from_fn, borrow::Borrow, sync::Arc}; +use itertools::Itertools; use openvm_circuit::{ arch::{ExecutionBridge, ExecutionState}, system::memory::{offline_checker::MemoryBridge, MemoryAddress, CHUNK}, @@ -724,7 +725,6 @@ impl Air write_data, data, should_permute, - read_sponge_state, write_sponge_state, write_final_idx, final_idx, @@ -856,29 +856,30 @@ impl Air end_idx, ); - let full_sponge_input = from_fn::<_, { CHUNK * 2 }, _>(|i| local.inner.inputs[i]); let full_sponge_output = from_fn::<_, { CHUNK * 2 }, _>(|i| { local.inner.ending_full_rounds[BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS - 1].post[i] }); - self.memory_bridge - .read( - MemoryAddress::new(self.address_space, state_ptr), - full_sponge_input, - start_timestamp + (end_idx - start_idx) * AB::F::TWO, - &read_sponge_state, - ) - .eval(builder, multi_observe_row * should_permute); - self.memory_bridge .write( MemoryAddress::new(self.address_space, state_ptr), full_sponge_output, - start_timestamp + (end_idx - start_idx) * AB::F::TWO + AB::F::ONE, + start_timestamp + (end_idx - start_idx) * AB::F::TWO, &write_sponge_state, ) .eval(builder, multi_observe_row * should_permute); + // enforce that prev_data is permutation input + write_sponge_state + .prev_data() + .iter() + .zip_eq(local.inner.inputs.iter()) + .for_each(|(a, b)| { + builder + .when(multi_observe_row * should_permute) + .assert_eq(*a, *b); + }); + /* self.memory_bridge .write( @@ -897,7 +898,10 @@ impl Air builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) - .assert_eq(next_multi_observe_specific.curr_len, multi_observe_specific.curr_len + end_idx - start_idx); + .assert_eq( + next_multi_observe_specific.curr_len, + multi_observe_specific.curr_len + end_idx - start_idx, + ); // Boundary conditions builder diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 2ccfbe93a3..05b73bd8c4 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -676,7 +676,7 @@ where if len >= (CHUNK - pos) { chunks.push((pos.clone(), CHUNK.clone())); len -= CHUNK - pos; - final_timestamp_inc += 2 * (CHUNK - pos + 1); + final_timestamp_inc += 2 * (CHUNK - pos) + 1; pos = 0; } else { chunks.push((pos.clone(), pos + len)); @@ -766,11 +766,7 @@ where multi_observe_cols.end_idx = F::from_canonical_usize(chunk_end); multi_observe_cols.is_first = F::ZERO; - multi_observe_cols.is_last = if i == num_chunks - 1 { - F::ONE - } else { - F::ZERO - }; + multi_observe_cols.is_last = if i == num_chunks - 1 { F::ONE } else { F::ZERO }; multi_observe_cols.curr_len = F::from_canonical_usize(input_idx); for j in chunk_start..CHUNK { @@ -796,13 +792,10 @@ where cur_timestamp += 2; } + let permutation_input: [F; 16] = + memory_read_native(state.memory.data(), state_ptr_u32); if chunk_end >= CHUNK { multi_observe_cols.should_permute = F::ONE; - let permutation_input: [F; 16] = tracing_read_native_helper( - state.memory, - state_ptr_u32, - multi_observe_cols.read_sponge_state.as_mut(), - ); cols.inner.inputs.clone_from_slice(&permutation_input); let output = self.subchip.permute(permutation_input); tracing_write_native_inplace( @@ -811,12 +804,10 @@ where std::array::from_fn(|i| output[i]), &mut multi_observe_cols.write_sponge_state, ); - cur_timestamp += 2; + cur_timestamp += 1; } else { multi_observe_cols.should_permute = F::ZERO; - let sponge_state: [F; 16] = - memory_read_native(state.memory.data(), state_ptr_u32); - cols.inner.inputs.clone_from_slice(&sponge_state); + cols.inner.inputs.clone_from_slice(&permutation_input); } } } else { @@ -1219,11 +1210,6 @@ impl NativePoseidon2Filler { // Permutation pub should_permute: T, - pub read_sponge_state: MemoryReadAuxCols, pub write_sponge_state: MemoryWriteAuxCols, // Final write back and registers From efb33c5ed9946ed168795ccb130fbf7dfd91393a Mon Sep 17 00:00:00 2001 From: xkx Date: Tue, 25 Nov 2025 12:31:37 +0800 Subject: [PATCH 05/12] Fix overspilling of `MULTI_OBSERVE` constraints due to indicator insufficiency (#7) (#14) Co-authored-by: Ray Gao --- extensions/native/circuit/src/poseidon2/air.rs | 13 +++++++++++-- extensions/native/circuit/src/poseidon2/chip.rs | 1 + extensions/native/circuit/src/poseidon2/columns.rs | 1 + extensions/native/recursion/tests/recursion.rs | 12 ++++++++++-- 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index 5f40d41d32..d68ddf15f3 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -721,6 +721,7 @@ impl Air end_idx, aux_after_start, aux_before_end, + aux_read_enabled, read_data, write_data, data, @@ -802,7 +803,7 @@ impl Air start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO, &read_data[i], ) - .eval(builder, aux_after_start[i] * aux_before_end[i]); + .eval(builder, multi_observe_row * aux_read_enabled[i]); self.memory_bridge .write( @@ -811,21 +812,29 @@ impl Air start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO + AB::F::ONE, &write_data[i], ) - .eval(builder, aux_after_start[i] * aux_before_end[i]); + .eval(builder, multi_observe_row * aux_read_enabled[i]); } for i in 0..(CHUNK - 1) { builder + .when(multi_observe_row) .when(aux_after_start[i]) .assert_one(aux_after_start[i + 1]); } for i in 1..CHUNK { builder + .when(multi_observe_row) .when(aux_before_end[i]) .assert_one(aux_before_end[i - 1]); } + for i in 0..CHUNK { + builder + .when(multi_observe_row) + .assert_eq(aux_after_start[i] * aux_before_end[i], aux_read_enabled[i]); + } + builder .when(multi_observe_row) .when(not(is_first)) diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 05b73bd8c4..a5a1d0d7d4 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -781,6 +781,7 @@ where input_ptr_u32 + input_idx as u32, multi_observe_cols.read_data[j].as_mut(), ); + multi_observe_cols.aux_read_enabled[j] = F::ONE; tracing_write_native_inplace( state.memory, state_ptr_u32 + j as u32, diff --git a/extensions/native/circuit/src/poseidon2/columns.rs b/extensions/native/circuit/src/poseidon2/columns.rs index df2ad6eef3..f710efc370 100644 --- a/extensions/native/circuit/src/poseidon2/columns.rs +++ b/extensions/native/circuit/src/poseidon2/columns.rs @@ -229,6 +229,7 @@ pub struct MultiObserveCols { pub end_idx: T, pub aux_after_start: [T; CHUNK], pub aux_before_end: [T; CHUNK], + pub aux_read_enabled: [T; CHUNK], // Transcript observation pub read_data: [MemoryReadAuxCols; CHUNK], diff --git a/extensions/native/recursion/tests/recursion.rs b/extensions/native/recursion/tests/recursion.rs index a147d161f0..4e3bba92e7 100644 --- a/extensions/native/recursion/tests/recursion.rs +++ b/extensions/native/recursion/tests/recursion.rs @@ -16,7 +16,8 @@ use openvm_native_compiler::{ prelude::Usize, }; use openvm_native_recursion::{ - challenger::duplex::DuplexChallengerVariable, testing_utils::inner::run_recursive_test, + challenger::{duplex::DuplexChallengerVariable, CanObserveVariable}, + testing_utils::inner::run_recursive_test, }; use openvm_stark_backend::{ config::{Domain, StarkGenericConfig, Val}, @@ -218,7 +219,14 @@ fn build_test_program(builder: &mut Builder) { let sample_lens: Vec = vec![10, 2, 1, 3, 20]; let mut rng = create_seeded_rng(); - let challenger = DuplexChallengerVariable::new(builder); + let mut challenger = DuplexChallengerVariable::new(builder); + + // Observe a setup label + let label_f: Vec = vec![128, 3098, 192, 394, 1662, 928, 374, 281, 598, 182, 475, 729]; + for n in label_f { + let f: Felt = builder.constant(C::F::from_canonical_u64(n)); + challenger.observe(builder, f); + } for l in sample_lens { let sample_input: Array> = builder.dyn_array(l); From 3957406a5423d8d49f150bc16b26c7ea4dd3e4d9 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 25 Nov 2025 12:36:20 +0800 Subject: [PATCH 06/12] fmt --- crates/vm/src/metrics/cycle_tracker/mod.rs | 5 +- .../native/circuit/src/poseidon2/tests.rs | 2 +- .../native/compiler/src/asm/compiler.rs | 9 +++- .../native/compiler/src/asm/instruction.rs | 6 ++- .../native/compiler/src/ir/instructions.rs | 11 +++-- extensions/native/compiler/src/ir/poseidon.rs | 47 +++++++++---------- 6 files changed, 44 insertions(+), 36 deletions(-) diff --git a/crates/vm/src/metrics/cycle_tracker/mod.rs b/crates/vm/src/metrics/cycle_tracker/mod.rs index 451eb59ba3..2d569c9774 100644 --- a/crates/vm/src/metrics/cycle_tracker/mod.rs +++ b/crates/vm/src/metrics/cycle_tracker/mod.rs @@ -23,12 +23,13 @@ impl CycleTracker { pub fn top(&self) -> Option<&String> { match self.stack.last() { Some(span) => Some(&span.tag), - _ => None + _ => None, } } /// Starts a new cycle tracker span for the given name. - /// If a span already exists for the given name, it ends the existing span and pushes a new one to the vec. + /// If a span already exists for the given name, it ends the existing span and pushes a new one + /// to the vec. pub fn start(&mut self, mut name: String, cycles_count: usize) { // hack to remove "CT-" prefix if name.starts_with("CT-") { diff --git a/extensions/native/circuit/src/poseidon2/tests.rs b/extensions/native/circuit/src/poseidon2/tests.rs index bce4358451..1a4270dd8a 100644 --- a/extensions/native/circuit/src/poseidon2/tests.rs +++ b/extensions/native/circuit/src/poseidon2/tests.rs @@ -467,7 +467,7 @@ fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester { tester.write(e, lhs, data_left); tester.write(e, lhs + CHUNK, data_right); - }, + } MULTI_OBSERVE => {} } diff --git a/extensions/native/compiler/src/asm/compiler.rs b/extensions/native/compiler/src/asm/compiler.rs index 0e0db9cff9..42e32f3a7d 100644 --- a/extensions/native/compiler/src/asm/compiler.rs +++ b/extensions/native/compiler/src/asm/compiler.rs @@ -491,10 +491,15 @@ impl + TwoAdicField> AsmCo } DslIr::Poseidon2MultiObserve(dst, init_pos, arr_ptr, len) => { self.push( - AsmInstruction::Poseidon2MultiObserve(dst.fp(), init_pos.fp(), arr_ptr.fp(), len.get_var().fp()), + AsmInstruction::Poseidon2MultiObserve( + dst.fp(), + init_pos.fp(), + arr_ptr.fp(), + len.get_var().fp(), + ), debug_info, ); - }, + } DslIr::Poseidon2PermuteBabyBear(dst, src) => match (dst, src) { (Array::Dyn(dst, _), Array::Dyn(src, _)) => self.push( AsmInstruction::Poseidon2Permute(dst.fp(), src.fp()), diff --git a/extensions/native/compiler/src/asm/instruction.rs b/extensions/native/compiler/src/asm/instruction.rs index cd4990b08b..ae0875b83a 100644 --- a/extensions/native/compiler/src/asm/instruction.rs +++ b/extensions/native/compiler/src/asm/instruction.rs @@ -340,7 +340,11 @@ impl> AsmInstruction { AsmInstruction::Halt => write!(f, "halt"), AsmInstruction::HintBits(src, len) => write!(f, "hint_bits ({})fp, {}", src, len), AsmInstruction::Poseidon2MultiObserve(dst, init_pos, arr, len) => { - write!(f, "poseidon2_multi_observe ({})fp, ({})fp ({})fp ({})fp", dst, init_pos, arr, len) + write!( + f, + "poseidon2_multi_observe ({})fp, ({})fp ({})fp ({})fp", + dst, init_pos, arr, len + ) } AsmInstruction::Poseidon2Permute(dst, lhs) => { write!(f, "poseidon2_permute ({})fp, ({})fp", dst, lhs) diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index 13f5c4a653..3b30a45ad6 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -208,12 +208,13 @@ pub enum DslIr { /// Permutes an array of Bn254 elements using Poseidon2 (output = p2_permute(array)). Should /// only be used when target is a circuit. CircuitPoseidon2Permute([Var; 3]), - /// Absorbs an array of baby bear elements into a duplex transcript with Poseidon2 permutations (output = p2_multi_observe(array, els)). + /// Absorbs an array of baby bear elements into a duplex transcript with Poseidon2 permutations + /// (output = p2_multi_observe(array, els)). Poseidon2MultiObserve( - Ptr, // sponge_state - Var, // initial input_ptr position - Ptr, // input array (base elements) - Usize, // len of els + Ptr, // sponge_state + Var, // initial input_ptr position + Ptr, // input array (base elements) + Usize, // len of els ), // Miscellaneous instructions. diff --git a/extensions/native/compiler/src/ir/poseidon.rs b/extensions/native/compiler/src/ir/poseidon.rs index ee9bc0d87d..c82bbaec38 100644 --- a/extensions/native/compiler/src/ir/poseidon.rs +++ b/extensions/native/compiler/src/ir/poseidon.rs @@ -1,9 +1,8 @@ use openvm_native_compiler_derive::iter_zip; use openvm_stark_backend::p3_field::FieldAlgebra; -use crate::ir::Variable; - use super::{Array, ArrayLike, Builder, Config, DslIr, Ext, Felt, MemIndex, Ptr, Usize, Var}; +use crate::ir::Variable; pub const DIGEST_SIZE: usize = 8; pub const HASH_RATE: usize = 8; @@ -11,13 +10,13 @@ pub const PERMUTATION_WIDTH: usize = 16; impl Builder { /// Extends native VM ability to observe multiple base elements in one opcode operation - /// Absorbs elements sequentially at the RATE portion of sponge state and performs as many permutations as necessary. - /// Returns the index position of the next input_ptr. - /// + /// Absorbs elements sequentially at the RATE portion of sponge state and performs as many + /// permutations as necessary. Returns the index position of the next input_ptr. + /// /// [Reference](https://docs.rs/p3-poseidon2/latest/p3_poseidon2/struct.Poseidon2.html) pub fn poseidon2_multi_observe( &mut self, - sponge_state: &Array>, + sponge_state: &Array>, input_ptr: Ptr, arr: &Array>, ) -> Usize { @@ -28,26 +27,24 @@ impl Builder { Array::Fixed(_) => { panic!("Poseidon2 permutation is not allowed on fixed arrays"); } - Array::Dyn(sponge_ptr, _) => { - match arr { - Array::Fixed(_) => { - panic!("Base elements input must be dynamic"); - } - Array::Dyn(ptr, len) => { - let init_pos: Var = Var::uninit(self); - self.assign(&init_pos, input_ptr.address - sponge_ptr.address); - - self.operations.push(DslIr::Poseidon2MultiObserve( - *sponge_ptr, - init_pos, - *ptr, - len.clone(), - )); - - Usize::Var(init_pos) - } + Array::Dyn(sponge_ptr, _) => match arr { + Array::Fixed(_) => { + panic!("Base elements input must be dynamic"); } - } + Array::Dyn(ptr, len) => { + let init_pos: Var = Var::uninit(self); + self.assign(&init_pos, input_ptr.address - sponge_ptr.address); + + self.operations.push(DslIr::Poseidon2MultiObserve( + *sponge_ptr, + init_pos, + *ptr, + len.clone(), + )); + + Usize::Var(init_pos) + } + }, } } From 1451be17b58006dec1f0ba154a966e312caf7b41 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 25 Nov 2025 12:42:22 +0800 Subject: [PATCH 07/12] clippy --- extensions/native/circuit/src/poseidon2/chip.rs | 6 +++--- extensions/native/circuit/src/poseidon2/execution.rs | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index a5a1d0d7d4..3a6ccf094a 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -674,15 +674,15 @@ where let mut final_timestamp_inc = NUM_HEAD_ACCESSES; while len > 0 { if len >= (CHUNK - pos) { - chunks.push((pos.clone(), CHUNK.clone())); + chunks.push((pos, CHUNK)); len -= CHUNK - pos; final_timestamp_inc += 2 * (CHUNK - pos) + 1; pos = 0; } else { - chunks.push((pos.clone(), pos + len)); + chunks.push((pos, pos + len)); final_timestamp_inc += 2 * len; len = 0; - pos = pos + len; + pos += len; } } diff --git a/extensions/native/circuit/src/poseidon2/execution.rs b/extensions/native/circuit/src/poseidon2/execution.rs index 41f4827a67..7e8d9a8aea 100644 --- a/extensions/native/circuit/src/poseidon2/execution.rs +++ b/extensions/native/circuit/src/poseidon2/execution.rs @@ -610,7 +610,7 @@ unsafe fn execute_multi_observe_e12_impl< } else { observation_chunks.push((pos, pos + len)); len = 0; - pos = pos + len; + pos += len; } } @@ -619,14 +619,14 @@ unsafe fn execute_multi_observe_e12_impl< for (chunk_start, chunk_end) in observation_chunks { for j in chunk_start..chunk_end { - let [n_f]: [F; 1] = exec_state.vm_read(NATIVE_AS as u32, input_ptr_u32 + input_idx); - exec_state.vm_write(NATIVE_AS as u32, sponge_ptr_u32 + (j as u32), &[n_f]); + let [n_f]: [F; 1] = exec_state.vm_read(NATIVE_AS, input_ptr_u32 + input_idx); + exec_state.vm_write(NATIVE_AS, sponge_ptr_u32 + (j as u32), &[n_f]); input_idx += 1; } if chunk_end == CHUNK { - let mut p2_input: [F; CHUNK * 2] = exec_state.vm_read(NATIVE_AS as u32, sponge_ptr_u32); + let mut p2_input: [F; CHUNK * 2] = exec_state.vm_read(NATIVE_AS, sponge_ptr_u32); subchip.permute_mut(&mut p2_input); - exec_state.vm_write(NATIVE_AS as u32, sponge_ptr_u32, &p2_input); + exec_state.vm_write(NATIVE_AS, sponge_ptr_u32, &p2_input); } height += 1; From c006da1c8aaf55b4b2bc2aaf90e16884b7b440ec Mon Sep 17 00:00:00 2001 From: xkx Date: Tue, 25 Nov 2025 19:06:54 +0800 Subject: [PATCH 08/12] Feat: tracegen for `multi_observe` (#15) * wip * finish --- Cargo.lock | 1 + .../circuit/cuda/include/native/poseidon2.cuh | 33 +++++++++- .../native/circuit/cuda/src/poseidon2.cu | 65 ++++++++++++++++++- .../native/circuit/src/poseidon2/cuda.rs | 3 + extensions/native/recursion/Cargo.toml | 3 +- .../native/recursion/tests/recursion.rs | 16 +++++ 6 files changed, 118 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9cb5dfbba2..ad540352cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5840,6 +5840,7 @@ dependencies = [ "metrics", "once_cell", "openvm-circuit", + "openvm-cuda-backend", "openvm-native-circuit", "openvm-native-compiler", "openvm-native-compiler-derive", diff --git a/extensions/native/circuit/cuda/include/native/poseidon2.cuh b/extensions/native/circuit/cuda/include/native/poseidon2.cuh index 737406839f..40f0e4ad43 100644 --- a/extensions/native/circuit/cuda/include/native/poseidon2.cuh +++ b/extensions/native/circuit/cuda/include/native/poseidon2.cuh @@ -62,12 +62,43 @@ template struct SimplePoseidonSpecificCols { MemoryWriteAuxCols write_data_2; }; +template struct MultiObserveCols { + T pc; + T final_timestamp_increment; + T state_ptr; + T input_ptr; + T init_pos; + T len; + T input_register_1; + T input_register_2; + T input_register_3; + T output_register; + T is_first; + T is_last; + T curr_len; + T start_idx; + T end_idx; + T aux_after_start[CHUNK]; + T aux_before_end[CHUNK]; + T aux_read_enabled[CHUNK]; + MemoryReadAuxCols read_data[CHUNK]; + MemoryWriteAuxCols write_data[CHUNK]; + T data[CHUNK]; + T should_permute; + MemoryWriteAuxCols write_sponge_state; + MemoryWriteAuxCols write_final_idx; + T final_idx; +}; + template constexpr T constexpr_max(T a, T b) { return a > b ? a : b; } constexpr size_t COL_SPECIFIC_WIDTH = constexpr_max( sizeof(TopLevelSpecificCols), constexpr_max( sizeof(InsideRowSpecificCols), - sizeof(SimplePoseidonSpecificCols) + constexpr_max( + sizeof(SimplePoseidonSpecificCols), + sizeof(MultiObserveCols) + ) ) ); diff --git a/extensions/native/circuit/cuda/src/poseidon2.cu b/extensions/native/circuit/cuda/src/poseidon2.cu index 9779b601e4..ece788b0a8 100644 --- a/extensions/native/circuit/cuda/src/poseidon2.cu +++ b/extensions/native/circuit/cuda/src/poseidon2.cu @@ -22,6 +22,7 @@ template struct NativePoseidon2Cols { T incorporate_sibling; T inside_row; T simple; + T multi_observe_row; T end_inside_row; T end_top_level; @@ -38,7 +39,7 @@ template struct NativePoseidon2Cols { }; __device__ void mem_fill_base( - MemoryAuxColsFactory mem_helper, + MemoryAuxColsFactory &mem_helper, uint32_t timestamp, RowSlice base_aux ) { @@ -58,6 +59,8 @@ template struct Poseidon2Wrapper { ) { if (row[COL_INDEX(Cols, simple)] == Fp::one()) { fill_simple_chunk(row, range_checker, timestamp_max_bits); + } else if (row[COL_INDEX(Cols, multi_observe_row)] == Fp::one()) { + fill_multi_observe_chunk(row, range_checker, timestamp_max_bits); } else { fill_verify_batch_chunk(row, range_checker, timestamp_max_bits); } @@ -335,6 +338,66 @@ template struct Poseidon2Wrapper { } } } + + __device__ static void fill_multi_observe_chunk( + RowSlice row, + VariableRangeChecker range_checker, + uint32_t timestamp_max_bits + ) { + MemoryAuxColsFactory mem_helper(range_checker, timestamp_max_bits); + Poseidon2Row head_row(row); + uint32_t num_rows = head_row.export_col()[0].asUInt32(); + + for (uint32_t idx = 0; idx < num_rows; ++idx) { + RowSlice curr_row = row.shift_row(idx); + fill_inner(curr_row); + fill_multi_observe_specific(curr_row, mem_helper); + } + } + + __device__ static void fill_multi_observe_specific( + RowSlice row, + MemoryAuxColsFactory &mem_helper + ) { + RowSlice specific = row.slice_from(COL_INDEX(Cols, specific)); + if (specific[COL_INDEX(MultiObserveCols, is_first)] == Fp::one()) { + uint32_t very_start_timestamp = + row[COL_INDEX(Cols, very_first_timestamp)].asUInt32(); + for (uint32_t i = 0; i < 4; ++i) { + mem_fill_base( + mem_helper, + very_start_timestamp + i, + specific.slice_from(COL_INDEX(MultiObserveCols, read_data[i].base)) + ); + } + } else { + uint32_t start_timestamp = row[COL_INDEX(Cols, start_timestamp)].asUInt32(); + uint32_t chunk_start = + specific[COL_INDEX(MultiObserveCols, start_idx)].asUInt32(); + uint32_t chunk_end = + specific[COL_INDEX(MultiObserveCols, end_idx)].asUInt32(); + for (uint32_t j = chunk_start; j < chunk_end; ++j) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(MultiObserveCols, write_data[j].base)) + ); + start_timestamp += 2; + } + if (chunk_end >= CHUNK) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(MultiObserveCols, write_sponge_state.base)) + ); + } + } + } }; template diff --git a/extensions/native/circuit/src/poseidon2/cuda.rs b/extensions/native/circuit/src/poseidon2/cuda.rs index 107589ef49..0425cfda18 100644 --- a/extensions/native/circuit/src/poseidon2/cuda.rs +++ b/extensions/native/circuit/src/poseidon2/cuda.rs @@ -53,6 +53,9 @@ impl Chip chunk_start.push(row_idx as u32); if cols.simple.is_one() { row_idx += 1; + } else if cols.multi_observe_row.is_one() { + let num_rows = cols.inner.export.as_canonical_u32() as usize; + row_idx += num_rows; } else { let num_non_inside_row = cols.inner.export.as_canonical_u32() as usize; let non_inside_start = start + (num_non_inside_row - 1) * width; diff --git a/extensions/native/recursion/Cargo.toml b/extensions/native/recursion/Cargo.toml index a8efd69ab3..f47f263693 100644 --- a/extensions/native/recursion/Cargo.toml +++ b/extensions/native/recursion/Cargo.toml @@ -8,6 +8,7 @@ repository.workspace = true [dependencies] openvm-stark-backend = { workspace = true } +openvm-cuda-backend = { workspace = true, optional = true } openvm-native-circuit = { workspace = true, features = ["test-utils"] } openvm-native-compiler = { workspace = true } openvm-native-compiler-derive = { workspace = true } @@ -58,4 +59,4 @@ parallel = ["openvm-stark-backend/parallel"] mimalloc = ["openvm-stark-backend/mimalloc"] jemalloc = ["openvm-stark-backend/jemalloc"] nightly-features = ["openvm-circuit/nightly-features"] -cuda = ["openvm-circuit/cuda", "openvm-native-circuit/cuda"] +cuda = ["openvm-circuit/cuda", "openvm-native-circuit/cuda", "dep:openvm-cuda-backend"] diff --git a/extensions/native/recursion/tests/recursion.rs b/extensions/native/recursion/tests/recursion.rs index 4e3bba92e7..3b78b6734b 100644 --- a/extensions/native/recursion/tests/recursion.rs +++ b/extensions/native/recursion/tests/recursion.rs @@ -6,6 +6,8 @@ use openvm_circuit::{ }, utils::{air_test_impl, TestStarkEngine}, }; +#[cfg(feature = "cuda")] +use openvm_cuda_backend::engine::GpuBabyBearPoseidon2Engine; use openvm_native_circuit::{ execute_program_with_config, test_native_config, NativeBuilder, NativeConfig, }; @@ -211,8 +213,22 @@ fn test_multi_observe() { config.system.memory_config.max_access_adapter_n = 16; let vb = NativeBuilder::default(); + #[cfg(not(feature = "cuda"))] air_test_impl::(fri_params, vb, config, program, vec![], 1, true) .unwrap(); + #[cfg(feature = "cuda")] + { + air_test_impl::( + fri_params, + vb, + config, + program, + vec![], + 1, + true, + ) + .unwrap(); + } } fn build_test_program(builder: &mut Builder) { From d73de3eb777f3f3ad235d8b8183d3b5d1b9b2d16 Mon Sep 17 00:00:00 2001 From: xkx Date: Wed, 26 Nov 2025 14:28:27 +0800 Subject: [PATCH 09/12] Fix: write final pos back (#16) * write final pos back * apply change for gpu * fix that multi_observe works as observe_slice * larger test case --- .../circuit/cuda/include/native/poseidon2.cuh | 1 - .../native/circuit/cuda/src/poseidon2.cu | 8 ++++ .../native/circuit/src/poseidon2/air.rs | 30 ++++++++++----- .../native/circuit/src/poseidon2/chip.rs | 18 +++++++++ .../native/circuit/src/poseidon2/columns.rs | 1 - .../native/circuit/src/poseidon2/execution.rs | 11 +++++- extensions/native/compiler/src/ir/poseidon.rs | 1 + .../native/recursion/src/challenger/duplex.rs | 18 +++++++++ .../native/recursion/tests/recursion.rs | 38 +++++-------------- 9 files changed, 86 insertions(+), 40 deletions(-) diff --git a/extensions/native/circuit/cuda/include/native/poseidon2.cuh b/extensions/native/circuit/cuda/include/native/poseidon2.cuh index 40f0e4ad43..206c0e16c0 100644 --- a/extensions/native/circuit/cuda/include/native/poseidon2.cuh +++ b/extensions/native/circuit/cuda/include/native/poseidon2.cuh @@ -87,7 +87,6 @@ template struct MultiObserveCols { T should_permute; MemoryWriteAuxCols write_sponge_state; MemoryWriteAuxCols write_final_idx; - T final_idx; }; template constexpr T constexpr_max(T a, T b) { return a > b ? a : b; } diff --git a/extensions/native/circuit/cuda/src/poseidon2.cu b/extensions/native/circuit/cuda/src/poseidon2.cu index ece788b0a8..59c65626ab 100644 --- a/extensions/native/circuit/cuda/src/poseidon2.cu +++ b/extensions/native/circuit/cuda/src/poseidon2.cu @@ -395,6 +395,14 @@ template struct Poseidon2Wrapper { start_timestamp, specific.slice_from(COL_INDEX(MultiObserveCols, write_sponge_state.base)) ); + start_timestamp += 1; + } + if (specific[COL_INDEX(MultiObserveCols, is_last)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(MultiObserveCols, write_final_idx.base)) + ); } } } diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index d68ddf15f3..baf18b06a3 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -728,7 +728,6 @@ impl Air should_permute, write_sponge_state, write_final_idx, - final_idx, input_register_1, input_register_2, input_register_3, @@ -830,6 +829,16 @@ impl Air } for i in 0..CHUNK { + builder + .when(multi_observe_row) + .assert_bool(aux_after_start[i]); + builder + .when(multi_observe_row) + .assert_bool(aux_before_end[i]); + builder + .when(multi_observe_row) + .when(is_first) + .assert_zero(aux_read_enabled[i]); builder .when(multi_observe_row) .assert_eq(aux_after_start[i] * aux_before_end[i], aux_read_enabled[i]); @@ -889,19 +898,22 @@ impl Air .assert_eq(*a, *b); }); - /* + builder + .when(multi_observe_row) + .when(aux_read_enabled[CHUNK - 1]) + .assert_one(should_permute); + + // final_idx = aux_read_enabled[CHUNK-1] * 0 + (1 - aux_read_enabled[CHUNK-1]) * end_idx + let final_idx = aux_read_enabled[CHUNK - 1] * AB::Expr::ZERO + + (AB::Expr::ONE - aux_read_enabled[CHUNK - 1]) * end_idx; self.memory_bridge .write( - MemoryAddress::new( - self.address_space, - input_register_1, - ), + MemoryAddress::new(self.address_space, input_register_1), [final_idx], - start_timestamp + is_first * AB::F::from_canonical_usize(4) + (end_idx - start_idx) * AB::F::TWO + should_permute * AB::F::TWO, - &write_final_idx + start_timestamp + (end_idx - start_idx) * AB::F::TWO + should_permute, + &write_final_idx, ) .eval(builder, multi_observe_row * is_last); - */ // Field transitions builder diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 3a6ccf094a..aecff9f10f 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -685,6 +685,7 @@ where pos += len; } } + final_timestamp_inc += 1; // write back to init_pos_register let allocated_rows = arena .alloc(MultiRowLayout::new(NativePoseidon2Metadata { @@ -810,6 +811,15 @@ where multi_observe_cols.should_permute = F::ZERO; cols.inner.inputs.clone_from_slice(&permutation_input); } + if i == num_chunks - 1 { + let final_idx = F::from_canonical_usize(chunk_end % CHUNK); + tracing_write_native_inplace( + state.memory, + init_pos_register.as_canonical_u32(), + [final_idx], + &mut multi_observe_cols.write_final_idx, + ); + } } } else { unreachable!() @@ -1213,6 +1223,14 @@ impl NativePoseidon2Filler { // Final write back and registers pub write_final_idx: MemoryWriteAuxCols, - pub final_idx: T, } diff --git a/extensions/native/circuit/src/poseidon2/execution.rs b/extensions/native/circuit/src/poseidon2/execution.rs index 7e8d9a8aea..a0c1fc72a2 100644 --- a/extensions/native/circuit/src/poseidon2/execution.rs +++ b/extensions/native/circuit/src/poseidon2/execution.rs @@ -331,8 +331,9 @@ impl MeteredExecutor { #[inline(always)] fn metered_pre_compute_size(&self) -> usize { - std::cmp::max( + max3( size_of::>>(), + size_of::>>(), size_of::>>(), ) } @@ -613,6 +614,7 @@ unsafe fn execute_multi_observe_e12_impl< pos += len; } } + let final_idx = observation_chunks.last().map(|(_, end)| *end % CHUNK); height += 1; let mut input_idx = 0; @@ -631,6 +633,13 @@ unsafe fn execute_multi_observe_e12_impl< height += 1; } + if let Some(final_idx) = final_idx { + exec_state.vm_write::( + NATIVE_AS, + pre_compute.init_pos_register, + &[F::from_canonical_usize(final_idx)], + ); + } *pc = pc.wrapping_add(DEFAULT_PC_STEP); *instret += 1; diff --git a/extensions/native/compiler/src/ir/poseidon.rs b/extensions/native/compiler/src/ir/poseidon.rs index c82bbaec38..6d32f89409 100644 --- a/extensions/native/compiler/src/ir/poseidon.rs +++ b/extensions/native/compiler/src/ir/poseidon.rs @@ -42,6 +42,7 @@ impl Builder { len.clone(), )); + // automatically updated by Poseidon2MultiObserve operation Usize::Var(init_pos) } }, diff --git a/extensions/native/recursion/src/challenger/duplex.rs b/extensions/native/recursion/src/challenger/duplex.rs index 2d45d896be..440b14ec59 100644 --- a/extensions/native/recursion/src/challenger/duplex.rs +++ b/extensions/native/recursion/src/challenger/duplex.rs @@ -77,6 +77,24 @@ impl DuplexChallengerVariable { } } + // Observes multiple elements from an array. + // This is equivalent to calling `observe` multiple times, but more efficient. + pub fn observe_slice_opt(&self, builder: &mut Builder, arr: &Array>) { + builder.if_ne(arr.len(), Usize::from(0)).then(|builder| { + let next_pos = builder.poseidon2_multi_observe(&self.sponge_state, self.input_ptr, arr); + + builder.assign(&self.input_ptr, self.io_empty_ptr + next_pos.clone()); + builder.if_ne(next_pos, Usize::from(0)).then_or_else( + |builder| { + builder.assign(&self.output_ptr, self.io_empty_ptr); + }, + |builder| { + builder.assign(&self.output_ptr, self.io_full_ptr); + }, + ); + }); + } + fn sample(&self, builder: &mut Builder) -> Felt { builder .if_ne(self.input_ptr.address, self.io_empty_ptr.address) diff --git a/extensions/native/recursion/tests/recursion.rs b/extensions/native/recursion/tests/recursion.rs index 3b78b6734b..68e033845b 100644 --- a/extensions/native/recursion/tests/recursion.rs +++ b/extensions/native/recursion/tests/recursion.rs @@ -15,10 +15,9 @@ use openvm_native_compiler::{ asm::{AsmBuilder, AsmCompiler}, conversion::{convert_program, CompilerOptions}, ir::{Array, Builder, Config, Felt}, - prelude::Usize, }; use openvm_native_recursion::{ - challenger::{duplex::DuplexChallengerVariable, CanObserveVariable}, + challenger::{duplex::DuplexChallengerVariable, CanObserveVariable, CanSampleVariable}, testing_utils::inner::run_recursive_test, }; use openvm_stark_backend::{ @@ -192,7 +191,6 @@ fn test_multi_observe() { compiler.build(builder.operations); let asm_code = compiler.code(); - // let program = Program::from_instructions(&instructions); let program: Program<_> = convert_program(asm_code, compilation_options); let poseidon2_max_constraint_degree = 3; @@ -232,17 +230,12 @@ fn test_multi_observe() { } fn build_test_program(builder: &mut Builder) { - let sample_lens: Vec = vec![10, 2, 1, 3, 20]; + let sample_lens: Vec = vec![10, 2, 1, 0, 3, 20, 200, 400]; let mut rng = create_seeded_rng(); - let mut challenger = DuplexChallengerVariable::new(builder); - // Observe a setup label - let label_f: Vec = vec![128, 3098, 192, 394, 1662, 928, 374, 281, 598, 182, 475, 729]; - for n in label_f { - let f: Felt = builder.constant(C::F::from_canonical_u64(n)); - challenger.observe(builder, f); - } + let mut c1 = DuplexChallengerVariable::new(builder); + let mut c2 = DuplexChallengerVariable::new(builder); for l in sample_lens { let sample_input: Array> = builder.dyn_array(l); @@ -251,24 +244,13 @@ fn build_test_program(builder: &mut Builder) { builder.set(&sample_input, idx_vec[0], C::F::from_canonical_u32(f_u32)); }); - let next_input_ptr = builder.poseidon2_multi_observe( - &challenger.sponge_state, - challenger.input_ptr, - &sample_input, - ); + c1.observe_slice_opt(builder, &sample_input); + c2.observe_slice(builder, sample_input); + + let e1 = c1.sample(builder); + let e2 = c2.sample(builder); - builder.assign( - &challenger.input_ptr, - challenger.io_empty_ptr + next_input_ptr.clone(), - ); - builder.if_ne(next_input_ptr, Usize::from(0)).then_or_else( - |builder| { - builder.assign(&challenger.output_ptr, challenger.io_empty_ptr); - }, - |builder| { - builder.assign(&challenger.output_ptr, challenger.io_full_ptr); - }, - ); + builder.assert_felt_eq(e1, e2); } builder.halt(); } From 102bd1f5026f5efcbbee8f086b4dd0ed9749e6ab Mon Sep 17 00:00:00 2001 From: xkx Date: Sun, 30 Nov 2025 20:51:52 +0800 Subject: [PATCH 10/12] Pick #8 (#17) * wip * wip2 * wip3 * wip4 * wip5 * replace variable position by name * wip6 * wip7 * clippy * wip8 * wip9 * wip10 * wip11 * wip12 * wip13 * wip14 * wip15 * clippy --- .gitignore | 3 + crates/sdk/src/prover/agg.rs | 6 +- .../native/circuit/src/extension/mod.rs | 20 +- .../circuit/src/field_extension/core.rs | 4 +- extensions/native/circuit/src/fri/mod.rs | 2 +- extensions/native/circuit/src/lib.rs | 2 + .../native/circuit/src/poseidon2/chip.rs | 41 +- extensions/native/circuit/src/sumcheck/air.rs | 591 ++++++++++++++++++ .../native/circuit/src/sumcheck/chip.rs | 579 +++++++++++++++++ .../native/circuit/src/sumcheck/columns.rs | 135 ++++ .../native/circuit/src/sumcheck/cuda.rs | 1 + .../native/circuit/src/sumcheck/execution.rs | 345 ++++++++++ extensions/native/circuit/src/sumcheck/mod.rs | 11 + extensions/native/circuit/src/utils.rs | 27 + .../native/compiler/src/asm/compiler.rs | 12 + .../native/compiler/src/asm/instruction.rs | 16 + .../native/compiler/src/conversion/mod.rs | 16 +- .../native/compiler/src/ir/instructions.rs | 25 + extensions/native/compiler/src/ir/mod.rs | 1 + extensions/native/compiler/src/ir/sumcheck.rs | 48 ++ extensions/native/compiler/src/lib.rs | 15 + extensions/native/recursion/tests/sumcheck.rs | 431 +++++++++++++ 22 files changed, 2292 insertions(+), 39 deletions(-) create mode 100644 extensions/native/circuit/src/sumcheck/air.rs create mode 100644 extensions/native/circuit/src/sumcheck/chip.rs create mode 100644 extensions/native/circuit/src/sumcheck/columns.rs create mode 100644 extensions/native/circuit/src/sumcheck/cuda.rs create mode 100644 extensions/native/circuit/src/sumcheck/execution.rs create mode 100644 extensions/native/circuit/src/sumcheck/mod.rs create mode 100644 extensions/native/compiler/src/ir/sumcheck.rs create mode 100644 extensions/native/recursion/tests/sumcheck.rs diff --git a/.gitignore b/.gitignore index c6e6aa2049..c87b73f198 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,9 @@ Cargo.lock **/.env .DS_Store +# Log outputs +*.log + .cache/ rustc-* diff --git a/crates/sdk/src/prover/agg.rs b/crates/sdk/src/prover/agg.rs index 5e22562675..fd4491ddec 100644 --- a/crates/sdk/src/prover/agg.rs +++ b/crates/sdk/src/prover/agg.rs @@ -27,12 +27,12 @@ where E: StarkFriEngine, NativeBuilder: VmBuilder, { - leaf_prover: VmInstance, - leaf_controller: LeafProvingController, + pub leaf_prover: VmInstance, + pub leaf_controller: LeafProvingController, pub internal_prover: VmInstance, #[cfg(feature = "evm-prove")] - root_prover: RootVerifierLocalProver, + pub root_prover: RootVerifierLocalProver, pub num_children_internal: usize, pub max_internal_wrapper_layers: usize, } diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index a86cdb1bd2..9f3e2035ad 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -17,7 +17,8 @@ use openvm_instructions::{program::DEFAULT_PC_STEP, LocalOpcode, PhantomDiscrimi use openvm_native_compiler::{ CastfOpcode, FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode, NativeJalOpcode, NativeLoadStore4Opcode, NativeLoadStoreOpcode, NativePhantom, - NativeRangeCheckOpcode, Poseidon2Opcode, VerifyBatchOpcode, BLOCK_LOAD_STORE_SIZE, + NativeRangeCheckOpcode, Poseidon2Opcode, SumcheckOpcode, VerifyBatchOpcode, + BLOCK_LOAD_STORE_SIZE, }; use openvm_poseidon2_air::Poseidon2Config; use openvm_rv32im_circuit::BranchEqualCoreAir; @@ -61,6 +62,10 @@ use crate::{ chip::{NativePoseidon2Executor, NativePoseidon2Filler}, NativePoseidon2Chip, }, + sumcheck::{ + air::NativeSumcheckAir, + chip::{NativeSumcheckChip, NativeSumcheckExecutor, NativeSumcheckFiller}, + }, }; cfg_if::cfg_if! { @@ -94,6 +99,7 @@ pub enum NativeExecutor { FieldExtension(FieldExtensionExecutor), FriReducedOpening(FriReducedOpeningExecutor), VerifyBatch(NativePoseidon2Executor), + TowerVerify(NativeSumcheckExecutor), } impl VmExecutionExtension for Native { @@ -169,6 +175,12 @@ impl VmExecutionExtension for Native { ], )?; + let tower_verify = NativeSumcheckExecutor::new(); + inventory.add_executor( + tower_verify, + [SumcheckOpcode::SUMCHECK_LAYER_EVAL.global_opcode()], + )?; + inventory.add_phantom_sub_executor( NativeHintInputSubEx, PhantomDiscriminant(NativePhantom::HintInput as u16), @@ -262,6 +274,9 @@ where ); inventory.add_air(verify_batch); + let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge); + inventory.add_air(tower_evaluate); + Ok(()) } } @@ -342,6 +357,9 @@ where ); inventory.add_executor_chip(poseidon2); + let tower_verify = NativeSumcheckChip::new(NativeSumcheckFiller::new(), mem_helper.clone()); + inventory.add_executor_chip(tower_verify); + Ok(()) } } diff --git a/extensions/native/circuit/src/field_extension/core.rs b/extensions/native/circuit/src/field_extension/core.rs index 5afaf74af5..a7d535a14b 100644 --- a/extensions/native/circuit/src/field_extension/core.rs +++ b/extensions/native/circuit/src/field_extension/core.rs @@ -254,10 +254,10 @@ pub(crate) struct FieldExtension; impl FieldExtension { pub(crate) fn add(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG] where - V: Copy, + V: Clone, V: Add, { - array::from_fn(|i| x[i] + y[i]) + array::from_fn(|i| x[i].clone() + y[i].clone()) } pub(crate) fn subtract(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG] diff --git a/extensions/native/circuit/src/fri/mod.rs b/extensions/native/circuit/src/fri/mod.rs index 1e1ec65cb8..4a0d3847d5 100644 --- a/extensions/native/circuit/src/fri/mod.rs +++ b/extensions/native/circuit/src/fri/mod.rs @@ -542,7 +542,7 @@ fn assert_array_eq, I2: Into, const } } -fn elem_to_ext(elem: F) -> [F; EXT_DEG] { +pub fn elem_to_ext(elem: F) -> [F; EXT_DEG] { let mut ret = [F::ZERO; EXT_DEG]; ret[0] = elem; ret diff --git a/extensions/native/circuit/src/lib.rs b/extensions/native/circuit/src/lib.rs index b5db4f0010..ce257c9c22 100644 --- a/extensions/native/circuit/src/lib.rs +++ b/extensions/native/circuit/src/lib.rs @@ -42,9 +42,11 @@ mod fri; mod jal_rangecheck; mod loadstore; mod poseidon2; +mod sumcheck; mod extension; pub use extension::*; +pub use field_extension::EXT_DEG; mod utils; #[cfg(any(test, feature = "test-utils"))] diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index aecff9f10f..770efc7307 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -3,10 +3,8 @@ use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::*, system::{ - memory::{offline_checker::MemoryBaseAuxCols, online::TracingMemory, MemoryAuxColsFactory}, - native_adapter::util::{ - memory_read_native, tracing_read_native, tracing_write_native_inplace, - }, + memory::{online::TracingMemory, MemoryAuxColsFactory}, + native_adapter::util::{memory_read_native, tracing_write_native_inplace}, }, }; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; @@ -23,12 +21,16 @@ use openvm_stark_backend::{ p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelSliceMut, *}, }; -use crate::poseidon2::{ - columns::{ - InsideRowSpecificCols, MultiObserveCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, - TopLevelSpecificCols, +use crate::{ + mem_fill_helper, + poseidon2::{ + columns::{ + InsideRowSpecificCols, MultiObserveCols, NativePoseidon2Cols, + SimplePoseidonSpecificCols, TopLevelSpecificCols, + }, + CHUNK, }, - CHUNK, + tracing_read_native_helper, }; #[derive(Clone)] @@ -1240,24 +1242,3 @@ impl NativePoseidon2Filler( - memory: &mut TracingMemory, - ptr: u32, - base_aux: &mut MemoryBaseAuxCols, -) -> [F; BLOCK_SIZE] { - let mut prev_ts = 0; - let ret = tracing_read_native(memory, ptr, &mut prev_ts); - base_aux.set_prev(F::from_canonical_u32(prev_ts)); - ret -} - -/// Fill `MemoryBaseAuxCols`, assuming that the `prev_timestamp` is already set in `base_aux`. -fn mem_fill_helper( - mem_helper: &MemoryAuxColsFactory, - timestamp: u32, - base_aux: &mut MemoryBaseAuxCols, -) { - let prev_ts = base_aux.prev_timestamp.as_canonical_u32(); - mem_helper.fill(prev_ts, timestamp, base_aux); -} diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs new file mode 100644 index 0000000000..1cae3847ca --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -0,0 +1,591 @@ +use std::borrow::Borrow; + +use openvm_circuit::{ + arch::{ExecutionBridge, ExecutionState}, + system::memory::{offline_checker::MemoryBridge, MemoryAddress}, +}; +use openvm_circuit_primitives::utils::{and, assert_array_eq, not}; +use openvm_instructions::{LocalOpcode, NATIVE_AS}; +use openvm_native_compiler::SumcheckOpcode::SUMCHECK_LAYER_EVAL; +use openvm_stark_backend::{ + interaction::InteractionBuilder, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra}, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; + +use crate::{ + field_extension::{FieldExtension, EXT_DEG}, + sumcheck::{ + chip::CONTEXT_ARR_BASE_LEN, + columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}, + }, +}; + +#[derive(Clone, Debug)] +pub struct NativeSumcheckAir { + pub execution_bridge: ExecutionBridge, + pub memory_bridge: MemoryBridge, +} + +impl NativeSumcheckAir { + pub fn new(execution_bridge: ExecutionBridge, memory_bridge: MemoryBridge) -> Self { + Self { + execution_bridge, + memory_bridge, + } + } +} + +impl BaseAir for NativeSumcheckAir { + fn width(&self) -> usize { + NativeSumcheckCols::::width() + } +} + +impl BaseAirWithPublicValues for NativeSumcheckAir {} + +impl PartitionedBaseAir for NativeSumcheckAir {} + +impl Air for NativeSumcheckAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local: &NativeSumcheckCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &NativeSumcheckCols = (*next).borrow(); + let native_as = AB::F::from_canonical_u32(NATIVE_AS); + + let &NativeSumcheckCols { + // Row indicators + header_row, + prod_row, + logup_row, + is_end, + + prod_continued, + logup_continued, + // What type of evaluation is performed + // mainly for reducing constraint degree + prod_in_round_evaluation, + prod_next_round_evaluation, + logup_in_round_evaluation, + logup_next_round_evaluation, + + // Indicates whether the round evaluations should be added to the accumulator + prod_acc, + logup_acc, + + // Timestamps + first_timestamp, + start_timestamp, + last_timestamp, + + // Results from reading registers + register_ptrs, + ctx, + prod_nested_len, + logup_nested_len, + + // Challenges + alpha, + challenges, + + curr_prod_n, + curr_logup_n, + + max_round, + within_round_limit, + should_acc, + eval_acc, + specific, + } = local; + + let [round, num_prod_spec, num_logup_spec, _prod_spec_inner_len, prod_spec_inner_inner_len, _logup_spec_inner_len, logup_spec_inner_inner_len, in_round] = + ctx; + builder.assert_bool(header_row); + builder.assert_bool(prod_row); + builder.assert_bool(logup_row); + builder.assert_bool(within_round_limit); + builder.assert_bool(prod_in_round_evaluation); + builder.assert_bool(logup_in_round_evaluation); + + let enabled = header_row + prod_row + logup_row; + let next_enabled = next.header_row + next.prod_row + next.logup_row; + builder.assert_bool(enabled.clone()); + + builder + .when_transition() + .assert_eq(prod_row * next.prod_row, prod_continued); + builder + .when_transition() + .assert_eq(logup_row * next.logup_row, logup_continued); + // TODO: handle last row properly + + builder.when_transition().assert_eq::( + prod_row * next.header_row + + logup_row * next.header_row + + not::(next_enabled), + is_end.into(), + ); + + // TODO: within_round_limit = true => round < max_round + + // Randomness transition + let alpha1: [_; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap(); + let c1: [_; EXT_DEG] = challenges[EXT_DEG..{ EXT_DEG * 2 }].try_into().unwrap(); + let c2: [_; EXT_DEG] = challenges[{ EXT_DEG * 2 }..{ EXT_DEG * 3 }] + .try_into() + .unwrap(); + let alpha2: [_; EXT_DEG] = challenges[{ EXT_DEG * 3 }..{ EXT_DEG * 4 }] + .try_into() + .unwrap(); + let next_alpha1: [_; EXT_DEG] = next.challenges[0..EXT_DEG].try_into().unwrap(); + + // Carry along columns + assert_array_eq( + &mut builder.when(next.prod_row + next.logup_row), + register_ptrs, + next.register_ptrs, + ); + assert_array_eq( + &mut builder.when(next.prod_row + next.logup_row), + ctx, + next.ctx, + ); + // c1, c2 remain the same + assert_array_eq::<_, _, _, { EXT_DEG * 2 }>( + &mut builder.when(next.prod_row + next.logup_row), + challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect(""), + next.challenges[EXT_DEG..(EXT_DEG * 3)] + .try_into() + .expect(""), + ); + assert_array_eq( + &mut builder.when(next.prod_row + next.logup_row), + alpha, + next.alpha, + ); + builder + .when(next.prod_row + next.logup_row) + .assert_eq(prod_nested_len, next.prod_nested_len); + builder + .when(next.prod_row + next.logup_row) + .assert_eq(logup_nested_len, next.logup_nested_len); + + //////////////////////////////////////////////////////////////// + // Row transitions from current to next row + // The basic pattern is + // header_row -> prod_row -> ... -> prod_row + // -> logup_row -> ... -> logup_row + //////////////////////////////////////////////////////////////// + + // (curr_prod_n, curr_logup_n) start at 0 + builder.when(header_row).assert_zero(curr_prod_n); + builder + .when(header_row + prod_row) + .assert_zero(curr_logup_n); + builder + .when(next.prod_row) + .assert_eq(curr_prod_n + AB::F::ONE, next.curr_prod_n); + builder + .when(next.logup_row) + .assert_eq(curr_logup_n + AB::F::ONE, next.curr_logup_n); + // if header row is followed by another header row + // then num_prod_spec and num_logup_spec should be zero + builder + .when(header_row) + .when(next.header_row) + .assert_zero(num_prod_spec); + builder + .when(header_row) + .when(next.header_row) + .assert_zero(num_logup_spec); + // if header row is followed by a logup row, + // then num_prod_spec should be zero + builder + .when(header_row) + .when(next.logup_row) + .assert_zero(num_prod_spec); + builder + .when(prod_row) + .when(next.logup_row) + .assert_eq(curr_prod_n, num_prod_spec); + builder + .when(logup_row) + .when(next.header_row) + .assert_eq(curr_logup_n, num_logup_spec); + + // Timestamp transition + builder + .when(header_row) + .when(next.prod_row + next.logup_row) + .assert_eq( + next.start_timestamp, + start_timestamp + AB::F::from_canonical_usize(7), + ); + builder + .when(prod_row) + .when(next.prod_row + next.logup_row) + .assert_eq( + next.start_timestamp, + start_timestamp + AB::F::ONE + within_round_limit * AB::F::TWO, + ); + builder + .when(logup_row) + .when(next.prod_row + next.logup_row) + .assert_eq( + next.start_timestamp, + start_timestamp + AB::F::ONE + within_round_limit * AB::F::from_canonical_usize(3), + ); + + // Termination condition + assert_array_eq( + &mut builder.when::(is_end.into()), + eval_acc, + [AB::F::ZERO; 4], + ); + + // Randomness transition + assert_array_eq( + &mut builder.when(and(header_row, next.prod_row + next.logup_row)), + next.challenges[0..EXT_DEG].try_into().unwrap(), + [AB::F::ONE, AB::F::ZERO, AB::F::ZERO, AB::F::ZERO], + ); + assert_array_eq::<_, _, _, { EXT_DEG }>(&mut builder.when(header_row), alpha, alpha1); + let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(prod_continued), + prod_next_alpha, + next_alpha1, + ); + // alpha1 = alpha_numerator, alpha2 = alpha_denominator for logup row + let alpha_denominator = FieldExtension::multiply(alpha1, alpha); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_row), + alpha_denominator, + alpha2, + ); + let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_continued), + logup_next_alpha, + next_alpha1, + ); + + /////////////////////////////////////// + // Header + /////////////////////////////////////// + let header_row_specific: &HeaderSpecificCols = + specific[..HeaderSpecificCols::::width()].borrow(); + let registers = header_row_specific.registers; + + self.execution_bridge + .execute_and_increment_pc( + AB::Expr::from_canonical_usize(SUMCHECK_LAYER_EVAL.global_opcode().as_usize()), + [ + registers[4].into(), + registers[0].into(), + registers[1].into(), + native_as.into(), + native_as.into(), + registers[2].into(), + registers[3].into(), + ], + ExecutionState::new(header_row_specific.pc, first_timestamp), + last_timestamp - first_timestamp, + ) + .eval(builder, header_row); + + // Read registers + for i in 0..5usize { + self.memory_bridge + .read( + MemoryAddress::new(native_as, registers[i]), + [register_ptrs[i]], + first_timestamp + AB::F::from_canonical_usize(i), + &header_row_specific.read_records[i], + ) + .eval(builder, header_row); + } + + // Read ctx + self.memory_bridge + .read( + MemoryAddress::new(native_as, register_ptrs[0]), + ctx, + first_timestamp + AB::F::from_canonical_usize(5), + &header_row_specific.read_records[5], + ) + .eval(builder, header_row); + + // Read challenges + self.memory_bridge + .read( + MemoryAddress::new(native_as, register_ptrs[1]), + challenges, + first_timestamp + AB::F::from_canonical_usize(6), + &header_row_specific.read_records[6], + ) + .eval(builder, header_row); + + // Write final result + self.memory_bridge + .write( + MemoryAddress::new(native_as, register_ptrs[4]), + eval_acc, + last_timestamp - AB::F::ONE, + &header_row_specific.write_records, + ) + .eval(builder, header_row); + + /////////////////////////////////////// + // Prod spec evaluation + /////////////////////////////////////// + let prod_row_specific: &ProdSpecificCols = + specific[..ProdSpecificCols::::width()].borrow(); + let next_prod_row_specific: &ProdSpecificCols = + next.specific[..ProdSpecificCols::::width()].borrow(); + + self.memory_bridge + .read( + MemoryAddress::new( + native_as, + register_ptrs[0] + + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN) + + (curr_prod_n - AB::F::ONE), + ), // curr_prod_n starts at 1. + [max_round], + start_timestamp, + &prod_row_specific.read_records[0], + ) + .eval(builder, prod_row); + + // prod_row * within_round_limit = + // prod_in_round_evaluation + prod_next_round_evaluation + builder + .when(prod_in_round_evaluation + prod_next_round_evaluation) + .assert_eq( + prod_row_specific.data_ptr, + (prod_nested_len * (curr_prod_n - AB::F::ONE) + prod_spec_inner_inner_len * round) + * AB::F::from_canonical_usize(EXT_DEG), + ); + builder.assert_eq( + prod_row * within_round_limit * in_round, + prod_in_round_evaluation, + ); + builder.assert_eq( + prod_row * within_round_limit * not(in_round), + prod_next_round_evaluation, + ); + builder.assert_eq(prod_row * should_acc, prod_acc); + + self.memory_bridge + .read( + MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), + prod_row_specific.p, + start_timestamp + AB::F::ONE, + &prod_row_specific.read_records[1], + ) + .eval(builder, prod_row * within_round_limit); + + let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap(); + let p2: [AB::Var; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)] + .try_into() + .unwrap(); + + self.memory_bridge + .write( + MemoryAddress::new( + native_as, + register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), + ), + prod_row_specific.p_evals, + start_timestamp + AB::F::TWO, + &prod_row_specific.write_record, + ) + .eval(builder, prod_row * within_round_limit); + + // Calculate evaluations + let next_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, c1), + FieldExtension::multiply::(p2, c2), + ); + let in_round_p_evals = FieldExtension::multiply::(p1, p2); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(prod_in_round_evaluation), + in_round_p_evals, + prod_row_specific.p_evals, + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(prod_next_round_evaluation), + next_round_p_evals, + prod_row_specific.p_evals, + ); + + // TODO: add constraint on should_acc + + // Accumulate `eval_rlc` into global accumulator `eval_acc` + // when round < max_round - 2 + let eval_rlc = + FieldExtension::multiply::(prod_row_specific.p_evals, alpha1); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(prod_acc), + prod_row_specific.eval_rlc, + eval_rlc, + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(next.prod_acc), + FieldExtension::add(next.eval_acc, next_prod_row_specific.eval_rlc), + eval_acc, + ); + + /////////////////////////////////////// + // Logup spec evaluation + /////////////////////////////////////// + let logup_row_specific: &LogupSpecificCols = + specific[..LogupSpecificCols::::width()].borrow(); + let next_logup_row_specfic: &LogupSpecificCols = + next.specific[..LogupSpecificCols::::width()].borrow(); + + self.memory_bridge + .read( + MemoryAddress::new( + native_as, + register_ptrs[0] + + AB::F::from_canonical_usize(EXT_DEG * 2) + + num_prod_spec + + (curr_logup_n - AB::F::ONE), + ), // curr_logup_n starts at 1. + [max_round], + start_timestamp, + &logup_row_specific.read_records[0], + ) + .eval(builder, logup_row); + + // logup_row * within_round_limit = + // logup_in_round_evaluation + logup_next_round_evaluation + builder + .when(logup_in_round_evaluation + logup_next_round_evaluation) + .assert_eq( + logup_row_specific.data_ptr, + (logup_nested_len * (curr_logup_n - AB::F::ONE) + + logup_spec_inner_inner_len * round) + * AB::F::from_canonical_usize(EXT_DEG), + ); + builder.assert_eq( + logup_row * within_round_limit * in_round, + logup_in_round_evaluation, + ); + builder.assert_eq( + logup_row * within_round_limit * not(in_round), + logup_next_round_evaluation, + ); + builder.assert_eq(logup_row * should_acc, logup_acc); + + self.memory_bridge + .read( + MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), + logup_row_specific.pq, + start_timestamp + AB::F::ONE, + &logup_row_specific.read_records[1], + ) + .eval(builder, logup_row * within_round_limit); + + let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap(); + let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)] + .try_into() + .unwrap(); + let q1: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 2)..{ EXT_DEG * 3 }] + .try_into() + .unwrap(); + let q2: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 3)..(EXT_DEG * 4)] + .try_into() + .unwrap(); + + // write p_evals + self.memory_bridge + .write( + MemoryAddress::new( + native_as, + register_ptrs[4] + + (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), + ), + logup_row_specific.p_evals, + start_timestamp + AB::F::TWO, + &logup_row_specific.write_records[0], + ) + .eval(builder, logup_row * within_round_limit); + + // write q_evals + self.memory_bridge + .write( + MemoryAddress::new( + native_as, + register_ptrs[4] + + (num_prod_spec + num_logup_spec + curr_logup_n) + * AB::F::from_canonical_usize(EXT_DEG), + ), + logup_row_specific.q_evals, + start_timestamp + AB::F::from_canonical_usize(3), + &logup_row_specific.write_records[1], + ) + .eval(builder, logup_row * within_round_limit); + + // Calculate evaluations + let next_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, c1), + FieldExtension::multiply::(p2, c2), + ); + let in_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, q2), + FieldExtension::multiply::(p2, q1), + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_in_round_evaluation), + in_round_p_evals, + logup_row_specific.p_evals, + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_next_round_evaluation), + next_round_p_evals, + logup_row_specific.p_evals, + ); + + let next_round_q_evals = FieldExtension::add( + FieldExtension::multiply::(q1, c1), + FieldExtension::multiply::(q2, c2), + ); + let in_round_q_evals = FieldExtension::multiply::(q1, q2); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_in_round_evaluation), + in_round_q_evals, + logup_row_specific.q_evals, + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_next_round_evaluation), + next_round_q_evals, + logup_row_specific.q_evals, + ); + + // Accumulate evaluation + let eval_rlc = FieldExtension::add( + FieldExtension::multiply::(logup_row_specific.p_evals, alpha1), + FieldExtension::multiply::(logup_row_specific.q_evals, alpha2), + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_acc), + logup_row_specific.eval_rlc, + eval_rlc, + ); + + // Accumulate into global accumulator `eval_acc` + // when round < max_round - 2 + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(next.logup_acc), + FieldExtension::add(next.eval_acc, next_logup_row_specfic.eval_rlc), + eval_acc, + ); + } +} diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs new file mode 100644 index 0000000000..17e99c442a --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -0,0 +1,579 @@ +use std::borrow::BorrowMut; + +use openvm_circuit::{ + arch::{ + CustomBorrow, ExecutionError, MultiRowLayout, MultiRowMetadata, PreflightExecutor, + RecordArena, TraceFiller, VmChipWrapper, VmStateMut, + }, + system::{ + memory::{online::TracingMemory, MemoryAuxColsFactory}, + native_adapter::util::{memory_read_native, tracing_write_native_inplace}, + }, +}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode, NATIVE_AS, +}; +use openvm_native_compiler::SumcheckOpcode::SUMCHECK_LAYER_EVAL; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{ + field_extension::{FieldExtension, EXT_DEG}, + fri::elem_to_ext, + mem_fill_helper, + sumcheck::columns::{ + HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols, + }, + tracing_read_native_helper, +}; + +pub(crate) const CONTEXT_ARR_BASE_LEN: usize = EXT_DEG * 2; +pub(crate) const CURRENT_LAYER_MODE: u32 = 1; +pub(crate) const NEXT_LAYER_MODE: u32 = 0; + +pub(crate) fn calculate_3d_ext_idx( + inner_inner_len: u32, + inner_len: u32, + outer_idx: u32, + inner_idx: u32, + inner_inner_idx: u32, +) -> u32 { + (inner_inner_len * inner_len * outer_idx + inner_inner_len * inner_idx + inner_inner_idx) + * EXT_DEG as u32 +} + +#[derive(Debug, Clone, Default)] +pub struct NativeSumcheckMetadata { + num_rows: usize, +} + +impl MultiRowMetadata for NativeSumcheckMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + self.num_rows + } +} + +type NativeSumcheckRecordLayout = MultiRowLayout; + +pub struct NativeSumcheckRecordMut<'a, F>(&'a mut [NativeSumcheckCols]); + +impl<'a, F: PrimeField32> + CustomBorrow<'a, NativeSumcheckRecordMut<'a, F>, NativeSumcheckRecordLayout> for [u8] +{ + fn custom_borrow( + &'a mut self, + layout: NativeSumcheckRecordLayout, + ) -> NativeSumcheckRecordMut<'a, F> { + // SAFETY: + // - align_to_mut() ensures proper alignment for NativeSumcheckCols + // - Layout guarantees sufficient length for num_rows records + // - Slice bounds validated by taking only num_rows elements + let arr = unsafe { self.align_to_mut::>().1 }; + NativeSumcheckRecordMut(&mut arr[..layout.metadata.num_rows]) + } + + unsafe fn extract_layout(&self) -> NativeSumcheckRecordLayout { + // Each instruction record consists solely of some number of contiguously + // stored NativeSumcheckCols<...> structs, each of which corresponds to a + // single trace row. Trace fillers don't actually need to know how many rows + // each instruction uses, and can thus treat each NativePoseidon2Cols<...> + // as a single record. + NativeSumcheckRecordLayout { + metadata: NativeSumcheckMetadata { num_rows: 1 }, + } + } +} + +#[derive(derive_new::new, Copy, Clone)] +pub struct NativeSumcheckExecutor; + +#[derive(derive_new::new)] +pub struct NativeSumcheckFiller; + +pub type NativeSumcheckChip = VmChipWrapper; + +impl Default for NativeSumcheckExecutor { + fn default() -> Self { + Self::new() + } +} + +impl PreflightExecutor for NativeSumcheckExecutor +where + F: PrimeField32, + for<'buf> RA: RecordArena<'buf, NativeSumcheckRecordLayout, NativeSumcheckRecordMut<'buf, F>>, +{ + fn execute( + &self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let &Instruction { + opcode: op, + a: r_evals_reg, + b: ctx_reg, + c: challenges_reg, + d: data_address_space, + e: register_address_space, + f: prod_evals_reg, + g: logup_evals_reg, + } = instruction; + + // This opcode supports two modes of operation: + // 1. calculate the expected evaluation of two types of sumchecks for the current round + // a. product sumcheck: v' = v[0] * v[1] + // b. logup sumcheck: p'= p[0] * q[1] + p[1] * q[0] and q'= q[0] * q[1]. + // 2. calculate the expected value of next layer: + // a. product sumcheck: v[r] = eq(0,r) * v[0] + eq(1,r) * v[1] + // b. logup sumcheck: p[r] = eq(0,r) * p[0] + eq(1,r) * p[1] + // and q[r] = eq(0,r) * q[0] + eq(1,r) * q[1] + assert_eq!(op, SUMCHECK_LAYER_EVAL.global_opcode()); + assert_eq!(data_address_space.as_canonical_u32(), NATIVE_AS); + assert_eq!(register_address_space.as_canonical_u32(), NATIVE_AS); + + let [ctx_ptr]: [F; 1] = memory_read_native(state.memory.data(), ctx_reg.as_canonical_u32()); + let ctx: [u32; 8] = memory_read_native(state.memory.data(), ctx_ptr.as_canonical_u32()) + .map(|x: F| x.as_canonical_u32()); + + let [round, num_prod_spec, num_logup_spec, prod_specs_inner_len, prod_specs_inner_inner_len, logup_specs_inner_len, logup_specs_inner_inner_len, mode] = + ctx; + // allocate n rows + let num_rows = (1 + num_prod_spec + num_logup_spec) as usize; + let rows = state + .ctx + .alloc(MultiRowLayout::new(NativeSumcheckMetadata { num_rows })) + .0; + + let mut cur_timestamp = state.memory.timestamp(); + // head row + let head_row: &mut NativeSumcheckCols = &mut rows[0]; + let head_specific: &mut HeaderSpecificCols = + head_row.specific[..HeaderSpecificCols::::width()].borrow_mut(); + + head_row.header_row = F::ONE; + head_row.first_timestamp = F::from_canonical_u32(cur_timestamp); + head_row.start_timestamp = F::from_canonical_u32(cur_timestamp); + + head_specific.pc = F::from_canonical_u32(*state.pc); + + head_specific.registers[0] = ctx_reg; + head_specific.registers[1] = challenges_reg; + head_specific.registers[2] = prod_evals_reg; + head_specific.registers[3] = logup_evals_reg; + head_specific.registers[4] = r_evals_reg; + + // read pointers + let [ctx_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + ctx_reg.as_canonical_u32(), + head_specific.read_records[0].as_mut(), + ); + let [challenges_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + challenges_reg.as_canonical_u32(), + head_specific.read_records[1].as_mut(), + ); + let [prod_evals_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + prod_evals_reg.as_canonical_u32(), + head_specific.read_records[2].as_mut(), + ); + let [logup_evals_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + logup_evals_reg.as_canonical_u32(), + head_specific.read_records[3].as_mut(), + ); + let [r_evals_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + r_evals_reg.as_canonical_u32(), + head_specific.read_records[4].as_mut(), + ); + + let ctx: [F; CONTEXT_ARR_BASE_LEN] = tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32(), + head_specific.read_records[5].as_mut(), + ); + + let challenges: [F; EXT_DEG * 4] = tracing_read_native_helper( + state.memory, + challenges_ptr.as_canonical_u32(), + head_specific.read_records[6].as_mut(), + ); + cur_timestamp += 7; // 5 register reads + ctx read + challenges read + head_row.challenges.copy_from_slice(&challenges); + + // challenges = [alpha, c1=r, c2=1-r] + let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().unwrap(); + let c1: [F; 4] = challenges[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); + let c2: [F; 4] = challenges[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().unwrap(); + + let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); + let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); + + // all rows share same register values, ctx, challenges + for row in rows.iter_mut() { + // c1, c2 are same during the entire execution + row.challenges[EXT_DEG..3 * EXT_DEG].copy_from_slice(&challenges[EXT_DEG..3 * EXT_DEG]); + row.alpha = alpha; + row.ctx = ctx; + row.prod_nested_len = + F::from_canonical_u32(prod_specs_inner_len * prod_specs_inner_inner_len); + row.logup_nested_len = + F::from_canonical_u32(logup_specs_inner_len * logup_specs_inner_inner_len); + row.register_ptrs[0] = ctx_ptr; + row.register_ptrs[1] = challenges_ptr; + row.register_ptrs[2] = prod_evals_ptr; + row.register_ptrs[3] = logup_evals_ptr; + row.register_ptrs[4] = r_evals_ptr; + } + + // product rows + for (i, prod_row) in rows + .iter_mut() + .skip(1) + .take(num_prod_spec as usize) + .enumerate() + { + let prod_specific: &mut ProdSpecificCols = + prod_row.specific[..ProdSpecificCols::::width()].borrow_mut(); + + prod_row.prod_row = F::ONE; + prod_row.prod_continued = if i < (num_prod_spec - 1) as usize { + F::ONE + } else { + F::ZERO + }; + prod_row.curr_prod_n = F::from_canonical_usize(i + 1); // curr_prod_n starts from 1 + prod_row.start_timestamp = F::from_canonical_u32(cur_timestamp); + + // read max_round + let [max_round]: [F; 1] = tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32() + (CONTEXT_ARR_BASE_LEN + i) as u32, + prod_specific.read_records[0].as_mut(), + ); + cur_timestamp += 1; + + prod_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc); + prod_row.max_round = max_round; + + let max_round = max_round.as_canonical_u32(); + // round starts from 0 + if round < max_round - 1 { + prod_row.within_round_limit = F::ONE; + let start = calculate_3d_ext_idx( + prod_specs_inner_inner_len, + prod_specs_inner_len, + i as u32, + round, + 0, + ); + prod_specific.data_ptr = F::from_canonical_u32(start); + + // read p1, p2 + let ps: [F; EXT_DEG * 2] = tracing_read_native_helper( + state.memory, + prod_evals_ptr.as_canonical_u32() + start, + prod_specific.read_records[1].as_mut(), + ); + let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); + let p2: [F; EXT_DEG] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); + + prod_specific.p = ps; + + // compute expected eval + let eval = match mode { + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ), + CURRENT_LAYER_MODE => FieldExtension::multiply(p1, p2), + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + prod_specific.p_evals = eval; + + match mode { + NEXT_LAYER_MODE => { + prod_row.prod_next_round_evaluation = F::ONE; + } + CURRENT_LAYER_MODE => { + prod_row.prod_in_round_evaluation = F::ONE; + } + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + } + + // write p eval + tracing_write_native_inplace( + state.memory, + r_evals_ptr.as_canonical_u32() + (1 + i as u32) * (EXT_DEG as u32), + eval, + &mut prod_specific.write_record, + ); + cur_timestamp += 2; + + let eval_rlc = FieldExtension::multiply(alpha_acc, eval); + prod_specific.eval_rlc = eval_rlc; + + if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 { + eval_acc = FieldExtension::add(eval_acc, eval_rlc); + prod_row.should_acc = F::ONE; + prod_row.prod_acc = F::ONE; + prod_row.eval_acc = eval_acc; + } + } + + alpha_acc = FieldExtension::multiply(alpha_acc, alpha); + } + + // logup rows + for (i, logup_row) in rows.iter_mut().skip(1 + num_prod_spec as usize).enumerate() { + let logup_specific: &mut LogupSpecificCols = + logup_row.specific[..LogupSpecificCols::::width()].borrow_mut(); + + logup_row.logup_row = F::ONE; + logup_row.logup_continued = if i < (num_logup_spec - 1) as usize { + F::ONE + } else { + F::ZERO + }; + logup_row.curr_logup_n = F::from_canonical_usize(i + 1); // curr_logup_n starts from 1 + logup_row.start_timestamp = F::from_canonical_u32(cur_timestamp); + + let [max_round]: [F; 1] = tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32() + num_prod_spec + (CONTEXT_ARR_BASE_LEN + i) as u32, + logup_specific.read_records[0].as_mut(), + ); + logup_row.max_round = max_round; + cur_timestamp += 1; + + let alpha_numerator = alpha_acc; + let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); + logup_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc); + logup_row.challenges[3 * EXT_DEG..(4 * EXT_DEG)].copy_from_slice(&alpha_denominator); + + let max_round = max_round.as_canonical_u32(); + if round < max_round - 1 { + logup_row.within_round_limit = F::ONE; + let start = calculate_3d_ext_idx( + logup_specs_inner_inner_len, + logup_specs_inner_len, + i as u32, + round, + 0, + ); + logup_specific.data_ptr = F::from_canonical_u32(start); + + // read p1, p2, q1, q2 + let pqs: [F; EXT_DEG * 4] = tracing_read_native_helper( + state.memory, + logup_evals_ptr.as_canonical_u32() + start, + logup_specific.read_records[1].as_mut(), + ); + let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); + let p2: [F; EXT_DEG] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); + let q1: [F; EXT_DEG] = pqs[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().unwrap(); + let q2: [F; EXT_DEG] = pqs[(EXT_DEG * 3)..(EXT_DEG * 4)].try_into().unwrap(); + + logup_specific.pq = pqs; + + // compute expected evals + let p_eval = match mode { + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ), + CURRENT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, q2), + FieldExtension::multiply(p2, q1), + ), + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + let q_eval = match mode { + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(q1, c1), + FieldExtension::multiply(q2, c2), + ), + CURRENT_LAYER_MODE => FieldExtension::multiply(q1, q2), + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + + match mode { + NEXT_LAYER_MODE => { + logup_row.logup_next_round_evaluation = F::ONE; + } + CURRENT_LAYER_MODE => { + logup_row.logup_in_round_evaluation = F::ONE; + } + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + } + + logup_specific.p_evals = p_eval; + logup_specific.q_evals = q_eval; + + // write p_eval + tracing_write_native_inplace( + state.memory, + r_evals_ptr.as_canonical_u32() + + (1 + num_prod_spec + i as u32) * (EXT_DEG as u32), + p_eval, + &mut logup_specific.write_records[0], + ); + // write q_eval + tracing_write_native_inplace( + state.memory, + r_evals_ptr.as_canonical_u32() + + (1 + num_prod_spec + num_logup_spec + i as u32) * (EXT_DEG as u32), + q_eval, + &mut logup_specific.write_records[1], + ); + cur_timestamp += 3; // 1 read, 2 writes + + let eval_rlc = FieldExtension::add( + FieldExtension::multiply(alpha_numerator, p_eval), + FieldExtension::multiply(alpha_denominator, q_eval), + ); + logup_specific.eval_rlc = eval_rlc; + if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 { + eval_acc = FieldExtension::add(eval_acc, eval_rlc); + logup_row.should_acc = F::ONE; + logup_row.logup_acc = F::ONE; + logup_row.eval_acc = eval_acc; + } + } + + alpha_acc = FieldExtension::multiply(alpha_denominator, alpha); + } + + if let Some(last_row) = rows.last_mut() { + last_row.is_end = F::ONE; + } + + let head_row = &mut rows[0]; + head_row.last_timestamp = F::from_canonical_u32(cur_timestamp + 1); + + let head_specific: &mut HeaderSpecificCols = + head_row.specific[..HeaderSpecificCols::::width()].borrow_mut(); + + tracing_write_native_inplace( + state.memory, + r_evals_ptr.as_canonical_u32(), + eval_acc, + &mut head_specific.write_records, + ); + + for row in rows.iter_mut() { + if row.header_row == F::ONE { + row.eval_acc = eval_acc; + } else if row.prod_row == F::ONE { + let specific: &mut ProdSpecificCols = + row.specific[..ProdSpecificCols::::width()].borrow_mut(); + + eval_acc = FieldExtension::subtract(eval_acc, specific.eval_rlc); + row.eval_acc = eval_acc; + } else if row.logup_row == F::ONE { + let specific: &mut LogupSpecificCols = + row.specific[..LogupSpecificCols::::width()].borrow_mut(); + eval_acc = FieldExtension::subtract(eval_acc, specific.eval_rlc); + row.eval_acc = eval_acc; + } + } + assert_eq!(eval_acc, elem_to_ext(F::from_canonical_u32(0)),); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + Ok(()) + } + + // GKR layered IOP for product and logup relations + fn get_opcode_name(&self, opcode: usize) -> String { + assert_eq!(opcode, SUMCHECK_LAYER_EVAL.global_opcode().as_usize()); + String::from("SUMCHECK_LAYER_EVAL") + } +} + +impl TraceFiller for NativeSumcheckFiller { + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let cols: &mut NativeSumcheckCols = row_slice.borrow_mut(); + let start_timestamp = cols.start_timestamp.as_canonical_u32(); + let last_timestamp = cols.last_timestamp.as_canonical_u32(); + + if cols.header_row == F::ONE { + let header: &mut HeaderSpecificCols = + cols.specific[..HeaderSpecificCols::::width()].borrow_mut(); + + for i in 0..7usize { + mem_fill_helper( + mem_helper, + start_timestamp + i as u32, + header.read_records[i].as_mut(), + ); + } + mem_fill_helper( + mem_helper, + last_timestamp - 1, + header.write_records.as_mut(), + ); + } else if cols.prod_row == F::ONE { + let prod_row_specific: &mut ProdSpecificCols = + cols.specific[..ProdSpecificCols::::width()].borrow_mut(); + + // read max_round + mem_fill_helper( + mem_helper, + start_timestamp, + prod_row_specific.read_records[0].as_mut(), + ); + if cols.within_round_limit == F::ONE { + // read p1, p2 + mem_fill_helper( + mem_helper, + start_timestamp + 1, + prod_row_specific.read_records[1].as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 2, + prod_row_specific.write_record.as_mut(), + ); + } + } else if cols.logup_row == F::ONE { + let logup_row_specific: &mut LogupSpecificCols = + cols.specific[..LogupSpecificCols::::width()].borrow_mut(); + + // read max_round + mem_fill_helper( + mem_helper, + start_timestamp, + logup_row_specific.read_records[0].as_mut(), + ); + if cols.within_round_limit == F::ONE { + // read p1, p2, q1, q2 + mem_fill_helper( + mem_helper, + start_timestamp + 1, + logup_row_specific.read_records[1].as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 2, + logup_row_specific.write_records[0].as_mut(), + ); + // write q_eval + mem_fill_helper( + mem_helper, + start_timestamp + 3, + logup_row_specific.write_records[1].as_mut(), + ); + } + } + } + + fn fill_dummy_trace_row(&self, row_slice: &mut [F]) { + let cols: &mut NativeSumcheckCols = row_slice.borrow_mut(); + + cols.is_end = F::ONE; + } +} diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs new file mode 100644 index 0000000000..b3e6bf4f25 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -0,0 +1,135 @@ +use openvm_circuit::system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}; +use openvm_circuit_primitives_derive::AlignedBorrow; + +use crate::{field_extension::EXT_DEG, utils::const_max}; + +const fn max3(a: usize, b: usize, c: usize) -> usize { + const_max(a, const_max(b, c)) +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct NativeSumcheckCols { + /// Indicates that this row is the header for a layer sum operation + pub header_row: T, + /// Indicates that this row is a step for prod_spec in the layer sum operation + pub prod_row: T, + /// Indicates that this row is a step for logup_spec in the layer sum operation + pub logup_row: T, + /// Indicates that this row is the end of the entire layer sum operation + pub is_end: T, + + pub prod_continued: T, + pub logup_continued: T, + + /// Indicates what type of evaluation constraints should be applied + pub prod_in_round_evaluation: T, + pub prod_next_round_evaluation: T, + pub logup_in_round_evaluation: T, + pub logup_next_round_evaluation: T, + + /// Indicates if evaluations are accumulated + pub prod_acc: T, + pub logup_acc: T, + + /// Timestamps + pub first_timestamp: T, + pub start_timestamp: T, + pub last_timestamp: T, + + // Register values + pub register_ptrs: [T; 5], + + // Context variables + // [ + // round, + // num_prod_spec, + // num_logup_spec, + // prod_spec_inner_len, + // prod_spec_inner_inner_len, + // logup_spec_inner_len, + // logup_spec_inner_inner_len, + // in_layer, + // ] + pub ctx: [T; EXT_DEG * 2], + + pub prod_nested_len: T, + pub logup_nested_len: T, + + pub curr_prod_n: T, + pub curr_logup_n: T, + + pub alpha: [T; EXT_DEG], + // alpha1, c1, c2, alpha2 (for logup rows) + pub challenges: [T; EXT_DEG * 4], + + // Specific to each row + pub max_round: T, + // Is this round within max_round + pub within_round_limit: T, + // Should the evaluation be accumualted + pub should_acc: T, + + // The current final evaluation accumulator. Extension element. + pub eval_acc: [T; EXT_DEG], + + // /// 1. For header row, 5 registers, ctx, challenges + // /// 2. For the rest: max_variables, p1, p2, q1, q2 + // pub read_records: [MemoryReadAuxCols; 7], + // /// 1. For header row, write final result + // /// 2. For prod rows: write prod_evals + // /// 3. For logup rows: write q_evals, p_evals + // pub write_records: [MemoryWriteAuxCols; 2], + pub specific: [T; max3( + HeaderSpecificCols::::width(), + ProdSpecificCols::::width(), + LogupSpecificCols::::width(), + )], +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct HeaderSpecificCols { + pub pc: T, + pub registers: [T; 5], + /// 5 register reads + ctx read + challenges read + pub read_records: [MemoryReadAuxCols; 7], + /// Write the final evaluation + pub write_records: MemoryWriteAuxCols, +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct ProdSpecificCols { + /// Pointer + pub data_ptr: T, + /// 2 extension elements + pub p: [T; EXT_DEG * 2], + /// read max varibale and 2 p values + pub read_records: [MemoryReadAuxCols; 2], + /// Calculated p evals + pub p_evals: [T; EXT_DEG], + /// write p_evals + pub write_record: MemoryWriteAuxCols, + /// p_evals * alpha^i + pub eval_rlc: [T; EXT_DEG], +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct LogupSpecificCols { + /// Pointer + pub data_ptr: T, + /// 4 extension elements + pub pq: [T; EXT_DEG * 4], + /// read max variable and 4 values: p1, p2, q1, q2 + pub read_records: [MemoryReadAuxCols; 2], + /// Calculated p evals + pub p_evals: [T; EXT_DEG], + /// Calculated q evals + pub q_evals: [T; EXT_DEG], + /// write both p_evals and q_evals + pub write_records: [MemoryWriteAuxCols; 2], + /// Evaluation for the accumulator + pub eval_rlc: [T; EXT_DEG], +} diff --git a/extensions/native/circuit/src/sumcheck/cuda.rs b/extensions/native/circuit/src/sumcheck/cuda.rs new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/cuda.rs @@ -0,0 +1 @@ + diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs new file mode 100644 index 0000000000..7202e57b00 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -0,0 +1,345 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + mem::size_of, +}; + +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, NATIVE_AS}; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{ + field_extension::{FieldExtension, EXT_DEG}, + fri::elem_to_ext, + sumcheck::chip::{ + calculate_3d_ext_idx, NativeSumcheckExecutor, CONTEXT_ARR_BASE_LEN, CURRENT_LAYER_MODE, + NEXT_LAYER_MODE, + }, +}; + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct NativeSumcheckPreCompute { + r_evals_reg: u32, + ctx_reg: u32, + challenges_reg: u32, + prod_evals_reg: u32, + logup_evals_reg: u32, +} + +impl NativeSumcheckExecutor { + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut NativeSumcheckPreCompute, + ) -> Result<(), StaticProgramError> { + let &Instruction { + a, + b, + c, + d, + e, + f, + g, + .. + } = inst; + + let r_evals_reg = a.as_canonical_u32(); + let ctx_reg = b.as_canonical_u32(); + let challenges_reg = c.as_canonical_u32(); + let prod_evals_reg = f.as_canonical_u32(); + let logup_evals_reg = g.as_canonical_u32(); + + if d.as_canonical_u32() != NATIVE_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + if e.as_canonical_u32() != NATIVE_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + *data = NativeSumcheckPreCompute { + r_evals_reg, + ctx_reg, + challenges_reg, + prod_evals_reg, + logup_evals_reg, + }; + + Ok(()) + } +} + +impl Executor for NativeSumcheckExecutor +where + F: PrimeField32, +{ + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut NativeSumcheckPreCompute = data.borrow_mut(); + + self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = execute_e1_handler; + Ok(fn_ptr) + } + + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[cfg(not(feature = "tco"))] + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut NativeSumcheckPreCompute = data.borrow_mut(); + + self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = execute_e1_impl; + Ok(fn_ptr) + } +} + +impl MeteredExecutor for NativeSumcheckExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + #[cfg(not(feature = "tco"))] + #[inline(always)] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = execute_e2_impl; + Ok(fn_ptr) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = execute_e2_handler; + Ok(fn_ptr) + } +} + +#[create_handler] +#[inline(always)] +unsafe fn execute_e1_impl( + pre_compute: &[u8], + instret: &mut u64, + pc: &mut u32, + _instret_end: u64, + exec_state: &mut VmExecState, +) { + let pre_compute: &NativeSumcheckPreCompute = pre_compute.borrow(); + execute_e12_impl(pre_compute, instret, pc, exec_state); +} + +#[create_handler] +#[inline(always)] +unsafe fn execute_e2_impl( + pre_compute: &[u8], + instret: &mut u64, + pc: &mut u32, + _arg: u64, + exec_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + let height = execute_e12_impl(&pre_compute.data, instret, pc, exec_state); + exec_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &NativeSumcheckPreCompute, + instret: &mut u64, + pc: &mut u32, + exec_state: &mut VmExecState, +) -> u32 { + let [r_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.r_evals_reg); + let [ctx_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.ctx_reg); + let [challenges_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.challenges_reg); + let [prod_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.prod_evals_reg); + let [logup_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.logup_evals_reg); + + let r_evals_ptr_u32 = r_evals_ptr.as_canonical_u32(); + let ctx_ptr_u32 = ctx_ptr.as_canonical_u32(); + let logup_evals_ptr = logup_evals_ptr.as_canonical_u32(); + let prod_evals_ptr = prod_evals_ptr.as_canonical_u32(); + + let ctx: [u32; 8] = exec_state + .vm_read(NATIVE_AS, ctx_ptr_u32) + .map(|x: F| x.as_canonical_u32()); + let [round, num_prod_spec, num_logup_spec, prod_specs_inner_len, prod_specs_inner_inner_len, logup_specs_inner_len, logup_specs_inner_inner_len, mode] = + ctx; + let challenges: [F; EXT_DEG * 4] = + exec_state.vm_read(NATIVE_AS, challenges_ptr.as_canonical_u32()); + let alpha: [F; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap(); + let c1: [F; EXT_DEG] = challenges[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); + let c2: [F; EXT_DEG] = challenges[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); + + let mut height = 1; + let mut alpha_acc = elem_to_ext(F::ONE); + let mut eval_acc = elem_to_ext(F::ZERO); + + let prod_offset = ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32; + for i in 0..num_prod_spec { + let [max_round]: [u32; 1] = exec_state + .vm_read(NATIVE_AS, prod_offset + i) + .map(|x: F| x.as_canonical_u32()); + + let start = calculate_3d_ext_idx( + prod_specs_inner_inner_len, + prod_specs_inner_len, + i, + round, + 0, + ); + + if round < max_round - 1 { + let ps: [F; EXT_DEG * 2] = exec_state.vm_read(NATIVE_AS, prod_evals_ptr + start); + let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); + let p2: [F; EXT_DEG] = ps[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); + + let eval = match mode { + CURRENT_LAYER_MODE => FieldExtension::multiply(p1, p2), + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ), + _ => unreachable!("mode can only be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + + exec_state.vm_write(NATIVE_AS, r_evals_ptr_u32 + (1 + i) * EXT_DEG as u32, &eval); + + if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 { + // update eval_acc + eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_acc, eval)); + } + } + + // update alpha_acc + alpha_acc = FieldExtension::multiply(alpha_acc, alpha); + height += 1; + } + + let logup_offset = ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32 + num_prod_spec; + for i in 0..num_logup_spec { + // read max_round + let [max_round]: [u32; 1] = exec_state + .vm_read(NATIVE_AS, logup_offset + i) + .map(|x: F| x.as_canonical_u32()); + let start = calculate_3d_ext_idx( + logup_specs_inner_len, + logup_specs_inner_inner_len, + i, + round, + 0, + ); + + let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); + let alpha_numerator = alpha_acc; + + if round < max_round - 1 { + // read logup_evals + let pqs: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, logup_evals_ptr + start); + let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); + let p2: [F; EXT_DEG] = pqs[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); + let q1: [F; EXT_DEG] = pqs[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); + let q2: [F; EXT_DEG] = pqs[EXT_DEG * 3..EXT_DEG * 4].try_into().unwrap(); + + // compute p_eval and q_eval + let p_eval = match mode { + CURRENT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, q2), + FieldExtension::multiply(p2, q1), + ), + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ), + _ => unreachable!("mode can only be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + let q_eval = match mode { + CURRENT_LAYER_MODE => FieldExtension::multiply(q1, q2), + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(q1, c1), + FieldExtension::multiply(q2, c2), + ), + _ => unreachable!("mode can only be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + + // write eval to r_evals + exec_state.vm_write( + NATIVE_AS, + r_evals_ptr_u32 + (1 + num_prod_spec + i) * EXT_DEG as u32, + &p_eval, + ); + exec_state.vm_write( + NATIVE_AS, + r_evals_ptr_u32 + (1 + num_prod_spec + num_logup_spec + i) * EXT_DEG as u32, + &q_eval, + ); + + let eval_rlc = FieldExtension::add( + FieldExtension::multiply(alpha_numerator, p_eval), + FieldExtension::multiply(alpha_denominator, q_eval), + ); + if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 { + // update eval_acc + eval_acc = FieldExtension::add(eval_acc, eval_rlc); + } + } + + // update alpha_acc + alpha_acc = FieldExtension::multiply(alpha_denominator, alpha); + height += 1; + } + + *pc += DEFAULT_PC_STEP; + *instret += 1; + + exec_state.vm_write(NATIVE_AS, r_evals_ptr_u32, &eval_acc); + // return height delta + height +} diff --git a/extensions/native/circuit/src/sumcheck/mod.rs b/extensions/native/circuit/src/sumcheck/mod.rs new file mode 100644 index 0000000000..8c35a0a7aa --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/mod.rs @@ -0,0 +1,11 @@ +pub mod air; +pub mod chip; +mod columns; + +#[cfg(feature = "cuda")] +mod cuda; +#[cfg(feature = "cuda")] +pub use cuda::*; + +// mod tests; +mod execution; diff --git a/extensions/native/circuit/src/utils.rs b/extensions/native/circuit/src/utils.rs index c38e656e74..3d05656f16 100644 --- a/extensions/native/circuit/src/utils.rs +++ b/extensions/native/circuit/src/utils.rs @@ -1,9 +1,36 @@ +use openvm_circuit::system::{ + memory::{offline_checker::MemoryBaseAuxCols, online::TracingMemory, MemoryAuxColsFactory}, + native_adapter::util::tracing_read_native, +}; +use p3_field::PrimeField32; + pub(crate) const CASTF_MAX_BITS: usize = 30; pub(crate) const fn const_max(a: usize, b: usize) -> usize { [a, b][(a < b) as usize] } +/// Fill `MemoryBaseAuxCols`, assuming that the `prev_timestamp` is already set in `base_aux`. +pub(crate) fn mem_fill_helper( + mem_helper: &MemoryAuxColsFactory, + timestamp: u32, + base_aux: &mut MemoryBaseAuxCols, +) { + let prev_ts = base_aux.prev_timestamp.as_canonical_u32(); + mem_helper.fill(prev_ts, timestamp, base_aux); +} + +pub(crate) fn tracing_read_native_helper( + memory: &mut TracingMemory, + ptr: u32, + base_aux: &mut MemoryBaseAuxCols, +) -> [F; BLOCK_SIZE] { + let mut prev_ts = 0; + let ret = tracing_read_native(memory, ptr, &mut prev_ts); + base_aux.set_prev(F::from_canonical_u32(prev_ts)); + ret +} + /// Testing framework #[cfg(any(test, feature = "test-utils"))] pub mod test_utils { diff --git a/extensions/native/compiler/src/asm/compiler.rs b/extensions/native/compiler/src/asm/compiler.rs index 42e32f3a7d..c80615ca7a 100644 --- a/extensions/native/compiler/src/asm/compiler.rs +++ b/extensions/native/compiler/src/asm/compiler.rs @@ -637,6 +637,18 @@ impl + TwoAdicField> AsmCo ); } } + DslIr::SumcheckLayerEval(input_ctx, challenges, prod_ptr, logup_ptr, r_ptr) => { + self.push( + AsmInstruction::SumcheckLayerEval( + input_ctx.fp(), + challenges.fp(), + prod_ptr.fp(), + logup_ptr.fp(), + r_ptr.fp(), + ), + debug_info, + ); + } _ => unimplemented!(), } } diff --git a/extensions/native/compiler/src/asm/instruction.rs b/extensions/native/compiler/src/asm/instruction.rs index ae0875b83a..3d498c00f4 100644 --- a/extensions/native/compiler/src/asm/instruction.rs +++ b/extensions/native/compiler/src/asm/instruction.rs @@ -171,6 +171,15 @@ pub enum AsmInstruction { CycleTrackerStart(), CycleTrackerEnd(), + + // Native opcode for calculating sumcheck layer evaluation + // SumcheckLayerEval(reg_a, reg_b, reg_c, ... , reg_f, reg_g) + // - reg_a: Output ptr for next layer's evaluations + // - reg_b: Context variables + // - reg_c: Challenge values (alpha, coeff) + // - reg_g: GKR product IOP evaluations + // - reg_f: GKR logup IOP evaluations + SumcheckLayerEval(i32, i32, i32, i32, i32), } impl> AsmInstruction { @@ -407,6 +416,13 @@ impl> AsmInstruction { AsmInstruction::RangeCheck(fp, lo_bits, hi_bits) => { write!(f, "range_check_fp ({})fp, ({}), ({})", fp, lo_bits, hi_bits) } + AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr, r_ptr) => { + write!( + f, + "sumcheck_layer_eval ({})fp, ({})fp, ({})fp, ({})fp, ({})fp", + ctx, cs, p_ptr, l_ptr, r_ptr + ) + } } } } diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index 82ee912703..e71608c190 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -12,7 +12,7 @@ use crate::{ asm::{AsmInstruction, AssemblyCode}, FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode, NativeJalOpcode, NativeLoadStore4Opcode, NativeLoadStoreOpcode, NativePhantom, - NativeRangeCheckOpcode, Poseidon2Opcode, VerifyBatchOpcode, + NativeRangeCheckOpcode, Poseidon2Opcode, SumcheckOpcode, VerifyBatchOpcode, }; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] @@ -535,7 +535,19 @@ fn convert_instruction>( // Here it just requires a 0 AS::Immediate, )] - } + }, + AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr, r_ptr) => vec![ + Instruction { + opcode: options.opcode_with_offset(SumcheckOpcode::SUMCHECK_LAYER_EVAL), + a: i32_f(r_ptr), + b: i32_f(ctx), + c: i32_f(cs), + d: AS::Native.to_field(), + e: AS::Native.to_field(), + f: i32_f(p_ptr), + g: i32_f(l_ptr), + } + ], }; let debug_infos = vec![debug_info; instructions.len()]; diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index 3b30a45ad6..78347283d5 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -320,6 +320,31 @@ pub enum DslIr { CycleTrackerStart(String), /// End the cycle tracker used by a block of code annotated by the string input. CycleTrackerEnd(String), + + /// Native operation for calculating a sumcheck layer's evaluation + /// This op supports two modes: + /// 1. for computing expected evaluation for current layer, output = [ \sum_i alpha^i * + /// prod[i][0] * prod[i][1] + \sum_j alpha^(2j) * (logup_q[i][0] * logup_q[i][1] + alpha* + /// logup_p[i][0] * logup_q[i][1] + alpha * logup_p[i][1] * logup_q[i][0] ]; + /// + /// 2. for computing expected evaluation of next layer, output[1+i] = eq(0,r)*p[i][0] + eq(1,r) + /// * p[i][1]. + SumcheckLayerEval( + Ptr, // Context variables: + // 0. round, + // 1. number of product + // 2. number of logup + // 3. (3D array description) prod_specs_eval inner length + // 4. (3D array description) prod_specs_eval inner_inner length + // 5. (3D array description) logup_spec_eval inner length + // 6. (3D array description) logup_spec_eval inner length + // 7. Operational mode indicator + // 8+. usize-type variables indicating maximum rounds + Ptr, // Challenges: alpha, coeffs + Ptr, // prod_specs_eval + Ptr, // logup_specs_eval + Ptr, // output + ), } impl Default for DslIr { diff --git a/extensions/native/compiler/src/ir/mod.rs b/extensions/native/compiler/src/ir/mod.rs index 47e901cd3a..f708318c34 100644 --- a/extensions/native/compiler/src/ir/mod.rs +++ b/extensions/native/compiler/src/ir/mod.rs @@ -18,6 +18,7 @@ mod instructions; mod poseidon; mod ptr; mod select; +mod sumcheck; mod symbolic; mod types; mod utils; diff --git a/extensions/native/compiler/src/ir/sumcheck.rs b/extensions/native/compiler/src/ir/sumcheck.rs new file mode 100644 index 0000000000..0237fd6740 --- /dev/null +++ b/extensions/native/compiler/src/ir/sumcheck.rs @@ -0,0 +1,48 @@ +use super::{Array, Builder, Config, DslIr, Ext, Usize}; + +impl Builder { + /// Extends native VM ability to calculate the evaluation for a sumcheck layer + /// This opcode supports two modes (indicated by a context variable): + /// 1. calculate the expected evaluation of two types of sumchecks (prod, logup) + /// 2. calculate the expected value of next layer p[r] = eq(0,r)*p[0] + eq(1,r)*p[1] + /// + /// Context variables + /// + /// 0: round, + /// 1: number of product + /// 2. number of logup + /// 3. (3D array description) prod_specs_eval inner length + /// 4. (3D array description) prod_specs_eval inner_inner length + /// 5. (3D array description) logup_spec_eval inner length + /// 6. (3D array description) logup_spec_eval inner length + /// 7. Operational mode indicator + /// 8+ Additional usize-type variables indicating maximum rounds + /// + /// Output + /// + /// 1. for computing expected evaluation, output = [ \sum_i alpha^i * prod[i][0] * prod[i][1] + + /// \sum_j alpha^(2j) * (logup_q[i][0] * logup_q[i][1] + alpha* logup_p[i][0] * logup_q[i][1] + /// + alpha * logup_p[i][1] * logup_q[i][0] ]; + /// + /// 2. for computing expected eval of next layer, output[1+i] = eq(0,r)*p[i][0] + eq(1,r) * + /// p[i][1]. + pub fn sumcheck_layer_eval( + &mut self, + input_ctx: &Array>, // Context variables + challenges: &Array>, // Challenges + prod_specs_eval: &Array>, /* GKR product IOP evaluations. Flattened + * from 3D array. */ + logup_specs_eval: &Array>, /* GKR logup IOP evaluations. Flattened + * from 3D array. */ + r_evals: &Array>, /* Next layer's evaluations (pointer used for + * storing opcode output) */ + ) { + self.operations.push(DslIr::SumcheckLayerEval( + input_ctx.ptr(), + challenges.ptr(), + prod_specs_eval.ptr(), + logup_specs_eval.ptr(), + r_evals.ptr(), + )); + } +} diff --git a/extensions/native/compiler/src/lib.rs b/extensions/native/compiler/src/lib.rs index 66c786fbd9..efb45b0159 100644 --- a/extensions/native/compiler/src/lib.rs +++ b/extensions/native/compiler/src/lib.rs @@ -212,3 +212,18 @@ pub enum VerifyBatchOpcode { /// per column polynomial, per opening point VERIFY_BATCH, } + +/// Opcodes for sumcheck. +#[derive( + Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, +)] +#[opcode_offset = 0x180] +#[repr(usize)] +#[allow(non_camel_case_types)] +pub enum SumcheckOpcode { + /// Compute the expected evaluation for each layer in the tower structure that GKR product IOP + /// and logup IOP uses Supports two modes of operation: + /// 1. Calculate current layer's expected evaluation + /// 2. Calculate next layer's evaluation + SUMCHECK_LAYER_EVAL, +} diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs new file mode 100644 index 0000000000..a4039028bc --- /dev/null +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -0,0 +1,431 @@ +use openvm_circuit::arch::instructions::program::Program; +#[cfg(not(feature = "cuda"))] +use openvm_circuit::utils::air_test_impl; +use openvm_native_circuit::{NativeBuilder, NativeConfig, EXT_DEG}; +use openvm_native_compiler::{ + asm::{AsmBuilder, AsmCompiler}, + conversion::{convert_program, CompilerOptions}, + ir::{Ext, Usize}, + prelude::*, +}; +use openvm_stark_backend::p3_field::{ + extension::BinomialExtensionField, FieldAlgebra, FieldExtensionAlgebra, +}; +use openvm_stark_sdk::{ + config::{ + baby_bear_poseidon2::BabyBearPoseidon2Engine, + fri_params::standard_fri_params_with_100_bits_conjectured_security, FriParameters, + }, + p3_baby_bear::BabyBear, +}; + +pub type F = BabyBear; +pub type E = BinomialExtensionField; + +#[test] +fn test_sumcheck_layer_eval() { + let mut builder = AsmBuilder::>::default(); + + build_test_program(&mut builder); + + let compilation_options = CompilerOptions::default().with_cycle_tracker(); + let mut compiler = AsmCompiler::new(compilation_options.word_size); + compiler.build(builder.operations); + let asm_code = compiler.code(); + + let program: Program<_> = convert_program(asm_code, compilation_options); + let sumcheck_max_constraint_degree = 3; + let fri_params = if matches!(std::env::var("OPENVM_FAST_TEST"), Ok(x) if &x == "1") { + FriParameters { + // max constraint degree = 2^log_blowup + 1 + log_blowup: 1, + log_final_poly_len: 0, + num_queries: 2, + proof_of_work_bits: 0, + } + } else { + standard_fri_params_with_100_bits_conjectured_security(1) + }; + + let mut config = NativeConfig::aggregation(0, sumcheck_max_constraint_degree); + config.system.memory_config.max_access_adapter_n = 16; + + let vb = NativeBuilder::default(); + #[cfg(not(feature = "cuda"))] + air_test_impl::(fri_params, vb, config, program, vec![], 1, true) + .unwrap(); + #[cfg(feature = "cuda")] + { + air_test_impl::( + fri_params, + vb, + config, + program, + vec![], + 1, + true, + ) + .unwrap(); + } +} + +fn build_test_program(builder: &mut Builder) { + let ctx_u32s = [3u32, 6, 5, 8, 2, 8, 4, 0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]; + let ctx: Array> = builder.dyn_array(ctx_u32s.len()); + for (idx, n) in ctx_u32s.into_iter().enumerate() { + builder.set(&ctx, idx, Usize::from(n as usize)); + } + + #[rustfmt::skip] + let challenges_u32s = [ + 548478283u32, 456436544, 1716290291, 791326976, + 1829717553, 1422025771, 1917123958, 727015942, + 183548369, 591240150, 96141963, 1286249979, + 0, 0, 0, 0, + ]; + let challenges: Array> = builder.dyn_array(challenges_u32s.len() / EXT_DEG); + for (idx, n) in challenges_u32s.chunks(EXT_DEG).enumerate() { + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]), + ])); + + builder.set(&challenges, idx, e); + } + + #[rustfmt::skip] + let prod_spec_eval_u32s = [ + 1538906710u32, 637535518, 1753132406, 1395236651, + 278806441, 1722910382, 1475548665, 1117874675, + 1578586709, 1826764884, 384068476, 1852240363, + 707958906, 1960944944, 183554399, 1259273357, + 227285124, 243066436, 1718037317, 369721963, + 1752968006, 1061013677, 775617499, 1464907431, + 544300429, 871461966, 135151545, 1343592602, + 1622220528, 643966158, 3932580, 434948358, + 540553922, 1446502052, 153298741, 1191216273, + 265936762, 1463035257, 1237633339, 1797346310, + 1355791584, 389527741, 1741650463, 1728913415, + 1825739540, 1790924136, 460776743, 29536554, + 6842036, 252495270, 1968285155, 299467416, + 49085744, 1499815729, 1098802236, 644489275, + 1827273105, 1888401527, 390077051, 565528894, + 1366177188, 67441791, 958486301, 402056716, + 590379691, 462035406, 633459131, 843304872, + 584100013, 1932496508, 250656031, 146983915, + 1835173157, 939973454, 1844873638, 1916054832, + 1601784696, 167251717, 409107688, 1062925788, + 1291319514, 1790529531, 495655592, 1093359708, + 790197205, 674458164, 195988318, 399764452, + 106865258, 967050329, 350035523, 1109292118, + 1815460301, 281986036, 900636603, 1121197008, + 1228976590, 1879998708, 1924332706, 434695844, + 1159360621, 471397106, 473371067, 1009065094, + 1320176846, 168020789, 1265321929, 1901808675, + 223657700, 1480150183, 1779968584, 144416591, + 304407746, 1864498679, 1482460119, 1554376965, + 1479261548, 1657723043, 1039345063, 1053923521, + 442080513, 1964082352, 691664908, 1941008321, + 1007729002, 860529393, 849697342, 754485488, + 584295923, 1072251466, 1105105254, 996079746, + 1305909868, 1348028973, 122275988, 464050036, + 692807777, 1098809324, 397235220, 596459886, + 1663209783, 720230826, 1422510715, 1760654694, + 544197700, 1417744567, 1938716517, 1571826328, + 1591430185, 1173137446, 175285007, 1541718596, + 1715958587, 1429966110, 583013357, 1667787861, + 109891172, 668253167, 161783842, 296183397, + 1681897325, 1054396117, 264741948, 464026995, + 1907686022, 1532786783, 394869458, 1766734740, + 136047179, 536856195, 376188855, 700633625, + 515518419, 531043483, 60673499, 556496527, + 1743028981, 873954569, 1371062291, 632169731, + 1353239206, 526507035, 1894490088, 589441599, + 1610487168, 1074160583, 366366374, 247602990, + 1535354896, 894493713, 1555870413, 1389854934, + 1897251683, 1525812801, 675621735, 697919636, + 1690274072, 1466810921, 1221110784, 1741995587, + 1877169764, 390876982, 1794129810, 297662156, + 144295349, 417037264, 1290835727, 1654971513, + 1674131303, 1625667423, 1471248832, 1676797844, + 1172916558, 1707775403, 423725211, 1643279661, + 1695774264, 378140395, 1517569394, 1666625392, + 1803981250, 439036260, 247966130, 709534816, + 361144100, 1546096548, 1240886454, 1898161518, + 843262057, 1709259464, 1301015977, 1997626928, + 677153173, 1606710353, 1216038070, 435565562, + 98686333, 1773787396, 267051994, 99395396, + 545509105, 782289675, 1289865975, 1707775075, + 1158993015, 1506576588, 993215179, 1523099397, + 923914455, 1895162386, 284489994, 1444139016, + 1943825680, 466202724, 1632522710, 1384015062, + 723147188, 1284031324, 1430481515, 341213007, + 171192499, 1061688239, 808927167, 83182639, + 759209907, 1728321272, 976049976, 1652071995, + 1002877840, 69880246, 1095135165, 677588420, + 1384715290, 829619452, 170122781, 1958173727, + 13389238, 789379698, 1883383039, 1279195174, + 1618672336, 1192839317, 1348311124, 758896285, + 1939775389, 684108413, 1838340479, 1332232130, + 1070486028, 549228790, 868851698, 1678207843, + 1754321489, 637000403, 647901906, 45343322, + 1768524074, 1167955205, 1816497210, 1609414096, + 1985231742, 1540534482, 232730819, 232221968, + 1509637836, 1480860627, 884647789, 1096458024, + 163721583, 1248032262, 436419506, 1737102298, + 651105860, 452298073, 1064372507, 1792838683, + 619243471, 860127631, 721724708, 950768433, + 279913448, 339693210, 47730422, 1952683911, + 1316500770, 675944216, 386902809, 619333956, + 1194800389, 43989936, 1944372656, 666045666, + 1155873844, 522696968, 58874730, 1497238023, + 421619994, 1980672127, 1657191856, 1913792631, + 1784663131, 1118400672, 1828104993, 1637808383, + 414755472, 775410449, 747132157, 136820101, + 1082674285, 93190395, 357955402, 335652723, + 1192102705, 480365232, 1354935730, 1391829361, + 966662991, 1601510445, 569528575, 545490940, + 1753711688, 807025222, 580374183, 587718008, + 977546290, 1055719519, 1157107032, 562799608, + 859466927, 840450024, 815325134, 936576801, + 1010587056, 246624382, 1808049797, 1098183398, + 1005077390, 772432546, 1976629565, 1003772218, + 1655315418, 1767931114, 982008720, 785023351, + ]; + + let prod_spec_evals: Array> = + builder.dyn_array(prod_spec_eval_u32s.len() / EXT_DEG); + for (idx, n) in prod_spec_eval_u32s.chunks(EXT_DEG).enumerate() { + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]), + ])); + + builder.set(&prod_spec_evals, idx, e); + } + + #[rustfmt::skip] + let logup_spec_eval_u32s = [ + 1522353967u32, 457603397, 421847521, 1352563318, + 1746817766, 737872688, 1087008622, 1850835028, + 456475558, 892966330, 638163666, 148568548, + 678863061, 1334386850, 1896333039, 154585769, + 433618446, 1186936470, 970218722, 1213827097, + 1798557019, 861757965, 119285527, 395360622, + 226164366, 1330279872, 66561048, 785421608, + 1950755756, 1559889596, 348449876, 1090789452, + 257578851, 273164442, 1644906, 295600924, + 1187949602, 1168249609, 469763604, 60929061, + 291163036, 403842501, 1421902433, 1700188477, + 1046093370, 921059131, 1638991894, 464012042, + 96905857, 1370999592, 271896041, 13595534, + 1489760970, 1650552701, 133367846, 25680377, + 377631580, 652729291, 645763356, 426747355, + 482475486, 1877299223, 103226636, 1333832358, + 1399609097, 458536972, 976248802, 1109365280, + 515164588, 1579426417, 1601829549, 607169702, + 852817956, 1980537127, 134138338, 913344050, + 737880920, 476360275, 61624034, 1610624252, + 264461991, 546933535, 937769429, 293346965, + 1522058041, 1012551797, 994330314, 23333322, + 1969510890, 974351570, 2012030621, 120742000, + 450250620, 180547360, 642746933, 1815029950, + 629489142, 1176992624, 723354779, 572648755, + 1218615348, 648847054, 351903235, 723149764, + 248065753, 243829448, 1283393001, 1912627886, + 581641342, 702465306, 205969758, 1061911274, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1703043252, 1467887451, 1714319214, 907866644, + 1542426838, 742609036, 1814393459, 448706641, + 1960340767, 46490834, 186512520, 363973095, + 846448854, 463742343, 2012517527, 40473617, + 9472552, 263483342, 105738598, 586389136, + 254290990, 625150844, 960233097, 1488303724, + 1700231692, 1471714612, 1540211186, 1590246915, + 945341972, 1343225515, 179976237, 34857822, + 276912528, 984309272, 1277293398, 1520924162, + 1823117694, 604836357, 1460812009, 600052559, + 970469338, 1771022707, 181855831, 1445947220, + 467514809, 1514677498, 947030389, 170390653, + 415409007, 1601463730, 204153427, 904614278, + 1855419512, 2009471607, 1352607379, 576586082, + 1343812879, 1176377580, 1166188815, 1592289048, + 761793881, 1529621462, 193034837, 344011596, + 1669461833, 1356800025, 314186361, 586497329, + 1832810846, 1288092861, 1619454491, 732529408, + 737934269, 909504928, 769680420, 1437893101, + 1727002258, 1618231110, 535125583, 153412473, + 1917760929, 588586507, 564531165, 1790797737, + 1666283994, 1366948884, 117673690, 476470378, + 2012274032, 1951406668, 1739767532, 1273142151, + 1591812317, 1900205312, 1912608761, 1734766024, + 1265002082, 1450462894, 749810837, 1329222552, + 745081805, 1231519431, 1420957967, 883846107, + 1995463911, 407795592, 161655852, 125886157, + 995318920, 484905024, 284135318, 551493419, + 406742309, 1089024446, 637339867, 1858138403, + 1230680117, 187078889, 1929517480, 1125646261, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1610035932, 462442436, 831412555, 44798862, + 1748147276, 1911945531, 1329343740, 971894393, + 362147969, 1583335926, 1528700112, 426908674, + 847905883, 447889090, 1050883911, 1883537469, + 1487501632, 964178870, 1818828551, 1980840799, + 340372118, 1697179193, 215113037, 1893217470, + 1138628493, 1788052486, 443362955, 1349213730, + 589553425, 562526667, 1006040406, 1194546769, + 1831034644, 612004157, 730213913, 1068905440, + 371983982, 502900790, 802785198, 822377635, + 1477528437, 501356237, 684668525, 1306043781, + 621032592, 1971342708, 1411586583, 733418745, + 186045462, 1559301855, 323758310, 453170140, + 498381240, 976247416, 631213663, 898017829, + 501459603, 609703046, 1379288251, 177682695, + 912381595, 121915494, 1137416430, 504054388, + 1138277238, 1603388253, 1838013301, 1700271853, + 20488607, 58775264, 217974275, 979141729, + 53136584, 1331566240, 1460303356, 525812787, + 718385521, 1477919263, 1663622276, 1089788203, + 1204483837, 54225863, 290660186, 1441441958, + 134168813, 349638823, 1867912015, 1579183319, + 55528656, 1602973359, 194297109, 949763297, + 101931919, 242300116, 1610052257, 1351823848, + 174522860, 776955925, 1706962365, 808187490, + 1487253852, 431806906, 213982593, 1170647308, + 1776840400, 295916317, 378708073, 381270341, + 457494568, 705823997, 1407301442, 1693003013, + 700310785, 1349874247, 1284363817, 1566253815, + 1014298154, 215294365, 1070968678, 871641358, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1302679751, 1121894357, 368587356, 1564724097, + 733815591, 2012670011, 1146780092, 1439780227, + 1801628424, 838692317, 932318853, 213634365, + 155292454, 1644317110, 1599846194, 978829059, + 1282095862, 1780431647, 527412087, 1024583705, + 804423802, 951808322, 689345230, 180304167, + 1784562773, 1514653374, 2009396440, 1143778943, + 235299446, 1553017484, 475425117, 758292254, + 716575432, 517083432, 1728864125, 418010549, + 43202592, 507659742, 433077118, 1268144019, + 1462778342, 1928073362, 1330130180, 1749624351, + 827401013, 1236194147, 1875519726, 1437946791, + 607293265, 309229599, 1009445595, 1725229718, + 1436309341, 1952606463, 943149111, 291680468, + 1989684076, 1944713370, 1285294139, 399758737, + 1572979232, 213817406, 214840530, 184898060, + 1483844295, 1536616777, 494816009, 217625163, + 529448032, 786640964, 1766471731, 1424140424, + 1721961711, 740275169, 169908711, 913969302, + 1359358267, 1328322971, 593228769, 771095186, + 801680440, 450930656, 1796349530, 1824428677, + 1111258504, 1741666629, 1098430204, 1792001884, + 1679003061, 590088446, 647614538, 1324461639, + 818996796, 229187928, 74288115, 1158900266, + 1512606270, 1381672753, 785927403, 493453164, + 425259497, 1367873539, 931023744, 221202218, + 669580668, 424996238, 1840425275, 1873362670, + 967642716, 263556335, 578560519, 1558449223, + 607579284, 1724012378, 333582342, 1195784167, + 1419727276, 199294290, 138807165, 1061030752, + 1, 0, 0, 0, + 1, 0, 0, 0, + 776332180, 1333076185, 1855163818, 1897408938, + 799274251, 950452503, 691904988, 1205387466, + 659107883, 434394982, 129587940, 639018629, + 659238594, 1957584892, 864291238, 589178070, + 1267157231, 48925338, 200093884, 1953762869, + 1227617341, 1471420621, 193077633, 1007876111, + 228491220, 1377349503, 1889411060, 1807513892, + 1593042934, 1240864695, 1472870721, 583021932, + 598239104, 1862008818, 1811242869, 780768026, + 520870395, 292016292, 322246659, 868240490, + 1715620331, 1183509209, 2010262726, 1003957251, + 264895455, 307755941, 201990485, 1662471178, + 1643997923, 1573129362, 277821143, 388834470, + 943361405, 1449402196, 614413575, 1504113993, + 1860552739, 1755127315, 1734129760, 1232115188, + 803035456, 360488092, 271342171, 1269544258, + 290642673, 660703582, 986842267, 870891877, + 454573044, 1999346236, 701614601, 820253867, + 883282765, 137247873, 1727164949, 1320585493, + 1738664600, 1900116905, 472215154, 1114994489, + 104218174, 1694603079, 771486383, 935361143, + 92277671, 881040480, 925124484, 1464396527, + 100625197, 65290355, 1001454341, 134627585, + 58629702, 1541542242, 568583607, 1706262052, + 530687550, 1303187245, 1010302462, 264001857, + 789816678, 561378226, 827432508, 801307507, + 1613508315, 1650822853, 1603502703, 439320335, + 15283580, 1244486577, 254345266, 1745653280, + 1648250354, 1528271018, 528366563, 1078707735, + 1430767759, 1890467731, 2001894083, 799949326, + 1, 0, 0, 0, + 1, 0, 0, 0, + 1341839494, 1092219735, 755644898, 966729319, + 1914277278, 1545367697, 1765189119, 1693413008, + ]; + + let logup_spec_evals: Array> = + builder.dyn_array(logup_spec_eval_u32s.len() / EXT_DEG); + for (idx, n) in logup_spec_eval_u32s.chunks(EXT_DEG).enumerate() { + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]), + ])); + + builder.set(&logup_spec_evals, idx, e); + } + + #[rustfmt::skip] + let r_evals_u32s = [ + 941378355u32, 1078920879, 696738840, 496039492, + 1555445457, 184545404, 905938226, 1847966044, + 1024875886, 1782716223, 1625644635, 266865456, + 465953066, 1663531470, 757423849, 1957075986, + 1919693393, 839104130, 127480221, 1527842912, + 918650796, 921462354, 575456073, 696646705, + 1585912361, 258186488, 353168830, 1111094691, + 1401166558, 1905942163, 1923083163, 393037255, + 1042127700, 1126793296, 895794165, 1124924482, + 1324266058, 722406365, 1963838171, 968504459, + 1934378800, 714588691, 6465911, 1168379648, + 903786009, 1326035939, 518289228, 418998914, + 1513133474, 1578096058, 617547414, 1658315126, + 68556894, 1697802593, 1346510664, 1709381671, + 345062962, 1254089535, 1002281845, 1882822096, + 700581748, 1431345304, 489112954, 98435728, + 1799886007, 479788390, 223111065, 631662309, + ]; + + let next_layer_evals: Array> = + builder.dyn_array(r_evals_u32s.len() / EXT_DEG); + for (idx, n) in r_evals_u32s.chunks(EXT_DEG).enumerate() { + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]), + ])); + + builder.set(&next_layer_evals, idx, e); + } + + builder.sumcheck_layer_eval( + &ctx, + &challenges, + &prod_spec_evals, + &logup_spec_evals, + &next_layer_evals, + ); + + builder.halt(); +} From 30e9eff582a1be608a08bbd9546b708958cd5399 Mon Sep 17 00:00:00 2001 From: xkx Date: Sun, 30 Nov 2025 22:54:56 +0800 Subject: [PATCH 11/12] tracegen (#18) --- .../circuit/cuda/include/native/sumcheck.cuh | 85 ++++++++++++ .../circuit/cuda/include/native/utils.cuh | 13 ++ .../native/circuit/cuda/src/poseidon2.cu | 10 +- .../native/circuit/cuda/src/sumcheck.cu | 126 ++++++++++++++++++ extensions/native/circuit/src/cuda_abi.rs | 38 ++++++ .../native/circuit/src/extension/cuda.rs | 5 + .../native/circuit/src/sumcheck/chip.rs | 14 +- .../native/circuit/src/sumcheck/cuda.rs | 56 ++++++++ extensions/native/recursion/tests/sumcheck.rs | 13 +- 9 files changed, 342 insertions(+), 18 deletions(-) create mode 100644 extensions/native/circuit/cuda/include/native/sumcheck.cuh create mode 100644 extensions/native/circuit/cuda/include/native/utils.cuh create mode 100644 extensions/native/circuit/cuda/src/sumcheck.cu diff --git a/extensions/native/circuit/cuda/include/native/sumcheck.cuh b/extensions/native/circuit/cuda/include/native/sumcheck.cuh new file mode 100644 index 0000000000..052dc03fd5 --- /dev/null +++ b/extensions/native/circuit/cuda/include/native/sumcheck.cuh @@ -0,0 +1,85 @@ +#pragma once + +#include "primitives/constants.h" +#include "system/memory/offline_checker.cuh" + +using namespace native; + +template struct HeaderSpecificCols { + T pc; + T registers[5]; + MemoryReadAuxCols read_records[7]; + MemoryWriteAuxCols write_records; +}; + +template struct ProdSpecificCols { + T data_ptr; + T p[EXT_DEG * 2]; + MemoryReadAuxCols read_records[2]; + T p_evals[EXT_DEG]; + MemoryWriteAuxCols write_record; + T eval_rlc[EXT_DEG]; +}; + +template struct LogupSpecificCols { + T data_ptr; + T pq[EXT_DEG * 4]; + MemoryReadAuxCols read_records[2]; + T p_evals[EXT_DEG]; + T q_evals[EXT_DEG]; + MemoryWriteAuxCols write_records[2]; + T eval_rlc[EXT_DEG]; +}; + +template constexpr T constexpr_max(T a, T b) { + return a > b ? a : b; +} + +constexpr size_t COL_SPECIFIC_WIDTH = constexpr_max( + sizeof(HeaderSpecificCols), + constexpr_max(sizeof(ProdSpecificCols), sizeof(LogupSpecificCols)) +); + +template struct NativeSumcheckCols { + T header_row; + T prod_row; + T logup_row; + T is_end; + + T prod_continued; + T logup_continued; + + T prod_in_round_evaluation; + T prod_next_round_evaluation; + T logup_in_round_evaluation; + T logup_next_round_evaluation; + + T prod_acc; + T logup_acc; + + T first_timestamp; + T start_timestamp; + T last_timestamp; + + T register_ptrs[5]; + + T ctx[EXT_DEG * 2]; + + T prod_nested_len; + T logup_nested_len; + + T curr_prod_n; + T curr_logup_n; + + T alpha[EXT_DEG]; + T challenges[EXT_DEG * 4]; + + T max_round; + T within_round_limit; + T should_acc; + + T eval_acc[EXT_DEG]; + + T specific[COL_SPECIFIC_WIDTH]; +}; + diff --git a/extensions/native/circuit/cuda/include/native/utils.cuh b/extensions/native/circuit/cuda/include/native/utils.cuh new file mode 100644 index 0000000000..f217350959 --- /dev/null +++ b/extensions/native/circuit/cuda/include/native/utils.cuh @@ -0,0 +1,13 @@ +#pragma once + +#include "primitives/trace_access.h" +#include "system/memory/controller.cuh" + +__device__ __forceinline__ void mem_fill_base( + MemoryAuxColsFactory &mem_helper, + uint32_t timestamp, + RowSlice base_aux +) { + uint32_t prev = base_aux[COL_INDEX(MemoryBaseAuxCols, prev_timestamp)].asUInt32(); + mem_helper.fill(base_aux, prev, timestamp); +} diff --git a/extensions/native/circuit/cuda/src/poseidon2.cu b/extensions/native/circuit/cuda/src/poseidon2.cu index 59c65626ab..fdbe0d3ce5 100644 --- a/extensions/native/circuit/cuda/src/poseidon2.cu +++ b/extensions/native/circuit/cuda/src/poseidon2.cu @@ -2,6 +2,7 @@ #include "poseidon2-air/columns.cuh" #include "poseidon2-air/params.cuh" #include "poseidon2-air/tracegen.cuh" +#include "native/utils.cuh" #include "primitives/trace_access.h" #include "system/memory/controller.cuh" @@ -38,15 +39,6 @@ template struct NativePoseidon2Cols { T specific[COL_SPECIFIC_WIDTH]; }; -__device__ void mem_fill_base( - MemoryAuxColsFactory &mem_helper, - uint32_t timestamp, - RowSlice base_aux -) { - uint32_t prev = base_aux[COL_INDEX(MemoryBaseAuxCols, prev_timestamp)].asUInt32(); - mem_helper.fill(base_aux, prev, timestamp); -} - template struct Poseidon2Wrapper { template using Cols = NativePoseidon2Cols; using Poseidon2Row = diff --git a/extensions/native/circuit/cuda/src/sumcheck.cu b/extensions/native/circuit/cuda/src/sumcheck.cu new file mode 100644 index 0000000000..99c365135b --- /dev/null +++ b/extensions/native/circuit/cuda/src/sumcheck.cu @@ -0,0 +1,126 @@ +#include "launcher.cuh" +#include "native/sumcheck.cuh" +#include "native/utils.cuh" +#include "primitives/trace_access.h" +#include "system/memory/controller.cuh" + +using namespace native; + +__device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_helper) { + RowSlice specific = row.slice_from(COL_INDEX(NativeSumcheckCols, specific)); + uint32_t start_timestamp = row[COL_INDEX(NativeSumcheckCols, start_timestamp)].asUInt32(); + + if (row[COL_INDEX(NativeSumcheckCols, header_row)] == Fp::one()) { + for (uint32_t i = 0; i < 7; ++i) { + mem_fill_base( + mem_helper, + start_timestamp + i, + specific.slice_from(COL_INDEX(HeaderSpecificCols, read_records[i].base)) + ); + } + uint32_t last_timestamp = row[COL_INDEX(NativeSumcheckCols, last_timestamp)].asUInt32(); + mem_fill_base( + mem_helper, + last_timestamp - 1, + specific.slice_from(COL_INDEX(HeaderSpecificCols, write_records.base)) + ); + } else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base)) + ); + if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[1].base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 2, + specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base)) + ); + } + } else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base)) + ); + if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[1].base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 2, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 3, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base)) + ); + } + } +} + +__global__ void native_sumcheck_tracegen( + Fp *trace, + size_t height, + size_t width, + const Fp *records, + size_t rows_used, + uint32_t *range_checker_ptr, + uint32_t range_checker_num_bins, + uint32_t timestamp_max_bits +) { + uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= height) { + return; + } + + RowSlice row(trace + idx, height); + if (idx < rows_used) { + const Fp *record = records + idx * width; + for (uint32_t col = 0; col < width; ++col) { + row[col] = record[col]; + } + MemoryAuxColsFactory mem_helper( + VariableRangeChecker(range_checker_ptr, range_checker_num_bins), timestamp_max_bits + ); + fill_sumcheck_specific(row, mem_helper); + } else { + row.fill_zero(0, width); + COL_WRITE_VALUE(row, NativeSumcheckCols, is_end, Fp::one()); + } +} + +extern "C" int _native_sumcheck_tracegen( + Fp *d_trace, + size_t height, + size_t width, + const Fp *d_records, + size_t rows_used, + uint32_t *d_range_checker, + uint32_t range_checker_num_bins, + uint32_t timestamp_max_bits +) { + assert((height & (height - 1)) == 0); + assert(width == sizeof(NativeSumcheckCols)); + auto [grid, block] = kernel_launch_params(height); + native_sumcheck_tracegen<<>>( + d_trace, + height, + width, + d_records, + rows_used, + d_range_checker, + range_checker_num_bins, + timestamp_max_bits + ); + return CHECK_KERNEL(); +} diff --git a/extensions/native/circuit/src/cuda_abi.rs b/extensions/native/circuit/src/cuda_abi.rs index ad1a454d7b..5de9124f0d 100644 --- a/extensions/native/circuit/src/cuda_abi.rs +++ b/extensions/native/circuit/src/cuda_abi.rs @@ -235,6 +235,44 @@ pub mod poseidon2_cuda { } } +pub mod sumcheck_cuda { + use super::*; + + extern "C" { + pub fn _native_sumcheck_tracegen( + d_trace: *mut F, + height: usize, + width: usize, + d_records: *const F, + rows_used: usize, + d_range_checker: *mut u32, + range_checker_max_bins: u32, + timestamp_max_bits: u32, + ) -> i32; + } + + pub unsafe fn tracegen( + d_trace: &DeviceBuffer, + height: usize, + width: usize, + d_records: &DeviceBuffer, + rows_used: usize, + d_range_checker: &DeviceBuffer, + timestamp_max_bits: u32, + ) -> Result<(), CudaError> { + CudaError::from_result(_native_sumcheck_tracegen( + d_trace.as_mut_ptr(), + height, + width, + d_records.as_ptr(), + rows_used, + d_range_checker.as_mut_ptr() as *mut u32, + d_range_checker.len() as u32, + timestamp_max_bits, + )) + } +} + pub mod native_loadstore_cuda { use super::*; diff --git a/extensions/native/circuit/src/extension/cuda.rs b/extensions/native/circuit/src/extension/cuda.rs index 9a433fce11..765ce8d6cc 100644 --- a/extensions/native/circuit/src/extension/cuda.rs +++ b/extensions/native/circuit/src/extension/cuda.rs @@ -17,6 +17,7 @@ use crate::{ jal_rangecheck::{JalRangeCheckAir, JalRangeCheckGpu}, loadstore::{NativeLoadStoreAir, NativeLoadStoreChipGpu}, poseidon2::{air::NativePoseidon2Air, NativePoseidon2ChipGpu}, + sumcheck::{air::NativeSumcheckAir, NativeSumcheckChipGpu}, CastFExtension, GpuBackend, Native, }; @@ -75,6 +76,10 @@ impl VmProverExtension let poseidon2 = NativePoseidon2ChipGpu::<1>::new(range_checker.clone(), timestamp_max_bits); inventory.add_executor_chip(poseidon2); + inventory.next_air::()?; + let sumcheck = NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits); + inventory.add_executor_chip(sumcheck); + Ok(()) } } diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 17e99c442a..a33d286cd1 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -3,7 +3,7 @@ use std::borrow::BorrowMut; use openvm_circuit::{ arch::{ CustomBorrow, ExecutionError, MultiRowLayout, MultiRowMetadata, PreflightExecutor, - RecordArena, TraceFiller, VmChipWrapper, VmStateMut, + RecordArena, SizedRecord, TraceFiller, VmChipWrapper, VmStateMut, }, system::{ memory::{online::TracingMemory, MemoryAuxColsFactory}, @@ -76,7 +76,7 @@ impl<'a, F: PrimeField32> // Each instruction record consists solely of some number of contiguously // stored NativeSumcheckCols<...> structs, each of which corresponds to a // single trace row. Trace fillers don't actually need to know how many rows - // each instruction uses, and can thus treat each NativePoseidon2Cols<...> + // each instruction uses, and can thus treat each NativeSumcheckCols<...> // as a single record. NativeSumcheckRecordLayout { metadata: NativeSumcheckMetadata { num_rows: 1 }, @@ -84,6 +84,16 @@ impl<'a, F: PrimeField32> } } +impl SizedRecord for NativeSumcheckRecordMut<'_, F> { + fn size(layout: &NativeSumcheckRecordLayout) -> usize { + layout.metadata.num_rows * size_of::>() + } + + fn alignment(_layout: &NativeSumcheckRecordLayout) -> usize { + align_of::>() + } +} + #[derive(derive_new::new, Copy, Clone)] pub struct NativeSumcheckExecutor; diff --git a/extensions/native/circuit/src/sumcheck/cuda.rs b/extensions/native/circuit/src/sumcheck/cuda.rs index 8b13789179..60aba15b95 100644 --- a/extensions/native/circuit/src/sumcheck/cuda.rs +++ b/extensions/native/circuit/src/sumcheck/cuda.rs @@ -1 +1,57 @@ +use std::{mem::size_of, slice::from_raw_parts, sync::Arc}; +use derive_new::new; +use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero}; +use openvm_circuit_primitives::var_range::VariableRangeCheckerChipGPU; +use openvm_cuda_backend::{ + base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F, +}; +use openvm_cuda_common::copy::MemCopyH2D; +use openvm_stark_backend::{prover::types::AirProvingContext, Chip}; + +use super::columns::NativeSumcheckCols; +use crate::cuda_abi::sumcheck_cuda; + +#[derive(new)] +pub struct NativeSumcheckChipGpu { + pub range_checker: Arc, + pub timestamp_max_bits: usize, +} + +impl Chip for NativeSumcheckChipGpu { + fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext { + let records = arena.allocated(); + if records.is_empty() { + return get_empty_air_proving_ctx::(); + } + + let width = NativeSumcheckCols::::width(); + let record_size = width * size_of::(); + assert_eq!(records.len() % record_size, 0); + + let height = records.len() / record_size; + let padded_height = next_power_of_two_or_zero(height); + let trace = DeviceMatrix::::with_capacity(padded_height, width); + + let record_slice = unsafe { + let ptr = records.as_ptr(); + from_raw_parts(ptr as *const F, records.len() / size_of::()) + }; + let d_records = record_slice.to_device().unwrap(); + + unsafe { + sumcheck_cuda::tracegen( + trace.buffer(), + padded_height, + width, + &d_records, + height, + &self.range_checker.count, + self.timestamp_max_bits as u32, + ) + .unwrap(); + } + + AirProvingContext::simple_no_pis(trace) + } +} diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index a4039028bc..494d82c03a 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -1,6 +1,6 @@ -use openvm_circuit::arch::instructions::program::Program; -#[cfg(not(feature = "cuda"))] -use openvm_circuit::utils::air_test_impl; +use openvm_circuit::{arch::instructions::program::Program, utils::air_test_impl}; +#[cfg(feature = "cuda")] +use openvm_cuda_backend::engine::GpuBabyBearPoseidon2Engine; use openvm_native_circuit::{NativeBuilder, NativeConfig, EXT_DEG}; use openvm_native_compiler::{ asm::{AsmBuilder, AsmCompiler}, @@ -11,11 +11,10 @@ use openvm_native_compiler::{ use openvm_stark_backend::p3_field::{ extension::BinomialExtensionField, FieldAlgebra, FieldExtensionAlgebra, }; +#[cfg(not(feature = "cuda"))] +use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Engine; use openvm_stark_sdk::{ - config::{ - baby_bear_poseidon2::BabyBearPoseidon2Engine, - fri_params::standard_fri_params_with_100_bits_conjectured_security, FriParameters, - }, + config::{fri_params::standard_fri_params_with_100_bits_conjectured_security, FriParameters}, p3_baby_bear::BabyBear, }; From 097e353b718628c7256cc1cac4b7a07101e7c887 Mon Sep 17 00:00:00 2001 From: xkx Date: Tue, 2 Dec 2025 21:45:39 +0800 Subject: [PATCH 12/12] Fix: sumcheck unit test failure (#19) * fix sumcheck unit test * remove hardcoded constants * fix --- .../native/circuit/src/sumcheck/chip.rs | 14 +- .../native/circuit/src/sumcheck/execution.rs | 8 +- extensions/native/recursion/tests/sumcheck.rs | 419 +++++------------- 3 files changed, 120 insertions(+), 321 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index a33d286cd1..bb7cfa7080 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -325,7 +325,8 @@ where let eval_rlc = FieldExtension::multiply(alpha_acc, eval); prod_specific.eval_rlc = eval_rlc; - if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 { + let to_next_round = if mode == NEXT_LAYER_MODE { 1 } else { 0 }; + if round + to_next_round < max_round - 1 { eval_acc = FieldExtension::add(eval_acc, eval_rlc); prod_row.should_acc = F::ONE; prod_row.prod_acc = F::ONE; @@ -445,7 +446,8 @@ where FieldExtension::multiply(alpha_denominator, q_eval), ); logup_specific.eval_rlc = eval_rlc; - if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 { + let to_next_round = if mode == NEXT_LAYER_MODE { 1 } else { 0 }; + if round + to_next_round < max_round - 1 { eval_acc = FieldExtension::add(eval_acc, eval_rlc); logup_row.should_acc = F::ONE; logup_row.logup_acc = F::ONE; @@ -480,12 +482,16 @@ where let specific: &mut ProdSpecificCols = row.specific[..ProdSpecificCols::::width()].borrow_mut(); - eval_acc = FieldExtension::subtract(eval_acc, specific.eval_rlc); + if row.should_acc == F::ONE { + eval_acc = FieldExtension::subtract(eval_acc, specific.eval_rlc); + } row.eval_acc = eval_acc; } else if row.logup_row == F::ONE { let specific: &mut LogupSpecificCols = row.specific[..LogupSpecificCols::::width()].borrow_mut(); - eval_acc = FieldExtension::subtract(eval_acc, specific.eval_rlc); + if row.should_acc == F::ONE { + eval_acc = FieldExtension::subtract(eval_acc, specific.eval_rlc); + } row.eval_acc = eval_acc; } } diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs index 7202e57b00..a475bf9e49 100644 --- a/extensions/native/circuit/src/sumcheck/execution.rs +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -252,7 +252,8 @@ unsafe fn execute_e12_impl( exec_state.vm_write(NATIVE_AS, r_evals_ptr_u32 + (1 + i) * EXT_DEG as u32, &eval); - if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 { + let to_next_round = if mode == NEXT_LAYER_MODE { 1 } else { 0 }; + if round + to_next_round < max_round - 1 { // update eval_acc eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_acc, eval)); } @@ -270,8 +271,8 @@ unsafe fn execute_e12_impl( .vm_read(NATIVE_AS, logup_offset + i) .map(|x: F| x.as_canonical_u32()); let start = calculate_3d_ext_idx( - logup_specs_inner_len, logup_specs_inner_inner_len, + logup_specs_inner_len, i, round, 0, @@ -325,7 +326,8 @@ unsafe fn execute_e12_impl( FieldExtension::multiply(alpha_numerator, p_eval), FieldExtension::multiply(alpha_denominator, q_eval), ); - if mode == NEXT_LAYER_MODE && round + 1 < max_round - 1 { + let to_next_round = if mode == NEXT_LAYER_MODE { 1 } else { 0 }; + if round + to_next_round < max_round - 1 { // update eval_acc eval_acc = FieldExtension::add(eval_acc, eval_rlc); } diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs index 494d82c03a..a500ee6aac 100644 --- a/extensions/native/recursion/tests/sumcheck.rs +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -1,3 +1,5 @@ +use std::iter::{once, repeat_n}; + use openvm_circuit::{arch::instructions::program::Program, utils::air_test_impl}; #[cfg(feature = "cuda")] use openvm_cuda_backend::engine::GpuBabyBearPoseidon2Engine; @@ -17,6 +19,7 @@ use openvm_stark_sdk::{ config::{fri_params::standard_fri_params_with_100_bits_conjectured_security, FriParameters}, p3_baby_bear::BabyBear, }; +use rand::{thread_rng, RngCore}; pub type F = BabyBear; pub type E = BinomialExtensionField; @@ -68,8 +71,36 @@ fn test_sumcheck_layer_eval() { } } +fn new_rand_ext(rng: &mut R) -> C::EF { + C::EF::from_base_slice(&[ + C::F::from_canonical_u32(rng.next_u32()), + C::F::from_canonical_u32(rng.next_u32()), + C::F::from_canonical_u32(rng.next_u32()), + C::F::from_canonical_u32(rng.next_u32()), + ]) +} + fn build_test_program(builder: &mut Builder) { - let ctx_u32s = [3u32, 6, 5, 8, 2, 8, 4, 0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]; + let mut rng = thread_rng(); + // 6 prod specs in 8 layers, 5 logup specs in 8 layers + let round = 3; + let num_prod_specs = 6; + let num_logup_specs = 5; + let num_layers = 8; + let mode = 1; // current_layer + + let mut ctx_u32s = vec![ + round, + num_prod_specs, + num_logup_specs, + num_layers, + 2, + num_layers, + 4, + mode, + ]; + ctx_u32s.extend(repeat_n(num_layers, num_prod_specs + num_logup_specs)); + let ctx: Array> = builder.dyn_array(ctx_u32s.len()); for (idx, n) in ctx_u32s.into_iter().enumerate() { builder.set(&ctx, idx, Usize::from(n as usize)); @@ -94,330 +125,85 @@ fn build_test_program(builder: &mut Builder) { builder.set(&challenges, idx, e); } - #[rustfmt::skip] - let prod_spec_eval_u32s = [ - 1538906710u32, 637535518, 1753132406, 1395236651, - 278806441, 1722910382, 1475548665, 1117874675, - 1578586709, 1826764884, 384068476, 1852240363, - 707958906, 1960944944, 183554399, 1259273357, - 227285124, 243066436, 1718037317, 369721963, - 1752968006, 1061013677, 775617499, 1464907431, - 544300429, 871461966, 135151545, 1343592602, - 1622220528, 643966158, 3932580, 434948358, - 540553922, 1446502052, 153298741, 1191216273, - 265936762, 1463035257, 1237633339, 1797346310, - 1355791584, 389527741, 1741650463, 1728913415, - 1825739540, 1790924136, 460776743, 29536554, - 6842036, 252495270, 1968285155, 299467416, - 49085744, 1499815729, 1098802236, 644489275, - 1827273105, 1888401527, 390077051, 565528894, - 1366177188, 67441791, 958486301, 402056716, - 590379691, 462035406, 633459131, 843304872, - 584100013, 1932496508, 250656031, 146983915, - 1835173157, 939973454, 1844873638, 1916054832, - 1601784696, 167251717, 409107688, 1062925788, - 1291319514, 1790529531, 495655592, 1093359708, - 790197205, 674458164, 195988318, 399764452, - 106865258, 967050329, 350035523, 1109292118, - 1815460301, 281986036, 900636603, 1121197008, - 1228976590, 1879998708, 1924332706, 434695844, - 1159360621, 471397106, 473371067, 1009065094, - 1320176846, 168020789, 1265321929, 1901808675, - 223657700, 1480150183, 1779968584, 144416591, - 304407746, 1864498679, 1482460119, 1554376965, - 1479261548, 1657723043, 1039345063, 1053923521, - 442080513, 1964082352, 691664908, 1941008321, - 1007729002, 860529393, 849697342, 754485488, - 584295923, 1072251466, 1105105254, 996079746, - 1305909868, 1348028973, 122275988, 464050036, - 692807777, 1098809324, 397235220, 596459886, - 1663209783, 720230826, 1422510715, 1760654694, - 544197700, 1417744567, 1938716517, 1571826328, - 1591430185, 1173137446, 175285007, 1541718596, - 1715958587, 1429966110, 583013357, 1667787861, - 109891172, 668253167, 161783842, 296183397, - 1681897325, 1054396117, 264741948, 464026995, - 1907686022, 1532786783, 394869458, 1766734740, - 136047179, 536856195, 376188855, 700633625, - 515518419, 531043483, 60673499, 556496527, - 1743028981, 873954569, 1371062291, 632169731, - 1353239206, 526507035, 1894490088, 589441599, - 1610487168, 1074160583, 366366374, 247602990, - 1535354896, 894493713, 1555870413, 1389854934, - 1897251683, 1525812801, 675621735, 697919636, - 1690274072, 1466810921, 1221110784, 1741995587, - 1877169764, 390876982, 1794129810, 297662156, - 144295349, 417037264, 1290835727, 1654971513, - 1674131303, 1625667423, 1471248832, 1676797844, - 1172916558, 1707775403, 423725211, 1643279661, - 1695774264, 378140395, 1517569394, 1666625392, - 1803981250, 439036260, 247966130, 709534816, - 361144100, 1546096548, 1240886454, 1898161518, - 843262057, 1709259464, 1301015977, 1997626928, - 677153173, 1606710353, 1216038070, 435565562, - 98686333, 1773787396, 267051994, 99395396, - 545509105, 782289675, 1289865975, 1707775075, - 1158993015, 1506576588, 993215179, 1523099397, - 923914455, 1895162386, 284489994, 1444139016, - 1943825680, 466202724, 1632522710, 1384015062, - 723147188, 1284031324, 1430481515, 341213007, - 171192499, 1061688239, 808927167, 83182639, - 759209907, 1728321272, 976049976, 1652071995, - 1002877840, 69880246, 1095135165, 677588420, - 1384715290, 829619452, 170122781, 1958173727, - 13389238, 789379698, 1883383039, 1279195174, - 1618672336, 1192839317, 1348311124, 758896285, - 1939775389, 684108413, 1838340479, 1332232130, - 1070486028, 549228790, 868851698, 1678207843, - 1754321489, 637000403, 647901906, 45343322, - 1768524074, 1167955205, 1816497210, 1609414096, - 1985231742, 1540534482, 232730819, 232221968, - 1509637836, 1480860627, 884647789, 1096458024, - 163721583, 1248032262, 436419506, 1737102298, - 651105860, 452298073, 1064372507, 1792838683, - 619243471, 860127631, 721724708, 950768433, - 279913448, 339693210, 47730422, 1952683911, - 1316500770, 675944216, 386902809, 619333956, - 1194800389, 43989936, 1944372656, 666045666, - 1155873844, 522696968, 58874730, 1497238023, - 421619994, 1980672127, 1657191856, 1913792631, - 1784663131, 1118400672, 1828104993, 1637808383, - 414755472, 775410449, 747132157, 136820101, - 1082674285, 93190395, 357955402, 335652723, - 1192102705, 480365232, 1354935730, 1391829361, - 966662991, 1601510445, 569528575, 545490940, - 1753711688, 807025222, 580374183, 587718008, - 977546290, 1055719519, 1157107032, 562799608, - 859466927, 840450024, 815325134, 936576801, - 1010587056, 246624382, 1808049797, 1098183398, - 1005077390, 772432546, 1976629565, 1003772218, - 1655315418, 1767931114, 982008720, 785023351, - ]; - - let prod_spec_evals: Array> = - builder.dyn_array(prod_spec_eval_u32s.len() / EXT_DEG); - for (idx, n) in prod_spec_eval_u32s.chunks(EXT_DEG).enumerate() { - let e: Ext = builder.constant(C::EF::from_base_slice(&[ - C::F::from_canonical_u32(n[0]), - C::F::from_canonical_u32(n[1]), - C::F::from_canonical_u32(n[2]), - C::F::from_canonical_u32(n[3]), - ])); + let num_prod_evals = num_prod_specs * num_layers * 2; + let prod_spec_evals: Array> = builder.dyn_array(num_prod_evals); + for idx in 0..num_prod_evals { + let e: Ext = builder.constant(new_rand_ext::(&mut rng)); builder.set(&prod_spec_evals, idx, e); } - #[rustfmt::skip] - let logup_spec_eval_u32s = [ - 1522353967u32, 457603397, 421847521, 1352563318, - 1746817766, 737872688, 1087008622, 1850835028, - 456475558, 892966330, 638163666, 148568548, - 678863061, 1334386850, 1896333039, 154585769, - 433618446, 1186936470, 970218722, 1213827097, - 1798557019, 861757965, 119285527, 395360622, - 226164366, 1330279872, 66561048, 785421608, - 1950755756, 1559889596, 348449876, 1090789452, - 257578851, 273164442, 1644906, 295600924, - 1187949602, 1168249609, 469763604, 60929061, - 291163036, 403842501, 1421902433, 1700188477, - 1046093370, 921059131, 1638991894, 464012042, - 96905857, 1370999592, 271896041, 13595534, - 1489760970, 1650552701, 133367846, 25680377, - 377631580, 652729291, 645763356, 426747355, - 482475486, 1877299223, 103226636, 1333832358, - 1399609097, 458536972, 976248802, 1109365280, - 515164588, 1579426417, 1601829549, 607169702, - 852817956, 1980537127, 134138338, 913344050, - 737880920, 476360275, 61624034, 1610624252, - 264461991, 546933535, 937769429, 293346965, - 1522058041, 1012551797, 994330314, 23333322, - 1969510890, 974351570, 2012030621, 120742000, - 450250620, 180547360, 642746933, 1815029950, - 629489142, 1176992624, 723354779, 572648755, - 1218615348, 648847054, 351903235, 723149764, - 248065753, 243829448, 1283393001, 1912627886, - 581641342, 702465306, 205969758, 1061911274, - 1, 0, 0, 0, - 1, 0, 0, 0, - 1703043252, 1467887451, 1714319214, 907866644, - 1542426838, 742609036, 1814393459, 448706641, - 1960340767, 46490834, 186512520, 363973095, - 846448854, 463742343, 2012517527, 40473617, - 9472552, 263483342, 105738598, 586389136, - 254290990, 625150844, 960233097, 1488303724, - 1700231692, 1471714612, 1540211186, 1590246915, - 945341972, 1343225515, 179976237, 34857822, - 276912528, 984309272, 1277293398, 1520924162, - 1823117694, 604836357, 1460812009, 600052559, - 970469338, 1771022707, 181855831, 1445947220, - 467514809, 1514677498, 947030389, 170390653, - 415409007, 1601463730, 204153427, 904614278, - 1855419512, 2009471607, 1352607379, 576586082, - 1343812879, 1176377580, 1166188815, 1592289048, - 761793881, 1529621462, 193034837, 344011596, - 1669461833, 1356800025, 314186361, 586497329, - 1832810846, 1288092861, 1619454491, 732529408, - 737934269, 909504928, 769680420, 1437893101, - 1727002258, 1618231110, 535125583, 153412473, - 1917760929, 588586507, 564531165, 1790797737, - 1666283994, 1366948884, 117673690, 476470378, - 2012274032, 1951406668, 1739767532, 1273142151, - 1591812317, 1900205312, 1912608761, 1734766024, - 1265002082, 1450462894, 749810837, 1329222552, - 745081805, 1231519431, 1420957967, 883846107, - 1995463911, 407795592, 161655852, 125886157, - 995318920, 484905024, 284135318, 551493419, - 406742309, 1089024446, 637339867, 1858138403, - 1230680117, 187078889, 1929517480, 1125646261, - 1, 0, 0, 0, - 1, 0, 0, 0, - 1610035932, 462442436, 831412555, 44798862, - 1748147276, 1911945531, 1329343740, 971894393, - 362147969, 1583335926, 1528700112, 426908674, - 847905883, 447889090, 1050883911, 1883537469, - 1487501632, 964178870, 1818828551, 1980840799, - 340372118, 1697179193, 215113037, 1893217470, - 1138628493, 1788052486, 443362955, 1349213730, - 589553425, 562526667, 1006040406, 1194546769, - 1831034644, 612004157, 730213913, 1068905440, - 371983982, 502900790, 802785198, 822377635, - 1477528437, 501356237, 684668525, 1306043781, - 621032592, 1971342708, 1411586583, 733418745, - 186045462, 1559301855, 323758310, 453170140, - 498381240, 976247416, 631213663, 898017829, - 501459603, 609703046, 1379288251, 177682695, - 912381595, 121915494, 1137416430, 504054388, - 1138277238, 1603388253, 1838013301, 1700271853, - 20488607, 58775264, 217974275, 979141729, - 53136584, 1331566240, 1460303356, 525812787, - 718385521, 1477919263, 1663622276, 1089788203, - 1204483837, 54225863, 290660186, 1441441958, - 134168813, 349638823, 1867912015, 1579183319, - 55528656, 1602973359, 194297109, 949763297, - 101931919, 242300116, 1610052257, 1351823848, - 174522860, 776955925, 1706962365, 808187490, - 1487253852, 431806906, 213982593, 1170647308, - 1776840400, 295916317, 378708073, 381270341, - 457494568, 705823997, 1407301442, 1693003013, - 700310785, 1349874247, 1284363817, 1566253815, - 1014298154, 215294365, 1070968678, 871641358, - 1, 0, 0, 0, - 1, 0, 0, 0, - 1302679751, 1121894357, 368587356, 1564724097, - 733815591, 2012670011, 1146780092, 1439780227, - 1801628424, 838692317, 932318853, 213634365, - 155292454, 1644317110, 1599846194, 978829059, - 1282095862, 1780431647, 527412087, 1024583705, - 804423802, 951808322, 689345230, 180304167, - 1784562773, 1514653374, 2009396440, 1143778943, - 235299446, 1553017484, 475425117, 758292254, - 716575432, 517083432, 1728864125, 418010549, - 43202592, 507659742, 433077118, 1268144019, - 1462778342, 1928073362, 1330130180, 1749624351, - 827401013, 1236194147, 1875519726, 1437946791, - 607293265, 309229599, 1009445595, 1725229718, - 1436309341, 1952606463, 943149111, 291680468, - 1989684076, 1944713370, 1285294139, 399758737, - 1572979232, 213817406, 214840530, 184898060, - 1483844295, 1536616777, 494816009, 217625163, - 529448032, 786640964, 1766471731, 1424140424, - 1721961711, 740275169, 169908711, 913969302, - 1359358267, 1328322971, 593228769, 771095186, - 801680440, 450930656, 1796349530, 1824428677, - 1111258504, 1741666629, 1098430204, 1792001884, - 1679003061, 590088446, 647614538, 1324461639, - 818996796, 229187928, 74288115, 1158900266, - 1512606270, 1381672753, 785927403, 493453164, - 425259497, 1367873539, 931023744, 221202218, - 669580668, 424996238, 1840425275, 1873362670, - 967642716, 263556335, 578560519, 1558449223, - 607579284, 1724012378, 333582342, 1195784167, - 1419727276, 199294290, 138807165, 1061030752, - 1, 0, 0, 0, - 1, 0, 0, 0, - 776332180, 1333076185, 1855163818, 1897408938, - 799274251, 950452503, 691904988, 1205387466, - 659107883, 434394982, 129587940, 639018629, - 659238594, 1957584892, 864291238, 589178070, - 1267157231, 48925338, 200093884, 1953762869, - 1227617341, 1471420621, 193077633, 1007876111, - 228491220, 1377349503, 1889411060, 1807513892, - 1593042934, 1240864695, 1472870721, 583021932, - 598239104, 1862008818, 1811242869, 780768026, - 520870395, 292016292, 322246659, 868240490, - 1715620331, 1183509209, 2010262726, 1003957251, - 264895455, 307755941, 201990485, 1662471178, - 1643997923, 1573129362, 277821143, 388834470, - 943361405, 1449402196, 614413575, 1504113993, - 1860552739, 1755127315, 1734129760, 1232115188, - 803035456, 360488092, 271342171, 1269544258, - 290642673, 660703582, 986842267, 870891877, - 454573044, 1999346236, 701614601, 820253867, - 883282765, 137247873, 1727164949, 1320585493, - 1738664600, 1900116905, 472215154, 1114994489, - 104218174, 1694603079, 771486383, 935361143, - 92277671, 881040480, 925124484, 1464396527, - 100625197, 65290355, 1001454341, 134627585, - 58629702, 1541542242, 568583607, 1706262052, - 530687550, 1303187245, 1010302462, 264001857, - 789816678, 561378226, 827432508, 801307507, - 1613508315, 1650822853, 1603502703, 439320335, - 15283580, 1244486577, 254345266, 1745653280, - 1648250354, 1528271018, 528366563, 1078707735, - 1430767759, 1890467731, 2001894083, 799949326, - 1, 0, 0, 0, - 1, 0, 0, 0, - 1341839494, 1092219735, 755644898, 966729319, - 1914277278, 1545367697, 1765189119, 1693413008, - ]; - - let logup_spec_evals: Array> = - builder.dyn_array(logup_spec_eval_u32s.len() / EXT_DEG); - for (idx, n) in logup_spec_eval_u32s.chunks(EXT_DEG).enumerate() { - let e: Ext = builder.constant(C::EF::from_base_slice(&[ - C::F::from_canonical_u32(n[0]), - C::F::from_canonical_u32(n[1]), - C::F::from_canonical_u32(n[2]), - C::F::from_canonical_u32(n[3]), - ])); + let num_logup_evals = num_logup_specs * num_layers * 4; + let logup_spec_evals: Array> = builder.dyn_array(num_logup_evals); + for idx in 0..num_logup_evals { + let e: Ext = builder.constant(new_rand_ext::(&mut rng)); builder.set(&logup_spec_evals, idx, e); } - #[rustfmt::skip] - let r_evals_u32s = [ - 941378355u32, 1078920879, 696738840, 496039492, - 1555445457, 184545404, 905938226, 1847966044, - 1024875886, 1782716223, 1625644635, 266865456, - 465953066, 1663531470, 757423849, 1957075986, - 1919693393, 839104130, 127480221, 1527842912, - 918650796, 921462354, 575456073, 696646705, - 1585912361, 258186488, 353168830, 1111094691, - 1401166558, 1905942163, 1923083163, 393037255, - 1042127700, 1126793296, 895794165, 1124924482, - 1324266058, 722406365, 1963838171, 968504459, - 1934378800, 714588691, 6465911, 1168379648, - 903786009, 1326035939, 518289228, 418998914, - 1513133474, 1578096058, 617547414, 1658315126, - 68556894, 1697802593, 1346510664, 1709381671, - 345062962, 1254089535, 1002281845, 1882822096, - 700581748, 1431345304, 489112954, 98435728, - 1799886007, 479788390, 223111065, 631662309, - ]; - - let next_layer_evals: Array> = - builder.dyn_array(r_evals_u32s.len() / EXT_DEG); - for (idx, n) in r_evals_u32s.chunks(EXT_DEG).enumerate() { - let e: Ext = builder.constant(C::EF::from_base_slice(&[ - C::F::from_canonical_u32(n[0]), - C::F::from_canonical_u32(n[1]), - C::F::from_canonical_u32(n[2]), - C::F::from_canonical_u32(n[3]), - ])); + let alpha = builder.get(&challenges, 0); + let c1 = builder.get(&challenges, 1); + let c2 = builder.get(&challenges, 2); + + let alpha_acc: Ext = builder.constant(C::EF::ONE); + let eval_acc: Ext = builder.constant(C::EF::ZERO); + + let mut p_evals = vec![]; + for i in 0..num_prod_specs { + let start = num_layers * 2 * i + 2 * round; + let p1 = builder.get(&prod_spec_evals, start); + let p2 = builder.get(&prod_spec_evals, start + 1); + let p_eval: Ext = if mode == 1 { + // current layer + builder.eval(p1 * p2) + } else { + // next layer + builder.eval(p1 * c1 + p2 * c2) + }; + p_evals.push(p_eval); + let eval_rlc: Ext = builder.eval(alpha_acc * p_eval); + builder.assign(&eval_acc, eval_acc + eval_rlc); + builder.assign(&alpha_acc, alpha_acc * alpha); + } - builder.set(&next_layer_evals, idx, e); + let mut logup_p_evals = vec![]; + let mut logup_q_evals = vec![]; + for i in 0..num_logup_specs { + let start = num_layers * 4 * i + 4 * round; + let p1 = builder.get(&logup_spec_evals, start); + let p2 = builder.get(&logup_spec_evals, start + 1); + let q1 = builder.get(&logup_spec_evals, start + 2); + let q2 = builder.get(&logup_spec_evals, start + 3); + let p_eval: Ext = if mode == 1 { + builder.eval(p1 * q2 + p2 * q1) + } else { + builder.eval(p1 * c1 + p2 * c2) + }; + let q_eval: Ext = if mode == 1 { + builder.eval(q1 * q2) + } else { + builder.eval(q1 * c1 + q2 * c2) + }; + + logup_p_evals.push(p_eval); + logup_q_evals.push(q_eval); + + let alpha_denominator: Ext = builder.eval(alpha_acc * alpha); + let eval_rlc: Ext = + builder.eval(alpha_acc * p_eval + alpha_denominator * q_eval); + + builder.assign(&eval_acc, eval_acc + eval_rlc); + builder.assign(&alpha_acc, alpha_acc * alpha * alpha); } + let r_evals = once(eval_acc) + .chain(p_evals.into_iter()) + .chain(logup_p_evals.into_iter()) + .chain(logup_q_evals.into_iter()) + .collect::>(); + + let next_layer_evals: Array> = builder.dyn_array(r_evals.len()); + builder.sumcheck_layer_eval( &ctx, &challenges, @@ -426,5 +212,10 @@ fn build_test_program(builder: &mut Builder) { &next_layer_evals, ); + for (idx, e) in r_evals.into_iter().enumerate() { + let next_eval = builder.get(&next_layer_evals, idx); + builder.assert_ext_eq(next_eval, e); + } + builder.halt(); }