diff --git a/.gitignore b/.gitignore index c6e6aa2049..c87b73f198 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,9 @@ Cargo.lock **/.env .DS_Store +# Log outputs +*.log + .cache/ rustc-* diff --git a/Cargo.lock b/Cargo.lock index 9cb5dfbba2..ad540352cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5840,6 +5840,7 @@ dependencies = [ "metrics", "once_cell", "openvm-circuit", + "openvm-cuda-backend", "openvm-native-circuit", "openvm-native-compiler", "openvm-native-compiler-derive", diff --git a/crates/sdk/src/prover/agg.rs b/crates/sdk/src/prover/agg.rs index 5e22562675..fd4491ddec 100644 --- a/crates/sdk/src/prover/agg.rs +++ b/crates/sdk/src/prover/agg.rs @@ -27,12 +27,12 @@ where E: StarkFriEngine, NativeBuilder: VmBuilder, { - leaf_prover: VmInstance, - leaf_controller: LeafProvingController, + pub leaf_prover: VmInstance, + pub leaf_controller: LeafProvingController, pub internal_prover: VmInstance, #[cfg(feature = "evm-prove")] - root_prover: RootVerifierLocalProver, + pub root_prover: RootVerifierLocalProver, pub num_children_internal: usize, pub max_internal_wrapper_layers: usize, } diff --git a/crates/vm/src/metrics/cycle_tracker/mod.rs b/crates/vm/src/metrics/cycle_tracker/mod.rs index 3d989bc44b..2d569c9774 100644 --- a/crates/vm/src/metrics/cycle_tracker/mod.rs +++ b/crates/vm/src/metrics/cycle_tracker/mod.rs @@ -1,7 +1,18 @@ +/// Stats for a nested span in the execution segment that is tracked by the [`CycleTracker`]. +#[derive(Clone, Debug, Default)] +pub struct SpanInfo { + /// The name of the span. + pub tag: String, + /// The cycle count at which the span starts. + pub start: usize, +} + #[derive(Clone, Debug, Default)] pub struct CycleTracker { /// Stack of span names, with most recent at the end - stack: Vec, + stack: Vec, + /// Depth of the stack. + depth: usize, } impl CycleTracker { @@ -10,29 +21,42 @@ impl CycleTracker { } pub fn top(&self) -> Option<&String> { - self.stack.last() + match self.stack.last() { + Some(span) => Some(&span.tag), + _ => None, + } } /// Starts a new cycle tracker span for the given name. /// If a span already exists for the given name, it ends the existing span and pushes a new one /// to the vec. - pub fn start(&mut self, mut name: String) { + pub fn start(&mut self, mut name: String, cycles_count: usize) { // hack to remove "CT-" prefix if name.starts_with("CT-") { name = name.split_off(3); } - self.stack.push(name); + self.stack.push(SpanInfo { + tag: name.clone(), + start: cycles_count, + }); + let padding = "│ ".repeat(self.depth); + tracing::info!("{}┌╴{}", padding, name); + self.depth += 1; } /// Ends the cycle tracker span for the given name. /// If no span exists for the given name, it panics. - pub fn end(&mut self, mut name: String) { + pub fn end(&mut self, mut name: String, cycles_count: usize) { // hack to remove "CT-" prefix if name.starts_with("CT-") { name = name.split_off(3); } - let stack_top = self.stack.pop(); - assert_eq!(stack_top.unwrap(), name, "Stack top does not match name"); + let SpanInfo { tag, start } = self.stack.pop().unwrap(); + assert_eq!(tag, name, "Stack top does not match name"); + self.depth -= 1; + let padding = "│ ".repeat(self.depth); + let span_cycles = cycles_count - start; + tracing::info!("{}└╴{} cycles", padding, span_cycles); } /// Ends the current cycle tracker span. @@ -42,7 +66,11 @@ impl CycleTracker { /// Get full name of span with all parent names separated by ";" in flamegraph format pub fn get_full_name(&self) -> String { - self.stack.join(";") + self.stack + .iter() + .map(|span_info| span_info.tag.clone()) + .collect::>() + .join(";") } } diff --git a/crates/vm/src/metrics/mod.rs b/crates/vm/src/metrics/mod.rs index 698ba1f1e8..7af679b3cd 100644 --- a/crates/vm/src/metrics/mod.rs +++ b/crates/vm/src/metrics/mod.rs @@ -224,7 +224,7 @@ impl VmMetrics { .map(|(_, func)| (*func).clone()) .unwrap(); if pc == self.current_fn.start { - self.cycle_tracker.start(self.current_fn.name.clone()); + self.cycle_tracker.start(self.current_fn.name.clone(), 0); } else { while let Some(name) = self.cycle_tracker.top() { if name == &self.current_fn.name { diff --git a/extensions/native/circuit/cuda/include/native/poseidon2.cuh b/extensions/native/circuit/cuda/include/native/poseidon2.cuh index 737406839f..206c0e16c0 100644 --- a/extensions/native/circuit/cuda/include/native/poseidon2.cuh +++ b/extensions/native/circuit/cuda/include/native/poseidon2.cuh @@ -62,12 +62,42 @@ template struct SimplePoseidonSpecificCols { MemoryWriteAuxCols write_data_2; }; +template struct MultiObserveCols { + T pc; + T final_timestamp_increment; + T state_ptr; + T input_ptr; + T init_pos; + T len; + T input_register_1; + T input_register_2; + T input_register_3; + T output_register; + T is_first; + T is_last; + T curr_len; + T start_idx; + T end_idx; + T aux_after_start[CHUNK]; + T aux_before_end[CHUNK]; + T aux_read_enabled[CHUNK]; + MemoryReadAuxCols read_data[CHUNK]; + MemoryWriteAuxCols write_data[CHUNK]; + T data[CHUNK]; + T should_permute; + MemoryWriteAuxCols write_sponge_state; + MemoryWriteAuxCols write_final_idx; +}; + template constexpr T constexpr_max(T a, T b) { return a > b ? a : b; } constexpr size_t COL_SPECIFIC_WIDTH = constexpr_max( sizeof(TopLevelSpecificCols), constexpr_max( sizeof(InsideRowSpecificCols), - sizeof(SimplePoseidonSpecificCols) + constexpr_max( + sizeof(SimplePoseidonSpecificCols), + sizeof(MultiObserveCols) + ) ) ); diff --git a/extensions/native/circuit/cuda/include/native/sumcheck.cuh b/extensions/native/circuit/cuda/include/native/sumcheck.cuh new file mode 100644 index 0000000000..052dc03fd5 --- /dev/null +++ b/extensions/native/circuit/cuda/include/native/sumcheck.cuh @@ -0,0 +1,85 @@ +#pragma once + +#include "primitives/constants.h" +#include "system/memory/offline_checker.cuh" + +using namespace native; + +template struct HeaderSpecificCols { + T pc; + T registers[5]; + MemoryReadAuxCols read_records[7]; + MemoryWriteAuxCols write_records; +}; + +template struct ProdSpecificCols { + T data_ptr; + T p[EXT_DEG * 2]; + MemoryReadAuxCols read_records[2]; + T p_evals[EXT_DEG]; + MemoryWriteAuxCols write_record; + T eval_rlc[EXT_DEG]; +}; + +template struct LogupSpecificCols { + T data_ptr; + T pq[EXT_DEG * 4]; + MemoryReadAuxCols read_records[2]; + T p_evals[EXT_DEG]; + T q_evals[EXT_DEG]; + MemoryWriteAuxCols write_records[2]; + T eval_rlc[EXT_DEG]; +}; + +template constexpr T constexpr_max(T a, T b) { + return a > b ? a : b; +} + +constexpr size_t COL_SPECIFIC_WIDTH = constexpr_max( + sizeof(HeaderSpecificCols), + constexpr_max(sizeof(ProdSpecificCols), sizeof(LogupSpecificCols)) +); + +template struct NativeSumcheckCols { + T header_row; + T prod_row; + T logup_row; + T is_end; + + T prod_continued; + T logup_continued; + + T prod_in_round_evaluation; + T prod_next_round_evaluation; + T logup_in_round_evaluation; + T logup_next_round_evaluation; + + T prod_acc; + T logup_acc; + + T first_timestamp; + T start_timestamp; + T last_timestamp; + + T register_ptrs[5]; + + T ctx[EXT_DEG * 2]; + + T prod_nested_len; + T logup_nested_len; + + T curr_prod_n; + T curr_logup_n; + + T alpha[EXT_DEG]; + T challenges[EXT_DEG * 4]; + + T max_round; + T within_round_limit; + T should_acc; + + T eval_acc[EXT_DEG]; + + T specific[COL_SPECIFIC_WIDTH]; +}; + diff --git a/extensions/native/circuit/cuda/include/native/utils.cuh b/extensions/native/circuit/cuda/include/native/utils.cuh new file mode 100644 index 0000000000..f217350959 --- /dev/null +++ b/extensions/native/circuit/cuda/include/native/utils.cuh @@ -0,0 +1,13 @@ +#pragma once + +#include "primitives/trace_access.h" +#include "system/memory/controller.cuh" + +__device__ __forceinline__ void mem_fill_base( + MemoryAuxColsFactory &mem_helper, + uint32_t timestamp, + RowSlice base_aux +) { + uint32_t prev = base_aux[COL_INDEX(MemoryBaseAuxCols, prev_timestamp)].asUInt32(); + mem_helper.fill(base_aux, prev, timestamp); +} diff --git a/extensions/native/circuit/cuda/src/poseidon2.cu b/extensions/native/circuit/cuda/src/poseidon2.cu index 9779b601e4..fdbe0d3ce5 100644 --- a/extensions/native/circuit/cuda/src/poseidon2.cu +++ b/extensions/native/circuit/cuda/src/poseidon2.cu @@ -2,6 +2,7 @@ #include "poseidon2-air/columns.cuh" #include "poseidon2-air/params.cuh" #include "poseidon2-air/tracegen.cuh" +#include "native/utils.cuh" #include "primitives/trace_access.h" #include "system/memory/controller.cuh" @@ -22,6 +23,7 @@ template struct NativePoseidon2Cols { T incorporate_sibling; T inside_row; T simple; + T multi_observe_row; T end_inside_row; T end_top_level; @@ -37,15 +39,6 @@ template struct NativePoseidon2Cols { T specific[COL_SPECIFIC_WIDTH]; }; -__device__ void mem_fill_base( - MemoryAuxColsFactory mem_helper, - uint32_t timestamp, - RowSlice base_aux -) { - uint32_t prev = base_aux[COL_INDEX(MemoryBaseAuxCols, prev_timestamp)].asUInt32(); - mem_helper.fill(base_aux, prev, timestamp); -} - template struct Poseidon2Wrapper { template using Cols = NativePoseidon2Cols; using Poseidon2Row = @@ -58,6 +51,8 @@ template struct Poseidon2Wrapper { ) { if (row[COL_INDEX(Cols, simple)] == Fp::one()) { fill_simple_chunk(row, range_checker, timestamp_max_bits); + } else if (row[COL_INDEX(Cols, multi_observe_row)] == Fp::one()) { + fill_multi_observe_chunk(row, range_checker, timestamp_max_bits); } else { fill_verify_batch_chunk(row, range_checker, timestamp_max_bits); } @@ -335,6 +330,74 @@ template struct Poseidon2Wrapper { } } } + + __device__ static void fill_multi_observe_chunk( + RowSlice row, + VariableRangeChecker range_checker, + uint32_t timestamp_max_bits + ) { + MemoryAuxColsFactory mem_helper(range_checker, timestamp_max_bits); + Poseidon2Row head_row(row); + uint32_t num_rows = head_row.export_col()[0].asUInt32(); + + for (uint32_t idx = 0; idx < num_rows; ++idx) { + RowSlice curr_row = row.shift_row(idx); + fill_inner(curr_row); + fill_multi_observe_specific(curr_row, mem_helper); + } + } + + __device__ static void fill_multi_observe_specific( + RowSlice row, + MemoryAuxColsFactory &mem_helper + ) { + RowSlice specific = row.slice_from(COL_INDEX(Cols, specific)); + if (specific[COL_INDEX(MultiObserveCols, is_first)] == Fp::one()) { + uint32_t very_start_timestamp = + row[COL_INDEX(Cols, very_first_timestamp)].asUInt32(); + for (uint32_t i = 0; i < 4; ++i) { + mem_fill_base( + mem_helper, + very_start_timestamp + i, + specific.slice_from(COL_INDEX(MultiObserveCols, read_data[i].base)) + ); + } + } else { + uint32_t start_timestamp = row[COL_INDEX(Cols, start_timestamp)].asUInt32(); + uint32_t chunk_start = + specific[COL_INDEX(MultiObserveCols, start_idx)].asUInt32(); + uint32_t chunk_end = + specific[COL_INDEX(MultiObserveCols, end_idx)].asUInt32(); + for (uint32_t j = chunk_start; j < chunk_end; ++j) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(MultiObserveCols, write_data[j].base)) + ); + start_timestamp += 2; + } + if (chunk_end >= CHUNK) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(MultiObserveCols, write_sponge_state.base)) + ); + start_timestamp += 1; + } + if (specific[COL_INDEX(MultiObserveCols, is_last)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(MultiObserveCols, write_final_idx.base)) + ); + } + } + } }; template diff --git a/extensions/native/circuit/cuda/src/sumcheck.cu b/extensions/native/circuit/cuda/src/sumcheck.cu new file mode 100644 index 0000000000..99c365135b --- /dev/null +++ b/extensions/native/circuit/cuda/src/sumcheck.cu @@ -0,0 +1,126 @@ +#include "launcher.cuh" +#include "native/sumcheck.cuh" +#include "native/utils.cuh" +#include "primitives/trace_access.h" +#include "system/memory/controller.cuh" + +using namespace native; + +__device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_helper) { + RowSlice specific = row.slice_from(COL_INDEX(NativeSumcheckCols, specific)); + uint32_t start_timestamp = row[COL_INDEX(NativeSumcheckCols, start_timestamp)].asUInt32(); + + if (row[COL_INDEX(NativeSumcheckCols, header_row)] == Fp::one()) { + for (uint32_t i = 0; i < 7; ++i) { + mem_fill_base( + mem_helper, + start_timestamp + i, + specific.slice_from(COL_INDEX(HeaderSpecificCols, read_records[i].base)) + ); + } + uint32_t last_timestamp = row[COL_INDEX(NativeSumcheckCols, last_timestamp)].asUInt32(); + mem_fill_base( + mem_helper, + last_timestamp - 1, + specific.slice_from(COL_INDEX(HeaderSpecificCols, write_records.base)) + ); + } else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[0].base)) + ); + if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(ProdSpecificCols, read_records[1].base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 2, + specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base)) + ); + } + } else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[0].base)) + ); + if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(LogupSpecificCols, read_records[1].base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 2, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) + ); + mem_fill_base( + mem_helper, + start_timestamp + 3, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base)) + ); + } + } +} + +__global__ void native_sumcheck_tracegen( + Fp *trace, + size_t height, + size_t width, + const Fp *records, + size_t rows_used, + uint32_t *range_checker_ptr, + uint32_t range_checker_num_bins, + uint32_t timestamp_max_bits +) { + uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= height) { + return; + } + + RowSlice row(trace + idx, height); + if (idx < rows_used) { + const Fp *record = records + idx * width; + for (uint32_t col = 0; col < width; ++col) { + row[col] = record[col]; + } + MemoryAuxColsFactory mem_helper( + VariableRangeChecker(range_checker_ptr, range_checker_num_bins), timestamp_max_bits + ); + fill_sumcheck_specific(row, mem_helper); + } else { + row.fill_zero(0, width); + COL_WRITE_VALUE(row, NativeSumcheckCols, is_end, Fp::one()); + } +} + +extern "C" int _native_sumcheck_tracegen( + Fp *d_trace, + size_t height, + size_t width, + const Fp *d_records, + size_t rows_used, + uint32_t *d_range_checker, + uint32_t range_checker_num_bins, + uint32_t timestamp_max_bits +) { + assert((height & (height - 1)) == 0); + assert(width == sizeof(NativeSumcheckCols)); + auto [grid, block] = kernel_launch_params(height); + native_sumcheck_tracegen<<>>( + d_trace, + height, + width, + d_records, + rows_used, + d_range_checker, + range_checker_num_bins, + timestamp_max_bits + ); + return CHECK_KERNEL(); +} diff --git a/extensions/native/circuit/src/cuda_abi.rs b/extensions/native/circuit/src/cuda_abi.rs index ad1a454d7b..5de9124f0d 100644 --- a/extensions/native/circuit/src/cuda_abi.rs +++ b/extensions/native/circuit/src/cuda_abi.rs @@ -235,6 +235,44 @@ pub mod poseidon2_cuda { } } +pub mod sumcheck_cuda { + use super::*; + + extern "C" { + pub fn _native_sumcheck_tracegen( + d_trace: *mut F, + height: usize, + width: usize, + d_records: *const F, + rows_used: usize, + d_range_checker: *mut u32, + range_checker_max_bins: u32, + timestamp_max_bits: u32, + ) -> i32; + } + + pub unsafe fn tracegen( + d_trace: &DeviceBuffer, + height: usize, + width: usize, + d_records: &DeviceBuffer, + rows_used: usize, + d_range_checker: &DeviceBuffer, + timestamp_max_bits: u32, + ) -> Result<(), CudaError> { + CudaError::from_result(_native_sumcheck_tracegen( + d_trace.as_mut_ptr(), + height, + width, + d_records.as_ptr(), + rows_used, + d_range_checker.as_mut_ptr() as *mut u32, + d_range_checker.len() as u32, + timestamp_max_bits, + )) + } +} + pub mod native_loadstore_cuda { use super::*; diff --git a/extensions/native/circuit/src/extension/cuda.rs b/extensions/native/circuit/src/extension/cuda.rs index 9a433fce11..765ce8d6cc 100644 --- a/extensions/native/circuit/src/extension/cuda.rs +++ b/extensions/native/circuit/src/extension/cuda.rs @@ -17,6 +17,7 @@ use crate::{ jal_rangecheck::{JalRangeCheckAir, JalRangeCheckGpu}, loadstore::{NativeLoadStoreAir, NativeLoadStoreChipGpu}, poseidon2::{air::NativePoseidon2Air, NativePoseidon2ChipGpu}, + sumcheck::{air::NativeSumcheckAir, NativeSumcheckChipGpu}, CastFExtension, GpuBackend, Native, }; @@ -75,6 +76,10 @@ impl VmProverExtension let poseidon2 = NativePoseidon2ChipGpu::<1>::new(range_checker.clone(), timestamp_max_bits); inventory.add_executor_chip(poseidon2); + inventory.next_air::()?; + let sumcheck = NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits); + inventory.add_executor_chip(sumcheck); + Ok(()) } } diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index 98b2fe774d..9f3e2035ad 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -17,7 +17,8 @@ use openvm_instructions::{program::DEFAULT_PC_STEP, LocalOpcode, PhantomDiscrimi use openvm_native_compiler::{ CastfOpcode, FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode, NativeJalOpcode, NativeLoadStore4Opcode, NativeLoadStoreOpcode, NativePhantom, - NativeRangeCheckOpcode, Poseidon2Opcode, VerifyBatchOpcode, BLOCK_LOAD_STORE_SIZE, + NativeRangeCheckOpcode, Poseidon2Opcode, SumcheckOpcode, VerifyBatchOpcode, + BLOCK_LOAD_STORE_SIZE, }; use openvm_poseidon2_air::Poseidon2Config; use openvm_rv32im_circuit::BranchEqualCoreAir; @@ -61,6 +62,10 @@ use crate::{ chip::{NativePoseidon2Executor, NativePoseidon2Filler}, NativePoseidon2Chip, }, + sumcheck::{ + air::NativeSumcheckAir, + chip::{NativeSumcheckChip, NativeSumcheckExecutor, NativeSumcheckFiller}, + }, }; cfg_if::cfg_if! { @@ -94,6 +99,7 @@ pub enum NativeExecutor { FieldExtension(FieldExtensionExecutor), FriReducedOpening(FriReducedOpeningExecutor), VerifyBatch(NativePoseidon2Executor), + TowerVerify(NativeSumcheckExecutor), } impl VmExecutionExtension for Native { @@ -165,9 +171,16 @@ impl VmExecutionExtension for Native { VerifyBatchOpcode::VERIFY_BATCH.global_opcode(), Poseidon2Opcode::PERM_POS2.global_opcode(), Poseidon2Opcode::COMP_POS2.global_opcode(), + Poseidon2Opcode::MULTI_OBSERVE.global_opcode(), ], )?; + let tower_verify = NativeSumcheckExecutor::new(); + inventory.add_executor( + tower_verify, + [SumcheckOpcode::SUMCHECK_LAYER_EVAL.global_opcode()], + )?; + inventory.add_phantom_sub_executor( NativeHintInputSubEx, PhantomDiscriminant(NativePhantom::HintInput as u16), @@ -261,6 +274,9 @@ where ); inventory.add_air(verify_batch); + let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge); + inventory.add_air(tower_evaluate); + Ok(()) } } @@ -341,6 +357,9 @@ where ); inventory.add_executor_chip(poseidon2); + let tower_verify = NativeSumcheckChip::new(NativeSumcheckFiller::new(), mem_helper.clone()); + inventory.add_executor_chip(tower_verify); + Ok(()) } } diff --git a/extensions/native/circuit/src/field_extension/core.rs b/extensions/native/circuit/src/field_extension/core.rs index 5afaf74af5..a7d535a14b 100644 --- a/extensions/native/circuit/src/field_extension/core.rs +++ b/extensions/native/circuit/src/field_extension/core.rs @@ -254,10 +254,10 @@ pub(crate) struct FieldExtension; impl FieldExtension { pub(crate) fn add(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG] where - V: Copy, + V: Clone, V: Add, { - array::from_fn(|i| x[i] + y[i]) + array::from_fn(|i| x[i].clone() + y[i].clone()) } pub(crate) fn subtract(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG] diff --git a/extensions/native/circuit/src/fri/mod.rs b/extensions/native/circuit/src/fri/mod.rs index 1e1ec65cb8..4a0d3847d5 100644 --- a/extensions/native/circuit/src/fri/mod.rs +++ b/extensions/native/circuit/src/fri/mod.rs @@ -542,7 +542,7 @@ fn assert_array_eq, I2: Into, const } } -fn elem_to_ext(elem: F) -> [F; EXT_DEG] { +pub fn elem_to_ext(elem: F) -> [F; EXT_DEG] { let mut ret = [F::ZERO; EXT_DEG]; ret[0] = elem; ret diff --git a/extensions/native/circuit/src/lib.rs b/extensions/native/circuit/src/lib.rs index b5db4f0010..ce257c9c22 100644 --- a/extensions/native/circuit/src/lib.rs +++ b/extensions/native/circuit/src/lib.rs @@ -42,9 +42,11 @@ mod fri; mod jal_rangecheck; mod loadstore; mod poseidon2; +mod sumcheck; mod extension; pub use extension::*; +pub use field_extension::EXT_DEG; mod utils; #[cfg(any(test, feature = "test-utils"))] diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index adf2c09a62..baf18b06a3 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -1,14 +1,15 @@ use std::{array::from_fn, borrow::Borrow, sync::Arc}; +use itertools::Itertools; use openvm_circuit::{ arch::{ExecutionBridge, ExecutionState}, - system::memory::{offline_checker::MemoryBridge, MemoryAddress}, + system::memory::{offline_checker::MemoryBridge, MemoryAddress, CHUNK}, }; use openvm_circuit_primitives::utils::not; use openvm_instructions::LocalOpcode; use openvm_native_compiler::{ conversion::AS, - Poseidon2Opcode::{COMP_POS2, PERM_POS2}, + Poseidon2Opcode::{COMP_POS2, MULTI_OBSERVE, PERM_POS2}, VerifyBatchOpcode::VERIFY_BATCH, }; use openvm_poseidon2_air::{ @@ -26,10 +27,9 @@ use openvm_stark_backend::{ use crate::poseidon2::{ chip::{NUM_INITIAL_READS, NUM_SIMPLE_ACCESSES}, columns::{ - InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, + InsideRowSpecificCols, MultiObserveCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, TopLevelSpecificCols, }, - CHUNK, }; #[derive(Clone, Debug)] @@ -90,6 +90,7 @@ impl Air incorporate_sibling, inside_row, simple, + multi_observe_row, end_inside_row, end_top_level, start_top_level, @@ -117,7 +118,9 @@ impl Air builder.assert_bool(incorporate_sibling); builder.assert_bool(inside_row); builder.assert_bool(simple); - let enabled = incorporate_row + incorporate_sibling + inside_row + simple; + builder.assert_bool(multi_observe_row); + let enabled = + incorporate_row + incorporate_sibling + inside_row + simple + multi_observe_row; builder.assert_bool(enabled.clone()); builder.assert_bool(end_inside_row); builder.when(end_inside_row).assert_one(inside_row); @@ -698,6 +701,312 @@ impl Air &write_data_2, ) .eval(builder, simple * is_permute); + + //// multi_observe contraints + let multi_observe_specific: &MultiObserveCols = + specific[..MultiObserveCols::::width()].borrow(); + let next_multi_observe_specific: &MultiObserveCols = + next.specific[..MultiObserveCols::::width()].borrow(); + let &MultiObserveCols { + pc, + final_timestamp_increment, + state_ptr, + input_ptr, + init_pos, + len, + is_first, + is_last, + curr_len, + start_idx, + end_idx, + aux_after_start, + aux_before_end, + aux_read_enabled, + read_data, + write_data, + data, + should_permute, + write_sponge_state, + write_final_idx, + input_register_1, + input_register_2, + input_register_3, + output_register, + } = multi_observe_specific; + + builder.when(multi_observe_row).assert_bool(is_first); + builder.when(multi_observe_row).assert_bool(is_last); + builder.when(multi_observe_row).assert_bool(should_permute); + + self.execution_bridge + .execute_and_increment_pc( + AB::F::from_canonical_usize(MULTI_OBSERVE.global_opcode().as_usize()), + [ + output_register.into(), + input_register_1.into(), + input_register_2.into(), + self.address_space.into(), + self.address_space.into(), + input_register_3.into(), + ], + ExecutionState::new(pc, very_first_timestamp), + final_timestamp_increment, + ) + .eval(builder, multi_observe_row * is_first); + + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, output_register), + [state_ptr], + very_first_timestamp, + &read_data[0], + ) + .eval(builder, multi_observe_row * is_first); + + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, input_register_1), + [init_pos], + very_first_timestamp + AB::F::ONE, + &read_data[1], + ) + .eval(builder, multi_observe_row * is_first); + + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, input_register_2), + [input_ptr], + very_first_timestamp + AB::F::TWO, + &read_data[2], + ) + .eval(builder, multi_observe_row * is_first); + + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, input_register_3), + [len], + very_first_timestamp + AB::F::from_canonical_usize(3), + &read_data[3], + ) + .eval(builder, multi_observe_row * is_first); + + for i in 0..CHUNK { + let i_var = AB::F::from_canonical_usize(i); + self.memory_bridge + .read( + MemoryAddress::new( + self.address_space, + input_ptr + curr_len + i_var - start_idx, + ), + [data[i]], + start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO, + &read_data[i], + ) + .eval(builder, multi_observe_row * aux_read_enabled[i]); + + self.memory_bridge + .write( + MemoryAddress::new(self.address_space, state_ptr + i_var), + [data[i]], + start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO + AB::F::ONE, + &write_data[i], + ) + .eval(builder, multi_observe_row * aux_read_enabled[i]); + } + + for i in 0..(CHUNK - 1) { + builder + .when(multi_observe_row) + .when(aux_after_start[i]) + .assert_one(aux_after_start[i + 1]); + } + + for i in 1..CHUNK { + builder + .when(multi_observe_row) + .when(aux_before_end[i]) + .assert_one(aux_before_end[i - 1]); + } + + for i in 0..CHUNK { + builder + .when(multi_observe_row) + .assert_bool(aux_after_start[i]); + builder + .when(multi_observe_row) + .assert_bool(aux_before_end[i]); + builder + .when(multi_observe_row) + .when(is_first) + .assert_zero(aux_read_enabled[i]); + builder + .when(multi_observe_row) + .assert_eq(aux_after_start[i] * aux_before_end[i], aux_read_enabled[i]); + } + + builder + .when(multi_observe_row) + .when(not(is_first)) + .assert_eq( + aux_after_start[0] + + aux_after_start[1] + + aux_after_start[2] + + aux_after_start[3] + + aux_after_start[4] + + aux_after_start[5] + + aux_after_start[6] + + aux_after_start[7], + AB::Expr::from_canonical_usize(CHUNK) - start_idx.into(), + ); + + builder + .when(multi_observe_row) + .when(not(is_first)) + .assert_eq( + aux_before_end[0] + + aux_before_end[1] + + aux_before_end[2] + + aux_before_end[3] + + aux_before_end[4] + + aux_before_end[5] + + aux_before_end[6] + + aux_before_end[7], + end_idx, + ); + + let full_sponge_output = from_fn::<_, { CHUNK * 2 }, _>(|i| { + local.inner.ending_full_rounds[BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS - 1].post[i] + }); + + self.memory_bridge + .write( + MemoryAddress::new(self.address_space, state_ptr), + full_sponge_output, + start_timestamp + (end_idx - start_idx) * AB::F::TWO, + &write_sponge_state, + ) + .eval(builder, multi_observe_row * should_permute); + + // enforce that prev_data is permutation input + write_sponge_state + .prev_data() + .iter() + .zip_eq(local.inner.inputs.iter()) + .for_each(|(a, b)| { + builder + .when(multi_observe_row * should_permute) + .assert_eq(*a, *b); + }); + + builder + .when(multi_observe_row) + .when(aux_read_enabled[CHUNK - 1]) + .assert_one(should_permute); + + // final_idx = aux_read_enabled[CHUNK-1] * 0 + (1 - aux_read_enabled[CHUNK-1]) * end_idx + let final_idx = aux_read_enabled[CHUNK - 1] * AB::Expr::ZERO + + (AB::Expr::ONE - aux_read_enabled[CHUNK - 1]) * end_idx; + self.memory_bridge + .write( + MemoryAddress::new(self.address_space, input_register_1), + [final_idx], + start_timestamp + (end_idx - start_idx) * AB::F::TWO + should_permute, + &write_final_idx, + ) + .eval(builder, multi_observe_row * is_last); + + // Field transitions + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq( + next_multi_observe_specific.curr_len, + multi_observe_specific.curr_len + end_idx - start_idx, + ); + + // Boundary conditions + builder + .when(multi_observe_row) + .when(is_first) + .assert_zero(curr_len); + + builder + .when(multi_observe_row) + .when(is_last) + .assert_eq(curr_len + (end_idx - start_idx), len); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_one(multi_observe_row); + + builder + .when(multi_observe_row) + .when(not(is_last)) + .assert_one(next.multi_observe_row); + + // Fields remain same across same instance + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(state_ptr, next_multi_observe_specific.state_ptr); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(input_ptr, next_multi_observe_specific.input_ptr); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(init_pos, next_multi_observe_specific.init_pos); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(len, next_multi_observe_specific.len); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq( + input_register_1, + next_multi_observe_specific.input_register_1, + ); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq( + input_register_2, + next_multi_observe_specific.input_register_2, + ); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq( + input_register_3, + next_multi_observe_specific.input_register_3, + ); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(output_register, next_multi_observe_specific.output_register); + + // Timestamp constraints + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(very_first_timestamp, next.very_first_timestamp); + + /* + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(next.start_timestamp, start_timestamp + is_first * AB::F::from_canonical_usize(4) + (end_idx - start_idx) * AB::F::TWO + should_permute * AB::F::TWO); + */ } } diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 331cb9dbd0..770efc7307 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -3,16 +3,14 @@ use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::*, system::{ - memory::{offline_checker::MemoryBaseAuxCols, online::TracingMemory, MemoryAuxColsFactory}, - native_adapter::util::{ - memory_read_native, tracing_read_native, tracing_write_native_inplace, - }, + memory::{online::TracingMemory, MemoryAuxColsFactory}, + native_adapter::util::{memory_read_native, tracing_write_native_inplace}, }, }; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{ conversion::AS, - Poseidon2Opcode::{COMP_POS2, PERM_POS2}, + Poseidon2Opcode::{COMP_POS2, MULTI_OBSERVE, PERM_POS2}, VerifyBatchOpcode::VERIFY_BATCH, }; use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubChip, Poseidon2SubCols}; @@ -23,12 +21,16 @@ use openvm_stark_backend::{ p3_maybe_rayon::prelude::{IntoParallelIterator, ParallelSliceMut, *}, }; -use crate::poseidon2::{ - columns::{ - InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, - TopLevelSpecificCols, +use crate::{ + mem_fill_helper, + poseidon2::{ + columns::{ + InsideRowSpecificCols, MultiObserveCols, NativePoseidon2Cols, + SimplePoseidonSpecificCols, TopLevelSpecificCols, + }, + CHUNK, }, - CHUNK, + tracing_read_native_helper, }; #[derive(Clone)] @@ -644,6 +646,183 @@ where if !self.optimistic { assert_eq!(commit, root); } + } else if instruction.opcode == MULTI_OBSERVE.global_opcode() { + let &Instruction { + a: state_ptr_register, + b: init_pos_register, + c: input_ptr_register, + d: register_address_space, + e: data_address_space, + f: len_register, + .. + } = instruction; + + assert_eq!( + register_address_space, + F::from_canonical_u32(AS::Native as u32) + ); + assert_eq!(data_address_space, F::from_canonical_u32(AS::Native as u32)); + + let [init_pos]: [F; 1] = + memory_read_native(state.memory.data(), init_pos_register.as_canonical_u32()); + let [input_len]: [F; 1] = + memory_read_native(state.memory.data(), len_register.as_canonical_u32()); + + let mut len = input_len.as_canonical_u32() as usize; + let mut pos = init_pos.as_canonical_u32() as usize; + let mut chunks: Vec<(usize, usize)> = vec![]; + + const NUM_HEAD_ACCESSES: usize = 4; + let mut final_timestamp_inc = NUM_HEAD_ACCESSES; + while len > 0 { + if len >= (CHUNK - pos) { + chunks.push((pos, CHUNK)); + len -= CHUNK - pos; + final_timestamp_inc += 2 * (CHUNK - pos) + 1; + pos = 0; + } else { + chunks.push((pos, pos + len)); + final_timestamp_inc += 2 * len; + len = 0; + pos += len; + } + } + final_timestamp_inc += 1; // write back to init_pos_register + + let allocated_rows = arena + .alloc(MultiRowLayout::new(NativePoseidon2Metadata { + num_rows: 1 + chunks.len(), + })) + .0; + let head_cols = &mut allocated_rows[0]; + let head_multi_observe_cols: &mut MultiObserveCols = + head_cols.specific[..MultiObserveCols::::width()].borrow_mut(); + + let [state_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + state_ptr_register.as_canonical_u32(), + head_multi_observe_cols.read_data[0].as_mut(), + ); + let [init_pos]: [F; 1] = tracing_read_native_helper( + state.memory, + init_pos_register.as_canonical_u32(), + head_multi_observe_cols.read_data[1].as_mut(), + ); + let [input_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + input_ptr_register.as_canonical_u32(), + head_multi_observe_cols.read_data[2].as_mut(), + ); + let [input_len]: [F; 1] = tracing_read_native_helper( + state.memory, + len_register.as_canonical_u32(), + head_multi_observe_cols.read_data[3].as_mut(), + ); + + let input_ptr_u32 = input_ptr.as_canonical_u32(); + let state_ptr_u32 = state_ptr.as_canonical_u32(); + + let init_timestamp = F::from_canonical_u32(init_timestamp_u32); + + for (i, cols) in allocated_rows.iter_mut().enumerate() { + let multi_observe_cols: &mut MultiObserveCols = + cols.specific[..MultiObserveCols::::width()].borrow_mut(); + multi_observe_cols.input_register_1 = init_pos_register; + multi_observe_cols.input_register_2 = input_ptr_register; + multi_observe_cols.input_register_3 = len_register; + multi_observe_cols.output_register = state_ptr_register; + multi_observe_cols.init_pos = init_pos; + multi_observe_cols.input_ptr = input_ptr; + multi_observe_cols.state_ptr = state_ptr; + multi_observe_cols.len = input_len; + + cols.multi_observe_row = F::ONE; + cols.very_first_timestamp = init_timestamp; + + if i == 0 { + // head row + cols.inner.export = F::from_canonical_u32(1 + chunks.len() as u32); + multi_observe_cols.pc = F::from_canonical_u32(*state.pc); + multi_observe_cols.final_timestamp_increment = + F::from_canonical_usize(final_timestamp_inc); + multi_observe_cols.is_first = F::ONE; + multi_observe_cols.is_last = F::ZERO; + multi_observe_cols.curr_len = F::ZERO; + multi_observe_cols.should_permute = F::ZERO; + } + } + + let mut input_idx: usize = 0; + let mut cur_timestamp = init_timestamp_u32 + NUM_HEAD_ACCESSES as u32; + let num_chunks = chunks.len(); + for (i, ((chunk_start, chunk_end), cols)) in chunks + .into_iter() + .zip(allocated_rows.iter_mut().skip(1)) + .enumerate() + { + let multi_observe_cols: &mut MultiObserveCols = + cols.specific[..MultiObserveCols::::width()].borrow_mut(); + + cols.start_timestamp = F::from_canonical_u32(cur_timestamp); + + multi_observe_cols.start_idx = F::from_canonical_usize(chunk_start); + multi_observe_cols.end_idx = F::from_canonical_usize(chunk_end); + + multi_observe_cols.is_first = F::ZERO; + multi_observe_cols.is_last = if i == num_chunks - 1 { F::ONE } else { F::ZERO }; + multi_observe_cols.curr_len = F::from_canonical_usize(input_idx); + + for j in chunk_start..CHUNK { + multi_observe_cols.aux_after_start[j] = F::ONE; + } + for j in 0..chunk_end { + multi_observe_cols.aux_before_end[j] = F::ONE; + } + for j in chunk_start..chunk_end { + let n_f: [F; 1] = tracing_read_native_helper( + state.memory, + input_ptr_u32 + input_idx as u32, + multi_observe_cols.read_data[j].as_mut(), + ); + multi_observe_cols.aux_read_enabled[j] = F::ONE; + tracing_write_native_inplace( + state.memory, + state_ptr_u32 + j as u32, + n_f, + &mut multi_observe_cols.write_data[j], + ); + multi_observe_cols.data[j] = n_f[0]; + input_idx += 1; + cur_timestamp += 2; + } + + let permutation_input: [F; 16] = + memory_read_native(state.memory.data(), state_ptr_u32); + if chunk_end >= CHUNK { + multi_observe_cols.should_permute = F::ONE; + cols.inner.inputs.clone_from_slice(&permutation_input); + let output = self.subchip.permute(permutation_input); + tracing_write_native_inplace( + state.memory, + state_ptr_u32, + std::array::from_fn(|i| output[i]), + &mut multi_observe_cols.write_sponge_state, + ); + cur_timestamp += 1; + } else { + multi_observe_cols.should_permute = F::ZERO; + cols.inner.inputs.clone_from_slice(&permutation_input); + } + if i == num_chunks - 1 { + let final_idx = F::from_canonical_usize(chunk_end % CHUNK); + tracing_write_native_inplace( + state.memory, + init_pos_register.as_canonical_u32(), + [final_idx], + &mut multi_observe_cols.write_final_idx, + ); + } + } } else { unreachable!() } @@ -659,6 +838,8 @@ where String::from("PERM_POS2") } else if opcode == COMP_POS2.global_opcode().as_usize() { String::from("COMP_POS2") + } else if opcode == MULTI_OBSERVE.global_opcode().as_usize() { + String::from("MULTI_OBSERVE") } else { unreachable!("unsupported opcode: {}", opcode) } @@ -686,6 +867,10 @@ impl TraceFiller let (curr, rest) = if cols.simple.is_one() { row_idx += 1; row_slice.split_at_mut(width) + } else if cols.multi_observe_row.is_one() { + let total_num_row = cols.inner.export.as_canonical_u32() as usize; + row_idx += total_num_row; + row_slice.split_at_mut(total_num_row * width) } else { let num_non_inside_row = cols.inner.export.as_canonical_u32() as usize; let start = (num_non_inside_row - 1) * width; @@ -702,6 +887,8 @@ impl TraceFiller let cols: &NativePoseidon2Cols = chunk_slice[..width].borrow(); if cols.simple.is_one() { self.fill_simple_chunk(mem_helper, chunk_slice); + } else if cols.multi_observe_row.is_one() { + self.fill_multi_observe_chunk(mem_helper, chunk_slice); } else { self.fill_verify_batch_chunk(mem_helper, chunk_slice); } @@ -959,29 +1146,99 @@ impl NativePoseidon2Filler, + chunk_slice: &mut [F], + ) { + let inner_width = self.subchip.air.width(); + let width = NativePoseidon2Cols::::width(); + let head_cols: &mut NativePoseidon2Cols = + chunk_slice[..width].borrow_mut(); + let num_rows = head_cols.inner.export.as_canonical_u32() as usize; + + let head_multi_observe_cols: &mut MultiObserveCols = + head_cols.specific[..MultiObserveCols::::width()].borrow_mut(); + let start_timestamp_u32 = head_cols.very_first_timestamp.as_canonical_u32(); + + // state_ptr, init_pos, input_ptr, len + mem_fill_helper( + mem_helper, + start_timestamp_u32, + head_multi_observe_cols.read_data[0].as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 1, + head_multi_observe_cols.read_data[1].as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 2, + head_multi_observe_cols.read_data[2].as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 3, + head_multi_observe_cols.read_data[3].as_mut(), + ); + + // generate permutation traces for each row + for row_idx in 0..num_rows { + let cols: &NativePoseidon2Cols = chunk_slice + [row_idx * width..(row_idx + 1) * width] + .as_ref() + .borrow(); + let inner_cols = &self.subchip.generate_trace(vec![cols.inner.inputs]).values; + chunk_slice[row_idx * width..(row_idx + 1) * width][..inner_width] + .copy_from_slice(inner_cols); + } + + for row_idx in 1..num_rows { + let cols: &mut NativePoseidon2Cols = + chunk_slice[row_idx * width..(row_idx + 1) * width].borrow_mut(); + let multi_observe_cols: &mut MultiObserveCols = + cols.specific[..MultiObserveCols::::width()].borrow_mut(); + + let mut start_timestamp_u32 = cols.start_timestamp.as_canonical_u32(); + let chunk_start = multi_observe_cols.start_idx.as_canonical_u32(); + let chunk_end = multi_observe_cols.end_idx.as_canonical_u32(); + + for j in chunk_start..chunk_end { + mem_fill_helper( + mem_helper, + start_timestamp_u32, + multi_observe_cols.read_data[j as usize].as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 1, + multi_observe_cols.write_data[j as usize].as_mut(), + ); + + start_timestamp_u32 += 2; + } + + if chunk_end >= CHUNK as u32 { + mem_fill_helper( + mem_helper, + start_timestamp_u32, + multi_observe_cols.write_sponge_state.as_mut(), + ); + start_timestamp_u32 += 1; + } + if row_idx == num_rows - 1 { + mem_fill_helper( + mem_helper, + start_timestamp_u32, + multi_observe_cols.write_final_idx.as_mut(), + ); + } + } + } + #[inline(always)] fn poseidon2_output_from_trace(inner: &Poseidon2SubCols) -> &[F; 2 * CHUNK] { &inner.ending_full_rounds.last().unwrap().post } } - -fn tracing_read_native_helper( - memory: &mut TracingMemory, - ptr: u32, - base_aux: &mut MemoryBaseAuxCols, -) -> [F; BLOCK_SIZE] { - let mut prev_ts = 0; - let ret = tracing_read_native(memory, ptr, &mut prev_ts); - base_aux.set_prev(F::from_canonical_u32(prev_ts)); - ret -} - -/// Fill `MemoryBaseAuxCols`, assuming that the `prev_timestamp` is already set in `base_aux`. -fn mem_fill_helper( - mem_helper: &MemoryAuxColsFactory, - timestamp: u32, - base_aux: &mut MemoryBaseAuxCols, -) { - let prev_ts = base_aux.prev_timestamp.as_canonical_u32(); - mem_helper.fill(prev_ts, timestamp, base_aux); -} diff --git a/extensions/native/circuit/src/poseidon2/columns.rs b/extensions/native/circuit/src/poseidon2/columns.rs index 6c47c23245..abb8db54a2 100644 --- a/extensions/native/circuit/src/poseidon2/columns.rs +++ b/extensions/native/circuit/src/poseidon2/columns.rs @@ -28,6 +28,8 @@ pub struct NativePoseidon2Cols { pub inside_row: T, /// Indicates that this row is a simple row. pub simple: T, + /// Indicates that this row is a multi_observe row. + pub multi_observe_row: T, /// Indicates the last row in an inside-row block. pub end_inside_row: T, @@ -60,15 +62,16 @@ pub struct NativePoseidon2Cols { /// indicates that cell `i + 1` inside a chunk is exhausted. pub is_exhausted: [T; CHUNK - 1], - pub specific: [T; max3( + pub specific: [T; max4( TopLevelSpecificCols::::width(), InsideRowSpecificCols::::width(), SimplePoseidonSpecificCols::::width(), + MultiObserveCols::::width(), )], } -const fn max3(a: usize, b: usize, c: usize) -> usize { - const_max(a, const_max(b, c)) +const fn max4(a: usize, b: usize, c: usize, d: usize) -> usize { + const_max(a, const_max(b, const_max(c, d))) } #[repr(C)] #[derive(AlignedBorrow)] @@ -200,3 +203,43 @@ pub struct SimplePoseidonSpecificCols { pub write_data_1: MemoryWriteAuxCols, pub write_data_2: MemoryWriteAuxCols, } + +#[repr(C)] +#[derive(AlignedBorrow, Copy, Clone)] +pub struct MultiObserveCols { + // Program states + pub pc: T, + pub final_timestamp_increment: T, + + // Initial reads from registers + // They are same across same instance of multi_observe + pub state_ptr: T, + pub input_ptr: T, + pub init_pos: T, + pub len: T, + pub input_register_1: T, + pub input_register_2: T, + pub input_register_3: T, + pub output_register: T, + + pub is_first: T, + pub is_last: T, + pub curr_len: T, + pub start_idx: T, + pub end_idx: T, + pub aux_after_start: [T; CHUNK], + pub aux_before_end: [T; CHUNK], + pub aux_read_enabled: [T; CHUNK], + + // Transcript observation + pub read_data: [MemoryReadAuxCols; CHUNK], + pub write_data: [MemoryWriteAuxCols; CHUNK], + pub data: [T; CHUNK], + + // Permutation + pub should_permute: T, + pub write_sponge_state: MemoryWriteAuxCols, + + // Final write back and registers + pub write_final_idx: MemoryWriteAuxCols, +} diff --git a/extensions/native/circuit/src/poseidon2/cuda.rs b/extensions/native/circuit/src/poseidon2/cuda.rs index 107589ef49..0425cfda18 100644 --- a/extensions/native/circuit/src/poseidon2/cuda.rs +++ b/extensions/native/circuit/src/poseidon2/cuda.rs @@ -53,6 +53,9 @@ impl Chip chunk_start.push(row_idx as u32); if cols.simple.is_one() { row_idx += 1; + } else if cols.multi_observe_row.is_one() { + let num_rows = cols.inner.export.as_canonical_u32() as usize; + row_idx += num_rows; } else { let num_non_inside_row = cols.inner.export.as_canonical_u32() as usize; let non_inside_start = start + (num_non_inside_row - 1) * width; diff --git a/extensions/native/circuit/src/poseidon2/execution.rs b/extensions/native/circuit/src/poseidon2/execution.rs index 20889e4186..a0c1fc72a2 100644 --- a/extensions/native/circuit/src/poseidon2/execution.rs +++ b/extensions/native/circuit/src/poseidon2/execution.rs @@ -5,10 +5,12 @@ use std::{ use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; use openvm_circuit_primitives::AlignedBytesBorrow; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode, NATIVE_AS, +}; use openvm_native_compiler::{ conversion::AS, - Poseidon2Opcode::{COMP_POS2, PERM_POS2}, + Poseidon2Opcode::{COMP_POS2, MULTI_OBSERVE, PERM_POS2}, VerifyBatchOpcode::VERIFY_BATCH, }; use openvm_poseidon2_air::Poseidon2SubChip; @@ -29,6 +31,16 @@ struct Pos2PreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { input_register_2: u32, } +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct MultiObservePreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { + subchip: &'a Poseidon2SubChip, + pub init_pos_register: u32, + pub input_ptr_register: u32, + pub len_register: u32, + pub state_ptr_register: u32, +} + #[derive(AlignedBytesBorrow, Clone)] #[repr(C)] struct VerifyBatchPreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { @@ -87,6 +99,51 @@ impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Executor, + multi_observe_data: &mut MultiObservePreCompute<'a, F, SBOX_REGISTERS>, + ) -> Result<(), StaticProgramError> { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + f, + .. + } = inst; + + if opcode != MULTI_OBSERVE.global_opcode() { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + let f = f.as_canonical_u32(); + + if d != AS::Native as u32 { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + if e != AS::Native as u32 { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + multi_observe_data.subchip = &self.subchip; + multi_observe_data.state_ptr_register = a; + multi_observe_data.init_pos_register = b; + multi_observe_data.input_ptr_register = c; + multi_observe_data.len_register = f; + + Ok(()) + } + #[inline(always)] fn pre_compute_verify_batch_impl( &'a self, @@ -142,6 +199,7 @@ impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Executor) } + } else if $opcode == MULTI_OBSERVE.global_opcode() { + let multi_observe_data: &mut MultiObservePreCompute = + $data.borrow_mut(); + $executor.pre_compute_multi_observe_impl($pc, $inst, multi_observe_data)?; + Ok($execute_multi_observe_impl::<_, _, SBOX_REGISTERS>) } else { let verify_batch_data: &mut VerifyBatchPreCompute = $data.borrow_mut(); @@ -166,13 +229,18 @@ macro_rules! dispatch1 { }; } +fn max3(a: usize, b: usize, c: usize) -> usize { + std::cmp::max(a, std::cmp::max(b, c)) +} + impl Executor for NativePoseidon2Executor { #[inline(always)] fn pre_compute_size(&self) -> usize { - std::cmp::max( + max3( size_of::>(), + size_of::>(), size_of::>(), ) } @@ -187,6 +255,7 @@ impl Executor ) -> Result, StaticProgramError> { dispatch1!( execute_pos2_e1_impl, + execute_multi_observe_e1_impl, execute_verify_batch_e1_impl, self, inst.opcode, @@ -205,6 +274,7 @@ impl Executor ) -> Result, StaticProgramError> { dispatch1!( execute_pos2_e1_handler, + execute_multi_observe_e1_handler, execute_verify_batch_e1_handler, self, inst.opcode, @@ -218,6 +288,7 @@ impl Executor macro_rules! dispatch2 { ( $execute_pos2_impl:ident, + $execute_multi_observe_impl:ident, $execute_verify_batch_impl:ident, $executor:ident, $opcode:expr, @@ -237,6 +308,13 @@ macro_rules! dispatch2 { } else { Ok($execute_pos2_impl::<_, _, SBOX_REGISTERS, false>) } + } else if $opcode == MULTI_OBSERVE.global_opcode() { + let pre_compute: &mut E2PreCompute> = + $data.borrow_mut(); + pre_compute.chip_idx = $chip_idx as u32; + + $executor.pre_compute_multi_observe_impl($pc, $inst, &mut pre_compute.data)?; + Ok($execute_multi_observe_impl::<_, _, SBOX_REGISTERS>) } else { let pre_compute: &mut E2PreCompute> = $data.borrow_mut(); @@ -253,8 +331,9 @@ impl MeteredExecutor { #[inline(always)] fn metered_pre_compute_size(&self) -> usize { - std::cmp::max( + max3( size_of::>>(), + size_of::>>(), size_of::>>(), ) } @@ -270,6 +349,7 @@ impl MeteredExecutor ) -> Result, StaticProgramError> { dispatch2!( execute_pos2_e2_impl, + execute_multi_observe_e2_impl, execute_verify_batch_e2_impl, self, inst.opcode, @@ -290,6 +370,7 @@ impl MeteredExecutor ) -> Result, StaticProgramError> { dispatch2!( execute_pos2_e2_handler, + execute_multi_observe_e2_handler, execute_verify_batch_e2_handler, self, inst.opcode, @@ -345,6 +426,49 @@ unsafe fn execute_pos2_e2_impl< .on_height_change(pre_compute.chip_idx as usize, height); } +#[create_handler] +#[inline(always)] +unsafe fn execute_multi_observe_e1_impl< + F: PrimeField32, + CTX: ExecutionCtxTrait, + const SBOX_REGISTERS: usize, +>( + pre_compute: &[u8], + instret: &mut u64, + pc: &mut u32, + _arg: u64, + exec_state: &mut VmExecState, +) { + let pre_compute: &MultiObservePreCompute = pre_compute.borrow(); + execute_multi_observe_e12_impl::<_, _, SBOX_REGISTERS>(pre_compute, instret, pc, exec_state); +} + +#[create_handler] +#[inline(always)] +unsafe fn execute_multi_observe_e2_impl< + F: PrimeField32, + CTX: MeteredExecutionCtxTrait, + const SBOX_REGISTERS: usize, +>( + pre_compute: &[u8], + instret: &mut u64, + pc: &mut u32, + _arg: u64, + exec_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute> = + pre_compute.borrow(); + let height = execute_multi_observe_e12_impl::<_, _, SBOX_REGISTERS>( + &pre_compute.data, + instret, + pc, + exec_state, + ); + exec_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + #[create_handler] #[inline(always)] unsafe fn execute_verify_batch_e1_impl< @@ -452,6 +576,76 @@ unsafe fn execute_pos2_e12_impl< 1 } +#[inline(always)] +unsafe fn execute_multi_observe_e12_impl< + F: PrimeField32, + CTX: ExecutionCtxTrait, + const SBOX_REGISTERS: usize, +>( + pre_compute: &MultiObservePreCompute, + instret: &mut u64, + pc: &mut u32, + exec_state: &mut VmExecState, +) -> u32 { + let subchip = pre_compute.subchip; + + let [sponge_ptr]: [F; 1] = + exec_state.vm_read(AS::Native as u32, pre_compute.state_ptr_register); + let [init_pos]: [F; 1] = exec_state.vm_read(AS::Native as u32, pre_compute.init_pos_register); + let [input_ptr]: [F; 1] = exec_state.vm_read(AS::Native as u32, pre_compute.input_ptr_register); + let [len]: [F; 1] = exec_state.vm_read(AS::Native as u32, pre_compute.len_register); + + let mut len = len.as_canonical_u32() as usize; + let mut pos = init_pos.as_canonical_u32() as usize; + let input_ptr_u32 = input_ptr.as_canonical_u32(); + let sponge_ptr_u32 = sponge_ptr.as_canonical_u32(); + let mut height = 0; + + // split input into chunks s.t. each chunk fills the RATE portion of sponge state + let mut observation_chunks: Vec<(usize, usize)> = vec![]; + while len > 0 { + if len >= (CHUNK - pos) { + observation_chunks.push((pos, CHUNK)); + len -= CHUNK - pos; + pos = 0; + } else { + observation_chunks.push((pos, pos + len)); + len = 0; + pos += len; + } + } + let final_idx = observation_chunks.last().map(|(_, end)| *end % CHUNK); + + height += 1; + let mut input_idx = 0; + + for (chunk_start, chunk_end) in observation_chunks { + for j in chunk_start..chunk_end { + let [n_f]: [F; 1] = exec_state.vm_read(NATIVE_AS, input_ptr_u32 + input_idx); + exec_state.vm_write(NATIVE_AS, sponge_ptr_u32 + (j as u32), &[n_f]); + input_idx += 1; + } + if chunk_end == CHUNK { + let mut p2_input: [F; CHUNK * 2] = exec_state.vm_read(NATIVE_AS, sponge_ptr_u32); + subchip.permute_mut(&mut p2_input); + exec_state.vm_write(NATIVE_AS, sponge_ptr_u32, &p2_input); + } + + height += 1; + } + if let Some(final_idx) = final_idx { + exec_state.vm_write::( + NATIVE_AS, + pre_compute.init_pos_register, + &[F::from_canonical_usize(final_idx)], + ); + } + *pc = pc.wrapping_add(DEFAULT_PC_STEP); + *instret += 1; + + height +} + #[inline(always)] unsafe fn execute_verify_batch_e12_impl< F: PrimeField32, diff --git a/extensions/native/circuit/src/poseidon2/tests.rs b/extensions/native/circuit/src/poseidon2/tests.rs index 197def47b8..1a4270dd8a 100644 --- a/extensions/native/circuit/src/poseidon2/tests.rs +++ b/extensions/native/circuit/src/poseidon2/tests.rs @@ -468,6 +468,7 @@ fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester {} } tester.execute(&mut harness.executor, &mut harness.arena, &instruction); @@ -484,6 +485,7 @@ fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester {} } } tester.build().load(harness).finalize() diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs new file mode 100644 index 0000000000..1cae3847ca --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -0,0 +1,591 @@ +use std::borrow::Borrow; + +use openvm_circuit::{ + arch::{ExecutionBridge, ExecutionState}, + system::memory::{offline_checker::MemoryBridge, MemoryAddress}, +}; +use openvm_circuit_primitives::utils::{and, assert_array_eq, not}; +use openvm_instructions::{LocalOpcode, NATIVE_AS}; +use openvm_native_compiler::SumcheckOpcode::SUMCHECK_LAYER_EVAL; +use openvm_stark_backend::{ + interaction::InteractionBuilder, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra}, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; + +use crate::{ + field_extension::{FieldExtension, EXT_DEG}, + sumcheck::{ + chip::CONTEXT_ARR_BASE_LEN, + columns::{HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}, + }, +}; + +#[derive(Clone, Debug)] +pub struct NativeSumcheckAir { + pub execution_bridge: ExecutionBridge, + pub memory_bridge: MemoryBridge, +} + +impl NativeSumcheckAir { + pub fn new(execution_bridge: ExecutionBridge, memory_bridge: MemoryBridge) -> Self { + Self { + execution_bridge, + memory_bridge, + } + } +} + +impl BaseAir for NativeSumcheckAir { + fn width(&self) -> usize { + NativeSumcheckCols::::width() + } +} + +impl BaseAirWithPublicValues for NativeSumcheckAir {} + +impl PartitionedBaseAir for NativeSumcheckAir {} + +impl Air for NativeSumcheckAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local: &NativeSumcheckCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &NativeSumcheckCols = (*next).borrow(); + let native_as = AB::F::from_canonical_u32(NATIVE_AS); + + let &NativeSumcheckCols { + // Row indicators + header_row, + prod_row, + logup_row, + is_end, + + prod_continued, + logup_continued, + // What type of evaluation is performed + // mainly for reducing constraint degree + prod_in_round_evaluation, + prod_next_round_evaluation, + logup_in_round_evaluation, + logup_next_round_evaluation, + + // Indicates whether the round evaluations should be added to the accumulator + prod_acc, + logup_acc, + + // Timestamps + first_timestamp, + start_timestamp, + last_timestamp, + + // Results from reading registers + register_ptrs, + ctx, + prod_nested_len, + logup_nested_len, + + // Challenges + alpha, + challenges, + + curr_prod_n, + curr_logup_n, + + max_round, + within_round_limit, + should_acc, + eval_acc, + specific, + } = local; + + let [round, num_prod_spec, num_logup_spec, _prod_spec_inner_len, prod_spec_inner_inner_len, _logup_spec_inner_len, logup_spec_inner_inner_len, in_round] = + ctx; + builder.assert_bool(header_row); + builder.assert_bool(prod_row); + builder.assert_bool(logup_row); + builder.assert_bool(within_round_limit); + builder.assert_bool(prod_in_round_evaluation); + builder.assert_bool(logup_in_round_evaluation); + + let enabled = header_row + prod_row + logup_row; + let next_enabled = next.header_row + next.prod_row + next.logup_row; + builder.assert_bool(enabled.clone()); + + builder + .when_transition() + .assert_eq(prod_row * next.prod_row, prod_continued); + builder + .when_transition() + .assert_eq(logup_row * next.logup_row, logup_continued); + // TODO: handle last row properly + + builder.when_transition().assert_eq::( + prod_row * next.header_row + + logup_row * next.header_row + + not::(next_enabled), + is_end.into(), + ); + + // TODO: within_round_limit = true => round < max_round + + // Randomness transition + let alpha1: [_; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap(); + let c1: [_; EXT_DEG] = challenges[EXT_DEG..{ EXT_DEG * 2 }].try_into().unwrap(); + let c2: [_; EXT_DEG] = challenges[{ EXT_DEG * 2 }..{ EXT_DEG * 3 }] + .try_into() + .unwrap(); + let alpha2: [_; EXT_DEG] = challenges[{ EXT_DEG * 3 }..{ EXT_DEG * 4 }] + .try_into() + .unwrap(); + let next_alpha1: [_; EXT_DEG] = next.challenges[0..EXT_DEG].try_into().unwrap(); + + // Carry along columns + assert_array_eq( + &mut builder.when(next.prod_row + next.logup_row), + register_ptrs, + next.register_ptrs, + ); + assert_array_eq( + &mut builder.when(next.prod_row + next.logup_row), + ctx, + next.ctx, + ); + // c1, c2 remain the same + assert_array_eq::<_, _, _, { EXT_DEG * 2 }>( + &mut builder.when(next.prod_row + next.logup_row), + challenges[EXT_DEG..(EXT_DEG * 3)].try_into().expect(""), + next.challenges[EXT_DEG..(EXT_DEG * 3)] + .try_into() + .expect(""), + ); + assert_array_eq( + &mut builder.when(next.prod_row + next.logup_row), + alpha, + next.alpha, + ); + builder + .when(next.prod_row + next.logup_row) + .assert_eq(prod_nested_len, next.prod_nested_len); + builder + .when(next.prod_row + next.logup_row) + .assert_eq(logup_nested_len, next.logup_nested_len); + + //////////////////////////////////////////////////////////////// + // Row transitions from current to next row + // The basic pattern is + // header_row -> prod_row -> ... -> prod_row + // -> logup_row -> ... -> logup_row + //////////////////////////////////////////////////////////////// + + // (curr_prod_n, curr_logup_n) start at 0 + builder.when(header_row).assert_zero(curr_prod_n); + builder + .when(header_row + prod_row) + .assert_zero(curr_logup_n); + builder + .when(next.prod_row) + .assert_eq(curr_prod_n + AB::F::ONE, next.curr_prod_n); + builder + .when(next.logup_row) + .assert_eq(curr_logup_n + AB::F::ONE, next.curr_logup_n); + // if header row is followed by another header row + // then num_prod_spec and num_logup_spec should be zero + builder + .when(header_row) + .when(next.header_row) + .assert_zero(num_prod_spec); + builder + .when(header_row) + .when(next.header_row) + .assert_zero(num_logup_spec); + // if header row is followed by a logup row, + // then num_prod_spec should be zero + builder + .when(header_row) + .when(next.logup_row) + .assert_zero(num_prod_spec); + builder + .when(prod_row) + .when(next.logup_row) + .assert_eq(curr_prod_n, num_prod_spec); + builder + .when(logup_row) + .when(next.header_row) + .assert_eq(curr_logup_n, num_logup_spec); + + // Timestamp transition + builder + .when(header_row) + .when(next.prod_row + next.logup_row) + .assert_eq( + next.start_timestamp, + start_timestamp + AB::F::from_canonical_usize(7), + ); + builder + .when(prod_row) + .when(next.prod_row + next.logup_row) + .assert_eq( + next.start_timestamp, + start_timestamp + AB::F::ONE + within_round_limit * AB::F::TWO, + ); + builder + .when(logup_row) + .when(next.prod_row + next.logup_row) + .assert_eq( + next.start_timestamp, + start_timestamp + AB::F::ONE + within_round_limit * AB::F::from_canonical_usize(3), + ); + + // Termination condition + assert_array_eq( + &mut builder.when::(is_end.into()), + eval_acc, + [AB::F::ZERO; 4], + ); + + // Randomness transition + assert_array_eq( + &mut builder.when(and(header_row, next.prod_row + next.logup_row)), + next.challenges[0..EXT_DEG].try_into().unwrap(), + [AB::F::ONE, AB::F::ZERO, AB::F::ZERO, AB::F::ZERO], + ); + assert_array_eq::<_, _, _, { EXT_DEG }>(&mut builder.when(header_row), alpha, alpha1); + let prod_next_alpha = FieldExtension::multiply(alpha1, alpha); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(prod_continued), + prod_next_alpha, + next_alpha1, + ); + // alpha1 = alpha_numerator, alpha2 = alpha_denominator for logup row + let alpha_denominator = FieldExtension::multiply(alpha1, alpha); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_row), + alpha_denominator, + alpha2, + ); + let logup_next_alpha = FieldExtension::multiply(alpha2, alpha); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_continued), + logup_next_alpha, + next_alpha1, + ); + + /////////////////////////////////////// + // Header + /////////////////////////////////////// + let header_row_specific: &HeaderSpecificCols = + specific[..HeaderSpecificCols::::width()].borrow(); + let registers = header_row_specific.registers; + + self.execution_bridge + .execute_and_increment_pc( + AB::Expr::from_canonical_usize(SUMCHECK_LAYER_EVAL.global_opcode().as_usize()), + [ + registers[4].into(), + registers[0].into(), + registers[1].into(), + native_as.into(), + native_as.into(), + registers[2].into(), + registers[3].into(), + ], + ExecutionState::new(header_row_specific.pc, first_timestamp), + last_timestamp - first_timestamp, + ) + .eval(builder, header_row); + + // Read registers + for i in 0..5usize { + self.memory_bridge + .read( + MemoryAddress::new(native_as, registers[i]), + [register_ptrs[i]], + first_timestamp + AB::F::from_canonical_usize(i), + &header_row_specific.read_records[i], + ) + .eval(builder, header_row); + } + + // Read ctx + self.memory_bridge + .read( + MemoryAddress::new(native_as, register_ptrs[0]), + ctx, + first_timestamp + AB::F::from_canonical_usize(5), + &header_row_specific.read_records[5], + ) + .eval(builder, header_row); + + // Read challenges + self.memory_bridge + .read( + MemoryAddress::new(native_as, register_ptrs[1]), + challenges, + first_timestamp + AB::F::from_canonical_usize(6), + &header_row_specific.read_records[6], + ) + .eval(builder, header_row); + + // Write final result + self.memory_bridge + .write( + MemoryAddress::new(native_as, register_ptrs[4]), + eval_acc, + last_timestamp - AB::F::ONE, + &header_row_specific.write_records, + ) + .eval(builder, header_row); + + /////////////////////////////////////// + // Prod spec evaluation + /////////////////////////////////////// + let prod_row_specific: &ProdSpecificCols = + specific[..ProdSpecificCols::::width()].borrow(); + let next_prod_row_specific: &ProdSpecificCols = + next.specific[..ProdSpecificCols::::width()].borrow(); + + self.memory_bridge + .read( + MemoryAddress::new( + native_as, + register_ptrs[0] + + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN) + + (curr_prod_n - AB::F::ONE), + ), // curr_prod_n starts at 1. + [max_round], + start_timestamp, + &prod_row_specific.read_records[0], + ) + .eval(builder, prod_row); + + // prod_row * within_round_limit = + // prod_in_round_evaluation + prod_next_round_evaluation + builder + .when(prod_in_round_evaluation + prod_next_round_evaluation) + .assert_eq( + prod_row_specific.data_ptr, + (prod_nested_len * (curr_prod_n - AB::F::ONE) + prod_spec_inner_inner_len * round) + * AB::F::from_canonical_usize(EXT_DEG), + ); + builder.assert_eq( + prod_row * within_round_limit * in_round, + prod_in_round_evaluation, + ); + builder.assert_eq( + prod_row * within_round_limit * not(in_round), + prod_next_round_evaluation, + ); + builder.assert_eq(prod_row * should_acc, prod_acc); + + self.memory_bridge + .read( + MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), + prod_row_specific.p, + start_timestamp + AB::F::ONE, + &prod_row_specific.read_records[1], + ) + .eval(builder, prod_row * within_round_limit); + + let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap(); + let p2: [AB::Var; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)] + .try_into() + .unwrap(); + + self.memory_bridge + .write( + MemoryAddress::new( + native_as, + register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), + ), + prod_row_specific.p_evals, + start_timestamp + AB::F::TWO, + &prod_row_specific.write_record, + ) + .eval(builder, prod_row * within_round_limit); + + // Calculate evaluations + let next_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, c1), + FieldExtension::multiply::(p2, c2), + ); + let in_round_p_evals = FieldExtension::multiply::(p1, p2); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(prod_in_round_evaluation), + in_round_p_evals, + prod_row_specific.p_evals, + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(prod_next_round_evaluation), + next_round_p_evals, + prod_row_specific.p_evals, + ); + + // TODO: add constraint on should_acc + + // Accumulate `eval_rlc` into global accumulator `eval_acc` + // when round < max_round - 2 + let eval_rlc = + FieldExtension::multiply::(prod_row_specific.p_evals, alpha1); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(prod_acc), + prod_row_specific.eval_rlc, + eval_rlc, + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(next.prod_acc), + FieldExtension::add(next.eval_acc, next_prod_row_specific.eval_rlc), + eval_acc, + ); + + /////////////////////////////////////// + // Logup spec evaluation + /////////////////////////////////////// + let logup_row_specific: &LogupSpecificCols = + specific[..LogupSpecificCols::::width()].borrow(); + let next_logup_row_specfic: &LogupSpecificCols = + next.specific[..LogupSpecificCols::::width()].borrow(); + + self.memory_bridge + .read( + MemoryAddress::new( + native_as, + register_ptrs[0] + + AB::F::from_canonical_usize(EXT_DEG * 2) + + num_prod_spec + + (curr_logup_n - AB::F::ONE), + ), // curr_logup_n starts at 1. + [max_round], + start_timestamp, + &logup_row_specific.read_records[0], + ) + .eval(builder, logup_row); + + // logup_row * within_round_limit = + // logup_in_round_evaluation + logup_next_round_evaluation + builder + .when(logup_in_round_evaluation + logup_next_round_evaluation) + .assert_eq( + logup_row_specific.data_ptr, + (logup_nested_len * (curr_logup_n - AB::F::ONE) + + logup_spec_inner_inner_len * round) + * AB::F::from_canonical_usize(EXT_DEG), + ); + builder.assert_eq( + logup_row * within_round_limit * in_round, + logup_in_round_evaluation, + ); + builder.assert_eq( + logup_row * within_round_limit * not(in_round), + logup_next_round_evaluation, + ); + builder.assert_eq(logup_row * should_acc, logup_acc); + + self.memory_bridge + .read( + MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), + logup_row_specific.pq, + start_timestamp + AB::F::ONE, + &logup_row_specific.read_records[1], + ) + .eval(builder, logup_row * within_round_limit); + + let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap(); + let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)] + .try_into() + .unwrap(); + let q1: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 2)..{ EXT_DEG * 3 }] + .try_into() + .unwrap(); + let q2: [_; EXT_DEG] = logup_row_specific.pq[(EXT_DEG * 3)..(EXT_DEG * 4)] + .try_into() + .unwrap(); + + // write p_evals + self.memory_bridge + .write( + MemoryAddress::new( + native_as, + register_ptrs[4] + + (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), + ), + logup_row_specific.p_evals, + start_timestamp + AB::F::TWO, + &logup_row_specific.write_records[0], + ) + .eval(builder, logup_row * within_round_limit); + + // write q_evals + self.memory_bridge + .write( + MemoryAddress::new( + native_as, + register_ptrs[4] + + (num_prod_spec + num_logup_spec + curr_logup_n) + * AB::F::from_canonical_usize(EXT_DEG), + ), + logup_row_specific.q_evals, + start_timestamp + AB::F::from_canonical_usize(3), + &logup_row_specific.write_records[1], + ) + .eval(builder, logup_row * within_round_limit); + + // Calculate evaluations + let next_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, c1), + FieldExtension::multiply::(p2, c2), + ); + let in_round_p_evals = FieldExtension::add( + FieldExtension::multiply::(p1, q2), + FieldExtension::multiply::(p2, q1), + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_in_round_evaluation), + in_round_p_evals, + logup_row_specific.p_evals, + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_next_round_evaluation), + next_round_p_evals, + logup_row_specific.p_evals, + ); + + let next_round_q_evals = FieldExtension::add( + FieldExtension::multiply::(q1, c1), + FieldExtension::multiply::(q2, c2), + ); + let in_round_q_evals = FieldExtension::multiply::(q1, q2); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_in_round_evaluation), + in_round_q_evals, + logup_row_specific.q_evals, + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_next_round_evaluation), + next_round_q_evals, + logup_row_specific.q_evals, + ); + + // Accumulate evaluation + let eval_rlc = FieldExtension::add( + FieldExtension::multiply::(logup_row_specific.p_evals, alpha1), + FieldExtension::multiply::(logup_row_specific.q_evals, alpha2), + ); + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(logup_acc), + logup_row_specific.eval_rlc, + eval_rlc, + ); + + // Accumulate into global accumulator `eval_acc` + // when round < max_round - 2 + assert_array_eq::<_, _, _, { EXT_DEG }>( + &mut builder.when(next.logup_acc), + FieldExtension::add(next.eval_acc, next_logup_row_specfic.eval_rlc), + eval_acc, + ); + } +} diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs new file mode 100644 index 0000000000..bb7cfa7080 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -0,0 +1,595 @@ +use std::borrow::BorrowMut; + +use openvm_circuit::{ + arch::{ + CustomBorrow, ExecutionError, MultiRowLayout, MultiRowMetadata, PreflightExecutor, + RecordArena, SizedRecord, TraceFiller, VmChipWrapper, VmStateMut, + }, + system::{ + memory::{online::TracingMemory, MemoryAuxColsFactory}, + native_adapter::util::{memory_read_native, tracing_write_native_inplace}, + }, +}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode, NATIVE_AS, +}; +use openvm_native_compiler::SumcheckOpcode::SUMCHECK_LAYER_EVAL; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{ + field_extension::{FieldExtension, EXT_DEG}, + fri::elem_to_ext, + mem_fill_helper, + sumcheck::columns::{ + HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols, + }, + tracing_read_native_helper, +}; + +pub(crate) const CONTEXT_ARR_BASE_LEN: usize = EXT_DEG * 2; +pub(crate) const CURRENT_LAYER_MODE: u32 = 1; +pub(crate) const NEXT_LAYER_MODE: u32 = 0; + +pub(crate) fn calculate_3d_ext_idx( + inner_inner_len: u32, + inner_len: u32, + outer_idx: u32, + inner_idx: u32, + inner_inner_idx: u32, +) -> u32 { + (inner_inner_len * inner_len * outer_idx + inner_inner_len * inner_idx + inner_inner_idx) + * EXT_DEG as u32 +} + +#[derive(Debug, Clone, Default)] +pub struct NativeSumcheckMetadata { + num_rows: usize, +} + +impl MultiRowMetadata for NativeSumcheckMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + self.num_rows + } +} + +type NativeSumcheckRecordLayout = MultiRowLayout; + +pub struct NativeSumcheckRecordMut<'a, F>(&'a mut [NativeSumcheckCols]); + +impl<'a, F: PrimeField32> + CustomBorrow<'a, NativeSumcheckRecordMut<'a, F>, NativeSumcheckRecordLayout> for [u8] +{ + fn custom_borrow( + &'a mut self, + layout: NativeSumcheckRecordLayout, + ) -> NativeSumcheckRecordMut<'a, F> { + // SAFETY: + // - align_to_mut() ensures proper alignment for NativeSumcheckCols + // - Layout guarantees sufficient length for num_rows records + // - Slice bounds validated by taking only num_rows elements + let arr = unsafe { self.align_to_mut::>().1 }; + NativeSumcheckRecordMut(&mut arr[..layout.metadata.num_rows]) + } + + unsafe fn extract_layout(&self) -> NativeSumcheckRecordLayout { + // Each instruction record consists solely of some number of contiguously + // stored NativeSumcheckCols<...> structs, each of which corresponds to a + // single trace row. Trace fillers don't actually need to know how many rows + // each instruction uses, and can thus treat each NativeSumcheckCols<...> + // as a single record. + NativeSumcheckRecordLayout { + metadata: NativeSumcheckMetadata { num_rows: 1 }, + } + } +} + +impl SizedRecord for NativeSumcheckRecordMut<'_, F> { + fn size(layout: &NativeSumcheckRecordLayout) -> usize { + layout.metadata.num_rows * size_of::>() + } + + fn alignment(_layout: &NativeSumcheckRecordLayout) -> usize { + align_of::>() + } +} + +#[derive(derive_new::new, Copy, Clone)] +pub struct NativeSumcheckExecutor; + +#[derive(derive_new::new)] +pub struct NativeSumcheckFiller; + +pub type NativeSumcheckChip = VmChipWrapper; + +impl Default for NativeSumcheckExecutor { + fn default() -> Self { + Self::new() + } +} + +impl PreflightExecutor for NativeSumcheckExecutor +where + F: PrimeField32, + for<'buf> RA: RecordArena<'buf, NativeSumcheckRecordLayout, NativeSumcheckRecordMut<'buf, F>>, +{ + fn execute( + &self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let &Instruction { + opcode: op, + a: r_evals_reg, + b: ctx_reg, + c: challenges_reg, + d: data_address_space, + e: register_address_space, + f: prod_evals_reg, + g: logup_evals_reg, + } = instruction; + + // This opcode supports two modes of operation: + // 1. calculate the expected evaluation of two types of sumchecks for the current round + // a. product sumcheck: v' = v[0] * v[1] + // b. logup sumcheck: p'= p[0] * q[1] + p[1] * q[0] and q'= q[0] * q[1]. + // 2. calculate the expected value of next layer: + // a. product sumcheck: v[r] = eq(0,r) * v[0] + eq(1,r) * v[1] + // b. logup sumcheck: p[r] = eq(0,r) * p[0] + eq(1,r) * p[1] + // and q[r] = eq(0,r) * q[0] + eq(1,r) * q[1] + assert_eq!(op, SUMCHECK_LAYER_EVAL.global_opcode()); + assert_eq!(data_address_space.as_canonical_u32(), NATIVE_AS); + assert_eq!(register_address_space.as_canonical_u32(), NATIVE_AS); + + let [ctx_ptr]: [F; 1] = memory_read_native(state.memory.data(), ctx_reg.as_canonical_u32()); + let ctx: [u32; 8] = memory_read_native(state.memory.data(), ctx_ptr.as_canonical_u32()) + .map(|x: F| x.as_canonical_u32()); + + let [round, num_prod_spec, num_logup_spec, prod_specs_inner_len, prod_specs_inner_inner_len, logup_specs_inner_len, logup_specs_inner_inner_len, mode] = + ctx; + // allocate n rows + let num_rows = (1 + num_prod_spec + num_logup_spec) as usize; + let rows = state + .ctx + .alloc(MultiRowLayout::new(NativeSumcheckMetadata { num_rows })) + .0; + + let mut cur_timestamp = state.memory.timestamp(); + // head row + let head_row: &mut NativeSumcheckCols = &mut rows[0]; + let head_specific: &mut HeaderSpecificCols = + head_row.specific[..HeaderSpecificCols::::width()].borrow_mut(); + + head_row.header_row = F::ONE; + head_row.first_timestamp = F::from_canonical_u32(cur_timestamp); + head_row.start_timestamp = F::from_canonical_u32(cur_timestamp); + + head_specific.pc = F::from_canonical_u32(*state.pc); + + head_specific.registers[0] = ctx_reg; + head_specific.registers[1] = challenges_reg; + head_specific.registers[2] = prod_evals_reg; + head_specific.registers[3] = logup_evals_reg; + head_specific.registers[4] = r_evals_reg; + + // read pointers + let [ctx_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + ctx_reg.as_canonical_u32(), + head_specific.read_records[0].as_mut(), + ); + let [challenges_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + challenges_reg.as_canonical_u32(), + head_specific.read_records[1].as_mut(), + ); + let [prod_evals_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + prod_evals_reg.as_canonical_u32(), + head_specific.read_records[2].as_mut(), + ); + let [logup_evals_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + logup_evals_reg.as_canonical_u32(), + head_specific.read_records[3].as_mut(), + ); + let [r_evals_ptr]: [F; 1] = tracing_read_native_helper( + state.memory, + r_evals_reg.as_canonical_u32(), + head_specific.read_records[4].as_mut(), + ); + + let ctx: [F; CONTEXT_ARR_BASE_LEN] = tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32(), + head_specific.read_records[5].as_mut(), + ); + + let challenges: [F; EXT_DEG * 4] = tracing_read_native_helper( + state.memory, + challenges_ptr.as_canonical_u32(), + head_specific.read_records[6].as_mut(), + ); + cur_timestamp += 7; // 5 register reads + ctx read + challenges read + head_row.challenges.copy_from_slice(&challenges); + + // challenges = [alpha, c1=r, c2=1-r] + let alpha: [F; 4] = challenges[0..EXT_DEG].try_into().unwrap(); + let c1: [F; 4] = challenges[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); + let c2: [F; 4] = challenges[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().unwrap(); + + let mut eval_acc = elem_to_ext(F::from_canonical_u32(0)); + let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1)); + + // all rows share same register values, ctx, challenges + for row in rows.iter_mut() { + // c1, c2 are same during the entire execution + row.challenges[EXT_DEG..3 * EXT_DEG].copy_from_slice(&challenges[EXT_DEG..3 * EXT_DEG]); + row.alpha = alpha; + row.ctx = ctx; + row.prod_nested_len = + F::from_canonical_u32(prod_specs_inner_len * prod_specs_inner_inner_len); + row.logup_nested_len = + F::from_canonical_u32(logup_specs_inner_len * logup_specs_inner_inner_len); + row.register_ptrs[0] = ctx_ptr; + row.register_ptrs[1] = challenges_ptr; + row.register_ptrs[2] = prod_evals_ptr; + row.register_ptrs[3] = logup_evals_ptr; + row.register_ptrs[4] = r_evals_ptr; + } + + // product rows + for (i, prod_row) in rows + .iter_mut() + .skip(1) + .take(num_prod_spec as usize) + .enumerate() + { + let prod_specific: &mut ProdSpecificCols = + prod_row.specific[..ProdSpecificCols::::width()].borrow_mut(); + + prod_row.prod_row = F::ONE; + prod_row.prod_continued = if i < (num_prod_spec - 1) as usize { + F::ONE + } else { + F::ZERO + }; + prod_row.curr_prod_n = F::from_canonical_usize(i + 1); // curr_prod_n starts from 1 + prod_row.start_timestamp = F::from_canonical_u32(cur_timestamp); + + // read max_round + let [max_round]: [F; 1] = tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32() + (CONTEXT_ARR_BASE_LEN + i) as u32, + prod_specific.read_records[0].as_mut(), + ); + cur_timestamp += 1; + + prod_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc); + prod_row.max_round = max_round; + + let max_round = max_round.as_canonical_u32(); + // round starts from 0 + if round < max_round - 1 { + prod_row.within_round_limit = F::ONE; + let start = calculate_3d_ext_idx( + prod_specs_inner_inner_len, + prod_specs_inner_len, + i as u32, + round, + 0, + ); + prod_specific.data_ptr = F::from_canonical_u32(start); + + // read p1, p2 + let ps: [F; EXT_DEG * 2] = tracing_read_native_helper( + state.memory, + prod_evals_ptr.as_canonical_u32() + start, + prod_specific.read_records[1].as_mut(), + ); + let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); + let p2: [F; EXT_DEG] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); + + prod_specific.p = ps; + + // compute expected eval + let eval = match mode { + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ), + CURRENT_LAYER_MODE => FieldExtension::multiply(p1, p2), + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + prod_specific.p_evals = eval; + + match mode { + NEXT_LAYER_MODE => { + prod_row.prod_next_round_evaluation = F::ONE; + } + CURRENT_LAYER_MODE => { + prod_row.prod_in_round_evaluation = F::ONE; + } + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + } + + // write p eval + tracing_write_native_inplace( + state.memory, + r_evals_ptr.as_canonical_u32() + (1 + i as u32) * (EXT_DEG as u32), + eval, + &mut prod_specific.write_record, + ); + cur_timestamp += 2; + + let eval_rlc = FieldExtension::multiply(alpha_acc, eval); + prod_specific.eval_rlc = eval_rlc; + + let to_next_round = if mode == NEXT_LAYER_MODE { 1 } else { 0 }; + if round + to_next_round < max_round - 1 { + eval_acc = FieldExtension::add(eval_acc, eval_rlc); + prod_row.should_acc = F::ONE; + prod_row.prod_acc = F::ONE; + prod_row.eval_acc = eval_acc; + } + } + + alpha_acc = FieldExtension::multiply(alpha_acc, alpha); + } + + // logup rows + for (i, logup_row) in rows.iter_mut().skip(1 + num_prod_spec as usize).enumerate() { + let logup_specific: &mut LogupSpecificCols = + logup_row.specific[..LogupSpecificCols::::width()].borrow_mut(); + + logup_row.logup_row = F::ONE; + logup_row.logup_continued = if i < (num_logup_spec - 1) as usize { + F::ONE + } else { + F::ZERO + }; + logup_row.curr_logup_n = F::from_canonical_usize(i + 1); // curr_logup_n starts from 1 + logup_row.start_timestamp = F::from_canonical_u32(cur_timestamp); + + let [max_round]: [F; 1] = tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32() + num_prod_spec + (CONTEXT_ARR_BASE_LEN + i) as u32, + logup_specific.read_records[0].as_mut(), + ); + logup_row.max_round = max_round; + cur_timestamp += 1; + + let alpha_numerator = alpha_acc; + let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); + logup_row.challenges[0..EXT_DEG].copy_from_slice(&alpha_acc); + logup_row.challenges[3 * EXT_DEG..(4 * EXT_DEG)].copy_from_slice(&alpha_denominator); + + let max_round = max_round.as_canonical_u32(); + if round < max_round - 1 { + logup_row.within_round_limit = F::ONE; + let start = calculate_3d_ext_idx( + logup_specs_inner_inner_len, + logup_specs_inner_len, + i as u32, + round, + 0, + ); + logup_specific.data_ptr = F::from_canonical_u32(start); + + // read p1, p2, q1, q2 + let pqs: [F; EXT_DEG * 4] = tracing_read_native_helper( + state.memory, + logup_evals_ptr.as_canonical_u32() + start, + logup_specific.read_records[1].as_mut(), + ); + let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); + let p2: [F; EXT_DEG] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); + let q1: [F; EXT_DEG] = pqs[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().unwrap(); + let q2: [F; EXT_DEG] = pqs[(EXT_DEG * 3)..(EXT_DEG * 4)].try_into().unwrap(); + + logup_specific.pq = pqs; + + // compute expected evals + let p_eval = match mode { + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ), + CURRENT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, q2), + FieldExtension::multiply(p2, q1), + ), + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + let q_eval = match mode { + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(q1, c1), + FieldExtension::multiply(q2, c2), + ), + CURRENT_LAYER_MODE => FieldExtension::multiply(q1, q2), + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + + match mode { + NEXT_LAYER_MODE => { + logup_row.logup_next_round_evaluation = F::ONE; + } + CURRENT_LAYER_MODE => { + logup_row.logup_in_round_evaluation = F::ONE; + } + _ => unreachable!("mode should be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + } + + logup_specific.p_evals = p_eval; + logup_specific.q_evals = q_eval; + + // write p_eval + tracing_write_native_inplace( + state.memory, + r_evals_ptr.as_canonical_u32() + + (1 + num_prod_spec + i as u32) * (EXT_DEG as u32), + p_eval, + &mut logup_specific.write_records[0], + ); + // write q_eval + tracing_write_native_inplace( + state.memory, + r_evals_ptr.as_canonical_u32() + + (1 + num_prod_spec + num_logup_spec + i as u32) * (EXT_DEG as u32), + q_eval, + &mut logup_specific.write_records[1], + ); + cur_timestamp += 3; // 1 read, 2 writes + + let eval_rlc = FieldExtension::add( + FieldExtension::multiply(alpha_numerator, p_eval), + FieldExtension::multiply(alpha_denominator, q_eval), + ); + logup_specific.eval_rlc = eval_rlc; + let to_next_round = if mode == NEXT_LAYER_MODE { 1 } else { 0 }; + if round + to_next_round < max_round - 1 { + eval_acc = FieldExtension::add(eval_acc, eval_rlc); + logup_row.should_acc = F::ONE; + logup_row.logup_acc = F::ONE; + logup_row.eval_acc = eval_acc; + } + } + + alpha_acc = FieldExtension::multiply(alpha_denominator, alpha); + } + + if let Some(last_row) = rows.last_mut() { + last_row.is_end = F::ONE; + } + + let head_row = &mut rows[0]; + head_row.last_timestamp = F::from_canonical_u32(cur_timestamp + 1); + + let head_specific: &mut HeaderSpecificCols = + head_row.specific[..HeaderSpecificCols::::width()].borrow_mut(); + + tracing_write_native_inplace( + state.memory, + r_evals_ptr.as_canonical_u32(), + eval_acc, + &mut head_specific.write_records, + ); + + for row in rows.iter_mut() { + if row.header_row == F::ONE { + row.eval_acc = eval_acc; + } else if row.prod_row == F::ONE { + let specific: &mut ProdSpecificCols = + row.specific[..ProdSpecificCols::::width()].borrow_mut(); + + if row.should_acc == F::ONE { + eval_acc = FieldExtension::subtract(eval_acc, specific.eval_rlc); + } + row.eval_acc = eval_acc; + } else if row.logup_row == F::ONE { + let specific: &mut LogupSpecificCols = + row.specific[..LogupSpecificCols::::width()].borrow_mut(); + if row.should_acc == F::ONE { + eval_acc = FieldExtension::subtract(eval_acc, specific.eval_rlc); + } + row.eval_acc = eval_acc; + } + } + assert_eq!(eval_acc, elem_to_ext(F::from_canonical_u32(0)),); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + Ok(()) + } + + // GKR layered IOP for product and logup relations + fn get_opcode_name(&self, opcode: usize) -> String { + assert_eq!(opcode, SUMCHECK_LAYER_EVAL.global_opcode().as_usize()); + String::from("SUMCHECK_LAYER_EVAL") + } +} + +impl TraceFiller for NativeSumcheckFiller { + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let cols: &mut NativeSumcheckCols = row_slice.borrow_mut(); + let start_timestamp = cols.start_timestamp.as_canonical_u32(); + let last_timestamp = cols.last_timestamp.as_canonical_u32(); + + if cols.header_row == F::ONE { + let header: &mut HeaderSpecificCols = + cols.specific[..HeaderSpecificCols::::width()].borrow_mut(); + + for i in 0..7usize { + mem_fill_helper( + mem_helper, + start_timestamp + i as u32, + header.read_records[i].as_mut(), + ); + } + mem_fill_helper( + mem_helper, + last_timestamp - 1, + header.write_records.as_mut(), + ); + } else if cols.prod_row == F::ONE { + let prod_row_specific: &mut ProdSpecificCols = + cols.specific[..ProdSpecificCols::::width()].borrow_mut(); + + // read max_round + mem_fill_helper( + mem_helper, + start_timestamp, + prod_row_specific.read_records[0].as_mut(), + ); + if cols.within_round_limit == F::ONE { + // read p1, p2 + mem_fill_helper( + mem_helper, + start_timestamp + 1, + prod_row_specific.read_records[1].as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 2, + prod_row_specific.write_record.as_mut(), + ); + } + } else if cols.logup_row == F::ONE { + let logup_row_specific: &mut LogupSpecificCols = + cols.specific[..LogupSpecificCols::::width()].borrow_mut(); + + // read max_round + mem_fill_helper( + mem_helper, + start_timestamp, + logup_row_specific.read_records[0].as_mut(), + ); + if cols.within_round_limit == F::ONE { + // read p1, p2, q1, q2 + mem_fill_helper( + mem_helper, + start_timestamp + 1, + logup_row_specific.read_records[1].as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 2, + logup_row_specific.write_records[0].as_mut(), + ); + // write q_eval + mem_fill_helper( + mem_helper, + start_timestamp + 3, + logup_row_specific.write_records[1].as_mut(), + ); + } + } + } + + fn fill_dummy_trace_row(&self, row_slice: &mut [F]) { + let cols: &mut NativeSumcheckCols = row_slice.borrow_mut(); + + cols.is_end = F::ONE; + } +} diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs new file mode 100644 index 0000000000..b3e6bf4f25 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -0,0 +1,135 @@ +use openvm_circuit::system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}; +use openvm_circuit_primitives_derive::AlignedBorrow; + +use crate::{field_extension::EXT_DEG, utils::const_max}; + +const fn max3(a: usize, b: usize, c: usize) -> usize { + const_max(a, const_max(b, c)) +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct NativeSumcheckCols { + /// Indicates that this row is the header for a layer sum operation + pub header_row: T, + /// Indicates that this row is a step for prod_spec in the layer sum operation + pub prod_row: T, + /// Indicates that this row is a step for logup_spec in the layer sum operation + pub logup_row: T, + /// Indicates that this row is the end of the entire layer sum operation + pub is_end: T, + + pub prod_continued: T, + pub logup_continued: T, + + /// Indicates what type of evaluation constraints should be applied + pub prod_in_round_evaluation: T, + pub prod_next_round_evaluation: T, + pub logup_in_round_evaluation: T, + pub logup_next_round_evaluation: T, + + /// Indicates if evaluations are accumulated + pub prod_acc: T, + pub logup_acc: T, + + /// Timestamps + pub first_timestamp: T, + pub start_timestamp: T, + pub last_timestamp: T, + + // Register values + pub register_ptrs: [T; 5], + + // Context variables + // [ + // round, + // num_prod_spec, + // num_logup_spec, + // prod_spec_inner_len, + // prod_spec_inner_inner_len, + // logup_spec_inner_len, + // logup_spec_inner_inner_len, + // in_layer, + // ] + pub ctx: [T; EXT_DEG * 2], + + pub prod_nested_len: T, + pub logup_nested_len: T, + + pub curr_prod_n: T, + pub curr_logup_n: T, + + pub alpha: [T; EXT_DEG], + // alpha1, c1, c2, alpha2 (for logup rows) + pub challenges: [T; EXT_DEG * 4], + + // Specific to each row + pub max_round: T, + // Is this round within max_round + pub within_round_limit: T, + // Should the evaluation be accumualted + pub should_acc: T, + + // The current final evaluation accumulator. Extension element. + pub eval_acc: [T; EXT_DEG], + + // /// 1. For header row, 5 registers, ctx, challenges + // /// 2. For the rest: max_variables, p1, p2, q1, q2 + // pub read_records: [MemoryReadAuxCols; 7], + // /// 1. For header row, write final result + // /// 2. For prod rows: write prod_evals + // /// 3. For logup rows: write q_evals, p_evals + // pub write_records: [MemoryWriteAuxCols; 2], + pub specific: [T; max3( + HeaderSpecificCols::::width(), + ProdSpecificCols::::width(), + LogupSpecificCols::::width(), + )], +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct HeaderSpecificCols { + pub pc: T, + pub registers: [T; 5], + /// 5 register reads + ctx read + challenges read + pub read_records: [MemoryReadAuxCols; 7], + /// Write the final evaluation + pub write_records: MemoryWriteAuxCols, +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct ProdSpecificCols { + /// Pointer + pub data_ptr: T, + /// 2 extension elements + pub p: [T; EXT_DEG * 2], + /// read max varibale and 2 p values + pub read_records: [MemoryReadAuxCols; 2], + /// Calculated p evals + pub p_evals: [T; EXT_DEG], + /// write p_evals + pub write_record: MemoryWriteAuxCols, + /// p_evals * alpha^i + pub eval_rlc: [T; EXT_DEG], +} + +#[repr(C)] +#[derive(AlignedBorrow)] +pub struct LogupSpecificCols { + /// Pointer + pub data_ptr: T, + /// 4 extension elements + pub pq: [T; EXT_DEG * 4], + /// read max variable and 4 values: p1, p2, q1, q2 + pub read_records: [MemoryReadAuxCols; 2], + /// Calculated p evals + pub p_evals: [T; EXT_DEG], + /// Calculated q evals + pub q_evals: [T; EXT_DEG], + /// write both p_evals and q_evals + pub write_records: [MemoryWriteAuxCols; 2], + /// Evaluation for the accumulator + pub eval_rlc: [T; EXT_DEG], +} diff --git a/extensions/native/circuit/src/sumcheck/cuda.rs b/extensions/native/circuit/src/sumcheck/cuda.rs new file mode 100644 index 0000000000..60aba15b95 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/cuda.rs @@ -0,0 +1,57 @@ +use std::{mem::size_of, slice::from_raw_parts, sync::Arc}; + +use derive_new::new; +use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero}; +use openvm_circuit_primitives::var_range::VariableRangeCheckerChipGPU; +use openvm_cuda_backend::{ + base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F, +}; +use openvm_cuda_common::copy::MemCopyH2D; +use openvm_stark_backend::{prover::types::AirProvingContext, Chip}; + +use super::columns::NativeSumcheckCols; +use crate::cuda_abi::sumcheck_cuda; + +#[derive(new)] +pub struct NativeSumcheckChipGpu { + pub range_checker: Arc, + pub timestamp_max_bits: usize, +} + +impl Chip for NativeSumcheckChipGpu { + fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext { + let records = arena.allocated(); + if records.is_empty() { + return get_empty_air_proving_ctx::(); + } + + let width = NativeSumcheckCols::::width(); + let record_size = width * size_of::(); + assert_eq!(records.len() % record_size, 0); + + let height = records.len() / record_size; + let padded_height = next_power_of_two_or_zero(height); + let trace = DeviceMatrix::::with_capacity(padded_height, width); + + let record_slice = unsafe { + let ptr = records.as_ptr(); + from_raw_parts(ptr as *const F, records.len() / size_of::()) + }; + let d_records = record_slice.to_device().unwrap(); + + unsafe { + sumcheck_cuda::tracegen( + trace.buffer(), + padded_height, + width, + &d_records, + height, + &self.range_checker.count, + self.timestamp_max_bits as u32, + ) + .unwrap(); + } + + AirProvingContext::simple_no_pis(trace) + } +} diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs new file mode 100644 index 0000000000..a475bf9e49 --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -0,0 +1,347 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + mem::size_of, +}; + +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, NATIVE_AS}; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{ + field_extension::{FieldExtension, EXT_DEG}, + fri::elem_to_ext, + sumcheck::chip::{ + calculate_3d_ext_idx, NativeSumcheckExecutor, CONTEXT_ARR_BASE_LEN, CURRENT_LAYER_MODE, + NEXT_LAYER_MODE, + }, +}; + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct NativeSumcheckPreCompute { + r_evals_reg: u32, + ctx_reg: u32, + challenges_reg: u32, + prod_evals_reg: u32, + logup_evals_reg: u32, +} + +impl NativeSumcheckExecutor { + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut NativeSumcheckPreCompute, + ) -> Result<(), StaticProgramError> { + let &Instruction { + a, + b, + c, + d, + e, + f, + g, + .. + } = inst; + + let r_evals_reg = a.as_canonical_u32(); + let ctx_reg = b.as_canonical_u32(); + let challenges_reg = c.as_canonical_u32(); + let prod_evals_reg = f.as_canonical_u32(); + let logup_evals_reg = g.as_canonical_u32(); + + if d.as_canonical_u32() != NATIVE_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + if e.as_canonical_u32() != NATIVE_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + *data = NativeSumcheckPreCompute { + r_evals_reg, + ctx_reg, + challenges_reg, + prod_evals_reg, + logup_evals_reg, + }; + + Ok(()) + } +} + +impl Executor for NativeSumcheckExecutor +where + F: PrimeField32, +{ + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut NativeSumcheckPreCompute = data.borrow_mut(); + + self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = execute_e1_handler; + Ok(fn_ptr) + } + + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[cfg(not(feature = "tco"))] + #[inline(always)] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut NativeSumcheckPreCompute = data.borrow_mut(); + + self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = execute_e1_impl; + Ok(fn_ptr) + } +} + +impl MeteredExecutor for NativeSumcheckExecutor +where + F: PrimeField32, +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + #[cfg(not(feature = "tco"))] + #[inline(always)] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = execute_e2_impl; + Ok(fn_ptr) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = execute_e2_handler; + Ok(fn_ptr) + } +} + +#[create_handler] +#[inline(always)] +unsafe fn execute_e1_impl( + pre_compute: &[u8], + instret: &mut u64, + pc: &mut u32, + _instret_end: u64, + exec_state: &mut VmExecState, +) { + let pre_compute: &NativeSumcheckPreCompute = pre_compute.borrow(); + execute_e12_impl(pre_compute, instret, pc, exec_state); +} + +#[create_handler] +#[inline(always)] +unsafe fn execute_e2_impl( + pre_compute: &[u8], + instret: &mut u64, + pc: &mut u32, + _arg: u64, + exec_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + let height = execute_e12_impl(&pre_compute.data, instret, pc, exec_state); + exec_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &NativeSumcheckPreCompute, + instret: &mut u64, + pc: &mut u32, + exec_state: &mut VmExecState, +) -> u32 { + let [r_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.r_evals_reg); + let [ctx_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.ctx_reg); + let [challenges_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.challenges_reg); + let [prod_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.prod_evals_reg); + let [logup_evals_ptr]: [F; 1] = exec_state.vm_read(NATIVE_AS, pre_compute.logup_evals_reg); + + let r_evals_ptr_u32 = r_evals_ptr.as_canonical_u32(); + let ctx_ptr_u32 = ctx_ptr.as_canonical_u32(); + let logup_evals_ptr = logup_evals_ptr.as_canonical_u32(); + let prod_evals_ptr = prod_evals_ptr.as_canonical_u32(); + + let ctx: [u32; 8] = exec_state + .vm_read(NATIVE_AS, ctx_ptr_u32) + .map(|x: F| x.as_canonical_u32()); + let [round, num_prod_spec, num_logup_spec, prod_specs_inner_len, prod_specs_inner_inner_len, logup_specs_inner_len, logup_specs_inner_inner_len, mode] = + ctx; + let challenges: [F; EXT_DEG * 4] = + exec_state.vm_read(NATIVE_AS, challenges_ptr.as_canonical_u32()); + let alpha: [F; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap(); + let c1: [F; EXT_DEG] = challenges[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); + let c2: [F; EXT_DEG] = challenges[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); + + let mut height = 1; + let mut alpha_acc = elem_to_ext(F::ONE); + let mut eval_acc = elem_to_ext(F::ZERO); + + let prod_offset = ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32; + for i in 0..num_prod_spec { + let [max_round]: [u32; 1] = exec_state + .vm_read(NATIVE_AS, prod_offset + i) + .map(|x: F| x.as_canonical_u32()); + + let start = calculate_3d_ext_idx( + prod_specs_inner_inner_len, + prod_specs_inner_len, + i, + round, + 0, + ); + + if round < max_round - 1 { + let ps: [F; EXT_DEG * 2] = exec_state.vm_read(NATIVE_AS, prod_evals_ptr + start); + let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); + let p2: [F; EXT_DEG] = ps[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); + + let eval = match mode { + CURRENT_LAYER_MODE => FieldExtension::multiply(p1, p2), + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ), + _ => unreachable!("mode can only be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + + exec_state.vm_write(NATIVE_AS, r_evals_ptr_u32 + (1 + i) * EXT_DEG as u32, &eval); + + let to_next_round = if mode == NEXT_LAYER_MODE { 1 } else { 0 }; + if round + to_next_round < max_round - 1 { + // update eval_acc + eval_acc = FieldExtension::add(eval_acc, FieldExtension::multiply(alpha_acc, eval)); + } + } + + // update alpha_acc + alpha_acc = FieldExtension::multiply(alpha_acc, alpha); + height += 1; + } + + let logup_offset = ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32 + num_prod_spec; + for i in 0..num_logup_spec { + // read max_round + let [max_round]: [u32; 1] = exec_state + .vm_read(NATIVE_AS, logup_offset + i) + .map(|x: F| x.as_canonical_u32()); + let start = calculate_3d_ext_idx( + logup_specs_inner_inner_len, + logup_specs_inner_len, + i, + round, + 0, + ); + + let alpha_denominator = FieldExtension::multiply(alpha_acc, alpha); + let alpha_numerator = alpha_acc; + + if round < max_round - 1 { + // read logup_evals + let pqs: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, logup_evals_ptr + start); + let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); + let p2: [F; EXT_DEG] = pqs[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); + let q1: [F; EXT_DEG] = pqs[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); + let q2: [F; EXT_DEG] = pqs[EXT_DEG * 3..EXT_DEG * 4].try_into().unwrap(); + + // compute p_eval and q_eval + let p_eval = match mode { + CURRENT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, q2), + FieldExtension::multiply(p2, q1), + ), + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(p1, c1), + FieldExtension::multiply(p2, c2), + ), + _ => unreachable!("mode can only be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + let q_eval = match mode { + CURRENT_LAYER_MODE => FieldExtension::multiply(q1, q2), + NEXT_LAYER_MODE => FieldExtension::add( + FieldExtension::multiply(q1, c1), + FieldExtension::multiply(q2, c2), + ), + _ => unreachable!("mode can only be {CURRENT_LAYER_MODE} or {NEXT_LAYER_MODE}"), + }; + + // write eval to r_evals + exec_state.vm_write( + NATIVE_AS, + r_evals_ptr_u32 + (1 + num_prod_spec + i) * EXT_DEG as u32, + &p_eval, + ); + exec_state.vm_write( + NATIVE_AS, + r_evals_ptr_u32 + (1 + num_prod_spec + num_logup_spec + i) * EXT_DEG as u32, + &q_eval, + ); + + let eval_rlc = FieldExtension::add( + FieldExtension::multiply(alpha_numerator, p_eval), + FieldExtension::multiply(alpha_denominator, q_eval), + ); + let to_next_round = if mode == NEXT_LAYER_MODE { 1 } else { 0 }; + if round + to_next_round < max_round - 1 { + // update eval_acc + eval_acc = FieldExtension::add(eval_acc, eval_rlc); + } + } + + // update alpha_acc + alpha_acc = FieldExtension::multiply(alpha_denominator, alpha); + height += 1; + } + + *pc += DEFAULT_PC_STEP; + *instret += 1; + + exec_state.vm_write(NATIVE_AS, r_evals_ptr_u32, &eval_acc); + // return height delta + height +} diff --git a/extensions/native/circuit/src/sumcheck/mod.rs b/extensions/native/circuit/src/sumcheck/mod.rs new file mode 100644 index 0000000000..8c35a0a7aa --- /dev/null +++ b/extensions/native/circuit/src/sumcheck/mod.rs @@ -0,0 +1,11 @@ +pub mod air; +pub mod chip; +mod columns; + +#[cfg(feature = "cuda")] +mod cuda; +#[cfg(feature = "cuda")] +pub use cuda::*; + +// mod tests; +mod execution; diff --git a/extensions/native/circuit/src/utils.rs b/extensions/native/circuit/src/utils.rs index c38e656e74..3d05656f16 100644 --- a/extensions/native/circuit/src/utils.rs +++ b/extensions/native/circuit/src/utils.rs @@ -1,9 +1,36 @@ +use openvm_circuit::system::{ + memory::{offline_checker::MemoryBaseAuxCols, online::TracingMemory, MemoryAuxColsFactory}, + native_adapter::util::tracing_read_native, +}; +use p3_field::PrimeField32; + pub(crate) const CASTF_MAX_BITS: usize = 30; pub(crate) const fn const_max(a: usize, b: usize) -> usize { [a, b][(a < b) as usize] } +/// Fill `MemoryBaseAuxCols`, assuming that the `prev_timestamp` is already set in `base_aux`. +pub(crate) fn mem_fill_helper( + mem_helper: &MemoryAuxColsFactory, + timestamp: u32, + base_aux: &mut MemoryBaseAuxCols, +) { + let prev_ts = base_aux.prev_timestamp.as_canonical_u32(); + mem_helper.fill(prev_ts, timestamp, base_aux); +} + +pub(crate) fn tracing_read_native_helper( + memory: &mut TracingMemory, + ptr: u32, + base_aux: &mut MemoryBaseAuxCols, +) -> [F; BLOCK_SIZE] { + let mut prev_ts = 0; + let ret = tracing_read_native(memory, ptr, &mut prev_ts); + base_aux.set_prev(F::from_canonical_u32(prev_ts)); + ret +} + /// Testing framework #[cfg(any(test, feature = "test-utils"))] pub mod test_utils { diff --git a/extensions/native/circuit/tests/ext.rs b/extensions/native/circuit/tests/ext.rs index 5da70cb53b..70584fd926 100644 --- a/extensions/native/circuit/tests/ext.rs +++ b/extensions/native/circuit/tests/ext.rs @@ -35,7 +35,7 @@ fn test_ext2felt() { } #[test] -fn test_ext_from_base_slice() { +fn test_ext_from_base_vec() { const D: usize = 4; type F = BabyBear; type EF = BinomialExtensionField; @@ -52,8 +52,9 @@ fn test_ext_from_base_slice() { let val = EF::from_base_slice(base_slice); let expected: Ext<_, _> = builder.constant(val); - let felts = base_slice.map(|e| builder.constant::>(e)); - let actual = builder.ext_from_base_slice(&felts); + let felts = base_slice.map(|e| builder.constant::>(e)).to_vec(); + let actual = builder.uninit(); + builder.ext_from_base_vec(actual, felts); builder.assert_ext_eq(actual, expected); builder.halt(); diff --git a/extensions/native/compiler/src/asm/compiler.rs b/extensions/native/compiler/src/asm/compiler.rs index a09c8a217e..c80615ca7a 100644 --- a/extensions/native/compiler/src/asm/compiler.rs +++ b/extensions/native/compiler/src/asm/compiler.rs @@ -489,6 +489,17 @@ impl + TwoAdicField> AsmCo DslIr::HintBitsF(var, len) => { self.push(AsmInstruction::HintBits(var.fp(), len), debug_info); } + DslIr::Poseidon2MultiObserve(dst, init_pos, arr_ptr, len) => { + self.push( + AsmInstruction::Poseidon2MultiObserve( + dst.fp(), + init_pos.fp(), + arr_ptr.fp(), + len.get_var().fp(), + ), + debug_info, + ); + } DslIr::Poseidon2PermuteBabyBear(dst, src) => match (dst, src) { (Array::Dyn(dst, _), Array::Dyn(src, _)) => self.push( AsmInstruction::Poseidon2Permute(dst.fp(), src.fp()), @@ -617,6 +628,27 @@ impl + TwoAdicField> AsmCo debug_info, ); } + DslIr::ExtFromBaseVec(ext, base_vec) => { + assert_eq!(base_vec.len(), EF::D); + for (i, base) in base_vec.into_iter().enumerate() { + self.push( + AsmInstruction::CopyF(ext.fp() + (i as i32), base.fp()), + debug_info.clone(), + ); + } + } + DslIr::SumcheckLayerEval(input_ctx, challenges, prod_ptr, logup_ptr, r_ptr) => { + self.push( + AsmInstruction::SumcheckLayerEval( + input_ctx.fp(), + challenges.fp(), + prod_ptr.fp(), + logup_ptr.fp(), + r_ptr.fp(), + ), + debug_info, + ); + } _ => unimplemented!(), } } diff --git a/extensions/native/compiler/src/asm/instruction.rs b/extensions/native/compiler/src/asm/instruction.rs index 1aa5ea8527..3d498c00f4 100644 --- a/extensions/native/compiler/src/asm/instruction.rs +++ b/extensions/native/compiler/src/asm/instruction.rs @@ -110,6 +110,11 @@ pub enum AsmInstruction { /// Halt. Halt, + /// Absorbs multiple base elements into a duplex transcript with Poseidon2 permutation + /// (sponge_state, init_pos, arr_ptr, len) + /// Returns the final index position of hash sponge + Poseidon2MultiObserve(i32, i32, i32, i32), + /// Perform a Poseidon2 permutation on state starting at address `lhs` /// and store new state at `rhs`. /// (a, b) are pointers to (lhs, rhs). @@ -166,6 +171,15 @@ pub enum AsmInstruction { CycleTrackerStart(), CycleTrackerEnd(), + + // Native opcode for calculating sumcheck layer evaluation + // SumcheckLayerEval(reg_a, reg_b, reg_c, ... , reg_f, reg_g) + // - reg_a: Output ptr for next layer's evaluations + // - reg_b: Context variables + // - reg_c: Challenge values (alpha, coeff) + // - reg_g: GKR product IOP evaluations + // - reg_f: GKR logup IOP evaluations + SumcheckLayerEval(i32, i32, i32, i32, i32), } impl> AsmInstruction { @@ -334,6 +348,13 @@ impl> AsmInstruction { AsmInstruction::Trap => write!(f, "trap"), AsmInstruction::Halt => write!(f, "halt"), AsmInstruction::HintBits(src, len) => write!(f, "hint_bits ({})fp, {}", src, len), + AsmInstruction::Poseidon2MultiObserve(dst, init_pos, arr, len) => { + write!( + f, + "poseidon2_multi_observe ({})fp, ({})fp ({})fp ({})fp", + dst, init_pos, arr, len + ) + } AsmInstruction::Poseidon2Permute(dst, lhs) => { write!(f, "poseidon2_permute ({})fp, ({})fp", dst, lhs) } @@ -395,6 +416,13 @@ impl> AsmInstruction { AsmInstruction::RangeCheck(fp, lo_bits, hi_bits) => { write!(f, "range_check_fp ({})fp, ({}), ({})", fp, lo_bits, hi_bits) } + AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr, r_ptr) => { + write!( + f, + "sumcheck_layer_eval ({})fp, ({})fp, ({})fp, ({})fp, ({})fp", + ctx, cs, p_ptr, l_ptr, r_ptr + ) + } } } } diff --git a/extensions/native/compiler/src/constraints/halo2/compiler.rs b/extensions/native/compiler/src/constraints/halo2/compiler.rs index fd75d526d2..3bc7efcf04 100644 --- a/extensions/native/compiler/src/constraints/halo2/compiler.rs +++ b/extensions/native/compiler/src/constraints/halo2/compiler.rs @@ -493,11 +493,11 @@ impl Halo2ConstraintCompiler { } DslIr::CycleTrackerStart(_name) => { #[cfg(feature = "metrics")] - cell_tracker.start(_name); + cell_tracker.start(_name, 0); } DslIr::CycleTrackerEnd(_name) => { #[cfg(feature = "metrics")] - cell_tracker.end(_name); + cell_tracker.end(_name, 0); } DslIr::CircuitPublish(val, index) => { public_values[index] = vars[&val.0]; diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index af4e5080fb..e71608c190 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -12,7 +12,7 @@ use crate::{ asm::{AsmInstruction, AssemblyCode}, FieldArithmeticOpcode, FieldExtensionOpcode, FriOpcode, NativeBranchEqualOpcode, NativeJalOpcode, NativeLoadStore4Opcode, NativeLoadStoreOpcode, NativePhantom, - NativeRangeCheckOpcode, Poseidon2Opcode, VerifyBatchOpcode, + NativeRangeCheckOpcode, Poseidon2Opcode, SumcheckOpcode, VerifyBatchOpcode, }; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] @@ -441,6 +441,18 @@ fn convert_instruction>( AS::Native, AS::Native, )], + AsmInstruction::Poseidon2MultiObserve(dst, init, arr, len) => vec![ + Instruction { + opcode: options.opcode_with_offset(Poseidon2Opcode::MULTI_OBSERVE), + a: i32_f(dst), + b: i32_f(init), + c: i32_f(arr), + d: AS::Native.to_field(), + e: AS::Native.to_field(), + f: i32_f(len), + g: F::ZERO, + } + ], AsmInstruction::Poseidon2Compress(dst, src1, src2) => vec![inst( options.opcode_with_offset(Poseidon2Opcode::COMP_POS2), i32_f(dst), @@ -523,7 +535,19 @@ fn convert_instruction>( // Here it just requires a 0 AS::Immediate, )] - } + }, + AsmInstruction::SumcheckLayerEval(ctx, cs, p_ptr, l_ptr, r_ptr) => vec![ + Instruction { + opcode: options.opcode_with_offset(SumcheckOpcode::SUMCHECK_LAYER_EVAL), + a: i32_f(r_ptr), + b: i32_f(ctx), + c: i32_f(cs), + d: AS::Native.to_field(), + e: AS::Native.to_field(), + f: i32_f(p_ptr), + g: i32_f(l_ptr), + } + ], }; let debug_infos = vec![debug_info; instructions.len()]; diff --git a/extensions/native/compiler/src/ir/builder.rs b/extensions/native/compiler/src/ir/builder.rs index 966c0db21c..64c62e64fa 100644 --- a/extensions/native/compiler/src/ir/builder.rs +++ b/extensions/native/compiler/src/ir/builder.rs @@ -620,6 +620,10 @@ impl Builder { self.witness_space.get(id.value()).unwrap() } + pub fn ext_from_base_vec(&mut self, ext: Ext, base_vec: Vec>) { + self.push(DslIr::ExtFromBaseVec(ext, base_vec)); + } + /// Throws an error. pub fn error(&mut self) { self.operations.trace_push(DslIr::Error()); diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index fa4d8b9931..78347283d5 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -208,6 +208,14 @@ pub enum DslIr { /// Permutes an array of Bn254 elements using Poseidon2 (output = p2_permute(array)). Should /// only be used when target is a circuit. CircuitPoseidon2Permute([Var; 3]), + /// Absorbs an array of baby bear elements into a duplex transcript with Poseidon2 permutations + /// (output = p2_multi_observe(array, els)). + Poseidon2MultiObserve( + Ptr, // sponge_state + Var, // initial input_ptr position + Ptr, // input array (base elements) + Usize, // len of els + ), // Miscellaneous instructions. /// Prints a variable. @@ -241,6 +249,9 @@ pub enum DslIr { /// Operation to halt the program. Should be the last instruction in the program. Halt, + /// Packs a vector of felts into an ext. + ExtFromBaseVec(Ext, Vec>), + // Public inputs for circuits. /// Publish a field element as the ith public value. Should only be used when target is a /// circuit. @@ -309,6 +320,31 @@ pub enum DslIr { CycleTrackerStart(String), /// End the cycle tracker used by a block of code annotated by the string input. CycleTrackerEnd(String), + + /// Native operation for calculating a sumcheck layer's evaluation + /// This op supports two modes: + /// 1. for computing expected evaluation for current layer, output = [ \sum_i alpha^i * + /// prod[i][0] * prod[i][1] + \sum_j alpha^(2j) * (logup_q[i][0] * logup_q[i][1] + alpha* + /// logup_p[i][0] * logup_q[i][1] + alpha * logup_p[i][1] * logup_q[i][0] ]; + /// + /// 2. for computing expected evaluation of next layer, output[1+i] = eq(0,r)*p[i][0] + eq(1,r) + /// * p[i][1]. + SumcheckLayerEval( + Ptr, // Context variables: + // 0. round, + // 1. number of product + // 2. number of logup + // 3. (3D array description) prod_specs_eval inner length + // 4. (3D array description) prod_specs_eval inner_inner length + // 5. (3D array description) logup_spec_eval inner length + // 6. (3D array description) logup_spec_eval inner length + // 7. Operational mode indicator + // 8+. usize-type variables indicating maximum rounds + Ptr, // Challenges: alpha, coeffs + Ptr, // prod_specs_eval + Ptr, // logup_specs_eval + Ptr, // output + ), } impl Default for DslIr { diff --git a/extensions/native/compiler/src/ir/mod.rs b/extensions/native/compiler/src/ir/mod.rs index 47e901cd3a..f708318c34 100644 --- a/extensions/native/compiler/src/ir/mod.rs +++ b/extensions/native/compiler/src/ir/mod.rs @@ -18,6 +18,7 @@ mod instructions; mod poseidon; mod ptr; mod select; +mod sumcheck; mod symbolic; mod types; mod utils; diff --git a/extensions/native/compiler/src/ir/poseidon.rs b/extensions/native/compiler/src/ir/poseidon.rs index 12ec526c89..6d32f89409 100644 --- a/extensions/native/compiler/src/ir/poseidon.rs +++ b/extensions/native/compiler/src/ir/poseidon.rs @@ -2,12 +2,53 @@ use openvm_native_compiler_derive::iter_zip; use openvm_stark_backend::p3_field::FieldAlgebra; use super::{Array, ArrayLike, Builder, Config, DslIr, Ext, Felt, MemIndex, Ptr, Usize, Var}; +use crate::ir::Variable; pub const DIGEST_SIZE: usize = 8; pub const HASH_RATE: usize = 8; pub const PERMUTATION_WIDTH: usize = 16; impl Builder { + /// Extends native VM ability to observe multiple base elements in one opcode operation + /// Absorbs elements sequentially at the RATE portion of sponge state and performs as many + /// permutations as necessary. Returns the index position of the next input_ptr. + /// + /// [Reference](https://docs.rs/p3-poseidon2/latest/p3_poseidon2/struct.Poseidon2.html) + pub fn poseidon2_multi_observe( + &mut self, + sponge_state: &Array>, + input_ptr: Ptr, + arr: &Array>, + ) -> Usize { + let buffer_size: Var = Var::uninit(self); + self.assign(&buffer_size, C::N::from_canonical_usize(HASH_RATE)); + + match sponge_state { + Array::Fixed(_) => { + panic!("Poseidon2 permutation is not allowed on fixed arrays"); + } + Array::Dyn(sponge_ptr, _) => match arr { + Array::Fixed(_) => { + panic!("Base elements input must be dynamic"); + } + Array::Dyn(ptr, len) => { + let init_pos: Var = Var::uninit(self); + self.assign(&init_pos, input_ptr.address - sponge_ptr.address); + + self.operations.push(DslIr::Poseidon2MultiObserve( + *sponge_ptr, + init_pos, + *ptr, + len.clone(), + )); + + // automatically updated by Poseidon2MultiObserve operation + Usize::Var(init_pos) + } + }, + } + } + /// Applies the Poseidon2 permutation to the given array. /// /// [Reference](https://docs.rs/p3-poseidon2/latest/p3_poseidon2/struct.Poseidon2.html) diff --git a/extensions/native/compiler/src/ir/sumcheck.rs b/extensions/native/compiler/src/ir/sumcheck.rs new file mode 100644 index 0000000000..0237fd6740 --- /dev/null +++ b/extensions/native/compiler/src/ir/sumcheck.rs @@ -0,0 +1,48 @@ +use super::{Array, Builder, Config, DslIr, Ext, Usize}; + +impl Builder { + /// Extends native VM ability to calculate the evaluation for a sumcheck layer + /// This opcode supports two modes (indicated by a context variable): + /// 1. calculate the expected evaluation of two types of sumchecks (prod, logup) + /// 2. calculate the expected value of next layer p[r] = eq(0,r)*p[0] + eq(1,r)*p[1] + /// + /// Context variables + /// + /// 0: round, + /// 1: number of product + /// 2. number of logup + /// 3. (3D array description) prod_specs_eval inner length + /// 4. (3D array description) prod_specs_eval inner_inner length + /// 5. (3D array description) logup_spec_eval inner length + /// 6. (3D array description) logup_spec_eval inner length + /// 7. Operational mode indicator + /// 8+ Additional usize-type variables indicating maximum rounds + /// + /// Output + /// + /// 1. for computing expected evaluation, output = [ \sum_i alpha^i * prod[i][0] * prod[i][1] + + /// \sum_j alpha^(2j) * (logup_q[i][0] * logup_q[i][1] + alpha* logup_p[i][0] * logup_q[i][1] + /// + alpha * logup_p[i][1] * logup_q[i][0] ]; + /// + /// 2. for computing expected eval of next layer, output[1+i] = eq(0,r)*p[i][0] + eq(1,r) * + /// p[i][1]. + pub fn sumcheck_layer_eval( + &mut self, + input_ctx: &Array>, // Context variables + challenges: &Array>, // Challenges + prod_specs_eval: &Array>, /* GKR product IOP evaluations. Flattened + * from 3D array. */ + logup_specs_eval: &Array>, /* GKR logup IOP evaluations. Flattened + * from 3D array. */ + r_evals: &Array>, /* Next layer's evaluations (pointer used for + * storing opcode output) */ + ) { + self.operations.push(DslIr::SumcheckLayerEval( + input_ctx.ptr(), + challenges.ptr(), + prod_specs_eval.ptr(), + logup_specs_eval.ptr(), + r_evals.ptr(), + )); + } +} diff --git a/extensions/native/compiler/src/lib.rs b/extensions/native/compiler/src/lib.rs index ef28b37139..efb45b0159 100644 --- a/extensions/native/compiler/src/lib.rs +++ b/extensions/native/compiler/src/lib.rs @@ -184,6 +184,7 @@ pub enum NativePhantom { pub enum Poseidon2Opcode { PERM_POS2, COMP_POS2, + MULTI_OBSERVE, } /// Opcodes for FRI opening proofs. @@ -211,3 +212,18 @@ pub enum VerifyBatchOpcode { /// per column polynomial, per opening point VERIFY_BATCH, } + +/// Opcodes for sumcheck. +#[derive( + Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, +)] +#[opcode_offset = 0x180] +#[repr(usize)] +#[allow(non_camel_case_types)] +pub enum SumcheckOpcode { + /// Compute the expected evaluation for each layer in the tower structure that GKR product IOP + /// and logup IOP uses Supports two modes of operation: + /// 1. Calculate current layer's expected evaluation + /// 2. Calculate next layer's evaluation + SUMCHECK_LAYER_EVAL, +} diff --git a/extensions/native/recursion/Cargo.toml b/extensions/native/recursion/Cargo.toml index a8efd69ab3..f47f263693 100644 --- a/extensions/native/recursion/Cargo.toml +++ b/extensions/native/recursion/Cargo.toml @@ -8,6 +8,7 @@ repository.workspace = true [dependencies] openvm-stark-backend = { workspace = true } +openvm-cuda-backend = { workspace = true, optional = true } openvm-native-circuit = { workspace = true, features = ["test-utils"] } openvm-native-compiler = { workspace = true } openvm-native-compiler-derive = { workspace = true } @@ -58,4 +59,4 @@ parallel = ["openvm-stark-backend/parallel"] mimalloc = ["openvm-stark-backend/mimalloc"] jemalloc = ["openvm-stark-backend/jemalloc"] nightly-features = ["openvm-circuit/nightly-features"] -cuda = ["openvm-circuit/cuda", "openvm-native-circuit/cuda"] +cuda = ["openvm-circuit/cuda", "openvm-native-circuit/cuda", "dep:openvm-cuda-backend"] diff --git a/extensions/native/recursion/src/challenger/duplex.rs b/extensions/native/recursion/src/challenger/duplex.rs index 7c0cd4dd88..440b14ec59 100644 --- a/extensions/native/recursion/src/challenger/duplex.rs +++ b/extensions/native/recursion/src/challenger/duplex.rs @@ -77,6 +77,24 @@ impl DuplexChallengerVariable { } } + // Observes multiple elements from an array. + // This is equivalent to calling `observe` multiple times, but more efficient. + pub fn observe_slice_opt(&self, builder: &mut Builder, arr: &Array>) { + builder.if_ne(arr.len(), Usize::from(0)).then(|builder| { + let next_pos = builder.poseidon2_multi_observe(&self.sponge_state, self.input_ptr, arr); + + builder.assign(&self.input_ptr, self.io_empty_ptr + next_pos.clone()); + builder.if_ne(next_pos, Usize::from(0)).then_or_else( + |builder| { + builder.assign(&self.output_ptr, self.io_empty_ptr); + }, + |builder| { + builder.assign(&self.output_ptr, self.io_full_ptr); + }, + ); + }); + } + fn sample(&self, builder: &mut Builder) -> Felt { builder .if_ne(self.input_ptr.address, self.io_empty_ptr.address) @@ -101,7 +119,9 @@ impl DuplexChallengerVariable { let b = self.sample(builder); let c = self.sample(builder); let d = self.sample(builder); - builder.ext_from_base_slice(&[a, b, c, d]) + let ext = builder.uninit(); + builder.ext_from_base_vec(ext, vec![a, b, c, d]); + ext } fn sample_bits(&self, builder: &mut Builder, nb_bits: RVar) -> Array> diff --git a/extensions/native/recursion/tests/recursion.rs b/extensions/native/recursion/tests/recursion.rs index 6322614784..68e033845b 100644 --- a/extensions/native/recursion/tests/recursion.rs +++ b/extensions/native/recursion/tests/recursion.rs @@ -4,13 +4,22 @@ use openvm_circuit::{ instructions::program::Program, PreflightExecutionOutput, PreflightExecutor, VmBuilder, VmCircuitConfig, VmExecutionConfig, }, - utils::TestStarkEngine, + utils::{air_test_impl, TestStarkEngine}, }; +#[cfg(feature = "cuda")] +use openvm_cuda_backend::engine::GpuBabyBearPoseidon2Engine; use openvm_native_circuit::{ execute_program_with_config, test_native_config, NativeBuilder, NativeConfig, }; -use openvm_native_compiler::{asm::AsmBuilder, ir::Felt}; -use openvm_native_recursion::testing_utils::inner::run_recursive_test; +use openvm_native_compiler::{ + asm::{AsmBuilder, AsmCompiler}, + conversion::{convert_program, CompilerOptions}, + ir::{Array, Builder, Config, Felt}, +}; +use openvm_native_recursion::{ + challenger::{duplex::DuplexChallengerVariable, CanObserveVariable, CanSampleVariable}, + testing_utils::inner::run_recursive_test, +}; use openvm_stark_backend::{ config::{Domain, StarkGenericConfig, Val}, p3_commit::PolynomialSpace, @@ -24,12 +33,14 @@ use openvm_stark_backend::{ use openvm_stark_sdk::{ config::{ baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine}, + fri_params::standard_fri_params_with_100_bits_conjectured_security, FriParameters, }, engine::StarkFriEngine, p3_baby_bear::BabyBear, - utils::ProofInputForTest, + utils::{create_seeded_rng, ProofInputForTest}, }; +use rand::Rng; fn fibonacci_program(a: u32, b: u32, n: u32) -> Program { type F = BabyBear; @@ -166,3 +177,80 @@ fn test_fibonacci_program_halo2_verify() { FriParameters::new_for_testing(LOG_BLOWUP), ); } + +#[test] +fn test_multi_observe() { + type F = BabyBear; + type EF = BinomialExtensionField; + let mut builder = AsmBuilder::::default(); + + build_test_program(&mut builder); + + let compilation_options = CompilerOptions::default().with_cycle_tracker(); + let mut compiler = AsmCompiler::new(compilation_options.word_size); + compiler.build(builder.operations); + let asm_code = compiler.code(); + + let program: Program<_> = convert_program(asm_code, compilation_options); + + let poseidon2_max_constraint_degree = 3; + + let fri_params = if matches!(std::env::var("OPENVM_FAST_TEST"), Ok(x) if &x == "1") { + FriParameters { + // max constraint degree = 2^log_blowup + 1 + log_blowup: 1, + log_final_poly_len: 0, + num_queries: 2, + proof_of_work_bits: 0, + } + } else { + standard_fri_params_with_100_bits_conjectured_security(1) + }; + + let mut config = NativeConfig::aggregation(0, poseidon2_max_constraint_degree); + config.system.memory_config.max_access_adapter_n = 16; + + let vb = NativeBuilder::default(); + #[cfg(not(feature = "cuda"))] + air_test_impl::(fri_params, vb, config, program, vec![], 1, true) + .unwrap(); + #[cfg(feature = "cuda")] + { + air_test_impl::( + fri_params, + vb, + config, + program, + vec![], + 1, + true, + ) + .unwrap(); + } +} + +fn build_test_program(builder: &mut Builder) { + let sample_lens: Vec = vec![10, 2, 1, 0, 3, 20, 200, 400]; + + let mut rng = create_seeded_rng(); + + let mut c1 = DuplexChallengerVariable::new(builder); + let mut c2 = DuplexChallengerVariable::new(builder); + + for l in sample_lens { + let sample_input: Array> = builder.dyn_array(l); + builder.range(0, l).for_each(|idx_vec, builder| { + let f_u32: u32 = rng.gen_range(1..1 << 30); + builder.set(&sample_input, idx_vec[0], C::F::from_canonical_u32(f_u32)); + }); + + c1.observe_slice_opt(builder, &sample_input); + c2.observe_slice(builder, sample_input); + + let e1 = c1.sample(builder); + let e2 = c2.sample(builder); + + builder.assert_felt_eq(e1, e2); + } + builder.halt(); +} diff --git a/extensions/native/recursion/tests/sumcheck.rs b/extensions/native/recursion/tests/sumcheck.rs new file mode 100644 index 0000000000..a500ee6aac --- /dev/null +++ b/extensions/native/recursion/tests/sumcheck.rs @@ -0,0 +1,221 @@ +use std::iter::{once, repeat_n}; + +use openvm_circuit::{arch::instructions::program::Program, utils::air_test_impl}; +#[cfg(feature = "cuda")] +use openvm_cuda_backend::engine::GpuBabyBearPoseidon2Engine; +use openvm_native_circuit::{NativeBuilder, NativeConfig, EXT_DEG}; +use openvm_native_compiler::{ + asm::{AsmBuilder, AsmCompiler}, + conversion::{convert_program, CompilerOptions}, + ir::{Ext, Usize}, + prelude::*, +}; +use openvm_stark_backend::p3_field::{ + extension::BinomialExtensionField, FieldAlgebra, FieldExtensionAlgebra, +}; +#[cfg(not(feature = "cuda"))] +use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Engine; +use openvm_stark_sdk::{ + config::{fri_params::standard_fri_params_with_100_bits_conjectured_security, FriParameters}, + p3_baby_bear::BabyBear, +}; +use rand::{thread_rng, RngCore}; + +pub type F = BabyBear; +pub type E = BinomialExtensionField; + +#[test] +fn test_sumcheck_layer_eval() { + let mut builder = AsmBuilder::>::default(); + + build_test_program(&mut builder); + + let compilation_options = CompilerOptions::default().with_cycle_tracker(); + let mut compiler = AsmCompiler::new(compilation_options.word_size); + compiler.build(builder.operations); + let asm_code = compiler.code(); + + let program: Program<_> = convert_program(asm_code, compilation_options); + let sumcheck_max_constraint_degree = 3; + let fri_params = if matches!(std::env::var("OPENVM_FAST_TEST"), Ok(x) if &x == "1") { + FriParameters { + // max constraint degree = 2^log_blowup + 1 + log_blowup: 1, + log_final_poly_len: 0, + num_queries: 2, + proof_of_work_bits: 0, + } + } else { + standard_fri_params_with_100_bits_conjectured_security(1) + }; + + let mut config = NativeConfig::aggregation(0, sumcheck_max_constraint_degree); + config.system.memory_config.max_access_adapter_n = 16; + + let vb = NativeBuilder::default(); + #[cfg(not(feature = "cuda"))] + air_test_impl::(fri_params, vb, config, program, vec![], 1, true) + .unwrap(); + #[cfg(feature = "cuda")] + { + air_test_impl::( + fri_params, + vb, + config, + program, + vec![], + 1, + true, + ) + .unwrap(); + } +} + +fn new_rand_ext(rng: &mut R) -> C::EF { + C::EF::from_base_slice(&[ + C::F::from_canonical_u32(rng.next_u32()), + C::F::from_canonical_u32(rng.next_u32()), + C::F::from_canonical_u32(rng.next_u32()), + C::F::from_canonical_u32(rng.next_u32()), + ]) +} + +fn build_test_program(builder: &mut Builder) { + let mut rng = thread_rng(); + // 6 prod specs in 8 layers, 5 logup specs in 8 layers + let round = 3; + let num_prod_specs = 6; + let num_logup_specs = 5; + let num_layers = 8; + let mode = 1; // current_layer + + let mut ctx_u32s = vec![ + round, + num_prod_specs, + num_logup_specs, + num_layers, + 2, + num_layers, + 4, + mode, + ]; + ctx_u32s.extend(repeat_n(num_layers, num_prod_specs + num_logup_specs)); + + let ctx: Array> = builder.dyn_array(ctx_u32s.len()); + for (idx, n) in ctx_u32s.into_iter().enumerate() { + builder.set(&ctx, idx, Usize::from(n as usize)); + } + + #[rustfmt::skip] + let challenges_u32s = [ + 548478283u32, 456436544, 1716290291, 791326976, + 1829717553, 1422025771, 1917123958, 727015942, + 183548369, 591240150, 96141963, 1286249979, + 0, 0, 0, 0, + ]; + let challenges: Array> = builder.dyn_array(challenges_u32s.len() / EXT_DEG); + for (idx, n) in challenges_u32s.chunks(EXT_DEG).enumerate() { + let e: Ext = builder.constant(C::EF::from_base_slice(&[ + C::F::from_canonical_u32(n[0]), + C::F::from_canonical_u32(n[1]), + C::F::from_canonical_u32(n[2]), + C::F::from_canonical_u32(n[3]), + ])); + + builder.set(&challenges, idx, e); + } + + let num_prod_evals = num_prod_specs * num_layers * 2; + let prod_spec_evals: Array> = builder.dyn_array(num_prod_evals); + for idx in 0..num_prod_evals { + let e: Ext = builder.constant(new_rand_ext::(&mut rng)); + + builder.set(&prod_spec_evals, idx, e); + } + + let num_logup_evals = num_logup_specs * num_layers * 4; + let logup_spec_evals: Array> = builder.dyn_array(num_logup_evals); + for idx in 0..num_logup_evals { + let e: Ext = builder.constant(new_rand_ext::(&mut rng)); + + builder.set(&logup_spec_evals, idx, e); + } + + let alpha = builder.get(&challenges, 0); + let c1 = builder.get(&challenges, 1); + let c2 = builder.get(&challenges, 2); + + let alpha_acc: Ext = builder.constant(C::EF::ONE); + let eval_acc: Ext = builder.constant(C::EF::ZERO); + + let mut p_evals = vec![]; + for i in 0..num_prod_specs { + let start = num_layers * 2 * i + 2 * round; + let p1 = builder.get(&prod_spec_evals, start); + let p2 = builder.get(&prod_spec_evals, start + 1); + let p_eval: Ext = if mode == 1 { + // current layer + builder.eval(p1 * p2) + } else { + // next layer + builder.eval(p1 * c1 + p2 * c2) + }; + p_evals.push(p_eval); + let eval_rlc: Ext = builder.eval(alpha_acc * p_eval); + builder.assign(&eval_acc, eval_acc + eval_rlc); + builder.assign(&alpha_acc, alpha_acc * alpha); + } + + let mut logup_p_evals = vec![]; + let mut logup_q_evals = vec![]; + for i in 0..num_logup_specs { + let start = num_layers * 4 * i + 4 * round; + let p1 = builder.get(&logup_spec_evals, start); + let p2 = builder.get(&logup_spec_evals, start + 1); + let q1 = builder.get(&logup_spec_evals, start + 2); + let q2 = builder.get(&logup_spec_evals, start + 3); + let p_eval: Ext = if mode == 1 { + builder.eval(p1 * q2 + p2 * q1) + } else { + builder.eval(p1 * c1 + p2 * c2) + }; + let q_eval: Ext = if mode == 1 { + builder.eval(q1 * q2) + } else { + builder.eval(q1 * c1 + q2 * c2) + }; + + logup_p_evals.push(p_eval); + logup_q_evals.push(q_eval); + + let alpha_denominator: Ext = builder.eval(alpha_acc * alpha); + let eval_rlc: Ext = + builder.eval(alpha_acc * p_eval + alpha_denominator * q_eval); + + builder.assign(&eval_acc, eval_acc + eval_rlc); + builder.assign(&alpha_acc, alpha_acc * alpha * alpha); + } + + let r_evals = once(eval_acc) + .chain(p_evals.into_iter()) + .chain(logup_p_evals.into_iter()) + .chain(logup_q_evals.into_iter()) + .collect::>(); + + let next_layer_evals: Array> = builder.dyn_array(r_evals.len()); + + builder.sumcheck_layer_eval( + &ctx, + &challenges, + &prod_spec_evals, + &logup_spec_evals, + &next_layer_evals, + ); + + for (idx, e) in r_evals.into_iter().enumerate() { + let next_eval = builder.get(&next_layer_evals, idx); + builder.assert_ext_eq(next_eval, e); + } + + builder.halt(); +}