Skip to content

Jax sampling tweaks#149

Open
danlkv wants to merge 5 commits into
QuEraComputing:mainfrom
danlkv:dlykov/jax-sampling-tweaks
Open

Jax sampling tweaks#149
danlkv wants to merge 5 commits into
QuEraComputing:mainfrom
danlkv:dlykov/jax-sampling-tweaks

Conversation

@danlkv
Copy link
Copy Markdown

@danlkv danlkv commented May 20, 2026

Two minor JAX-side tweaks that give a significant boost on msc_3 and star_d7.

  1. exact_scalar.py: lax.scan replaces associative_scan in .sum / .prod
    Helps circuits with big-G components and a small number of trivial
    components (msc_3).

    Reductions did lax.associative_scan(op, ..., axis) + take(-1, axis),
    which materialises the full prefix tensor along the scan axis just to
    return the last element. Replaced with a lax.scan-based reduction that
    keeps only a single (power, coeffs) carry — O(1) extra memory along the
    scan axis vs O(N) before, which was dominating memory cost.

    _SCAN_UNROLL = 16 is a heuristically chosen constant; may need tuning
    (1 = sequential, higher = larger unrolled body, longer compile).

  2. sampler.py: device-side concat + single d2h
    Most relevant when there are a large number of trivial components
    (star_d7).

    result = np.concatenate(batches)[:shots] triggers one array per
    jax.Array in the list (each forces a d2h) plus a host-side memcpy into
    the freshly-allocated concat output. On a (500k, 528) bool tensor
    (~264 MB) the CPU memcpy alone was ~1 s on top of the PCIe transfer.

    Replace with combined = batches[0] if len == 1 else jnp.concatenate(..., axis=0) followed by a single
    np.asarray(combined)[:shots]. One d2h, no extra host memcpy.
    Measured on H100 / star_d7 cat5, 500k shots: d2h block drops from
    ~2.5 s to ~0.8 s.

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@danlkv danlkv force-pushed the dlykov/jax-sampling-tweaks branch 2 times, most recently from 6f39b9e to 3ff19bd Compare May 22, 2026 18:24
danlkv and others added 5 commits May 22, 2026 13:29
The previous np.concatenate(batches)[:shots] triggered (1) one __array__ call
per jax.Array in the list — each forcing its own d2h transfer — and (2) a
fresh host buffer allocation plus a host-side memcpy from the d2h'd buffer
into the concat output. On large bool tensors (e.g. 500k shots × 528
detector bits = 264 MB) the host concat alone was about 1 s on top of the
PCIe transfer.

Concatenating on device first and then doing a single np.asarray means one
d2h, no extra host memcpy. For the common batch_size == shots case we
also skip the jnp.concatenate (single-element batches list).

No behavior change, no new flag.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
ExactScalarArray.sum and .prod previously did:

    scanned = lax.associative_scan(op, ..., axis=axis)
    return take(scanned, -1, axis=axis)

That builds the full O(N) prefix tensor along the scan axis just to throw
all but the last element away. For the batches the sampler now reaches
(B up to 500k+, T up to several dozen) the prefix tensor was the dominant
memory cost in .sum / .prod and the limiting factor for batch size.

Replace with a lax.scan-based reduction that keeps a single
(power, coeffs) carry — O(1) extra memory along the scan axis. Compute
depth becomes O(N) sequential instead of O(log N) parallel, but the public
.sum/.prod API only needs the final value, so the parallel-prefix advantage
of associative_scan doesn't apply.

unroll=4 default keeps XLA's traced body small enough that compile times
stay reasonable for typical scan lengths.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Default numpy d2h (np.asarray on a jax.Array) lands in pageable host memory,
which forces the CUDA driver to stage the transfer through internal pinned
scratch before copying into the user buffer. The second hop is bound by host
DRAM bandwidth and caps the effective d2h throughput well below PCIe line
rate. Pinning the destination skips the staging hop.

Measured throughput at 3.4 GB transfer (rotated surface code d=7, 10M shots):

                                    H100 (gen4 x16)   B200 (gen5 x16)
  cp.asnumpy / np.asarray (pageable)      1.9 GB/s        4.1 GB/s
  cudaMemcpy → cudaHostAlloc (pinned)    23.4 GB/s       51.4 GB/s

Translated to per-shot wall on the surface-code-noise sweep:

         p           vanilla    pre-pin (+G)   post-pin (+G)
  1e-6  (10M shots)   0.084µs     0.162µs        0.021µs
  1e-4  (10M shots)   0.176µs     0.188µs        0.021µs
  1e-2  (10K shots)   4.86µs      0.102µs        0.091µs

The pinned path lifts +G from "loses to vanilla below p~1e-5" to
"monotonically faster across the whole sweep."

Implementation:
- New tsim.utils.cuda_helpers module: _PinnedBuf (RAII over cudaHostAlloc),
  alloc_pinned_numpy (returns a pinned-backed ndarray with lifetime tied to
  the underlying region via ctypes + ndarray.base), copy_d2h (the public
  entry, picks pinned fast path or numpy fallback based on import of
  cuda.bindings).
- sampler._sample_batches replaces np.asarray(combined)[:shots] with
  copy_d2h(combined)[:shots] and adds the matching jax.block_until_ready
  before the call.

cuda.bindings is a soft dep — when import fails, copy_d2h falls back to
np.array, preserving the prior behavior. The pyproject.toml is unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
unroll
------
Sweeping ``_SCAN_UNROLL ∈ {1, 4, 8, 16, True}`` on msc_3 cutting 100k
(H100):

  unroll  us/shot  compile_s
       1     6.16       6.82
       4     5.35       6.78    (prior default)
       8     5.07       6.77
      16     4.22       6.82
    True     4.24       6.66

16 ties ``True`` on speed and stays explicit. Compile time is flat
across the range, so there's no compile-cost reason to keep 4. Verified
on star_d7 cat5 500k: unroll=16 gives 1.81 µs/shot vs 1.96 at unroll=4
(no regression).

Net vs upstream/main vanilla on msc_3 cutting 100k: 4.22 vs 5.31 µs/shot
(-21%).

fixpoint
--------
``_scalar_add_with_power`` and ``_scalar_mul_with_power`` each apply one
``_reduce_power_coeffs_step`` per call. The tree-shaped
``lax.associative_scan`` this PR replaced naturally produces canonical
form (gcd of coeffs is odd) because each combine sees two operands of
equal accumulation depth — one reduce per node suffices. Sequential
``lax.scan`` accumulation can lag canonical form by up to ``log2(N)``
reductions for an N-element scan, breaking
``test_sum_reduces_while_adding``.

Bring the carry back to canonical form with a ``lax.while_loop`` that
early-exits as soon as no element reduces further. Most inputs converge
in 1-2 iters; on msc_3 cutting / star_d7 cat5 benches the added cost is
inside measurement noise.

cuda_helpers
------------
Also gate ``copy_d2h``'s pinned-d2h fast path on the source actually
living on a CUDA device. JAX on a CPU-only jaxlib reports
``src.unsafe_buffer_pointer()`` as a host pointer; feeding that to
``cudaMemcpy(..., DeviceToHost)`` returns ``cudaErrorInvalidValue``.
Detect via ``src.devices()`` platform and route to the numpy fallback
there too.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@danlkv danlkv force-pushed the dlykov/jax-sampling-tweaks branch from 3ff19bd to 81203a6 Compare May 22, 2026 18:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant