From ead66023d7c952bac0ac58a6b17878ca585212ea Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sat, 7 Mar 2026 01:52:54 -0500 Subject: [PATCH 01/18] hint bridge --- crates/vm/src/arch/testing/cpu.rs | 10 +- crates/vm/src/arch/testing/cuda.rs | 10 +- crates/vm/src/arch/testing/mod.rs | 1 + .../system/memory/offline_checker/bridge.rs | 136 +++++++++++++++++- .../src/system/memory/offline_checker/bus.rs | 101 +++++++++++++ .../system/memory/offline_checker/columns.rs | 45 ++++++ crates/vm/src/system/mod.rs | 10 +- .../algebra/circuit/src/extension/fp2.rs | 1 + .../algebra/circuit/src/extension/modular.rs | 1 + .../bigint/circuit/src/extension/mod.rs | 1 + .../ecc/circuit/src/extension/weierstrass.rs | 1 + .../keccak256/circuit/src/extension/mod.rs | 1 + .../native/circuit/src/extension/mod.rs | 2 + .../rv32im/circuit/src/extension/mod.rs | 3 + .../sha256/circuit/src/sha256_chip/air.rs | 1 + 15 files changed, 316 insertions(+), 8 deletions(-) diff --git a/crates/vm/src/arch/testing/cpu.rs b/crates/vm/src/arch/testing/cpu.rs index 70c374968c..ae2e9e906d 100644 --- a/crates/vm/src/arch/testing/cpu.rs +++ b/crates/vm/src/arch/testing/cpu.rs @@ -37,8 +37,9 @@ use crate::{ testing::{ execution::air::ExecutionDummyAir, program::{air::ProgramDummyAir, ProgramTester}, - ExecutionTester, MemoryTester, TestBuilder, TestChipHarness, EXECUTION_BUS, MEMORY_BUS, - MEMORY_MERKLE_BUS, POSEIDON2_DIRECT_BUS, RANGE_CHECKER_BUS, READ_INSTRUCTION_BUS, + ExecutionTester, MemoryTester, TestBuilder, TestChipHarness, EXECUTION_BUS, HINT_BUS, + MEMORY_BUS, MEMORY_MERKLE_BUS, POSEIDON2_DIRECT_BUS, RANGE_CHECKER_BUS, + READ_INSTRUCTION_BUS, }, vm_poseidon2_config, Arena, ExecutionBridge, ExecutionBus, ExecutionState, MatrixRecordArena, MemoryConfig, PreflightExecutor, Streams, VmStateMut, @@ -46,7 +47,7 @@ use crate::{ system::{ memory::{ adapter::records::arena_size_bound, - offline_checker::{MemoryBridge, MemoryBus}, + offline_checker::{HintBridge, HintBus, MemoryBridge, MemoryBus}, online::TracingMemory, MemoryAirInventory, MemoryController, SharedMemoryHelper, CHUNK, }, @@ -258,10 +259,13 @@ impl VmChipTestBuilder { } pub fn system_port(&self) -> SystemPort { + let hint_bus = HintBus::new(HINT_BUS); + let hint_bridge = HintBridge::new(hint_bus); SystemPort { execution_bus: self.execution.bus, program_bus: self.program.bus, memory_bridge: self.memory_bridge(), + hint_bridge, } } diff --git a/crates/vm/src/arch/testing/cuda.rs b/crates/vm/src/arch/testing/cuda.rs index 0427f50671..e53ded66cc 100644 --- a/crates/vm/src/arch/testing/cuda.rs +++ b/crates/vm/src/arch/testing/cuda.rs @@ -50,7 +50,7 @@ use crate::{ execution::{air::ExecutionDummyAir, DeviceExecutionTester}, memory::DeviceMemoryTester, program::{air::ProgramDummyAir, DeviceProgramTester}, - TestBuilder, TestChipHarness, EXECUTION_BUS, MEMORY_BUS, MEMORY_MERKLE_BUS, + TestBuilder, TestChipHarness, EXECUTION_BUS, HINT_BUS, MEMORY_BUS, MEMORY_MERKLE_BUS, POSEIDON2_DIRECT_BUS, READ_INSTRUCTION_BUS, }, Arena, DenseRecordArena, ExecutionBridge, ExecutionBus, ExecutionState, MatrixRecordArena, @@ -59,7 +59,7 @@ use crate::{ system::{ cuda::{poseidon2::Poseidon2PeripheryChipGPU, DIGEST_WIDTH}, memory::{ - offline_checker::{MemoryBridge, MemoryBus}, + offline_checker::{HintBridge, HintBus, MemoryBridge, MemoryBus}, MemoryAirInventory, SharedMemoryHelper, }, poseidon2::air::Poseidon2PeripheryAir, @@ -393,8 +393,14 @@ impl GpuChipTestBuilder { execution_bus: self.execution_bus(), program_bus: self.program_bus(), memory_bridge: self.memory_bridge(), + hint_bridge: self.hint_bridge(), } } + + pub fn hint_bridge(&self) -> HintBridge { + let hint_bus = HintBus::new(HINT_BUS); + HintBridge::new(hint_bus) + } pub fn execution_bridge(&self) -> ExecutionBridge { ExecutionBridge::new(self.execution.bus(), self.program.bus()) } diff --git a/crates/vm/src/arch/testing/mod.rs b/crates/vm/src/arch/testing/mod.rs index 5293a0275a..52a9385e25 100644 --- a/crates/vm/src/arch/testing/mod.rs +++ b/crates/vm/src/arch/testing/mod.rs @@ -29,6 +29,7 @@ pub const BITWISE_OP_LOOKUP_BUS: BusIndex = 9; pub const BYTE_XOR_BUS: BusIndex = 10; pub const RANGE_TUPLE_CHECKER_BUS: BusIndex = 11; pub const MEMORY_MERKLE_BUS: BusIndex = 12; +pub const HINT_BUS: BusIndex = 13; pub const RANGE_CHECKER_BUS: BusIndex = 4; diff --git a/crates/vm/src/system/memory/offline_checker/bridge.rs b/crates/vm/src/system/memory/offline_checker/bridge.rs index 367e1344d7..e377497375 100644 --- a/crates/vm/src/system/memory/offline_checker/bridge.rs +++ b/crates/vm/src/system/memory/offline_checker/bridge.rs @@ -10,10 +10,11 @@ use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::AirBuilder, p3_field::FieldAlgebra, }; -use super::bus::MemoryBus; +use super::bus::{HintBus, MemoryBus}; use crate::system::memory::{ offline_checker::columns::{ - MemoryBaseAuxCols, MemoryReadAuxCols, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols, + HintReadAuxCols, HintWriteAuxCols, MemoryBaseAuxCols, MemoryReadAuxCols, + MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols, }, MemoryAddress, }; @@ -326,3 +327,134 @@ impl MemoryOfflineChecker { .eval(builder, enabled); } } + +/// The [HintBridge] is used to constrain logical hint memory operations (read/write). +/// It adds all necessary constraints and interactions for a separate hint space. +#[derive(Clone, Copy, Debug)] +pub struct HintBridge { + hint_bus: HintBus, +} + +impl HintBridge { + /// Create a new [HintBridge] with the provided hint bus. + pub fn new(hint_bus: HintBus) -> Self { + Self { hint_bus } + } + + pub fn hint_bus(&self) -> HintBus { + self.hint_bus + } + + /// Prepare a logical hint read operation. + /// A read or write interaction has the form: (hint_id, offset, value, timestamp) + #[must_use] + pub fn read<'a, T, V>( + &self, + hint_id: impl Into, + offset: impl Into, + value: impl Into, + timestamp: impl Into, + aux: &'a HintReadAuxCols, + ) -> HintReadOperation<'a, T, V> { + HintReadOperation { + hint_bus: self.hint_bus, + hint_id: hint_id.into(), + offset: offset.into(), + value: value.into(), + timestamp: timestamp.into(), + aux, + } + } + + /// Prepare a logical hint write operation. + /// A read or write interaction has the form: (hint_id, offset, value, timestamp) + #[must_use] + pub fn write<'a, T, V>( + &self, + hint_id: impl Into, + offset: impl Into, + value: impl Into, + timestamp: impl Into, + aux: &'a HintWriteAuxCols, + ) -> HintWriteOperation<'a, T, V> { + HintWriteOperation { + hint_bus: self.hint_bus, + hint_id: hint_id.into(), + offset: offset.into(), + value: value.into(), + timestamp: timestamp.into(), + aux, + } + } +} + +/// Constraints and interactions for a logical hint read of `(hint_id, offset, value)` at time `timestamp`. +/// This reads `(hint_id, offset, value, timestamp_prev)` from the hint bus and writes +/// `(hint_id, offset, value, timestamp)` to the hint bus. +/// +/// The generic `T` type is intended to be `AB::Expr` where `AB` is the [AirBuilder]. +/// The auxiliary columns are not expected to be expressions, so the generic `V` type is intended +/// to be `AB::Var`. +pub struct HintReadOperation<'a, T, V> { + hint_bus: HintBus, + hint_id: T, + offset: T, + value: T, + timestamp: T, + aux: &'a HintReadAuxCols, +} + +impl> HintReadOperation<'_, F, V> { + /// Evaluate constraints and send/receive interactions. + pub fn eval(self, builder: &mut AB, enabled: impl Into) + where + AB: InteractionBuilder, + { + let enabled = enabled.into(); + + self.hint_bus + .receive(self.hint_id.clone(), self.offset.clone(), self.value.clone(), self.aux.prev_timestamp) + .eval(builder, enabled.clone()); + + self.hint_bus + .send(self.hint_id, self.offset, self.value, self.timestamp) + .eval(builder, enabled); + } +} + +/// Constraints and interactions for a logical hint write of `(hint_id, offset, value)` at time +/// `timestamp`. This reads `(hint_id, offset, prev_value, timestamp_prev)` from the hint bus +/// and writes `(hint_id, offset, value, timestamp)` to the hint bus. +/// +/// **Note:** This can be used as a logical read operation by setting `prev_value = value`. +pub struct HintWriteOperation<'a, T, V> { + hint_bus: HintBus, + hint_id: T, + offset: T, + value: T, + timestamp: T, + aux: &'a HintWriteAuxCols, +} + +impl> HintWriteOperation<'_, T, V> { + /// Evaluate constraints and send/receive interactions. `enabled` must be boolean. + pub fn eval(self, builder: &mut AB, enabled: impl Into) + where + AB: InteractionBuilder, + { + let enabled = enabled.into(); + + self.hint_bus + .receive( + self.hint_id.clone(), + self.offset.clone(), + self.aux.prev_value.clone(), + self.aux.prev_timestamp, + ) + .eval(builder, enabled.clone()); + + self.hint_bus + .send(self.hint_id, self.offset, self.value, self.timestamp) + .eval(builder, enabled); + } +} diff --git a/crates/vm/src/system/memory/offline_checker/bus.rs b/crates/vm/src/system/memory/offline_checker/bus.rs index d15f5798ea..b082510bb2 100644 --- a/crates/vm/src/system/memory/offline_checker/bus.rs +++ b/crates/vm/src/system/memory/offline_checker/bus.rs @@ -101,3 +101,104 @@ impl MemoryBusInteraction { } } } + +/// Represents a hint bus identified by a unique bus index (`usize`). +/// Used to check correct read/write operations in hint space. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct HintBus { + pub inner: PermutationCheckBus, +} + +impl HintBus { + pub const fn new(index: BusIndex) -> Self { + Self { + inner: PermutationCheckBus::new(index), + } + } +} + +impl HintBus { + #[inline(always)] + pub fn index(&self) -> BusIndex { + self.inner.index + } + + /// Prepares a write operation through the hint bus. + pub fn send( + &self, + hint_id: impl Into, + offset: impl Into, + value: impl Into, + timestamp: impl Into, + ) -> HintBusInteraction { + self.push(true, hint_id, offset, value, timestamp) + } + + /// Prepares a read operation through the hint bus. + pub fn receive( + &self, + hint_id: impl Into, + offset: impl Into, + value: impl Into, + timestamp: impl Into, + ) -> HintBusInteraction { + self.push(false, hint_id, offset, value, timestamp) + } + + /// Prepares a hint operation (read or write) through the hint bus. + fn push( + &self, + is_send: bool, + hint_id: impl Into, + offset: impl Into, + value: impl Into, + timestamp: impl Into, + ) -> HintBusInteraction { + HintBusInteraction { + bus: self.inner, + is_send, + hint_id: hint_id.into(), + offset: offset.into(), + value: value.into(), + timestamp: timestamp.into(), + } + } +} + +#[derive(Clone, Debug)] +pub struct HintBusInteraction { + pub bus: PermutationCheckBus, + pub is_send: bool, + pub hint_id: T, + pub offset: T, + pub value: T, + pub timestamp: T, +} + +impl HintBusInteraction { + /// Finalizes and sends/receives the hint operation with the specified direction over the bus. + /// + /// A read corresponds to a receive, and a write corresponds to a send. + /// + /// The parameter `direction` can be -1, 0, or 1. A value of 1 means perform the same action + /// (send/receive), a value of -1 reverses the action, and a value of 0 means disabled. + /// + /// Caller must constrain `direction` to be in {-1, 0, 1}. + pub fn eval(self, builder: &mut AB, direction: impl Into) + where + AB: InteractionBuilder, + { + let fields = iter::empty() + .chain(iter::once(self.hint_id)) + .chain(iter::once(self.offset)) + .chain(iter::once(self.value)) + .chain(iter::once(self.timestamp)); + + if self.is_send { + self.bus.interact(builder, fields, direction); + } else { + self.bus + .interact(builder, fields, AB::Expr::NEG_ONE * direction.into()); + } + } +} diff --git a/crates/vm/src/system/memory/offline_checker/columns.rs b/crates/vm/src/system/memory/offline_checker/columns.rs index 63c114193d..6fe34d1180 100644 --- a/crates/vm/src/system/memory/offline_checker/columns.rs +++ b/crates/vm/src/system/memory/offline_checker/columns.rs @@ -123,3 +123,48 @@ impl AsMut> for MemoryReadOrImmediateAuxCols { &mut self.base } } +/// The auxiliary columns for a hint read operation. +/// Stores the previous timestamp to enable permutation checking through the hint bus. +#[repr(C)] +#[derive(Clone, Copy, Debug, AlignedBorrow)] +pub struct HintReadAuxCols { + /// The previous timestamp when this hint location was accessed + pub prev_timestamp: T, +} + +impl HintReadAuxCols { + pub fn new(prev_timestamp: F) -> Self { + Self { prev_timestamp } + } + + #[inline(always)] + pub fn set_prev(&mut self, timestamp: F) { + self.prev_timestamp = timestamp; + } +} + +/// The auxiliary columns for a hint write operation. +/// Stores the previous timestamp and previous value to enable permutation checking through the hint bus. +#[repr(C)] +#[derive(Clone, Copy, Debug, AlignedBorrow)] +pub struct HintWriteAuxCols { + /// The previous timestamp when this hint location was accessed + pub prev_timestamp: T, + /// The previous value at this hint location + pub prev_value: T, +} + +impl HintWriteAuxCols { + pub fn new(prev_timestamp: F, prev_value: F) -> Self { + Self { + prev_timestamp, + prev_value, + } + } + + #[inline(always)] + pub fn set_prev(&mut self, timestamp: F, value: F) { + self.prev_timestamp = timestamp; + self.prev_value = value; + } +} \ No newline at end of file diff --git a/crates/vm/src/system/mod.rs b/crates/vm/src/system/mod.rs index 8c0d2a3d37..5bb9a4be07 100644 --- a/crates/vm/src/system/mod.rs +++ b/crates/vm/src/system/mod.rs @@ -36,7 +36,7 @@ use crate::{ connector::VmConnectorChip, memory::{ interface::MemoryInterfaceAirs, - offline_checker::{MemoryBridge, MemoryBus}, + offline_checker::{HintBridge, HintBus, MemoryBridge, MemoryBus}, online::GuestMemory, MemoryAirInventory, MemoryController, TimestampedEquipartition, CHUNK, }, @@ -149,6 +149,7 @@ pub struct SystemPort { pub execution_bus: ExecutionBus, pub program_bus: ProgramBus, pub memory_bridge: MemoryBridge, + pub hint_bridge: HintBridge, } #[derive(Clone)] @@ -156,6 +157,7 @@ pub struct SystemAirInventory { pub program: ProgramAir, pub connector: VmConnectorAir, pub memory: MemoryAirInventory, + pub hint_bridge: HintBridge, /// Public values AIR exists if and only if continuations is disabled and `num_public_values` /// is greater than 0. pub public_values: Option, @@ -171,6 +173,7 @@ impl SystemAirInventory { execution_bus, program_bus, memory_bridge, + hint_bridge, } = port; let range_bus = memory_bridge.range_bus(); let program = ProgramAir::new(program_bus); @@ -212,6 +215,7 @@ impl SystemAirInventory { program, connector, memory, + hint_bridge, public_values, } } @@ -221,6 +225,7 @@ impl SystemAirInventory { memory_bridge: self.memory.bridge, program_bus: self.program.bus, execution_bus: self.connector.execution_bus, + hint_bridge: self.hint_bridge, } } @@ -300,10 +305,13 @@ impl VmCircuitConfig for SystemConfig { }; let memory_bridge = MemoryBridge::new(memory_bus, self.memory_config.timestamp_max_bits, range_bus); + let hint_bus = HintBus::new(bus_idx_mgr.new_bus_idx()); + let hint_bridge = HintBridge::new(hint_bus); let system_port = SystemPort { execution_bus, program_bus, memory_bridge, + hint_bridge, }; let system = SystemAirInventory::new(self, system_port, merkle_compression_buses); diff --git a/extensions/algebra/circuit/src/extension/fp2.rs b/extensions/algebra/circuit/src/extension/fp2.rs index 3081c88565..ce0938af4a 100644 --- a/extensions/algebra/circuit/src/extension/fp2.rs +++ b/extensions/algebra/circuit/src/extension/fp2.rs @@ -175,6 +175,7 @@ impl VmCircuitExtension for Fp2Extension { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); diff --git a/extensions/algebra/circuit/src/extension/modular.rs b/extensions/algebra/circuit/src/extension/modular.rs index 8946daa9c3..5c5be255fe 100644 --- a/extensions/algebra/circuit/src/extension/modular.rs +++ b/extensions/algebra/circuit/src/extension/modular.rs @@ -231,6 +231,7 @@ impl VmCircuitExtension for ModularExtension { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); diff --git a/extensions/bigint/circuit/src/extension/mod.rs b/extensions/bigint/circuit/src/extension/mod.rs index 1725a4860d..90e3688d0d 100644 --- a/extensions/bigint/circuit/src/extension/mod.rs +++ b/extensions/bigint/circuit/src/extension/mod.rs @@ -143,6 +143,7 @@ impl VmCircuitExtension for Int256 { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); diff --git a/extensions/ecc/circuit/src/extension/weierstrass.rs b/extensions/ecc/circuit/src/extension/weierstrass.rs index 5048584183..10181b1ea5 100644 --- a/extensions/ecc/circuit/src/extension/weierstrass.rs +++ b/extensions/ecc/circuit/src/extension/weierstrass.rs @@ -200,6 +200,7 @@ impl VmCircuitExtension for WeierstrassExtension { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); diff --git a/extensions/keccak256/circuit/src/extension/mod.rs b/extensions/keccak256/circuit/src/extension/mod.rs index 9f6e55a540..bbdd64cc06 100644 --- a/extensions/keccak256/circuit/src/extension/mod.rs +++ b/extensions/keccak256/circuit/src/extension/mod.rs @@ -148,6 +148,7 @@ impl VmCircuitExtension for Keccak256 { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index 9f3e2035ad..7f0fc0a981 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -219,6 +219,7 @@ where execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); let range_checker = inventory.range_checker().bus; @@ -542,6 +543,7 @@ impl VmCircuitExtension for CastFExtension { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); let range_checker = inventory.range_checker().bus; diff --git a/extensions/rv32im/circuit/src/extension/mod.rs b/extensions/rv32im/circuit/src/extension/mod.rs index 8055c16b54..fc9cb9ab63 100644 --- a/extensions/rv32im/circuit/src/extension/mod.rs +++ b/extensions/rv32im/circuit/src/extension/mod.rs @@ -202,6 +202,7 @@ impl VmCircuitExtension for Rv32I { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); @@ -466,6 +467,7 @@ impl VmCircuitExtension for Rv32M { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); @@ -635,6 +637,7 @@ impl VmCircuitExtension for Rv32Io { execution_bus, program_bus, memory_bridge, + hint_bridge: _, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); diff --git a/extensions/sha256/circuit/src/sha256_chip/air.rs b/extensions/sha256/circuit/src/sha256_chip/air.rs index 2fe1cb26c0..fb5ce2070f 100644 --- a/extensions/sha256/circuit/src/sha256_chip/air.rs +++ b/extensions/sha256/circuit/src/sha256_chip/air.rs @@ -53,6 +53,7 @@ impl Sha256VmAir { execution_bus, program_bus, memory_bridge, + hint_bridge: _, }: SystemPort, bitwise_lookup_bus: BitwiseOperationLookupBus, ptr_max_bits: usize, From 257ba4f5fc9d3862643fc0526144cc6c7e9c8d5b Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 8 Mar 2026 04:36:57 -0400 Subject: [PATCH 02/18] progress --- extensions/native/circuit/src/extension/mod.rs | 5 +++-- extensions/native/circuit/src/poseidon2/air.rs | 5 ++++- extensions/native/circuit/src/sumcheck/air.rs | 10 ++++++++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index 7f0fc0a981..81633c32b3 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -219,7 +219,7 @@ where execution_bus, program_bus, memory_bridge, - hint_bridge: _, + hint_bridge, } = inventory.system().port(); let exec_bridge = ExecutionBridge::new(execution_bus, program_bus); let range_checker = inventory.range_checker().bus; @@ -270,12 +270,13 @@ where let verify_batch = NativePoseidon2Air::<_, 1>::new( exec_bridge, memory_bridge, + hint_bridge, VerifyBatchBus::new(inventory.new_bus_idx()), Poseidon2Config::default(), ); inventory.add_air(verify_batch); - let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge); + let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge, hint_bridge); inventory.add_air(tower_evaluate); Ok(()) diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index baf18b06a3..9e9cdf5ce8 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -3,7 +3,7 @@ use std::{array::from_fn, borrow::Borrow, sync::Arc}; use itertools::Itertools; use openvm_circuit::{ arch::{ExecutionBridge, ExecutionState}, - system::memory::{offline_checker::MemoryBridge, MemoryAddress, CHUNK}, + system::memory::{offline_checker::{HintBridge, MemoryBridge}, MemoryAddress, CHUNK}, }; use openvm_circuit_primitives::utils::not; use openvm_instructions::LocalOpcode; @@ -36,6 +36,7 @@ use crate::poseidon2::{ pub struct NativePoseidon2Air { pub execution_bridge: ExecutionBridge, pub memory_bridge: MemoryBridge, + pub hint_bridge: HintBridge, pub internal_bus: VerifyBatchBus, pub(crate) subair: Arc>, pub(crate) address_space: F, @@ -45,12 +46,14 @@ impl NativePoseidon2Air, ) -> Self { NativePoseidon2Air { execution_bridge, memory_bridge, + hint_bridge, internal_bus: verify_batch_bus, subair: Arc::new(Poseidon2SubAir::new(poseidon2_config.constants.into())), address_space: F::from_canonical_u32(AS::Native as u32), diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index c9bbf1279e..74c99fe2b2 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -3,7 +3,7 @@ use std::borrow::Borrow; use openvm_circuit::{ arch::{ExecutionBridge, ExecutionState}, system::memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols}, + offline_checker::{HintBridge, MemoryBridge, MemoryReadAuxCols}, MemoryAddress, }, }; @@ -33,13 +33,19 @@ pub const NUM_RWS_FOR_LOGUP: usize = 3; pub struct NativeSumcheckAir { pub execution_bridge: ExecutionBridge, pub memory_bridge: MemoryBridge, + pub hint_bridge: HintBridge, } impl NativeSumcheckAir { - pub fn new(execution_bridge: ExecutionBridge, memory_bridge: MemoryBridge) -> Self { + pub fn new( + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + hint_bridge: HintBridge, + ) -> Self { Self { execution_bridge, memory_bridge, + hint_bridge, } } } From 798c00d80edca854c05a4915d30f1527f53522b9 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 8 Mar 2026 04:53:30 -0400 Subject: [PATCH 03/18] optional writeback --- .../native/circuit/src/extension/mod.rs | 2 +- extensions/native/circuit/src/sumcheck/air.rs | 68 ++------- .../native/circuit/src/sumcheck/chip.rs | 136 +++++++++--------- .../native/circuit/src/sumcheck/columns.rs | 4 +- .../native/circuit/src/sumcheck/execution.rs | 29 +--- 5 files changed, 86 insertions(+), 153 deletions(-) diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index 81633c32b3..01ba49e3f3 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -276,7 +276,7 @@ where ); inventory.add_air(verify_batch); - let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge, hint_bridge); + let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge); inventory.add_air(tower_evaluate); Ok(()) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 74c99fe2b2..aad3838b3a 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -3,7 +3,7 @@ use std::borrow::Borrow; use openvm_circuit::{ arch::{ExecutionBridge, ExecutionState}, system::memory::{ - offline_checker::{HintBridge, MemoryBridge, MemoryReadAuxCols}, + offline_checker::MemoryBridge, MemoryAddress, }, }; @@ -26,26 +26,20 @@ use crate::{ }, }; -pub const NUM_RWS_FOR_PRODUCT: usize = 2; -pub const NUM_RWS_FOR_LOGUP: usize = 3; - #[derive(Clone, Debug)] pub struct NativeSumcheckAir { pub execution_bridge: ExecutionBridge, pub memory_bridge: MemoryBridge, - pub hint_bridge: HintBridge, } impl NativeSumcheckAir { pub fn new( execution_bridge: ExecutionBridge, memory_bridge: MemoryBridge, - hint_bridge: HintBridge, ) -> Self { Self { execution_bridge, memory_bridge, - hint_bridge, } } } @@ -111,7 +105,7 @@ impl Air for NativeSumcheckAir { within_round_limit, should_acc, eval_acc, - is_hint_src_id, + is_writeback, specific, } = local; @@ -241,22 +235,6 @@ impl Air for NativeSumcheckAir { next.start_timestamp, start_timestamp + AB::F::from_canonical_usize(8), ); - builder - .when(prod_row) - .when(next.prod_row + next.logup_row) - .assert_eq( - next.start_timestamp, - start_timestamp - + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_PRODUCT), - ); - builder - .when(logup_row) - .when(next.prod_row + next.logup_row) - .assert_eq( - next.start_timestamp, - start_timestamp - + within_round_limit * AB::F::from_canonical_usize(NUM_RWS_FOR_LOGUP), - ); // Termination condition assert_array_eq( @@ -355,7 +333,7 @@ impl Air for NativeSumcheckAir { native_as, register_ptrs[0] + AB::F::from_canonical_usize(CONTEXT_ARR_BASE_LEN), ), - [max_round, is_hint_src_id], + [max_round, is_writeback], first_timestamp + AB::F::from_canonical_usize(7), &header_row_specific.read_records[7], ) @@ -398,21 +376,6 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(prod_row * should_acc, prod_acc); - // Read p1, p2 from witness arrays - self.memory_bridge - .read( - MemoryAddress::new(native_as, register_ptrs[2] + prod_row_specific.data_ptr), - prod_row_specific.p, - start_timestamp, - &MemoryReadAuxCols { - base: prod_row_specific.ps_record.base, - }, - ) - .eval( - builder, - (prod_in_round_evaluation + prod_next_round_evaluation) * not(is_hint_src_id), - ); - // Obtain p1, p2 from hint space and write back to witness arrays self.memory_bridge .write( @@ -423,7 +386,7 @@ impl Air for NativeSumcheckAir { ) .eval( builder, - (prod_in_round_evaluation + prod_next_round_evaluation) * is_hint_src_id, + (prod_in_round_evaluation + prod_next_round_evaluation) * is_writeback, ); let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap(); @@ -438,7 +401,7 @@ impl Air for NativeSumcheckAir { register_ptrs[4] + curr_prod_n * AB::F::from_canonical_usize(EXT_DEG), ), prod_row_specific.p_evals, - start_timestamp + AB::F::ONE, + start_timestamp + is_writeback * AB::F::ONE, &prod_row_specific.write_record, ) .eval(builder, prod_row * within_round_limit); @@ -505,21 +468,6 @@ impl Air for NativeSumcheckAir { ); builder.assert_eq(logup_row * should_acc, logup_acc); - // Read p1, p2, q1, q2 from witness arrays - self.memory_bridge - .read( - MemoryAddress::new(native_as, register_ptrs[3] + logup_row_specific.data_ptr), - logup_row_specific.pq, - start_timestamp, - &MemoryReadAuxCols { - base: logup_row_specific.pqs_record.base, - }, - ) - .eval( - builder, - (logup_in_round_evaluation + logup_next_round_evaluation) * not(is_hint_src_id), - ); - // Obtain p1, p2, q1, q2 from hint space self.memory_bridge .write( @@ -530,7 +478,7 @@ impl Air for NativeSumcheckAir { ) .eval( builder, - (logup_in_round_evaluation + logup_next_round_evaluation) * is_hint_src_id, + (logup_in_round_evaluation + logup_next_round_evaluation) * is_writeback, ); let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap(); let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)] @@ -552,7 +500,7 @@ impl Air for NativeSumcheckAir { + (num_prod_spec + curr_logup_n) * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.p_evals, - start_timestamp + AB::F::ONE, + start_timestamp + is_writeback * AB::F::ONE, &logup_row_specific.write_records[0], ) .eval(builder, logup_row * within_round_limit); @@ -567,7 +515,7 @@ impl Air for NativeSumcheckAir { * AB::F::from_canonical_usize(EXT_DEG), ), logup_row_specific.q_evals, - start_timestamp + AB::F::TWO, + start_timestamp + is_writeback * AB::F::ONE + AB::F::ONE, &logup_row_specific.write_records[1], ) .eval(builder, logup_row * within_round_limit); diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index d4fbf2524d..9efaf62a48 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -207,7 +207,7 @@ where challenges_ptr.as_canonical_u32(), head_specific.read_records[6].as_mut(), ); - let [max_round, is_hint_src_id]: [F; 2] = tracing_read_native_helper( + let [max_round, is_writeback]: [F; 2] = tracing_read_native_helper( state.memory, ctx_ptr.as_canonical_u32() + CONTEXT_ARR_BASE_LEN as u32, head_specific.read_records[7].as_mut(), @@ -242,21 +242,13 @@ where row.register_ptrs[3] = logup_evals_ptr; row.register_ptrs[4] = r_evals_ptr; row.max_round = max_round; - row.is_hint_src_id = is_hint_src_id; + row.is_writeback = is_writeback; } - // Load hints if source is a ptr - let is_hint_src_id = is_hint_src_id > F::ZERO; let prod_evals_id = prod_evals_id.as_canonical_u32(); let logup_evals_id = logup_evals_id.as_canonical_u32(); - let (prod_evals, logup_evals) = if is_hint_src_id { - ( - state.streams.hint_space[prod_evals_id as usize].clone(), - state.streams.hint_space[logup_evals_id as usize].clone(), - ) - } else { - (Vec::new(), Vec::new()) - }; + let prod_evals = state.streams.hint_space[prod_evals_id as usize].clone(); + let logup_evals = state.streams.hint_space[logup_evals_id as usize].clone(); // product rows for (i, prod_row) in rows @@ -292,24 +284,16 @@ where prod_specific.data_ptr = F::from_canonical_u32(start); // read p1, p2 - let ps: [F; EXT_DEG * 2] = if is_hint_src_id { - prod_evals[(start as usize)..((start as usize) + EXT_DEG * 2)] - .try_into() - .unwrap() - } else { - tracing_read_native_helper( - state.memory, - prod_evals_ptr.as_canonical_u32() + start, - prod_specific.ps_record.as_mut(), - ) - }; + let ps: [F; EXT_DEG * 2] = prod_evals[(start as usize)..((start as usize) + EXT_DEG * 2)] + .try_into() + .unwrap(); let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); prod_specific.p = ps; // If p values come from the hint stream, write back to the actual witness array - if is_hint_src_id { + if is_writeback > F::ZERO { tracing_write_native_inplace( state.memory, prod_evals_ptr.as_canonical_u32() + start, @@ -346,7 +330,7 @@ where eval, &mut prod_specific.write_record, ); - cur_timestamp += 2; // Either 1 read, 1 write (witness array input), or 2 writes (hint_ptr_id) + cur_timestamp += if is_writeback > F::ZERO { 2 } else { 1 }; // Only write back to the witness array when the is_writeback indicator is true let eval_rlc = FieldExtension::multiply(alpha_acc, eval); prod_specific.eval_rlc = eval_rlc; @@ -394,17 +378,9 @@ where logup_specific.data_ptr = F::from_canonical_u32(start); // read p1, p2, q1, q2 - let pqs: [F; EXT_DEG * 4] = if is_hint_src_id { - logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4] - .try_into() - .unwrap() - } else { - tracing_read_native_helper( - state.memory, - logup_evals_ptr.as_canonical_u32() + start, - logup_specific.pqs_record.as_mut(), - ) - }; + let pqs: [F; EXT_DEG * 4] = logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4] + .try_into() + .unwrap(); let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..(EXT_DEG * 2)].try_into().unwrap(); let q1: [F; EXT_DEG] = pqs[(EXT_DEG * 2)..(EXT_DEG * 3)].try_into().unwrap(); @@ -413,7 +389,7 @@ where logup_specific.pq = pqs; // write pqs - if is_hint_src_id { + if is_writeback > F::ZERO { tracing_write_native_inplace( state.memory, logup_evals_ptr.as_canonical_u32() + start, @@ -472,7 +448,7 @@ where q_eval, &mut logup_specific.write_records[1], ); - cur_timestamp += 3; // 1 read, 2 writes (witness array case) or 3 writes (hint space ptr case) + cur_timestamp += if is_writeback > F::ZERO { 3 } else { 2 }; // Only write back to the witness array when the is_writeback indicator is true let eval_rlc = FieldExtension::add( FieldExtension::multiply(alpha_numerator, p_eval), @@ -568,42 +544,66 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..ProdSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - // obtain p1, p2 - mem_fill_helper( - mem_helper, - start_timestamp, - prod_row_specific.ps_record.as_mut(), - ); - // write p_eval - mem_fill_helper( - mem_helper, - start_timestamp + 1, - prod_row_specific.write_record.as_mut(), - ); + if cols.is_writeback == F::ONE { + // writeback p1, p2 + mem_fill_helper( + mem_helper, + start_timestamp, + prod_row_specific.ps_record.as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 1, + prod_row_specific.write_record.as_mut(), + ); + } else { + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp, + prod_row_specific.write_record.as_mut(), + ); + } } } else if cols.logup_row == F::ONE { let logup_row_specific: &mut LogupSpecificCols = cols.specific[..LogupSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { - // obtain p1, p2, q1, q2 - mem_fill_helper( - mem_helper, - start_timestamp, - logup_row_specific.pqs_record.as_mut(), - ); - // write p_eval - mem_fill_helper( - mem_helper, - start_timestamp + 1, - logup_row_specific.write_records[0].as_mut(), - ); - // write q_eval - mem_fill_helper( - mem_helper, - start_timestamp + 2, - logup_row_specific.write_records[1].as_mut(), - ); + if cols.is_writeback == F::ONE { + // writeback p1, p2, q1, q2 + mem_fill_helper( + mem_helper, + start_timestamp, + logup_row_specific.pqs_record.as_mut(), + ); + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp + 1, + logup_row_specific.write_records[0].as_mut(), + ); + // write q_eval + mem_fill_helper( + mem_helper, + start_timestamp + 2, + logup_row_specific.write_records[1].as_mut(), + ); + } else { + // write p_eval + mem_fill_helper( + mem_helper, + start_timestamp, + logup_row_specific.write_records[0].as_mut(), + ); + // write q_eval + mem_fill_helper( + mem_helper, + start_timestamp + 1, + logup_row_specific.write_records[1].as_mut(), + ); + } } } } diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index f02f154cf2..eeb81df134 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -73,8 +73,8 @@ pub struct NativeSumcheckCols { // The current final evaluation accumulator. Extension element. pub eval_acc: [T; EXT_DEG], - // Indicator for an alternative source form of the inputs prod_evals/logup_evals - pub is_hint_src_id: T, + // Indicate whether the values read from hint slices should be written back to a witness array + pub is_writeback: T, // /// 1. For header row, 5 registers, ctx, challenges // /// 2. For the rest: max_variables, p1, p2, q1, q2 diff --git a/extensions/native/circuit/src/sumcheck/execution.rs b/extensions/native/circuit/src/sumcheck/execution.rs index a311e634b8..7cbd1fe319 100644 --- a/extensions/native/circuit/src/sumcheck/execution.rs +++ b/extensions/native/circuit/src/sumcheck/execution.rs @@ -217,7 +217,7 @@ unsafe fn execute_e12_impl( ctx; let challenges: [F; EXT_DEG * 4] = exec_state.vm_read(NATIVE_AS, challenges_ptr.as_canonical_u32()); - let [max_round, is_hint_space_ids]: [u32; 2] = exec_state + let [max_round, is_writeback]: [u32; 2] = exec_state .vm_read(NATIVE_AS, ctx_ptr_u32 + CONTEXT_ARR_BASE_LEN as u32) .map(|x: F| x.as_canonical_u32()); let alpha: [F; EXT_DEG] = challenges[0..EXT_DEG].try_into().unwrap(); @@ -228,14 +228,8 @@ unsafe fn execute_e12_impl( let mut alpha_acc = elem_to_ext(F::ONE); let mut eval_acc = elem_to_ext(F::ZERO); - let (prod_evals, logup_evals) = if is_hint_space_ids > 0 { - ( - exec_state.streams.hint_space[prod_evals_id as usize].clone(), - exec_state.streams.hint_space[logup_evals_id as usize].clone(), - ) - } else { - (Vec::new(), Vec::new()) - }; + let prod_evals = exec_state.streams.hint_space[prod_evals_id as usize].clone(); + let logup_evals = exec_state.streams.hint_space[logup_evals_id as usize].clone(); for i in 0..num_prod_spec { let start = calculate_3d_ext_idx( @@ -247,16 +241,11 @@ unsafe fn execute_e12_impl( ); if round < max_round - 1 { - let ps: [F; EXT_DEG * 2] = if is_hint_space_ids > 0 { - prod_evals[(start as usize)..(start as usize) + EXT_DEG * 2].try_into().unwrap() - } else { - exec_state.vm_read::<_, { EXT_DEG * 2 }>(NATIVE_AS, prod_evals_ptr + start).try_into().unwrap() - }; - + let ps: [F; EXT_DEG * 2] = prod_evals[(start as usize)..(start as usize) + EXT_DEG * 2].try_into().unwrap(); let p1: [F; EXT_DEG] = ps[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = ps[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); - if is_hint_space_ids > 0 { + if is_writeback > 0 { exec_state.vm_write(NATIVE_AS, prod_evals_ptr + start, &ps); } @@ -297,17 +286,13 @@ unsafe fn execute_e12_impl( if round < max_round - 1 { // read logup_evals - let pqs: [F; EXT_DEG * 4] = if is_hint_space_ids > 0 { - logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4].try_into().unwrap() - } else { - exec_state.vm_read::<_, { EXT_DEG * 4 }>(NATIVE_AS, logup_evals_ptr + start).try_into().unwrap() - }; + let pqs: [F; EXT_DEG * 4] = logup_evals[(start as usize)..(start as usize) + EXT_DEG * 4].try_into().unwrap(); let p1: [F; EXT_DEG] = pqs[0..EXT_DEG].try_into().unwrap(); let p2: [F; EXT_DEG] = pqs[EXT_DEG..EXT_DEG * 2].try_into().unwrap(); let q1: [F; EXT_DEG] = pqs[EXT_DEG * 2..EXT_DEG * 3].try_into().unwrap(); let q2: [F; EXT_DEG] = pqs[EXT_DEG * 3..EXT_DEG * 4].try_into().unwrap(); - if is_hint_space_ids > 0 { + if is_writeback > 0 { exec_state.vm_write(NATIVE_AS, logup_evals_ptr + start, &pqs); } From 6070b7f802b4a80b5eec37016052aa952e2f92d2 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 8 Mar 2026 04:59:50 -0400 Subject: [PATCH 04/18] hint bridge --- extensions/native/circuit/src/extension/mod.rs | 2 +- extensions/native/circuit/src/sumcheck/air.rs | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index 01ba49e3f3..81633c32b3 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -276,7 +276,7 @@ where ); inventory.add_air(verify_batch); - let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge); + let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge, hint_bridge); inventory.add_air(tower_evaluate); Ok(()) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index aad3838b3a..1dcace3e13 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -3,7 +3,7 @@ use std::borrow::Borrow; use openvm_circuit::{ arch::{ExecutionBridge, ExecutionState}, system::memory::{ - offline_checker::MemoryBridge, + offline_checker::{HintBridge, MemoryBridge}, MemoryAddress, }, }; @@ -30,16 +30,19 @@ use crate::{ pub struct NativeSumcheckAir { pub execution_bridge: ExecutionBridge, pub memory_bridge: MemoryBridge, + pub hint_bridge: HintBridge, } impl NativeSumcheckAir { pub fn new( execution_bridge: ExecutionBridge, memory_bridge: MemoryBridge, + hint_bridge: HintBridge, ) -> Self { Self { execution_bridge, memory_bridge, + hint_bridge, } } } From 7ae75f116dfd4440069020bca4db04175fe92eba Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 8 Mar 2026 07:12:30 -0400 Subject: [PATCH 05/18] hint bridge --- .../system/memory/offline_checker/bridge.rs | 137 ++++-------------- .../src/system/memory/offline_checker/bus.rs | 121 ++++++---------- .../native/circuit/src/extension/mod.rs | 19 ++- .../native/circuit/src/hint_space_provider.rs | 132 +++++++++++++++++ extensions/native/circuit/src/lib.rs | 1 + extensions/native/circuit/src/sumcheck/air.rs | 32 ++++ .../native/circuit/src/sumcheck/chip.rs | 29 +++- .../native/circuit/src/sumcheck/columns.rs | 4 + 8 files changed, 281 insertions(+), 194 deletions(-) create mode 100644 extensions/native/circuit/src/hint_space_provider.rs diff --git a/crates/vm/src/system/memory/offline_checker/bridge.rs b/crates/vm/src/system/memory/offline_checker/bridge.rs index e377497375..867348fb57 100644 --- a/crates/vm/src/system/memory/offline_checker/bridge.rs +++ b/crates/vm/src/system/memory/offline_checker/bridge.rs @@ -13,8 +13,7 @@ use openvm_stark_backend::{ use super::bus::{HintBus, MemoryBus}; use crate::system::memory::{ offline_checker::columns::{ - HintReadAuxCols, HintWriteAuxCols, MemoryBaseAuxCols, MemoryReadAuxCols, - MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols, + MemoryBaseAuxCols, MemoryReadAuxCols, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols, }, MemoryAddress, }; @@ -328,8 +327,9 @@ impl MemoryOfflineChecker { } } -/// The [HintBridge] is used to constrain logical hint memory operations (read/write). -/// It adds all necessary constraints and interactions for a separate hint space. +/// The [HintBridge] is used to constrain hint space lookups. +/// Consumer chips call `lookup` to verify that values they read from hint_space +/// match what was originally loaded via the hint bus lookup table. #[derive(Clone, Copy, Debug)] pub struct HintBridge { hint_bus: HintBus, @@ -345,116 +345,33 @@ impl HintBridge { self.hint_bus } - /// Prepare a logical hint read operation. - /// A read or write interaction has the form: (hint_id, offset, value, timestamp) - #[must_use] - pub fn read<'a, T, V>( + /// Perform a lookup on the hint bus for a single element. + /// + /// Constrains that `(hint_id, offset, value)` exists in the hint lookup table. + /// Caller must constrain that `enabled` is boolean. + pub fn lookup( &self, - hint_id: impl Into, - offset: impl Into, - value: impl Into, - timestamp: impl Into, - aux: &'a HintReadAuxCols, - ) -> HintReadOperation<'a, T, V> { - HintReadOperation { - hint_bus: self.hint_bus, - hint_id: hint_id.into(), - offset: offset.into(), - value: value.into(), - timestamp: timestamp.into(), - aux, - } + builder: &mut AB, + hint_id: impl Into, + offset: impl Into, + value: impl Into, + enabled: impl Into, + ) { + self.hint_bus.lookup(builder, hint_id, offset, value, enabled); } - /// Prepare a logical hint write operation. - /// A read or write interaction has the form: (hint_id, offset, value, timestamp) - #[must_use] - pub fn write<'a, T, V>( + /// Add a key to the hint lookup table. + /// + /// Provider chips call this to register that `(hint_id, offset, value)` is available. + pub fn provide( &self, - hint_id: impl Into, - offset: impl Into, - value: impl Into, - timestamp: impl Into, - aux: &'a HintWriteAuxCols, - ) -> HintWriteOperation<'a, T, V> { - HintWriteOperation { - hint_bus: self.hint_bus, - hint_id: hint_id.into(), - offset: offset.into(), - value: value.into(), - timestamp: timestamp.into(), - aux, - } - } -} - -/// Constraints and interactions for a logical hint read of `(hint_id, offset, value)` at time `timestamp`. -/// This reads `(hint_id, offset, value, timestamp_prev)` from the hint bus and writes -/// `(hint_id, offset, value, timestamp)` to the hint bus. -/// -/// The generic `T` type is intended to be `AB::Expr` where `AB` is the [AirBuilder]. -/// The auxiliary columns are not expected to be expressions, so the generic `V` type is intended -/// to be `AB::Var`. -pub struct HintReadOperation<'a, T, V> { - hint_bus: HintBus, - hint_id: T, - offset: T, - value: T, - timestamp: T, - aux: &'a HintReadAuxCols, -} - -impl> HintReadOperation<'_, F, V> { - /// Evaluate constraints and send/receive interactions. - pub fn eval(self, builder: &mut AB, enabled: impl Into) - where - AB: InteractionBuilder, - { - let enabled = enabled.into(); - - self.hint_bus - .receive(self.hint_id.clone(), self.offset.clone(), self.value.clone(), self.aux.prev_timestamp) - .eval(builder, enabled.clone()); - - self.hint_bus - .send(self.hint_id, self.offset, self.value, self.timestamp) - .eval(builder, enabled); - } -} - -/// Constraints and interactions for a logical hint write of `(hint_id, offset, value)` at time -/// `timestamp`. This reads `(hint_id, offset, prev_value, timestamp_prev)` from the hint bus -/// and writes `(hint_id, offset, value, timestamp)` to the hint bus. -/// -/// **Note:** This can be used as a logical read operation by setting `prev_value = value`. -pub struct HintWriteOperation<'a, T, V> { - hint_bus: HintBus, - hint_id: T, - offset: T, - value: T, - timestamp: T, - aux: &'a HintWriteAuxCols, -} - -impl> HintWriteOperation<'_, T, V> { - /// Evaluate constraints and send/receive interactions. `enabled` must be boolean. - pub fn eval(self, builder: &mut AB, enabled: impl Into) - where - AB: InteractionBuilder, - { - let enabled = enabled.into(); - - self.hint_bus - .receive( - self.hint_id.clone(), - self.offset.clone(), - self.aux.prev_value.clone(), - self.aux.prev_timestamp, - ) - .eval(builder, enabled.clone()); - + builder: &mut AB, + hint_id: impl Into, + offset: impl Into, + value: impl Into, + num_lookups: impl Into, + ) { self.hint_bus - .send(self.hint_id, self.offset, self.value, self.timestamp) - .eval(builder, enabled); + .provide(builder, hint_id, offset, value, num_lookups); } } diff --git a/crates/vm/src/system/memory/offline_checker/bus.rs b/crates/vm/src/system/memory/offline_checker/bus.rs index b082510bb2..4d5b9ef5cd 100644 --- a/crates/vm/src/system/memory/offline_checker/bus.rs +++ b/crates/vm/src/system/memory/offline_checker/bus.rs @@ -1,7 +1,7 @@ use std::iter; use openvm_stark_backend::{ - interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, + interaction::{BusIndex, InteractionBuilder, LookupBus, PermutationCheckBus}, p3_field::FieldAlgebra, }; @@ -102,103 +102,64 @@ impl MemoryBusInteraction { } } -/// Represents a hint bus identified by a unique bus index (`usize`). -/// Used to check correct read/write operations in hint space. +/// Represents a hint bus identified by a unique bus index. +/// Used as a lookup table to constrain values read from hint space. +/// +/// Consumer chips (e.g. NativeSumcheck) perform lookups to verify that +/// hint_space values match what was originally loaded. +/// Provider chips (e.g. a hint loader) add keys to the lookup table. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct HintBus { - pub inner: PermutationCheckBus, + pub inner: LookupBus, } impl HintBus { pub const fn new(index: BusIndex) -> Self { Self { - inner: PermutationCheckBus::new(index), + inner: LookupBus::new(index), } } -} -impl HintBus { #[inline(always)] pub fn index(&self) -> BusIndex { self.inner.index } - /// Prepares a write operation through the hint bus. - pub fn send( - &self, - hint_id: impl Into, - offset: impl Into, - value: impl Into, - timestamp: impl Into, - ) -> HintBusInteraction { - self.push(true, hint_id, offset, value, timestamp) - } - - /// Prepares a read operation through the hint bus. - pub fn receive( - &self, - hint_id: impl Into, - offset: impl Into, - value: impl Into, - timestamp: impl Into, - ) -> HintBusInteraction { - self.push(false, hint_id, offset, value, timestamp) - } - - /// Prepares a hint operation (read or write) through the hint bus. - fn push( + /// Performs a lookup on the hint bus. + /// + /// Asserts that `(hint_id, offset, value)` is present in the hint lookup table. + /// Caller must constrain that `enabled` is boolean. + pub fn lookup( &self, - is_send: bool, - hint_id: impl Into, - offset: impl Into, - value: impl Into, - timestamp: impl Into, - ) -> HintBusInteraction { - HintBusInteraction { - bus: self.inner, - is_send, - hint_id: hint_id.into(), - offset: offset.into(), - value: value.into(), - timestamp: timestamp.into(), - } + builder: &mut AB, + hint_id: impl Into, + offset: impl Into, + value: impl Into, + enabled: impl Into, + ) { + self.inner.lookup_key( + builder, + [hint_id.into(), offset.into(), value.into()], + enabled, + ); } -} -#[derive(Clone, Debug)] -pub struct HintBusInteraction { - pub bus: PermutationCheckBus, - pub is_send: bool, - pub hint_id: T, - pub offset: T, - pub value: T, - pub timestamp: T, -} - -impl HintBusInteraction { - /// Finalizes and sends/receives the hint operation with the specified direction over the bus. - /// - /// A read corresponds to a receive, and a write corresponds to a send. + /// Adds a key to the hint lookup table. /// - /// The parameter `direction` can be -1, 0, or 1. A value of 1 means perform the same action - /// (send/receive), a value of -1 reverses the action, and a value of 0 means disabled. - /// - /// Caller must constrain `direction` to be in {-1, 0, 1}. - pub fn eval(self, builder: &mut AB, direction: impl Into) - where - AB: InteractionBuilder, - { - let fields = iter::empty() - .chain(iter::once(self.hint_id)) - .chain(iter::once(self.offset)) - .chain(iter::once(self.value)) - .chain(iter::once(self.timestamp)); - - if self.is_send { - self.bus.interact(builder, fields, direction); - } else { - self.bus - .interact(builder, fields, AB::Expr::NEG_ONE * direction.into()); - } + /// The `num_lookups` parameter should equal the number of enabled lookups performed + /// for this key. + pub fn provide( + &self, + builder: &mut AB, + hint_id: impl Into, + offset: impl Into, + value: impl Into, + num_lookups: impl Into, + ) { + self.inner.add_key_with_lookups( + builder, + [hint_id.into(), offset.into(), value.into()], + num_lookups, + ); } } diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index 81633c32b3..86867ca1f5 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use alu_native_adapter::{AluNativeAdapterAir, AluNativeAdapterExecutor}; use branch_native_adapter::{BranchNativeAdapterAir, BranchNativeAdapterExecutor}; use convert_adapter::{ConvertAdapterAir, ConvertAdapterExecutor}; @@ -49,6 +51,7 @@ use crate::{ FriReducedOpeningAir, FriReducedOpeningChip, FriReducedOpeningExecutor, FriReducedOpeningFiller, }, + hint_space_provider::{HintSpaceProviderAir, HintSpaceProviderChip}, jal_rangecheck::{ JalRangeCheckAir, JalRangeCheckExecutor, JalRangeCheckFiller, NativeJalRangeCheckChip, }, @@ -276,6 +279,11 @@ where ); inventory.add_air(verify_batch); + let hint_space_provider = HintSpaceProviderAir { + hint_bus: hint_bridge.hint_bus(), + }; + inventory.add_air(hint_space_provider); + let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge, hint_bridge); inventory.add_air(tower_evaluate); @@ -359,7 +367,16 @@ where ); inventory.add_executor_chip(poseidon2); - let tower_verify = NativeSumcheckChip::new(NativeSumcheckFiller::new(), mem_helper.clone()); + let hint_bus = inventory.airs().system().hint_bridge.hint_bus(); + let hint_space_provider = Arc::new(HintSpaceProviderChip::new(hint_bus)); + + inventory.next_air::()?; + inventory.add_periphery_chip(hint_space_provider.clone()); + + let tower_verify = NativeSumcheckChip::new( + NativeSumcheckFiller::new(hint_space_provider), + mem_helper.clone(), + ); inventory.add_executor_chip(tower_verify); Ok(()) diff --git a/extensions/native/circuit/src/hint_space_provider.rs b/extensions/native/circuit/src/hint_space_provider.rs new file mode 100644 index 0000000000..b85479720d --- /dev/null +++ b/extensions/native/circuit/src/hint_space_provider.rs @@ -0,0 +1,132 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + mem::size_of, + sync::{Arc, Mutex}, +}; + +use openvm_circuit::system::memory::offline_checker::HintBus; +use openvm_circuit_primitives_derive::AlignedBorrow; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + interaction::InteractionBuilder, + p3_air::{Air, BaseAir}, + p3_field::{Field, PrimeField32}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + prover::{cpu::CpuBackend, types::AirProvingContext}, + rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, + Chip, ChipUsageGetter, +}; + +#[derive(Default, AlignedBorrow, Copy, Clone)] +#[repr(C)] +pub struct HintSpaceProviderCols { + pub hint_id: T, + pub offset: T, + pub value: T, + pub is_valid: T, +} + +pub const NUM_HINT_SPACE_PROVIDER_COLS: usize = size_of::>(); + +#[derive(Clone, Debug)] +pub struct HintSpaceProviderAir { + pub hint_bus: HintBus, +} + +impl BaseAirWithPublicValues for HintSpaceProviderAir {} +impl PartitionedBaseAir for HintSpaceProviderAir {} + +impl BaseAir for HintSpaceProviderAir { + fn width(&self) -> usize { + NUM_HINT_SPACE_PROVIDER_COLS + } +} + +impl Air for HintSpaceProviderAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local: &HintSpaceProviderCols = (*local).borrow(); + + builder.assert_bool(local.is_valid); + + self.hint_bus.provide( + builder, + local.hint_id, + local.offset, + local.value, + local.is_valid, + ); + } +} + +pub struct HintSpaceProviderChip { + pub air: HintSpaceProviderAir, + data: Mutex>, +} + +pub type SharedHintSpaceProviderChip = Arc>; + +impl HintSpaceProviderChip { + pub fn new(hint_bus: HintBus) -> Self { + Self { + air: HintSpaceProviderAir { hint_bus }, + data: Mutex::new(Vec::new()), + } + } + + /// Register a (hint_id, offset, value) triple for the provider trace. + /// Called by consumer chips during trace filling to match each lookup. + pub fn request(&self, hint_id: F, offset: F, value: F) { + self.data.lock().unwrap().push((hint_id, offset, value)); + } +} + +impl HintSpaceProviderChip { + pub fn generate_trace(&self) -> RowMajorMatrix { + let data = self.data.lock().unwrap(); + let num_real_rows = data.len(); + let trace_height = num_real_rows.next_power_of_two().max(2); + + let mut rows = F::zero_vec(trace_height * NUM_HINT_SPACE_PROVIDER_COLS); + for (n, row) in rows + .chunks_exact_mut(NUM_HINT_SPACE_PROVIDER_COLS) + .enumerate() + { + if n < num_real_rows { + let cols: &mut HintSpaceProviderCols = row.borrow_mut(); + cols.hint_id = data[n].0; + cols.offset = data[n].1; + cols.value = data[n].2; + cols.is_valid = F::ONE; + } + // padding rows are already zero (is_valid = 0) + } + RowMajorMatrix::new(rows, NUM_HINT_SPACE_PROVIDER_COLS) + } +} + +impl Chip> for HintSpaceProviderChip> +where + Val: PrimeField32, +{ + fn generate_proving_ctx(&self, _: R) -> AirProvingContext> { + let trace = self.generate_trace(); + AirProvingContext::simple_no_pis(Arc::new(trace)) + } +} + +impl ChipUsageGetter for HintSpaceProviderChip { + fn air_name(&self) -> String { + get_air_name(&self.air) + } + fn constant_trace_height(&self) -> Option { + None + } + fn current_trace_height(&self) -> usize { + self.data.lock().unwrap().len().next_power_of_two().max(2) + } + fn trace_width(&self) -> usize { + NUM_HINT_SPACE_PROVIDER_COLS + } +} diff --git a/extensions/native/circuit/src/lib.rs b/extensions/native/circuit/src/lib.rs index ce257c9c22..dc1c4dd050 100644 --- a/extensions/native/circuit/src/lib.rs +++ b/extensions/native/circuit/src/lib.rs @@ -39,6 +39,7 @@ mod castf; mod field_arithmetic; mod field_extension; mod fri; +pub(crate) mod hint_space_provider; mod jal_rangecheck; mod loadstore; mod poseidon2; diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 1dcace3e13..472e3e40bd 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -109,6 +109,8 @@ impl Air for NativeSumcheckAir { should_acc, eval_acc, is_writeback, + prod_hint_id, + logup_hint_id, specific, } = local; @@ -186,6 +188,12 @@ impl Air for NativeSumcheckAir { builder .when(next.prod_row + next.logup_row) .assert_eq(logup_nested_len, next.logup_nested_len); + builder + .when(next.prod_row + next.logup_row) + .assert_eq(prod_hint_id, next.prod_hint_id); + builder + .when(next.prod_row + next.logup_row) + .assert_eq(logup_hint_id, next.logup_hint_id); //////////////////////////////////////////////////////////////// // Row transitions from current to next row @@ -392,6 +400,18 @@ impl Air for NativeSumcheckAir { (prod_in_round_evaluation + prod_next_round_evaluation) * is_writeback, ); + // Lookup each element of p in the hint bus to constrain hint_space reads + let prod_enabled: AB::Expr = prod_in_round_evaluation + prod_next_round_evaluation; + for (j, &val) in prod_row_specific.p.iter().enumerate() { + self.hint_bridge.lookup( + builder, + prod_hint_id, + prod_row_specific.data_ptr + AB::F::from_canonical_usize(j), + val, + prod_enabled.clone(), + ); + } + let p1: [AB::Var; EXT_DEG] = prod_row_specific.p[0..EXT_DEG].try_into().unwrap(); let p2: [AB::Var; EXT_DEG] = prod_row_specific.p[EXT_DEG..(EXT_DEG * 2)] .try_into() @@ -483,6 +503,18 @@ impl Air for NativeSumcheckAir { builder, (logup_in_round_evaluation + logup_next_round_evaluation) * is_writeback, ); + + // Lookup each element of pq in the hint bus to constrain hint_space reads + let logup_enabled: AB::Expr = logup_in_round_evaluation + logup_next_round_evaluation; + for (j, &val) in logup_row_specific.pq.iter().enumerate() { + self.hint_bridge.lookup( + builder, + logup_hint_id, + logup_row_specific.data_ptr + AB::F::from_canonical_usize(j), + val, + logup_enabled.clone(), + ); + } let p1: [_; EXT_DEG] = logup_row_specific.pq[0..EXT_DEG].try_into().unwrap(); let p2: [_; EXT_DEG] = logup_row_specific.pq[EXT_DEG..(EXT_DEG * 2)] .try_into() diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 9efaf62a48..89ff6b9902 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -17,6 +17,7 @@ use openvm_stark_backend::p3_field::PrimeField32; use crate::{ field_extension::{FieldExtension, EXT_DEG}, fri::elem_to_ext, + hint_space_provider::SharedHintSpaceProviderChip, mem_fill_helper, sumcheck::columns::{ HeaderSpecificCols, LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols, @@ -96,9 +97,11 @@ impl SizedRecord for NativeSumcheck pub struct NativeSumcheckExecutor; #[derive(derive_new::new)] -pub struct NativeSumcheckFiller; +pub struct NativeSumcheckFiller { + pub hint_space_provider: SharedHintSpaceProviderChip, +} -pub type NativeSumcheckChip = VmChipWrapper; +pub type NativeSumcheckChip = VmChipWrapper>; impl Default for NativeSumcheckExecutor { fn default() -> Self { @@ -243,6 +246,8 @@ where row.register_ptrs[4] = r_evals_ptr; row.max_round = max_round; row.is_writeback = is_writeback; + row.prod_hint_id = prod_evals_id; + row.logup_hint_id = logup_evals_id; } let prod_evals_id = prod_evals_id.as_canonical_u32(); @@ -517,7 +522,7 @@ where } } -impl TraceFiller for NativeSumcheckFiller { +impl TraceFiller for NativeSumcheckFiller { fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { let cols: &mut NativeSumcheckCols = row_slice.borrow_mut(); let start_timestamp = cols.start_timestamp.as_canonical_u32(); @@ -544,6 +549,15 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..ProdSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { + // Register each p element with the hint space provider for the lookup bus + for (j, &val) in prod_row_specific.p.iter().enumerate() { + self.hint_space_provider.request( + cols.prod_hint_id, + prod_row_specific.data_ptr + F::from_canonical_usize(j), + val, + ); + } + if cols.is_writeback == F::ONE { // writeback p1, p2 mem_fill_helper( @@ -571,6 +585,15 @@ impl TraceFiller for NativeSumcheckFiller { cols.specific[..LogupSpecificCols::::width()].borrow_mut(); if cols.within_round_limit == F::ONE { + // Register each pq element with the hint space provider for the lookup bus + for (j, &val) in logup_row_specific.pq.iter().enumerate() { + self.hint_space_provider.request( + cols.logup_hint_id, + logup_row_specific.data_ptr + F::from_canonical_usize(j), + val, + ); + } + if cols.is_writeback == F::ONE { // writeback p1, p2, q1, q2 mem_fill_helper( diff --git a/extensions/native/circuit/src/sumcheck/columns.rs b/extensions/native/circuit/src/sumcheck/columns.rs index eeb81df134..ace9ad8a3d 100644 --- a/extensions/native/circuit/src/sumcheck/columns.rs +++ b/extensions/native/circuit/src/sumcheck/columns.rs @@ -76,6 +76,10 @@ pub struct NativeSumcheckCols { // Indicate whether the values read from hint slices should be written back to a witness array pub is_writeback: T, + // Hint space IDs for lookup bus interactions + pub prod_hint_id: T, + pub logup_hint_id: T, + // /// 1. For header row, 5 registers, ctx, challenges // /// 2. For the rest: max_variables, p1, p2, q1, q2 // pub read_records: [MemoryReadAuxCols; 7], From 435b8ad9bba3f5ec231cc092e342d04b732ff496 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 9 Mar 2026 17:45:31 -0400 Subject: [PATCH 06/18] fix stale data --- crates/continuations/src/verifier/common/non_leaf.rs | 1 - crates/vm/src/arch/vm.rs | 4 ---- extensions/native/circuit/src/hint_space_provider.rs | 2 +- 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/crates/continuations/src/verifier/common/non_leaf.rs b/crates/continuations/src/verifier/common/non_leaf.rs index b167e2933c..a41383e9fd 100644 --- a/crates/continuations/src/verifier/common/non_leaf.rs +++ b/crates/continuations/src/verifier/common/non_leaf.rs @@ -48,7 +48,6 @@ impl NonLeafVerifierVariables { let proof = builder.get(proofs, i); assert_required_air_for_agg_vm_present(builder, &proof); let proof_vm_pvs = self.verify_internal_or_leaf_verifier_proof(builder, &proof); - assert_single_segment_vm_exit_successfully(builder, &proof); builder.if_eq(i, RVar::zero()).then_or_else( |builder| { diff --git a/crates/vm/src/arch/vm.rs b/crates/vm/src/arch/vm.rs index cb58a0b77a..516efc8078 100644 --- a/crates/vm/src/arch/vm.rs +++ b/crates/vm/src/arch/vm.rs @@ -474,7 +474,6 @@ where .map(|(&h, w)| (h as usize, w)) .collect::>(); let ctx = PreflightCtx::new_with_capacity(&capacities, instret_end); - let system_config: &SystemConfig = self.config().as_ref(); let adapter_offset = system_config.access_adapter_air_id_offset(); // ATTENTION: this must agree with `num_memory_airs` @@ -511,7 +510,6 @@ where .finalize::>(system_config.continuation_enabled); #[cfg(feature = "perf-metrics")] crate::metrics::end_segment_metrics(&mut exec_state); - let instret = exec_state.vm_state.instret(); let pc = exec_state.vm_state.pc(); let memory = exec_state.vm_state.memory; @@ -685,7 +683,6 @@ where PreflightExecutor, VB::RecordArena>, { self.transport_init_memory_to_device(&state.memory); - let PreflightExecutionOutput { system_records, record_arenas, @@ -696,7 +693,6 @@ where (system_records.exit_code == Some(ExitCode::Success as u32)).then_some(to_state.memory); let ctx = self.generate_proving_ctx(system_records, record_arenas)?; let proof = self.engine.prove(&self.pk, ctx); - Ok((proof, final_memory)) } diff --git a/extensions/native/circuit/src/hint_space_provider.rs b/extensions/native/circuit/src/hint_space_provider.rs index b85479720d..4feac677b6 100644 --- a/extensions/native/circuit/src/hint_space_provider.rs +++ b/extensions/native/circuit/src/hint_space_provider.rs @@ -84,7 +84,7 @@ impl HintSpaceProviderChip { impl HintSpaceProviderChip { pub fn generate_trace(&self) -> RowMajorMatrix { - let data = self.data.lock().unwrap(); + let data = std::mem::take(&mut *self.data.lock().unwrap()); let num_real_rows = data.len(); let trace_height = num_real_rows.next_power_of_two().max(2); From dcd6bfc71fb7550059012ca3454f474e64161ef2 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 9 Mar 2026 18:18:26 -0400 Subject: [PATCH 07/18] hint provider gpu --- .../circuit/cuda/src/hint_space_provider.cu | 56 +++++++++++++++++++ extensions/native/circuit/src/cuda_abi.rs | 30 ++++++++++ .../native/circuit/src/extension/cuda.rs | 10 ++++ .../native/circuit/src/hint_space_provider.rs | 53 ++++++++++++++++++ 4 files changed, 149 insertions(+) create mode 100644 extensions/native/circuit/cuda/src/hint_space_provider.cu diff --git a/extensions/native/circuit/cuda/src/hint_space_provider.cu b/extensions/native/circuit/cuda/src/hint_space_provider.cu new file mode 100644 index 0000000000..a87c1155ac --- /dev/null +++ b/extensions/native/circuit/cuda/src/hint_space_provider.cu @@ -0,0 +1,56 @@ +#include "launcher.cuh" +#include "primitives/trace_access.h" + +// Columns layout matches HintSpaceProviderCols in hint_space_provider.rs +// Fields: hint_id, offset, value, is_valid +template struct HintSpaceProviderCols { + T hint_id; + T offset; + T value; + T is_valid; +}; + +constexpr uint32_t HINT_SPACE_PROVIDER_WIDTH = sizeof(HintSpaceProviderCols); + +__global__ void hint_space_provider_tracegen( + Fp *trace, + size_t height, + const Fp *records, + size_t rows_used +) { + uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= height) { + return; + } + + RowSlice row(trace + idx, height); + if (idx < rows_used) { + // Each record is a triple (hint_id, offset, value) + const Fp *rec = records + idx * 3; + COL_WRITE_VALUE(row, HintSpaceProviderCols, hint_id, rec[0]); + COL_WRITE_VALUE(row, HintSpaceProviderCols, offset, rec[1]); + COL_WRITE_VALUE(row, HintSpaceProviderCols, value, rec[2]); + COL_WRITE_VALUE(row, HintSpaceProviderCols, is_valid, Fp::one()); + } else { + row.fill_zero(0, HINT_SPACE_PROVIDER_WIDTH); + } +} + +extern "C" int _hint_space_provider_tracegen( + Fp *d_trace, + size_t height, + size_t width, + const Fp *d_records, + size_t rows_used +) { + assert((height & (height - 1)) == 0); + assert(width == HINT_SPACE_PROVIDER_WIDTH); + auto [grid, block] = kernel_launch_params(height); + hint_space_provider_tracegen<<>>( + d_trace, + height, + d_records, + rows_used + ); + return CHECK_KERNEL(); +} diff --git a/extensions/native/circuit/src/cuda_abi.rs b/extensions/native/circuit/src/cuda_abi.rs index 5de9124f0d..4530290521 100644 --- a/extensions/native/circuit/src/cuda_abi.rs +++ b/extensions/native/circuit/src/cuda_abi.rs @@ -345,3 +345,33 @@ pub mod native_jal_rangecheck_cuda { )) } } + +pub mod hint_space_provider_cuda { + use super::*; + + extern "C" { + pub fn _hint_space_provider_tracegen( + d_trace: *mut F, + height: usize, + width: usize, + d_records: *const F, + rows_used: usize, + ) -> i32; + } + + pub unsafe fn tracegen( + d_trace: &DeviceBuffer, + height: usize, + width: usize, + d_records: &DeviceBuffer, + rows_used: usize, + ) -> Result<(), CudaError> { + CudaError::from_result(_hint_space_provider_tracegen( + d_trace.as_mut_ptr(), + height, + width, + d_records.as_ptr(), + rows_used, + )) + } +} diff --git a/extensions/native/circuit/src/extension/cuda.rs b/extensions/native/circuit/src/extension/cuda.rs index 765ce8d6cc..885abcbe1a 100644 --- a/extensions/native/circuit/src/extension/cuda.rs +++ b/extensions/native/circuit/src/extension/cuda.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use openvm_circuit::{ arch::{ChipInventory, ChipInventoryError, DenseRecordArena, VmProverExtension}, system::cuda::extensions::get_inventory_range_checker, @@ -14,6 +16,7 @@ use crate::{ field_arithmetic::{FieldArithmeticAir, FieldArithmeticChipGpu}, field_extension::{FieldExtensionAir, FieldExtensionChipGpu}, fri::{FriReducedOpeningAir, FriReducedOpeningChipGpu}, + hint_space_provider::{cuda::HintSpaceProviderChipGpu, HintSpaceProviderAir, HintSpaceProviderChip}, jal_rangecheck::{JalRangeCheckAir, JalRangeCheckGpu}, loadstore::{NativeLoadStoreAir, NativeLoadStoreChipGpu}, poseidon2::{air::NativePoseidon2Air, NativePoseidon2ChipGpu}, @@ -76,6 +79,13 @@ impl VmProverExtension let poseidon2 = NativePoseidon2ChipGpu::<1>::new(range_checker.clone(), timestamp_max_bits); inventory.add_executor_chip(poseidon2); + // HintSpaceProvider must be registered BEFORE NativeSumcheck because chips are + // dispatched in reverse order: sumcheck runs first and populates the provider. + let hint_air: &HintSpaceProviderAir = inventory.next_air::()?; + let cpu_chip = Arc::new(HintSpaceProviderChip::new(hint_air.hint_bus.clone())); + let provider_gpu = HintSpaceProviderChipGpu::new(cpu_chip); + inventory.add_periphery_chip(provider_gpu); + inventory.next_air::()?; let sumcheck = NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits); inventory.add_executor_chip(sumcheck); diff --git a/extensions/native/circuit/src/hint_space_provider.rs b/extensions/native/circuit/src/hint_space_provider.rs index 4feac677b6..65c26bb299 100644 --- a/extensions/native/circuit/src/hint_space_provider.rs +++ b/extensions/native/circuit/src/hint_space_provider.rs @@ -130,3 +130,56 @@ impl ChipUsageGetter for HintSpaceProviderChip { NUM_HINT_SPACE_PROVIDER_COLS } } + +#[cfg(feature = "cuda")] +pub mod cuda { + use std::sync::Arc; + + use openvm_circuit::arch::DenseRecordArena; + use openvm_cuda_backend::{base::DeviceMatrix, prover_backend::GpuBackend, types::F}; + use openvm_cuda_common::copy::MemCopyH2D; + use openvm_stark_backend::{prover::types::AirProvingContext, Chip}; + + use super::{HintSpaceProviderChip, NUM_HINT_SPACE_PROVIDER_COLS}; + use crate::cuda_abi::hint_space_provider_cuda; + + pub struct HintSpaceProviderChipGpu { + pub cpu_chip: Arc>, + } + + impl HintSpaceProviderChipGpu { + pub fn new(cpu_chip: Arc>) -> Self { + Self { cpu_chip } + } + } + + impl Chip for HintSpaceProviderChipGpu { + fn generate_proving_ctx(&self, _: DenseRecordArena) -> AirProvingContext { + let data = std::mem::take(&mut *self.cpu_chip.data.lock().unwrap()); + let rows_used = data.len(); + let height = rows_used.next_power_of_two().max(2); + + // Flatten (hint_id, offset, value) triples into a contiguous [F] buffer + let flat: Vec = data + .into_iter() + .flat_map(|(h, o, v)| [h, o, v]) + .collect(); + + let d_records = flat.to_device().unwrap(); + let trace = DeviceMatrix::::with_capacity(height, NUM_HINT_SPACE_PROVIDER_COLS); + + unsafe { + hint_space_provider_cuda::tracegen( + trace.buffer(), + height, + NUM_HINT_SPACE_PROVIDER_COLS, + &d_records, + rows_used, + ) + .unwrap(); + } + + AirProvingContext::simple_no_pis(trace) + } + } +} From f2cd33241ce4307ad0fdc292bbe850ca2d7f494a Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 9 Mar 2026 18:36:45 -0400 Subject: [PATCH 08/18] resolve native sumcheck gpu trace fill problem --- .../native/circuit/src/extension/cuda.rs | 7 +- .../native/circuit/src/sumcheck/cuda.rs | 66 +++++++++++++++++-- 2 files changed, 66 insertions(+), 7 deletions(-) diff --git a/extensions/native/circuit/src/extension/cuda.rs b/extensions/native/circuit/src/extension/cuda.rs index 885abcbe1a..9ae1558500 100644 --- a/extensions/native/circuit/src/extension/cuda.rs +++ b/extensions/native/circuit/src/extension/cuda.rs @@ -82,12 +82,13 @@ impl VmProverExtension // HintSpaceProvider must be registered BEFORE NativeSumcheck because chips are // dispatched in reverse order: sumcheck runs first and populates the provider. let hint_air: &HintSpaceProviderAir = inventory.next_air::()?; - let cpu_chip = Arc::new(HintSpaceProviderChip::new(hint_air.hint_bus.clone())); - let provider_gpu = HintSpaceProviderChipGpu::new(cpu_chip); + let cpu_chip = Arc::new(HintSpaceProviderChip::new(hint_air.hint_bus)); + let provider_gpu = HintSpaceProviderChipGpu::new(cpu_chip.clone()); inventory.add_periphery_chip(provider_gpu); inventory.next_air::()?; - let sumcheck = NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits); + let sumcheck = + NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits, cpu_chip); inventory.add_executor_chip(sumcheck); Ok(()) diff --git a/extensions/native/circuit/src/sumcheck/cuda.rs b/extensions/native/circuit/src/sumcheck/cuda.rs index 60aba15b95..7dce367a5b 100644 --- a/extensions/native/circuit/src/sumcheck/cuda.rs +++ b/extensions/native/circuit/src/sumcheck/cuda.rs @@ -1,4 +1,4 @@ -use std::{mem::size_of, slice::from_raw_parts, sync::Arc}; +use std::{borrow::Borrow, mem::size_of, slice::from_raw_parts, sync::Arc}; use derive_new::new; use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero}; @@ -7,15 +7,70 @@ use openvm_cuda_backend::{ base::DeviceMatrix, chip::get_empty_air_proving_ctx, prover_backend::GpuBackend, types::F, }; use openvm_cuda_common::copy::MemCopyH2D; -use openvm_stark_backend::{prover::types::AirProvingContext, Chip}; +use openvm_stark_backend::{p3_field::PrimeField32, prover::types::AirProvingContext, Chip}; -use super::columns::NativeSumcheckCols; -use crate::cuda_abi::sumcheck_cuda; +use super::columns::{LogupSpecificCols, NativeSumcheckCols, ProdSpecificCols}; +use crate::{ + cuda_abi::sumcheck_cuda, + hint_space_provider::SharedHintSpaceProviderChip, +}; #[derive(new)] pub struct NativeSumcheckChipGpu { pub range_checker: Arc, pub timestamp_max_bits: usize, + pub hint_space_provider: SharedHintSpaceProviderChip, +} + +impl NativeSumcheckChipGpu { + /// Scans execution records to populate the hint space provider with + /// (hint_id, offset, value) triples for each hint element referenced + /// by prod and logup rows. This bridges the gap between CPU execution + /// (which produces the records) and GPU trace generation. + fn populate_hint_provider(&self, records: &[u8]) { + let width = NativeSumcheckCols::::width(); + let record_size = width * size_of::(); + if records.len() % record_size != 0 { + return; + } + let num_rows = records.len() / record_size; + + let row_slice = unsafe { + let ptr = records.as_ptr() as *const F; + from_raw_parts(ptr, num_rows * width) + }; + + for i in 0..num_rows { + let row_data = &row_slice[i * width..(i + 1) * width]; + let cols: &NativeSumcheckCols = row_data.borrow(); + + if cols.within_round_limit != F::ONE { + continue; + } + + if cols.prod_row == F::ONE { + let prod_specific: &ProdSpecificCols = + cols.specific[..ProdSpecificCols::::width()].borrow(); + for (j, &val) in prod_specific.p.iter().enumerate() { + self.hint_space_provider.request( + cols.prod_hint_id, + prod_specific.data_ptr + F::from_canonical_usize(j), + val, + ); + } + } else if cols.logup_row == F::ONE { + let logup_specific: &LogupSpecificCols = + cols.specific[..LogupSpecificCols::::width()].borrow(); + for (j, &val) in logup_specific.pq.iter().enumerate() { + self.hint_space_provider.request( + cols.logup_hint_id, + logup_specific.data_ptr + F::from_canonical_usize(j), + val, + ); + } + } + } + } } impl Chip for NativeSumcheckChipGpu { @@ -25,6 +80,9 @@ impl Chip for NativeSumcheckChipGpu { return get_empty_air_proving_ctx::(); } + // Populate hint space provider from execution records before GPU upload. + self.populate_hint_provider(records); + let width = NativeSumcheckCols::::width(); let record_size = width * size_of::(); assert_eq!(records.len() % record_size, 0); From 65823528741a3afaac2add5c1dee2d3f7fa680dc Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 9 Mar 2026 20:18:57 -0400 Subject: [PATCH 09/18] cuda adjust --- .../circuit/cuda/include/native/sumcheck.cuh | 5 +- .../native/circuit/cuda/src/sumcheck.cu | 79 +++++++++++++------ extensions/native/circuit/src/sumcheck/air.rs | 4 + 3 files changed, 62 insertions(+), 26 deletions(-) diff --git a/extensions/native/circuit/cuda/include/native/sumcheck.cuh b/extensions/native/circuit/cuda/include/native/sumcheck.cuh index a7d6eee536..3ba5a27195 100644 --- a/extensions/native/circuit/cuda/include/native/sumcheck.cuh +++ b/extensions/native/circuit/cuda/include/native/sumcheck.cuh @@ -82,7 +82,10 @@ template struct NativeSumcheckCols { T eval_acc[EXT_DEG]; - T is_hint_src_id; + T is_writeback; + + T prod_hint_id; + T logup_hint_id; T specific[COL_SPECIFIC_WIDTH]; }; diff --git a/extensions/native/circuit/cuda/src/sumcheck.cu b/extensions/native/circuit/cuda/src/sumcheck.cu index 139a56473f..224703517e 100644 --- a/extensions/native/circuit/cuda/src/sumcheck.cu +++ b/extensions/native/circuit/cuda/src/sumcheck.cu @@ -32,34 +32,63 @@ __device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_h ); } else if (row[COL_INDEX(NativeSumcheckCols, prod_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { - mem_fill_base( - mem_helper, - start_timestamp, - specific.slice_from(COL_INDEX(ProdSpecificCols, ps_record.base)) - ); - mem_fill_base( - mem_helper, - start_timestamp + 1, - specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base)) - ); + if (row[COL_INDEX(NativeSumcheckCols, is_writeback)] == Fp::one()) { + // writeback p1, p2 + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(ProdSpecificCols, ps_record.base)) + ); + // write p_eval + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base)) + ); + } else { + // write p_eval only + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(ProdSpecificCols, write_record.base)) + ); + } } } else if (row[COL_INDEX(NativeSumcheckCols, logup_row)] == Fp::one()) { if (row[COL_INDEX(NativeSumcheckCols, within_round_limit)] == Fp::one()) { - mem_fill_base( - mem_helper, - start_timestamp, - specific.slice_from(COL_INDEX(LogupSpecificCols, pqs_record.base)) - ); - mem_fill_base( - mem_helper, - start_timestamp + 1, - specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) - ); - mem_fill_base( - mem_helper, - start_timestamp + 2, - specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base)) - ); + if (row[COL_INDEX(NativeSumcheckCols, is_writeback)] == Fp::one()) { + // writeback p1, p2, q1, q2 + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(LogupSpecificCols, pqs_record.base)) + ); + // write p_eval + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) + ); + // write q_eval + mem_fill_base( + mem_helper, + start_timestamp + 2, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base)) + ); + } else { + // write p_eval + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[0].base)) + ); + // write q_eval + mem_fill_base( + mem_helper, + start_timestamp + 1, + specific.slice_from(COL_INDEX(LogupSpecificCols, write_records[1].base)) + ); + } } } } diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 472e3e40bd..3eee3906d6 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -120,6 +120,7 @@ impl Air for NativeSumcheckAir { builder.assert_bool(prod_row); builder.assert_bool(logup_row); builder.assert_bool(within_round_limit); + builder.assert_bool(is_writeback); builder.assert_bool(prod_in_round_evaluation); builder.assert_bool(logup_in_round_evaluation); @@ -188,6 +189,9 @@ impl Air for NativeSumcheckAir { builder .when(next.prod_row + next.logup_row) .assert_eq(logup_nested_len, next.logup_nested_len); + builder + .when(next.prod_row + next.logup_row) + .assert_eq(is_writeback, next.is_writeback); builder .when(next.prod_row + next.logup_row) .assert_eq(prod_hint_id, next.prod_hint_id); From d2b5424b4228d585f9f9f9ccf27470d9cc39505b Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 9 Mar 2026 22:27:56 -0400 Subject: [PATCH 10/18] fix compilation --- extensions/native/circuit/src/sumcheck/cuda.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/cuda.rs b/extensions/native/circuit/src/sumcheck/cuda.rs index 7dce367a5b..951e94b83d 100644 --- a/extensions/native/circuit/src/sumcheck/cuda.rs +++ b/extensions/native/circuit/src/sumcheck/cuda.rs @@ -14,6 +14,7 @@ use crate::{ cuda_abi::sumcheck_cuda, hint_space_provider::SharedHintSpaceProviderChip, }; +use p3_field::FieldAlgebra; #[derive(new)] pub struct NativeSumcheckChipGpu { @@ -54,7 +55,7 @@ impl NativeSumcheckChipGpu { for (j, &val) in prod_specific.p.iter().enumerate() { self.hint_space_provider.request( cols.prod_hint_id, - prod_specific.data_ptr + F::from_canonical_usize(j), + prod_specific.data_ptr + F::from_canonical_u16(j), val, ); } @@ -64,7 +65,7 @@ impl NativeSumcheckChipGpu { for (j, &val) in logup_specific.pq.iter().enumerate() { self.hint_space_provider.request( cols.logup_hint_id, - logup_specific.data_ptr + F::from_canonical_usize(j), + logup_specific.data_ptr + F::from_canonical_u16(j), val, ); } From 05843db42b36085a196bd7d98ea555f7396cd7f2 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 9 Mar 2026 22:41:00 -0400 Subject: [PATCH 11/18] fix compilation --- extensions/native/circuit/src/sumcheck/cuda.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/cuda.rs b/extensions/native/circuit/src/sumcheck/cuda.rs index 951e94b83d..2dcecd5756 100644 --- a/extensions/native/circuit/src/sumcheck/cuda.rs +++ b/extensions/native/circuit/src/sumcheck/cuda.rs @@ -55,7 +55,7 @@ impl NativeSumcheckChipGpu { for (j, &val) in prod_specific.p.iter().enumerate() { self.hint_space_provider.request( cols.prod_hint_id, - prod_specific.data_ptr + F::from_canonical_u16(j), + prod_specific.data_ptr + F::from_canonical_usize(j), val, ); } @@ -65,7 +65,7 @@ impl NativeSumcheckChipGpu { for (j, &val) in logup_specific.pq.iter().enumerate() { self.hint_space_provider.request( cols.logup_hint_id, - logup_specific.data_ptr + F::from_canonical_u16(j), + logup_specific.data_ptr + F::from_canonical_usize(j), val, ); } From ea70a288f69d909bfad0ef70c93e9c1c20e16eb0 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Tue, 10 Mar 2026 17:05:00 -0400 Subject: [PATCH 12/18] account for empty gpu trace --- .../native/circuit/src/hint_space_provider.rs | 38 +++++++++++-------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/extensions/native/circuit/src/hint_space_provider.rs b/extensions/native/circuit/src/hint_space_provider.rs index 65c26bb299..867a9803d7 100644 --- a/extensions/native/circuit/src/hint_space_provider.rs +++ b/extensions/native/circuit/src/hint_space_provider.rs @@ -159,24 +159,30 @@ pub mod cuda { let rows_used = data.len(); let height = rows_used.next_power_of_two().max(2); - // Flatten (hint_id, offset, value) triples into a contiguous [F] buffer - let flat: Vec = data - .into_iter() - .flat_map(|(h, o, v)| [h, o, v]) - .collect(); - - let d_records = flat.to_device().unwrap(); let trace = DeviceMatrix::::with_capacity(height, NUM_HINT_SPACE_PROVIDER_COLS); - unsafe { - hint_space_provider_cuda::tracegen( - trace.buffer(), - height, - NUM_HINT_SPACE_PROVIDER_COLS, - &d_records, - rows_used, - ) - .unwrap(); + if rows_used > 0 { + // Flatten (hint_id, offset, value) triples into a contiguous [F] buffer + let flat: Vec = data + .into_iter() + .flat_map(|(h, o, v)| [h, o, v]) + .collect(); + + let d_records = flat.to_device().unwrap(); + + unsafe { + hint_space_provider_cuda::tracegen( + trace.buffer(), + height, + NUM_HINT_SPACE_PROVIDER_COLS, + &d_records, + rows_used, + ) + .unwrap(); + } + } else { + // No data — zero-fill the trace (all padding rows with is_valid=0) + trace.buffer().fill_zero().unwrap(); } AirProvingContext::simple_no_pis(trace) From 3222314707e7e5eda8c197fa8d9b8b87649fbe79 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 11 Mar 2026 18:21:20 -0400 Subject: [PATCH 13/18] constraints --- .../src/verifier/common/non_leaf.rs | 1 + crates/vm/src/arch/vm.rs | 4 ++ .../system/memory/offline_checker/columns.rs | 45 ------------------- extensions/native/circuit/src/sumcheck/air.rs | 20 +++++++++ .../native/circuit/src/sumcheck/chip.rs | 8 ++-- 5 files changed, 29 insertions(+), 49 deletions(-) diff --git a/crates/continuations/src/verifier/common/non_leaf.rs b/crates/continuations/src/verifier/common/non_leaf.rs index a41383e9fd..b167e2933c 100644 --- a/crates/continuations/src/verifier/common/non_leaf.rs +++ b/crates/continuations/src/verifier/common/non_leaf.rs @@ -48,6 +48,7 @@ impl NonLeafVerifierVariables { let proof = builder.get(proofs, i); assert_required_air_for_agg_vm_present(builder, &proof); let proof_vm_pvs = self.verify_internal_or_leaf_verifier_proof(builder, &proof); + assert_single_segment_vm_exit_successfully(builder, &proof); builder.if_eq(i, RVar::zero()).then_or_else( |builder| { diff --git a/crates/vm/src/arch/vm.rs b/crates/vm/src/arch/vm.rs index 516efc8078..cb58a0b77a 100644 --- a/crates/vm/src/arch/vm.rs +++ b/crates/vm/src/arch/vm.rs @@ -474,6 +474,7 @@ where .map(|(&h, w)| (h as usize, w)) .collect::>(); let ctx = PreflightCtx::new_with_capacity(&capacities, instret_end); + let system_config: &SystemConfig = self.config().as_ref(); let adapter_offset = system_config.access_adapter_air_id_offset(); // ATTENTION: this must agree with `num_memory_airs` @@ -510,6 +511,7 @@ where .finalize::>(system_config.continuation_enabled); #[cfg(feature = "perf-metrics")] crate::metrics::end_segment_metrics(&mut exec_state); + let instret = exec_state.vm_state.instret(); let pc = exec_state.vm_state.pc(); let memory = exec_state.vm_state.memory; @@ -683,6 +685,7 @@ where PreflightExecutor, VB::RecordArena>, { self.transport_init_memory_to_device(&state.memory); + let PreflightExecutionOutput { system_records, record_arenas, @@ -693,6 +696,7 @@ where (system_records.exit_code == Some(ExitCode::Success as u32)).then_some(to_state.memory); let ctx = self.generate_proving_ctx(system_records, record_arenas)?; let proof = self.engine.prove(&self.pk, ctx); + Ok((proof, final_memory)) } diff --git a/crates/vm/src/system/memory/offline_checker/columns.rs b/crates/vm/src/system/memory/offline_checker/columns.rs index 6fe34d1180..f614a43026 100644 --- a/crates/vm/src/system/memory/offline_checker/columns.rs +++ b/crates/vm/src/system/memory/offline_checker/columns.rs @@ -122,49 +122,4 @@ impl AsMut> for MemoryReadOrImmediateAuxCols { fn as_mut(&mut self) -> &mut MemoryBaseAuxCols { &mut self.base } -} -/// The auxiliary columns for a hint read operation. -/// Stores the previous timestamp to enable permutation checking through the hint bus. -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct HintReadAuxCols { - /// The previous timestamp when this hint location was accessed - pub prev_timestamp: T, -} - -impl HintReadAuxCols { - pub fn new(prev_timestamp: F) -> Self { - Self { prev_timestamp } - } - - #[inline(always)] - pub fn set_prev(&mut self, timestamp: F) { - self.prev_timestamp = timestamp; - } -} - -/// The auxiliary columns for a hint write operation. -/// Stores the previous timestamp and previous value to enable permutation checking through the hint bus. -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct HintWriteAuxCols { - /// The previous timestamp when this hint location was accessed - pub prev_timestamp: T, - /// The previous value at this hint location - pub prev_value: T, -} - -impl HintWriteAuxCols { - pub fn new(prev_timestamp: F, prev_value: F) -> Self { - Self { - prev_timestamp, - prev_value, - } - } - - #[inline(always)] - pub fn set_prev(&mut self, timestamp: F, value: F) { - self.prev_timestamp = timestamp; - self.prev_value = value; - } } \ No newline at end of file diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 3eee3906d6..6d49b05798 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -251,6 +251,26 @@ impl Air for NativeSumcheckAir { start_timestamp + AB::F::from_canonical_usize(8), ); + // Prod row timestamp transition + builder + .when_transition() + .when(prod_row) + .assert_eq( + next.start_timestamp, + start_timestamp + + within_round_limit * (AB::Expr::ONE + is_writeback), + ); + + // Logup row timestamp transition + builder + .when_transition() + .when(logup_row) + .assert_eq( + next.start_timestamp, + start_timestamp + + within_round_limit * (AB::Expr::TWO + is_writeback), + ); + // Termination condition assert_array_eq( &mut builder.when::(is_end.into()), diff --git a/extensions/native/circuit/src/sumcheck/chip.rs b/extensions/native/circuit/src/sumcheck/chip.rs index 89ff6b9902..df6a41d1ca 100644 --- a/extensions/native/circuit/src/sumcheck/chip.rs +++ b/extensions/native/circuit/src/sumcheck/chip.rs @@ -298,7 +298,7 @@ where prod_specific.p = ps; // If p values come from the hint stream, write back to the actual witness array - if is_writeback > F::ZERO { + if is_writeback != F::ZERO { tracing_write_native_inplace( state.memory, prod_evals_ptr.as_canonical_u32() + start, @@ -335,7 +335,7 @@ where eval, &mut prod_specific.write_record, ); - cur_timestamp += if is_writeback > F::ZERO { 2 } else { 1 }; // Only write back to the witness array when the is_writeback indicator is true + cur_timestamp += if is_writeback != F::ZERO { 2 } else { 1 }; // Only write back to the witness array when the is_writeback indicator is true let eval_rlc = FieldExtension::multiply(alpha_acc, eval); prod_specific.eval_rlc = eval_rlc; @@ -394,7 +394,7 @@ where logup_specific.pq = pqs; // write pqs - if is_writeback > F::ZERO { + if is_writeback != F::ZERO { tracing_write_native_inplace( state.memory, logup_evals_ptr.as_canonical_u32() + start, @@ -453,7 +453,7 @@ where q_eval, &mut logup_specific.write_records[1], ); - cur_timestamp += if is_writeback > F::ZERO { 3 } else { 2 }; // Only write back to the witness array when the is_writeback indicator is true + cur_timestamp += if is_writeback != F::ZERO { 3 } else { 2 }; // Only write back to the witness array when the is_writeback indicator is true let eval_rlc = FieldExtension::add( FieldExtension::multiply(alpha_numerator, p_eval), From b8cda79665b91f16f1ba1cf8b3c07a3a481744d2 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 11 Mar 2026 19:52:57 -0400 Subject: [PATCH 14/18] column structure --- .../circuit/cuda/src/hint_space_provider.cu | 56 ----- extensions/native/circuit/src/cuda_abi.rs | 30 --- .../native/circuit/src/hint_space_provider.rs | 195 ++++++++++++------ 3 files changed, 133 insertions(+), 148 deletions(-) delete mode 100644 extensions/native/circuit/cuda/src/hint_space_provider.cu diff --git a/extensions/native/circuit/cuda/src/hint_space_provider.cu b/extensions/native/circuit/cuda/src/hint_space_provider.cu deleted file mode 100644 index a87c1155ac..0000000000 --- a/extensions/native/circuit/cuda/src/hint_space_provider.cu +++ /dev/null @@ -1,56 +0,0 @@ -#include "launcher.cuh" -#include "primitives/trace_access.h" - -// Columns layout matches HintSpaceProviderCols in hint_space_provider.rs -// Fields: hint_id, offset, value, is_valid -template struct HintSpaceProviderCols { - T hint_id; - T offset; - T value; - T is_valid; -}; - -constexpr uint32_t HINT_SPACE_PROVIDER_WIDTH = sizeof(HintSpaceProviderCols); - -__global__ void hint_space_provider_tracegen( - Fp *trace, - size_t height, - const Fp *records, - size_t rows_used -) { - uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= height) { - return; - } - - RowSlice row(trace + idx, height); - if (idx < rows_used) { - // Each record is a triple (hint_id, offset, value) - const Fp *rec = records + idx * 3; - COL_WRITE_VALUE(row, HintSpaceProviderCols, hint_id, rec[0]); - COL_WRITE_VALUE(row, HintSpaceProviderCols, offset, rec[1]); - COL_WRITE_VALUE(row, HintSpaceProviderCols, value, rec[2]); - COL_WRITE_VALUE(row, HintSpaceProviderCols, is_valid, Fp::one()); - } else { - row.fill_zero(0, HINT_SPACE_PROVIDER_WIDTH); - } -} - -extern "C" int _hint_space_provider_tracegen( - Fp *d_trace, - size_t height, - size_t width, - const Fp *d_records, - size_t rows_used -) { - assert((height & (height - 1)) == 0); - assert(width == HINT_SPACE_PROVIDER_WIDTH); - auto [grid, block] = kernel_launch_params(height); - hint_space_provider_tracegen<<>>( - d_trace, - height, - d_records, - rows_used - ); - return CHECK_KERNEL(); -} diff --git a/extensions/native/circuit/src/cuda_abi.rs b/extensions/native/circuit/src/cuda_abi.rs index 4530290521..5de9124f0d 100644 --- a/extensions/native/circuit/src/cuda_abi.rs +++ b/extensions/native/circuit/src/cuda_abi.rs @@ -345,33 +345,3 @@ pub mod native_jal_rangecheck_cuda { )) } } - -pub mod hint_space_provider_cuda { - use super::*; - - extern "C" { - pub fn _hint_space_provider_tracegen( - d_trace: *mut F, - height: usize, - width: usize, - d_records: *const F, - rows_used: usize, - ) -> i32; - } - - pub unsafe fn tracegen( - d_trace: &DeviceBuffer, - height: usize, - width: usize, - d_records: &DeviceBuffer, - rows_used: usize, - ) -> Result<(), CudaError> { - CudaError::from_result(_hint_space_provider_tracegen( - d_trace.as_mut_ptr(), - height, - width, - d_records.as_ptr(), - rows_used, - )) - } -} diff --git a/extensions/native/circuit/src/hint_space_provider.rs b/extensions/native/circuit/src/hint_space_provider.rs index 867a9803d7..c12ecab16b 100644 --- a/extensions/native/circuit/src/hint_space_provider.rs +++ b/extensions/native/circuit/src/hint_space_provider.rs @@ -1,5 +1,6 @@ use std::{ borrow::{Borrow, BorrowMut}, + collections::HashMap, mem::size_of, sync::{Arc, Mutex}, }; @@ -9,8 +10,8 @@ use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, - p3_air::{Air, BaseAir}, - p3_field::{Field, PrimeField32}, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, prover::{cpu::CpuBackend, types::AirProvingContext}, rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, @@ -23,7 +24,15 @@ pub struct HintSpaceProviderCols { pub hint_id: T, pub offset: T, pub value: T, - pub is_valid: T, + pub multiplicity: T, + /// Inverse of multiplicity when nonzero; 0 for padding rows. + /// Used to derive a virtual boolean `is_non_padding = multiplicity * mult_inv`. + pub mult_inv: T, + /// Boolean: 1 if hint_id changes between this row and the next non-padding row. + pub hint_id_changed: T, + /// When hint_id_changed = 1: inverse of (next.hint_id - hint_id), proving they differ. + /// When hint_id_changed = 0: unused (zero). + pub diff_hint_id_inv: T, } pub const NUM_HINT_SPACE_PROVIDER_COLS: usize = size_of::>(); @@ -45,24 +54,77 @@ impl BaseAir for HintSpaceProviderAir { impl Air for HintSpaceProviderAir { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let local = main.row_slice(0); - let local: &HintSpaceProviderCols = (*local).borrow(); + let curr = main.row_slice(0); + let curr: &HintSpaceProviderCols = (*curr).borrow(); + let next = main.row_slice(1); + let next: &HintSpaceProviderCols = (*next).borrow(); - builder.assert_bool(local.is_valid); + // Derive virtual boolean `is_non_padding` from multiplicity. + let curr_is_non_padding: AB::Expr = curr.multiplicity * curr.mult_inv; + let next_is_non_padding: AB::Expr = next.multiplicity * next.mult_inv; + + // is_non_padding must be boolean. + builder.assert_zero(curr_is_non_padding.clone() * (curr_is_non_padding.clone() - AB::Expr::ONE)); + // multiplicity = 0 when is_non_padding = 0 (padding rows can't provide to the bus). + builder.assert_zero((AB::Expr::ONE - curr_is_non_padding.clone()) * curr.multiplicity); + + builder.assert_bool(curr.hint_id_changed); + + // Non-padding rows must appear before padding rows (is_non_padding is non-increasing). + builder + .when_transition() + .when(next_is_non_padding.clone()) + .assert_one(curr_is_non_padding.clone()); + + // Uniqueness of (hint_id, offset) among non-padding rows. + // Rows are sorted by (hint_id, offset). For consecutive non-padding rows: + // - Same hint_id (hint_id_changed=0): offset must increase by exactly 1. + // - Different hint_id (hint_id_changed=1): hint_id must actually differ + // (proven via inverse), and the new block starts at offset 0. + // Since offsets within a hint_id form a contiguous 0,1,2,...,N-1 sequence, + // this prevents any duplicate (hint_id, offset) pairs. + let both_non_padding: AB::Expr = curr_is_non_padding * next_is_non_padding; + let d_id: AB::Expr = next.hint_id - curr.hint_id; + + // hint_id_changed = 0 => same hint_id, offset increases by 1 + builder + .when_transition() + .when(both_non_padding.clone()) + .when_ne(curr.hint_id_changed, AB::Expr::ONE) + .assert_zero(d_id.clone()); + builder + .when_transition() + .when(both_non_padding.clone()) + .when_ne(curr.hint_id_changed, AB::Expr::ONE) + .assert_eq(next.offset, curr.offset + AB::Expr::ONE); + + // hint_id_changed = 1 => hint_id actually differs, and new block starts at offset 0 + builder + .when_transition() + .when(both_non_padding.clone()) + .when(curr.hint_id_changed) + .assert_one(d_id * curr.diff_hint_id_inv); + builder + .when_transition() + .when(both_non_padding) + .when(curr.hint_id_changed) + .assert_zero(next.offset); self.hint_bus.provide( builder, - local.hint_id, - local.offset, - local.value, - local.is_valid, + curr.hint_id, + curr.offset, + curr.value, + curr.multiplicity, ); } } pub struct HintSpaceProviderChip { pub air: HintSpaceProviderAir, - data: Mutex>, + /// Maps (hint_id, offset) -> (value, multiplicity). + /// Deduplicates keys and tracks how many times each is looked up. + data: Mutex>, } pub type SharedHintSpaceProviderChip = Arc>; @@ -71,37 +133,71 @@ impl HintSpaceProviderChip { pub fn new(hint_bus: HintBus) -> Self { Self { air: HintSpaceProviderAir { hint_bus }, - data: Mutex::new(Vec::new()), + data: Mutex::new(HashMap::new()), } } +} +impl HintSpaceProviderChip { /// Register a (hint_id, offset, value) triple for the provider trace. /// Called by consumer chips during trace filling to match each lookup. + /// Deduplicates by (hint_id, offset) and increments the multiplicity counter. pub fn request(&self, hint_id: F, offset: F, value: F) { - self.data.lock().unwrap().push((hint_id, offset, value)); + self.data + .lock() + .unwrap() + .entry((hint_id, offset)) + .and_modify(|(v, m)| { + debug_assert_eq!(*v, value, "conflicting values for same (hint_id, offset)"); + *m += F::ONE; + }) + .or_insert((value, F::ONE)); } } impl HintSpaceProviderChip { pub fn generate_trace(&self) -> RowMajorMatrix { let data = std::mem::take(&mut *self.data.lock().unwrap()); - let num_real_rows = data.len(); - let trace_height = num_real_rows.next_power_of_two().max(2); + // Collect into a Vec and sort by (hint_id, offset) to satisfy the AIR ordering constraints. + let mut entries: Vec<_> = data.into_iter().collect(); + entries.sort_by_key(|((h, o), _)| (h.as_canonical_u64(), o.as_canonical_u64())); + + let num_non_padding_rows = entries.len(); + let trace_height = num_non_padding_rows.next_power_of_two().max(2); let mut rows = F::zero_vec(trace_height * NUM_HINT_SPACE_PROVIDER_COLS); - for (n, row) in rows - .chunks_exact_mut(NUM_HINT_SPACE_PROVIDER_COLS) - .enumerate() - { - if n < num_real_rows { - let cols: &mut HintSpaceProviderCols = row.borrow_mut(); - cols.hint_id = data[n].0; - cols.offset = data[n].1; - cols.value = data[n].2; - cols.is_valid = F::ONE; + for (n, ((hint_id, offset), (value, multiplicity))) in entries.iter().enumerate() { + let row = + &mut rows[n * NUM_HINT_SPACE_PROVIDER_COLS..(n + 1) * NUM_HINT_SPACE_PROVIDER_COLS]; + let cols: &mut HintSpaceProviderCols = row.borrow_mut(); + cols.hint_id = *hint_id; + cols.offset = *offset; + cols.value = *value; + cols.multiplicity = *multiplicity; + cols.mult_inv = multiplicity.try_inverse().unwrap(); + + // Fill auxiliary columns for the uniqueness constraint. + if n + 1 < num_non_padding_rows { + let next_hint_id = entries[n + 1].0 .0; + let d_id = next_hint_id - *hint_id; + if d_id != F::ZERO { + cols.hint_id_changed = F::ONE; + cols.diff_hint_id_inv = d_id.try_inverse().unwrap(); + } else { + debug_assert_eq!( + entries[n + 1].0 .1, + *offset + F::ONE, + "Offsets for hint_id {:?} are not consecutive: {:?} -> {:?}", + hint_id, + offset, + entries[n + 1].0 .1 + ); + // hint_id_changed = 0, diff_hint_id_inv = 0 (defaults) + } } - // padding rows are already zero (is_valid = 0) + // Last non-padding row: aux columns stay zero (no next non-padding row to compare) } + // padding rows are already zero (multiplicity = 0) RowMajorMatrix::new(rows, NUM_HINT_SPACE_PROVIDER_COLS) } } @@ -136,12 +232,15 @@ pub mod cuda { use std::sync::Arc; use openvm_circuit::arch::DenseRecordArena; - use openvm_cuda_backend::{base::DeviceMatrix, prover_backend::GpuBackend, types::F}; - use openvm_cuda_common::copy::MemCopyH2D; - use openvm_stark_backend::{prover::types::AirProvingContext, Chip}; + use openvm_cuda_backend::{ + chip::cpu_proving_ctx_to_gpu, prover_backend::GpuBackend, types::F, types::SC, + }; + use openvm_stark_backend::{ + prover::{cpu::CpuBackend, types::AirProvingContext}, + Chip, + }; - use super::{HintSpaceProviderChip, NUM_HINT_SPACE_PROVIDER_COLS}; - use crate::cuda_abi::hint_space_provider_cuda; + use super::HintSpaceProviderChip; pub struct HintSpaceProviderChipGpu { pub cpu_chip: Arc>, @@ -155,37 +254,9 @@ pub mod cuda { impl Chip for HintSpaceProviderChipGpu { fn generate_proving_ctx(&self, _: DenseRecordArena) -> AirProvingContext { - let data = std::mem::take(&mut *self.cpu_chip.data.lock().unwrap()); - let rows_used = data.len(); - let height = rows_used.next_power_of_two().max(2); - - let trace = DeviceMatrix::::with_capacity(height, NUM_HINT_SPACE_PROVIDER_COLS); - - if rows_used > 0 { - // Flatten (hint_id, offset, value) triples into a contiguous [F] buffer - let flat: Vec = data - .into_iter() - .flat_map(|(h, o, v)| [h, o, v]) - .collect(); - - let d_records = flat.to_device().unwrap(); - - unsafe { - hint_space_provider_cuda::tracegen( - trace.buffer(), - height, - NUM_HINT_SPACE_PROVIDER_COLS, - &d_records, - rows_used, - ) - .unwrap(); - } - } else { - // No data — zero-fill the trace (all padding rows with is_valid=0) - trace.buffer().fill_zero().unwrap(); - } - - AirProvingContext::simple_no_pis(trace) + let cpu_ctx: AirProvingContext> = + AirProvingContext::simple_no_pis(Arc::new(self.cpu_chip.generate_trace())); + cpu_proving_ctx_to_gpu(cpu_ctx) } } } From 5ba7f71da9bf5ed2fc479ad1a43260f43e2060a3 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 12 Mar 2026 20:00:19 -0400 Subject: [PATCH 15/18] adjust degree --- .../native/circuit/src/hint_space_provider.rs | 65 ++++++++++++------- extensions/native/circuit/src/sumcheck/air.rs | 43 ++++++------ 2 files changed, 67 insertions(+), 41 deletions(-) diff --git a/extensions/native/circuit/src/hint_space_provider.rs b/extensions/native/circuit/src/hint_space_provider.rs index c12ecab16b..c9dc95be0d 100644 --- a/extensions/native/circuit/src/hint_space_provider.rs +++ b/extensions/native/circuit/src/hint_space_provider.rs @@ -26,13 +26,18 @@ pub struct HintSpaceProviderCols { pub value: T, pub multiplicity: T, /// Inverse of multiplicity when nonzero; 0 for padding rows. - /// Used to derive a virtual boolean `is_non_padding = multiplicity * mult_inv`. pub mult_inv: T, /// Boolean: 1 if hint_id changes between this row and the next non-padding row. pub hint_id_changed: T, /// When hint_id_changed = 1: inverse of (next.hint_id - hint_id), proving they differ. /// When hint_id_changed = 0: unused (zero). pub diff_hint_id_inv: T, + /// Boolean: 1 if this row is not a padding row (multiplicity > 0). + pub curr_is_non_padding: T, + /// Boolean: 1 if the next row is not a padding row. + pub next_is_non_padding: T, + /// Boolean: curr_is_non_padding * next_is_non_padding. + pub both_non_padding: T, } pub const NUM_HINT_SPACE_PROVIDER_COLS: usize = size_of::>(); @@ -59,54 +64,65 @@ impl Air for HintSpaceProviderAir { let next = main.row_slice(1); let next: &HintSpaceProviderCols = (*next).borrow(); - // Derive virtual boolean `is_non_padding` from multiplicity. - let curr_is_non_padding: AB::Expr = curr.multiplicity * curr.mult_inv; - let next_is_non_padding: AB::Expr = next.multiplicity * next.mult_inv; - - // is_non_padding must be boolean. - builder.assert_zero(curr_is_non_padding.clone() * (curr_is_non_padding.clone() - AB::Expr::ONE)); - // multiplicity = 0 when is_non_padding = 0 (padding rows can't provide to the bus). - builder.assert_zero((AB::Expr::ONE - curr_is_non_padding.clone()) * curr.multiplicity); + // curr_is_non_padding is boolean and tied to multiplicity via mult_inv. + builder.assert_bool(curr.curr_is_non_padding); + builder.assert_eq( + curr.curr_is_non_padding, + curr.multiplicity * curr.mult_inv, + ); + // Padding rows must have multiplicity = 0. + builder.assert_zero( + (AB::Expr::ONE - curr.curr_is_non_padding) * curr.multiplicity, + ); builder.assert_bool(curr.hint_id_changed); - // Non-padding rows must appear before padding rows (is_non_padding is non-increasing). + // Tie next_is_non_padding and both_non_padding columns to their definitions. builder .when_transition() - .when(next_is_non_padding.clone()) - .assert_one(curr_is_non_padding.clone()); + .assert_eq(curr.next_is_non_padding, next.curr_is_non_padding); + builder.when_transition().assert_eq( + curr.both_non_padding, + curr.curr_is_non_padding * curr.next_is_non_padding, + ); + + // Non-padding rows must appear before padding rows (non-increasing). + builder + .when_transition() + .when(curr.next_is_non_padding) + .assert_one(curr.curr_is_non_padding); // Uniqueness of (hint_id, offset) among non-padding rows. // Rows are sorted by (hint_id, offset). For consecutive non-padding rows: // - Same hint_id (hint_id_changed=0): offset must increase by exactly 1. // - Different hint_id (hint_id_changed=1): hint_id must actually differ // (proven via inverse), and the new block starts at offset 0. - // Since offsets within a hint_id form a contiguous 0,1,2,...,N-1 sequence, - // this prevents any duplicate (hint_id, offset) pairs. - let both_non_padding: AB::Expr = curr_is_non_padding * next_is_non_padding; let d_id: AB::Expr = next.hint_id - curr.hint_id; // hint_id_changed = 0 => same hint_id, offset increases by 1 builder .when_transition() - .when(both_non_padding.clone()) + .when(curr.both_non_padding) .when_ne(curr.hint_id_changed, AB::Expr::ONE) .assert_zero(d_id.clone()); builder .when_transition() - .when(both_non_padding.clone()) + .when(curr.both_non_padding) .when_ne(curr.hint_id_changed, AB::Expr::ONE) .assert_eq(next.offset, curr.offset + AB::Expr::ONE); - // hint_id_changed = 1 => hint_id actually differs, and new block starts at offset 0 + // Combined inverse check: d_id * diff_hint_id_inv = hint_id_changed. + // When hint_id_changed = 0: trivially 0 = 0 (since d_id = 0 from above). + // When hint_id_changed = 1: proves d_id != 0 (hint_id actually changed). builder .when_transition() - .when(both_non_padding.clone()) - .when(curr.hint_id_changed) - .assert_one(d_id * curr.diff_hint_id_inv); + .when(curr.both_non_padding) + .assert_eq(d_id * curr.diff_hint_id_inv, curr.hint_id_changed); + + // hint_id_changed = 1 => new block starts at offset 0 builder .when_transition() - .when(both_non_padding) + .when(curr.both_non_padding) .when(curr.hint_id_changed) .assert_zero(next.offset); @@ -175,6 +191,11 @@ impl HintSpaceProviderChip { cols.value = *value; cols.multiplicity = *multiplicity; cols.mult_inv = multiplicity.try_inverse().unwrap(); + cols.curr_is_non_padding = F::ONE; + cols.next_is_non_padding = + if n + 1 < num_non_padding_rows { F::ONE } else { F::ZERO }; + cols.both_non_padding = + if n + 1 < num_non_padding_rows { F::ONE } else { F::ZERO }; // Fill auxiliary columns for the uniqueness constraint. if n + 1 < num_non_padding_rows { diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index 6d49b05798..fd8f31233f 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -251,25 +251,30 @@ impl Air for NativeSumcheckAir { start_timestamp + AB::F::from_canonical_usize(8), ); - // Prod row timestamp transition - builder - .when_transition() - .when(prod_row) - .assert_eq( - next.start_timestamp, - start_timestamp - + within_round_limit * (AB::Expr::ONE + is_writeback), - ); - - // Logup row timestamp transition - builder - .when_transition() - .when(logup_row) - .assert_eq( - next.start_timestamp, - start_timestamp - + within_round_limit * (AB::Expr::TWO + is_writeback), - ); + // _debug + // // Prod row timestamp transition + // // Equivalent to: when(prod_row): next_ts = ts + within_round_limit * (1 + is_writeback) + // // Reformulated using prod_in_round_evaluation + prod_next_round_evaluation = prod_row * within_round_limit + // // to stay within max constraint degree 3 while guarding against is_end boundary rows. + // builder + // .when_transition() + // .when_ne(is_end, AB::Expr::ONE) + // .assert_eq( + // prod_row * (next.start_timestamp - start_timestamp), + // (prod_in_round_evaluation + prod_next_round_evaluation) + // * (AB::Expr::ONE + is_writeback), + // ); + + // // Logup row timestamp transition + // // Same reformulation using logup_in_round_evaluation + logup_next_round_evaluation = logup_row * within_round_limit + // builder + // .when_transition() + // .when_ne(is_end, AB::Expr::ONE) + // .assert_eq( + // logup_row * (next.start_timestamp - start_timestamp), + // (logup_in_round_evaluation + logup_next_round_evaluation) + // * (AB::Expr::TWO + is_writeback), + // ); // Termination condition assert_array_eq( From efa8f0b20f800148e97c798baa6a7bb6b2ae5c3b Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 12 Mar 2026 20:07:31 -0400 Subject: [PATCH 16/18] remove debug flag --- extensions/native/circuit/src/sumcheck/air.rs | 43 ++++++++----------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/extensions/native/circuit/src/sumcheck/air.rs b/extensions/native/circuit/src/sumcheck/air.rs index fd8f31233f..4207695486 100644 --- a/extensions/native/circuit/src/sumcheck/air.rs +++ b/extensions/native/circuit/src/sumcheck/air.rs @@ -251,30 +251,25 @@ impl Air for NativeSumcheckAir { start_timestamp + AB::F::from_canonical_usize(8), ); - // _debug - // // Prod row timestamp transition - // // Equivalent to: when(prod_row): next_ts = ts + within_round_limit * (1 + is_writeback) - // // Reformulated using prod_in_round_evaluation + prod_next_round_evaluation = prod_row * within_round_limit - // // to stay within max constraint degree 3 while guarding against is_end boundary rows. - // builder - // .when_transition() - // .when_ne(is_end, AB::Expr::ONE) - // .assert_eq( - // prod_row * (next.start_timestamp - start_timestamp), - // (prod_in_round_evaluation + prod_next_round_evaluation) - // * (AB::Expr::ONE + is_writeback), - // ); - - // // Logup row timestamp transition - // // Same reformulation using logup_in_round_evaluation + logup_next_round_evaluation = logup_row * within_round_limit - // builder - // .when_transition() - // .when_ne(is_end, AB::Expr::ONE) - // .assert_eq( - // logup_row * (next.start_timestamp - start_timestamp), - // (logup_in_round_evaluation + logup_next_round_evaluation) - // * (AB::Expr::TWO + is_writeback), - // ); + // Prod row timestamp transition + builder + .when_transition() + .when_ne(is_end, AB::Expr::ONE) + .assert_eq( + prod_row * (next.start_timestamp - start_timestamp), + (prod_in_round_evaluation + prod_next_round_evaluation) + * (AB::Expr::ONE + is_writeback), + ); + + // Logup row timestamp transition + builder + .when_transition() + .when_ne(is_end, AB::Expr::ONE) + .assert_eq( + logup_row * (next.start_timestamp - start_timestamp), + (logup_in_round_evaluation + logup_next_round_evaluation) + * (AB::Expr::TWO + is_writeback), + ); // Termination condition assert_array_eq( From 4a49ffe54a9e8dec767f4e8b61c6d75da4a08aa8 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 15 Mar 2026 16:31:09 -0400 Subject: [PATCH 17/18] Adjust hint constraints --- .../native/circuit/src/extension/cuda.rs | 8 ++- .../native/circuit/src/extension/mod.rs | 11 ++- .../native/circuit/src/hint_space_provider.rs | 72 +++++++++++++------ 3 files changed, 67 insertions(+), 24 deletions(-) diff --git a/extensions/native/circuit/src/extension/cuda.rs b/extensions/native/circuit/src/extension/cuda.rs index 9ae1558500..50a9cba86d 100644 --- a/extensions/native/circuit/src/extension/cuda.rs +++ b/extensions/native/circuit/src/extension/cuda.rs @@ -79,10 +79,12 @@ impl VmProverExtension let poseidon2 = NativePoseidon2ChipGpu::<1>::new(range_checker.clone(), timestamp_max_bits); inventory.add_executor_chip(poseidon2); - // HintSpaceProvider must be registered BEFORE NativeSumcheck because chips are - // dispatched in reverse order: sumcheck runs first and populates the provider. let hint_air: &HintSpaceProviderAir = inventory.next_air::()?; - let cpu_chip = Arc::new(HintSpaceProviderChip::new(hint_air.hint_bus)); + let cpu_chip = Arc::new(HintSpaceProviderChip::new( + hint_air.hint_bus, + range_checker.clone(), + timestamp_max_bits, + )); let provider_gpu = HintSpaceProviderChipGpu::new(cpu_chip.clone()); inventory.add_periphery_chip(provider_gpu); diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index 86867ca1f5..924d4927e8 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -14,6 +14,7 @@ use openvm_circuit::{ }, system::{memory::SharedMemoryHelper, SystemPort}, }; +use openvm_circuit_primitives::is_less_than::IsLtSubAir; use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor}; use openvm_instructions::{program::DEFAULT_PC_STEP, LocalOpcode, PhantomDiscriminant}; use openvm_native_compiler::{ @@ -281,6 +282,10 @@ where let hint_space_provider = HintSpaceProviderAir { hint_bus: hint_bridge.hint_bus(), + lt_air: IsLtSubAir::new( + range_checker, + inventory.config().memory_config.timestamp_max_bits, + ), }; inventory.add_air(hint_space_provider); @@ -368,7 +373,11 @@ where inventory.add_executor_chip(poseidon2); let hint_bus = inventory.airs().system().hint_bridge.hint_bus(); - let hint_space_provider = Arc::new(HintSpaceProviderChip::new(hint_bus)); + let hint_space_provider = Arc::new(HintSpaceProviderChip::new( + hint_bus, + range_checker.clone(), + timestamp_max_bits, + )); inventory.next_air::()?; inventory.add_periphery_chip(hint_space_provider.clone()); diff --git a/extensions/native/circuit/src/hint_space_provider.rs b/extensions/native/circuit/src/hint_space_provider.rs index c9dc95be0d..2c5fdf0652 100644 --- a/extensions/native/circuit/src/hint_space_provider.rs +++ b/extensions/native/circuit/src/hint_space_provider.rs @@ -6,6 +6,11 @@ use std::{ }; use openvm_circuit::system::memory::offline_checker::HintBus; +use openvm_circuit_primitives::{ + is_less_than::{IsLessThanIo, IsLtSubAir}, + var_range::SharedVariableRangeCheckerChip, + SubAir, TraceSubRowGenerator, +}; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, @@ -17,6 +22,7 @@ use openvm_stark_backend::{ rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, Chip, ChipUsageGetter, }; +pub const HINT_ID_LT_AUX_LEN: usize = 2; #[derive(Default, AlignedBorrow, Copy, Clone)] #[repr(C)] @@ -29,9 +35,8 @@ pub struct HintSpaceProviderCols { pub mult_inv: T, /// Boolean: 1 if hint_id changes between this row and the next non-padding row. pub hint_id_changed: T, - /// When hint_id_changed = 1: inverse of (next.hint_id - hint_id), proving they differ. - /// When hint_id_changed = 0: unused (zero). - pub diff_hint_id_inv: T, + /// Auxiliary limbs for IsLtSubAir range check decomposition of (curr.hint_id < next.hint_id). + pub hint_id_lt_aux: [T; HINT_ID_LT_AUX_LEN], /// Boolean: 1 if this row is not a padding row (multiplicity > 0). pub curr_is_non_padding: T, /// Boolean: 1 if the next row is not a padding row. @@ -45,6 +50,7 @@ pub const NUM_HINT_SPACE_PROVIDER_COLS: usize = size_of:: BaseAirWithPublicValues for HintSpaceProviderAir {} @@ -95,8 +101,8 @@ impl Air for HintSpaceProviderAir { // Uniqueness of (hint_id, offset) among non-padding rows. // Rows are sorted by (hint_id, offset). For consecutive non-padding rows: // - Same hint_id (hint_id_changed=0): offset must increase by exactly 1. - // - Different hint_id (hint_id_changed=1): hint_id must actually differ - // (proven via inverse), and the new block starts at offset 0. + // - Different hint_id (hint_id_changed=1): hint_id strictly increases + // (proven via IsLtSubAir range-check), and the new block starts at offset 0. let d_id: AB::Expr = next.hint_id - curr.hint_id; // hint_id_changed = 0 => same hint_id, offset increases by 1 @@ -104,20 +110,27 @@ impl Air for HintSpaceProviderAir { .when_transition() .when(curr.both_non_padding) .when_ne(curr.hint_id_changed, AB::Expr::ONE) - .assert_zero(d_id.clone()); + .assert_zero(d_id); builder .when_transition() .when(curr.both_non_padding) .when_ne(curr.hint_id_changed, AB::Expr::ONE) .assert_eq(next.offset, curr.offset + AB::Expr::ONE); - // Combined inverse check: d_id * diff_hint_id_inv = hint_id_changed. - // When hint_id_changed = 0: trivially 0 = 0 (since d_id = 0 from above). - // When hint_id_changed = 1: proves d_id != 0 (hint_id actually changed). - builder - .when_transition() - .when(curr.both_non_padding) - .assert_eq(d_id * curr.diff_hint_id_inv, curr.hint_id_changed); + // hint_id_changed = 1 => hint_id strictly increases (curr.hint_id < next.hint_id) + let lt_count: AB::Expr = curr.hint_id_changed.into() * curr.both_non_padding.into(); + self.lt_air.eval( + builder, + ( + IsLessThanIo { + x: curr.hint_id.into(), + y: next.hint_id.into(), + out: curr.hint_id_changed.into(), + count: lt_count, + }, + &curr.hint_id_lt_aux, + ), + ); // hint_id_changed = 1 => new block starts at offset 0 builder @@ -138,6 +151,7 @@ impl Air for HintSpaceProviderAir { pub struct HintSpaceProviderChip { pub air: HintSpaceProviderAir, + range_checker: SharedVariableRangeCheckerChip, /// Maps (hint_id, offset) -> (value, multiplicity). /// Deduplicates keys and tracks how many times each is looked up. data: Mutex>, @@ -146,9 +160,21 @@ pub struct HintSpaceProviderChip { pub type SharedHintSpaceProviderChip = Arc>; impl HintSpaceProviderChip { - pub fn new(hint_bus: HintBus) -> Self { + pub fn new( + hint_bus: HintBus, + range_checker: SharedVariableRangeCheckerChip, + hint_id_max_bits: usize, + ) -> Self { + let lt_air = IsLtSubAir::new(range_checker.bus(), hint_id_max_bits); + assert_eq!( + lt_air.decomp_limbs, HINT_ID_LT_AUX_LEN, + "hint_id_max_bits={hint_id_max_bits} with range_max_bits={} requires {} limbs, but HINT_ID_LT_AUX_LEN={HINT_ID_LT_AUX_LEN}", + range_checker.range_max_bits(), + lt_air.decomp_limbs + ); Self { - air: HintSpaceProviderAir { hint_bus }, + air: HintSpaceProviderAir { hint_bus, lt_air }, + range_checker, data: Mutex::new(HashMap::new()), } } @@ -200,10 +226,16 @@ impl HintSpaceProviderChip { // Fill auxiliary columns for the uniqueness constraint. if n + 1 < num_non_padding_rows { let next_hint_id = entries[n + 1].0 .0; - let d_id = next_hint_id - *hint_id; - if d_id != F::ZERO { - cols.hint_id_changed = F::ONE; - cols.diff_hint_id_inv = d_id.try_inverse().unwrap(); + if next_hint_id != *hint_id { + // hint_id changes: fill IsLtSubAir aux columns + self.air.lt_air.generate_subrow( + ( + self.range_checker.as_ref(), + hint_id.as_canonical_u32(), + next_hint_id.as_canonical_u32(), + ), + (&mut cols.hint_id_lt_aux, &mut cols.hint_id_changed), + ); } else { debug_assert_eq!( entries[n + 1].0 .1, @@ -213,7 +245,7 @@ impl HintSpaceProviderChip { offset, entries[n + 1].0 .1 ); - // hint_id_changed = 0, diff_hint_id_inv = 0 (defaults) + // hint_id_changed = 0, hint_id_lt_aux = [0; ..] (defaults) } } // Last non-padding row: aux columns stay zero (no next non-padding row to compare) From 23c956d672b9d4aa1c7d8fc8117f9b980a053efc Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 22 Mar 2026 21:11:18 -0700 Subject: [PATCH 18/18] Allow Reading from Hint Space for `MULTI_OBSERVE` (#41) * rebase hint multi observe * adjust degree * fix * debug * adjust * adjust cuda * fix cuda * fix cuda * fix cuda * add debug utilities * add debug flags * remove debug flags * remove debug flag * remove debug flag --- crates/circuits/mod-builder/src/utils.rs | 2 +- .../circuit/cuda/include/native/poseidon2.cuh | 15 +- .../native/circuit/cuda/src/poseidon2.cu | 43 ++++- .../native/circuit/src/extension/cuda.rs | 20 +- .../native/circuit/src/extension/mod.rs | 36 ++-- extensions/native/circuit/src/fri/cuda.rs | 4 +- .../native/circuit/src/jal_rangecheck/cuda.rs | 4 +- .../native/circuit/src/poseidon2/air.rs | 132 ++++++++++--- .../native/circuit/src/poseidon2/chip.rs | 182 ++++++++++++++---- .../native/circuit/src/poseidon2/columns.rs | 30 ++- .../native/circuit/src/poseidon2/cuda.rs | 112 ++++++++++- .../native/circuit/src/poseidon2/execution.rs | 40 +++- .../native/compiler/src/asm/compiler.rs | 6 +- .../native/compiler/src/asm/instruction.rs | 11 +- .../native/compiler/src/conversion/mod.rs | 6 +- .../native/compiler/src/ir/instructions.rs | 13 +- extensions/native/compiler/src/ir/poseidon.rs | 33 +++- .../native/recursion/src/challenger/duplex.rs | 2 +- 18 files changed, 534 insertions(+), 157 deletions(-) diff --git a/crates/circuits/mod-builder/src/utils.rs b/crates/circuits/mod-builder/src/utils.rs index 2f2561ba87..f8dcd948b2 100644 --- a/crates/circuits/mod-builder/src/utils.rs +++ b/crates/circuits/mod-builder/src/utils.rs @@ -11,4 +11,4 @@ pub fn biguint_to_limbs_vec(x: &BigUint, num_limbs: usize) -> Vec { .chain(std::iter::repeat(0u8)) .take(num_limbs) .collect() -} +} \ No newline at end of file diff --git a/extensions/native/circuit/cuda/include/native/poseidon2.cuh b/extensions/native/circuit/cuda/include/native/poseidon2.cuh index 206c0e16c0..b794aacaef 100644 --- a/extensions/native/circuit/cuda/include/native/poseidon2.cuh +++ b/extensions/native/circuit/cuda/include/native/poseidon2.cuh @@ -65,14 +65,17 @@ template struct SimplePoseidonSpecificCols { template struct MultiObserveCols { T pc; T final_timestamp_increment; + T state_ptr_register; + T ctx_register; + T input_ptr_register; + T hint_id_register; T state_ptr; + T ctx_ptr; T input_ptr; - T init_pos; - T len; - T input_register_1; - T input_register_2; - T input_register_3; - T output_register; + T hint_id; + T ctx[4]; + MemoryReadAuxCols read_ctx; + T chunk_ts_count; T is_first; T is_last; T curr_len; diff --git a/extensions/native/circuit/cuda/src/poseidon2.cu b/extensions/native/circuit/cuda/src/poseidon2.cu index fdbe0d3ce5..32c0d36ec9 100644 --- a/extensions/native/circuit/cuda/src/poseidon2.cu +++ b/extensions/native/circuit/cuda/src/poseidon2.cu @@ -24,6 +24,7 @@ template struct NativePoseidon2Cols { T inside_row; T simple; T multi_observe_row; + T not_hint_multi_observe; T end_inside_row; T end_top_level; @@ -355,31 +356,57 @@ template struct Poseidon2Wrapper { if (specific[COL_INDEX(MultiObserveCols, is_first)] == Fp::one()) { uint32_t very_start_timestamp = row[COL_INDEX(Cols, very_first_timestamp)].asUInt32(); - for (uint32_t i = 0; i < 4; ++i) { + for (uint32_t i = 0; i < 3; ++i) { mem_fill_base( mem_helper, very_start_timestamp + i, specific.slice_from(COL_INDEX(MultiObserveCols, read_data[i].base)) ); } + mem_fill_base( + mem_helper, + very_start_timestamp + 3, + specific.slice_from(COL_INDEX(MultiObserveCols, read_ctx.base)) + ); + mem_fill_base( + mem_helper, + very_start_timestamp + 4, + specific.slice_from(COL_INDEX(MultiObserveCols, read_data[3].base)) + ); + + // Zero-length MULTI_OBSERVE case: head row is both first and last. + // The final ctx[0] writeback lives at row.start_timestamp. + if (specific[COL_INDEX(MultiObserveCols, is_last)] == Fp::one()) { + uint32_t start_timestamp = row[COL_INDEX(Cols, start_timestamp)].asUInt32(); + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(MultiObserveCols, write_final_idx.base)) + ); + } } else { uint32_t start_timestamp = row[COL_INDEX(Cols, start_timestamp)].asUInt32(); uint32_t chunk_start = specific[COL_INDEX(MultiObserveCols, start_idx)].asUInt32(); uint32_t chunk_end = specific[COL_INDEX(MultiObserveCols, end_idx)].asUInt32(); + uint32_t is_hint = + specific[COL_INDEX(MultiObserveCols, ctx[2])].asUInt32(); + uint32_t ts_per_element = 2 - is_hint; for (uint32_t j = chunk_start; j < chunk_end; ++j) { + if (!is_hint) { + mem_fill_base( + mem_helper, + start_timestamp, + specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base)) + ); + } mem_fill_base( mem_helper, - start_timestamp, - specific.slice_from(COL_INDEX(MultiObserveCols, read_data[j].base)) - ); - mem_fill_base( - mem_helper, - start_timestamp + 1, + start_timestamp + (1 - is_hint), specific.slice_from(COL_INDEX(MultiObserveCols, write_data[j].base)) ); - start_timestamp += 2; + start_timestamp += ts_per_element; } if (chunk_end >= CHUNK) { mem_fill_base( diff --git a/extensions/native/circuit/src/extension/cuda.rs b/extensions/native/circuit/src/extension/cuda.rs index 50a9cba86d..0d476413ac 100644 --- a/extensions/native/circuit/src/extension/cuda.rs +++ b/extensions/native/circuit/src/extension/cuda.rs @@ -75,19 +75,29 @@ impl VmProverExtension FriReducedOpeningChipGpu::new(range_checker.clone(), timestamp_max_bits); inventory.add_executor_chip(fri_reduced_opening); - inventory.next_air::>()?; - let poseidon2 = NativePoseidon2ChipGpu::<1>::new(range_checker.clone(), timestamp_max_bits); - inventory.add_executor_chip(poseidon2); - let hint_air: &HintSpaceProviderAir = inventory.next_air::()?; + let cpu_range_checker = range_checker + .cpu_chip + .clone() + .expect("VariableRangeCheckerChipGPU is expected to be hybrid with cpu_chip"); let cpu_chip = Arc::new(HintSpaceProviderChip::new( hint_air.hint_bus, - range_checker.clone(), + cpu_range_checker, timestamp_max_bits, )); + let provider_gpu = HintSpaceProviderChipGpu::new(cpu_chip.clone()); inventory.add_periphery_chip(provider_gpu); + inventory.next_air::>()?; + + let poseidon2 = NativePoseidon2ChipGpu::<1>::new_with_hint_space_provider( + range_checker.clone(), + timestamp_max_bits, + cpu_chip.clone(), + ); + inventory.add_executor_chip(poseidon2); + inventory.next_air::()?; let sumcheck = NativeSumcheckChipGpu::new(range_checker.clone(), timestamp_max_bits, cpu_chip); diff --git a/extensions/native/circuit/src/extension/mod.rs b/extensions/native/circuit/src/extension/mod.rs index 924d4927e8..4a930ddafe 100644 --- a/extensions/native/circuit/src/extension/mod.rs +++ b/extensions/native/circuit/src/extension/mod.rs @@ -271,15 +271,6 @@ where ); inventory.add_air(fri_reduced_opening); - let verify_batch = NativePoseidon2Air::<_, 1>::new( - exec_bridge, - memory_bridge, - hint_bridge, - VerifyBatchBus::new(inventory.new_bus_idx()), - Poseidon2Config::default(), - ); - inventory.add_air(verify_batch); - let hint_space_provider = HintSpaceProviderAir { hint_bus: hint_bridge.hint_bus(), lt_air: IsLtSubAir::new( @@ -289,6 +280,15 @@ where }; inventory.add_air(hint_space_provider); + let verify_batch = NativePoseidon2Air::<_, 1>::new( + exec_bridge, + memory_bridge, + hint_bridge, + VerifyBatchBus::new(inventory.new_bus_idx()), + Poseidon2Config::default(), + ); + inventory.add_air(verify_batch); + let tower_evaluate = NativeSumcheckAir::new(exec_bridge, memory_bridge, hint_bridge); inventory.add_air(tower_evaluate); @@ -365,13 +365,6 @@ where FriReducedOpeningChip::new(FriReducedOpeningFiller::new(), mem_helper.clone()); inventory.add_executor_chip(fri_reduced_opening); - inventory.next_air::, 1>>()?; - let poseidon2 = NativePoseidon2Chip::<_, 1>::new( - NativePoseidon2Filler::new(Poseidon2Config::default()), - mem_helper.clone(), - ); - inventory.add_executor_chip(poseidon2); - let hint_bus = inventory.airs().system().hint_bridge.hint_bus(); let hint_space_provider = Arc::new(HintSpaceProviderChip::new( hint_bus, @@ -382,8 +375,17 @@ where inventory.next_air::()?; inventory.add_periphery_chip(hint_space_provider.clone()); + inventory.next_air::, 1>>()?; + + let poseidon2 = NativePoseidon2Chip::<_, 1>::new( + NativePoseidon2Filler::new(Poseidon2Config::default(), hint_space_provider.clone()), + mem_helper.clone(), + ); + inventory.add_executor_chip(poseidon2); + + inventory.next_air::()?; let tower_verify = NativeSumcheckChip::new( - NativeSumcheckFiller::new(hint_space_provider), + NativeSumcheckFiller::new(hint_space_provider.clone()), mem_helper.clone(), ); inventory.add_executor_chip(tower_verify); diff --git a/extensions/native/circuit/src/fri/cuda.rs b/extensions/native/circuit/src/fri/cuda.rs index 06f3e91180..4f4ef59e60 100644 --- a/extensions/native/circuit/src/fri/cuda.rs +++ b/extensions/native/circuit/src/fri/cuda.rs @@ -13,7 +13,9 @@ use openvm_cuda_common::copy::MemCopyH2D; use openvm_stark_backend::{prover::types::AirProvingContext, Chip}; use super::{FriReducedOpeningRecordMut, OVERALL_WIDTH}; -use crate::cuda_abi::fri_cuda; +use crate::{ + cuda_abi::fri_cuda, +}; #[derive(new)] pub struct FriReducedOpeningChipGpu { diff --git a/extensions/native/circuit/src/jal_rangecheck/cuda.rs b/extensions/native/circuit/src/jal_rangecheck/cuda.rs index a273a1b443..7ba9fa198b 100644 --- a/extensions/native/circuit/src/jal_rangecheck/cuda.rs +++ b/extensions/native/circuit/src/jal_rangecheck/cuda.rs @@ -10,7 +10,9 @@ use openvm_cuda_common::copy::MemCopyH2D; use openvm_stark_backend::{prover::types::AirProvingContext, Chip}; use super::{JalRangeCheckCols, JalRangeCheckRecord}; -use crate::cuda_abi::native_jal_rangecheck_cuda; +use crate::{ + cuda_abi::native_jal_rangecheck_cuda, +}; #[derive(new)] pub struct JalRangeCheckGpu { diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index 9e9cdf5ce8..8b3b02ebc2 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -94,6 +94,7 @@ impl Air inside_row, simple, multi_observe_row, + not_hint_multi_observe, end_inside_row, end_top_level, start_top_level, @@ -713,10 +714,17 @@ impl Air let &MultiObserveCols { pc, final_timestamp_increment, + state_ptr_register, + ctx_register, + input_ptr_register, + hint_id_register, state_ptr, + ctx_ptr, input_ptr, - init_pos, - len, + hint_id, + ctx, + read_ctx, + chunk_ts_count, is_first, is_last, curr_len, @@ -731,26 +739,38 @@ impl Air should_permute, write_sponge_state, write_final_idx, - input_register_1, - input_register_2, - input_register_3, - output_register, } = multi_observe_specific; + // Alias context values + let init_pos = ctx[0]; + let len = ctx[1]; + let is_hint = ctx[2]; + builder.when(multi_observe_row).assert_bool(is_first); builder.when(multi_observe_row).assert_bool(is_last); builder.when(multi_observe_row).assert_bool(should_permute); + builder.when(multi_observe_row).assert_bool(is_hint); + builder.assert_eq( + not_hint_multi_observe, + multi_observe_row * (AB::Expr::ONE - is_hint), + ); + let hint_multi_observe: AB::Expr = multi_observe_row - not_hint_multi_observe; + // chunk_ts_count = (end_idx - start_idx) * (2 - is_hint) + builder.when(multi_observe_row).assert_eq( + chunk_ts_count, + (end_idx - start_idx) * AB::F::TWO - (end_idx - start_idx) * is_hint, + ); self.execution_bridge .execute_and_increment_pc( AB::F::from_canonical_usize(MULTI_OBSERVE.global_opcode().as_usize()), [ - output_register.into(), - input_register_1.into(), - input_register_2.into(), + state_ptr_register.into(), + ctx_register.into(), + input_ptr_register.into(), self.address_space.into(), self.address_space.into(), - input_register_3.into(), + hint_id_register.into(), ], ExecutionState::new(pc, very_first_timestamp), final_timestamp_increment, @@ -759,7 +779,7 @@ impl Air self.memory_bridge .read( - MemoryAddress::new(self.address_space, output_register), + MemoryAddress::new(self.address_space, state_ptr_register), [state_ptr], very_first_timestamp, &read_data[0], @@ -768,8 +788,8 @@ impl Air self.memory_bridge .read( - MemoryAddress::new(self.address_space, input_register_1), - [init_pos], + MemoryAddress::new(self.address_space, ctx_register), + [ctx_ptr], very_first_timestamp + AB::F::ONE, &read_data[1], ) @@ -777,24 +797,47 @@ impl Air self.memory_bridge .read( - MemoryAddress::new(self.address_space, input_register_2), + MemoryAddress::new(self.address_space, input_ptr_register), [input_ptr], very_first_timestamp + AB::F::TWO, &read_data[2], ) .eval(builder, multi_observe_row * is_first); + // Read context array: [init_pos, len, is_hint, reserved] from ctx_ptr self.memory_bridge .read( - MemoryAddress::new(self.address_space, input_register_3), - [len], + MemoryAddress::new(self.address_space, ctx_ptr), + ctx, very_first_timestamp + AB::F::from_canonical_usize(3), + &read_ctx, + ) + .eval(builder, multi_observe_row * is_first); + + // Read hint_id from register (reuse spare read_data[3] on head row) + self.memory_bridge + .read( + MemoryAddress::new(self.address_space, hint_id_register), + [hint_id], + very_first_timestamp + AB::F::from_canonical_usize(4), &read_data[3], ) .eval(builder, multi_observe_row * is_first); + // Per-element constraints for chunk rows. for i in 0..CHUNK { let i_var = AB::F::from_canonical_usize(i); + + // Hint mode: lookup from hint space. + self.hint_bridge.lookup( + builder, + hint_id, + curr_len + i_var - start_idx, + data[i], + hint_multi_observe.clone() * aux_read_enabled[i], + ); + + // Non-hint mode: read from memory. self.memory_bridge .read( MemoryAddress::new( @@ -805,8 +848,7 @@ impl Air start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO, &read_data[i], ) - .eval(builder, multi_observe_row * aux_read_enabled[i]); - + .eval(builder, not_hint_multi_observe * aux_read_enabled[i]); self.memory_bridge .write( MemoryAddress::new(self.address_space, state_ptr + i_var), @@ -814,7 +856,16 @@ impl Air start_timestamp + i_var * AB::F::TWO - start_idx * AB::F::TWO + AB::F::ONE, &write_data[i], ) - .eval(builder, multi_observe_row * aux_read_enabled[i]); + .eval(builder, not_hint_multi_observe * aux_read_enabled[i]); + + self.memory_bridge + .write( + MemoryAddress::new(self.address_space, state_ptr + i_var), + [data[i]], + start_timestamp + i_var - start_idx, + &write_data[i], + ) + .eval(builder, hint_multi_observe.clone() * aux_read_enabled[i]); } for i in 0..(CHUNK - 1) { @@ -885,7 +936,7 @@ impl Air .write( MemoryAddress::new(self.address_space, state_ptr), full_sponge_output, - start_timestamp + (end_idx - start_idx) * AB::F::TWO, + start_timestamp + chunk_ts_count, &write_sponge_state, ) .eval(builder, multi_observe_row * should_permute); @@ -909,11 +960,12 @@ impl Air // final_idx = aux_read_enabled[CHUNK-1] * 0 + (1 - aux_read_enabled[CHUNK-1]) * end_idx let final_idx = aux_read_enabled[CHUNK - 1] * AB::Expr::ZERO + (AB::Expr::ONE - aux_read_enabled[CHUNK - 1]) * end_idx; + // Write final_idx back to ctx[0] (ctx_ptr address) self.memory_bridge .write( - MemoryAddress::new(self.address_space, input_register_1), + MemoryAddress::new(self.address_space, ctx_ptr), [final_idx], - start_timestamp + (end_idx - start_idx) * AB::F::TWO + should_permute, + start_timestamp + chunk_ts_count + should_permute, &write_final_idx, ) .eval(builder, multi_observe_row * is_last); @@ -962,41 +1014,59 @@ impl Air builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) - .assert_eq(init_pos, next_multi_observe_specific.init_pos); + .assert_eq(init_pos, next_multi_observe_specific.ctx[0]); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(len, next_multi_observe_specific.ctx[1]); builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) - .assert_eq(len, next_multi_observe_specific.len); + .assert_eq( + state_ptr_register, + next_multi_observe_specific.state_ptr_register, + ); builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) .assert_eq( - input_register_1, - next_multi_observe_specific.input_register_1, + ctx_register, + next_multi_observe_specific.ctx_register, ); builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) .assert_eq( - input_register_2, - next_multi_observe_specific.input_register_2, + input_ptr_register, + next_multi_observe_specific.input_ptr_register, ); + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(ctx_ptr, next_multi_observe_specific.ctx_ptr); + + builder + .when(next.multi_observe_row) + .when(not(next_multi_observe_specific.is_first)) + .assert_eq(hint_id, next_multi_observe_specific.hint_id); + builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) .assert_eq( - input_register_3, - next_multi_observe_specific.input_register_3, + hint_id_register, + next_multi_observe_specific.hint_id_register, ); builder .when(next.multi_observe_row) .when(not(next_multi_observe_specific.is_first)) - .assert_eq(output_register, next_multi_observe_specific.output_register); + .assert_eq(is_hint, next_multi_observe_specific.ctx[2]); // Timestamp constraints builder diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 770efc7307..e136c2fbb0 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -1,4 +1,5 @@ use std::borrow::{Borrow, BorrowMut}; +use std::sync::Arc; use openvm_circuit::{ arch::*, @@ -22,6 +23,7 @@ use openvm_stark_backend::{ }; use crate::{ + hint_space_provider::HintSpaceProviderChip, mem_fill_helper, poseidon2::{ columns::{ @@ -45,6 +47,7 @@ pub struct NativePoseidon2Filler { // pre-computed Poseidon2 sub cols for dummy rows. empty_poseidon2_sub_cols: Vec, pub(super) subchip: Poseidon2SubChip, + pub hint_space_provider: Arc>, } impl NativePoseidon2Executor { @@ -71,12 +74,16 @@ pub(crate) fn compress( } impl NativePoseidon2Filler { - pub fn new(poseidon2_config: Poseidon2Config) -> Self { + pub fn new( + poseidon2_config: Poseidon2Config, + hint_space_provider: Arc>, + ) -> Self { let subchip = Poseidon2SubChip::new(poseidon2_config.constants); let empty_poseidon2_sub_cols = subchip.generate_trace(vec![[F::ZERO; CHUNK * 2]]).values; Self { empty_poseidon2_sub_cols, subchip, + hint_space_provider, } } } @@ -649,11 +656,11 @@ where } else if instruction.opcode == MULTI_OBSERVE.global_opcode() { let &Instruction { a: state_ptr_register, - b: init_pos_register, + b: ctx_register, c: input_ptr_register, d: register_address_space, e: data_address_space, - f: len_register, + f: hint_id_register, .. } = instruction; @@ -663,31 +670,52 @@ where ); assert_eq!(data_address_space, F::from_canonical_u32(AS::Native as u32)); - let [init_pos]: [F; 1] = - memory_read_native(state.memory.data(), init_pos_register.as_canonical_u32()); - let [input_len]: [F; 1] = - memory_read_native(state.memory.data(), len_register.as_canonical_u32()); + // Read ctx_ptr from register, then read context array from memory + let [ctx_ptr]: [F; 1] = + memory_read_native(state.memory.data(), ctx_register.as_canonical_u32()); + let ctx: [F; 4] = + memory_read_native(state.memory.data(), ctx_ptr.as_canonical_u32()); + let init_pos = ctx[0]; + let input_len = ctx[1]; + let is_hint = ctx[2].as_canonical_u32() != 0; + + // Read hint_id from register + let [hint_id]: [F; 1] = + memory_read_native(state.memory.data(), hint_id_register.as_canonical_u32()); + + // Get hint_space data if in hint mode + let hint_data: Vec = if is_hint { + state.streams.hint_space[hint_id.as_canonical_u32() as usize].clone() + } else { + vec![] + }; let mut len = input_len.as_canonical_u32() as usize; let mut pos = init_pos.as_canonical_u32() as usize; let mut chunks: Vec<(usize, usize)> = vec![]; - const NUM_HEAD_ACCESSES: usize = 4; + // 3 register reads + 1 context array read + 1 hint_id register read = 5 head accesses + const NUM_HEAD_ACCESSES: usize = 5; let mut final_timestamp_inc = NUM_HEAD_ACCESSES; + // In hint mode: 1 timestamp per element (write only) + // In non-hint mode: 2 timestamps per element (read + write) + let ts_per_element = if is_hint { 1 } else { 2 }; while len > 0 { if len >= (CHUNK - pos) { chunks.push((pos, CHUNK)); len -= CHUNK - pos; - final_timestamp_inc += 2 * (CHUNK - pos) + 1; + final_timestamp_inc += ts_per_element * (CHUNK - pos) + 1; pos = 0; } else { chunks.push((pos, pos + len)); - final_timestamp_inc += 2 * len; + final_timestamp_inc += ts_per_element * len; len = 0; pos += len; } } - final_timestamp_inc += 1; // write back to init_pos_register + // Final ctx[0] writeback always happens (including zero-length input + // where the head row is both the first and last row). + final_timestamp_inc += 1; let allocated_rows = arena .alloc(MultiRowLayout::new(NativePoseidon2Metadata { @@ -698,14 +726,15 @@ where let head_multi_observe_cols: &mut MultiObserveCols = head_cols.specific[..MultiObserveCols::::width()].borrow_mut(); + // 3 register reads: state_ptr, ctx_ptr, input_ptr let [state_ptr]: [F; 1] = tracing_read_native_helper( state.memory, state_ptr_register.as_canonical_u32(), head_multi_observe_cols.read_data[0].as_mut(), ); - let [init_pos]: [F; 1] = tracing_read_native_helper( + let [ctx_ptr]: [F; 1] = tracing_read_native_helper( state.memory, - init_pos_register.as_canonical_u32(), + ctx_register.as_canonical_u32(), head_multi_observe_cols.read_data[1].as_mut(), ); let [input_ptr]: [F; 1] = tracing_read_native_helper( @@ -713,9 +742,16 @@ where input_ptr_register.as_canonical_u32(), head_multi_observe_cols.read_data[2].as_mut(), ); - let [input_len]: [F; 1] = tracing_read_native_helper( + // 1 context array read: [init_pos, len, is_hint, reserved] + let ctx: [F; 4] = tracing_read_native_helper( + state.memory, + ctx_ptr.as_canonical_u32(), + head_multi_observe_cols.read_ctx.as_mut(), + ); + // 1 hint_id register read (reuse spare read_data[3] on head row) + let [hint_id]: [F; 1] = tracing_read_native_helper( state.memory, - len_register.as_canonical_u32(), + hint_id_register.as_canonical_u32(), head_multi_observe_cols.read_data[3].as_mut(), ); @@ -727,16 +763,20 @@ where for (i, cols) in allocated_rows.iter_mut().enumerate() { let multi_observe_cols: &mut MultiObserveCols = cols.specific[..MultiObserveCols::::width()].borrow_mut(); - multi_observe_cols.input_register_1 = init_pos_register; - multi_observe_cols.input_register_2 = input_ptr_register; - multi_observe_cols.input_register_3 = len_register; - multi_observe_cols.output_register = state_ptr_register; - multi_observe_cols.init_pos = init_pos; - multi_observe_cols.input_ptr = input_ptr; + multi_observe_cols.state_ptr_register = state_ptr_register; + multi_observe_cols.ctx_register = ctx_register; + multi_observe_cols.input_ptr_register = input_ptr_register; + multi_observe_cols.hint_id_register = hint_id_register; multi_observe_cols.state_ptr = state_ptr; - multi_observe_cols.len = input_len; + multi_observe_cols.ctx_ptr = ctx_ptr; + multi_observe_cols.input_ptr = input_ptr; + multi_observe_cols.hint_id = hint_id; + multi_observe_cols.ctx = ctx; + + // chunk_ts_count will be filled per-chunk row below cols.multi_observe_row = F::ONE; + cols.not_hint_multi_observe = if is_hint { F::ZERO } else { F::ONE }; cols.very_first_timestamp = init_timestamp; if i == 0 { @@ -746,9 +786,26 @@ where multi_observe_cols.final_timestamp_increment = F::from_canonical_usize(final_timestamp_inc); multi_observe_cols.is_first = F::ONE; - multi_observe_cols.is_last = F::ZERO; + multi_observe_cols.is_last = if chunks.is_empty() { F::ONE } else { F::ZERO }; multi_observe_cols.curr_len = F::ZERO; multi_observe_cols.should_permute = F::ZERO; + if chunks.is_empty() { + // Zero-length input: head row is both first and last. + // Set start_timestamp to right after the 5 head reads, + // and write back init_pos (unchanged) to ctx_ptr[0]. + cols.start_timestamp = F::from_canonical_u32( + init_timestamp_u32 + NUM_HEAD_ACCESSES as u32, + ); + multi_observe_cols.start_idx = init_pos; + multi_observe_cols.end_idx = init_pos; + // state.memory.timestamp == init_ts + NUM_HEAD_ACCESSES here. + tracing_write_native_inplace( + state.memory, + ctx_ptr.as_canonical_u32(), + [init_pos], + &mut multi_observe_cols.write_final_idx, + ); + } } } @@ -767,6 +824,7 @@ where multi_observe_cols.start_idx = F::from_canonical_usize(chunk_start); multi_observe_cols.end_idx = F::from_canonical_usize(chunk_end); + multi_observe_cols.chunk_ts_count = F::from_canonical_usize((chunk_end - chunk_start) * ts_per_element); multi_observe_cols.is_first = F::ZERO; multi_observe_cols.is_last = if i == num_chunks - 1 { F::ONE } else { F::ZERO }; @@ -779,21 +837,29 @@ where multi_observe_cols.aux_before_end[j] = F::ONE; } for j in chunk_start..chunk_end { - let n_f: [F; 1] = tracing_read_native_helper( - state.memory, - input_ptr_u32 + input_idx as u32, - multi_observe_cols.read_data[j].as_mut(), - ); + let n_f: F = if is_hint { + // In hint mode: read from hint_space + hint_data[input_idx] + } else { + // In non-hint mode: read from memory via tracing read + let [v]: [F; 1] = tracing_read_native_helper( + state.memory, + input_ptr_u32 + input_idx as u32, + multi_observe_cols.read_data[j].as_mut(), + ); + v + }; + multi_observe_cols.aux_read_enabled[j] = F::ONE; tracing_write_native_inplace( state.memory, state_ptr_u32 + j as u32, - n_f, + [n_f], &mut multi_observe_cols.write_data[j], ); - multi_observe_cols.data[j] = n_f[0]; + multi_observe_cols.data[j] = n_f; input_idx += 1; - cur_timestamp += 2; + cur_timestamp += ts_per_element as u32; } let permutation_input: [F; 16] = @@ -817,7 +883,7 @@ where let final_idx = F::from_canonical_usize(chunk_end % CHUNK); tracing_write_native_inplace( state.memory, - init_pos_register.as_canonical_u32(), + ctx_ptr.as_canonical_u32(), [final_idx], &mut multi_observe_cols.write_final_idx, ); @@ -1161,7 +1227,7 @@ impl NativePoseidon2Filler::width()].borrow_mut(); let start_timestamp_u32 = head_cols.very_first_timestamp.as_canonical_u32(); - // state_ptr, init_pos, input_ptr, len + // 3 register reads: state_ptr, ctx_ptr, input_ptr mem_fill_helper( mem_helper, start_timestamp_u32, @@ -1177,12 +1243,23 @@ impl NativePoseidon2Filler = chunk_slice @@ -1194,6 +1271,8 @@ impl NativePoseidon2Filler = chunk_slice[row_idx * width..(row_idx + 1) * width].borrow_mut(); @@ -1205,18 +1284,32 @@ impl NativePoseidon2Filler= CHUNK as u32 { @@ -1235,6 +1328,15 @@ impl NativePoseidon2Filler = + chunk_slice[..width].borrow_mut(); + let head_mo: &mut MultiObserveCols = + head_c.specific[..MultiObserveCols::::width()].borrow_mut(); + let head_ts = head_c.start_timestamp.as_canonical_u32(); + mem_fill_helper(mem_helper, head_ts, head_mo.write_final_idx.as_mut()); + } } #[inline(always)] diff --git a/extensions/native/circuit/src/poseidon2/columns.rs b/extensions/native/circuit/src/poseidon2/columns.rs index abb8db54a2..67557a5f73 100644 --- a/extensions/native/circuit/src/poseidon2/columns.rs +++ b/extensions/native/circuit/src/poseidon2/columns.rs @@ -31,6 +31,10 @@ pub struct NativePoseidon2Cols { /// Indicates that this row is a multi_observe row. pub multi_observe_row: T, + /// Materialized column: multi_observe_row * (1 - is_hint). + /// Lives in main cols (not overlaid specific) so it is 0 on non-multi_observe rows. + pub not_hint_multi_observe: T, + /// Indicates the last row in an inside-row block. pub end_inside_row: T, /// Indicates the last row in a top-level block. @@ -211,16 +215,24 @@ pub struct MultiObserveCols { pub pc: T, pub final_timestamp_increment: T, - // Initial reads from registers - // They are same across same instance of multi_observe + // Register addresses + pub state_ptr_register: T, + pub ctx_register: T, + pub input_ptr_register: T, + pub hint_id_register: T, + + // Values read from registers pub state_ptr: T, + pub ctx_ptr: T, pub input_ptr: T, - pub init_pos: T, - pub len: T, - pub input_register_1: T, - pub input_register_2: T, - pub input_register_3: T, - pub output_register: T, + pub hint_id: T, + + // Context array values read from ctx_ptr + // ctx[0] = init_pos, ctx[1] = len, ctx[2] = is_hint, ctx[3] = reserved + pub ctx: [T; 4], + pub read_ctx: MemoryReadAuxCols, + + pub chunk_ts_count: T, pub is_first: T, pub is_last: T, @@ -240,6 +252,6 @@ pub struct MultiObserveCols { pub should_permute: T, pub write_sponge_state: MemoryWriteAuxCols, - // Final write back and registers + // Final write back to ctx[0] pub write_final_idx: MemoryWriteAuxCols, } diff --git a/extensions/native/circuit/src/poseidon2/cuda.rs b/extensions/native/circuit/src/poseidon2/cuda.rs index 0425cfda18..4bdf337a75 100644 --- a/extensions/native/circuit/src/poseidon2/cuda.rs +++ b/extensions/native/circuit/src/poseidon2/cuda.rs @@ -1,6 +1,5 @@ use std::{borrow::Borrow, mem::size_of, slice::from_raw_parts, sync::Arc}; -use derive_new::new; use openvm_circuit::{arch::DenseRecordArena, utils::next_power_of_two_or_zero}; use openvm_circuit_primitives::var_range::VariableRangeCheckerChipGPU; use openvm_cuda_backend::{ @@ -8,15 +7,115 @@ use openvm_cuda_backend::{ }; use openvm_cuda_common::copy::MemCopyH2D; use openvm_stark_backend::{prover::types::AirProvingContext, Chip}; -use p3_field::{Field, PrimeField32}; +use p3_field::{Field, FieldAlgebra, PrimeField32}; -use super::columns::NativePoseidon2Cols; -use crate::cuda_abi::poseidon2_cuda; +use super::columns::{MultiObserveCols, NativePoseidon2Cols}; +use crate::{ + cuda_abi::poseidon2_cuda, + hint_space_provider::SharedHintSpaceProviderChip, +}; -#[derive(new)] pub struct NativePoseidon2ChipGpu { pub range_checker: Arc, pub timestamp_max_bits: usize, + pub hint_space_provider: Option>, +} + +impl NativePoseidon2ChipGpu { + pub fn new(range_checker: Arc, timestamp_max_bits: usize) -> Self { + Self { + range_checker, + timestamp_max_bits, + hint_space_provider: None, + } + } + + pub fn new_with_hint_space_provider( + range_checker: Arc, + timestamp_max_bits: usize, + hint_space_provider: SharedHintSpaceProviderChip, + ) -> Self { + Self { + range_checker, + timestamp_max_bits, + hint_space_provider: Some(hint_space_provider), + } + } + + /// Scans multi-observe execution records to populate the hint provider with + /// (hint_id, offset, value) triples for hint-mode rows. + fn populate_hint_provider(&self, records: &[u8]) { + let Some(hint_space_provider) = &self.hint_space_provider else { + return; + }; + + let width = NativePoseidon2Cols::::width(); + let record_size = width * size_of::(); + if records.len() % record_size != 0 { + return; + } + let height = records.len() / record_size; + + let row_slice = unsafe { + let ptr = records.as_ptr() as *const F; + from_raw_parts(ptr, height * width) + }; + + let mut row_idx = 0; + while row_idx < height { + let start = row_idx * width; + let cols: &NativePoseidon2Cols = + row_slice[start..(start + width)].borrow(); + + if cols.multi_observe_row.is_one() { + let num_rows = cols.inner.export.as_canonical_u32() as usize; + if num_rows > 1 { + let head_multi_observe_cols: &MultiObserveCols = + cols.specific[..MultiObserveCols::::width()].borrow(); + let is_hint = head_multi_observe_cols.ctx[2] != F::ZERO; + if is_hint { + let hint_id = head_multi_observe_cols.hint_id; + for local_row in 1..num_rows { + let chunk_cols: &NativePoseidon2Cols = + row_slice[(row_idx + local_row) * width + ..(row_idx + local_row + 1) * width] + .borrow(); + let multi_observe_cols: &MultiObserveCols = chunk_cols.specific + [..MultiObserveCols::::width()] + .borrow(); + + let chunk_start = multi_observe_cols.start_idx.as_canonical_u32(); + let chunk_end = multi_observe_cols.end_idx.as_canonical_u32(); + let curr_len = multi_observe_cols.curr_len.as_canonical_u32(); + + for j in chunk_start..chunk_end { + let input_idx = curr_len + (j - chunk_start); + let val = multi_observe_cols.data[j as usize]; + hint_space_provider.request( + hint_id, + F::from_canonical_u32(input_idx), + val, + ); + } + } + } + } + row_idx += num_rows.max(1); + continue; + } + + if cols.simple.is_one() { + row_idx += 1; + } else { + let num_non_inside_row = cols.inner.export.as_canonical_u32() as usize; + let non_inside_start = start + (num_non_inside_row - 1) * width; + let last_non_inside_cols: &NativePoseidon2Cols = + row_slice[non_inside_start..(non_inside_start + width)].borrow(); + let total_num_row = last_non_inside_cols.inner.export.as_canonical_u32() as usize; + row_idx += total_num_row; + } + } + } } impl Chip @@ -28,6 +127,9 @@ impl Chip return get_empty_air_proving_ctx::(); } + // Populate hint space provider from multi-observe records before GPU upload. + self.populate_hint_provider(records); + // For Poseidon2, the records are already the trace rows // Use the columns width directly let width = NativePoseidon2Cols::::width(); diff --git a/extensions/native/circuit/src/poseidon2/execution.rs b/extensions/native/circuit/src/poseidon2/execution.rs index a0c1fc72a2..d41d911812 100644 --- a/extensions/native/circuit/src/poseidon2/execution.rs +++ b/extensions/native/circuit/src/poseidon2/execution.rs @@ -35,9 +35,9 @@ struct Pos2PreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { #[repr(C)] struct MultiObservePreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { subchip: &'a Poseidon2SubChip, - pub init_pos_register: u32, + pub ctx_register: u32, pub input_ptr_register: u32, - pub len_register: u32, + pub hint_id_register: u32, pub state_ptr_register: u32, } @@ -137,9 +137,9 @@ impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Executor = if is_hint { + exec_state.streams.hint_space[hint_id_u32 as usize].clone() + } else { + vec![] + }; for (chunk_start, chunk_end) in observation_chunks { for j in chunk_start..chunk_end { - let [n_f]: [F; 1] = exec_state.vm_read(NATIVE_AS, input_ptr_u32 + input_idx); + let n_f = if is_hint { + hint_data[input_idx as usize] + } else { + let [v]: [F; 1] = exec_state.vm_read(NATIVE_AS, input_ptr_u32 + input_idx); + v + }; + exec_state.vm_write(NATIVE_AS, sponge_ptr_u32 + (j as u32), &[n_f]); input_idx += 1; } @@ -634,9 +653,10 @@ unsafe fn execute_multi_observe_e12_impl< height += 1; } if let Some(final_idx) = final_idx { + // Write final_idx back to ctx[0] (overwriting init_pos in context array) exec_state.vm_write::( NATIVE_AS, - pre_compute.init_pos_register, + ctx_ptr.as_canonical_u32(), &[F::from_canonical_usize(final_idx)], ); } diff --git a/extensions/native/compiler/src/asm/compiler.rs b/extensions/native/compiler/src/asm/compiler.rs index 689bb1ebd0..14d84b4f1f 100644 --- a/extensions/native/compiler/src/asm/compiler.rs +++ b/extensions/native/compiler/src/asm/compiler.rs @@ -489,13 +489,13 @@ impl + TwoAdicField> AsmCo DslIr::HintBitsF(var, len) => { self.push(AsmInstruction::HintBits(var.fp(), len), debug_info); } - DslIr::Poseidon2MultiObserve(dst, init_pos, arr_ptr, len) => { + DslIr::Poseidon2MultiObserve(dst, ctx_ptr, arr_ptr, hint_id) => { self.push( AsmInstruction::Poseidon2MultiObserve( dst.fp(), - init_pos.fp(), + ctx_ptr.fp(), arr_ptr.fp(), - len.get_var().fp(), + hint_id.fp(), ), debug_info, ); diff --git a/extensions/native/compiler/src/asm/instruction.rs b/extensions/native/compiler/src/asm/instruction.rs index b715d97cf1..16b2be4b49 100644 --- a/extensions/native/compiler/src/asm/instruction.rs +++ b/extensions/native/compiler/src/asm/instruction.rs @@ -110,9 +110,10 @@ pub enum AsmInstruction { /// Halt. Halt, - /// Absorbs multiple base elements into a duplex transcript with Poseidon2 permutation - /// (sponge_state, init_pos, arr_ptr, len) - /// Returns the final index position of hash sponge + /// Absorbs multiple base elements into a duplex transcript with Poseidon2 permutation. + /// (sponge_state, ctx_ptr, arr_ptr, hint_id) + /// Context array at ctx_ptr: [init_pos, len, is_hint, reserved] + /// When is_hint=1, data is read from hint space using hint_id. Poseidon2MultiObserve(i32, i32, i32, i32), /// Perform a Poseidon2 permutation on state starting at address `lhs` @@ -350,11 +351,11 @@ impl> AsmInstruction { AsmInstruction::Trap => write!(f, "trap"), AsmInstruction::Halt => write!(f, "halt"), AsmInstruction::HintBits(src, len) => write!(f, "hint_bits ({})fp, {}", src, len), - AsmInstruction::Poseidon2MultiObserve(dst, init_pos, arr, len) => { + AsmInstruction::Poseidon2MultiObserve(dst, ctx, arr, hint_id) => { write!( f, "poseidon2_multi_observe ({})fp, ({})fp ({})fp ({})fp", - dst, init_pos, arr, len + dst, ctx, arr, hint_id ) } AsmInstruction::Poseidon2Permute(dst, lhs) => { diff --git a/extensions/native/compiler/src/conversion/mod.rs b/extensions/native/compiler/src/conversion/mod.rs index 0ff358ec70..61fd726d3f 100644 --- a/extensions/native/compiler/src/conversion/mod.rs +++ b/extensions/native/compiler/src/conversion/mod.rs @@ -441,15 +441,15 @@ fn convert_instruction>( AS::Native, AS::Native, )], - AsmInstruction::Poseidon2MultiObserve(dst, init, arr, len) => vec![ + AsmInstruction::Poseidon2MultiObserve(dst, ctx, arr, hint_id) => vec![ Instruction { opcode: options.opcode_with_offset(Poseidon2Opcode::MULTI_OBSERVE), a: i32_f(dst), - b: i32_f(init), + b: i32_f(ctx), c: i32_f(arr), d: AS::Native.to_field(), e: AS::Native.to_field(), - f: i32_f(len), + f: i32_f(hint_id), g: F::ZERO, } ], diff --git a/extensions/native/compiler/src/ir/instructions.rs b/extensions/native/compiler/src/ir/instructions.rs index a4932d2826..8658a8a06e 100644 --- a/extensions/native/compiler/src/ir/instructions.rs +++ b/extensions/native/compiler/src/ir/instructions.rs @@ -208,13 +208,14 @@ pub enum DslIr { /// Permutes an array of Bn254 elements using Poseidon2 (output = p2_permute(array)). Should /// only be used when target is a circuit. CircuitPoseidon2Permute([Var; 3]), - /// Absorbs an array of baby bear elements into a duplex transcript with Poseidon2 permutations - /// (output = p2_multi_observe(array, els)). + /// Absorbs an array of baby bear elements into a duplex transcript with Poseidon2 permutations. + /// Context values (init_pos, len, is_hint) are passed via a context array instead of separate registers. + /// When is_hint=1, data is read from hint space using hint_id instead of from input array pointer. Poseidon2MultiObserve( - Ptr, // sponge_state - Var, // initial input_ptr position - Ptr, // input array (base elements) - Usize, // len of els + Ptr, // sponge_state + Ptr, // ctx_ptr (context array: [init_pos, len, is_hint, reserved]) + Ptr, // input array (base elements; used when is_hint=0) + Var, // hint_id (hint space id; used when is_hint=1) ), // Miscellaneous instructions. diff --git a/extensions/native/compiler/src/ir/poseidon.rs b/extensions/native/compiler/src/ir/poseidon.rs index 6d32f89409..6310917c0d 100644 --- a/extensions/native/compiler/src/ir/poseidon.rs +++ b/extensions/native/compiler/src/ir/poseidon.rs @@ -19,6 +19,8 @@ impl Builder { sponge_state: &Array>, input_ptr: Ptr, arr: &Array>, + input_len: Usize, + hint_id: Option>, ) -> Usize { let buffer_size: Var = Var::uninit(self); self.assign(&buffer_size, C::N::from_canonical_usize(HASH_RATE)); @@ -31,19 +33,40 @@ impl Builder { Array::Fixed(_) => { panic!("Base elements input must be dynamic"); } - Array::Dyn(ptr, len) => { + Array::Dyn(ptr, _) => { let init_pos: Var = Var::uninit(self); self.assign(&init_pos, input_ptr.address - sponge_ptr.address); + let is_hint = hint_id.is_some(); + let hint_id_var: Var = if let Some(id) = hint_id { + id + } else { + let v: Var = Var::uninit(self); + self.assign(&v, C::N::ZERO); + v + }; + + // Allocate context array: [init_pos, len, is_hint, reserved] + let ctx = self.dyn_array::>(4usize); + self.set(&ctx, 0, init_pos); + self.set(&ctx, 1, input_len.get_var()); + self.set( + &ctx, + 2, + if is_hint { C::N::ONE } else { C::N::ZERO }, + ); + self.set(&ctx, 3, C::N::ZERO); + self.operations.push(DslIr::Poseidon2MultiObserve( *sponge_ptr, - init_pos, + ctx.ptr(), *ptr, - len.clone(), + hint_id_var, )); - // automatically updated by Poseidon2MultiObserve operation - Usize::Var(init_pos) + // Read back the updated init_pos from ctx[0] + let final_pos: Var = self.get(&ctx, 0); + Usize::Var(final_pos) } }, } diff --git a/extensions/native/recursion/src/challenger/duplex.rs b/extensions/native/recursion/src/challenger/duplex.rs index 440b14ec59..10b9bc62e9 100644 --- a/extensions/native/recursion/src/challenger/duplex.rs +++ b/extensions/native/recursion/src/challenger/duplex.rs @@ -81,7 +81,7 @@ impl DuplexChallengerVariable { // This is equivalent to calling `observe` multiple times, but more efficient. pub fn observe_slice_opt(&self, builder: &mut Builder, arr: &Array>) { builder.if_ne(arr.len(), Usize::from(0)).then(|builder| { - let next_pos = builder.poseidon2_multi_observe(&self.sponge_state, self.input_ptr, arr); + let next_pos = builder.poseidon2_multi_observe(&self.sponge_state, self.input_ptr, arr, arr.len(), None); builder.assign(&self.input_ptr, self.io_empty_ptr + next_pos.clone()); builder.if_ne(next_pos, Usize::from(0)).then_or_else(