Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ Cargo.lock
**/.env
.DS_Store

# Log outputs
*.log

.cache/
rustc-*

Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions crates/sdk/src/prover/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ where
E: StarkFriEngine<SC = SC>,
NativeBuilder: VmBuilder<E, VmConfig = NativeConfig>,
{
leaf_prover: VmInstance<E, NativeBuilder>,
leaf_controller: LeafProvingController,
pub leaf_prover: VmInstance<E, NativeBuilder>,
pub leaf_controller: LeafProvingController,

pub internal_prover: VmInstance<E, NativeBuilder>,
#[cfg(feature = "evm-prove")]
root_prover: RootVerifierLocalProver,
pub root_prover: RootVerifierLocalProver,
pub num_children_internal: usize,
pub max_internal_wrapper_layers: usize,
}
Expand Down
44 changes: 36 additions & 8 deletions crates/vm/src/metrics/cycle_tracker/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
/// Stats for a nested span in the execution segment that is tracked by the [`CycleTracker`].
#[derive(Clone, Debug, Default)]
pub struct SpanInfo {
/// The name of the span.
pub tag: String,
/// The cycle count at which the span starts.
pub start: usize,
}

#[derive(Clone, Debug, Default)]
pub struct CycleTracker {
/// Stack of span names, with most recent at the end
stack: Vec<String>,
stack: Vec<SpanInfo>,
/// Depth of the stack.
depth: usize,
}

