Jax sampling tweaks#149
Open
danlkv wants to merge 5 commits into
Open
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
6f39b9e to
3ff19bd
Compare
The previous np.concatenate(batches)[:shots] triggered (1) one __array__ call per jax.Array in the list — each forcing its own d2h transfer — and (2) a fresh host buffer allocation plus a host-side memcpy from the d2h'd buffer into the concat output. On large bool tensors (e.g. 500k shots × 528 detector bits = 264 MB) the host concat alone was about 1 s on top of the PCIe transfer. Concatenating on device first and then doing a single np.asarray means one d2h, no extra host memcpy. For the common batch_size == shots case we also skip the jnp.concatenate (single-element batches list). No behavior change, no new flag. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
ExactScalarArray.sum and .prod previously did:
scanned = lax.associative_scan(op, ..., axis=axis)
return take(scanned, -1, axis=axis)
That builds the full O(N) prefix tensor along the scan axis just to throw
all but the last element away. For the batches the sampler now reaches
(B up to 500k+, T up to several dozen) the prefix tensor was the dominant
memory cost in .sum / .prod and the limiting factor for batch size.
Replace with a lax.scan-based reduction that keeps a single
(power, coeffs) carry — O(1) extra memory along the scan axis. Compute
depth becomes O(N) sequential instead of O(log N) parallel, but the public
.sum/.prod API only needs the final value, so the parallel-prefix advantage
of associative_scan doesn't apply.
unroll=4 default keeps XLA's traced body small enough that compile times
stay reasonable for typical scan lengths.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Default numpy d2h (np.asarray on a jax.Array) lands in pageable host memory,
which forces the CUDA driver to stage the transfer through internal pinned
scratch before copying into the user buffer. The second hop is bound by host
DRAM bandwidth and caps the effective d2h throughput well below PCIe line
rate. Pinning the destination skips the staging hop.
Measured throughput at 3.4 GB transfer (rotated surface code d=7, 10M shots):
H100 (gen4 x16) B200 (gen5 x16)
cp.asnumpy / np.asarray (pageable) 1.9 GB/s 4.1 GB/s
cudaMemcpy → cudaHostAlloc (pinned) 23.4 GB/s 51.4 GB/s
Translated to per-shot wall on the surface-code-noise sweep:
p vanilla pre-pin (+G) post-pin (+G)
1e-6 (10M shots) 0.084µs 0.162µs 0.021µs
1e-4 (10M shots) 0.176µs 0.188µs 0.021µs
1e-2 (10K shots) 4.86µs 0.102µs 0.091µs
The pinned path lifts +G from "loses to vanilla below p~1e-5" to
"monotonically faster across the whole sweep."
Implementation:
- New tsim.utils.cuda_helpers module: _PinnedBuf (RAII over cudaHostAlloc),
alloc_pinned_numpy (returns a pinned-backed ndarray with lifetime tied to
the underlying region via ctypes + ndarray.base), copy_d2h (the public
entry, picks pinned fast path or numpy fallback based on import of
cuda.bindings).
- sampler._sample_batches replaces np.asarray(combined)[:shots] with
copy_d2h(combined)[:shots] and adds the matching jax.block_until_ready
before the call.
cuda.bindings is a soft dep — when import fails, copy_d2h falls back to
np.array, preserving the prior behavior. The pyproject.toml is unchanged.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
unroll
------
Sweeping ``_SCAN_UNROLL ∈ {1, 4, 8, 16, True}`` on msc_3 cutting 100k
(H100):
unroll us/shot compile_s
1 6.16 6.82
4 5.35 6.78 (prior default)
8 5.07 6.77
16 4.22 6.82
True 4.24 6.66
16 ties ``True`` on speed and stays explicit. Compile time is flat
across the range, so there's no compile-cost reason to keep 4. Verified
on star_d7 cat5 500k: unroll=16 gives 1.81 µs/shot vs 1.96 at unroll=4
(no regression).
Net vs upstream/main vanilla on msc_3 cutting 100k: 4.22 vs 5.31 µs/shot
(-21%).
fixpoint
--------
``_scalar_add_with_power`` and ``_scalar_mul_with_power`` each apply one
``_reduce_power_coeffs_step`` per call. The tree-shaped
``lax.associative_scan`` this PR replaced naturally produces canonical
form (gcd of coeffs is odd) because each combine sees two operands of
equal accumulation depth — one reduce per node suffices. Sequential
``lax.scan`` accumulation can lag canonical form by up to ``log2(N)``
reductions for an N-element scan, breaking
``test_sum_reduces_while_adding``.
Bring the carry back to canonical form with a ``lax.while_loop`` that
early-exits as soon as no element reduces further. Most inputs converge
in 1-2 iters; on msc_3 cutting / star_d7 cat5 benches the added cost is
inside measurement noise.
cuda_helpers
------------
Also gate ``copy_d2h``'s pinned-d2h fast path on the source actually
living on a CUDA device. JAX on a CPU-only jaxlib reports
``src.unsafe_buffer_pointer()`` as a host pointer; feeding that to
``cudaMemcpy(..., DeviceToHost)`` returns ``cudaErrorInvalidValue``.
Detect via ``src.devices()`` platform and route to the numpy fallback
there too.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
3ff19bd to
81203a6
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Two minor JAX-side tweaks that give a significant boost on msc_3 and star_d7.
exact_scalar.py: lax.scan replaces associative_scan in .sum / .prod
Helps circuits with big-G components and a small number of trivial
components (msc_3).
Reductions did
lax.associative_scan(op, ..., axis) + take(-1, axis),which materialises the full prefix tensor along the scan axis just to
return the last element. Replaced with a lax.scan-based reduction that
keeps only a single (power, coeffs) carry — O(1) extra memory along the
scan axis vs O(N) before, which was dominating memory cost.
_SCAN_UNROLL = 16is a heuristically chosen constant; may need tuning(1 = sequential, higher = larger unrolled body, longer compile).
sampler.py: device-side concat + single d2h
Most relevant when there are a large number of trivial components
(star_d7).
result = np.concatenate(batches)[:shots]triggers one array perjax.Array in the list (each forces a d2h) plus a host-side memcpy into
the freshly-allocated concat output. On a (500k, 528) bool tensor
(~264 MB) the CPU memcpy alone was ~1 s on top of the PCIe transfer.
Replace with
combined = batches[0] if len == 1 else jnp.concatenate(..., axis=0)followed by a singlenp.asarray(combined)[:shots]. One d2h, no extra host memcpy.Measured on H100 / star_d7 cat5, 500k shots: d2h block drops from
~2.5 s to ~0.8 s.