diff --git a/.gitignore b/.gitignore index 55d8a4c3a..2c2fca099 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ /target +/target-* +/.lm-* .vscode /docs/benchmark_graphs/.venv minimal_zkVM.synctex.gz diff --git a/Cargo.lock b/Cargo.lock index a1e508d94..0bb8c2bce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -104,7 +104,7 @@ dependencies = [ "mt-symetric", "mt-utils", "mt-whir", - "rayon", + "parallel", "tracing", ] @@ -138,6 +138,16 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "cc" +version = "1.2.63" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "556e016178bb5662a08681bbe0f00f8e17631781a4dfc8c45e466e4b185ec27f" +dependencies = [ + "find-msvc-tools", + "shlex", +] + [[package]] name = "cfg-if" version = "1.0.4" @@ -240,31 +250,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" -[[package]] -name = "crossbeam-deque" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" -dependencies = [ - "crossbeam-epoch", - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-epoch" -version = "0.9.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" -dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-utils" -version = "0.8.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" - [[package]] name = "crypto-common" version = "0.1.7" @@ -284,6 +269,12 @@ dependencies = [ "hybrid-array", ] +[[package]] +name = "cty" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b365fabc795046672053e29c954733ec3b05e4be654ab130fe8f1f94d7051f35" + [[package]] name = "digest" version = "0.10.7" @@ -329,6 +320,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + [[package]] name = "foldhash" version = "0.1.5" @@ -493,6 +490,9 @@ dependencies = [ "backend", "clap", "lean_vm", + "libc", + "libmimalloc-sys", + "mimalloc", "rand", "rec_aggregation", "serde_json", @@ -500,7 +500,6 @@ dependencies = [ "system-info", "utils", "xmss", - "zk-alloc", ] [[package]] @@ -565,6 +564,16 @@ version = "0.2.186" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" +[[package]] +name = "libmimalloc-sys" +version = "0.1.49" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a45a52f43e1c16f667ccfe4dd8c85b7f7c204fd5e3bf46c5b0db9a5c3c0b8e9" +dependencies = [ + "cc", + "cty", +] + [[package]] name = "lock_api" version = "0.4.14" @@ -604,6 +613,15 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +[[package]] +name = "mimalloc" +version = "0.1.52" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d4139bb28d14ad1facf21d5eb8825051b326e172d216b39f6d31df53cc97862" +dependencies = [ + "libmimalloc-sys", +] + [[package]] name = "mt-air" version = "0.1.0" @@ -620,7 +638,7 @@ dependencies = [ "mt-koala-bear", "mt-symetric", "mt-utils", - "rayon", + "parallel", "serde", "tracing", ] @@ -632,9 +650,9 @@ dependencies = [ "itertools", "mt-utils", "num-bigint", + "parallel", "paste", "rand", - "rayon", "serde", "tracing", ] @@ -649,7 +667,6 @@ dependencies = [ "num-bigint", "paste", "rand", - "rayon", "serde", "tracing", ] @@ -662,8 +679,8 @@ dependencies = [ "mt-field", "mt-koala-bear", "mt-utils", + "parallel", "rand", - "rayon", "serde", "system-info", ] @@ -677,7 +694,7 @@ dependencies = [ "mt-field", "mt-koala-bear", "mt-poly", - "rayon", + "parallel", "tracing", ] @@ -687,7 +704,7 @@ version = "0.1.0" dependencies = [ "mt-field", "mt-koala-bear", - "rayon", + "parallel", ] [[package]] @@ -709,8 +726,8 @@ dependencies = [ "mt-sumcheck", "mt-symetric", "mt-utils", + "parallel", "rand", - "rayon", "system-info", "tracing", "tracing-forest", @@ -791,6 +808,13 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "parallel" +version = "0.1.0" +dependencies = [ + "system-info", +] + [[package]] name = "paste" version = "1.0.15" @@ -910,26 +934,6 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba" -[[package]] -name = "rayon" -version = "1.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" -dependencies = [ - "either", - "rayon-core", -] - -[[package]] -name = "rayon-core" -version = "1.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" -dependencies = [ - "crossbeam-deque", - "crossbeam-utils", -] - [[package]] name = "rec_aggregation" version = "0.1.0" @@ -949,7 +953,6 @@ dependencies = [ "tracing", "utils", "xmss", - "zk-alloc", ] [[package]] @@ -1063,6 +1066,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shlex" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8fadd59c855ef2080decdef8ff161eb6661b86933c9d82e5ba29dc602a55aba" + [[package]] name = "smallvec" version = "1.15.1" @@ -1118,7 +1127,6 @@ name = "system-info" version = "0.1.0" dependencies = [ "libc", - "rayon", ] [[package]] @@ -1473,15 +1481,6 @@ dependencies = [ "utils", ] -[[package]] -name = "zk-alloc" -version = "0.1.0" -dependencies = [ - "libc", - "rayon", - "system-info", -] - [[package]] name = "zmij" version = "1.0.21" diff --git a/Cargo.toml b/Cargo.toml index f8e2ada76..ce778c312 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ members = [ "crates/backend/fiat-shamir", "crates/backend/sumcheck", "crates/backend/system-info", - "crates/backend/zk-alloc", + "crates/backend/parallel", ] [workspace.lints] @@ -61,14 +61,13 @@ lean_compiler = { path = "crates/lean_compiler" } lean_prover = { path = "crates/lean_prover" } rec_aggregation = { path = "crates/rec_aggregation" } backend = { path = "crates/backend" } -zk-alloc = { path = "crates/backend/zk-alloc" } system-info = { path = "crates/backend/system-info" } +parallel = { path = "crates/backend/parallel" } # External sha3 = "0.11.0" clap = { version = "4.5.59", features = ["derive"] } rand = "0.10.0" -rayon = "1.11.0" pest = "2.7" pest_derive = "2.7" itertools = "0.14.0" @@ -83,12 +82,15 @@ include_dir = "0.7" [features] prox-gaps-conjecture = ["rec_aggregation/prox-gaps-conjecture"] -standard-alloc = ["rec_aggregation/standard-alloc"] +# Build with the plain system allocator instead of mimalloc (for comparison/debugging). +standard-alloc = [] [dependencies] clap.workspace = true rec_aggregation.workspace = true -zk-alloc.workspace = true +mimalloc = "0.1" +libmimalloc-sys = { version = "0.1", features = ["extended"] } +libc = "0.2" rand.workspace = true sub_protocols.workspace = true utils.workspace = true diff --git a/crates/backend/Cargo.toml b/crates/backend/Cargo.toml index 3f61957af..d56cf56a8 100644 --- a/crates/backend/Cargo.toml +++ b/crates/backend/Cargo.toml @@ -9,7 +9,7 @@ poly = { path = "poly", package = "mt-poly" } sumcheck = { path = "sumcheck", package = "mt-sumcheck" } field = { path = "field", package = "mt-field" } air = { path = "air", package = "mt-air" } -rayon.workspace = true +parallel.workspace = true whir = { path = "../whir", package = "mt-whir" } tracing.workspace = true fiat-shamir = { path = "fiat-shamir", package = "mt-fiat-shamir" } diff --git a/crates/backend/fiat-shamir/Cargo.toml b/crates/backend/fiat-shamir/Cargo.toml index ec8649bc2..57f32d2ba 100644 --- a/crates/backend/fiat-shamir/Cargo.toml +++ b/crates/backend/fiat-shamir/Cargo.toml @@ -10,4 +10,4 @@ symetric = { path = "../symetric", package = "mt-symetric" } utils = { path = "../utils", package = "mt-utils" } tracing.workspace = true serde.workspace = true -rayon.workspace = true +parallel.workspace = true diff --git a/crates/backend/fiat-shamir/src/prover.rs b/crates/backend/fiat-shamir/src/prover.rs index 80bb6d13e..79d3859bb 100644 --- a/crates/backend/fiat-shamir/src/prover.rs +++ b/crates/backend/fiat-shamir/src/prover.rs @@ -9,8 +9,7 @@ use field::PrimeCharacteristicRing; use field::integers::QuotientMap; use field::{ExtensionField, PrimeField64}; use koala_bear::symmetric::Permutation; -use rayon::prelude::*; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::time::Duration; use std::{fmt::Debug, sync::Mutex, time::Instant}; @@ -132,9 +131,22 @@ where let witness_found = Mutex::>>::new(None); // each batch tests lanes witnesses simultaneously let num_batches = PF::::ORDER_U64.div_ceil(lanes as u64); - (0..num_batches) - .into_par_iter() - .find_any(|&batch| { + + // Parallel short-circuiting search (replaces rayon `find_any`): spawn one + // searcher per worker, each claiming batches from a shared counter and bailing + // as soon as any worker finds a witness. Bounds the work to ~expected + a few + // extra batches instead of enumerating all `num_batches` (which can be ~2^31). + let next_batch = AtomicU64::new(0); + let found = AtomicBool::new(false); + parallel::for_each_index(parallel::num_threads(), |_| { + loop { + if found.load(Ordering::Relaxed) { + return; + } + let batch = next_batch.fetch_add(1, Ordering::Relaxed); + if batch >= num_batches { + return; + } let base = batch * lanes as u64; let packed_witnesses = Packed::::from_fn(|lane| { @@ -159,12 +171,13 @@ where let rand_usize = sample.as_canonical_u64() as usize; if (rand_usize & ((1 << bits) - 1)) == 0 { *witness_found.lock().unwrap() = Some(*witness); - return true; + found.store(true, Ordering::Relaxed); + return; } } - false - }) - .expect("failed to find witness"); + } + }); + assert!(found.load(Ordering::Relaxed), "failed to find witness"); let witness = witness_found.lock().unwrap().unwrap(); diff --git a/crates/backend/field/Cargo.toml b/crates/backend/field/Cargo.toml index 89e87c133..cde41bb61 100644 --- a/crates/backend/field/Cargo.toml +++ b/crates/backend/field/Cargo.toml @@ -9,7 +9,7 @@ utils = { path = "../utils", package = "mt-utils" } itertools.workspace = true num-bigint = "*" paste = "*" +parallel.workspace = true rand.workspace = true -rayon.workspace = true serde.workspace = true tracing.workspace = true diff --git a/crates/backend/field/src/field.rs b/crates/backend/field/src/field.rs index b44ed45ed..836529cf9 100644 --- a/crates/backend/field/src/field.rs +++ b/crates/backend/field/src/field.rs @@ -9,7 +9,6 @@ use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAss use core::{array, slice}; use num_bigint::BigUint; -use rayon::{current_num_threads, prelude::*}; use serde::Serialize; use serde::de::DeserializeOwned; use utils::{flatten_to_base, iter_array_chunks_padded}; @@ -1020,7 +1019,7 @@ impl BoundedPowers { let mut points_packed = F::Packing::zero_vec(num_packed); // Split computation evenly among threads - let num_threads = current_num_threads().max(1); + let num_threads = parallel::num_threads().max(1); let chunk_size = num_packed.div_ceil(num_threads); // Precompute base for each chunk. @@ -1028,16 +1027,13 @@ impl BoundedPowers { let chunk_base = base.exp_u64((chunk_size * width) as u64); let shift = self.iter.current; - points_packed - .par_chunks_mut(chunk_size) - .enumerate() - .for_each(|(chunk_idx, chunk_slice)| { - // First power in this chunk - let chunk_start = shift * chunk_base.exp_u64(chunk_idx as u64); + parallel::par_chunks_mut(&mut points_packed, chunk_size, |chunk_idx, chunk_slice| { + // First power in this chunk + let chunk_start = shift * chunk_base.exp_u64(chunk_idx as u64); - // Fill the chunk with packed powers. - F::Packing::packed_shifted_powers(base, chunk_start).fill(chunk_slice); - }); + // Fill the chunk with packed powers. + F::Packing::packed_shifted_powers(base, chunk_start).fill(chunk_slice); + }); // return the number of requested points, discarding the unused packed powers // SAFETY: size_of:: always divides size_of::. diff --git a/crates/backend/koala-bear/Cargo.toml b/crates/backend/koala-bear/Cargo.toml index aba2ab231..5ce4ad111 100644 --- a/crates/backend/koala-bear/Cargo.toml +++ b/crates/backend/koala-bear/Cargo.toml @@ -8,7 +8,6 @@ field = { path = "../field", package = "mt-field" } utils = { path = "../utils", package = "mt-utils" } rand.workspace = true -rayon.workspace = true serde.workspace = true itertools.workspace = true tracing.workspace = true diff --git a/crates/backend/parallel/Cargo.toml b/crates/backend/parallel/Cargo.toml new file mode 100644 index 000000000..731b5163d --- /dev/null +++ b/crates/backend/parallel/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "parallel" +version.workspace = true +edition.workspace = true +description = "Minimal fixed-size thread pool for static data-parallel kernels" + +[dependencies] +system-info.workspace = true + +[lints] +workspace = true diff --git a/crates/backend/parallel/src/lib.rs b/crates/backend/parallel/src/lib.rs new file mode 100644 index 000000000..0fca9a4b9 --- /dev/null +++ b/crates/backend/parallel/src/lib.rs @@ -0,0 +1,439 @@ +//! Minimal fixed-size thread pool for flat data-parallel kernels: "split a range into +//! pieces, run a closure on each." No nested work-stealing, no per-dispatch allocation — +//! owning the runtime lets us attach per-worker scratch buffers and drop rayon. +//! +//! ## Model +//! +//! `NUM_THREADS - 1` background workers (ids `1..NUM_THREADS`); the dispatching thread is +//! worker `0` and runs its share inline (no oversubscription). Tasks are claimed from a +//! shared atomic counter for dynamic load balancing. +//! +//! ## Lock-free dispatch +//! +//! A mutex+condvar wake-up costs ~2x per dispatch. Instead, dispatch bumps a `generation` +//! counter that idle workers **spin** on (back-to-back dispatches pay no syscall), parking +//! only after `SPIN_LIMIT` unrewarded spins. Completion is a lock-free `working` countdown +//! the dispatcher spins on. The per-worker `parked` flag is ordered SeqCst against +//! `generation` so no wake-up is lost and the unpark syscall is skipped while workers spin. +//! +//! ## Nesting is forbidden +//! +//! A flat pool can't dispatch from inside a task (would deadlock the dispatch lock), so a +//! dispatch issued from within a pool task panics. The outer level has already saturated all +//! cores, so nested parallelism would buy nothing anyway. +//! +//! ## Constraint +//! +//! One dispatcher at a time: concurrent (non-nested) dispatches are serialized by a mutex. + +use std::cell::{Cell, UnsafeCell}; +use std::ptr::NonNull; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::{Mutex, Once, OnceLock}; +use std::thread::Thread; + +use system_info::NUM_THREADS; + +/// Idle spins before a worker parks. Tuned so back-to-back dispatches stay hot but +/// sequential gaps let workers sleep, freeing cores for the active thread. +const SPIN_LIMIT: u32 = 1 << 12; + +/// Cap on a single guided-self-scheduling claim (see [`drain`]). Bounds worst-case load +/// imbalance while keeping million-task kernels to a few thousand claims. +const MAX_CLAIM_BATCH: usize = 1 << 12; + +/// Total worker count (including the dispatching thread). Equal to build-time `NUM_THREADS`. +#[must_use] +pub const fn num_threads() -> usize { + NUM_THREADS +} + +/// Chunk size for a flat fan-out: a few chunks per worker so the atomic counter can +/// rebalance heterogeneous cores, coarse enough to amortize dispatch. +#[must_use] +#[inline] +pub fn recommended_chunk_size(n_items: usize) -> usize { + n_items.div_ceil(NUM_THREADS * 4).max(1) +} + +thread_local! { + /// Stable id of this thread within the pool. Set once per background worker; + /// stays `0` on the dispatching thread (worker 0) and on any non-worker thread. + static WORKER_ID: Cell = const { Cell::new(0) }; + /// True while this thread is executing a pool task. A dispatch issued in that state is + /// nested parallelism, which is forbidden and panics (see module docs). + static IN_TASK: Cell = const { Cell::new(false) }; +} + +/// Stable id of the calling worker, in `0..NUM_THREADS` (`0` off-pool). The hook for +/// per-worker scratch buffers. +#[must_use] +pub fn current_worker_id() -> usize { + WORKER_ID.with(Cell::get) +} + +/// Type-erased unit of work: a `&(dyn Fn(usize, usize) + Sync)` whose lifetime is erased +/// to `'static`. Only dereferenced inside a dispatch window during which the dispatcher +/// blocks, so the source borrow outlives every call. +struct Job { + /// Range-based work: `f(start, end)` processes the half-open task range `start..end`. + /// Building the primitive on ranges (rather than single indices) lets reductions look + /// up their per-worker accumulator once per claimed batch instead of once per element. + f: NonNull, + n_tasks: usize, +} + +struct Pool { + /// Current job. Written by the dispatcher before bumping `generation`, read by + /// workers after they observe the bump; `generation` supplies the happens-before. + job: UnsafeCell>, + /// Bumped once per dispatch. Idle workers watch it (spin, then park). + generation: AtomicUsize, + /// Next task index to claim. Reset to 0 before each dispatch. + counter: AtomicUsize, + /// Background workers still draining the current dispatch; dispatcher spins to 0. + working: AtomicUsize, + shutdown: AtomicBool, + /// Per-worker "currently parked" flags (indexed by worker id; slot 0 unused). + parked: Vec, + /// Per-worker thread handles for `unpark` (indexed by worker id; slot 0 unused). + handles: Vec>, + /// Serializes dispatchers: only one thread may drive the pool at a time. + dispatch: Mutex<()>, +} + +// SAFETY: `job` is mutated only by the unique dispatcher while workers are parked or +// before they observe the generation bump, and read only after; the `generation` +// release/acquire (and SeqCst park protocol) order these phases. The erased `Job` +// pointer is never used outside a dispatch window during which its borrow is live. +unsafe impl Sync for Pool {} +unsafe impl Send for Pool {} + +/// Idempotent warm-up: spawn workers and run one dispatch so the pool, parkers, and the +/// mutex's lazily-allocated `pthread_mutex_t` (macOS) exist before timed proving work. +/// Without it the pool initializes lazily on the first parallel call. +pub fn init() { + static INIT: Once = Once::new(); + INIT.call_once(|| { + let _ = pool(); + if NUM_THREADS > 1 { + for_each_index(NUM_THREADS, |_| {}); + } + }); +} + +fn pool() -> &'static Pool { + static POOL: OnceLock<&'static Pool> = OnceLock::new(); + POOL.get_or_init(|| { + let n = NUM_THREADS.max(1); + let p: &'static Pool = Box::leak(Box::new(Pool { + job: UnsafeCell::new(None), + generation: AtomicUsize::new(0), + counter: AtomicUsize::new(0), + working: AtomicUsize::new(0), + shutdown: AtomicBool::new(false), + parked: (0..n).map(|_| AtomicBool::new(false)).collect(), + handles: (0..n).map(|_| OnceLock::new()).collect(), + dispatch: Mutex::new(()), + })); + for id in 1..n { + std::thread::Builder::new() + .name(format!("parallel-worker-{id}")) + .spawn(move || worker_main(p, id)) + .expect("failed to spawn pool worker"); + } + p + }) +} + +fn worker_main(pool: &'static Pool, id: usize) { + WORKER_ID.with(|c| c.set(id)); + let _ = pool.handles[id].set(std::thread::current()); + + let mut last_gen = 0usize; + loop { + let mut spins = 0u32; + let g = loop { + let g = pool.generation.load(Ordering::Acquire); + if g != last_gen { + break g; + } + if pool.shutdown.load(Ordering::Acquire) { + return; + } + if spins < SPIN_LIMIT { + spins += 1; + std::hint::spin_loop(); + } else { + // About to park. Publish it, then re-check `generation`: by SeqCst + // total order with the dispatcher's `generation` bump and `parked` + // load, at least one side sees the other, so no wake-up is lost. + pool.parked[id].store(true, Ordering::SeqCst); + if pool.generation.load(Ordering::SeqCst) != last_gen { + pool.parked[id].store(false, Ordering::SeqCst); + } else if pool.shutdown.load(Ordering::SeqCst) { + pool.parked[id].store(false, Ordering::SeqCst); + return; + } else { + std::thread::park(); + pool.parked[id].store(false, Ordering::SeqCst); + } + spins = 0; + } + }; + last_gen = g; + drain(pool); + pool.working.fetch_sub(1, Ordering::Release); + } +} + +/// Claim and run task indices until the counter is exhausted, using **guided +/// self-scheduling**: each claim grabs `remaining / (NUM_THREADS * 2)`, capped at +/// [`MAX_CLAIM_BATCH`]. Large early batches keep counter contention low; the proportional +/// shrink keeps the tail finely divided for load balance. +fn drain(pool: &Pool) { + // SAFETY: the dispatcher published `Some(job)` before the generation bump this + // worker just observed, and overwrites it only on the next dispatch (gated on + // `working == 0`); nobody writes during drain. + let job = unsafe { (*pool.job.get()).as_ref().expect("drain without a published job") }; + // SAFETY: `job.f` points at a `&dyn Fn` borrow held live by the blocked dispatcher. + let f = unsafe { job.f.as_ref() }; + let n = job.n_tasks; + // Mark this thread as in-task so a nested dispatch panics (see `for_each_chunk`) + // rather than deadlocking on the dispatch lock. + let prev = IN_TASK.replace(true); + loop { + // Stale counter read only affects batch granularity, not correctness: `fetch_add` + // tiles `0..n` into non-overlapping claims, and out-of-range tails are clamped. + let observed = pool.counter.load(Ordering::Relaxed); + if observed >= n { + break; + } + let batch = ((n - observed) / (NUM_THREADS * 2)).clamp(1, MAX_CLAIM_BATCH); + let start = pool.counter.fetch_add(batch, Ordering::Relaxed); + if start >= n { + break; + } + let end = (start + batch).min(n); + f(start, end); + } + IN_TASK.set(prev); +} + +/// Core dispatch: run `f(start, end)` over disjoint contiguous sub-ranges that together +/// tile `0..n_tasks`, in parallel across the pool. The ranges are produced by guided +/// self-scheduling (see [`drain`]); a worker may receive several. Blocks until all +/// complete; the dispatching thread participates as worker 0. +/// +/// This is the primitive everything else is built on. Range-based (rather than per-index) +/// so a reduction can look up its per-worker accumulator once per claimed batch. +pub fn for_each_chunk(n_tasks: usize, f: F) { + // Nesting is forbidden (would deadlock the dispatch lock): panic rather than silently + // running sequentially, so an accidental nested dispatch is caught instead of going slow. + assert!(!IN_TASK.get(), "nested parallel dispatch from within a pool task"); + + // Trivial sizes and single-core builds run sequentially inline. + if NUM_THREADS <= 1 || n_tasks <= 1 { + if n_tasks > 0 { + f(0, n_tasks); + } + return; + } + + let pool = pool(); + let _guard = pool.dispatch.lock().unwrap(); + let n = NUM_THREADS; + + // SAFETY: erase the borrow's lifetime so the closure can live in the `'static` + // `Job`. The dispatcher blocks on `working` below before returning, so `f` + // outlives every worker dereference of this pointer. + let f_ref: &(dyn Fn(usize, usize) + Sync) = &f; + let f_erased: NonNull = unsafe { std::mem::transmute(NonNull::from(f_ref)) }; + + // SAFETY: all workers finished the previous dispatch (we waited for `working == 0`) + // and none observes this one until the generation bump, so we are the sole writer. + unsafe { *pool.job.get() = Some(Job { f: f_erased, n_tasks }) }; + pool.counter.store(0, Ordering::Relaxed); + pool.working.store(n - 1, Ordering::Release); + + // Publish the dispatch. SeqCst so the parked-flag protocol can't lose a wake-up. + pool.generation.fetch_add(1, Ordering::SeqCst); + + // Wake only workers that actually parked; hot (spinning) ones see the bump for free. + for id in 1..n { + if pool.parked[id].load(Ordering::SeqCst) + && let Some(t) = pool.handles[id].get() + { + t.unpark(); + } + } + + drain(pool); // dispatcher runs as worker 0 + + // Lock-free completion: wait for every background worker to finish draining. + while pool.working.load(Ordering::Acquire) != 0 { + std::hint::spin_loop(); + } +} + +/// Run `f(i)` for every `i` in `0..n_tasks`, in parallel across the pool. Blocks until +/// all tasks complete; the dispatching thread participates as worker 0. +pub fn for_each_index(n_tasks: usize, f: F) { + for_each_chunk(n_tasks, |start, end| { + for i in start..end { + f(i); + } + }); +} + +/// A raw base pointer shareable across pool workers. Sound only because callers partition +/// the allocation by task index, so each worker touches a disjoint region. Reuse this for +/// the "share a `*mut` across a dispatch" pattern instead of redefining it per crate. +#[derive(Debug)] +pub struct SendPtr(pub *mut T); +// SAFETY: accesses are partitioned by task index (see callers). +unsafe impl Send for SendPtr {} +unsafe impl Sync for SendPtr {} + +impl SendPtr { + /// Offset the base pointer by `n` elements. + /// + /// # Safety + /// `n` must keep the result within the original allocation, and any write through it + /// must target a slot no other concurrent task touches. + #[inline] + pub unsafe fn add(&self, n: usize) -> *mut T { + unsafe { self.0.add(n) } + } + + /// Reconstruct the `len`-long slice starting at element offset `off`. + /// + /// # Safety + /// `off`/`len` must stay in-bounds and the slice must be disjoint from every other + /// slice any concurrent task reconstructs. + #[inline] + pub unsafe fn slice<'a>(&self, off: usize, len: usize) -> &'a mut [T] { + unsafe { std::slice::from_raw_parts_mut(self.0.add(off), len) } + } +} + +/// Parallel equivalent of `data.chunks_mut(chunk).enumerate().for_each(...)`, running +/// `f(chunk_index, chunk)` on each in parallel. The final chunk may be shorter. +pub fn par_chunks_mut(data: &mut [T], chunk: usize, f: F) +where + F: Fn(usize, &mut [T]) + Sync, +{ + assert!(chunk > 0, "chunk size must be non-zero"); + let len = data.len(); + let n_chunks = len.div_ceil(chunk); + let base = SendPtr(data.as_mut_ptr()); + + for_each_index(n_chunks, |i| { + let start = i * chunk; + let this_len = chunk.min(len - start); + // SAFETY: distinct `i` produce non-overlapping in-bounds ranges, and `data` + // stays borrowed for the whole call. + let slice = unsafe { std::slice::from_raw_parts_mut(base.add(start), this_len) }; + f(i, slice); + }); +} + +/// Parallel `data.iter_mut().enumerate().for_each(...)`, sized with +/// [`recommended_chunk_size`]. Hands the closure each element's **global** index. +/// `#[inline]` to fold into the per-region caller (recovers the hand-written codegen). +#[inline] +pub fn par_for_each_mut(data: &mut [T], f: F) +where + F: Fn(usize, &mut T) + Sync, +{ + let chunk = recommended_chunk_size(data.len()); + par_chunks_mut(data, chunk, |ci, sub| { + let base = ci * chunk; + for (k, slot) in sub.iter_mut().enumerate() { + f(base + k, slot); + } + }); +} + +/// Give each worker exclusive, persistent access to its own `Option` slot while it +/// drains `0..n_tasks`: `run(slot, start, end)` is called once per claimed batch, always +/// with the same slot for a given worker, so state accumulates across the batches it +/// claims. Returns the per-worker slots (one per worker that ran, rest `None`) for the +/// caller to combine. This is the single home of the cross-worker slot `unsafe`. +fn drain_into_slots(n_tasks: usize, run: impl Fn(&mut Option, usize, usize) + Sync) -> Vec> { + let mut slots: Vec> = (0..NUM_THREADS).map(|_| None).collect(); + let ptr = SendPtr(slots.as_mut_ptr()); + for_each_chunk(n_tasks, |start, end| { + // SAFETY: `current_worker_id()` is unique per live worker and < NUM_THREADS, so + // workers touch disjoint slots; `slots` outlives the dispatch. + let slot = unsafe { &mut *ptr.add(current_worker_id()) }; + run(slot, start, end); + }); + slots +} + +/// Parallel map-reduce over `0..n_tasks`, equivalent to +/// `(0..n_tasks).into_par_iter().map(map).reduce(identity, reduce)`. +/// +/// Each worker folds the task indices it claims into one local accumulator (one per +/// worker, not per task); the per-worker partials are combined on the dispatcher. +/// `reduce` must be associative. +pub fn map_reduce(n_tasks: usize, identity: ID, map: M, reduce: R) -> T +where + T: Send, + ID: Fn() -> T, + M: Fn(usize) -> T + Sync, + R: Fn(T, T) -> T + Sync, +{ + if NUM_THREADS <= 1 || n_tasks <= 1 { + return (0..n_tasks).fold(identity(), |acc, i| reduce(acc, map(i))); + } + let slots = drain_into_slots(n_tasks, |slot, start, end| { + // Fold the batch into the worker's accumulator, taking/replacing the shared slot + // just once so per-element writes stay off the cross-worker pointer. + let acc = (start..end).fold(slot.take(), |acc, i| { + Some(match acc { + Some(a) => reduce(a, map(i)), + None => map(i), + }) + }); + *slot = acc; + }); + slots.into_iter().flatten().fold(identity(), &reduce) +} + +/// Parallel reduce where each worker keeps reusable **scratch** alongside its +/// accumulator, so the per-task body can avoid allocating. Each worker creates +/// `(scratch, acc)` once on its first task and reuses the scratch across all the +/// tasks it claims; the per-worker `acc`s are then combined. `combine` must be +/// associative. +pub fn map_reduce_with_state(n_tasks: usize, init_state: IS, init_acc: IA, fold: F, combine: C) -> A +where + S: Send, + A: Send, + IS: Fn() -> S + Sync, + IA: Fn() -> A + Sync, + F: Fn(&mut S, &mut A, usize) + Sync, + C: Fn(A, A) -> A, +{ + if NUM_THREADS <= 1 || n_tasks <= 1 { + let mut state = init_state(); + let mut acc = init_acc(); + for i in 0..n_tasks { + fold(&mut state, &mut acc, i); + } + return acc; + } + let slots = drain_into_slots(n_tasks, |slot, start, end| { + // Scratch and accumulator are created once and threaded through every batch. + let (state, acc) = slot.get_or_insert_with(|| (init_state(), init_acc())); + for i in start..end { + fold(state, acc, i); + } + }); + slots + .into_iter() + .flatten() + .map(|(_, acc)| acc) + .fold(init_acc(), &combine) +} diff --git a/crates/backend/poly/Cargo.toml b/crates/backend/poly/Cargo.toml index dcdf80aed..f198a2d19 100644 --- a/crates/backend/poly/Cargo.toml +++ b/crates/backend/poly/Cargo.toml @@ -7,9 +7,9 @@ edition.workspace = true field = { path = "../field", package = "mt-field" } utils = { path = "../utils", package = "mt-utils" } system-info.workspace = true +parallel.workspace = true itertools.workspace = true -rayon.workspace = true rand.workspace = true serde.workspace = true diff --git a/crates/backend/poly/src/eq_mle.rs b/crates/backend/poly/src/eq_mle.rs index 64d3733f5..16c59f015 100644 --- a/crates/backend/poly/src/eq_mle.rs +++ b/crates/backend/poly/src/eq_mle.rs @@ -2,13 +2,68 @@ use crate::*; use crate::{EFPacking, PF}; use ::utils::{iter_array_chunks_padded, log2_ceil_usize, log2_strict_usize}; use field::*; -use rayon::prelude::*; use system_info::NUM_THREADS; const LOG_NUM_THREADS: usize = log2_ceil_usize(NUM_THREADS); -const NUM_THREADS_PADDED: usize = 1 << LOG_NUM_THREADS; const LOG_BATCHED_TILE_SIZE: usize = 14; +/// log2 oversubscription for the eq_mle fan-out: emit `NUM_THREADS << this` chunks so the +/// pool's task counter rebalances across heterogeneous cores (e.g. P/E). `0` = one chunk +/// per worker. Default `2` (4x) is conservative; a runtime knob so the benchmark can sweep. +pub static PARALLEL_LOG_OVERSUB: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(2); + +/// `(log2(n_chunks), n_chunks)` for the parallel fan-out, honoring [`PARALLEL_LOG_OVERSUB`]. +#[inline] +fn parallel_split() -> (usize, usize) { + let log_chunks = LOG_NUM_THREADS + PARALLEL_LOG_OVERSUB.load(std::sync::atomic::Ordering::Relaxed); + (log_chunks, 1 << log_chunks) +} + +/// Parallel equivalent of +/// `out.par_chunks_exact_mut(chunk).zip(buf).enumerate().for_each(|(i, (c, _))| g(i, c, &buf[i]))`, +/// dispatched through the in-house [`parallel`] pool. `chunk` must divide `out.len()` +/// exactly into `buf.len()` chunks (the eq_mle fan-out always does). +#[inline] +fn par_chunks_zip(out: &mut [T], chunk: usize, buf: &[A], g: G) +where + T: Send, + A: Sync, + G: Fn(&mut [T], &A) + Sync, +{ + parallel::par_chunks_mut(out, chunk, |i, c| g(c, &buf[i])); +} + +/// Shared parallel tail of the `compute_eval_eq*` family. With `eval` split into +/// `log_chunks` leading variables (handled one-per-chunk), `log_packing_width` trailing +/// variables already folded into `seed = buffer[0]`, and the middle variables left for +/// `kernel`, this builds the per-chunk equality buffer and runs +/// `kernel(middle, out_chunk, buffer_val)` over `out` in parallel. `kernel` fires once +/// per chunk (not per element), so threading it through a closure costs nothing. +#[inline] +fn par_eval_eq( + eval: &[In], + out: &mut [Out], + log_chunks: usize, + n_chunks: usize, + log_packing_width: usize, + seed: Buf, + kernel: impl Fn(&[In], &mut [Out], Buf) + Sync, +) where + In: Field, + Buf: Algebra + Copy + Send + Sync, + Out: Send, +{ + let mut buffer = Buf::zero_vec(n_chunks); + buffer[0] = seed; + fill_buffer(eval[..log_chunks].iter().rev(), &mut buffer); + + let out_chunk_size = out.len() / n_chunks; + let middle = &eval[log_chunks..(eval.len() - log_packing_width)]; + par_chunks_zip(out, out_chunk_size, &buffer, |out_chunk, buffer_val| { + kernel(middle, out_chunk, *buffer_val); + }); +} + /// Given `evals` = (α_1, ..., α_n), returns a multilinear polynomial P in n variables, /// defined on the boolean hypercube by: ∀ (x_1, ..., x_n) ∈ {0, 1}^n, /// P(x_1, ..., x_n) = Π_{i=1}^{n} (x_i.α_i + (1 - x_i).(1 - α_i)) @@ -87,62 +142,33 @@ where F: Field, EF: ExtensionField, { - // It's possible for this to be called with F = EF (Despite F actually being an extension field). - // - // IMPORTANT: We previously checked here that `packing_width > 1`, - // but this check is **not viable** for Goldilocks on Neon or when not using `target-cpu=native`. - // - // Why? Because Neon SIMD vectors are 128 bits and Goldilocks elements are already 64 bits, - // so no packing happens (width stays 1), and there's no performance advantage. - // - // Be careful: this means code relying on packing optimizations should **not assume** - // `packing_width > 1` is always true. + // `packing_width` may be 1 (e.g. Goldilocks on Neon, or without `target-cpu=native`), + // so nothing here may assume it is > 1. let log_packing_width = log2_strict_usize(F::Packing::WIDTH); - - // Ensure that the output buffer size is correct: - // It should be of size `2^n`, where `n` is the number of variables. debug_assert_eq!(out.len(), 1 << eval.len()); - // If the number of variables is small, there is no need to use - // parallelization or packings. - if eval.len() <= log_packing_width + 1 + LOG_NUM_THREADS { - // A basic recursive approach. + let (log_chunks, n_chunks) = parallel_split(); + if eval.len() <= log_packing_width + 1 + log_chunks { + // Too small to be worth packing/parallelizing. eval_eq_basic::<_, _, _, INITIALIZED>(eval, out, scalar); return; } - let eval_len_min_packing = eval.len() - log_packing_width; - - // We split eval into three parts: - // - eval[..LOG_NUM_THREADS] (the first LOG_NUM_THREADS elements) - // - eval[LOG_NUM_THREADS..eval_len_min_packing] (the middle elements) - // - eval[eval_len_min_packing..] (the last log_packing_width elements) - - // The middle elements are the ones which will be computed in parallel. - // The last log_packing_width elements are the ones which will be packed. - - // We make a buffer of elements of size `NUM_THREADS`. - let mut parallel_buffer = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED); - let out_chunk_size = out.len() / NUM_THREADS_PADDED; - - // Compute the equality polynomial corresponding to the last log_packing_width elements - // and pack these. - parallel_buffer[0] = packed_eq_poly(&eval[eval_len_min_packing..], scalar); - - // Update the buffer so it contains the evaluations of the equality polynomial - // with respect to parts one and three. - fill_buffer(eval[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer); - - // Finally do all computations involving the middle elements in parallel. - out.par_chunks_exact_mut(out_chunk_size) - .zip(parallel_buffer.par_iter()) - .for_each(|(out_chunk, buffer_val)| { - eval_eq_with_packed_scalar::<_, _, INITIALIZED>( - &eval[LOG_NUM_THREADS..(eval.len() - log_packing_width)], - out_chunk, - *buffer_val, - ); - }); + // Split `eval` into [leading `log_chunks` | middle | trailing `log_packing_width`]: the + // trailing vars fold into the per-chunk seed, the leading vars index the chunks, the + // middle runs in parallel. + let seed = packed_eq_poly(&eval[eval.len() - log_packing_width..], scalar); + par_eval_eq( + eval, + out, + log_chunks, + n_chunks, + log_packing_width, + seed, + |middle, out_chunk, buffer_val| { + eval_eq_with_packed_scalar::<_, _, INITIALIZED>(middle, out_chunk, buffer_val); + }, + ); } #[inline] @@ -150,16 +176,8 @@ pub fn compute_eval_eq_packed(eval: &[EF], out: &mu where EF: ExtensionField>, { - // It's possible for this to be called with F = EF (Despite F actually being an extension field). - // - // IMPORTANT: We previously checked here that `packing_width > 1`, - // but this check is **not viable** for Goldilocks on Neon or when not using `target-cpu=native`. - // - // Why? Because Neon SIMD vectors are 128 bits and Goldilocks elements are already 64 bits, - // so no packing happens (width stays 1), and there's no performance advantage. - // - // Be careful: this means code relying on packing optimizations should **not assume** - // `packing_width > 1` is always true. + // `packing_width` may be 1 (e.g. Goldilocks on Neon, or without `target-cpu=native`), + // so nothing here may assume it is > 1. let packing_width = packing_width::(); let log_packing_width = log2_strict_usize(packing_width); @@ -168,12 +186,13 @@ where // If the number of variables is small, there is no need to use // parallelization or packings. - if eval.len() <= log_packing_width + 1 + LOG_NUM_THREADS { - // A basic recursive approach. - let mut output_no_packing = EF::zero_vec(1 << eval.len()); - eval_eq_basic::<_, _, _, false>(eval, &mut output_no_packing, scalar); - out.par_iter_mut() - .zip(output_no_packing.par_chunks_exact(packing_width)) + let (log_chunks, n_chunks) = parallel_split(); + if eval.len() <= log_packing_width + 1 + log_chunks { + // Small case: evaluate unpacked, then pack lanes into `out`. + let mut unpacked = EF::zero_vec(1 << eval.len()); + eval_eq_basic::<_, _, _, false>(eval, &mut unpacked, scalar); + out.iter_mut() + .zip(unpacked.chunks_exact(packing_width)) .for_each(|(out_elem, chunk)| { if INITIALIZED { *out_elem += EF::ExtensionPacking::from_ext_slice(chunk); @@ -181,40 +200,22 @@ where *out_elem = EF::ExtensionPacking::from_ext_slice(chunk); } }); - } else { - let eval_len_min_packing = eval.len() - log_packing_width; - - // We split eval into three parts: - // - eval[..LOG_NUM_THREADS] (the first LOG_NUM_THREADS elements) - // - eval[LOG_NUM_THREADS..eval_len_min_packing] (the middle elements) - // - eval[eval_len_min_packing..] (the last log_packing_width elements) - - // The middle elements are the ones which will be computed in parallel. - // The last log_packing_width elements are the ones which will be packed. - - // We make a buffer of elements of size `NUM_THREADS`. - let mut parallel_buffer = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED); - let out_chunk_size = out.len() / NUM_THREADS_PADDED; - - // Compute the equality polynomial corresponding to the last log_packing_width elements - // and pack these. - parallel_buffer[0] = packed_eq_poly(&eval[eval_len_min_packing..], scalar); - - // Update the buffer so it contains the evaluations of the equality polynomial - // with respect to parts one and three. - fill_buffer(eval[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer); - - // Finally do all computations involving the middle elements in parallel. - out.par_chunks_exact_mut(out_chunk_size) - .zip(parallel_buffer.par_iter()) - .for_each(|(out_chunk, buffer_val)| { - eval_eq_with_packed_output::<_, _, INITIALIZED>( - &eval[LOG_NUM_THREADS..(eval.len() - log_packing_width)], - out_chunk, - *buffer_val, - ); - }); + return; } + + // See `compute_eval_eq` for the leading/middle/trailing split. + let seed = packed_eq_poly(&eval[eval.len() - log_packing_width..], scalar); + par_eval_eq( + eval, + out, + log_chunks, + n_chunks, + log_packing_width, + seed, + |middle, out_chunk, buffer_val| { + eval_eq_with_packed_output::<_, _, INITIALIZED>(middle, out_chunk, buffer_val); + }, + ); } /// Computes the equality polynomial evaluations efficiently. @@ -240,57 +241,30 @@ where F: Field, EF: ExtensionField, { - // we assume that packing_width is a power of 2. let log_packing_width = log2_strict_usize(F::Packing::WIDTH); - - // Ensure that the output buffer size is correct: - // It should be of size `2^n`, where `n` is the number of variables. debug_assert_eq!(out.len(), 1 << eval.len()); - // If the number of variables is small, there is no need to use - // parallelization or packings. - if eval.len() <= log_packing_width + 1 + LOG_NUM_THREADS { - // A basic recursive approach. + let (log_chunks, n_chunks) = parallel_split(); + if eval.len() <= log_packing_width + 1 + log_chunks { eval_eq_basic::<_, _, _, INITIALIZED>(eval, out, scalar); return; } - let eval_len_min_packing = eval.len() - log_packing_width; - - // We split eval into three parts: - // - eval[..LOG_NUM_THREADS] (the first LOG_NUM_THREADS elements) - // - eval[LOG_NUM_THREADS..eval_len_min_packing] (the middle elements) - // - eval[eval_len_min_packing..] (the last log_packing_width elements) - - // The middle elements are the ones which will be computed in parallel. - // The last log_packing_width elements are the ones which will be packed. - - // We make a buffer of PackedField elements of size `NUM_THREADS`. - // Note that this is a slightly different strategy to `eval_eq` which instead - // uses PackedExtensionField elements. Whilst this involves slightly more mathematical - // operations, it seems to be faster in practice due to less data moving around. - let mut parallel_buffer = F::Packing::zero_vec(NUM_THREADS_PADDED); - let out_chunk_size = out.len() / NUM_THREADS_PADDED; - - // Compute the equality polynomial corresponding to the last log_packing_width elements - // and pack these. - parallel_buffer[0] = packed_eq_poly(&eval[eval_len_min_packing..], F::ONE); - - // Update the buffer so it contains the evaluations of the equality polynomial - // with respect to parts one and three. - fill_buffer(eval[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer); - - // Finally do all computations involving the middle elements in parallel. - out.par_chunks_exact_mut(out_chunk_size) - .zip(parallel_buffer.par_iter()) - .for_each(|(out_chunk, buffer_val)| { - base_eval_eq_packed::<_, _, INITIALIZED>( - &eval[LOG_NUM_THREADS..(eval.len() - log_packing_width)], - out_chunk, - *buffer_val, - scalar, - ); - }); + // Base-field input: seed the per-chunk buffer with `F::Packing` (not `EF::ExtensionPacking`) + // and apply `scalar` inside the kernel — slightly more ops but less data movement, which is + // faster here in practice. See `compute_eval_eq` for the leading/middle/trailing split. + let seed = packed_eq_poly(&eval[eval.len() - log_packing_width..], F::ONE); + par_eval_eq( + eval, + out, + log_chunks, + n_chunks, + log_packing_width, + seed, + |middle, out_chunk, buffer_val| { + base_eval_eq_packed::<_, _, INITIALIZED>(middle, out_chunk, buffer_val, scalar); + }, + ); } #[inline] @@ -302,24 +276,18 @@ pub fn compute_eval_eq_base_packed( F: Field, EF: ExtensionField, { - // we assume that packing_width is a power of 2. let packing_width = F::Packing::WIDTH; let log_packing_width = log2_strict_usize(packing_width); assert!(log_packing_width <= eval.len()); assert_eq!(out.len(), 1 << (eval.len() - log_packing_width)); - // Ensure that the output buffer size is correct: - // It should be of size `2^n`, where `n` is the number of variables. - debug_assert_eq!(out.len(), 1 << (eval.len() - log_packing_width)); - - // If the number of variables is small, there is no need to use - // parallelization or packings. - if eval.len() <= log_packing_width + 1 + LOG_NUM_THREADS { - // A basic recursive approach. - let mut output_no_packing = EF::zero_vec(1 << eval.len()); - eval_eq_basic::<_, _, _, false>(eval, &mut output_no_packing, scalar); - out.par_iter_mut() - .zip(output_no_packing.par_chunks_exact(packing_width)) + let (log_chunks, n_chunks) = parallel_split(); + if eval.len() <= log_packing_width + 1 + log_chunks { + // Small case: evaluate unpacked, then pack lanes into `out`. + let mut unpacked = EF::zero_vec(1 << eval.len()); + eval_eq_basic::<_, _, _, false>(eval, &mut unpacked, scalar); + out.iter_mut() + .zip(unpacked.chunks_exact(packing_width)) .for_each(|(out_elem, chunk)| { if INITIALIZED { *out_elem += EF::ExtensionPacking::from_ext_slice(chunk); @@ -327,45 +295,24 @@ pub fn compute_eval_eq_base_packed( *out_elem = EF::ExtensionPacking::from_ext_slice(chunk); } }); - } else { - let eval_len_min_packing = eval.len() - log_packing_width; - - // We split eval into three parts: - // - eval[..LOG_NUM_THREADS] (the first LOG_NUM_THREADS elements) - // - eval[LOG_NUM_THREADS..eval_len_min_packing] (the middle elements) - // - eval[eval_len_min_packing..] (the last log_packing_width elements) - - // The middle elements are the ones which will be computed in parallel. - // The last log_packing_width elements are the ones which will be packed. - - // We make a buffer of PackedField elements of size `NUM_THREADS`. - // Note that this is a slightly different strategy to `eval_eq` which instead - // uses PackedExtensionField elements. Whilst this involves slightly more mathematical - // operations, it seems to be faster in practice due to less data moving around. - let mut parallel_buffer = F::Packing::zero_vec(NUM_THREADS_PADDED); - let out_chunk_size = out.len() / NUM_THREADS_PADDED; - - // Compute the equality polynomial corresponding to the last log_packing_width elements - // and pack these. - parallel_buffer[0] = packed_eq_poly(&eval[eval_len_min_packing..], F::ONE); - - // Update the buffer so it contains the evaluations of the equality polynomial - // with respect to parts one and three. - fill_buffer(eval[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer); - - // Finally do all computations involving the middle elements in parallel. - let scalar_packed = EF::ExtensionPacking::from(scalar); - out.par_chunks_exact_mut(out_chunk_size) - .zip(parallel_buffer.par_iter()) - .for_each(|(out_chunk, buffer_val)| { - base_eval_eq_packed_with_packed_output::( - &eval[LOG_NUM_THREADS..(eval.len() - log_packing_width)], - out_chunk, - *buffer_val, - scalar_packed, - ); - }); + return; } + + // Base-field input: seed with `F::Packing` and apply `scalar` in the kernel (less data + // movement — see `compute_eval_eq_base`). See `compute_eval_eq` for the split. + let scalar_packed = EF::ExtensionPacking::from(scalar); + let seed = packed_eq_poly(&eval[eval.len() - log_packing_width..], F::ONE); + par_eval_eq( + eval, + out, + log_chunks, + n_chunks, + log_packing_width, + seed, + |middle, out_chunk, buffer_val| { + base_eval_eq_packed_with_packed_output::(middle, out_chunk, buffer_val, scalar_packed); + }, + ); } #[inline] @@ -412,21 +359,21 @@ pub fn compute_eval_eq_base_packed_batched( }) .collect(); - out.par_chunks_exact_mut(tile_packed_size) - .enumerate() - .for_each(|(tile_idx, out_tile)| { - for (eq_prefix, middle, eq_suffix) in &per_query { - // Here e could precompute the eq poly, trading some memory for less computation - // (2x faster on M4 max, but 2x slower on machines with smaller caches. - // TODO implement both and choose based on cache size?) - base_eval_eq_packed_with_packed_output::( - middle, - out_tile, - *eq_suffix, - EF::ExtensionPacking::from(eq_prefix[tile_idx]), - ); - } - }); + // `out` already splits into `2^n_prefix_levels` tiles — many more than there are + // workers — so the pool's task counter load-balances these directly. + parallel::par_chunks_mut(out, tile_packed_size, |tile_idx, out_tile| { + for (eq_prefix, middle, eq_suffix) in &per_query { + // Here e could precompute the eq poly, trading some memory for less computation + // (2x faster on M4 max, but 2x slower on machines with smaller caches. + // TODO implement both and choose based on cache size?) + base_eval_eq_packed_with_packed_output::( + middle, + out_tile, + *eq_suffix, + EF::ExtensionPacking::from(eq_prefix[tile_idx]), + ); + } + }); } /// Fills the `buffer` with evaluations of the equality polynomial @@ -944,39 +891,40 @@ pub fn compute_eval_eq_packed_dual( assert!(log_packing_width <= eval_a.len()); assert_eq!(out.len(), 1 << (eval_a.len() - log_packing_width)); - if eval_a.len() <= log_packing_width + 1 + LOG_NUM_THREADS { + let (log_chunks, n_chunks) = parallel_split(); + if eval_a.len() <= log_packing_width + 1 + log_chunks { let mut output_no_packing = EF::zero_vec(1 << eval_a.len()); eval_eq_basic::<_, _, _, false>(eval_a, &mut output_no_packing, scalar_a); eval_eq_basic::<_, _, _, true>(eval_b, &mut output_no_packing, scalar_b); - out.par_iter_mut() - .zip(output_no_packing.par_chunks_exact(packing_width)) + out.iter_mut() + .zip(output_no_packing.chunks_exact(packing_width)) .for_each(|(out_elem, chunk)| { *out_elem = EF::ExtensionPacking::from_ext_slice(chunk); }); } else { let eval_len_min_packing = eval_a.len() - log_packing_width; - let mut parallel_buffer_a = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED); - let mut parallel_buffer_b = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED); - let out_chunk_size = out.len() / NUM_THREADS_PADDED; + let mut parallel_buffer_a = EF::ExtensionPacking::zero_vec(n_chunks); + let mut parallel_buffer_b = EF::ExtensionPacking::zero_vec(n_chunks); + let out_chunk_size = out.len() / n_chunks; parallel_buffer_a[0] = packed_eq_poly(&eval_a[eval_len_min_packing..], scalar_a); - fill_buffer(eval_a[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer_a); + fill_buffer(eval_a[..log_chunks].iter().rev(), &mut parallel_buffer_a); parallel_buffer_b[0] = packed_eq_poly(&eval_b[eval_len_min_packing..], scalar_b); - fill_buffer(eval_b[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer_b); - - out.par_chunks_exact_mut(out_chunk_size) - .enumerate() - .for_each(|(i, out_chunk)| { - eval_eq_with_packed_output_dual::, EF>( - &eval_a[LOG_NUM_THREADS..eval_len_min_packing], - &eval_b[LOG_NUM_THREADS..eval_len_min_packing], - out_chunk, - parallel_buffer_a[i], - parallel_buffer_b[i], - ); - }); + fill_buffer(eval_b[..log_chunks].iter().rev(), &mut parallel_buffer_b); + + let middle_a = &eval_a[log_chunks..eval_len_min_packing]; + let middle_b = &eval_b[log_chunks..eval_len_min_packing]; + parallel::par_chunks_mut(out, out_chunk_size, |i, out_chunk| { + eval_eq_with_packed_output_dual::, EF>( + middle_a, + middle_b, + out_chunk, + parallel_buffer_a[i], + parallel_buffer_b[i], + ); + }); } } @@ -1312,7 +1260,7 @@ mod tests { let time = Instant::now(); compute_eval_eq::(&eval, &mut out_3, scalar); let out_3_packed = out_3 - .par_chunks_exact(packing_width) + .chunks_exact(packing_width) .map(>::ExtensionPacking::from_ext_slice) .collect::>(); println!("EXTENSION PACKED AFTER: {:?}", time.elapsed()); @@ -1347,7 +1295,7 @@ mod tests { let time = Instant::now(); compute_eval_eq_base::(&eval, &mut out_3, scalar); let out_3_packed = out_3 - .par_chunks_exact(packing_width) + .chunks_exact(packing_width) .map(>::ExtensionPacking::from_ext_slice) .collect::>(); println!("BASE PACKED AFTER: {:?}", time.elapsed()); @@ -1357,6 +1305,84 @@ mod tests { } } + #[test] + #[ignore = "benchmark; run explicitly with --ignored --nocapture"] + fn bench_pool_oversub() { + use std::sync::atomic::Ordering; + use std::time::Instant; + + // Sweep oversubscription log-factors. 0 == one chunk per worker. + const FACTORS: [usize; 6] = [0, 1, 2, 3, 4, 5]; + + let mut rng = StdRng::seed_from_u64(0); + + // Time `f` over `iters` runs after `warmup` discarded runs; report best. + fn timed(warmup: usize, iters: usize, mut f: impl FnMut()) -> std::time::Duration { + for _ in 0..warmup { + f(); + } + let mut best = std::time::Duration::MAX; + for _ in 0..iters { + let t = Instant::now(); + f(); + best = best.min(t.elapsed()); + } + best + } + + // Time `run` at every oversub factor; print ms (best-of-N) per factor, with + // the per-row best marked. Lets us pick a factor robust across machines. + fn sweep(label: &str, n_vars: usize, warmup: usize, iters: usize, mut run: impl FnMut()) { + let restore = PARALLEL_LOG_OVERSUB.load(Ordering::Relaxed); + let times: Vec = FACTORS + .iter() + .map(|&f| { + PARALLEL_LOG_OVERSUB.store(f, Ordering::Relaxed); + timed(warmup, iters, &mut run).as_secs_f64() * 1e3 + }) + .collect(); + PARALLEL_LOG_OVERSUB.store(restore, Ordering::Relaxed); + let best = times.iter().copied().fold(f64::MAX, f64::min); + print!(" {label:>14} n={n_vars:>2} |"); + for &t in × { + let mark = if (t - best).abs() < 1e-9 { '*' } else { ' ' }; + print!(" {t:>6.2}{mark}"); + } + println!(); + } + + print!("\n oversub factor:"); + for f in FACTORS { + print!(" {f:>2}x "); + } + println!(" (ms, best-of-N, * = row best)"); + for n_vars in [18usize, 20, 22, 23, 24] { + let eval_ef: Vec = (0..n_vars).map(|_| rng.random()).collect(); + let eval_f: Vec = (0..n_vars).map(|_| rng.random()).collect(); + let scalar: EF = rng.random(); + let (warmup, iters) = if n_vars >= 23 { (1, 3) } else { (2, 10) }; + + // Correctness: the factor must not change the result. + PARALLEL_LOG_OVERSUB.store(0, Ordering::Relaxed); + let mut ref_out = EF::zero_vec(1 << n_vars); + compute_eval_eq::(&eval_ef, &mut ref_out, scalar); + for f in FACTORS { + PARALLEL_LOG_OVERSUB.store(f, Ordering::Relaxed); + let mut out = EF::zero_vec(1 << n_vars); + compute_eval_eq::(&eval_ef, &mut out, scalar); + assert_eq!(ref_out, out, "oversub {f} changed output (ext) at n={n_vars}"); + } + + let mut out = EF::zero_vec(1 << n_vars); + sweep("eval_eq (ext)", n_vars, warmup, iters, || { + compute_eval_eq::(&eval_ef, &mut out, scalar); + }); + sweep("eval_eq_base", n_vars, warmup, iters, || { + compute_eval_eq_base::(&eval_f, &mut out, scalar); + }); + } + } + #[test] fn test_compute_eval_eq_packed_dual() { let packing_width = ::Packing::WIDTH; diff --git a/crates/backend/poly/src/evals.rs b/crates/backend/poly/src/evals.rs index 7e0e07b4f..c1168470f 100644 --- a/crates/backend/poly/src/evals.rs +++ b/crates/backend/poly/src/evals.rs @@ -1,8 +1,8 @@ use crate::*; use crate::{EFPacking, PF}; +use ::utils::log2_ceil_usize; use field::{ExtensionField, Field, PrimeCharacteristicRing}; use itertools::Itertools; -use rayon::{join, prelude::*}; use std::borrow::Borrow; pub trait EvaluationsList { @@ -87,7 +87,11 @@ pub fn scale_poly>(poly: &[F], factor: EF) -> Ve if poly.len() < PARALLEL_THRESHOLD { poly.iter().map(|&e| factor * e).collect() } else { - poly.par_iter().map(|&e| factor * e).collect() + let mut out: Vec = unsafe { uninitialized_vec(poly.len()) }; + parallel::par_for_each_mut(&mut out, |i, o| { + *o = factor * poly[i]; + }); + out } } @@ -257,20 +261,23 @@ where // // This chain of operations computes the regrouped sum: // Σ_{v_high} eq(v_high, p_high) * (Σ_{v_low} f(v_high, v_low) * eq(v_low, p_low)) - evals - .par_chunks(left.len()) - .zip_eq(right.par_iter()) - .map(|(part, &c)| { + let left_len = left.len(); + parallel::map_reduce( + right.len(), + || Res::ZERO, + |i| { + let part = &evals[i * left_len..][..left_len]; // This is the inner sum: a dot product between the evaluation chunk and the `left` basis values. mul_res_point( part.iter() .zip_eq(left.iter()) .map(|(&a, &b)| mul_coeffs_point(a, b)) .sum::(), - c, + right[i], ) - }) - .sum() + }, + |a, b| a + b, + ) } else { evals .chunks(left.len()) @@ -290,62 +297,77 @@ where } else { // For moderately sized inputs (5 to 19 variables), use the recursive strategy. // - // Split the evaluations into two halves, corresponding to the first variable being 0 or 1. - let (f0, f1) = evals.split_at(evals.len() / 2); - - // Recursively evaluate on the two smaller hypercubes. - let (f0_eval, f1_eval) = { - // Only spawn parallel tasks if the subproblem is large enough to overcome - // the overhead of threading. - let work_size: usize = (1 << 15) / std::mem::size_of::(); - if evals.len() > work_size && PARALLEL { - join( - || { - eval_multilinear_generic::<_, _, _, _, _, _, PARALLEL>( - f0, - tail, - mul_coeffs_point, - add_res_coeffs, - mul_res_point, - ) - }, - || { - eval_multilinear_generic::<_, _, _, _, _, _, PARALLEL>( - f1, - tail, - mul_coeffs_point, - add_res_coeffs, - mul_res_point, - ) - }, - ) - } else { - // For smaller subproblems, execute sequentially. - ( - eval_multilinear_generic::<_, _, _, _, _, _, false>( - f0, - tail, - mul_coeffs_point, - add_res_coeffs, - mul_res_point, - ), - eval_multilinear_generic::<_, _, _, _, _, _, false>( - f1, - tail, - mul_coeffs_point, - add_res_coeffs, - mul_res_point, - ), - ) - } - }; - // Perform the final linear interpolation for the first variable `x`. - f0_eval + mul_res_point(f1_eval - f0_eval, *x) + // Only spawn parallel tasks if the subproblem is large enough to overcome + // the overhead of threading. + let work_size: usize = (1 << 15) / std::mem::size_of::(); + if evals.len() > work_size && PARALLEL { + // Flat fan-out: peel the `n_split` leading variables into `2^n_split` + // independent subproblems, evaluate each over the remaining coordinates + // sequentially across the pool, then interpolate the partial results over + // the leading coordinates. Equivalent to the recursive `join` split, but + // flat so the in-house pool can parallelize it (nested dispatches fall + // back to sequential, so a recursive split would lose all parallelism). + let log_work = log2_ceil_usize(work_size.max(2)); + let n_split = point.len().saturating_sub(log_work).max(1); + let (lead, sub_point) = point.split_at(n_split); + let n_chunks = 1 << n_split; + let chunk = evals.len() >> n_split; + let mut partials = vec![Res::ZERO; n_chunks]; + parallel::par_chunks_mut(&mut partials, 1, |j, slot| { + slot[0] = eval_multilinear_generic::<_, _, _, _, _, _, false>( + &evals[j * chunk..][..chunk], + sub_point, + mul_coeffs_point, + add_res_coeffs, + mul_res_point, + ); + }); + interpolate_res(&partials, lead, mul_res_point) + } else { + let (f0, f1) = evals.split_at(evals.len() / 2); + let f0_eval = eval_multilinear_generic::<_, _, _, _, _, _, false>( + f0, + tail, + mul_coeffs_point, + add_res_coeffs, + mul_res_point, + ); + let f1_eval = eval_multilinear_generic::<_, _, _, _, _, _, false>( + f1, + tail, + mul_coeffs_point, + add_res_coeffs, + mul_res_point, + ); + // Perform the final linear interpolation for the first variable `x`. + f0_eval + mul_res_point(f1_eval - f0_eval, *x) + } } } } } +/// Multilinear interpolation of `values` (the `2^point.len()` hypercube evaluations of a +/// function, indexed lexicographically) at `point`, using only `Res` arithmetic and the +/// `mul_res_point` scaling. Used to recombine the partial results of the flat parallel +/// fan-out in [`eval_multilinear_generic`]. +fn interpolate_res(values: &[Res], point: &[Point], mul_res_point: &MRP) -> Res +where + Point: Field, + Res: Copy + PrimeCharacteristicRing, + MRP: Fn(Res, Point) -> Res, +{ + match point { + [] => values[0], + [x, tail @ ..] => { + let (low, high) = values.split_at(values.len() / 2); + let p0 = interpolate_res(low, tail, mul_res_point); + let p1 = interpolate_res(high, tail, mul_res_point); + p0 + mul_res_point(p1 - p0, *x) + } + } +} + #[cfg(test)] mod tests { use std::time::Instant; diff --git a/crates/backend/poly/src/mle/mle_single_ref.rs b/crates/backend/poly/src/mle/mle_single_ref.rs index 61d607d76..269884fdf 100644 --- a/crates/backend/poly/src/mle/mle_single_ref.rs +++ b/crates/backend/poly/src/mle/mle_single_ref.rs @@ -119,13 +119,15 @@ impl<'a, EF: ExtensionField>> MleRef<'a, EF> { pub fn fold(&self, alpha: EF) -> MleOwned { match self { - Self::Base(pols) => MleOwned::Extension(fold_multilinear(pols, alpha, &|a, b| b * a)), - Self::Extension(pols) => MleOwned::Extension(fold_multilinear(pols, alpha, &|a, b| b * a)), + Self::Base(pols) => MleOwned::Extension(fold_multilinear(pols, alpha, &|a, b| b * a, false)), + Self::Extension(pols) => MleOwned::Extension(fold_multilinear(pols, alpha, &|a, b| b * a, false)), Self::BasePacked(pols) => { let alpha_packed = EFPacking::::from(alpha); - MleOwned::ExtensionPacked(fold_multilinear(pols, alpha_packed, &|a, b| b * a)) + MleOwned::ExtensionPacked(fold_multilinear(pols, alpha_packed, &|a, b| b * a, false)) + } + Self::ExtensionPacked(pols) => { + MleOwned::ExtensionPacked(fold_multilinear(pols, alpha, &|a, b| a * b, false)) } - Self::ExtensionPacked(pols) => MleOwned::ExtensionPacked(fold_multilinear(pols, alpha, &|a, b| a * b)), } } } diff --git a/crates/backend/poly/src/utils.rs b/crates/backend/poly/src/utils.rs index 5bb5fb1b4..cc729bc74 100644 --- a/crates/backend/poly/src/utils.rs +++ b/crates/backend/poly/src/utils.rs @@ -1,14 +1,9 @@ use std::{ mem::ManuallyDrop, - ops::{Add, Range, Sub}, + ops::{Add, Sub}, }; use field::*; -use rayon::{ - iter::Zip, - prelude::*, - slice::{Iter, IterMut}, -}; use crate::{EFPacking, PF, PFPacking}; @@ -26,9 +21,9 @@ pub fn pack_extension>>(slice: &[EF]) -> Vec>>(vec: &[EFPacking]) -> Ve write(chunk, x); } } else { - out.par_chunks_exact_mut(width) - .zip(vec.par_iter()) - .for_each(|(chunk, x)| write(chunk, x)); + // One pool task per group of `group` packed elements, each writing `group * width` + // contiguous output scalars from a disjoint slice of `vec`. + let group = parallel::recommended_chunk_size(vec.len()); + parallel::par_chunks_mut(&mut out, group * width, |ci, out_chunk| { + for (k, sub) in out_chunk.chunks_exact_mut(width).enumerate() { + write(sub, &vec[ci * group + k]); + } + }); } out } @@ -67,31 +67,24 @@ pub const fn must_unpack_multilinears(n_vars: usize) -> bool { n_vars <= 1 + packing_log_width::() } -pub fn batch_fold_multilinears< - EF: PrimeCharacteristicRing + Copy + Send + Sync, - IF: Copy + Sub + Send + Sync, - OF: Copy + Add + Send + Sync, - F: Fn(IF, EF) -> OF + Sync + Send, ->( - polys: &[&[IF]], - alpha: EF, - mul_if_of: F, -) -> Vec> { - let total_size: usize = polys.iter().map(|p| p.len()).sum(); - if total_size < PARALLEL_THRESHOLD { - polys - .iter() - .map(|poly| fold_multilinear(poly, alpha, &mul_if_of)) - .collect() +/// Fill `len` output slots with `compute(i)`, parallelizing via the pool when the work is +/// large enough. `seq` forces the sequential path: the batched wrappers below dispatch one +/// pool task per poly, so their inner fold must not nest a parallel dispatch (which would +/// panic in [`parallel`]). +#[inline] +fn fold_fill OF + Sync>(len: usize, seq: bool, compute: C) -> Vec { + let mut res = unsafe { uninitialized_vec(len) }; + if seq || len < PARALLEL_THRESHOLD { + for (i, r) in res.iter_mut().enumerate() { + *r = compute(i); + } } else { - polys - .par_iter() - .map(|poly| fold_multilinear(poly, alpha, &mul_if_of)) - .collect() + parallel::par_for_each_mut(&mut res, |i, r| *r = compute(i)); } + res } -pub fn fold_multilinear_lsb< +fn fold_multilinear_lsb< EF: PrimeCharacteristicRing + Copy + Send + Sync, IF: Copy + Sub + Send + Sync, OF: Copy + Add + Send + Sync, @@ -100,20 +93,14 @@ pub fn fold_multilinear_lsb< m: &[IF], alpha: EF, mul_if_of: &Mul, + seq: bool, ) -> Vec { - let new_size = m.len() / 2; - let mut res = unsafe { uninitialized_vec(new_size) }; - let compute = |(c, r_v): (&[IF], &mut OF)| { - *r_v = mul_if_of(c[1] - c[0], alpha) + c[0]; - }; - if new_size < PARALLEL_THRESHOLD { - m.chunks_exact(2).zip(res.iter_mut()).for_each(compute); - } else { - m.par_chunks_exact(2).zip(res.par_iter_mut()).for_each(compute); - } - res + fold_fill(m.len() / 2, seq, |j| { + mul_if_of(m[2 * j + 1] - m[2 * j], alpha) + m[2 * j] + }) } +/// Fold `m` at variable `bit`. `seq` forces sequential execution (see [`fold_fill`]). pub fn fold_multilinear_at_bit< EF: PrimeCharacteristicRing + Copy + Send + Sync, IF: Copy + Sub + Send + Sync, @@ -124,40 +111,24 @@ pub fn fold_multilinear_at_bit< alpha: EF, bit: usize, mul_if_of: &Mul, + seq: bool, ) -> Vec { - let new_size = m.len() / 2; assert!(m.len() >= 2 * (1 << bit), "bit out of range for slice length"); - if bit == 0 { - return fold_multilinear_lsb(m, alpha, mul_if_of); + return fold_multilinear_lsb(m, alpha, mul_if_of, seq); } - let stride = 1usize << bit; let lo_mask = stride - 1; - let mut res = unsafe { uninitialized_vec(new_size) }; - - let compute = |new_j: usize| { + fold_fill(m.len() / 2, seq, |new_j| { let i_hi = new_j >> bit; let i_lo = new_j & lo_mask; let i0 = (i_hi << (bit + 1)) | i_lo; let i1 = i0 | stride; mul_if_of(m[i1] - m[i0], alpha) + m[i0] - }; - - if new_size < PARALLEL_THRESHOLD { - for (new_j, res_v) in res.iter_mut().enumerate() { - *res_v = compute(new_j); - } - } else { - (0..new_size) - .into_par_iter() - .with_min_len(PARALLEL_THRESHOLD) - .map(compute) - .collect_into_vec(&mut res); - } - res + }) } +/// Fold `m` at its top variable. `seq` forces sequential execution (see [`fold_fill`]). pub fn fold_multilinear< EF: PrimeCharacteristicRing + Copy + Send + Sync, IF: Copy + Sub + Send + Sync, @@ -167,22 +138,35 @@ pub fn fold_multilinear< m: &[IF], alpha: EF, mul_if_of: &F, + seq: bool, ) -> Vec { let new_size = m.len() / 2; - let mut res = unsafe { uninitialized_vec(new_size) }; + fold_fill(new_size, seq, |i| mul_if_of(m[i + new_size] - m[i], alpha) + m[i]) +} - if new_size < PARALLEL_THRESHOLD { - for i in 0..new_size { - res[i] = mul_if_of(m[i + new_size] - m[i], alpha) + m[i]; - } +pub fn batch_fold_multilinears< + EF: PrimeCharacteristicRing + Copy + Send + Sync, + IF: Copy + Sub + Send + Sync, + OF: Copy + Add + Send + Sync, + F: Fn(IF, EF) -> OF + Sync + Send, +>( + polys: &[&[IF]], + alpha: EF, + mul_if_of: F, +) -> Vec> { + let total_size: usize = polys.iter().map(|p| p.len()).sum(); + if total_size < PARALLEL_THRESHOLD { + polys + .iter() + .map(|poly| fold_multilinear(poly, alpha, &mul_if_of, true)) + .collect() } else { - (0..new_size) - .into_par_iter() - .with_min_len(PARALLEL_THRESHOLD) - .map(|i| mul_if_of(m[i + new_size] - m[i], alpha) + m[i]) - .collect_into_vec(&mut res); + let mut out: Vec> = (0..polys.len()).map(|_| Vec::new()).collect(); + parallel::par_chunks_mut(&mut out, 1, |i, slot| { + slot[0] = fold_multilinear(polys[i], alpha, &mul_if_of, true); + }); + out } - res } pub fn batch_fold_multilinears_at_bit< @@ -196,17 +180,19 @@ pub fn batch_fold_multilinears_at_bit< bit: usize, mul_if_of: F, ) -> Vec> { + // See `batch_fold_multilinears`: one task per poly, inner fold forced sequential. let total_size: usize = polys.iter().map(|p| p.len()).sum(); if total_size < PARALLEL_THRESHOLD { polys .iter() - .map(|poly| fold_multilinear_at_bit(poly, alpha, bit, &mul_if_of)) + .map(|poly| fold_multilinear_at_bit(poly, alpha, bit, &mul_if_of, true)) .collect() } else { - polys - .par_iter() - .map(|poly| fold_multilinear_at_bit(poly, alpha, bit, &mul_if_of)) - .collect() + let mut out: Vec> = (0..polys.len()).map(|_| Vec::new()).collect(); + parallel::par_chunks_mut(&mut out, 1, |i, slot| { + slot[0] = fold_multilinear_at_bit(polys[i], alpha, bit, &mul_if_of, true); + }); + out } } @@ -281,54 +267,6 @@ pub fn split_at_mut_many<'a, A>(slice: &'a mut [A], indices: &[usize]) -> Vec<&' result } -// Parallel - -#[allow(clippy::type_complexity)] -pub fn par_iter_split_4<'a, A: Sync + Send>( - u: &'a [A], -) -> Zip, Iter<'a, A>>, Zip, Iter<'a, A>>> { - let n = u.len(); - assert!(n.is_multiple_of(4)); - let [u_ll, u_lr, u_rl, u_rr] = split_at_many(u, &[n / 4, n / 2, 3 * n / 4]).try_into().ok().unwrap(); - (u_ll.par_iter().zip(u_lr)).zip(u_rl.par_iter().zip(u_rr.par_iter())) -} - -pub fn par_iter_split_2<'a, A: Sync + Send>(u: &'a [A]) -> Zip, Iter<'a, A>> { - par_iter_split_2_capped(u, 0..u.len() / 2) -} - -pub fn par_iter_split_2_capped<'a, A: Sync + Send>(u: &'a [A], range: Range) -> Zip, Iter<'a, A>> { - let n = u.len(); - assert!(n.is_multiple_of(2)); - let (u_left, u_right) = u.split_at(n / 2); - u_left[range.clone()].par_iter().zip(u_right[range.clone()].par_iter()) -} - -pub fn par_iter_mut_split_2<'a, A: Sync + Send>(u: &'a mut [A]) -> Zip, IterMut<'a, A>> { - par_iter_mut_split_2_capped(u, 0..u.len() / 2) -} - -pub fn par_iter_mut_split_2_capped<'a, A: Sync + Send>( - u: &'a mut [A], - range: Range, -) -> Zip, IterMut<'a, A>> { - let n = u.len(); - assert!(n.is_multiple_of(2)); - let (u_left, u_right) = u.split_at_mut(n / 2); - u_left[range.clone()].par_iter_mut().zip(u_right[range].par_iter_mut()) -} - -#[allow(clippy::type_complexity)] -pub fn par_zip_fold_2<'a, 'b, A: Sync + Send, B: Sync + Send>( - u: &'a [A], - folded: &'b mut [B], -) -> Zip, Iter<'a, A>>, Zip, Iter<'a, A>>>, Zip, IterMut<'b, B>>> { - let n = u.len(); - assert!(n.is_multiple_of(4)); - assert_eq!(folded.len(), n / 2); - par_iter_split_4(u).zip(par_iter_mut_split_2(folded)) -} - // Sequential pub fn iter_split_2(u: &[A]) -> impl Iterator { diff --git a/crates/backend/src/lib.rs b/crates/backend/src/lib.rs index cbd44fb2b..f4cc2d18f 100644 --- a/crates/backend/src/lib.rs +++ b/crates/backend/src/lib.rs @@ -2,9 +2,8 @@ pub use air::*; pub use fiat_shamir::*; pub use field::*; pub use koala_bear::*; +pub use parallel; pub use poly::*; -pub use rayon; -pub use rayon::prelude::*; pub use sumcheck::*; pub use symetric::*; pub use utils::*; diff --git a/crates/backend/sumcheck/Cargo.toml b/crates/backend/sumcheck/Cargo.toml index 91085f352..1d5f486ca 100644 --- a/crates/backend/sumcheck/Cargo.toml +++ b/crates/backend/sumcheck/Cargo.toml @@ -8,8 +8,8 @@ field = { path = "../field", package = "mt-field" } air = { path = "../air", package = "mt-air" } poly = { path = "../poly", package = "mt-poly" } fiat-shamir = { path = "../fiat-shamir", package = "mt-fiat-shamir" } +parallel.workspace = true tracing.workspace = true -rayon.workspace = true [dev-dependencies] koala-bear = { path = "../koala-bear", package = "mt-koala-bear" } diff --git a/crates/backend/sumcheck/src/product_computation.rs b/crates/backend/sumcheck/src/product_computation.rs index 2828af039..4f93ad176 100644 --- a/crates/backend/sumcheck/src/product_computation.rs +++ b/crates/backend/sumcheck/src/product_computation.rs @@ -1,7 +1,6 @@ use fiat_shamir::*; use field::*; use poly::*; -use rayon::prelude::*; use tracing::instrument; use crate::{SumcheckComputation, sumcheck_prove_many_rounds}; @@ -146,15 +145,21 @@ pub fn compute_product_sumcheck_polynomial< (a0 + b0, a2 + b2) }) } else { - pol_0[..n / 2] - .par_iter() - .zip(pol_0[n / 2..].par_iter()) - .zip(pol_1[..n / 2].par_iter().zip(pol_1[n / 2..].par_iter())) - .map(sumcheck_quadratic) - .reduce( - || (EFPacking::ZERO, EFPacking::ZERO), - |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), - ) + // Per-worker in-place accumulation: each worker folds the contiguous range it + // claims straight into its own `(c0, c2)` accumulator (no per-chunk tuple to build + // and reduce, worker-slot lookup amortized once per batch by `for_each_chunk`). + let half = n / 2; + parallel::map_reduce_with_state( + half, + || (), + || (EFPacking::ZERO, EFPacking::ZERO), + |(), acc, i| { + let (b0, b2) = sumcheck_quadratic(((&pol_0[i], &pol_0[half + i]), (&pol_1[i], &pol_1[half + i]))); + acc.0 += b0; + acc.1 += b2; + }, + |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), + ) }; let c0 = decompose(c0_packed).into_iter().sum::(); @@ -186,15 +191,17 @@ pub fn compute_product_sumcheck_polynomial_base_ext_packed< let chunk_size = 1024; - let (c0_acc, c2_acc) = pol_0[..half] - .par_chunks(chunk_size) - .zip(pol_0[half..].par_chunks(chunk_size)) - .zip( - pol_1[..half] - .par_chunks(chunk_size) - .zip(pol_1[half..].par_chunks(chunk_size)), - ) - .map(|((b_lo, b_hi), (e_lo, e_hi))| { + let n_chunks = half.div_ceil(chunk_size); + let (c0_acc, c2_acc) = parallel::map_reduce( + n_chunks, + || ([0u128; DIM], [0i128; DIM]), + |c| { + let start = c * chunk_size; + let end = (start + chunk_size).min(half); + let b_lo = &pol_0[start..end]; + let b_hi = &pol_0[half + start..half + end]; + let e_lo = &pol_1[start..end]; + let e_hi = &pol_1[half + start..half + end]; let mut c0 = [0u128; DIM]; let mut c2 = [0i128; DIM]; for i in 0..b_lo.len() { @@ -216,17 +223,15 @@ pub fn compute_product_sumcheck_polynomial_base_ext_packed< } } (c0, c2) - }) - .reduce( - || ([0u128; DIM], [0i128; DIM]), - |(mut a0, mut a2): Acc, (b0, b2): Acc| { - for j in 0..DIM { - a0[j] += b0[j]; - a2[j] += b2[j]; - } - (a0, a2) - }, - ); + }, + |(mut a0, mut a2): Acc, (b0, b2): Acc| { + for j in 0..DIM { + a0[j] += b0[j]; + a2[j] += b2[j]; + } + (a0, a2) + }, + ); let c0 = EF::from_basis_coefficients_fn(|j| F::reduce_product_sum(c0_acc[j])); let c2 = EF::from_basis_coefficients_fn(|j| F::reduce_signed_product_sum(c2_acc[j])); @@ -283,13 +288,41 @@ pub fn fold_and_compute_product_sumcheck_polynomial< (a0 + b0, a2 + b2) }) } else { - par_zip_fold_2(pol_0, &mut pol_0_folded) - .zip(par_zip_fold_2(pol_1, &mut pol_1_folded)) - .map(|(p0, p1)| process_element(p0, p1)) - .reduce( - || (EFPacking::ZERO, EFPacking::ZERO), - |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), - ) + // Fused single pass with per-worker in-place accumulation: fold both polynomials + // (writing the disjoint `i` / `quarter + i` output slots) and accumulate the + // per-index quadratic straight into the worker's `(c0, c2)` — no per-chunk tuple. + let quarter = n / 4; + let p0f = parallel::SendPtr(pol_0_folded.as_mut_ptr()); + let p1f = parallel::SendPtr(pol_1_folded.as_mut_ptr()); + parallel::map_reduce_with_state( + quarter, + || (), + || (EFPacking::ZERO, EFPacking::ZERO), + |(), acc, i| { + let diff_0 = pol_0[2 * quarter + i] - pol_0[i]; + let diff_1 = pol_0[3 * quarter + i] - pol_0[quarter + i]; + let x_0 = prev_folding_factor_packed * diff_0 + pol_0[i]; + let x_1 = prev_folding_factor_packed * diff_1 + pol_0[quarter + i]; + + let y_0 = prev_folding_factor_packed * (pol_1[2 * quarter + i] - pol_1[i]) + pol_1[i]; + let y_1 = + prev_folding_factor_packed * (pol_1[3 * quarter + i] - pol_1[quarter + i]) + pol_1[quarter + i]; + + // SAFETY: distinct `i` write disjoint slots `i` and `quarter + i` in + // `[0, n/2)`; the dispatcher keeps both buffers borrowed for the call. + unsafe { + *p0f.add(i) = x_0; + *p0f.add(quarter + i) = x_1; + *p1f.add(i) = y_0; + *p1f.add(quarter + i) = y_1; + } + + let (b0, b2) = sumcheck_quadratic(((&x_0, &x_1), (&y_0, &y_1))); + acc.0 += b0; + acc.1 += b2; + }, + |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), + ) }; let c0 = decompose(c0_packed).into_iter().sum::(); diff --git a/crates/backend/sumcheck/src/sc_computation.rs b/crates/backend/sumcheck/src/sc_computation.rs index 6f589bb7b..56d63d0b3 100644 --- a/crates/backend/sumcheck/src/sc_computation.rs +++ b/crates/backend/sumcheck/src/sc_computation.rs @@ -2,10 +2,16 @@ use crate::*; use air::*; use field::*; use poly::*; -use rayon::prelude::*; use std::any::TypeId; use std::ops::{Add, AddAssign, Mul, MulAssign, Sub}; +fn add_assign_vec(mut a: Vec, b: Vec) -> Vec { + for (x, y) in a.iter_mut().zip(b) { + *x += y; + } + a +} + pub trait SumcheckComputation>>: Sync { type ExtraData: Send + Sync + 'static; @@ -58,44 +64,12 @@ where } } -fn parallel_sum(size: usize, n: usize, init_state: IS, compute_iteration: F) -> Vec -where - T: PrimeCharacteristicRing + Send + Sync, - S: Send, - IS: Fn() -> S + Sync + Send, - F: Fn(&mut S, usize) -> Vec + Sync + Send, -{ - let accumulate = |mut acc: Vec, sums: Vec| { - for (j, sum) in sums.into_iter().enumerate() { - acc[j] += sum; - } - acc - }; - - if size < PARALLEL_THRESHOLD { - let mut state = init_state(); - (0..size).fold(T::zero_vec(n), |acc, i| { - accumulate(acc, compute_iteration(&mut state, i)) - }) - } else { - (0..size) - .into_par_iter() - .map_init(&init_state, |state, i| compute_iteration(state, i)) - .reduce(|| T::zero_vec(n), accumulate) - } -} - fn build_evals>>( sums: impl IntoIterator, missing_mul_factor: Option, ) -> Vec { sums.into_iter() - .map(|mut sum| { - if let Some(factor) = missing_mul_factor { - sum *= factor; - } - sum - }) + .map(|sum| missing_mul_factor.map_or(sum, |f| sum * f)) .collect() } @@ -425,49 +399,49 @@ where + MulAssign, SC: SumcheckComputation, { + // Per-worker scratch: `rows` (the [lo, diff, hi] triples) and `point` (the + // evaluation point handed to `eval_fn`) are reused across every task a worker + // owns, so the hot loop allocates nothing. `acc` (length `degree`) is the + // per-worker partial sum. let n_mult = multilinears.len(); - let compute_at = |(rows, point): &mut (Vec<[IF; 3]>, Vec), i: usize| -> Vec { - let eq_val = eq_at(i); - - rows.clear(); - rows.extend(multilinears.iter().map(|m| { - let lo = m[i]; - let hi = m[i + fold_size]; - [lo, hi - lo, hi] - })); - - // z = 0 - point.clear(); - point.extend(rows.iter().map(|row| row[0])); - let mut eval_0 = eval_fn(computation, point, extra_data); - if let Some(eq) = eq_val { - eval_0 *= eq; - } - - let mut evals = Vec::with_capacity(degree); - evals.push(eval_0); - - // z = 2, 3, ... - for _ in 1..degree { - for [_, diff_hi_lo, running] in rows.iter_mut() { - *running += *diff_hi_lo; - } + let sums = parallel::map_reduce_with_state( + fold_size, + || (Vec::<[IF; 3]>::with_capacity(n_mult), Vec::::with_capacity(n_mult)), + || EFT::zero_vec(degree), + |(rows, point), acc, i| { + let eq_val = eq_at(i); + + rows.clear(); + rows.extend(multilinears.iter().map(|m| { + let lo = m[i]; + let hi = m[i + fold_size]; + [lo, hi - lo, hi] + })); + + // z = 0 point.clear(); - point.extend(rows.iter().map(|row| row[2])); - let mut eval = eval_fn(computation, point, extra_data); + point.extend(rows.iter().map(|row| row[0])); + let mut eval_0 = eval_fn(computation, point, extra_data); if let Some(eq) = eq_val { - eval *= eq; + eval_0 *= eq; } - evals.push(eval); - } - evals - }; + acc[0] += eval_0; - let sums = parallel_sum( - fold_size, - degree, - || (Vec::<[IF; 3]>::with_capacity(n_mult), Vec::::with_capacity(n_mult)), - compute_at, + // z = 2, 3, ... + for acc_d in acc.iter_mut().skip(1) { + for [_, diff_hi_lo, running] in rows.iter_mut() { + *running += *diff_hi_lo; + } + point.clear(); + point.extend(rows.iter().map(|row| row[2])); + let mut eval = eval_fn(computation, point, extra_data); + if let Some(eq) = eq_val { + eval *= eq; + } + *acc_d += eval; + } + }, + add_assign_vec, ); let unpacked_sums = sums.into_iter().map(&unpack_sum); build_evals(unpacked_sums, missing_mul_factor) @@ -500,54 +474,54 @@ where .map(|_| FT::zero_vec(prev_folded_size)) .collect(); + // Per-worker scratch: `rows_f` (the [lo, diff, hi] triples) and `point` (the + // evaluation point handed to `eval_fn`) are reused across every task a worker + // owns, so the hot loop allocates nothing. `acc` (length `degree`) is the + // per-worker partial sum. let n_mult = multilinears.len(); - let compute_iteration = |(rows_f, point): &mut (Vec<[FT; 3]>, Vec), i: usize| -> Vec { - let eq_mle_eval = eq_at(i); - - rows_f.clear(); - rows_f.extend(multilinears.iter().enumerate().map(|(j, m)| { - let lo = fold_f(m, i); - let hi = fold_f(m, i + compute_fold_size); - unsafe { - let ptr = folded_f[j].as_ptr() as *mut FT; - *ptr.add(i) = lo; - *ptr.add(i + compute_fold_size) = hi; - } - [lo, hi - lo, hi] - })); - - // z = 0 - point.clear(); - point.extend(rows_f.iter().map(|row| row[0])); - let mut eval_0 = eval_fn(computation, point, extra_data); - if let Some(eq) = eq_mle_eval { - eval_0 *= eq; - } - - let mut evals = Vec::with_capacity(degree); - evals.push(eval_0); + let sums = parallel::map_reduce_with_state( + compute_fold_size, + || (Vec::<[FT; 3]>::with_capacity(n_mult), Vec::::with_capacity(n_mult)), + || FT::zero_vec(degree), + |(rows_f, point), acc, i| { + let eq_mle_eval = eq_at(i); + + rows_f.clear(); + rows_f.extend(multilinears.iter().enumerate().map(|(j, m)| { + let lo = fold_f(m, i); + let hi = fold_f(m, i + compute_fold_size); + unsafe { + let ptr = folded_f[j].as_ptr() as *mut FT; + *ptr.add(i) = lo; + *ptr.add(i + compute_fold_size) = hi; + } + [lo, hi - lo, hi] + })); - // z = 2, 3, ... - for _ in 1..degree { - for [_, diff_hi_lo, running] in rows_f.iter_mut() { - *running += *diff_hi_lo; - } + // z = 0 point.clear(); - point.extend(rows_f.iter().map(|row| row[2])); - let mut eval = eval_fn(computation, point, extra_data); + point.extend(rows_f.iter().map(|row| row[0])); + let mut eval_0 = eval_fn(computation, point, extra_data); if let Some(eq) = eq_mle_eval { - eval *= eq; + eval_0 *= eq; } - evals.push(eval); - } - evals - }; + acc[0] += eval_0; - let sums = parallel_sum( - compute_fold_size, - degree, - || (Vec::<[FT; 3]>::with_capacity(n_mult), Vec::::with_capacity(n_mult)), - compute_iteration, + // z = 2, 3, ... + for acc_d in acc.iter_mut().skip(1) { + for [_, diff_hi_lo, running] in rows_f.iter_mut() { + *running += *diff_hi_lo; + } + point.clear(); + point.extend(rows_f.iter().map(|row| row[2])); + let mut eval = eval_fn(computation, point, extra_data); + if let Some(eq) = eq_mle_eval { + eval *= eq; + } + *acc_d += eval; + } + }, + add_assign_vec, ); let unpacked_sums = sums.into_iter().map(&unpack_sum); (build_evals(unpacked_sums, missing_mul_factor), wrap_f(folded_f)) @@ -575,65 +549,60 @@ where let eq_lo = &split_eq.eq_lo; let eq_hi = &split_eq.eq_hi_packed; - let zero = || EFPacking::::zero_vec(degree); - let accumulate = |mut acc: Vec>, vals: Vec>| -> Vec> { - for (a, v) in acc.iter_mut().zip(vals.iter()) { - *a += *v; - } - acc - }; - + // Per-worker scratch reused across every `b_lo` task: `rows` ([lo, diff, hi] + // triples), `point` (handed to `eval_fn`), and `block_acc` (per-`b_lo` partial + // sum, scaled by `eq_lo` before folding into the worker accumulator `acc`). let n_mult = multilinears.len(); - let sums: Vec> = (0..n_lo) - .into_par_iter() - .map_init( - || { - ( - Vec::<[EFPacking; 3]>::with_capacity(n_mult), - Vec::>::with_capacity(n_mult), - ) - }, - |(rows, point), b_lo| { - let eq_lo_bc = EFPacking::::from(eq_lo[b_lo]); - let base = b_lo << log_packed_hi; - let mut block_acc = zero(); - for k in 0..packed_hi { - let i = base + k; - let eq_val = eq_hi[k]; - - rows.clear(); - rows.extend(multilinears.iter().map(|m| { - let lo = m[i]; - let hi = m[i + fold_size]; - [lo, hi - lo, hi] - })); - - // z = 0 - point.clear(); - point.extend(rows.iter().map(|r| r[0])); - let mut e0 = eval_fn(computation, point, extra_data); - e0 *= eq_val; - block_acc[0] += e0; - - // z = 2, 3, ... - for d in 1..degree { - for [_, diff, running] in rows.iter_mut() { - *running += *diff; - } - point.clear(); - point.extend(rows.iter().map(|r| r[2])); - let mut ev = eval_fn(computation, point, extra_data); - ev *= eq_val; - block_acc[d] += ev; + let sums: Vec> = parallel::map_reduce_with_state( + n_lo, + || { + ( + Vec::<[EFPacking; 3]>::with_capacity(n_mult), + Vec::>::with_capacity(n_mult), + EFPacking::::zero_vec(degree), + ) + }, + || EFPacking::::zero_vec(degree), + |(rows, point, block_acc), acc, b_lo| { + let eq_lo_bc = EFPacking::::from(eq_lo[b_lo]); + let base = b_lo << log_packed_hi; + block_acc.iter_mut().for_each(|x| *x = EFPacking::::ZERO); + for k in 0..packed_hi { + let i = base + k; + let eq_val = eq_hi[k]; + + rows.clear(); + rows.extend(multilinears.iter().map(|m| { + let lo = m[i]; + let hi = m[i + fold_size]; + [lo, hi - lo, hi] + })); + + // z = 0 + point.clear(); + point.extend(rows.iter().map(|r| r[0])); + let mut e0 = eval_fn(computation, point, extra_data); + e0 *= eq_val; + block_acc[0] += e0; + + // z = 2, 3, ... + for d in 1..degree { + for [_, diff, running] in rows.iter_mut() { + *running += *diff; } + point.clear(); + point.extend(rows.iter().map(|r| r[2])); + let mut ev = eval_fn(computation, point, extra_data); + ev *= eq_val; + block_acc[d] += ev; } - for a in &mut block_acc { - *a *= eq_lo_bc; - } - block_acc - }, - ) - .reduce(zero, accumulate); + } + for (a, b) in acc.iter_mut().zip(block_acc.iter()) { + *a += *b * eq_lo_bc; + } + }, + add_assign_vec, + ); let unpacked = sums.into_iter().map(&unpack_sum); build_evals(unpacked, missing_mul_factor) @@ -670,69 +639,64 @@ where let eq_lo = &split_eq.eq_lo; let eq_hi = &split_eq.eq_hi_packed; - let zero = || EFPacking::::zero_vec(degree); - let accumulate = |mut acc: Vec>, vals: Vec>| -> Vec> { - for (a, v) in acc.iter_mut().zip(vals.iter()) { - *a += *v; - } - acc - }; - + // Per-worker scratch reused across every `b_lo` task (see `sumcheck_compute_with_split_eq`): + // `rows_f` triples, `point` for `eval_fn`, and the per-`b_lo` `block_acc`. let n_mult = multilinears.len(); - let sums: Vec> = (0..n_lo) - .into_par_iter() - .map_init( - || { - ( - Vec::<[EFPacking; 3]>::with_capacity(n_mult), - Vec::>::with_capacity(n_mult), - ) - }, - |(rows_f, point), b_lo| { - let eq_lo_bc = EFPacking::::from(eq_lo[b_lo]); - let base = b_lo << log_packed_hi; - let mut block_acc = zero(); - for k in 0..packed_hi { - let i = base + k; - let eq_val = eq_hi[k]; - - rows_f.clear(); - rows_f.extend(multilinears.iter().enumerate().map(|(j, m)| { - let lo = fold_f(m, i); - let hi = fold_f(m, i + compute_fold_size); - unsafe { - let ptr = folded_f[j].as_ptr() as *mut EFPacking; - *ptr.add(i) = lo; - *ptr.add(i + compute_fold_size) = hi; - } - [lo, hi - lo, hi] - })); - - // z = 0 - point.clear(); - point.extend(rows_f.iter().map(|r| r[0])); - let mut e0 = eval_fn(computation, point, extra_data); - e0 *= eq_val; - block_acc[0] += e0; - - for d in 1..degree { - for [_, diff, running] in rows_f.iter_mut() { - *running += *diff; - } - point.clear(); - point.extend(rows_f.iter().map(|r| r[2])); - let mut ev = eval_fn(computation, point, extra_data); - ev *= eq_val; - block_acc[d] += ev; + let sums: Vec> = parallel::map_reduce_with_state( + n_lo, + || { + ( + Vec::<[EFPacking; 3]>::with_capacity(n_mult), + Vec::>::with_capacity(n_mult), + EFPacking::::zero_vec(degree), + ) + }, + || EFPacking::::zero_vec(degree), + |(rows_f, point, block_acc), acc, b_lo| { + let eq_lo_bc = EFPacking::::from(eq_lo[b_lo]); + let base = b_lo << log_packed_hi; + block_acc.iter_mut().for_each(|x| *x = EFPacking::::ZERO); + for k in 0..packed_hi { + let i = base + k; + let eq_val = eq_hi[k]; + + rows_f.clear(); + rows_f.extend(multilinears.iter().enumerate().map(|(j, m)| { + let lo = fold_f(m, i); + let hi = fold_f(m, i + compute_fold_size); + unsafe { + let ptr = folded_f[j].as_ptr() as *mut EFPacking; + *ptr.add(i) = lo; + *ptr.add(i + compute_fold_size) = hi; } + [lo, hi - lo, hi] + })); + + // z = 0 + point.clear(); + point.extend(rows_f.iter().map(|r| r[0])); + let mut e0 = eval_fn(computation, point, extra_data); + e0 *= eq_val; + block_acc[0] += e0; + + // z = 2, 3, ... + for d in 1..degree { + for [_, diff, running] in rows_f.iter_mut() { + *running += *diff; + } + point.clear(); + point.extend(rows_f.iter().map(|r| r[2])); + let mut ev = eval_fn(computation, point, extra_data); + ev *= eq_val; + block_acc[d] += ev; } - for a in &mut block_acc { - *a *= eq_lo_bc; - } - block_acc - }, - ) - .reduce(zero, accumulate); + } + for (a, b) in acc.iter_mut().zip(block_acc.iter()) { + *a += *b * eq_lo_bc; + } + }, + add_assign_vec, + ); let unpacked = sums.into_iter().map(&unpack_sum); (build_evals(unpacked, missing_mul_factor), wrap_f(folded_f)) diff --git a/crates/backend/symetric/Cargo.toml b/crates/backend/symetric/Cargo.toml index 125fb5535..86b7c3cd3 100644 --- a/crates/backend/symetric/Cargo.toml +++ b/crates/backend/symetric/Cargo.toml @@ -6,4 +6,4 @@ edition.workspace = true [dependencies] koala-bear = { path = "../koala-bear", package = "mt-koala-bear" } field = { path = "../field", package = "mt-field" } -rayon.workspace = true +parallel.workspace = true diff --git a/crates/backend/symetric/src/merkle.rs b/crates/backend/symetric/src/merkle.rs index 2fe194855..4b609a09d 100644 --- a/crates/backend/symetric/src/merkle.rs +++ b/crates/backend/symetric/src/merkle.rs @@ -4,7 +4,6 @@ use std::array; use field::PackedValue; -use rayon::prelude::*; use crate::Compression; @@ -67,18 +66,18 @@ where let default_digest = [P::Value::default(); DIGEST_ELEMS]; let mut next_digests = vec![default_digest; next_len_padded]; - next_digests[0..next_len] - .par_chunks_exact_mut(width) - .enumerate() - .for_each(|(i, digests_chunk)| { - let first_row = i * width; - let left = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k)][j])); - let right = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j])); - let packed_digest = crate::compress(comp, [left, right]); - for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { - *dst = src; - } - }); + // Process only the full packed chunks in parallel (matches `par_chunks_exact_mut`); + // the `< width` remainder is handled by the sequential tail loop below. + let n_full = next_len / width * width; + parallel::par_chunks_mut(&mut next_digests[0..n_full], width, |i, digests_chunk| { + let first_row = i * width; + let left = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k)][j])); + let right = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j])); + let packed_digest = crate::compress(comp, [left, right]); + for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { + *dst = src; + } + }); for i in (next_len / width * width)..next_len { let left = prev_layer[2 * i]; diff --git a/crates/backend/system-info/Cargo.toml b/crates/backend/system-info/Cargo.toml index c63ee1297..862e36e89 100644 --- a/crates/backend/system-info/Cargo.toml +++ b/crates/backend/system-info/Cargo.toml @@ -5,7 +5,6 @@ edition.workspace = true [dependencies] libc = "0.2" -rayon.workspace = true [lints] workspace = true diff --git a/crates/backend/system-info/src/lib.rs b/crates/backend/system-info/src/lib.rs index 07180559b..5323c1ce4 100644 --- a/crates/backend/system-info/src/lib.rs +++ b/crates/backend/system-info/src/lib.rs @@ -9,36 +9,3 @@ pub fn peak_rss_bytes() -> u64 { // ru_maxrss unit: bytes on macOS, KiB on Linux. if cfg!(target_os = "macos") { max } else { max * 1024 } } - -/// Number of jobs [`flush_rayon`] pushes. Must exceed -/// `crossbeam_deque::deque::BLOCK_CAP` (currently 63 — -/// `crossbeam-deque-0.8.6/src/deque.rs:1191`). -const RAYON_FLUSH_JOBS: usize = 256; - -/// Drain rayon's internal queues so they release any storage allocated during the -/// previous phase. -/// -/// Rayon's global pool owns a `crossbeam_deque::Injector`, internally a linked list -/// of fixed-size blocks (`Block` and `Injector::push` — -/// `crossbeam-deque-0.8.6/src/deque.rs:1219` and `:1371`). A block is freed only -/// once its last slot has been consumed. -/// -/// `rayon::join` from a non-worker thread reaches that injector via -/// `join` (`rayon-core-1.13.0/src/join/mod.rs:132`) -> -/// `registry::in_worker` (`registry.rs:946`) -> -/// `Registry::in_worker_cold` (`:517`) -> -/// `Registry::inject` (`:428`) -> `Injector::push`. -/// -/// Under an arena allocator that recycles memory between phases (e.g. `zk-alloc`), -/// a block allocated *during* a phase points into a slab the next `begin_phase()` -/// will reuse. The next push then writes a `JobRef` straight through whatever the -/// application has placed on top, silently corrupting it. -/// -/// Pushing more than `BLOCK_CAP` jobs while the arena is off forces the Injector -/// to allocate a fresh tail block (which lands in System), and forces workers to -/// steal the last slot of every preceding block (which destroys them). -pub fn flush_rayon() { - for _ in 0..RAYON_FLUSH_JOBS { - rayon::join(|| {}, || {}); - } -} diff --git a/crates/backend/zk-alloc/Cargo.toml b/crates/backend/zk-alloc/Cargo.toml deleted file mode 100644 index fe4c12233..000000000 --- a/crates/backend/zk-alloc/Cargo.toml +++ /dev/null @@ -1,17 +0,0 @@ -[package] -name = "zk-alloc" -version.workspace = true -edition.workspace = true -description = "Bump+reset arena allocator for ZK proving workloads" - -[dependencies] -system-info.workspace = true - -[dev-dependencies] -rayon.workspace = true - -[target.'cfg(not(all(target_os = "linux", target_arch = "x86_64")))'.dependencies] -libc = "0.2" - -[lints] -workspace = true diff --git a/crates/backend/zk-alloc/src/lib.rs b/crates/backend/zk-alloc/src/lib.rs deleted file mode 100644 index 1b43143d6..000000000 --- a/crates/backend/zk-alloc/src/lib.rs +++ /dev/null @@ -1,198 +0,0 @@ -//! Bump-pointer arena allocator. -//! -//! One mmap region split into per-thread slabs. Allocation = increment a thread-local -//! pointer; free = no-op. `begin_phase()` resets the arena: each thread's next -//! allocation starts over at the beginning of its slab, overwriting the previous -//! phase's data. Allocations that don't fit (too large, or beyond `MAX_THREADS`) fall -//! back to the system allocator. -//! -//! ```ignore -//! init(); // once, at process start -//! loop { -//! begin_phase(); // arena ON; slabs reset lazily -//! let res = heavy_work(); // fast increments -//! end_phase(); // arena OFF; new allocations go to System -//! let copy = res.clone(); // detach from arena before next phase resets it -//! } -//! ``` - -use std::alloc::{GlobalAlloc, Layout}; -use std::cell::Cell; -use std::sync::Once; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; - -use system_info::NUM_THREADS; - -mod syscall; - -const SLAB_SIZE: usize = 8 << 30; // 8GB -const SLACK: usize = 4; // SLACK absorbs the main thread and any non-rayon helpers. -const MAX_THREADS: usize = NUM_THREADS + SLACK; -const REGION_SIZE: usize = SLAB_SIZE * MAX_THREADS; - -#[derive(Debug)] -pub struct ZkAllocator; - -/// Incremented by `begin_phase()`. Every thread caches the last value it saw in -/// `ARENA_GEN`; when they differ, the thread resets its allocation cursor to the start -/// of its slab on the next allocation. This is how a single store on the main thread -/// "resets" every other thread's slab without any cross-thread synchronization. -static GENERATION: AtomicUsize = AtomicUsize::new(0); - -/// Master switch for the arena. `true` (set by `begin_phase`) routes allocations -/// through the arena; `false` (set by `end_phase`) routes them to the system allocator. -static ARENA_ACTIVE: AtomicBool = AtomicBool::new(false); - -/// Base address of the mmap'd region, or `0` before `ensure_region` runs. Read on -/// every `dealloc` to test whether a pointer belongs to us. -static REGION_BASE: AtomicUsize = AtomicUsize::new(0); - -/// Synchronizes the one-time mmap so concurrent first-allocators don't race. -static REGION_INIT: Once = Once::new(); - -/// Monotonic counter handed out to threads to pick their slab. `fetch_add`'d once per -/// thread on its first arena allocation. Threads that get `idx >= MAX_THREADS` mark -/// themselves `ARENA_NO_SLAB` and permanently fall through to the system allocator. -static THREAD_IDX: AtomicUsize = AtomicUsize::new(0); - -thread_local! { - /// Where this thread's next allocation lands. Advanced past each allocation. - static ARENA_PTR: Cell = const { Cell::new(0) }; - /// One past the last byte of this thread's slab. An alloc fits iff - /// `aligned + size <= ARENA_END`. - static ARENA_END: Cell = const { Cell::new(0) }; - /// Base address of this thread's slab (`0` = not yet claimed). On reset, - /// `ARENA_PTR` is set back to this value. - static ARENA_BASE: Cell = const { Cell::new(0) }; - /// Last `GENERATION` value this thread observed. When the global moves past - /// this, the next allocation resets `ARENA_PTR` to `ARENA_BASE` and updates - /// this field. - static ARENA_GEN: Cell = const { Cell::new(0) }; - /// `true` if this thread was created after `MAX_THREADS` was already exhausted. - /// Such threads skip arena logic entirely and always go to the system allocator. - static ARENA_NO_SLAB: Cell = const { Cell::new(false) }; -} - -/// Returns the base address of the mmap'd region, mapping it on the first call. -fn ensure_region() -> usize { - REGION_INIT.call_once(|| { - // SAFETY: mmap_anonymous returns a page-aligned pointer or null. MAP_NORESERVE - // means no physical memory is committed until pages are touched. - let ptr = unsafe { syscall::mmap_anonymous(REGION_SIZE) }; - if ptr.is_null() { - std::process::abort(); - } - unsafe { syscall::madvise(ptr, REGION_SIZE, syscall::MADV_NOHUGEPAGE) }; - REGION_BASE.store(ptr as usize, Ordering::Release); - }); - REGION_BASE.load(Ordering::Acquire) -} - -/// Call once at process start, before any `begin_phase()`. -pub fn init() { - let actual_num_threads = std::thread::available_parallelism().unwrap().get(); - assert_eq!( - actual_num_threads, NUM_THREADS, - "built for {NUM_THREADS} threads but this machine reports {actual_num_threads} -> please rebuild`" - ); -} - -/// Activates the arena and resets every thread's slab. All allocations until the next -/// `end_phase()` go to the arena; the previous phase's data is overwritten in place. -pub fn begin_phase() { - let prev_active = ARENA_ACTIVE.swap(true, Ordering::Release); - assert!( - !prev_active, - "begin_phase() called while another phase is already active — phases must not nest" - ); - GENERATION.fetch_add(1, Ordering::Release); -} - -/// Deactivates the arena. New allocations go to the system allocator; existing arena -/// pointers stay valid until the next `begin_phase()` resets the slabs. -/// -/// Also calls [`system_info::flush_rayon`] to release any rayon/crossbeam storage -/// still referencing this phase's arena memory. -pub fn end_phase() { - ARENA_ACTIVE.store(false, Ordering::Release); - system_info::flush_rayon(); -} - -#[cold] -#[inline(never)] -unsafe fn arena_alloc_cold(size: usize, align: usize) -> *mut u8 { - let generation = GENERATION.load(Ordering::Relaxed); - if !ARENA_NO_SLAB.get() && ARENA_GEN.get() != generation { - let mut base = ARENA_BASE.get(); - if base == 0 { - let region = ensure_region(); - let idx = THREAD_IDX.fetch_add(1, Ordering::Relaxed); - if idx >= MAX_THREADS { - ARENA_NO_SLAB.set(true); - return unsafe { std::alloc::System.alloc(Layout::from_size_align_unchecked(size, align)) }; - } - base = region + idx * SLAB_SIZE; - ARENA_BASE.set(base); - ARENA_END.set(base + SLAB_SIZE); - } - ARENA_PTR.set(base); - ARENA_GEN.set(generation); - let aligned = base.next_multiple_of(align); - let new_ptr = aligned + size; - if new_ptr <= ARENA_END.get() { - ARENA_PTR.set(new_ptr); - return aligned as *mut u8; - } - } - unsafe { std::alloc::System.alloc(Layout::from_size_align_unchecked(size, align)) } -} - -// SAFETY: All pointers returned are either from our mmap'd region (valid, aligned, -// non-overlapping per thread) or from System. The arena is thread-local so no data -// races. Relaxed ordering on ARENA_ACTIVE/GENERATION is sound: worst case a thread -// sees a stale value and does one extra system-alloc before picking up the new -// generation on the next call. -unsafe impl GlobalAlloc for ZkAllocator { - #[inline(always)] - unsafe fn alloc(&self, layout: Layout) -> *mut u8 { - if ARENA_ACTIVE.load(Ordering::Relaxed) { - let generation = GENERATION.load(Ordering::Relaxed); - if ARENA_GEN.get() == generation { - let align = layout.align(); - let aligned = (ARENA_PTR.get() + align - 1) & !(align - 1); - let new_ptr = aligned + layout.size(); - if new_ptr <= ARENA_END.get() { - ARENA_PTR.set(new_ptr); - return aligned as *mut u8; - } - } - return unsafe { arena_alloc_cold(layout.size(), layout.align()) }; - } - unsafe { std::alloc::System.alloc(layout) } - } - - #[inline(always)] - unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { - let addr = ptr as usize; - let base = REGION_BASE.load(Ordering::Relaxed); - if base != 0 && addr >= base && addr < base + REGION_SIZE { - return; // arena-owned pointer — free is a no-op - } - unsafe { std::alloc::System.dealloc(ptr, layout) }; - } - - #[inline(always)] - unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 { - if new_size <= layout.size() { - return ptr; - } - // SAFETY: new_size > layout.size() > 0, align unchanged from valid layout. - let new_layout = unsafe { Layout::from_size_align_unchecked(new_size, layout.align()) }; - let new_ptr = unsafe { self.alloc(new_layout) }; - if !new_ptr.is_null() { - unsafe { std::ptr::copy(ptr, new_ptr, layout.size()) }; - unsafe { self.dealloc(ptr, layout) }; - } - new_ptr - } -} diff --git a/crates/backend/zk-alloc/src/syscall.rs b/crates/backend/zk-alloc/src/syscall.rs deleted file mode 100644 index 13d71531d..000000000 --- a/crates/backend/zk-alloc/src/syscall.rs +++ /dev/null @@ -1,163 +0,0 @@ -// Raw syscalls instead of libc wrappers to avoid reentrancy: libc's mmap/madvise -// may internally call malloc, which would deadlock when called from inside -// #[global_allocator]. - -#[cfg(all(target_os = "linux", target_arch = "x86_64"))] -mod imp { - use std::ptr; - - const SYS_MMAP: usize = 9; - const SYS_MADVISE: usize = 28; - - const PROT_READ: usize = 1; - const PROT_WRITE: usize = 2; - const MAP_PRIVATE: usize = 0x02; - const MAP_ANONYMOUS: usize = 0x20; - const MAP_NORESERVE: usize = 0x4000; - - pub const MADV_NOHUGEPAGE: usize = 15; - - #[inline] - unsafe fn syscall6(nr: usize, a1: usize, a2: usize, a3: usize, a4: usize, a5: usize, a6: usize) -> isize { - let ret: isize; - unsafe { - std::arch::asm!( - "syscall", - inlateout("rax") nr as isize => ret, - in("rdi") a1, - in("rsi") a2, - in("rdx") a3, - in("r10") a4, - in("r8") a5, - in("r9") a6, - lateout("rcx") _, - lateout("r11") _, - options(nostack), - ); - } - ret - } - - #[inline] - unsafe fn syscall3(nr: usize, a1: usize, a2: usize, a3: usize) -> isize { - let ret: isize; - unsafe { - std::arch::asm!( - "syscall", - inlateout("rax") nr as isize => ret, - in("rdi") a1, - in("rsi") a2, - in("rdx") a3, - lateout("rcx") _, - lateout("r11") _, - lateout("r10") _, - options(nostack), - ); - } - ret - } - - #[inline] - pub unsafe fn mmap_anonymous(size: usize) -> *mut u8 { - let flags = MAP_PRIVATE | MAP_ANONYMOUS | MAP_NORESERVE; - let ret = unsafe { syscall6(SYS_MMAP, 0, size, PROT_READ | PROT_WRITE, flags, usize::MAX, 0) }; - if ret < 0 { ptr::null_mut() } else { ret as *mut u8 } - } - - #[inline] - pub unsafe fn madvise(ptr: *mut u8, size: usize, advice: usize) { - unsafe { syscall3(SYS_MADVISE, ptr as usize, size, advice) }; - } -} - -#[cfg(all(target_os = "linux", target_arch = "aarch64"))] -mod imp { - use std::ptr; - - const SYS_MMAP: usize = 222; - const SYS_MADVISE: usize = 233; - - const PROT_READ: usize = 1; - const PROT_WRITE: usize = 2; - const MAP_PRIVATE: usize = 0x02; - const MAP_ANONYMOUS: usize = 0x20; - const MAP_NORESERVE: usize = 0x4000; - - pub const MADV_NOHUGEPAGE: usize = 15; - - #[inline] - unsafe fn syscall6(nr: usize, a1: usize, a2: usize, a3: usize, a4: usize, a5: usize, a6: usize) -> isize { - let ret: isize; - unsafe { - std::arch::asm!( - "svc 0", - in("x8") nr, - inlateout("x0") a1 as isize => ret, - in("x1") a2, - in("x2") a3, - in("x3") a4, - in("x4") a5, - in("x5") a6, - options(nostack), - ); - } - ret - } - - #[inline] - unsafe fn syscall3(nr: usize, a1: usize, a2: usize, a3: usize) -> isize { - let ret: isize; - unsafe { - std::arch::asm!( - "svc 0", - in("x8") nr, - inlateout("x0") a1 as isize => ret, - in("x1") a2, - in("x2") a3, - options(nostack), - ); - } - ret - } - - #[inline] - pub unsafe fn mmap_anonymous(size: usize) -> *mut u8 { - let flags = MAP_PRIVATE | MAP_ANONYMOUS | MAP_NORESERVE; - let ret = unsafe { syscall6(SYS_MMAP, 0, size, PROT_READ | PROT_WRITE, flags, usize::MAX, 0) }; - if ret < 0 { ptr::null_mut() } else { ret as *mut u8 } - } - - #[inline] - pub unsafe fn madvise(ptr: *mut u8, size: usize, advice: usize) { - unsafe { syscall3(SYS_MADVISE, ptr as usize, size, advice) }; - } -} - -#[cfg(not(all(target_os = "linux", any(target_arch = "x86_64", target_arch = "aarch64"))))] -mod imp { - use std::ptr; - - pub const MADV_NOHUGEPAGE: usize = 15; - - #[inline] - pub unsafe fn mmap_anonymous(size: usize) -> *mut u8 { - // MAP_NORESERVE is Linux-only. macOS lazily backs anonymous mappings - // with physical memory by default, so the large virtual reservation - // is fine without NORESERVE. - let prot = libc::PROT_READ | libc::PROT_WRITE; - let flags = libc::MAP_PRIVATE | libc::MAP_ANON; - let ret = unsafe { libc::mmap(ptr::null_mut(), size, prot, flags, -1, 0) }; - if ret == libc::MAP_FAILED { - ptr::null_mut() - } else { - ret.cast::() - } - } - - #[inline] - pub unsafe fn madvise(_ptr: *mut u8, _size: usize, _advice: usize) { - // The advice values we pass are Linux-specific. - } -} - -pub use imp::{MADV_NOHUGEPAGE, madvise, mmap_anonymous}; diff --git a/crates/backend/zk-alloc/tests/test_rayon.rs b/crates/backend/zk-alloc/tests/test_rayon.rs deleted file mode 100644 index ae084af21..000000000 --- a/crates/backend/zk-alloc/tests/test_rayon.rs +++ /dev/null @@ -1,26 +0,0 @@ -//! Regression test for the bug prevented by `system_info::flush_rayon`. - -use rayon::prelude::*; - -#[global_allocator] -static A: zk_alloc::ZkAllocator = zk_alloc::ZkAllocator; - -#[test] -fn rayon_does_not_corrupt_zkalloc() { - zk_alloc::init(); - let _: u64 = (0..1_000_000_u64).into_par_iter().sum(); - - zk_alloc::begin_phase(); - for _ in 0..200 { - rayon::join(|| {}, || {}); - } - zk_alloc::end_phase(); - - zk_alloc::begin_phase(); - let canary = vec![0xAB_u8; 8192]; - rayon::join(|| {}, || {}); - zk_alloc::end_phase(); - - let pos = canary.iter().position(|&b| b != 0xAB); - assert!(pos.is_none(), "canary corrupted at offset {}", pos.unwrap()); -} diff --git a/crates/lean_compiler/src/a_simplify_lang/mod.rs b/crates/lean_compiler/src/a_simplify_lang/mod.rs index c824c118c..d993fa7e4 100644 --- a/crates/lean_compiler/src/a_simplify_lang/mod.rs +++ b/crates/lean_compiler/src/a_simplify_lang/mod.rs @@ -3174,10 +3174,8 @@ fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap) fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) -> Result<(), String> { for line in lines { match line { - Line::ForwardDeclaration { var, .. } => { - if map.contains_key(var) { - return Err(format!("Variable {var} is a constant")); - } + Line::ForwardDeclaration { var, .. } if map.contains_key(var) => { + return Err(format!("Variable {var} is a constant")); } Line::Statement { targets, .. } => { for target in targets.iter() { diff --git a/crates/lean_compiler/src/c_compile_final.rs b/crates/lean_compiler/src/c_compile_final.rs index e10516e11..c6e06ac32 100644 --- a/crates/lean_compiler/src/c_compile_final.rs +++ b/crates/lean_compiler/src/c_compile_final.rs @@ -132,7 +132,10 @@ pub fn compile_to_low_level_bytecode( validate_instruction(instruction)?; } - let instructions_encoded = instructions.par_iter().map(field_representation).collect::>(); + let mut instructions_encoded: Vec<[F; N_INSTRUCTION_COLUMNS]> = unsafe { uninitialized_vec(instructions.len()) }; + parallel::par_for_each_mut(&mut instructions_encoded, |i, out| { + *out = field_representation(&instructions[i]); + }); let mut instructions_multilinear = vec![]; for instr in &instructions_encoded { diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index aaf50be3b..a952b373b 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -165,8 +165,9 @@ pub fn prove_execution( }) .collect(); let _span = info_span!("Computing shifted columns for AIR sumcheck").entered(); + // Only a few tables; run them serially and let `compute_shifted_columns` use the full pool. let shifted_rows: Vec>> = ALL_TABLES - .par_iter() + .iter() .zip(&column_refs) .map(|(table, cols)| compute_shifted_columns(table.n_shift_columns(), cols)) .collect(); diff --git a/crates/lean_prover/src/trace_gen.rs b/crates/lean_prover/src/trace_gen.rs index cd0e401be..86a736e25 100644 --- a/crates/lean_prover/src/trace_gen.rs +++ b/crates/lean_prover/src/trace_gen.rs @@ -1,7 +1,7 @@ use backend::*; use lean_vm::*; use std::{array, collections::BTreeMap}; -use utils::{ToUsize, get_poseidon_16_of_zero, transposed_par_iter_mut}; +use utils::{ToUsize, get_poseidon_16_of_zero, transposed_par_for_each_mut}; #[derive(Debug)] pub struct ExecutionTrace { @@ -27,74 +27,76 @@ pub fn get_execution_trace( } } - transposed_par_iter_mut(&mut main_trace) - .zip(execution_result.pcs.par_iter()) - .zip(execution_result.fps.par_iter()) - .for_each(|((trace_row, &pc), &fp)| { - let instruction = &bytecode.code[pc].instruction; - let field_repr = &bytecode.instructions_multilinear[pc * N_INSTRUCTION_COLUMNS.next_power_of_two()..] - [..N_INSTRUCTION_COLUMNS]; - - let flag_a = field_repr[instr_idx(EXEC_COL_FLAG_A)]; - let flag_b = field_repr[instr_idx(EXEC_COL_FLAG_B)]; - let flag_c = field_repr[instr_idx(EXEC_COL_FLAG_C)]; - let flag_c_fp = field_repr[instr_idx(EXEC_COL_FLAG_C_FP)]; - let flag_ab_fp = field_repr[instr_idx(EXEC_COL_FLAG_AB_FP)]; - let aux_1 = field_repr[instr_idx(EXEC_COL_AUX_1)]; - let is_deref = aux_1 == F::TWO; - - let mut addr_a = F::ZERO; - if flag_a.is_zero() && flag_ab_fp.is_zero() { - addr_a = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_A)]; - } - let value_a = memory.0.get(addr_a.to_usize()).copied().flatten().unwrap_or_default(); - - let mut addr_b = F::ZERO; - if flag_b.is_zero() && flag_ab_fp.is_zero() { - addr_b = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_B)]; - } else if is_deref { - // DEREF: addr_B = value_A + operand_B - addr_b = value_a + field_repr[instr_idx(EXEC_COL_OPERAND_B)]; - } - let value_b = memory.0.get(addr_b.to_usize()).copied().flatten().unwrap_or_default(); + transposed_par_for_each_mut(&mut main_trace, |i, trace_row| { + let pc = execution_result.pcs[i]; + let fp = execution_result.fps[i]; + let instruction = &bytecode.code[pc].instruction; + let field_repr = &bytecode.instructions_multilinear[pc * N_INSTRUCTION_COLUMNS.next_power_of_two()..] + [..N_INSTRUCTION_COLUMNS]; + + let flag_a = field_repr[instr_idx(EXEC_COL_FLAG_A)]; + let flag_b = field_repr[instr_idx(EXEC_COL_FLAG_B)]; + let flag_c = field_repr[instr_idx(EXEC_COL_FLAG_C)]; + let flag_c_fp = field_repr[instr_idx(EXEC_COL_FLAG_C_FP)]; + let flag_ab_fp = field_repr[instr_idx(EXEC_COL_FLAG_AB_FP)]; + let aux_1 = field_repr[instr_idx(EXEC_COL_AUX_1)]; + let is_deref = aux_1 == F::TWO; + + let mut addr_a = F::ZERO; + if flag_a.is_zero() && flag_ab_fp.is_zero() { + addr_a = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_A)]; + } + let value_a = memory.0.get(addr_a.to_usize()).copied().flatten().unwrap_or_default(); + + let mut addr_b = F::ZERO; + if flag_b.is_zero() && flag_ab_fp.is_zero() { + addr_b = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_B)]; + } else if is_deref { + // DEREF: addr_B = value_A + operand_B + addr_b = value_a + field_repr[instr_idx(EXEC_COL_OPERAND_B)]; + } + let value_b = memory.0.get(addr_b.to_usize()).copied().flatten().unwrap_or_default(); - let mut addr_c = F::ZERO; - if flag_c.is_zero() && flag_c_fp.is_zero() { - addr_c = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_C)]; - } - let value_c = memory.0.get(addr_c.to_usize()).copied().flatten().unwrap_or_default(); + let mut addr_c = F::ZERO; + if flag_c.is_zero() && flag_c_fp.is_zero() { + addr_c = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_C)]; + } + let value_c = memory.0.get(addr_c.to_usize()).copied().flatten().unwrap_or_default(); - for (j, field) in field_repr.iter().enumerate() { - *trace_row[j + N_RUNTIME_COLUMNS] = *field; - } + for (j, field) in field_repr.iter().enumerate() { + *trace_row[j + N_RUNTIME_COLUMNS] = *field; + } - let nu_a = flag_a * field_repr[instr_idx(EXEC_COL_OPERAND_A)] - + (F::ONE - flag_a - flag_ab_fp) * value_a - + flag_ab_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_A)]); - let nu_b = flag_b * field_repr[instr_idx(EXEC_COL_OPERAND_B)] - + (F::ONE - flag_b - flag_ab_fp) * value_b - + flag_ab_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_B)]); - let nu_c = flag_c * field_repr[instr_idx(EXEC_COL_OPERAND_C)] - + (F::ONE - flag_c - flag_c_fp) * value_c - + flag_c_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_C)]); - if let Instruction::Precompile(..) = instruction { - *trace_row[EXEC_COL_FLAG_PRECOMPILE] = F::ONE; - } - *trace_row[EXEC_COL_NU_A] = nu_a; - *trace_row[EXEC_COL_NU_B] = nu_b; - *trace_row[EXEC_COL_NU_C] = nu_c; - - *trace_row[EXEC_COL_VALUE_A] = value_a; - *trace_row[EXEC_COL_VALUE_B] = value_b; - *trace_row[EXEC_COL_VALUE_C] = value_c; - *trace_row[EXEC_COL_PC] = F::from_usize(pc); - *trace_row[EXEC_COL_FP] = F::from_usize(fp); - *trace_row[EXEC_COL_ADDR_A] = addr_a; - *trace_row[EXEC_COL_ADDR_B] = addr_b; - *trace_row[EXEC_COL_ADDR_C] = addr_c; - }); + let nu_a = flag_a * field_repr[instr_idx(EXEC_COL_OPERAND_A)] + + (F::ONE - flag_a - flag_ab_fp) * value_a + + flag_ab_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_A)]); + let nu_b = flag_b * field_repr[instr_idx(EXEC_COL_OPERAND_B)] + + (F::ONE - flag_b - flag_ab_fp) * value_b + + flag_ab_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_B)]); + let nu_c = flag_c * field_repr[instr_idx(EXEC_COL_OPERAND_C)] + + (F::ONE - flag_c - flag_c_fp) * value_c + + flag_c_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_C)]); + if let Instruction::Precompile(..) = instruction { + *trace_row[EXEC_COL_FLAG_PRECOMPILE] = F::ONE; + } + *trace_row[EXEC_COL_NU_A] = nu_a; + *trace_row[EXEC_COL_NU_B] = nu_b; + *trace_row[EXEC_COL_NU_C] = nu_c; + + *trace_row[EXEC_COL_VALUE_A] = value_a; + *trace_row[EXEC_COL_VALUE_B] = value_b; + *trace_row[EXEC_COL_VALUE_C] = value_c; + *trace_row[EXEC_COL_PC] = F::from_usize(pc); + *trace_row[EXEC_COL_FP] = F::from_usize(fp); + *trace_row[EXEC_COL_ADDR_A] = addr_a; + *trace_row[EXEC_COL_ADDR_B] = addr_b; + *trace_row[EXEC_COL_ADDR_C] = addr_c; + }); - let mut memory_padded = memory.0.par_iter().map(|&v| v.unwrap_or(F::ZERO)).collect::>(); + let mut memory_padded: Vec = unsafe { uninitialized_vec(memory.0.len()) }; + parallel::par_for_each_mut(&mut memory_padded, |i, slot| { + *slot = memory.0[i].unwrap_or(F::ZERO); + }); // Write [0000000000000000 | poseidon_compress(0000000000000000)] (to make lookups work on padding-rows). let padding_zero_vec_ptr = memory_padded.len(); @@ -124,23 +126,22 @@ pub fn get_execution_trace( const N: usize = HALF_DIGEST_LEN + DIGEST_LEN; let cols: &mut [Vec; N] = (&mut right[..N]).try_into().unwrap(); - transposed_par_iter_mut(cols) - .zip(flag_out4_col) - .zip(flag_out8_col) - .zip(nu_c_col) - .for_each(|(((row, &flag_out4), &flag_out8), &nu_c)| { - let base = nu_c.to_usize(); - if flag_out4 == F::ONE { - for j in 0..HALF_DIGEST_LEN { - *row[j] = memory_padded[base + HALF_DIGEST_LEN + j]; - } + transposed_par_for_each_mut(cols, |i, row| { + let flag_out4 = flag_out4_col[i]; + let flag_out8 = flag_out8_col[i]; + let nu_c = nu_c_col[i]; + let base = nu_c.to_usize(); + if flag_out4 == F::ONE { + for j in 0..HALF_DIGEST_LEN { + *row[j] = memory_padded[base + HALF_DIGEST_LEN + j]; } - if flag_out8 == F::ONE || flag_out4 == F::ONE { - for j in 0..DIGEST_LEN { - *row[HALF_DIGEST_LEN + j] = memory_padded[base + DIGEST_LEN + j]; - } + } + if flag_out8 == F::ONE || flag_out4 == F::ONE { + for j in 0..DIGEST_LEN { + *row[HALF_DIGEST_LEN + j] = memory_padded[base + DIGEST_LEN + j]; } - }); + } + }); } let extension_op_trace = traces.get_mut(&Table::extension_op()).unwrap(); @@ -197,7 +198,8 @@ fn pad_table( trace.log_n_rows = log2_ceil_usize(h + 1).max(min_log_n_rows); let n_rows = 1 << trace.log_n_rows; let padding_row = table.padding_row(zero_vec_ptr, null_poseidon_16_hash_ptr, ending_pc); - trace.columns.par_iter_mut().enumerate().for_each(|(i, col)| { + parallel::par_chunks_mut(&mut trace.columns, 1, |i, slot| { + let col = &mut slot[0]; assert!(col.len() <= h); // potentially some columns have not been filled (in Poseidon -> we fill it later with SIMD + parallelism), but the first one should always be representative col.resize(n_rows, padding_row[i]); }); diff --git a/crates/lean_vm/src/execution/memory.rs b/crates/lean_vm/src/execution/memory.rs index 364eb7471..38b5a41e2 100644 --- a/crates/lean_vm/src/execution/memory.rs +++ b/crates/lean_vm/src/execution/memory.rs @@ -11,6 +11,16 @@ pub trait MemoryAccess { (0..len).map(|i| self.get(start + i)).collect() } + /// In-place version of [`get_slice`] that writes into a caller-provided buffer, + /// avoiding a per-call heap allocation on the hot interpreter path (Poseidon / + /// extension-op slice reads run hundreds of thousands of times per proof). + fn get_slice_into(&self, start: usize, dest: &mut [F]) -> Result<(), RunnerError> { + for (i, d) in dest.iter_mut().enumerate() { + *d = self.get(start + i)?; + } + Ok(()) + } + fn set_slice(&mut self, start: usize, values: &[F]) -> Result<(), RunnerError> { for (i, v) in values.iter().enumerate() { self.set(start + i, *v)?; @@ -77,7 +87,7 @@ impl MemoryAccess for Memory { impl Memory { pub fn new(public_memory: Vec) -> Self { - Self(public_memory.into_par_iter().map(Some).collect()) + Self(public_memory.into_iter().map(Some).collect()) } pub fn get(&self, index: usize) -> Result { diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index f00e04880..8a47e716d 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -333,7 +333,12 @@ fn execute_bytecode_helper( None }; let runtime_memory_size = memory.0.len() - PUBLIC_INPUT_LEN - witness.preamble_memory_len; - let used_memory_cells = memory.0.par_iter().filter(|&&x| x.is_some()).count(); + let used_memory_cells = parallel::map_reduce( + memory.0.len(), + || 0usize, + |i| usize::from(memory.0[i].is_some()), + |a, b| a + b, + ); let metadata = ExecutionMetadata { cycles: trace.pcs.len(), memory: memory.0.len(), @@ -432,13 +437,26 @@ fn handle_parallel_batch( let split_at = batch.batch_fp + stride; // end of iteration 0's frame let (left, right) = memory.0.split_at_mut(split_at); let shared: &[Option] = &*left; - let segment_slices: Vec<&mut [Option]> = right.chunks_mut(stride).take(n_par).collect(); + let mut segment_slices: Vec<&mut [Option]> = right.chunks_mut(stride).take(n_par).collect(); type SegResult = Result<(Trace, Vec<(usize, F)>), RunnerError>; - let results: Vec = segment_slices - .into_par_iter() - .enumerate() - .map(|(i, seg_slice)| { + + // Raw base pointer + length per disjoint segment, so the pool can run each segment + // on its own slice without moving `&mut` references through the `Fn` task closure. + // SAFETY: segments are non-overlapping `chunks_mut` of `right`; task `i` touches only `i`. + let seg_info: Vec<(parallel::SendPtr>, usize)> = segment_slices + .iter_mut() + .map(|s| (parallel::SendPtr(s.as_mut_ptr()), s.len())) + .collect(); + // Release the `&mut` borrows so only the raw pointers alias the segments. + drop(segment_slices); + + let mut results: Vec> = (0..n_par).map(|_| None).collect(); + parallel::par_chunks_mut(&mut results, 1, |i, out| { + let (seg_ptr, seg_len) = &seg_info[i]; + // SAFETY: distinct `i` reconstruct disjoint segments of `right`, valid for the dispatch. + let seg_slice: &mut [Option] = unsafe { std::slice::from_raw_parts_mut(seg_ptr.0, *seg_len) }; + out[0] = Some((|| -> SegResult { let seg_start = split_at + i * stride; let mut seg_mem = SegmentMemory::new(shared, seg_slice, seg_start); let fp_i = batch.batch_fp + (i + 1) * stride; @@ -452,8 +470,10 @@ fn handle_parallel_batch( cursor.index += i * delta; } } - let seg_start_indices: HashMap<_, _> = - seg_named_hints.iter().map(|(name, c)| (name.clone(), c.index)).collect(); + let seg_start_indices: HashMap<_, _> = seg_named_hints + .iter() + .map(|(name, c)| (name.clone(), c.index)) + .collect(); let mut hints = HintState { diagnostics: None, named_hints: &mut seg_named_hints, @@ -478,8 +498,9 @@ fn handle_parallel_batch( } let deferred = seg_mem.into_deferred_writes(); Ok((seg_trace, deferred)) - }) - .collect(); + })()); + }); + let results: Vec = results.into_iter().map(Option::unwrap).collect(); for (idx, result) in results.into_iter().enumerate() { let (seg_trace, deferred) = result.map_err(|e| RunnerError::ParallelSegmentFailed(idx + 1, Box::new(e)))?; diff --git a/crates/lean_vm/src/tables/poseidon/mod.rs b/crates/lean_vm/src/tables/poseidon/mod.rs index fb1efcb8d..f0fff6dd0 100644 --- a/crates/lean_vm/src/tables/poseidon/mod.rs +++ b/crates/lean_vm/src/tables/poseidon/mod.rs @@ -241,14 +241,14 @@ impl TableT for Poseidon16Precompile { } else { arg_a_usize + HALF_DIGEST_LEN }; - let arg0_first = ctx.memory.get_slice(left_first_addr, HALF_DIGEST_LEN)?; - let arg0_second = ctx.memory.get_slice(left_second_addr, HALF_DIGEST_LEN)?; - let arg1 = ctx.memory.get_slice(arg_b.to_usize(), DIGEST_LEN)?; - + // Fill the Poseidon input array directly from memory — no per-call Vec allocation + // (this runs once per Poseidon instruction, the dominant small-alloc source). let mut input = [F::ZERO; DIGEST_LEN * 2]; - input[..HALF_DIGEST_LEN].copy_from_slice(&arg0_first); - input[HALF_DIGEST_LEN..DIGEST_LEN].copy_from_slice(&arg0_second); - input[DIGEST_LEN..].copy_from_slice(&arg1); + ctx.memory + .get_slice_into(left_first_addr, &mut input[..HALF_DIGEST_LEN])?; + ctx.memory + .get_slice_into(left_second_addr, &mut input[HALF_DIGEST_LEN..DIGEST_LEN])?; + ctx.memory.get_slice_into(arg_b.to_usize(), &mut input[DIGEST_LEN..])?; let res_addr = index_res_a.to_usize(); if permute { diff --git a/crates/lean_vm/src/tables/poseidon/trace_gen.rs b/crates/lean_vm/src/tables/poseidon/trace_gen.rs index 9022f6c33..dc3963b75 100644 --- a/crates/lean_vm/src/tables/poseidon/trace_gen.rs +++ b/crates/lean_vm/src/tables/poseidon/trace_gen.rs @@ -20,9 +20,10 @@ pub fn fill_trace_poseidon_16(trace: &mut [Vec]) { const N_COLS: usize = super::num_cols_poseidon_16(); - // fill the packed rows + // fill the packed rows. Bind a fixed-size array ref so the per-row `array::from_fn` + // indexing elides bounds checks (one length check here, none in the hot loop). let cols: &[&[FPacking]; N_COLS] = (&trace_packed[..N_COLS]).try_into().unwrap(); - (0..m / packing_width::()).into_par_iter().for_each(|i| { + parallel::for_each_index(m / packing_width::(), |i| { let ptrs: [*mut FPacking; N_COLS] = std::array::from_fn(|c| unsafe { (cols[c].as_ptr() as *mut FPacking).add(i) }); let perm: &mut Poseidon1Cols16<&mut FPacking> = diff --git a/crates/rec_aggregation/Cargo.toml b/crates/rec_aggregation/Cargo.toml index ac111ba3f..a28af7b9b 100644 --- a/crates/rec_aggregation/Cargo.toml +++ b/crates/rec_aggregation/Cargo.toml @@ -8,7 +8,6 @@ workspace = true [features] prox-gaps-conjecture = ["lean_prover/prox-gaps-conjecture"] -standard-alloc = [] [dependencies] utils.workspace = true @@ -25,7 +24,6 @@ backend.workspace = true postcard.workspace = true lz4_flex.workspace = true serde.workspace = true -zk-alloc.workspace = true [target.'cfg(target_os = "macos")'.dependencies] objc2 = { version = "0.6.4", default-features = false, features = ["std"] } diff --git a/crates/rec_aggregation/src/benchmark.rs b/crates/rec_aggregation/src/benchmark.rs index c9f188fdd..c80174d0a 100644 --- a/crates/rec_aggregation/src/benchmark.rs +++ b/crates/rec_aggregation/src/benchmark.rs @@ -397,9 +397,6 @@ fn build_aggregation( let mut last_result: Option = None; let own_display_index = display_index + count_nodes(topology) - 1; for _ in 0..repeat { - #[cfg(not(feature = "standard-alloc"))] - zk_alloc::begin_phase(); - let time = Instant::now(); let result = aggregate_single_msg_signatures( &children, @@ -411,13 +408,6 @@ fn build_aggregation( .unwrap(); let elapsed = time.elapsed(); - // Clone the outputs out of the arena before the next phase resets its slabs. - #[cfg(not(feature = "standard-alloc"))] - let result = { - zk_alloc::end_phase(); - result.clone() - }; - times.push(elapsed.as_secs_f64()); last_result = Some(result); diff --git a/crates/rec_aggregation/src/bytecode_claims.rs b/crates/rec_aggregation/src/bytecode_claims.rs index 91c44b369..081347b8a 100644 --- a/crates/rec_aggregation/src/bytecode_claims.rs +++ b/crates/rec_aggregation/src/bytecode_claims.rs @@ -64,15 +64,11 @@ pub(crate) fn reduce_bytecode_claims(verified: &[InnerVerified]) -> ReducedBytec let alpha: EF = reduction_prover.sample(); let alpha_powers: Vec = alpha.powers().take(n_claims).collect(); - let weights_packed = claims - .par_iter() - .zip(&alpha_powers) - .map(|(eval, &alpha_i)| eval_eq_packed_scaled(&eval.point.0, alpha_i)) - .reduce_with(|mut acc, eq_i| { - acc.par_iter_mut().zip(&eq_i).for_each(|(w, e)| *w += *e); - acc - }) - .unwrap(); + let n_vars = claims[0].point.0.len(); + let mut weights_packed = EFPacking::::zero_vec(1 << (n_vars - packing_log_width::())); + for (claim, &alpha_pow) in claims.iter().zip(&alpha_powers) { + compute_eval_eq_packed::(&claim.point.0, &mut weights_packed, alpha_pow); + } let claimed_sum: EF = dot_product(claims.iter().map(|c| c.value), alpha_powers.iter().copied()); diff --git a/crates/sub_protocols/src/air_sumcheck.rs b/crates/sub_protocols/src/air_sumcheck.rs index 0f536d7fa..8600de649 100644 --- a/crates/sub_protocols/src/air_sumcheck.rs +++ b/crates/sub_protocols/src/air_sumcheck.rs @@ -89,22 +89,20 @@ where let _span = info_span!("chunk-bit-reversing columns").entered(); let chunk_size = 1usize << pivot; let shift = usize::BITS as usize - pivot; - let bit_reversed = cols - .par_iter() - .map(|&src| { - let mut dst: Vec> = unsafe { uninitialized_vec(src.len()) }; - let src_u = PFPacking::::unpack_slice(src); - let dst_u = PFPacking::::unpack_slice_mut(&mut dst); - for (src_chunk, dst_chunk) in - src_u.chunks_exact(chunk_size).zip(dst_u.chunks_exact_mut(chunk_size)) - { - for (p, slot) in dst_chunk.iter_mut().enumerate() { - *slot = src_chunk[p.reverse_bits() >> shift]; - } + let mut bit_reversed: Vec>> = (0..cols.len()).map(|_| Vec::new()).collect(); + parallel::par_chunks_mut(&mut bit_reversed, 1, |i, out_slot| { + let src = cols[i]; + let mut dst: Vec> = unsafe { uninitialized_vec(src.len()) }; + let src_u = PFPacking::::unpack_slice(src); + let dst_u = PFPacking::::unpack_slice_mut(&mut dst); + for (src_chunk, dst_chunk) in src_u.chunks_exact(chunk_size).zip(dst_u.chunks_exact_mut(chunk_size)) + { + for (p, slot) in dst_chunk.iter_mut().enumerate() { + *slot = src_chunk[p.reverse_bits() >> shift]; } - dst - }) - .collect(); + } + out_slot[0] = dst; + }); MleGroup::Owned(MleGroupOwned::BasePacked(bit_reversed)) } _ => unreachable!(), @@ -438,120 +436,112 @@ where let hi_zs_halved: Vec<_> = hi_zs.iter().map(|&tz| tz.halve()).collect(); let lagrange_coeffs = lagrange_basis_evals(&low_zs, &hi_zs); - let acc = (0..active_count_pairs) - .into_par_iter() - .fold( - || { - ( - vec![EFPacking::::ZERO; degree], - Vec::::with_capacity(n_cols), - Vec::::with_capacity(n_cols), - vec![EFPacking::::ZERO; n_full], - Vec::::new(), - Vec::::new(), - Vec::::new(), - ) - }, - |(mut acc, mut point, mut diff, mut low_evals, mut state_0, mut state_2, mut cached_buf), new_j| { - let i_hi = new_j >> fold_bit; - let i_lo = new_j & lo_mask; - let i0 = (i_hi << (fold_bit + 1)) | i_lo; - let i1 = i0 | stride; - let partial_eq = get_split_eq(new_j); - - // `point` holds column values at z=0; `diff[k] = col_k[i1] - col_k[i0]`. - // Invariant for the rest of this closure: `col_k(z) = point[k] + z · diff[k]`, - // so advancing z by 1 means `point[k] += diff[k]` for all k. - point.clear(); - diff.clear(); - for c in cols { - let lo = c[i0]; - let hi = c[i1]; - point.push(lo); - diff.push(hi - lo); - } + let acc = parallel::map_reduce_with_state( + active_count_pairs, + || { + ( + Vec::::with_capacity(n_cols), + Vec::::with_capacity(n_cols), + vec![EFPacking::::ZERO; n_full], + Vec::::new(), + Vec::::new(), + Vec::::new(), + ) + }, + || vec![EFPacking::::ZERO; degree], + |(point, diff, low_evals, state_0, state_2, cached_buf), acc, new_j| { + let i_hi = new_j >> fold_bit; + let i_lo = new_j & lo_mask; + let i0 = (i_hi << (fold_bit + 1)) | i_lo; + let i1 = i0 | stride; + let partial_eq = get_split_eq(new_j); + + // `point` holds column values at z=0; `diff[k] = col_k[i1] - col_k[i0]`. + // Invariant for the rest of this closure: `col_k(z) = point[k] + z · diff[k]`, + // so advancing z by 1 means `point[k] += diff[k]` for all k. + point.clear(); + diff.clear(); + for c in cols { + let lo = c[i0]; + let hi = c[i1]; + point.push(lo); + diff.push(hi - lo); + } - // Phase 1: full AIR constraints + // Phase 1: full AIR constraints - // z = 0: full eval, capture post-block state. - { - let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); - folder.cached_state = Some(state_0); - Air::eval(computation, &mut folder, extra_data); - acc[0] += folder.accumulator * partial_eq; - low_evals[0] = folder.accumulator_low; - state_0 = folder.cached_state.unwrap(); - } + // z = 0: full eval, capture post-block state. + { + let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); + folder.cached_state = Some(std::mem::take(state_0)); + Air::eval(computation, &mut folder, extra_data); + acc[0] += folder.accumulator * partial_eq; + low_evals[0] = folder.accumulator_low; + *state_0 = folder.cached_state.unwrap(); + } - // z = 2: advance `point` by 2·diff, full eval, capture post-block state. - // Together with `state_0` this pins down the linear `state(z)` (linear when we "omit" the low degree constraints of the block) + // z = 2: advance `point` by 2·diff, full eval, capture post-block state. + // Together with `state_0` this pins down the linear `state(z)` (linear when we "omit" the low degree constraints of the block) + for k in 0..n_cols { + point[k] += diff[k].double(); + } + { + let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); + folder.cached_state = Some(std::mem::take(state_2)); + Air::eval(computation, &mut folder, extra_data); + acc[1] += folder.accumulator * partial_eq; + low_evals[1] = folder.accumulator_low; + *state_2 = folder.cached_state.unwrap(); + } + + // z = 3, …, d_low+1: still doing full eval + for z_idx in 2..n_full { for k in 0..n_cols { - point[k] += diff[k].double(); - } - { - let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); - folder.cached_state = Some(state_2); - Air::eval(computation, &mut folder, extra_data); - acc[1] += folder.accumulator * partial_eq; - low_evals[1] = folder.accumulator_low; - state_2 = folder.cached_state.unwrap(); + point[k] += diff[k]; } + let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); + Air::eval(computation, &mut folder, extra_data); + acc[z_idx] += folder.accumulator * partial_eq; + low_evals[z_idx] = folder.accumulator_low; + } - // z = 3, …, d_low+1: still doing full eval - for z_idx in 2..n_full { - for k in 0..n_cols { - point[k] += diff[k]; - } - let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); - Air::eval(computation, &mut folder, extra_data); - acc[z_idx] += folder.accumulator * partial_eq; - low_evals[z_idx] = folder.accumulator_low; + // Phase 2: skip the low degree constraints of the block + // For each skipped point, assemble Constraints(z) = high(z) + low(z): + // -high(z): run folder with `skip_low = true` + // -low(z): deduce it via Lagrange-interpolation from previous computations + for t in 0..n_skip { + for k in 0..n_cols { + point[k] += diff[k]; } - // Phase 2: skip the low degree constraints of the block - // For each skipped point, assemble Constraints(z) = high(z) + low(z): - // -high(z): run folder with `skip_low = true` - // -low(z): deduce it via Lagrange-interpolation from previous computations - for t in 0..n_skip { - for k in 0..n_cols { - point[k] += diff[k]; - } - - cached_buf.clear(); - for i in 0..state_0.len() { - cached_buf - .push(state_0[i] + (state_2[i] - state_0[i]) * PFPacking::::from(hi_zs_halved[t])); - } - - let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); - folder.skip_low = true; - folder.cached_state = Some(cached_buf); - folder.low_ci_count = low_n_constraints; - Air::eval(computation, &mut folder, extra_data); - cached_buf = folder.cached_state.unwrap(); - - // low(hi_zs[t]) = Σ_i L_i(hi_zs[t]) · low(low_zs[i]) - let mut low_interpolated = EFPacking::::ZERO; - for (i, lc) in lagrange_coeffs[t].iter().enumerate() { - low_interpolated += low_evals[i] * PFPacking::::from(*lc); - } - - acc[n_full + t] += (folder.accumulator + low_interpolated) * partial_eq; + cached_buf.clear(); + for i in 0..state_0.len() { + cached_buf.push(state_0[i] + (state_2[i] - state_0[i]) * PFPacking::::from(hi_zs_halved[t])); } - (acc, point, diff, low_evals, state_0, state_2, cached_buf) - }, - ) - .map(|(acc, ..)| acc) - .reduce( - || vec![EFPacking::::ZERO; degree], - |mut a, b| { - for i in 0..degree { - a[i] += b[i]; + let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); + folder.skip_low = true; + folder.cached_state = Some(std::mem::take(cached_buf)); + folder.low_ci_count = low_n_constraints; + Air::eval(computation, &mut folder, extra_data); + *cached_buf = folder.cached_state.unwrap(); + + // low(hi_zs[t]) = Σ_i L_i(hi_zs[t]) · low(low_zs[i]) + let mut low_interpolated = EFPacking::::ZERO; + for (i, lc) in lagrange_coeffs[t].iter().enumerate() { + low_interpolated += low_evals[i] * PFPacking::::from(*lc); } - a - }, - ); + + acc[n_full + t] += (folder.accumulator + low_interpolated) * partial_eq; + } + }, + |mut a, b| { + for i in 0..degree { + a[i] += b[i]; + } + a + }, + ); acc.into_iter().map(&unpack_sum).collect() } @@ -581,54 +571,43 @@ where let stride = 1usize << fold_bit; let lo_mask = stride - 1; - let acc = (0..active_count_pairs) - .into_par_iter() - .fold( - || { - ( - vec![EFT::ZERO; degree], - Vec::::with_capacity(n_cols), - Vec::::with_capacity(n_cols), - ) - }, - |(mut acc, mut point, mut diff), new_j| { - let i_hi = new_j >> fold_bit; - let i_lo = new_j & lo_mask; - let i0 = (i_hi << (fold_bit + 1)) | i_lo; - let i1 = i0 | stride; - let partial_eq = get_split_eq(new_j); - point.clear(); - diff.clear(); - for c in cols { - let lo = c[i0]; - let hi = c[i1]; - point.push(lo); - diff.push(hi - lo); - } - // z = 0 then (skip z = 1) z = 2, 3, …, degree. - acc[0] += eval_fn(computation, &point, extra_data) * partial_eq; + let acc = parallel::map_reduce_with_state( + active_count_pairs, + || (Vec::::with_capacity(n_cols), Vec::::with_capacity(n_cols)), + || vec![EFT::ZERO; degree], + |(point, diff), acc, new_j| { + let i_hi = new_j >> fold_bit; + let i_lo = new_j & lo_mask; + let i0 = (i_hi << (fold_bit + 1)) | i_lo; + let i1 = i0 | stride; + let partial_eq = get_split_eq(new_j); + point.clear(); + diff.clear(); + for c in cols { + let lo = c[i0]; + let hi = c[i1]; + point.push(lo); + diff.push(hi - lo); + } + // z = 0 then (skip z = 1) z = 2, 3, …, degree. + acc[0] += eval_fn(computation, point, extra_data) * partial_eq; + for k in 0..n_cols { + point[k] += diff[k]; + } + for acc_z in &mut acc[1..] { for k in 0..n_cols { point[k] += diff[k]; } - for acc_z in &mut acc[1..] { - for k in 0..n_cols { - point[k] += diff[k]; - } - *acc_z += eval_fn(computation, &point, extra_data) * partial_eq; - } - (acc, point, diff) - }, - ) - .map(|(acc, _, _)| acc) - .reduce( - || vec![EFT::ZERO; degree], - |mut a, b| { - for i in 0..degree { - a[i] += b[i]; - } - a - }, - ); + *acc_z += eval_fn(computation, point, extra_data) * partial_eq; + } + }, + |mut a, b| { + for i in 0..degree { + a[i] += b[i]; + } + a + }, + ); acc.into_iter().map(unpack_sum).collect() } @@ -680,15 +659,15 @@ pub fn prove_batched_air_sumcheck<'a, EF: ExtensionField>>( pub fn compute_shifted_columns(n_shift_columns: usize, columns: &[&[F]]) -> Vec> { // Convention: the first `n_shift_columns` columns are the ones that get shifted. - columns[..n_shift_columns] - .par_iter() - .map(|column| { - let mut shifted = unsafe { uninitialized_vec(column.len()) }; - shifted[..column.len() - 1].copy_from_slice(&column[1..]); - shifted[column.len() - 1] = column[column.len() - 1]; - shifted - }) - .collect() + let mut out: Vec> = (0..n_shift_columns).map(|_| Vec::new()).collect(); + parallel::par_chunks_mut(&mut out, 1, |i, slot| { + let column = columns[i]; + let mut shifted = unsafe { uninitialized_vec(column.len()) }; + shifted[..column.len() - 1].copy_from_slice(&column[1..]); + shifted[column.len() - 1] = column[column.len() - 1]; + slot[0] = shifted; + }); + out } pub fn natural_ordering_point_for_session(sumcheck_air_point: &[EF], log_n_rows: usize) -> Vec { diff --git a/crates/sub_protocols/src/logup.rs b/crates/sub_protocols/src/logup.rs index 55af0a320..85578a1d7 100644 --- a/crates/sub_protocols/src/logup.rs +++ b/crates/sub_protocols/src/logup.rs @@ -72,15 +72,13 @@ pub fn prove_generic_logup( }; let fill_num_from = |dst: &mut [F], src: &[F], neg: bool| { - dst.par_chunks_exact_mut(chunk_size) - .enumerate() - .for_each(|(c, dst_chunk)| { - let src_chunk = &src[c * chunk_size..][..chunk_size]; - for (i, slot) in dst_chunk.iter_mut().enumerate() { - let v = src_chunk[i.reverse_bits() >> chunk_shift]; - *slot = if neg { -v } else { v }; - } - }); + parallel::par_chunks_mut(dst, chunk_size, |c, dst_chunk| { + let src_chunk = &src[c * chunk_size..][..chunk_size]; + for (i, slot) in dst_chunk.iter_mut().enumerate() { + let v = src_chunk[i.reverse_bits() >> chunk_shift]; + *slot = if neg { -v } else { v }; + } + }); }; let mut offset = 0; @@ -118,12 +116,14 @@ pub fn prove_generic_logup( ); if 1 << log_bytecode < max_table_height { // padding - numerators[offset + (1 << log_bytecode)..offset + max_table_height] - .par_iter_mut() - .for_each(|n| *n = F::ZERO); - denominators[(offset + (1 << log_bytecode)) / width..(offset + max_table_height) / width] - .par_iter_mut() - .for_each(|d| *d = EFPacking::::ONE); + par_fill( + &mut numerators[offset + (1 << log_bytecode)..offset + max_table_height], + |_| F::ZERO, + ); + par_fill( + &mut denominators[(offset + (1 << log_bytecode)) / width..(offset + max_table_height) / width], + |_| EFPacking::::ONE, + ); } offset += max_table_height.max(1 << log_bytecode); @@ -142,17 +142,15 @@ pub fn prove_generic_logup( let col_index = &trace.columns[group.idx_col]; let packed_chunk_size = (1 << log_n_rows) / width; - numerators[offset..][..group_len << log_n_rows] - .par_iter_mut() - .for_each(|n| *n = F::ONE); + par_fill(&mut numerators[offset..][..group_len << log_n_rows], |_| F::ONE); - denominators[offset / width..][..group_len * packed_chunk_size] - .par_chunks_exact_mut(packed_chunk_size) - .enumerate() - .for_each(|(i, denom_chunk)| { + parallel::par_chunks_mut( + &mut denominators[offset / width..][..group_len * packed_chunk_size], + packed_chunk_size, + |i, denom_chunk| { let i_field = F::from_usize(i); let col_value = &trace.columns[group.value_cols[i]]; - denom_chunk.par_iter_mut().enumerate().for_each(|(p, slot)| { + for (p, slot) in denom_chunk.iter_mut().enumerate() { *slot = c_packed - finger_print_packed::( memory_domainsep_packed, @@ -162,8 +160,9 @@ pub fn prove_generic_logup( ], &alphas_packed, ); - }); - }); + } + }, + ); offset += group_len << log_n_rows; bus_idx += group_len; next_group += 1; @@ -175,7 +174,7 @@ pub fn prove_generic_logup( match bus.multiplicity { BusMultiplicity::One => { let val = bus.direction.to_field_flag(); - slice.par_iter_mut().for_each(|n| *n = val); + par_fill(slice, |_| val); } BusMultiplicity::Column(col) => { fill_num_from(slice, &trace.columns[col], matches!(bus.direction, BusDirection::Pull)); @@ -532,5 +531,12 @@ fn fill_denoms(dst: &mut [EFPacking], build: Build) where Build: Fn(usize) -> EFPacking + Sync, { - dst.par_iter_mut().enumerate().for_each(|(p, slot)| *slot = build(p)); + par_fill(dst, build); +} + +/// Fill `dst` in parallel through the in-house pool, computing each slot from its +/// global index. Replaces the rayon `par_iter_mut().enumerate()` constant/index fills. +#[inline] +fn par_fill T + Sync>(dst: &mut [T], build: Build) { + parallel::par_for_each_mut(dst, |i, slot| *slot = build(i)); } diff --git a/crates/sub_protocols/src/quotient_gkr/layers.rs b/crates/sub_protocols/src/quotient_gkr/layers.rs index 0ff9e1663..b6c4502e1 100644 --- a/crates/sub_protocols/src/quotient_gkr/layers.rs +++ b/crates/sub_protocols/src/quotient_gkr/layers.rs @@ -111,13 +111,12 @@ pub(super) fn bit_reverse_chunks(v: &[T], chunk_log: usiz return out; } let shift = usize::BITS as usize - chunk_log; - out.par_chunks_exact_mut(chunk_size) - .zip(v.par_chunks_exact(chunk_size)) - .for_each(|(dst, src)| { - for (p, slot) in dst.iter_mut().enumerate() { - *slot = src[p.reverse_bits() >> shift]; - } - }); + parallel::par_chunks_mut(&mut out, chunk_size, |c, dst| { + let src = &v[c * chunk_size..][..chunk_size]; + for (p, slot) in dst.iter_mut().enumerate() { + *slot = src[p.reverse_bits() >> shift]; + } + }); out } @@ -130,18 +129,18 @@ fn sum_quotients_2_by_2>>(nums: &[EF], dens: &[EF]) -> let mut new_nums: Vec = unsafe { uninitialized_vec(new_active) }; let mut new_dens: Vec = unsafe { uninitialized_vec(new_active) }; - new_nums[..full_pairs] - .par_iter_mut() - .zip(new_dens[..full_pairs].par_iter_mut()) - .enumerate() - .for_each(|(i, (num, den))| { + { + let dp = parallel::SendPtr(new_dens.as_mut_ptr()); + parallel::par_for_each_mut(&mut new_nums[..full_pairs], |i, num| { let n0 = nums[2 * i]; let n1 = nums[2 * i + 1]; let d0 = dens[2 * i]; let d1 = dens[2 * i + 1]; *num = d1 * n0 + d0 * n1; - *den = d0 * d1; + // SAFETY: each `i` writes a distinct slot in `new_dens`, a separate buffer. + unsafe { *dp.add(i) = d0 * d1 }; }); + } // Boundary (at most one pair: a/b + 0/1 = a/b). if full_pairs < new_active { @@ -172,18 +171,18 @@ where let mut new_nums: Vec> = unsafe { uninitialized_vec(nums.len() >> 1) }; let mut new_dens: Vec> = unsafe { uninitialized_vec(nums.len() >> 1) }; - new_nums - .par_iter_mut() - .zip(new_dens.par_iter_mut()) - .enumerate() - .for_each(|(new_j, (num_out, den_out))| { + { + let dp = parallel::SendPtr(new_dens.as_mut_ptr()); + parallel::par_for_each_mut(&mut new_nums, |new_j, num_out| { let i_hi = new_j >> bit; let i_lo = new_j & lo_mask; let i0 = (i_hi << (bit + 1)) | i_lo; let i1 = i0 | stride; *num_out = dens[i1] * nums[i0] + dens[i0] * nums[i1]; - *den_out = dens[i0] * dens[i1]; + // SAFETY: each `new_j` writes a distinct slot in `new_dens`, a separate buffer. + unsafe { *dp.add(new_j) = dens[i0] * dens[i1] }; }); + } (new_nums, new_dens) } diff --git a/crates/sub_protocols/src/quotient_gkr/mod.rs b/crates/sub_protocols/src/quotient_gkr/mod.rs index 26fa25a65..c9e9554fc 100644 --- a/crates/sub_protocols/src/quotient_gkr/mod.rs +++ b/crates/sub_protocols/src/quotient_gkr/mod.rs @@ -207,7 +207,7 @@ mod tests { type EF = QuinticExtensionFieldKB; fn sum_all_quotients(nums: &[F], den: &[EF]) -> EF { - nums.par_iter().zip(den).map(|(&n, &d)| EF::from(n) / d).sum() + nums.iter().zip(den).map(|(&n, &d)| EF::from(n) / d).sum() } fn bit_reverse_chunks_and_pack_ext>>(v: &[EF], chunk_log: usize) -> Vec> { diff --git a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs index 27afd58f7..49e0cc0cf 100644 --- a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs +++ b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs @@ -255,8 +255,8 @@ pub(super) fn run_phase1_sumcheck<'a, EF: ExtensionField>>( if let Some(prev_r) = pending_r { let prev_bit = layer_chunk_log - 1 - w; let mul = |x: EFPacking, a: EF| x * a; - nums = Cow::Owned(fold_multilinear_at_bit(nums.as_ref(), prev_r, prev_bit, &mul)); - dens = Cow::Owned(fold_multilinear_at_bit(dens.as_ref(), prev_r, prev_bit, &mul)); + nums = Cow::Owned(fold_multilinear_at_bit(nums.as_ref(), prev_r, prev_bit, &mul, false)); + dens = Cow::Owned(fold_multilinear_at_bit(dens.as_ref(), prev_r, prev_bit, &mul, false)); } let nums_nat = unpack_and_unreverse_active::(nums.as_ref(), layer_chunk_log); @@ -328,10 +328,7 @@ pub(super) fn run_phase2_sumcheck>>( }; let acc: RoundCoeffs = if active_pairs > PARALLEL_THRESHOLD { - (0..active_pairs) - .into_par_iter() - .map(term) - .reduce(RoundCoeffs::zero, Add::add) + parallel::map_reduce(active_pairs, RoundCoeffs::zero, term, Add::add) } else { (0..active_pairs).map(term).fold(RoundCoeffs::::zero(), Add::add) }; @@ -362,7 +359,9 @@ pub(super) fn run_phase2_sumcheck>>( if new_eq_len > 0 { let fold_eq = |i: usize| eq_table[2 * i] + eq_table[2 * i + 1]; eq_table = if new_eq_len >= PARALLEL_THRESHOLD { - (0..new_eq_len).into_par_iter().map(fold_eq).collect() + let mut out: Vec = unsafe { uninitialized_vec(new_eq_len) }; + parallel::par_for_each_mut(&mut out, |i, slot| *slot = fold_eq(i)); + out } else { (0..new_eq_len).map(fold_eq).collect() }; @@ -392,10 +391,7 @@ fn fold_normal_with_padding>>(m: &[EF], r: EF, pad_val if new_active < PARALLEL_THRESHOLD { out.iter_mut().enumerate().for_each(compute); } else { - out.par_iter_mut() - .with_min_len(PARALLEL_THRESHOLD) - .enumerate() - .for_each(compute); + parallel::par_for_each_mut(&mut out, |i, slot| compute((i, slot))); } out } @@ -420,10 +416,13 @@ where debug_assert_eq!(dens.len(), nums.len()); debug_assert_eq!(eq_within.len(), quarter); - nums.par_chunks_exact(layer_packed) - .zip(dens.par_chunks_exact(layer_packed)) - .enumerate() - .fold(RoundCoeffs::zero, |mut acc, (c, (n_c, d_c))| { + let n_chunks = nums.len() / layer_packed; + parallel::map_reduce( + n_chunks, + RoundCoeffs::zero, + |c| { + let n_c = &nums[c * layer_packed..][..layer_packed]; + let d_c = &dens[c * layer_packed..][..layer_packed]; let eq_o: EF = eq_outer.get(c).copied().unwrap_or(EF::ONE); let mut local = RoundCoeffs::>::zero(); for inner in 0..quarter { @@ -435,10 +434,10 @@ where ); local += coeffs * eq_within[inner]; } - acc += local * eq_o; - acc - }) - .reduce(RoundCoeffs::zero, Add::add) + local * eq_o + }, + Add::add, + ) } #[allow(clippy::type_complexity)] @@ -472,13 +471,19 @@ where let mut new_dens: Vec> = unsafe { uninitialized_vec(active_out_packed) }; let prev_r_packed: EFPacking = as From>::from(prev_r); - let coeffs = nums - .par_chunks_exact(in_packed) - .zip(dens.par_chunks_exact(in_packed)) - .zip(new_nums.par_chunks_exact_mut(out_packed)) - .zip(new_dens.par_chunks_exact_mut(out_packed)) - .enumerate() - .fold(RoundCoeffs::zero, |mut acc, (c, (((n_c, d_c), nn_c), nd_c))| { + let n_chunks = nums.len() / in_packed; + let nn = parallel::SendPtr(new_nums.as_mut_ptr()); + let nd = parallel::SendPtr(new_dens.as_mut_ptr()); + let coeffs = parallel::map_reduce( + n_chunks, + RoundCoeffs::zero, + |c| { + let n_c = &nums[c * in_packed..][..in_packed]; + let d_c = &dens[c * in_packed..][..in_packed]; + // SAFETY: chunk `c` owns the disjoint `out_packed`-sized regions of the two + // output buffers at `c * out_packed`; no other task touches them. + let nn_c = unsafe { std::slice::from_raw_parts_mut(nn.add(c * out_packed), out_packed) }; + let nd_c = unsafe { std::slice::from_raw_parts_mut(nd.add(c * out_packed), out_packed) }; let eq_o: EF = eq_outer.get(c).copied().unwrap_or(EF::ONE); let mut local = RoundCoeffs::>::zero(); for i in 0..in_eighth { @@ -499,10 +504,10 @@ where ); local += round * eq_within[i]; } - acc += local * eq_o; - acc - }) - .reduce(RoundCoeffs::zero, Add::add); + local * eq_o + }, + Add::add, + ); (new_nums, new_dens, coeffs) } diff --git a/crates/utils/src/misc.rs b/crates/utils/src/misc.rs index ff317ae4c..09bad5162 100644 --- a/crates/utils/src/misc.rs +++ b/crates/utils/src/misc.rs @@ -7,15 +7,23 @@ pub fn from_end(slice: &[A], n: usize) -> &[A] { &slice[slice.len() - n..] } -pub fn transposed_par_iter_mut( - array: &mut [Vec; N], // all vectors must have the same length -) -> impl IndexedParallelIterator + '_ { +/// Run `g(i, row)` in parallel over `i in 0..len`, where `row` is `[&mut A; N]` holding +/// the `i`-th element of each of the `N` equal-length vectors (a transposed row). +/// Dispatched through the in-house [`parallel`] pool. +pub fn transposed_par_for_each_mut(array: &mut [Vec; N], g: G) +where + G: Fn(usize, [&mut A; N]) + Sync, +{ + // all vectors must have the same length let len = array[0].len(); let data_ptrs: [AtomicPtr; N] = array.each_mut().map(|v| AtomicPtr::new(v.as_mut_ptr())); - (0..len) - .into_par_iter() - .map(move |i| unsafe { std::array::from_fn(|j| &mut *data_ptrs[j].load(Ordering::Relaxed).add(i)) }) + parallel::for_each_index(len, |i| { + // SAFETY: distinct `i` access disjoint row `i` of each of the `N` vectors, and the + // arrays outlive the dispatch (the dispatcher blocks until all tasks complete). + let row: [&mut A; N] = unsafe { std::array::from_fn(|j| &mut *data_ptrs[j].load(Ordering::Relaxed).add(i)) }; + g(i, row); + }); } pub fn collect_refs(vecs: &[Vec]) -> Vec<&[T]> { diff --git a/crates/utils/src/multilinear.rs b/crates/utils/src/multilinear.rs index 2030ca5f4..0147d28a0 100644 --- a/crates/utils/src/multilinear.rs +++ b/crates/utils/src/multilinear.rs @@ -13,10 +13,12 @@ pub fn multilinears_linear_combination, P: Borro assert_eq!(pols.len(), scalars.len()); let n_vars = log2_strict_usize(pols[0].borrow().len()); assert!(pols.iter().all(|p| log2_strict_usize(p.borrow().len()) == n_vars)); - (0..1 << n_vars) - .into_par_iter() - .map(|i| dot_product(scalars.iter().copied(), pols.iter().map(|p| p.borrow()[i]))) - .collect::>() + let n = 1usize << n_vars; + let mut out: Vec = unsafe { uninitialized_vec(n) }; + parallel::par_for_each_mut(&mut out, |i, slot| { + *slot = dot_product(scalars.iter().copied(), pols.iter().map(|p| p.borrow()[i])); + }); + out } pub fn multilinear_eval_constants_at_right(limit: usize, point: &[F]) -> F { diff --git a/crates/whir/Cargo.toml b/crates/whir/Cargo.toml index 1c2a2b0a7..6845c34b4 100644 --- a/crates/whir/Cargo.toml +++ b/crates/whir/Cargo.toml @@ -14,7 +14,7 @@ symetric = { path = "../backend/symetric", package = "mt-symetric" } system-info.workspace = true itertools.workspace = true -rayon.workspace = true +parallel.workspace = true rand.workspace = true tracing.workspace = true diff --git a/crates/whir/src/dft.rs b/crates/whir/src/dft.rs index 277597eb8..9b0d8329e 100644 --- a/crates/whir/src/dft.rs +++ b/crates/whir/src/dft.rs @@ -29,7 +29,6 @@ use field::PackedValue; use field::{BasedVectorSpace, Field, PackedField, TwoAdicField}; use itertools::Itertools; -use rayon::prelude::*; use tracing::instrument; use utils::{as_base_slice, log2_strict_usize}; @@ -164,7 +163,8 @@ where /// also divide by the height. #[inline] fn par_initial_layers(mat: &mut [F], chunk_size: usize, root_table: &[Vec], width: usize) { - mat.par_chunks_exact_mut(chunk_size).for_each(|chunk| { + let n_full = mat.len() / chunk_size * chunk_size; + parallel::par_chunks_mut(&mut mat[..n_full], chunk_size, |_, chunk| { initial_layers(chunk, root_table, width); }); } @@ -197,14 +197,20 @@ fn dft_layer>(vec: &mut [F], twiddles: &[B], width: us #[inline] fn dft_layer_par>(vec: &mut [F], twiddles: &[B], width: usize) { - vec.par_chunks_exact_mut(twiddles.len() * 2 * width).for_each(|block| { - let (left, right) = block.split_at_mut(twiddles.len() * width); - left.par_chunks_exact_mut(width) - .zip(right.par_chunks_exact_mut(width)) - .zip(twiddles.par_iter()) - .for_each(|((hi_chunk, lo_chunk), twiddle)| { - twiddle.apply_to_rows(hi_chunk, lo_chunk); - }); + let ts = twiddles.len(); + let block_size = 2 * ts * width; + let n_blocks = vec.len() / block_size; + // Flatten (block, group) into one parallel loop over `n_blocks * ts` groups so coarse + // layers (few blocks) still parallelize; guided scheduling keeps a worker's batch of + // consecutive groups within the same block, preserving the per-block cache locality. + let base = parallel::SendPtr(vec.as_mut_ptr()); + parallel::for_each_index(n_blocks * ts, |g| { + let block_base = (g / ts) * block_size; + let ind = g % ts; + // SAFETY: distinct `g` map to disjoint (hi, lo) `width`-rows. + let hi = unsafe { base.slice(block_base + ind * width, width) }; + let lo = unsafe { base.slice(block_base + (ts + ind) * width, width) }; + twiddles[ind].apply_to_rows(hi, lo); }); } @@ -234,40 +240,25 @@ fn dft_layer_par_double, M: MultiLayerButterfly> assert_eq!(twiddles_large.len(), twiddles_small.len() * 2); - // TODO optimal workload size with L1 cache - mat.values - .par_chunks_exact_mut(twiddles_large.len() * 2 * width) - .for_each(|block| { - // (0..twiddles_small.len()).into_par_iter().for_each(|ind| { - // let hi_hi = slice_ref_mut(block, ind * width, width); - // let hi_lo = slice_ref_mut(block, (ind + twiddles_small.len()) * width, width); - // let lo_hi = slice_ref_mut(block, (ind + 2 * twiddles_small.len()) * width, width); - // let lo_lo = slice_ref_mut(block, (ind + 3 * twiddles_small.len()) * width, width); - // multi_butterfly.apply_2_layers( - // ((hi_hi, hi_lo), (lo_hi, lo_lo)), - // ind, - // twiddles_small, - // twiddles_large, - // ); - // }); - let (hi_blocks, lo_blocks) = block.split_at_mut(twiddles_small.len() * width * 2); - let (hi_hi_blocks, hi_lo_blocks) = hi_blocks.split_at_mut(twiddles_small.len() * width); - let (lo_hi_blocks, lo_lo_blocks) = lo_blocks.split_at_mut(twiddles_small.len() * width); - hi_hi_blocks - .par_chunks_exact_mut(width) - .zip(hi_lo_blocks.par_chunks_exact_mut(width)) - .zip(lo_hi_blocks.par_chunks_exact_mut(width)) - .zip(lo_lo_blocks.par_chunks_exact_mut(width)) - .enumerate() - .for_each(|(ind, (((hi_hi, hi_lo), lo_hi), lo_lo))| { - multi_butterfly.apply_2_layers( - ((hi_hi, hi_lo), (lo_hi, lo_lo)), - ind, - twiddles_small, - twiddles_large, - ); - }); - }); + // Flatten (block, inner-group) into one parallel loop. A block is `4·ts` rows of + // `width`; group `ind` touches the 4 rows at sub-block offsets `k·ts + ind` (k=0..3). + // Coarse layers (few blocks) thus still parallelize over their `ts` inner groups, and + // guided scheduling keeps a worker's consecutive groups within one block (cache-local). + let ts = twiddles_small.len(); + let block_size = 4 * ts * width; // == twiddles_large.len() * 2 * width + let n_blocks = mat.values.len() / block_size; + let base = parallel::SendPtr(mat.values.as_mut_ptr()); + parallel::for_each_index(n_blocks * ts, |g| { + let block_base = (g / ts) * block_size; + let ind = g % ts; + let row = |k: usize| block_base + (k * ts + ind) * width; + // SAFETY: distinct `g` map to disjoint sets of 4 `width`-rows. + let hi_hi = unsafe { base.slice(row(0), width) }; + let hi_lo = unsafe { base.slice(row(1), width) }; + let lo_hi = unsafe { base.slice(row(2), width) }; + let lo_lo = unsafe { base.slice(row(3), width) }; + multi_butterfly.apply_2_layers(((hi_hi, hi_lo), (lo_hi, lo_lo)), ind, twiddles_small, twiddles_large); + }); } /// Applies three layers of a Radix-2 FFT butterfly network making use of parallelization. @@ -303,44 +294,38 @@ fn dft_layer_par_triple, M: MultiLayerButterfly> // let inner_chunk_size = // (workload_size::().next_power_of_two() / 8).min(eighth_outer_block_size); - mat.values - .par_chunks_exact_mut(twiddles_large.len() * 2 * width) - .for_each(|block| { - let (hi_blocks, lo_blocks) = block.split_at_mut(twiddles_small.len() * width * 4); - let (hi_hi_blocks, hi_lo_blocks) = hi_blocks.split_at_mut(twiddles_small.len() * width * 2); - let (lo_hi_blocks, lo_lo_blocks) = lo_blocks.split_at_mut(twiddles_small.len() * width * 2); - let (hi_hi_hi_blocks, hi_hi_lo_blocks) = hi_hi_blocks.split_at_mut(twiddles_small.len() * width); - let (hi_lo_hi_blocks, hi_lo_lo_blocks) = hi_lo_blocks.split_at_mut(twiddles_small.len() * width); - let (lo_hi_hi_blocks, lo_hi_lo_blocks) = lo_hi_blocks.split_at_mut(twiddles_small.len() * width); - let (lo_lo_hi_blocks, lo_lo_lo_blocks) = lo_lo_blocks.split_at_mut(twiddles_small.len() * width); - hi_hi_hi_blocks - .par_chunks_exact_mut(width) - .zip(hi_hi_lo_blocks.par_chunks_exact_mut(width)) - .zip(hi_lo_hi_blocks.par_chunks_exact_mut(width)) - .zip(hi_lo_lo_blocks.par_chunks_exact_mut(width)) - .zip(lo_hi_hi_blocks.par_chunks_exact_mut(width)) - .zip(lo_hi_lo_blocks.par_chunks_exact_mut(width)) - .zip(lo_lo_hi_blocks.par_chunks_exact_mut(width)) - .zip(lo_lo_lo_blocks.par_chunks_exact_mut(width)) - .enumerate() - .for_each( - |( - ind, - (((((((hi_hi_hi, hi_hi_lo), hi_lo_hi), hi_lo_lo), lo_hi_hi), lo_hi_lo), lo_lo_hi), lo_lo_lo), - )| { - multi_butterfly.apply_3_layers( - ( - ((hi_hi_hi, hi_hi_lo), (hi_lo_hi, hi_lo_lo)), - ((lo_hi_hi, lo_hi_lo), (lo_lo_hi, lo_lo_lo)), - ), - ind, - twiddles_small, - twiddles_med, - twiddles_large, - ); - }, - ); - }); + // Flatten (block, inner-group) into one parallel loop. A block is `8·ts` rows of + // `width`; group `ind` touches the 8 rows at sub-block offsets `k·ts + ind` (k=0..7). + // Coarse layers still parallelize over their `ts` inner groups; guided scheduling keeps + // a worker's consecutive groups within one block (cache-local). + let ts = twiddles_small.len(); + let block_size = 8 * ts * width; // == twiddles_large.len() * 2 * width + let n_blocks = mat.values.len() / block_size; + let base = parallel::SendPtr(mat.values.as_mut_ptr()); + parallel::for_each_index(n_blocks * ts, |g| { + let block_base = (g / ts) * block_size; + let ind = g % ts; + let row = |k: usize| block_base + (k * ts + ind) * width; + // SAFETY: distinct `g` map to disjoint sets of 8 `width`-rows. + let hi_hi_hi = unsafe { base.slice(row(0), width) }; + let hi_hi_lo = unsafe { base.slice(row(1), width) }; + let hi_lo_hi = unsafe { base.slice(row(2), width) }; + let hi_lo_lo = unsafe { base.slice(row(3), width) }; + let lo_hi_hi = unsafe { base.slice(row(4), width) }; + let lo_hi_lo = unsafe { base.slice(row(5), width) }; + let lo_lo_hi = unsafe { base.slice(row(6), width) }; + let lo_lo_lo = unsafe { base.slice(row(7), width) }; + multi_butterfly.apply_3_layers( + ( + ((hi_hi_hi, hi_hi_lo), (hi_lo_hi, hi_lo_lo)), + ((lo_hi_hi, lo_hi_lo), (lo_lo_hi, lo_lo_lo)), + ), + ind, + twiddles_small, + twiddles_med, + twiddles_large, + ); + }); } /// Applies the remaining layers of the Radix-2 FFT butterfly network in parallel. diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index eadf8212a..035942927 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -12,7 +12,6 @@ use field::PrimeCharacteristicRing; use koala_bear::{KoalaBear, QuinticExtensionFieldKB, default_koalabear_poseidon1_16}; use poly::*; -use rayon::prelude::*; use symetric::merkle::unpack_array; use tracing::instrument; use utils::log2_ceil_usize; @@ -198,22 +197,20 @@ where let mut digests = unsafe { uninitialized_vec(height) }; - digests - .par_chunks_exact_mut(width) - .enumerate() - .for_each(|(i, digests_chunk)| { - let first_row = i * width; - let rtl_iter = matrix.vertically_packed_row_rtl::

(first_row, effective_base_width, n_pad); - let packed_digest: [P; DIGEST_ELEMS] = - symetric::hash_rtl_iter_with_initial_state::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>( - perm, - rtl_iter, - packed_initial_state, - ); - for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { - *dst = src; - } - }); + // `height` is a multiple of `width`, so every chunk is exactly `width` long. + parallel::par_chunks_mut(&mut digests, width, |i, digests_chunk| { + let first_row = i * width; + let rtl_iter = matrix.vertically_packed_row_rtl::