impl CycleTracker {
Expand All @@ -10,29 +21,42 @@ impl CycleTracker {
}

pub fn top(&self) -> Option<&String> {
self.stack.last()
match self.stack.last() {
Some(span) => Some(&span.tag),
_ => None,
}
}

/// Starts a new cycle tracker span for the given name.
/// If a span already exists for the given name, it ends the existing span and pushes a new one
/// to the vec.
pub fn start(&mut self, mut name: String) {
pub fn start(&mut self, mut name: String, cycles_count: usize) {
// hack to remove "CT-" prefix
if name.starts_with("CT-") {
name = name.split_off(3);
}
self.stack.push(name);
self.stack.push(SpanInfo {
tag: name.clone(),
start: cycles_count,
});
let padding = "│ ".repeat(self.depth);
tracing::info!("{}┌╴{}", padding, name);
self.depth += 1;
}

/// Ends the cycle tracker span for the given name.
/// If no span exists for the given name, it panics.
pub fn end(&mut self, mut name: String) {
pub fn end(&mut self, mut name: String, cycles_count: usize) {
// hack to remove "CT-" prefix
if name.starts_with("CT-") {
name = name.split_off(3);
}
let stack_top = self.stack.pop();
assert_eq!(stack_top.unwrap(), name, "Stack top does not match name");
let SpanInfo { tag, start } = self.stack.pop().unwrap();
assert_eq!(tag, name, "Stack top does not match name");
self.depth -= 1;
let padding = "│ ".repeat(self.depth);
let span_cycles = cycles_count - start;
tracing::info!("{}└╴{} cycles", padding, span_cycles);
}

/// Ends the current cycle tracker span.
Expand All @@ -42,7 +66,11 @@ impl CycleTracker {

/// Get full name of span with all parent names separated by ";" in flamegraph format
pub fn get_full_name(&self) -> String {
self.stack.join(";")
self.stack
.iter()
.map(|span_info| span_info.tag.clone())
.collect::<Vec<String>>()
.join(";")
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/vm/src/metrics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ impl VmMetrics {
.map(|(_, func)| (*func).clone())
.unwrap();
if pc == self.current_fn.start {
self.cycle_tracker.start(self.current_fn.name.clone());
self.cycle_tracker.start(self.current_fn.name.clone(), 0);
} else {
while let Some(name) = self.cycle_tracker.top() {
if name == &self.current_fn.name {
Expand Down
32 changes: 31 additions & 1 deletion extensions/native/circuit/cuda/include/native/poseidon2.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,42 @@ template <typename T> struct SimplePoseidonSpecificCols {
MemoryWriteAuxCols<T, CHUNK> write_data_2;
};

template <typename T> struct MultiObserveCols {
T pc;
T final_timestamp_increment;
T state_ptr;
T input_ptr;
T init_pos;
T len;
T input_register_1;
T input_register_2;
T input_register_3;
T output_register;
T is_first;
T is_last;
T curr_len;
T start_idx;
T end_idx;
T aux_after_start[CHUNK];
T aux_before_end[CHUNK];
T aux_read_enabled[CHUNK];
MemoryReadAuxCols<T> read_data[CHUNK];
MemoryWriteAuxCols<T, 1> write_data[CHUNK];
T data[CHUNK];
T should_permute;
MemoryWriteAuxCols<T, CHUNK * 2> write_sponge_state;
MemoryWriteAuxCols<T, 1> write_final_idx;
};

template <typename T> constexpr T constexpr_max(T a, T b) { return a > b ? a : b; }

constexpr size_t COL_SPECIFIC_WIDTH = constexpr_max(
sizeof(TopLevelSpecificCols<uint8_t>),
constexpr_max(
sizeof(InsideRowSpecificCols<uint8_t>),
sizeof(SimplePoseidonSpecificCols<uint8_t>)
constexpr_max(
sizeof(SimplePoseidonSpecificCols<uint8_t>),
sizeof(MultiObserveCols<uint8_t>)
)
)
);
85 changes: 85 additions & 0 deletions extensions/native/circuit/cuda/include/native/sumcheck.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#pragma once

#include "primitives/constants.h"
#include "system/memory/offline_checker.cuh"

using namespace native;

template <typename T> struct HeaderSpecificCols {
T pc;
T registers[5];
MemoryReadAuxCols<T> read_records[7];
MemoryWriteAuxCols<T, EXT_DEG> write_records;
};

template <typename T> struct ProdSpecificCols {
T data_ptr;
T p[EXT_DEG * 2];
MemoryReadAuxCols<T> read_records[2];
T p_evals[EXT_DEG];
MemoryWriteAuxCols<T, EXT_DEG> write_record;
T eval_rlc[EXT_DEG];
};

template <typename T> struct LogupSpecificCols {
T data_ptr;
T pq[EXT_DEG * 4];
MemoryReadAuxCols<T> read_records[2];
T p_evals[EXT_DEG];
T q_evals[EXT_DEG];
MemoryWriteAuxCols<T, EXT_DEG> write_records[2];
T eval_rlc[EXT_DEG];
};

template <typename T> constexpr T constexpr_max(T a, T b) {
return a > b ? a : b;
}

constexpr size_t COL_SPECIFIC_WIDTH = constexpr_max(
sizeof(HeaderSpecificCols<uint8_t>),
constexpr_max(sizeof(ProdSpecificCols<uint8_t>), sizeof(LogupSpecificCols<uint8_t>))
);

template <typename T> struct NativeSumcheckCols {
T header_row;
T prod_row;
T logup_row;
T is_end;

T prod_continued;
T logup_continued;

T prod_in_round_evaluation;
T prod_next_round_evaluation;
T logup_in_round_evaluation;
T logup_next_round_evaluation;

T prod_acc;
T logup_acc;

T first_timestamp;
T start_timestamp;
T last_timestamp;

T register_ptrs[5];

T ctx[EXT_DEG * 2];

T prod_nested_len;
T logup_nested_len;

T curr_prod_n;
T curr_logup_n;

T alpha[EXT_DEG];
T challenges[EXT_DEG * 4];

T max_round;
T within_round_limit;
T should_acc;

T eval_acc[EXT_DEG];

T specific[COL_SPECIFIC_WIDTH];
};

13 changes: 13 additions & 0 deletions extensions/native/circuit/cuda/include/native/utils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once

#include "primitives/trace_access.h"
#include "system/memory/controller.cuh"

__device__ __forceinline__ void mem_fill_base(
MemoryAuxColsFactory &mem_helper,
uint32_t timestamp,
RowSlice base_aux
) {
uint32_t prev = base_aux[COL_INDEX(MemoryBaseAuxCols, prev_timestamp)].asUInt32();
mem_helper.fill(base_aux, prev, timestamp);
}
81 changes: 72 additions & 9 deletions extensions/native/circuit/cuda/src/poseidon2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "poseidon2-air/columns.cuh"
#include "poseidon2-air/params.cuh"
#include "poseidon2-air/tracegen.cuh"
#include "native/utils.cuh"
#include "primitives/trace_access.h"
#include "system/memory/controller.cuh"

Expand All @@ -22,6 +23,7 @@ template <typename T, size_t SBOX_REGISTERS> struct NativePoseidon2Cols {
T incorporate_sibling;
T inside_row;
T simple;
T multi_observe_row;

T end_inside_row;
T end_top_level;
Expand All @@ -37,15 +39,6 @@ template <typename T, size_t SBOX_REGISTERS> struct NativePoseidon2Cols {
T specific[COL_SPECIFIC_WIDTH];
};

__device__ void mem_fill_base(
MemoryAuxColsFactory mem_helper,
uint32_t timestamp,
RowSlice base_aux
) {
uint32_t prev = base_aux[COL_INDEX(MemoryBaseAuxCols, prev_timestamp)].asUInt32();
mem_helper.fill(base_aux, prev, timestamp);
}

template <size_t SBOX_REGISTERS> struct Poseidon2Wrapper {
template <typename T> using Cols = NativePoseidon2Cols<T, SBOX_REGISTERS>;
using Poseidon2Row =
Expand All @@ -58,6 +51,8 @@ template <size_t SBOX_REGISTERS> struct Poseidon2Wrapper {
) {
if (row[COL_INDEX(Cols, simple)] == Fp::one()) {
fill_simple_chunk(row, range_checker, timestamp_max_bits);
} else if (row[COL_INDEX(Cols, multi_observe_row)] == Fp::one()) {
fill_multi_observe_chunk(row, range_checker, timestamp_max_bits);
} else {
fill_verify_batch_chunk(row, range_checker, timestamp_max_bits);
}
Expand Down Expand Up @@ -335,6 +330,74 @@ template <size_t SBOX_REGISTERS> struct Poseidon2Wrapper {
}
}
}

__device__ static void fill_multi_observe_chunk(
RowSlice row,
VariableRangeChecker range_checker,
uint32_t timestamp_max_bits
) {
MemoryAuxColsFactory mem_helper(range_checker, timestamp_max_bits);
Poseidon2Row head_row(row);
uint32_t num_rows = head_row.export_col()[0].asUInt32();

for (uint32_t idx = 0; idx < num_rows; ++idx) {
RowSlice curr_row = row.shift_row(idx);
fill_inner(curr_row);
fill_multi_observe_specific(curr_row, mem_helper);
}
}

__device__ static void fill_multi_observe_specific(
RowSlice row,
MemoryAuxColsFactory &mem_helper
) {
RowSlice specific = row.slice_from(COL_INDEX(Cols, specific));
if (specific[COL_INDEX(MultiObserveCols, is_first)] == Fp::one()) {
uint32_t very_start_timestamp =
row[COL_INDEX(Cols, very_first_timestamp)].asUInt32();
for (uint32_t i = 0; i < 4; ++i) {
mem_fill_base(
mem_helper,
very_start_timestamp + i,
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[i].base))
);
}
} else {
uint32_t start_timestamp = row[COL_INDEX(Cols, start_timestamp)].asUInt32();
uint32_t chunk_start =
specific[COL_INDEX(MultiObserveCols, start_idx)].asUInt32();
uint32_t chunk_end =
specific[COL_INDEX(MultiObserveCols, end_idx)].asUInt32();
for (uint32_t j = chunk_start; j < chunk_end; ++j) {
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base))
);
mem_fill_base(
mem_helper,
start_timestamp + 1,
specific.slice_from(COL_INDEX(MultiObserveCols, write_data[j].base))
);
start_timestamp += 2;
}
if (chunk_end >= CHUNK) {
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(MultiObserveCols, write_sponge_state.base))
);
start_timestamp += 1;
}
if (specific[COL_INDEX(MultiObserveCols, is_last)] == Fp::one()) {
mem_fill_base(
mem_helper,
start_timestamp,
specific.slice_from(COL_INDEX(MultiObserveCols, write_final_idx.base))
);
}
}
}
};

template <size_t SBOX_REGISTERS>
Expand Down
Loading
Loading