(first_row, effective_base_width, n_pad); + let packed_digest: [P; DIGEST_ELEMS] = + symetric::hash_rtl_iter_with_initial_state::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>( + perm, + rtl_iter, + packed_initial_state, + ); + for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { + *dst = src; + } + }); digests } diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 6636b77c7..6b737b9c4 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -5,7 +5,6 @@ use fiat_shamir::{FSProver, MerklePath, ProofResult}; use field::PrimeCharacteristicRing; use field::{ExtensionField, Field, TwoAdicField}; use poly::*; -use rayon::prelude::*; use sumcheck::{ProductComputation, run_product_sumcheck, sumcheck_prove_many_rounds}; use tracing::{info_span, instrument}; @@ -594,17 +593,15 @@ where for (e, &scalar) in smt.values.iter().zip(&next_gamma_powers) { combined_sum += e.value * scalar; } - chunks_mut - .into_par_iter() - .zip(&indexed_smt_values) - .for_each(|(out_buff, &(origin_index, _))| { - out_buff[..1 << shift] - .par_iter_mut() - .zip(&inner_poly) - .for_each(|(out_elem, &poly_elem)| { - *out_elem += poly_elem * next_gamma_powers[origin_index]; - }); + // Few sparse statements (the outer chunks) but each inner accumulation can be + // large, so parallelize the inner loop per statement (the outer runs serial). + for (out_buff, &(origin_index, _)) in chunks_mut.iter_mut().zip(&indexed_smt_values) { + let out = &mut out_buff[..1 << shift]; + let scalar = next_gamma_powers[origin_index]; + parallel::par_for_each_mut(out, |i, out_elem| { + *out_elem += inner_poly[i] * scalar; }); + } gamma_pow = *next_gamma_powers.last().unwrap() * gamma; } } diff --git a/crates/whir/src/utils.rs b/crates/whir/src/utils.rs index e64799149..06288dc98 100644 --- a/crates/whir/src/utils.rs +++ b/crates/whir/src/utils.rs @@ -6,7 +6,6 @@ use field::Field; use field::PackedValue; use field::{ExtensionField, TwoAdicField}; use poly::*; -use rayon::prelude::*; use std::any::{Any, TypeId}; use std::collections::HashMap; use std::sync::{Arc, Mutex, OnceLock}; @@ -138,15 +137,28 @@ fn prepare_evals_for_fft_unpacked( let log_block_size = log2_strict_usize(block_size); let out_len = block_size * dft_n_cols; - (0..out_len) - .into_par_iter() - .map(|i| { - let block_index = i % dft_n_cols; - let offset_in_block = i / dft_n_cols; - let src_index = ((block_index << log_block_size) + offset_in_block) >> log_inv_rate; - unsafe { *evals.get_unchecked(src_index) } - }) - .collect() + let mut out: Vec = unsafe { uninitialized_vec(out_len) }; + if block_size == 0 || dft_n_cols == 0 { + return out; + } + + let rows_per_band = ((system_info::L1_CACHE_SIZE / 2) / (dft_n_cols * size_of::())).clamp(1, block_size); + let band_len = rows_per_band * dft_n_cols; + + parallel::par_chunks_mut(&mut out, band_len, |band_idx, band| { + let row0 = band_idx * rows_per_band; + let n_rows = band.len() / dft_n_cols; + for col in 0..dft_n_cols { + let col_base = col << log_block_size; + for r in 0..n_rows { + let src = (col_base + row0 + r) >> log_inv_rate; + unsafe { + *band.get_unchecked_mut(r * dft_n_cols + col) = *evals.get_unchecked(src); + } + } + } + }); + out } fn prepare_evals_for_fft_packed_extension>>( @@ -160,25 +172,38 @@ fn prepare_evals_for_fft_packed_extension>>( let full_len = evals.len() << (log_inv_rate + log_packing); let block_size = full_len / n_blocks; let log_block_size = log2_strict_usize(block_size); - let n_blocks_mask = n_blocks - 1; let packing_mask = (1 << log_packing) - 1; - (0..full_len) - .into_par_iter() - .map(|i| { - let block_index = i & n_blocks_mask; - let offset_in_block = i >> folding_factor; - let src_index = ((block_index << log_block_size) + offset_in_block) >> log_inv_rate; - let packed_src_index = src_index >> log_packing; - let offset_in_packing = src_index & packing_mask; - let packed = unsafe { evals.get_unchecked(packed_src_index) }; - let unpacked: &[PFPacking] = packed.as_basis_coefficients_slice(); - EF::from_basis_coefficients_fn(|i| unsafe { - let u: &PFPacking = unpacked.get_unchecked(i); - *u.as_slice().get_unchecked(offset_in_packing) - }) - }) - .collect() + let mut out: Vec = unsafe { uninitialized_vec(full_len) }; + if block_size == 0 || n_blocks == 0 { + return out; + } + + let rows_per_band = ((system_info::L1_CACHE_SIZE / 2) / (n_blocks * size_of::())).clamp(1, block_size); + let band_len = rows_per_band * n_blocks; + + parallel::par_chunks_mut(&mut out, band_len, |band_idx, band| { + let row0 = band_idx * rows_per_band; + let n_rows = band.len() / n_blocks; + for col in 0..n_blocks { + let col_base = col << log_block_size; + for r in 0..n_rows { + let src_index = (col_base + row0 + r) >> log_inv_rate; + let packed_src_index = src_index >> log_packing; + let offset_in_packing = src_index & packing_mask; + let packed = unsafe { evals.get_unchecked(packed_src_index) }; + let unpacked: &[PFPacking] = packed.as_basis_coefficients_slice(); + let val = EF::from_basis_coefficients_fn(|j| unsafe { + let u: &PFPacking = unpacked.get_unchecked(j); + *u.as_slice().get_unchecked(offset_in_packing) + }); + unsafe { + *band.get_unchecked_mut(r * n_blocks + col) = val; + } + } + } + }); + out } type CacheKey = TypeId; diff --git a/crates/xmss/src/signers_cache.rs b/crates/xmss/src/signers_cache.rs index 6e7a9956e..27c73dd15 100644 --- a/crates/xmss/src/signers_cache.rs +++ b/crates/xmss/src/signers_cache.rs @@ -89,18 +89,18 @@ fn gen_benchmark_signers_cache() -> Vec<(XmssPublicKey, XmssSignature)> { let completed = AtomicUsize::new(1); let time = Instant::now(); - let rest: Vec<_> = (1..NUM_BENCHMARK_SIGNERS) - .into_par_iter() - .map(|index| { - let signer = compute_signer(index); - let done = completed.fetch_add(1, Ordering::Relaxed) + 1; - print!( - "\rPrecomputing benchmark signatures (cached after first run): {:.0}%", - 100.0 * done as f64 / NUM_BENCHMARK_SIGNERS as f64 - ); - signer - }) - .collect(); + let n_rest = NUM_BENCHMARK_SIGNERS - 1; + let mut rest_opt: Vec> = (0..n_rest).map(|_| None).collect(); + parallel::par_for_each_mut(&mut rest_opt, |i, out| { + let signer = compute_signer(1 + i); + let done = completed.fetch_add(1, Ordering::Relaxed) + 1; + print!( + "\rPrecomputing benchmark signatures (cached after first run): {:.0}%", + 100.0 * done as f64 / NUM_BENCHMARK_SIGNERS as f64 + ); + *out = Some(signer); + }); + let rest: Vec<_> = rest_opt.into_iter().map(Option::unwrap).collect(); println!( "\rGenerating signatures for benchmark (one-time operation): 100% - done ({:.2}s)", @@ -128,7 +128,8 @@ fn gen_benchmark_signers_cache() -> Vec<(XmssPublicKey, XmssSignature)> { #[test] fn test_signature_cache() { let signatures = get_benchmark_signatures(); - signatures.par_iter().enumerate().for_each(|(i, (pk, sig))| { + parallel::for_each_index(signatures.len(), |i| { + let (pk, sig) = &signatures[i]; xmss_verify(pk, &message_for_benchmark(), sig, BENCHMARK_SLOT) .unwrap_or_else(|_| panic!("Signature {} failed to verify", i)); }); diff --git a/crates/xmss/src/xmss.rs b/crates/xmss/src/xmss.rs index d5f69f445..e56771cc8 100644 --- a/crates/xmss/src/xmss.rs +++ b/crates/xmss/src/xmss.rs @@ -84,13 +84,13 @@ pub fn xmss_key_gen( } let public_param: PublicParam = gen_public_param(&seed); // Level 0: WOTS leaf hashes for slots in [slot_start, slot_end] - let leaves: Vec = (slot_start..=slot_end) - .into_par_iter() - .map(|slot| { - let wots = gen_wots_secret_key(&seed, slot, public_param); - wots.public_key().hash(public_param, slot) - }) - .collect(); + let n_leaves = (slot_end - slot_start + 1) as usize; + let mut leaves: Vec = unsafe { uninitialized_vec(n_leaves) }; + parallel::par_for_each_mut(&mut leaves, |i, out| { + let slot = slot_start + i as u32; + let wots = gen_wots_secret_key(&seed, slot, public_param); + *out = wots.public_key().hash(public_param, slot); + }); let mut merkle_tree = vec![leaves]; // Build levels 1..=LOG_LIFETIME. // At level l, we store nodes with index in [(slot_start >> l), (slot_end >> l)]. @@ -102,30 +102,31 @@ pub fn xmss_key_gen( let prev_top: u64 = (slot_end as u64) >> (level - 1); let nodes: Vec = { let prev = &merkle_tree[level - 1]; - (base..=top) - .into_par_iter() - .map(|i| { - let left_idx = 2 * i; - let right_idx = 2 * i + 1; - let left = if left_idx >= prev_base && left_idx <= prev_top { - prev[(left_idx - prev_base) as usize] - } else { - gen_random_node(&seed, level - 1, left_idx) - }; - let right = if right_idx >= prev_base && right_idx <= prev_top { - prev[(right_idx - prev_base) as usize] - } else { - gen_random_node(&seed, level - 1, right_idx) - }; - let merkle_data = build_merkle_data( - make_tweak(TWEAK_TYPE_MERKLE, level, i as u32), - &public_param, - &left, - &right, - ); - poseidon16_compress(merkle_data)[..XMSS_DIGEST_LEN].try_into().unwrap() - }) - .collect() + let n_nodes = (top - base + 1) as usize; + let mut nodes: Vec = unsafe { uninitialized_vec(n_nodes) }; + parallel::par_for_each_mut(&mut nodes, |k, out| { + let i = base + k as u64; + let left_idx = 2 * i; + let right_idx = 2 * i + 1; + let left = if left_idx >= prev_base && left_idx <= prev_top { + prev[(left_idx - prev_base) as usize] + } else { + gen_random_node(&seed, level - 1, left_idx) + }; + let right = if right_idx >= prev_base && right_idx <= prev_top { + prev[(right_idx - prev_base) as usize] + } else { + gen_random_node(&seed, level - 1, right_idx) + }; + let merkle_data = build_merkle_data( + make_tweak(TWEAK_TYPE_MERKLE, level, i as u32), + &public_param, + &left, + &right, + ); + *out = poseidon16_compress(merkle_data)[..XMSS_DIGEST_LEN].try_into().unwrap(); + }); + nodes }; merkle_tree.push(nodes); } diff --git a/crates/xmss/tests/xmss_tests.rs b/crates/xmss/tests/xmss_tests.rs index 0fb08e01d..7abf02312 100644 --- a/crates/xmss/tests/xmss_tests.rs +++ b/crates/xmss/tests/xmss_tests.rs @@ -46,17 +46,19 @@ fn encoding_grinding_bits() { merkle_root: Default::default(), public_param: Default::default(), }; - let total_iters = (0..n) - .into_par_iter() - .map(|i| { + let total_iters = parallel::map_reduce( + n, + || 0usize, + |i| { let message: [F; MESSAGE_LEN_FE] = Default::default(); let slot = i as u32; let mut rng = StdRng::seed_from_u64(i as u64); let (_randomness, _encoding, num_iters) = find_randomness_for_wots_encoding(&message, slot, &xmss_pub_key, &mut rng); num_iters - }) - .sum::(); + }, + |a, b| a + b, + ); let grinding = ((total_iters as f64) / (n as f64)).log2(); println!("Average grinding bits: {:.1}", grinding); } diff --git a/src/lib.rs b/src/lib.rs index 577853996..48a4a0be0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,8 +11,41 @@ pub use xmss::{MESSAGE_LEN_FE, XmssPublicKey, XmssSecretKey, XmssSignature, xmss pub type F = KoalaBear; +/// Tune the allocator and VM memory policy for the prover's churn of huge buffers. +/// +/// Two things, both before any heavy proving allocation (idempotent): +/// +/// 1. **Disable mimalloc purging** so freed large blocks are *retained* rather than returned +/// to the OS and re-faulted on the next allocation — what made the old bump arena fast +/// (page reuse); mimalloc-with-retention matches and beats it. +/// +/// 2. **Disable Transparent Huge Pages for this process.** On Zen4 (and likely other x86 with +/// physically-indexed L2/L3), when the kernel promotes the allocator's large arenas to +/// 2 MB huge pages, the prover's strided multilinear/NTT array access collapses into a few +/// cache sets — measured **+217% cache-misses, IPC 0.85 → 0.51, +50% wall time** on +/// `fancy-aggregation`. It's intermittent (only fires when 2 MB-contiguous memory is free +/// for THP promotion), which is what made it so hard to pin down. `prctl(PR_SET_THP_DISABLE)` +/// is process-local and overrides even a system-wide `THP=always`. No-op off Linux (macOS +/// has no THP — Apple silicon was never affected). Applies under any allocator. +pub fn tune_allocator() { + // mimalloc v3 option index `mi_option_purge_delay` = 15; value -1 = never purge + // (equivalent to `MIMALLOC_PURGE_DELAY=-1`). No-op under the `standard-alloc` (plain + // system allocator) build. + #[cfg(not(feature = "standard-alloc"))] + unsafe { + libmimalloc_sys::mi_option_set(15, -1); + } + // Keep allocator arenas on 4 KB pages (see point 2 above). + #[cfg(target_os = "linux")] + unsafe { + libc::prctl(libc::PR_SET_THP_DISABLE, 1, 0, 0, 0); + } +} + /// Call once before proving. Compiles the aggregation program and precomputes DFT twiddles. pub fn setup_prover() { + tune_allocator(); + parallel::init(); // construct the thread pool up front (was done by `zk_alloc::begin_phase`) rec_aggregation::init_aggregation_bytecode(); precompute_dft_twiddles::(1 << 24); } @@ -21,16 +54,3 @@ pub fn setup_prover() { pub fn setup_verifier() { rec_aggregation::init_aggregation_bytecode(); } - -/// Bump-arena allocator. -/// -/// **Optional.** -/// -/// To enable, set it as the `#[global_allocator]` in your binary and call -/// [`init_allocator`] once at startup. Then bracket each proving call with -/// [`begin_phase`] / [`end_phase`] and **clone the outputs after -/// [`end_phase`]** so the cloned copy lands in the system allocator before the -/// next [`begin_phase`] resets the arena slabs. -/// -/// See `tests/test_zk_alloc.rs` for a runnable end-to-end example. -pub use zk_alloc::{ZkAllocator, begin_phase, end_phase, init as init_allocator}; diff --git a/src/main.rs b/src/main.rs index 646fc6f64..4328b882b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,15 @@ use clap::Parser; use rec_aggregation::benchmark::{AggregationTopology, biggest_leaf, run_aggregation_benchmark}; +// Allocator: mimalloc — a robust production allocator, tuned to retain freed memory (see +// `lean_multisig::tune_allocator`). Replaces the former `zk-alloc` bump arena, which was +// fast but fragile: any allocation outliving a phase, or a pointer retained across +// `begin_phase`'s slab reset (e.g. from a background thread, tracing, or a stray clone), +// silently corrupted memory. mimalloc-with-retention is both **faster** here and stable. +// The `standard-alloc` feature selects the plain system allocator for comparison. #[cfg(not(feature = "standard-alloc"))] #[global_allocator] -static ALLOC: zk_alloc::ZkAllocator = zk_alloc::ZkAllocator; +static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; #[derive(Parser)] enum Cli { @@ -67,8 +73,9 @@ fn run_with_warmup(topology: &AggregationTopology, tracing: bool, json: bool, re #[allow(clippy::too_many_lines)] fn main() { - #[cfg(not(feature = "standard-alloc"))] - zk_alloc::init(); + // Retain freed memory (no purging) so the prover's huge buffer churn reuses pages + // instead of re-faulting — the property that made the old arena fast. Before any work. + lean_multisig::tune_allocator(); let cli = Cli::parse(); diff --git a/tests/test_aggregation.rs b/tests/test_aggregation.rs new file mode 100644 index 000000000..14e06d5fb --- /dev/null +++ b/tests/test_aggregation.rs @@ -0,0 +1,21 @@ +use lean_multisig::{aggregate_single_msg_signatures, setup_prover, verify_single_message_aggregate}; +use xmss::signers_cache::{BENCHMARK_SLOT, get_benchmark_signatures, message_for_benchmark}; + +// End-to-end prove+verify under the default (mimalloc) allocator. Repeated to catch any +// cross-run state corruption. (Replaces the former `test_zk_alloc` arena regression test; +// the bump arena and its phase ceremony were removed in favor of a robust allocator.) +#[test] +fn test_aggregation_prove_verify() { + setup_prover(); + + let log_inv_rate = 2; + let message = message_for_benchmark(); + let slot: u32 = BENCHMARK_SLOT; + let signatures = get_benchmark_signatures(); + let raw_xmss = signatures[0..6].to_vec(); + + for _ in 0..2 { + let aggregated = aggregate_single_msg_signatures(&[], raw_xmss.clone(), message, slot, log_inv_rate).unwrap(); + verify_single_message_aggregate(&aggregated).unwrap(); + } +} diff --git a/tests/test_zk_alloc.rs b/tests/test_zk_alloc.rs deleted file mode 100644 index 4596bde61..000000000 --- a/tests/test_zk_alloc.rs +++ /dev/null @@ -1,27 +0,0 @@ -use lean_multisig::{ - ZkAllocator, aggregate_single_msg_signatures, begin_phase, end_phase, setup_prover, verify_single_message_aggregate, -}; -use xmss::signers_cache::{BENCHMARK_SLOT, get_benchmark_signatures, message_for_benchmark}; - -#[global_allocator] -static ALLOC: ZkAllocator = ZkAllocator; - -#[test] -#[allow(clippy::redundant_clone)] -fn test_aggregation_with_zk_alloc() { - setup_prover(); - - let log_inv_rate = 2; - let message = message_for_benchmark(); - let slot: u32 = BENCHMARK_SLOT; - let signatures = get_benchmark_signatures(); - let raw_xmss = signatures[0..6].to_vec(); - - begin_phase(); - let aggregated = aggregate_single_msg_signatures(&[], raw_xmss, message, slot, log_inv_rate).unwrap(); - end_phase(); - // IMPORTANT: clone to move the data out of the arena memory - let aggregated = aggregated.clone(); - - verify_single_message_aggregate(&aggregated).unwrap(); -}