From da6479199a02910ac2b63123955483266570c1c5 Mon Sep 17 00:00:00 2001 From: dzzz2001 Date: Fri, 29 May 2026 09:58:48 +0800 Subject: [PATCH 1/3] perf(gint): shape-exact bucketing + tile ladder + wide-LDS vbatched GEMM Optimize the GPU gint batched-GEMM path (gemm_{nn,tn}_vbatch, driven from phi_mul_phi / phi_mul_dm) for FP64 on V100/A100-class GPUs. - phi_operator_gpu: replace the single max-shape vbatch launch with shape-exact bucketing. Atom pairs are grouped by (nw1, nw2) via a dense NW_MAX*NW_MAX counting-sort table, pre-enumerated once per batch in set_bgrid_batch, so each bucket hands the kernel a scalar (m, n, k) and the tile ladder picks the tightest tile per shape -- no cross-species tile waste, no over-launched blocks. A guard aborts if any atom nw >= NW_MAX. - dgemm_vbatch: scalar (m, n, k) dispatch (drops the per-batchid M/N/K device arrays) feeding a 4x2 (NN) / 4x4 (TN) BLK_{M,N} ladder over {8,16,32,48}. - gemm_{nn,tn}_vbatch: K-inner shared-memory layout + wide (double2/float4) LDS inner loop -- one 16-byte LDS feeds VK FMAs per (m,n); PAD keeps the shmem stride 16-byte aligned and warp access bank-conflict-free. C accumulators stay double regardless of input type T, preserving the mixed-precision fp64-accumulator fix (#7368); the phi_operator kernel optimizations from #7366 (WantPhi dispatch, single-warp reduce) are retained. FP64 15-case GPU benchmark: end-to-end ~1.05x (A800) / ~1.04x (V100), with cal_gint_vl up to ~1.5x and cal_gint_rho up to ~1.65x; energies and pressures match develop to ~1e-10 on every case. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../module_gint/kernel/dgemm_vbatch.cu | 111 +++++- .../module_gint/kernel/dgemm_vbatch.h | 61 +-- .../module_gint/kernel/gemm_nn_vbatch.cuh | 177 ++++++--- .../module_gint/kernel/gemm_tn_vbatch.cuh | 156 +++++--- .../module_gint/kernel/gint_helper.cuh | 41 ++ .../module_gint/kernel/phi_operator_gpu.cu | 349 +++++++++++------- .../module_gint/kernel/phi_operator_gpu.h | 44 ++- 7 files changed, 618 insertions(+), 321 deletions(-) diff --git a/source/source_lcao/module_gint/kernel/dgemm_vbatch.cu b/source/source_lcao/module_gint/kernel/dgemm_vbatch.cu index 38946d51492..98764571baa 100644 --- a/source/source_lcao/module_gint/kernel/dgemm_vbatch.cu +++ b/source/source_lcao/module_gint/kernel/dgemm_vbatch.cu @@ -5,58 +5,129 @@ template void gemm_nn_vbatch( - int max_m, int max_n, int max_k, - const int* m_d, const int* n_d, const int* k_d, + int m, int n, int k, const T* const* A_array_d, const int* lda_d, const T* const* B_array_d, const int* ldb_d, double** C_array_d, const int* ldc_d, int batchCount, cudaStream_t stream, const T* alpha) { - vbatched_gemm_nn_impl - (max_m, max_n, m_d, n_d, k_d, - A_array_d, lda_d, - B_array_d, ldb_d, - C_array_d, ldc_d, - batchCount, stream, alpha); + // 4 (nw2 bracket) x 2 (bxyz bracket) = 8 instantiations. + // + // Mapping into the impl's parameter list is: + // + // which satisfies the kernel's tile-divisibility asserts because every + // (BLK_M, BLK_N, BLK_K=16) chosen below is a multiple of the matching + // (DIM_X, DIM_Y) pair. + #define NN_DISPATCH(DX, DY, BM, BN) \ + vbatched_gemm_nn_impl( \ + m, n, k, \ + A_array_d, lda_d, B_array_d, ldb_d, \ + C_array_d, ldc_d, batchCount, stream, alpha) + // BLK_M bracket -- smallest tile in {8,16,32,48} covering nw2. + const int blk_m_tag = (n <= 8) ? 0 + : (n <= 16) ? 1 + : (n <= 32) ? 2 + : 3; + + // BLK_N bracket -- 32 only when bxyz <=32 (caps mask waste at 50% for + // bxyz=27); 64 for everything else (best LDS reuse). + const int blk_n_tag = (m <= 32) ? 0 : 1; + + switch (blk_m_tag * 2 + blk_n_tag) + { + // BLK_M=8 (nw2 <=8 ). DIM=4x8 -> THR_M=2. + case 0: NN_DISPATCH( 4, 8, 8, 32); break; // THR=2*4=8 (under) + case 1: NN_DISPATCH( 4, 8, 8, 64); break; // THR=2*8=16 (in band) + // BLK_M=16 (nw2<=16). DIM=4x8 -> THR_M=4. + case 2: NN_DISPATCH( 4, 8, 16, 32); break; // THR=4*4=16 (in band) + case 3: NN_DISPATCH( 4, 8, 16, 64); break; // THR=4*8=32 (in band) + // BLK_M=32 (nw2<=32). DIM=8x8 -> THR_M=4. + case 4: NN_DISPATCH( 8, 8, 32, 32); break; // THR=4*4=16 (in band) + case 5: NN_DISPATCH( 8, 8, 32, 64); break; // THR=4*8=32 (in band) + // BLK_M=48 (nw2<=48). DIM=16x8 -> THR_M=3 (cap at 3 to keep + // register pressure room for the BLK_N=64 sibling). + case 6: NN_DISPATCH(16, 8, 48, 32); break; // THR=3*4=12 (just under) + case 7: NN_DISPATCH(16, 8, 48, 64); break; // THR=3*8=24 (in band) + } + + #undef NN_DISPATCH } template void gemm_tn_vbatch( - int max_m, int max_n, int max_k, - const int* m_d, const int* n_d, const int* k_d, + int m, int n, int k, const T* const* A_array_d, const int* lda_d, const T* const* B_array_d, const int* ldb_d, double** C_array_d, const int* ldc_d, int batchCount, cudaStream_t stream, const T* alpha) { - vbatched_gemm_tn_impl - (max_m, max_n, m_d, n_d, k_d, - A_array_d, lda_d, - B_array_d, ldb_d, - C_array_d, ldc_d, - batchCount, stream, alpha); + // 4 (nw2 bracket) x 4 (nw1 bracket) = 16 instantiations. + // + // Both output axes here are the small nw axis, so we use the same + // {8,16,32,48} ladder on both. BLK_K = 32 (the bxyz axis -- large). + #define TN_DISPATCH(DX, DY, BM, BN) \ + vbatched_gemm_tn_impl( \ + m, n, k, \ + A_array_d, lda_d, B_array_d, ldb_d, \ + C_array_d, ldc_d, batchCount, stream, alpha) + + auto bracket = [](int x) { + return (x <= 8) ? 0 + : (x <= 16) ? 1 + : (x <= 32) ? 2 + : 3; + }; + const int blk_m_tag = bracket(n); // BLK_M <- nw2 + const int blk_n_tag = bracket(m); // BLK_N <- nw1 + + switch (blk_m_tag * 4 + blk_n_tag) + { + // BLK_M=8 rungs (nw2<=8). DIM_X=4, THR_M=2. + case 0: TN_DISPATCH(4, 8, 8, 8); break; // THR=2*1=2 (corner) + case 1: TN_DISPATCH(4, 8, 8, 16); break; // THR=2*2=4 + case 2: TN_DISPATCH(4, 8, 8, 32); break; // THR=2*4=8 + case 3: TN_DISPATCH(4, 8, 8, 48); break; // THR=2*6=12 + // BLK_M=16 rungs (nw2<=16). DIM_X=4, THR_M=4. + case 4: TN_DISPATCH(4, 8, 16, 8); break; // THR=4*1=4 + case 5: TN_DISPATCH(4, 8, 16, 16); break; // THR=4*2=8 + case 6: TN_DISPATCH(4, 8, 16, 32); break; // THR=4*4=16 (in band) + case 7: TN_DISPATCH(4, 8, 16, 48); break; // THR=4*6=24 (in band) + // BLK_M=32 rungs (nw2<=32). DIM_X=8, THR_M=4. + case 8: TN_DISPATCH(8, 4, 32, 8); break; // THR=4*2=8 + case 9: TN_DISPATCH(8, 4, 32, 16); break; // THR=4*4=16 (in band) + case 10: TN_DISPATCH(8, 8, 32, 32); break; // THR=4*4=16 (in band) + case 11: TN_DISPATCH(8, 8, 32, 48); break; // THR=4*6=24 (in band) + // BLK_M=48 rungs (nw2<=48). DIM_X=8, THR_M=6. + case 12: TN_DISPATCH(8, 4, 48, 8); break; // THR=6*2=12 + case 13: TN_DISPATCH(8, 4, 48, 16); break; // THR=6*4=24 (in band) + case 14: TN_DISPATCH(8, 8, 48, 32); break; // THR=6*4=24 (in band) + case 15: TN_DISPATCH(8, 8, 48, 48); break; // THR=6*6=36 (top of band) + } + + #undef TN_DISPATCH } // Explicit instantiations template void gemm_nn_vbatch( - int, int, int, const int*, const int*, const int*, + int, int, int, const double* const*, const int*, const double* const*, const int*, double**, const int*, int, cudaStream_t, const double*); template void gemm_nn_vbatch( - int, int, int, const int*, const int*, const int*, + int, int, int, const float* const*, const int*, const float* const*, const int*, double**, const int*, int, cudaStream_t, const float*); template void gemm_tn_vbatch( - int, int, int, const int*, const int*, const int*, + int, int, int, const double* const*, const int*, const double* const*, const int*, double**, const int*, int, cudaStream_t, const double*); template void gemm_tn_vbatch( - int, int, int, const int*, const int*, const int*, + int, int, int, const float* const*, const int*, const float* const*, const int*, double**, const int*, int, cudaStream_t, const float*); diff --git a/source/source_lcao/module_gint/kernel/dgemm_vbatch.h b/source/source_lcao/module_gint/kernel/dgemm_vbatch.h index 68505c1f4e4..3052011767b 100644 --- a/source/source_lcao/module_gint/kernel/dgemm_vbatch.h +++ b/source/source_lcao/module_gint/kernel/dgemm_vbatch.h @@ -2,61 +2,36 @@ #include -// Template version: C(batch_id) = alpha * A(batch_id) * B(batch_id) + C(batch_id) -// As with gemm_tn_vbatch, the C accumulator is always double regardless of the -// input type T so the per-block reduction and device-side atomicAdd run in fp64. +// Shape-exact batched GEMM dispatchers. +// +// Every (A_i, B_i, C_i) in the batch has exactly the same (m, n, k); the +// caller (phi_operator_gpu.cu) enforces this by bucketing atom pairs on +// (nw1, nw2). The scalars drive tile-ladder selection, grid sizing, and +// flow all the way through the kernel -- there is no per-batchid M/N/K +// indirection left. +// +// The C accumulator is always double regardless of the input type T: a fp32 +// GEMM path (T=float) feeds fp32 multiplies into fp64 registers and a +// device-side fp64 atomicAdd, so summing many atom-pair contributions into the +// same hr_gint / phi_dm element does not drift. For T=double, A, B and C are +// all double and this matches the legacy signature. + +// C(batch) = alpha * A(batch) * B(batch) + C(batch) template void gemm_nn_vbatch( - int max_m, int max_n, int max_k, - const int* m_d, const int* n_d, const int* k_d, + int m, int n, int k, const T* const* A_array_d, const int* lda_d, const T* const* B_array_d, const int* ldb_d, double** C_array_d, const int* ldc_d, int batchCount, cudaStream_t stream, const T* alpha = nullptr); -// Template version: C(batch_id) = alpha * A(batch_id)^T * B(batch_id) + C(batch_id) -// The C accumulator is always double regardless of input type T: a fp32 GEMM -// path (T=float) feeds fp32 multiplies into fp64 accumulators (registers and -// device-side atomicAdds) to avoid catastrophic precision loss across many -// atom-pair contributions to the same hr_gint element. +// C(batch) = alpha * A(batch)^T * B(batch) + C(batch) template void gemm_tn_vbatch( - int max_m, int max_n, int max_k, - const int* m_d, const int* n_d, const int* k_d, + int m, int n, int k, const T* const* A_array_d, const int* lda_d, const T* const* B_array_d, const int* ldb_d, double** C_array_d, const int* ldc_d, int batchCount, cudaStream_t stream, const T* alpha = nullptr); - -// Legacy double-only aliases for backward compatibility -inline void dgemm_nn_vbatch( - int max_m, int max_n, int max_k, - const int* m_d, const int* n_d, const int* k_d, - const double* const* A_array_d, const int* lda_d, - const double* const* B_array_d, const int* ldb_d, - double** C_array_d, const int* ldc_d, - int batchCount, cudaStream_t stream, - const double* alpha = nullptr) -{ - gemm_nn_vbatch(max_m, max_n, max_k, - m_d, n_d, k_d, A_array_d, lda_d, B_array_d, ldb_d, - C_array_d, ldc_d, batchCount, stream, alpha); -} - -inline void dgemm_tn_vbatch( - int max_m, int max_n, int max_k, - const int* m_d, const int* n_d, const int* k_d, - const double* const* A_array_d, const int* lda_d, - const double* const* B_array_d, const int* ldb_d, - double** C_array_d, const int* ldc_d, - int batchCount, cudaStream_t stream, - const double* alpha = nullptr) -{ - // T=double path: A, B, and C are all double — the C-channel double-fix - // matches the legacy signature here. - gemm_tn_vbatch(max_m, max_n, max_k, - m_d, n_d, k_d, A_array_d, lda_d, B_array_d, ldb_d, - C_array_d, ldc_d, batchCount, stream, alpha); -} \ No newline at end of file diff --git a/source/source_lcao/module_gint/kernel/gemm_nn_vbatch.cuh b/source/source_lcao/module_gint/kernel/gemm_nn_vbatch.cuh index 39e10e4252d..aa1464341ae 100644 --- a/source/source_lcao/module_gint/kernel/gemm_nn_vbatch.cuh +++ b/source/source_lcao/module_gint/kernel/gemm_nn_vbatch.cuh @@ -12,7 +12,29 @@ #include "source_base/module_device/device_check.h" #include "source_base/module_device/kernel_compat.h" -#define sA(i, j) sA[(j)*slda + (i)] +// V1 K-inner shmem layout +// sA(m, k) = sA[m * slda + k] row-major in M, K-inner; slda = BLK_K + PAD +// sB(k, n) = sB[n * sldb + k] col-major in N, K-inner; sldb = BLK_K + PAD +// Both layouts make the inner loop read VK consecutive K elements per LDS, +// turning one scalar LDS-per-FMA into one 16-byte LDS-per-VK-FMAs. +// PAD comes from gemm_vec_traits::PAD (FP32: +4, FP64: +2) and is what +// makes slda/sldb 16-byte aligned for LDS.{64,128}. +// +// Phase V3 bank-conflict audit (sA inner-loop read, idx-strided lanes): +// FP64, DIM_X= 8 (8x16 thread tiles): slda=BLK_K+2 -> 8 lanes at +// stride 4 banks each side -> banks {0,4,...,28} disjoint -> 0 conflicts. +// FP32, DIM_X= 8 (8x16 thread tiles): slda=BLK_K+4 -> 8 lanes at +// stride 4 banks (4-bank vec) -> disjoint -> 0 conflicts. +// FP64, DIM_X=16 (V2 16x16 big tile): slda=BLK_K+2, 16 lanes; even +// slda forces gcd(2*slda,32) >= 2, so the LOW/HIGH bank pair lands +// on distinct banks for all 16 lanes only when 2*slda has order >=16 +// mod 32. With BLK_K=16 -> slda=18 -> 36 mod 32 = 4 -> 8-distinct +// -> 2-way conflict. Accepted in V2: still beats scalar LDS by ~VK/2, +// and removing the conflict requires a swizzled layout (Step 2). +// sB inner-loop read uses idy-strided lanes; with DIM_Y in {8,16} the +// warp covers only 2-4 distinct n_col values, broadcast factor >= 8 +// -> always conflict-free regardless of sldb. +#define sA(i, j) sA[(i)*slda + (j)] #define sB(i, j) sB[(j)*sldb + (i)] #define fetch(A, m, n, bound) offs_d##A[min(n * LD##A + m, bound)] @@ -43,6 +65,31 @@ static __device__ void vbatched_gemm_nn_device(int M, int sldb, T alpha) { + using vec_t = typename gemm_vec_traits::vec_t; + constexpr int VK = gemm_vec_traits::VK; + + // V1 contract: BLK_K must be a whole number of VK chunks so the + // vectorized FMA loop below covers it cleanly. PAD makes slda * 8/4 + // a multiple of 16 (LDS alignment) -- enforced at the kernel scope. + static_assert(BLK_K % VK == 0, + "BLK_K must be divisible by VK (16 / sizeof(T))"); + + // Tile-divisibility (Phase V3 audit): every dev->shmem load loop + // assumes the BLK_* dim is an exact multiple of the corresponding + // DIM_*, and the per-thread fan-out THR_M/N is BLK_M/N / DIM_X/Y. + // A mis-spec'd new template instantiation would silently load + // garbage; these asserts surface it at compile time. + static_assert(BLK_M % DIM_X == 0, "BLK_M must be divisible by DIM_X"); + static_assert(BLK_N % DIM_Y == 0, "BLK_N must be divisible by DIM_Y"); + static_assert(BLK_M % DIM_XA == 0, "BLK_M must be divisible by DIM_XA"); + static_assert(BLK_K % DIM_YA == 0, "BLK_K must be divisible by DIM_YA"); + static_assert(BLK_K % DIM_XB == 0, "BLK_K must be divisible by DIM_XB"); + static_assert(BLK_N % DIM_YB == 0, "BLK_N must be divisible by DIM_YB"); + static_assert(DIM_XA * DIM_YA == DIM_X * DIM_Y, + "A-loader thread grid must cover the whole block"); + static_assert(DIM_XB * DIM_YB == DIM_X * DIM_Y, + "B-loader thread grid must cover the whole block"); + int idx = threadIdx.x; // thread's m dimension int idy = threadIdx.y; // thread's n dimension @@ -57,13 +104,14 @@ static __device__ void vbatched_gemm_nn_device(int M, int blx = blockIdx.x; // block's m dimension int bly = blockIdx.y; // block's n dimension - // Registers for the innermost loop. rC accumulates in T; the widening to - // double happens only at the final atomicAdd into C. + // Accumulator tile (registers). Layout matches the original. T rC[THR_N][THR_M]; - T rA[THR_M]; - T rB[THR_N]; - // Registers for the dev->shmem copy + // Per-VK-step shmem->reg tiles. One LDS feeds VK FMAs per (m,n). + T rA[THR_M][VK]; + T rB[THR_N][VK]; + + // Registers for the dev->shmem copy (next-K-tile prefetch). T ra[BLK_K / DIM_YA][BLK_M / DIM_XA]; T rb[BLK_N / DIM_YB][BLK_K / DIM_XB]; @@ -86,7 +134,7 @@ static __device__ void vbatched_gemm_nn_device(int M, #pragma unroll for (m = 0; m < THR_M; m++) { - rC[n][m] = T(0); + rC[n][m] = 0.0; } } @@ -143,32 +191,44 @@ static __device__ void vbatched_gemm_nn_device(int M, } } -// Multiply +// Wide-LDS FMA: VK FMAs per shmem read. +// FP32: LDS.128 (float4) -> 4 FMAs per (m,n) per inner step +// FP64: LDS.64 (double2) -> 2 FMAs per (m,n) per inner step +// Both rely on slda/sldb being 16-byte aligned (PAD math) and on BLK_K +// being a whole number of VK chunks (static_assert above). #pragma unroll - for (k = 0; k < BLK_K; k++) + for (k = 0; k < BLK_K; k += VK) { // Load A shmem->regs #pragma unroll for (m = 0; m < THR_M; m++) { - rA[m] = sA(m * DIM_X + idx, k); + vec_t va = *reinterpret_cast( + &sA(m * DIM_X + idx, k)); + gemm_vec_traits::unpack(va, rA[m]); } // Load B shmem->regs #pragma unroll for (n = 0; n < THR_N; n++) { - rB[n] = sB(k, n * DIM_Y + idy); + vec_t vb = *reinterpret_cast( + &sB(k, n * DIM_Y + idy)); + gemm_vec_traits::unpack(vb, rB[n]); } -// Compute +// Compute (VK fan-out per (m,n)). #pragma unroll - for (n = 0; n < THR_N; n++) + for (int kv = 0; kv < VK; kv++) { #pragma unroll - for (m = 0; m < THR_M; m++) + for (n = 0; n < THR_N; n++) { - rC[n][m] += rA[m] * rB[n]; +#pragma unroll + for (m = 0; m < THR_M; m++) + { + rC[n][m] += rA[m][kv] * rB[n][kv]; + } } } } @@ -199,36 +259,36 @@ static __device__ void vbatched_gemm_nn_device(int M, __syncthreads(); } - // Multiply last full (BLK_K) or partial block of - // columns of op(A) and rows of op(B). + // Tail: last full (BLK_K) or partial block. Scalar from the K-inner + // layout -- the partial-K block can land on an odd k count (e.g. + // bxyz=27 -> tail 11), so don't try to vectorize it. // It's okay that m,n exceed matrix bounds as all work is in registers // or shared memory, and out-of-bounds rC[n][m] will not be saved later. kk = K - kk; #pragma unroll for (k = 0; k < kk; k++) { -// Load A shmem->regs + T rA_s[THR_M]; + T rB_s[THR_N]; #pragma unroll for (m = 0; m < THR_M; m++) { - rA[m] = sA(m * DIM_X + idx, k); + rA_s[m] = sA(m * DIM_X + idx, k); } -// Load B shmem->regs #pragma unroll for (n = 0; n < THR_N; n++) { - rB[n] = sB(k, n * DIM_Y + idy); + rB_s[n] = sB(k, n * DIM_Y + idy); } -// Compute #pragma unroll for (n = 0; n < THR_N; n++) { #pragma unroll for (m = 0; m < THR_M; m++) { - rC[n][m] += rA[m] * rB[n]; + rC[n][m] += rA_s[m] * rB_s[n]; } } } @@ -263,9 +323,9 @@ template -static __global__ void vbatched_gemm_nn_kernel(const int* M, - const int* N, - const int* K, +static __global__ void vbatched_gemm_nn_kernel(int M, + int N, + int K, const T* const* global_A_array, const int* global_lda, const T* const* global_B_array, @@ -274,23 +334,25 @@ static __global__ void vbatched_gemm_nn_kernel(const int* M, const int* global_ldc, const T* alpha) { - extern __shared__ __align__(sizeof(double)) unsigned char smem[]; + // 16-byte align for vec_t (double2 / float4) loads. + extern __shared__ __align__(16) unsigned char smem[]; T* shared_mem = reinterpret_cast(smem); int batchid = blockIdx.z; - int local_M = (int)M[batchid]; - int local_N = (int)N[batchid]; - int local_K = (int)K[batchid]; - - if (blockIdx.x >= (local_M + BLK_M - 1) / BLK_M) - return; - if (blockIdx.y >= (local_N + BLK_N - 1) / BLK_N) - return; - int shared_lda = BLK_M + 1; - int shared_ldb = BLK_K + 1; + constexpr int PAD = gemm_vec_traits::PAD; + static_assert(((BLK_K + PAD) * sizeof(T)) % 16 == 0, + "shmem K-stride * sizeof(T) must be 16-byte aligned for " + "LDS.{64,128}"); + static_assert(BLK_K % gemm_vec_traits::VK == 0, + "BLK_K must be divisible by VK = 16 / sizeof(T)"); + + // V1 K-inner: slda is the K-axis stride for sA (M-rows of (BLK_K + PAD)), + // sldb is the K-axis stride for sB (N-cols of (BLK_K + PAD)). + int shared_lda = BLK_K + PAD; + int shared_ldb = BLK_K + PAD; T* shared_A = (T*)shared_mem; - T* shared_B = shared_A + shared_lda * BLK_K; + T* shared_B = shared_A + BLK_M * shared_lda; T alpha_tmp = T(1.0); if (alpha != nullptr) { @@ -307,9 +369,9 @@ static __global__ void vbatched_gemm_nn_kernel(const int* M, DIM_XB, DIM_YB, (BLK_M / DIM_X), - (BLK_N / DIM_Y)>(local_M, - local_N, - local_K, + (BLK_N / DIM_Y)>(M, + N, + K, global_A_array[batchid], (int)global_lda[batchid], global_B_array[batchid], @@ -343,12 +405,9 @@ static __global__ void vbatched_gemm_nn_kernel(const int* M, * matrix B. * @tparam DIM_YB The number of threads in the y-dimension used for loading * matrix B. - * @param max_m The maximum number of rows in the matrices. - * @param max_n The maximum number of columns in the matrices. - * @param m An array of batch sizes for the number of rows in each matrix. - * @param n An array of batch sizes for the number of columns in each matrix. - * @param k An array of batch sizes for the number of elements in each matrix - * along the K dimension. + * @param m The number of rows in each matrix (same across the batch). + * @param n The number of columns in each matrix (same across the batch). + * @param k The number of elements along the K dimension (same across the batch). * @param global_A_array An array of pointers to the input matrices A. * @param global_lda An array of leading dimensions for the input matrices A. * @param global_B_array An array of pointers to the input matrices B. @@ -358,7 +417,7 @@ static __global__ void vbatched_gemm_nn_kernel(const int* M, * @param batchCount The number of matrices in the batch. * @param stream The CUDA stream to use for the computation. * @param alpha The scalar value to multiply the matrices by (optional, default - * is nullptr). generate by copilot + * is nullptr). */ template -void vbatched_gemm_nn_impl(int max_m, - int max_n, - const int* m, - const int* n, - const int* k, +void vbatched_gemm_nn_impl(int m, + int n, + int k, const T* const* global_A_array, const int* global_lda, const T* const* global_B_array, @@ -389,17 +446,21 @@ void vbatched_gemm_nn_impl(int max_m, // This is because vbatch_gemm_nn_kernel is column major, // but vatched_gemm_nn_impl is designed to be row major, + // V1 K-inner shmem footprint: + // sA: BLK_M rows of (BLK_K + PAD) elements + // sB: BLK_N cols of (BLK_K + PAD) elements + constexpr int PAD = gemm_vec_traits::PAD; size_t shared_mem_size = 0; - shared_mem_size += (BLK_M + 1) * BLK_K * sizeof(T); - shared_mem_size += (BLK_K + 1) * BLK_N * sizeof(T); + shared_mem_size += BLK_M * (BLK_K + PAD) * sizeof(T); + shared_mem_size += BLK_N * (BLK_K + PAD) * sizeof(T); dim3 dimBlock(DIM_X, DIM_Y); const int max_batch_count = 32768; for (int i = 0; i < batchCount; i += max_batch_count) { const int ibatch = min(max_batch_count, batchCount - i); - dim3 dimGrid(ceil_div(max_n, BLK_M), - ceil_div(max_m, BLK_N), + dim3 dimGrid(ceil_div(n, BLK_M), + ceil_div(m, BLK_N), ibatch); const T* alpha_tmp = nullptr; if (alpha != nullptr) @@ -418,7 +479,7 @@ void vbatched_gemm_nn_impl(int max_m, DIM_XB, DIM_YB> <<>>( - n + i, m + i, k + i, + n, m, k, global_B_array + i, global_ldb + i, global_A_array + i, global_lda + i, global_C_array + i, global_ldc + i, diff --git a/source/source_lcao/module_gint/kernel/gemm_tn_vbatch.cuh b/source/source_lcao/module_gint/kernel/gemm_tn_vbatch.cuh index ae6a4a00551..7ed8d3e7af1 100644 --- a/source/source_lcao/module_gint/kernel/gemm_tn_vbatch.cuh +++ b/source/source_lcao/module_gint/kernel/gemm_tn_vbatch.cuh @@ -12,7 +12,16 @@ #include "source_base/module_device/device_check.h" #include "source_base/module_device/kernel_compat.h" -#define sA(i, j) sA[(j)*slda + (i)] +// V1 K-inner shmem layout (matches gemm_nn_vbatch.cuh): +// sA(m, k) = sA[m * slda + k] row-major in M, K-inner; slda = BLK_K + PAD +// sB(k, n) = sB[n * sldb + k] col-major in N, K-inner; sldb = BLK_K + PAD +// PAD comes from gemm_vec_traits (FP32: +4, FP64: +2) and makes the +// stride 16-byte aligned + bank-conflict-free for warp-wide LDS. +// See gemm_nn_vbatch.cuh for the full Phase V3 bank-conflict audit table; +// the TN inner loop uses the same indexing pattern, so the same analysis +// applies (the only structural difference is sB's load loop, which writes +// to the same K-inner storage layout). +#define sA(i, j) sA[(i)*slda + (j)] #define sB(i, j) sB[(j)*sldb + (i)] #define fetch(A, m, n, bound) offs_d##A[min(n * LD##A + m, bound)] @@ -43,6 +52,26 @@ static __device__ void vbatched_gemm_nt_device(int M, int sldb, T alpha) { + using vec_t = typename gemm_vec_traits::vec_t; + constexpr int VK = gemm_vec_traits::VK; + + static_assert(BLK_K % VK == 0, + "BLK_K must be divisible by VK (16 / sizeof(T))"); + + // Tile-divisibility (Phase V3 audit): same checks as gemm_nn_vbatch. + // sB load loop in TN traverses (BLK_K rows x BLK_N cols), so the + // divisibility constraints on DIM_XB / DIM_YB are mirrored. + static_assert(BLK_M % DIM_X == 0, "BLK_M must be divisible by DIM_X"); + static_assert(BLK_N % DIM_Y == 0, "BLK_N must be divisible by DIM_Y"); + static_assert(BLK_M % DIM_XA == 0, "BLK_M must be divisible by DIM_XA"); + static_assert(BLK_K % DIM_YA == 0, "BLK_K must be divisible by DIM_YA"); + static_assert(BLK_N % DIM_XB == 0, "BLK_N must be divisible by DIM_XB"); + static_assert(BLK_K % DIM_YB == 0, "BLK_K must be divisible by DIM_YB"); + static_assert(DIM_XA * DIM_YA == DIM_X * DIM_Y, + "A-loader thread grid must cover the whole block"); + static_assert(DIM_XB * DIM_YB == DIM_X * DIM_Y, + "B-loader thread grid must cover the whole block"); + int idx = threadIdx.x; // thread's m dimension int idy = threadIdx.y; // thread's n dimension @@ -57,13 +86,14 @@ static __device__ void vbatched_gemm_nt_device(int M, int blx = blockIdx.x; // block's m dimension int bly = blockIdx.y; // block's n dimension - // Registers for the innermost loop. rC accumulates in T; the widening to - // double happens only at the final atomicAdd into C. + // Accumulator tile (registers). T rC[THR_N][THR_M]; - T rA[THR_M]; - T rB[THR_N]; - // Registers for the dev->shmem copy + // Per-VK-step shmem->reg tiles. One LDS feeds VK FMAs per (m,n). + T rA[THR_M][VK]; + T rB[THR_N][VK]; + + // Registers for the dev->shmem copy (next-K-tile prefetch). T ra[BLK_K / DIM_YA][BLK_M / DIM_XA]; T rb[BLK_K / DIM_YB][BLK_N / DIM_XB]; @@ -86,7 +116,7 @@ static __device__ void vbatched_gemm_nt_device(int M, #pragma unroll for (m = 0; m < THR_M; m++) { - rC[n][m] = T(0); + rC[n][m] = 0.0; } } @@ -143,32 +173,42 @@ static __device__ void vbatched_gemm_nt_device(int M, } } -// Multiply +// Wide-LDS FMA: VK FMAs per shmem read. +// FP32: LDS.128 (float4) -> 4 FMAs per (m,n) per inner step +// FP64: LDS.64 (double2) -> 2 FMAs per (m,n) per inner step #pragma unroll - for (k = 0; k < BLK_K; k++) + for (k = 0; k < BLK_K; k += VK) { // Load A shmem->regs #pragma unroll for (m = 0; m < THR_M; m++) { - rA[m] = sA(m * DIM_X + idx, k); + vec_t va = *reinterpret_cast( + &sA(m * DIM_X + idx, k)); + gemm_vec_traits::unpack(va, rA[m]); } // Load B shmem->regs #pragma unroll for (n = 0; n < THR_N; n++) { - rB[n] = sB(k, n * DIM_Y + idy); + vec_t vb = *reinterpret_cast( + &sB(k, n * DIM_Y + idy)); + gemm_vec_traits::unpack(vb, rB[n]); } -// Compute +// Compute (VK fan-out per (m,n)). #pragma unroll - for (n = 0; n < THR_N; n++) + for (int kv = 0; kv < VK; kv++) { #pragma unroll - for (m = 0; m < THR_M; m++) + for (n = 0; n < THR_N; n++) { - rC[n][m] += rA[m] * rB[n]; +#pragma unroll + for (m = 0; m < THR_M; m++) + { + rC[n][m] += rA[m][kv] * rB[n][kv]; + } } } } @@ -199,36 +239,35 @@ static __device__ void vbatched_gemm_nt_device(int M, __syncthreads(); } - // Multiply last full (BLK_K) or partial block of - // columns of op(A) and rows of op(B). + // Tail: scalar from the K-inner layout. Partial-K blocks can be odd + // (bxyz=27 -> tail 11; bxyz=125 -> tail 13), so don't try to vectorize. // It's okay that m,n exceed matrix bounds as all work is in registers // or shared memory, and out-of-bounds rC[n][m] will not be saved later. kk = K - kk; #pragma unroll for (k = 0; k < kk; k++) { -// Load A shmem->regs + T rA_s[THR_M]; + T rB_s[THR_N]; #pragma unroll for (m = 0; m < THR_M; m++) { - rA[m] = sA(m * DIM_X + idx, k); + rA_s[m] = sA(m * DIM_X + idx, k); } -// Load B shmem->regs #pragma unroll for (n = 0; n < THR_N; n++) { - rB[n] = sB(k, n * DIM_Y + idy); + rB_s[n] = sB(k, n * DIM_Y + idy); } -// Compute #pragma unroll for (n = 0; n < THR_N; n++) { #pragma unroll for (m = 0; m < THR_M; m++) { - rC[n][m] += rA[m] * rB[n]; + rC[n][m] += rA_s[m] * rB_s[n]; } } } @@ -263,9 +302,9 @@ template -static __global__ void vbatched_gemm_nt_kernel(const int* M, - const int* N, - const int* K, +static __global__ void vbatched_gemm_nt_kernel(int M, + int N, + int K, const T* const* global_A_array, const int* global_lda, const T* const* global_B_array, @@ -274,23 +313,25 @@ static __global__ void vbatched_gemm_nt_kernel(const int* M, const int* global_ldc, const T* alpha) { - extern __shared__ __align__(sizeof(double)) unsigned char smem[]; + // 16-byte align for vec_t (double2 / float4) loads. + extern __shared__ __align__(16) unsigned char smem[]; T* shared_mem = reinterpret_cast(smem); int batchid = blockIdx.z; - int local_M = (int)M[batchid]; - int local_N = (int)N[batchid]; - int local_K = (int)K[batchid]; - - if (blockIdx.x >= (local_M + BLK_M - 1) / BLK_M) - return; - if (blockIdx.y >= (local_N + BLK_N - 1) / BLK_N) - return; - int shared_lda = BLK_M + 1; - int shared_ldb = BLK_K + 1; + constexpr int PAD = gemm_vec_traits::PAD; + static_assert(((BLK_K + PAD) * sizeof(T)) % 16 == 0, + "shmem K-stride * sizeof(T) must be 16-byte aligned for " + "LDS.{64,128}"); + static_assert(BLK_K % gemm_vec_traits::VK == 0, + "BLK_K must be divisible by VK = 16 / sizeof(T)"); + + // V1 K-inner: slda = K-axis stride for sA (BLK_M rows of (BLK_K + PAD)), + // sldb = K-axis stride for sB (BLK_N cols of (BLK_K + PAD)). + int shared_lda = BLK_K + PAD; + int shared_ldb = BLK_K + PAD; T* shared_A = (T*)shared_mem; - T* shared_B = shared_A + shared_lda * BLK_K; + T* shared_B = shared_A + BLK_M * shared_lda; T alpha_tmp = T(1.0); if (alpha != nullptr) { @@ -307,9 +348,9 @@ static __global__ void vbatched_gemm_nt_kernel(const int* M, DIM_XB, DIM_YB, (BLK_M / DIM_X), - (BLK_N / DIM_Y)>(local_M, - local_N, - local_K, + (BLK_N / DIM_Y)>(M, + N, + K, global_A_array[batchid], (int)global_lda[batchid], global_B_array[batchid], @@ -343,12 +384,9 @@ static __global__ void vbatched_gemm_nt_kernel(const int* M, * matrix B. * @tparam DIM_YB The number of threads in the y-dimension used for loading * matrix B. - * @param max_m The maximum number of rows in the matrices. - * @param max_n The maximum number of columns in the matrices. - * @param m An array of batch sizes for the number of rows in each matrix. - * @param n An array of batch sizes for the number of columns in each matrix. - * @param k An array of batch sizes for the number of elements in each matrix - * along the K dimension. + * @param m The number of rows in each matrix (same across the batch). + * @param n The number of columns in each matrix (same across the batch). + * @param k The number of elements along the K dimension (same across the batch). * @param global_A_array An array of pointers to the input matrices A. * @param global_lda An array of leading dimensions for the input matrices A. * @param global_B_array An array of pointers to the input matrices B. @@ -358,7 +396,7 @@ static __global__ void vbatched_gemm_nt_kernel(const int* M, * @param batchCount The number of matrices in the batch. * @param stream The CUDA stream to use for the computation. * @param alpha The scalar value to multiply the matrices by (optional, default - * is nullptr). generate by copilot + * is nullptr). */ /* @@ -395,11 +433,9 @@ template -void vbatched_gemm_tn_impl(int max_m, - int max_n, - const int* m, - const int* n, - const int* k, +void vbatched_gemm_tn_impl(int m, + int n, + int k, const T* const* global_A_array, const int* global_lda, const T* const* global_B_array, @@ -414,17 +450,21 @@ void vbatched_gemm_tn_impl(int max_m, // This is because vbatch_gemm__tn_kernel is column major, // but vatched_gemm_nt_impl is designed to be row major, + // V1 K-inner shmem footprint (matches gemm_nn_vbatch_impl): + // sA: BLK_M rows of (BLK_K + PAD) elements + // sB: BLK_N cols of (BLK_K + PAD) elements + constexpr int PAD = gemm_vec_traits::PAD; size_t shared_mem_size = 0; - shared_mem_size += (BLK_M + 1) * BLK_K * sizeof(T); - shared_mem_size += (BLK_K + 1) * BLK_N * sizeof(T); + shared_mem_size += BLK_M * (BLK_K + PAD) * sizeof(T); + shared_mem_size += BLK_N * (BLK_K + PAD) * sizeof(T); dim3 dimBlock(DIM_X, DIM_Y); const int max_batch_count = 32768; for (int i = 0; i < batchCount; i += max_batch_count) { const int ibatch = min(max_batch_count, batchCount - i); - dim3 dimGrid(ceil_div(max_n, BLK_M), - ceil_div(max_m, BLK_N), + dim3 dimGrid(ceil_div(n, BLK_M), + ceil_div(m, BLK_N), ibatch); const T* alpha_tmp = nullptr; if (alpha != nullptr) @@ -443,7 +483,7 @@ void vbatched_gemm_tn_impl(int max_m, DIM_XB, DIM_YB> <<>>( - n + i, m + i, k + i, + n, m, k, global_B_array + i, global_ldb + i, global_A_array + i, global_lda + i, global_C_array + i, global_ldc + i, diff --git a/source/source_lcao/module_gint/kernel/gint_helper.cuh b/source/source_lcao/module_gint/kernel/gint_helper.cuh index eae5953654b..ec1007b1d23 100644 --- a/source/source_lcao/module_gint/kernel/gint_helper.cuh +++ b/source/source_lcao/module_gint/kernel/gint_helper.cuh @@ -43,3 +43,44 @@ inline int ceil_div(const int a, const int b) return a / b + (a % b != 0 && (a ^ b) > 0); } +// --------------------------------------------------------------------------- +// gemm_vec_traits -- wide-LDS primitive for the V1 K-inner inner loop. +// +// VK = how many T elements pack into one 16-byte LDS (4 for FP32, 2 for FP64) +// vec_t = the 16-byte CUDA vector type used for the LDS +// PAD = K-stride padding that makes (BLK_K + PAD) * sizeof(T) a multiple of +// 16 *and* keeps the warp's idx-strided shmem access bank-conflict-free +// (gcd((BLK_K+PAD) % 32, 32) == VK). +// +// The load is issued as one *reinterpret_cast(&sA(m,k)); the +// component fan-out is done by unpack(). FP64 needs the explicit cast -- +// the compiler's auto-vectorizer is reliable for float4 but not for +// double2; per-component .x/.y/.z/.w writes guarantee the LDS.{64,128} +// SASS forms emit. +// --------------------------------------------------------------------------- +template struct gemm_vec_traits; + +template <> struct gemm_vec_traits +{ + using vec_t = float4; + static constexpr int VK = 4; + static constexpr int PAD = 4; + __forceinline__ __device__ + static void unpack(const vec_t& v, float* d) + { + d[0] = v.x; d[1] = v.y; d[2] = v.z; d[3] = v.w; + } +}; + +template <> struct gemm_vec_traits +{ + using vec_t = double2; + static constexpr int VK = 2; + static constexpr int PAD = 2; + __forceinline__ __device__ + static void unpack(const vec_t& v, double* d) + { + d[0] = v.x; d[1] = v.y; + } +}; + diff --git a/source/source_lcao/module_gint/kernel/phi_operator_gpu.cu b/source/source_lcao/module_gint/kernel/phi_operator_gpu.cu index 7021f44e3fe..e29ec0a95dc 100644 --- a/source/source_lcao/module_gint/kernel/phi_operator_gpu.cu +++ b/source/source_lcao/module_gint/kernel/phi_operator_gpu.cu @@ -2,7 +2,12 @@ #include "phi_operator_kernel.cuh" #include "dgemm_vbatch.h" #include +#include +#include +#include #include +#include +#include #include "source_base/module_device/device_check.h" namespace ModuleGint @@ -18,11 +23,8 @@ bgrid_phi_start_(BatchBigGrid::get_max_batch_size(), stream_, true), atoms_iat_(BatchBigGrid::get_max_atoms_num(), stream_, true), atoms_bgrids_rcoords_(BatchBigGrid::get_max_atoms_num(), stream_, true), atom_phi_start_(BatchBigGrid::get_max_atoms_num(), stream_, true), -batch_mgrid_lidx_(BatchBigGrid::get_max_batch_size() +batch_mgrid_lidx_(BatchBigGrid::get_max_batch_size() * BatchBigGrid::get_bgrid_info()->get_mgrids_num(), stream_, true), -gemm_m_(BatchBigGrid::get_max_atom_pairs_num(), stream_, true), -gemm_n_(BatchBigGrid::get_max_atom_pairs_num(), stream_, true), -gemm_k_(BatchBigGrid::get_max_atom_pairs_num(), stream_, true), gemm_lda_(BatchBigGrid::get_max_atom_pairs_num(), stream_, true), gemm_ldb_(BatchBigGrid::get_max_atom_pairs_num(), stream_, true), gemm_ldc_(BatchBigGrid::get_max_atom_pairs_num(), stream_, true), @@ -90,6 +92,53 @@ void PhiOperatorGpu::set_bgrid_batch(std::shared_ptr bgrid_b atom_phi_start_.copy_host_to_device_async(bgrid_batch->get_atoms_num()); batch_mgrid_lidx_.copy_host_to_device_async(bgrid_batch->get_batch_size() * mgrids_num_); CHECK_CUDA(cudaEventRecord(event_, stream_)); + + // Pre-enumerate every (ia_1, ia_2) pair so phi_mul_phi / phi_mul_dm can + // skip this O(sum atoms_i^2) walk on every call. HContainer lookups and + // per-bgrid filters stay in the hot path since they depend on the + // specific caller (hRGint vs dm, symmetric vs not). + pair_cache_.clear(); + for (int i = 0; i < bgrid_batch->get_batch_size(); ++i) + { + const auto& bgrid = bgrid_batch->get_bgrids()[i]; + const int pre_atoms = atoms_num_info_h[i].y; + const int atoms_num = bgrid->get_atoms_num(); + const int phi_len_mgrid = bgrid->get_phi_len(); + const auto& atoms = bgrid->get_atoms(); + for (int ia_1 = 0; ia_1 < atoms_num; ++ia_1) + { + const auto& atom_1 = atoms[ia_1]; + for (int ia_2 = 0; ia_2 < atoms_num; ++ia_2) + { + const auto& atom_2 = atoms[ia_2]; + PairInfo p; + p.phi_1_offset = atom_phi_start_h[pre_atoms + ia_1]; + p.phi_2_offset = atom_phi_start_h[pre_atoms + ia_2]; + p.phi_len_mgrid = phi_len_mgrid; + p.iat_1 = atom_1->get_iat(); + p.iat_2 = atom_2->get_iat(); + p.r_diff = atom_1->get_R() - atom_2->get_R(); + p.nw1 = static_cast(atom_1->get_nw()); + p.nw2 = static_cast(atom_2->get_nw()); + // The shape key (nw1 * NW_MAX + nw2) indexes a dense + // NW_MAX*NW_MAX table in phi_mul_phi / phi_mul_dm, so nw must + // stay below NW_MAX or those passes write out of bounds. Guard + // it here (always on, including release builds). + if (p.nw1 >= NW_MAX || p.nw2 >= NW_MAX) + { + fprintf(stderr, + "PhiOperatorGpu: per-atom nw (%d, %d) >= NW_MAX " + "(%d); increase NW_MAX in phi_operator_gpu.h\n", + int(p.nw1), int(p.nw2), NW_MAX); + std::abort(); + } + p.ia_le = static_cast(ia_1 <= ia_2); + p.is_diag = static_cast(ia_1 == ia_2); + pair_cache_.push_back(p); + } + } + } + pair_scratch_offset_.assign(pair_cache_.size(), -1); } template @@ -234,57 +283,74 @@ void PhiOperatorGpu::phi_mul_phi( HContainer& hRGint, double* hr_d) const { - // ap_num means number of atom pairs + // Shape-exact bucketing: group atom pairs by (nw1, nw2). K = mgrids_num_ + // is already batch-wide constant, so (nw1, nw2) fully determines the GEMM + // shape. Each bucket hands gemm_tn_vbatch scalar (nw1, nw2, mgrids_num_), + // so the 3x3 template ladder picks the tightest tile for every item and + // the wrapper sizes the grid exactly -- no cross-species tile waste, no + // over-launched blocks. + // + // Algorithm: counting-sort-style two-pass over the pre-enumerated + // pair_cache_ populated in set_bgrid_batch(). + // Pass 1: HContainer lookup -> stash hr_offset, count items per shape. + // Prefix sum: build the list of non-empty buckets + their flat offsets. + // Pass 2: scatter A/B/C pointers + lda/ldb/ldc into the flat host arrays + // at each bucket's slot, then one H2D copy per array and one + // vbatch launch per bucket. (m/n/k arrays are no longer + // scattered -- the wrapper fills them on-device from the + // scalar bucket shape.) + + std::array counts{}; + + // Pass 1: filter + HContainer lookup + per-shape count. + for (size_t i = 0; i < pair_cache_.size(); ++i) + { + const auto& p = pair_cache_[i]; + if (p.iat_1 > p.iat_2) { pair_scratch_offset_[i] = -1; continue; } + const int hr = hRGint.find_matrix_offset(p.iat_1, p.iat_2, p.r_diff); + pair_scratch_offset_[i] = hr; + if (hr == -1) { continue; } + counts[p.nw1 * NW_MAX + p.nw2]++; + } + + // Prefix sum over dense keys -> compact bucket list. + struct Bucket { int key; int off; int cnt; }; + std::vector buckets; + buckets.reserve(32); + std::array key_to_base{}; int ap_num = 0; - int max_m = 0; - int max_n = 0; - int max_k = mgrids_num_; - CHECK_CUDA(cudaEventSynchronize(event_)); - for (int i = 0; i < bgrid_batch_->get_batch_size(); i++) + for (int k = 0; k < NW_MAX * NW_MAX; ++k) { - auto bgrid = bgrid_batch_->get_bgrids()[i]; - // the length of phi on a mesh grid - const int phi_len_mgrid = bgrid->get_phi_len(); - const int pre_atoms = atoms_num_info_.get_host_ptr()[i].y; - for (int ia_1 = 0; ia_1 < bgrid->get_atoms_num(); ia_1++) - { - auto atom_1 = bgrid->get_atoms()[ia_1]; - const int iat_1 = atom_1->get_iat(); - const auto& r_1 = atom_1->get_R(); - const int nw1 = atom_1->get_nw(); - const int phi_1_offset = atom_phi_start_.get_host_ptr()[pre_atoms + ia_1]; + if (counts[k] == 0) { continue; } + buckets.push_back({k, ap_num, counts[k]}); + key_to_base[k] = ap_num; + ap_num += counts[k]; + } - for (int ia_2 = 0; ia_2 < bgrid->get_atoms_num(); ia_2++) - { - auto atom_2 = bgrid->get_atoms()[ia_2]; - const int iat_2 = atom_2->get_iat(); - const auto& r_2 = atom_2->get_R(); - const int nw2 = atom_2->get_nw(); - - if(iat_1 > iat_2) - { continue; } - - int hr_offset = hRGint.find_matrix_offset(iat_1, iat_2, r_1 - r_2); - if (hr_offset == -1) - { continue; } - - const int phi_2_offset = atom_phi_start_.get_host_ptr()[pre_atoms + ia_2]; - - gemm_A_.get_host_ptr()[ap_num] = phi_d + phi_1_offset; - gemm_B_.get_host_ptr()[ap_num] = phi_vldr3_d + phi_2_offset; - gemm_C_.get_host_ptr()[ap_num] = hr_d + hr_offset; - gemm_lda_.get_host_ptr()[ap_num] = phi_len_mgrid; - gemm_ldb_.get_host_ptr()[ap_num] = phi_len_mgrid; - gemm_ldc_.get_host_ptr()[ap_num] = nw2; - gemm_m_.get_host_ptr()[ap_num] = nw1; - gemm_n_.get_host_ptr()[ap_num] = nw2; - gemm_k_.get_host_ptr()[ap_num] = bgrid->get_mgrids_num(); - ap_num++; - - max_m = std::max(max_m, nw1); - max_n = std::max(max_n, nw2); - } - } + auto* h_A = gemm_A_.get_host_ptr(); + auto* h_B = gemm_B_.get_host_ptr(); + auto* h_C = gemm_C_.get_host_ptr(); + auto* h_lda = gemm_lda_.get_host_ptr(); + auto* h_ldb = gemm_ldb_.get_host_ptr(); + auto* h_ldc = gemm_ldc_.get_host_ptr(); + + CHECK_CUDA(cudaEventSynchronize(event_)); + + // Pass 2: scatter into the flat host arrays at per-bucket cursors. + std::array cursor{}; + for (size_t i = 0; i < pair_cache_.size(); ++i) + { + const int hr = pair_scratch_offset_[i]; + if (hr == -1) { continue; } + const auto& p = pair_cache_[i]; + const int key = p.nw1 * NW_MAX + p.nw2; + const int pos = key_to_base[key] + cursor[key]++; + h_A[pos] = phi_d + p.phi_1_offset; + h_B[pos] = phi_vldr3_d + p.phi_2_offset; + h_C[pos] = hr_d + hr; + h_lda[pos] = p.phi_len_mgrid; + h_ldb[pos] = p.phi_len_mgrid; + h_ldc[pos] = p.nw2; } gemm_A_.copy_host_to_device_async(ap_num); @@ -293,26 +359,25 @@ void PhiOperatorGpu::phi_mul_phi( gemm_lda_.copy_host_to_device_async(ap_num); gemm_ldb_.copy_host_to_device_async(ap_num); gemm_ldc_.copy_host_to_device_async(ap_num); - gemm_m_.copy_host_to_device_async(ap_num); - gemm_n_.copy_host_to_device_async(ap_num); - gemm_k_.copy_host_to_device_async(ap_num); CHECK_CUDA(cudaEventRecord(event_, stream_)); - gemm_tn_vbatch(max_m, - max_n, - max_k, - gemm_m_.get_device_ptr(), - gemm_n_.get_device_ptr(), - gemm_k_.get_device_ptr(), - gemm_A_.get_device_ptr(), - gemm_lda_.get_device_ptr(), - gemm_B_.get_device_ptr(), - gemm_ldb_.get_device_ptr(), - gemm_C_.get_device_ptr(), - gemm_ldc_.get_device_ptr(), - ap_num, - stream_, - nullptr); + for (const auto& b : buckets) + { + const int nw1 = b.key / NW_MAX; + const int nw2 = b.key % NW_MAX; + gemm_tn_vbatch(nw1, + nw2, + mgrids_num_, + gemm_A_.get_device_ptr() + b.off, + gemm_lda_.get_device_ptr() + b.off, + gemm_B_.get_device_ptr() + b.off, + gemm_ldb_.get_device_ptr() + b.off, + gemm_C_.get_device_ptr() + b.off, + gemm_ldc_.get_device_ptr() + b.off, + b.cnt, + stream_, + nullptr); + } } template @@ -324,54 +389,68 @@ void PhiOperatorGpu::phi_mul_dm( double* phi_dm_d) { CHECK_CUDA(cudaMemsetAsync(phi_dm_d, 0, phi_len_ * sizeof(double), stream_)); - // ap_num means number of atom pairs + + // Shape-exact bucketing: same structure as phi_mul_phi, but NN-flavored. + // M = mgrids_num_ (batch-wide constant), N = nw2, K = nw1. + // Shape key is still (nw1, nw2); M is absent from the key since it's + // identical across every pair in the batch. is_symm selects the + // upper-triangle (ia_1 <= ia_2) subset and fills per-pair alpha. + + std::array counts{}; + + // Pass 1: filter + HContainer lookup + per-shape count. + for (size_t i = 0; i < pair_cache_.size(); ++i) + { + const auto& p = pair_cache_[i]; + if (is_symm && !p.ia_le) { pair_scratch_offset_[i] = -1; continue; } + const int dm_offset = dm.find_matrix_offset(p.iat_1, p.iat_2, p.r_diff); + pair_scratch_offset_[i] = dm_offset; + if (dm_offset == -1) { continue; } + counts[p.nw1 * NW_MAX + p.nw2]++; + } + + // Prefix sum over dense keys -> compact bucket list. + struct Bucket { int key; int off; int cnt; }; + std::vector buckets; + buckets.reserve(32); + std::array key_to_base{}; int ap_num = 0; - int max_m = mgrids_num_; - int max_n = 0; - int max_k = 0; + for (int k = 0; k < NW_MAX * NW_MAX; ++k) + { + if (counts[k] == 0) { continue; } + buckets.push_back({k, ap_num, counts[k]}); + key_to_base[k] = ap_num; + ap_num += counts[k]; + } + + auto* h_A = gemm_A_.get_host_ptr(); + auto* h_B = gemm_B_.get_host_ptr(); + auto* h_C = gemm_C_.get_host_ptr(); + auto* h_lda = gemm_lda_.get_host_ptr(); + auto* h_ldb = gemm_ldb_.get_host_ptr(); + auto* h_ldc = gemm_ldc_.get_host_ptr(); + auto* h_alpha = gemm_alpha_.get_host_ptr(); + CHECK_CUDA(cudaEventSynchronize(event_)); - for (int i = 0; i < bgrid_batch_->get_batch_size(); i++) + + // Pass 2: scatter. + std::array cursor{}; + for (size_t i = 0; i < pair_cache_.size(); ++i) { - auto bgrid = bgrid_batch_->get_bgrids()[i]; - // the length of phi on a mesh grid - const int phi_len_mgrid = bgrid->get_phi_len(); - const int pre_atoms = atoms_num_info_.get_host_ptr()[i].y; - for (int ia_1 = 0; ia_1 < bgrid->get_atoms_num(); ia_1++) + const int dm_offset = pair_scratch_offset_[i]; + if (dm_offset == -1) { continue; } + const auto& p = pair_cache_[i]; + const int key = p.nw1 * NW_MAX + p.nw2; + const int pos = key_to_base[key] + cursor[key]++; + h_A[pos] = phi_d + p.phi_1_offset; + h_B[pos] = dm_d + dm_offset; + h_C[pos] = phi_dm_d + p.phi_2_offset; + h_lda[pos] = p.phi_len_mgrid; + h_ldb[pos] = p.nw2; + h_ldc[pos] = p.phi_len_mgrid; + if (is_symm) { - auto atom_1 = bgrid->get_atoms()[ia_1]; - const int iat_1 = atom_1->get_iat(); - const auto& r_1 = atom_1->get_R(); - const int nw1 = atom_1->get_nw(); - const int phi_1_offset = atom_phi_start_.get_host_ptr()[pre_atoms + ia_1]; - int ia_2 = is_symm ? ia_1 : 0; - for (; ia_2 < bgrid->get_atoms_num(); ia_2++) - { - auto atom_2 = bgrid->get_atoms()[ia_2]; - const int iat_2 = atom_2->get_iat(); - const auto& r_2 = atom_2->get_R(); - const int nw2 = atom_2->get_nw(); - - int dm_offset = dm.find_matrix_offset(iat_1, iat_2, r_1-r_2); - if (dm_offset == -1) - { continue; } - - const int phi_dm_offset = atom_phi_start_.get_host_ptr()[pre_atoms + ia_2]; - - gemm_A_.get_host_ptr()[ap_num] = phi_d + phi_1_offset; - gemm_B_.get_host_ptr()[ap_num] = dm_d + dm_offset; - gemm_C_.get_host_ptr()[ap_num] = phi_dm_d + phi_dm_offset; - gemm_lda_.get_host_ptr()[ap_num] = phi_len_mgrid; - gemm_ldb_.get_host_ptr()[ap_num] = nw2; - gemm_ldc_.get_host_ptr()[ap_num] = phi_len_mgrid; - gemm_m_.get_host_ptr()[ap_num] = mgrids_num_; - gemm_n_.get_host_ptr()[ap_num] = nw2; - gemm_k_.get_host_ptr()[ap_num] = nw1; - gemm_alpha_.get_host_ptr()[ap_num] = ia_1 == ia_2 ? Real(1.0) : Real(2.0); - ap_num++; - - max_n = std::max(max_n, nw2); - max_k = std::max(max_k, nw1); - } + h_alpha[pos] = p.is_diag ? Real(1.0) : Real(2.0); } } @@ -381,33 +460,31 @@ void PhiOperatorGpu::phi_mul_dm( gemm_lda_.copy_host_to_device_async(ap_num); gemm_ldb_.copy_host_to_device_async(ap_num); gemm_ldc_.copy_host_to_device_async(ap_num); - gemm_m_.copy_host_to_device_async(ap_num); - gemm_n_.copy_host_to_device_async(ap_num); - gemm_k_.copy_host_to_device_async(ap_num); - if(is_symm) + if (is_symm) { - // if is_symm == false, gemm_alpha_ always equals 1.0, - // so we don't need to copy it to device + // if is_symm == false, gemm_alpha_ is always 1.0 and is skipped on device gemm_alpha_.copy_host_to_device_async(ap_num); } CHECK_CUDA(cudaEventRecord(event_, stream_)); - auto alpha_ptr = is_symm ? gemm_alpha_.get_device_ptr() : nullptr; - gemm_nn_vbatch(max_m, - max_n, - max_k, - gemm_m_.get_device_ptr(), - gemm_n_.get_device_ptr(), - gemm_k_.get_device_ptr(), - gemm_A_.get_device_ptr(), - gemm_lda_.get_device_ptr(), - gemm_B_.get_device_ptr(), - gemm_ldb_.get_device_ptr(), - gemm_C_.get_device_ptr(), - gemm_ldc_.get_device_ptr(), - ap_num, - stream_, - alpha_ptr); + for (const auto& b : buckets) + { + const int nw1 = b.key / NW_MAX; + const int nw2 = b.key % NW_MAX; + auto alpha_ptr = is_symm ? (gemm_alpha_.get_device_ptr() + b.off) : nullptr; + gemm_nn_vbatch(mgrids_num_, + nw2, + nw1, + gemm_A_.get_device_ptr() + b.off, + gemm_lda_.get_device_ptr() + b.off, + gemm_B_.get_device_ptr() + b.off, + gemm_ldb_.get_device_ptr() + b.off, + gemm_C_.get_device_ptr() + b.off, + gemm_ldc_.get_device_ptr() + b.off, + b.cnt, + stream_, + alpha_ptr); + } } template diff --git a/source/source_lcao/module_gint/kernel/phi_operator_gpu.h b/source/source_lcao/module_gint/kernel/phi_operator_gpu.h index fbdbc95352a..9ba43420236 100644 --- a/source/source_lcao/module_gint/kernel/phi_operator_gpu.h +++ b/source/source_lcao/module_gint/kernel/phi_operator_gpu.h @@ -1,5 +1,7 @@ #pragma once #include +#include +#include #include #include "source_lcao/module_gint/batch_biggrid.h" @@ -9,6 +11,30 @@ namespace ModuleGint { +// Upper bound for per-atom orbital count. Used to flatten a (nw1, nw2) pair +// into a single integer key via `nw1 * NW_MAX + nw2`, so that shape-exact +// bucketing of phi_mul_phi / phi_mul_dm can index a dense counts[] table +// instead of hashing. 64 comfortably covers the typical max nw (~25). +constexpr int NW_MAX = 64; + +// Per-atom-pair metadata cached in PhiOperatorGpu::set_bgrid_batch() so that +// phi_mul_phi / phi_mul_dm skip the O(bgrid * atoms^2) enumeration on every +// call. Holds only the fields both callers need; HContainer lookups still +// happen lazily in the hot path because they depend on hRGint / dm. +struct PairInfo +{ + int phi_1_offset; + int phi_2_offset; + int phi_len_mgrid; + int iat_1; + int iat_2; + Vec3i r_diff; + uint16_t nw1; + uint16_t nw2; + uint8_t ia_le; // (ia_1 <= ia_2) within the bgrid, for is_symm filter + uint8_t is_diag; // (ia_1 == ia_2), for phi_mul_dm is_symm alpha +}; + template class PhiOperatorGpu { @@ -102,19 +128,25 @@ class PhiOperatorGpu // Mapping of the index of meshgrid in the batch of biggrids to the index of meshgrid in the local cell CudaMemWrapper batch_mgrid_lidx_; - mutable CudaMemWrapper gemm_m_; - mutable CudaMemWrapper gemm_n_; - mutable CudaMemWrapper gemm_k_; mutable CudaMemWrapper gemm_lda_; mutable CudaMemWrapper gemm_ldb_; mutable CudaMemWrapper gemm_ldc_; mutable CudaMemWrapper gemm_A_; mutable CudaMemWrapper gemm_B_; - // Single C-pointer buffer: both phi_mul_phi (output hr) and phi_mul_dm - // (output phi_dm) write into double* accumulators, so a single shared - // gemm_C_ device buffer can serve both call sites. + // C accumulator pointers are always double*: both phi_mul_phi (hr) and + // phi_mul_dm (phi_dm) write into fp64 buffers via the GEMM's fp64 atomicAdd. mutable CudaMemWrapper gemm_C_; mutable CudaMemWrapper gemm_alpha_; + + // Full (ia_1, ia_2) pair enumeration, rebuilt in set_bgrid_batch(). + // Consumed by phi_mul_phi (TN, iat_1 <= iat_2 filter) and phi_mul_dm + // (NN, optional is_symm upper-triangle filter). + std::vector pair_cache_; + + // Scratch buffer reused across phi_mul_phi / phi_mul_dm calls to cache + // per-pair HContainer offsets from Pass 1 and replay them in Pass 2 + // without a second find_matrix_offset() call. + mutable std::vector pair_scratch_offset_; }; } \ No newline at end of file From 0fe0310acca0817948f47f71ff4a5033500effb0 Mon Sep 17 00:00:00 2001 From: dzzz2001 Date: Fri, 29 May 2026 14:37:45 +0800 Subject: [PATCH 2/3] refactor(gint): derive shape-bucket stride from ucell.nwmax, drop hardcoded NW_MAX The (nw1, nw2) shape-bucketing in phi_mul_phi / phi_mul_dm flattened pairs into a dense table key via `nw1 * NW_MAX + nw2`, with NW_MAX a hardcoded 64. That was both a magic number and an artificial ceiling: a basis with nw > 64 would abort(), and 64 was only a guess at the real max. The true upper bound is already known to the code as ucell.nwmax (max orbital count over all atom types), exposed via gint_gpu_vars_->nwmax. Use it: set nw_stride_ = nwmax + 1 once in the ctor so the bucket table is sized exactly to the basis -- no cap to maintain. A runtime stride can't index std::array, so the three counting-sort tables (counts / base / cursor) move to mutable std::vector members allocated once and re-zeroed per call. For typical nwmax~25 that's ~676 ints vs the old fixed 4096, so the hot path zeroes less and never reallocates. The set_bgrid_batch() abort guard becomes a structurally-unreachable assert, since nwmax is by definition the largest nw. Drop now-unused includes (, , ); add . Co-Authored-By: Claude Opus 4.8 (1M context) --- .../module_gint/kernel/phi_operator_gpu.cu | 68 ++++++++++--------- .../module_gint/kernel/phi_operator_gpu.h | 21 ++++-- 2 files changed, 52 insertions(+), 37 deletions(-) diff --git a/source/source_lcao/module_gint/kernel/phi_operator_gpu.cu b/source/source_lcao/module_gint/kernel/phi_operator_gpu.cu index e29ec0a95dc..0969036dc9d 100644 --- a/source/source_lcao/module_gint/kernel/phi_operator_gpu.cu +++ b/source/source_lcao/module_gint/kernel/phi_operator_gpu.cu @@ -3,11 +3,9 @@ #include "dgemm_vbatch.h" #include #include -#include #include #include -#include -#include +#include #include "source_base/module_device/device_check.h" namespace ModuleGint @@ -34,6 +32,14 @@ gemm_C_(BatchBigGrid::get_max_atom_pairs_num(), stream_, true), gemm_alpha_(BatchBigGrid::get_max_atom_pairs_num(), stream_, true) { CHECK_CUDA(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); + // nwmax is the largest per-atom orbital count in the cell, so a stride of + // nwmax + 1 lets every valid (nw1, nw2) flatten to a distinct bucket key + // with no artificial cap. Allocate the bucketing scratch once here; the hot + // path only re-zeroes it. + nw_stride_ = gint_gpu_vars_->nwmax + 1; + bucket_counts_.assign(nw_stride_ * nw_stride_, 0); + bucket_base_.assign(nw_stride_ * nw_stride_, 0); + bucket_cursor_.assign(nw_stride_ * nw_stride_, 0); } template @@ -120,18 +126,12 @@ void PhiOperatorGpu::set_bgrid_batch(std::shared_ptr bgrid_b p.r_diff = atom_1->get_R() - atom_2->get_R(); p.nw1 = static_cast(atom_1->get_nw()); p.nw2 = static_cast(atom_2->get_nw()); - // The shape key (nw1 * NW_MAX + nw2) indexes a dense - // NW_MAX*NW_MAX table in phi_mul_phi / phi_mul_dm, so nw must - // stay below NW_MAX or those passes write out of bounds. Guard - // it here (always on, including release builds). - if (p.nw1 >= NW_MAX || p.nw2 >= NW_MAX) - { - fprintf(stderr, - "PhiOperatorGpu: per-atom nw (%d, %d) >= NW_MAX " - "(%d); increase NW_MAX in phi_operator_gpu.h\n", - int(p.nw1), int(p.nw2), NW_MAX); - std::abort(); - } + // The shape key (nw1 * nw_stride_ + nw2) indexes a dense + // nw_stride_^2 table in phi_mul_phi / phi_mul_dm. nw_stride_ = + // nwmax + 1 and nwmax is by construction the largest per-atom nw + // in the cell, so this can only trip on an upstream + // inconsistency rather than an undersized cap. + assert(p.nw1 < nw_stride_ && p.nw2 < nw_stride_); p.ia_le = static_cast(ia_1 <= ia_2); p.is_diag = static_cast(ia_1 == ia_2); pair_cache_.push_back(p); @@ -300,7 +300,8 @@ void PhiOperatorGpu::phi_mul_phi( // scattered -- the wrapper fills them on-device from the // scalar bucket shape.) - std::array counts{}; + auto& counts = bucket_counts_; + std::fill(counts.begin(), counts.end(), 0); // Pass 1: filter + HContainer lookup + per-shape count. for (size_t i = 0; i < pair_cache_.size(); ++i) @@ -310,16 +311,17 @@ void PhiOperatorGpu::phi_mul_phi( const int hr = hRGint.find_matrix_offset(p.iat_1, p.iat_2, p.r_diff); pair_scratch_offset_[i] = hr; if (hr == -1) { continue; } - counts[p.nw1 * NW_MAX + p.nw2]++; + counts[p.nw1 * nw_stride_ + p.nw2]++; } // Prefix sum over dense keys -> compact bucket list. struct Bucket { int key; int off; int cnt; }; std::vector buckets; buckets.reserve(32); - std::array key_to_base{}; + auto& key_to_base = bucket_base_; + std::fill(key_to_base.begin(), key_to_base.end(), 0); int ap_num = 0; - for (int k = 0; k < NW_MAX * NW_MAX; ++k) + for (int k = 0; k < nw_stride_ * nw_stride_; ++k) { if (counts[k] == 0) { continue; } buckets.push_back({k, ap_num, counts[k]}); @@ -337,13 +339,14 @@ void PhiOperatorGpu::phi_mul_phi( CHECK_CUDA(cudaEventSynchronize(event_)); // Pass 2: scatter into the flat host arrays at per-bucket cursors. - std::array cursor{}; + auto& cursor = bucket_cursor_; + std::fill(cursor.begin(), cursor.end(), 0); for (size_t i = 0; i < pair_cache_.size(); ++i) { const int hr = pair_scratch_offset_[i]; if (hr == -1) { continue; } const auto& p = pair_cache_[i]; - const int key = p.nw1 * NW_MAX + p.nw2; + const int key = p.nw1 * nw_stride_ + p.nw2; const int pos = key_to_base[key] + cursor[key]++; h_A[pos] = phi_d + p.phi_1_offset; h_B[pos] = phi_vldr3_d + p.phi_2_offset; @@ -363,8 +366,8 @@ void PhiOperatorGpu::phi_mul_phi( for (const auto& b : buckets) { - const int nw1 = b.key / NW_MAX; - const int nw2 = b.key % NW_MAX; + const int nw1 = b.key / nw_stride_; + const int nw2 = b.key % nw_stride_; gemm_tn_vbatch(nw1, nw2, mgrids_num_, @@ -396,7 +399,8 @@ void PhiOperatorGpu::phi_mul_dm( // identical across every pair in the batch. is_symm selects the // upper-triangle (ia_1 <= ia_2) subset and fills per-pair alpha. - std::array counts{}; + auto& counts = bucket_counts_; + std::fill(counts.begin(), counts.end(), 0); // Pass 1: filter + HContainer lookup + per-shape count. for (size_t i = 0; i < pair_cache_.size(); ++i) @@ -406,16 +410,17 @@ void PhiOperatorGpu::phi_mul_dm( const int dm_offset = dm.find_matrix_offset(p.iat_1, p.iat_2, p.r_diff); pair_scratch_offset_[i] = dm_offset; if (dm_offset == -1) { continue; } - counts[p.nw1 * NW_MAX + p.nw2]++; + counts[p.nw1 * nw_stride_ + p.nw2]++; } // Prefix sum over dense keys -> compact bucket list. struct Bucket { int key; int off; int cnt; }; std::vector buckets; buckets.reserve(32); - std::array key_to_base{}; + auto& key_to_base = bucket_base_; + std::fill(key_to_base.begin(), key_to_base.end(), 0); int ap_num = 0; - for (int k = 0; k < NW_MAX * NW_MAX; ++k) + for (int k = 0; k < nw_stride_ * nw_stride_; ++k) { if (counts[k] == 0) { continue; } buckets.push_back({k, ap_num, counts[k]}); @@ -434,13 +439,14 @@ void PhiOperatorGpu::phi_mul_dm( CHECK_CUDA(cudaEventSynchronize(event_)); // Pass 2: scatter. - std::array cursor{}; + auto& cursor = bucket_cursor_; + std::fill(cursor.begin(), cursor.end(), 0); for (size_t i = 0; i < pair_cache_.size(); ++i) { const int dm_offset = pair_scratch_offset_[i]; if (dm_offset == -1) { continue; } const auto& p = pair_cache_[i]; - const int key = p.nw1 * NW_MAX + p.nw2; + const int key = p.nw1 * nw_stride_ + p.nw2; const int pos = key_to_base[key] + cursor[key]++; h_A[pos] = phi_d + p.phi_1_offset; h_B[pos] = dm_d + dm_offset; @@ -469,8 +475,8 @@ void PhiOperatorGpu::phi_mul_dm( for (const auto& b : buckets) { - const int nw1 = b.key / NW_MAX; - const int nw2 = b.key % NW_MAX; + const int nw1 = b.key / nw_stride_; + const int nw2 = b.key % nw_stride_; auto alpha_ptr = is_symm ? (gemm_alpha_.get_device_ptr() + b.off) : nullptr; gemm_nn_vbatch(mgrids_num_, nw2, diff --git a/source/source_lcao/module_gint/kernel/phi_operator_gpu.h b/source/source_lcao/module_gint/kernel/phi_operator_gpu.h index 9ba43420236..5631926aa3e 100644 --- a/source/source_lcao/module_gint/kernel/phi_operator_gpu.h +++ b/source/source_lcao/module_gint/kernel/phi_operator_gpu.h @@ -11,12 +11,6 @@ namespace ModuleGint { -// Upper bound for per-atom orbital count. Used to flatten a (nw1, nw2) pair -// into a single integer key via `nw1 * NW_MAX + nw2`, so that shape-exact -// bucketing of phi_mul_phi / phi_mul_dm can index a dense counts[] table -// instead of hashing. 64 comfortably covers the typical max nw (~25). -constexpr int NW_MAX = 64; - // Per-atom-pair metadata cached in PhiOperatorGpu::set_bgrid_batch() so that // phi_mul_phi / phi_mul_dm skip the O(bgrid * atoms^2) enumeration on every // call. Holds only the fields both callers need; HContainer lookups still @@ -106,6 +100,14 @@ class PhiOperatorGpu int phi_len_; + // Stride for flattening a (nw1, nw2) pair into a single dense bucket key + // (`nw1 * nw_stride_ + nw2`), so shape-exact bucketing of phi_mul_phi / + // phi_mul_dm can index a flat table instead of hashing. Set once in the + // ctor to ucell.nwmax + 1 -- nwmax is the largest per-atom orbital count in + // the cell, so there is no artificial ceiling: the table is sized to the + // actual basis (typical nwmax ~25). + int nw_stride_ = 0; + cudaStream_t stream_ = 0; cudaEvent_t event_; @@ -147,6 +149,13 @@ class PhiOperatorGpu // per-pair HContainer offsets from Pass 1 and replay them in Pass 2 // without a second find_matrix_offset() call. mutable std::vector pair_scratch_offset_; + + // Dense (nw_stride_ * nw_stride_) counting-sort scratch shared by + // phi_mul_phi / phi_mul_dm. Sized once in the ctor and just re-zeroed per + // call, so the hot path never reallocates. + mutable std::vector bucket_counts_; + mutable std::vector bucket_base_; + mutable std::vector bucket_cursor_; }; } \ No newline at end of file From 29ca1dfd3f6465c4c800b435fff50db9e5fe8e50 Mon Sep 17 00:00:00 2001 From: dzzz2001 Date: Fri, 29 May 2026 17:57:52 +0800 Subject: [PATCH 3/3] refactor(gint): clarify GEMM kernel comments, hoist shape-bucket struct Follow-up cleanup on the shape-exact vbatched GEMM path. No behavior change. - gemm_{nn,tn}_vbatch, dgemm_vbatch, gint_helper: rewrite the kernel comments to describe the actual mechanism (K-inner shared-memory layout, wide vector loads feeding VK FMAs per load, the tile ladder, fp64 cross-item accumulation) and drop the internal "V1/V3/Phase" development shorthand that carried no meaning outside the original work log. - phi_operator_gpu: the local `Bucket` struct was declared identically inside both phi_mul_phi and phi_mul_dm. Hoist it to a named GemmShapeBucket type and reuse a single buckets_ member vector (cleared, not reallocated) across both, reserved once in the ctor -- one less per-call heap allocation on the hot path. - phi_operator_gpu: pair_scratch_offset_ is fully overwritten in Pass 1 before Pass 2 reads it, so resize() it instead of assign(..., -1); the -1 sentinel was never observed. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../module_gint/kernel/dgemm_vbatch.cu | 23 +++++- .../module_gint/kernel/dgemm_vbatch.h | 19 ++--- .../module_gint/kernel/gemm_nn_vbatch.cuh | 76 ++++++++----------- .../module_gint/kernel/gemm_tn_vbatch.cuh | 47 ++++++------ .../module_gint/kernel/gint_helper.cuh | 23 +++--- .../module_gint/kernel/phi_operator_gpu.cu | 19 ++--- .../module_gint/kernel/phi_operator_gpu.h | 25 +++++- 7 files changed, 131 insertions(+), 101 deletions(-) diff --git a/source/source_lcao/module_gint/kernel/dgemm_vbatch.cu b/source/source_lcao/module_gint/kernel/dgemm_vbatch.cu index 98764571baa..0b4df833b3b 100644 --- a/source/source_lcao/module_gint/kernel/dgemm_vbatch.cu +++ b/source/source_lcao/module_gint/kernel/dgemm_vbatch.cu @@ -3,6 +3,21 @@ #include "dgemm_vbatch.h" #include "source_base/module_device/device.h" +// Tile ladder +// ----------- +// The caller splits each batch into buckets of identical (m, n, k) and calls +// in once per bucket. The dispatchers below pick, for each bucket, the kernel +// instantiation whose (BLK_M, BLK_N) tile is the smallest rung that still +// covers the bucket's output shape, so boundary blocks don't spend most of +// their work on masked-off padding. +// +// Each thread owns a THR_M x THR_N register accumulator tile, i.e. it computes +// THR = THR_M * THR_N = (BLK_M / DIM_X) * (BLK_N / DIM_Y) +// output elements. We aim to keep THR in roughly [16, 36]: below that the inner +// FMAs don't amortize the shared-memory traffic and there's too little ILP; +// above it register pressure starts cutting occupancy. The "(in band)" / +// "(under)" notes on each case below mark where that rung lands. + template void gemm_nn_vbatch( int m, int n, int k, @@ -32,8 +47,10 @@ void gemm_nn_vbatch( : (n <= 32) ? 2 : 3; - // BLK_N bracket -- 32 only when bxyz <=32 (caps mask waste at 50% for - // bxyz=27); 64 for everything else (best LDS reuse). + // BLK_N bracket -- tiles the bxyz (mesh-grid) axis. Use 32 when bxyz<=32 so + // a partial final block-row isn't mostly masked padding (e.g. bxyz=27 in a + // 64-row tile leaves ~58% of the rows idle); use 64 above that, where the + // larger tile gives better shared-memory reuse. const int blk_n_tag = (m <= 32) ? 0 : 1; switch (blk_m_tag * 2 + blk_n_tag) @@ -87,7 +104,7 @@ void gemm_tn_vbatch( switch (blk_m_tag * 4 + blk_n_tag) { // BLK_M=8 rungs (nw2<=8). DIM_X=4, THR_M=2. - case 0: TN_DISPATCH(4, 8, 8, 8); break; // THR=2*1=2 (corner) + case 0: TN_DISPATCH(4, 8, 8, 8); break; // THR=2*1=2 (well under band) case 1: TN_DISPATCH(4, 8, 8, 16); break; // THR=2*2=4 case 2: TN_DISPATCH(4, 8, 8, 32); break; // THR=2*4=8 case 3: TN_DISPATCH(4, 8, 8, 48); break; // THR=2*6=12 diff --git a/source/source_lcao/module_gint/kernel/dgemm_vbatch.h b/source/source_lcao/module_gint/kernel/dgemm_vbatch.h index 3052011767b..9d70c84348c 100644 --- a/source/source_lcao/module_gint/kernel/dgemm_vbatch.h +++ b/source/source_lcao/module_gint/kernel/dgemm_vbatch.h @@ -5,16 +5,17 @@ // Shape-exact batched GEMM dispatchers. // // Every (A_i, B_i, C_i) in the batch has exactly the same (m, n, k); the -// caller (phi_operator_gpu.cu) enforces this by bucketing atom pairs on -// (nw1, nw2). The scalars drive tile-ladder selection, grid sizing, and -// flow all the way through the kernel -- there is no per-batchid M/N/K -// indirection left. +// caller (phi_operator_gpu.cu) guarantees this by bucketing atom pairs on +// (nw1, nw2) before calling in. The scalar m/n/k drive tile-ladder selection, +// grid sizing, and the kernel itself -- there is no per-batch-id M/N/K +// indirection. // -// The C accumulator is always double regardless of the input type T: a fp32 -// GEMM path (T=float) feeds fp32 multiplies into fp64 registers and a -// device-side fp64 atomicAdd, so summing many atom-pair contributions into the -// same hr_gint / phi_dm element does not drift. For T=double, A, B and C are -// all double and this matches the legacy signature. +// The C output is always double, independent of T. For T=float the per-item +// inner products accumulate in fp32, but the cross-item accumulation into a +// shared C element is done with a device-side fp64 atomicAdd (see the kernels' +// store loop), so summing many atom-pair contributions into the same +// hr_gint / phi_dm element does not drift. For T=double, A, B and C are all +// double. // C(batch) = alpha * A(batch) * B(batch) + C(batch) template diff --git a/source/source_lcao/module_gint/kernel/gemm_nn_vbatch.cuh b/source/source_lcao/module_gint/kernel/gemm_nn_vbatch.cuh index aa1464341ae..0ae74f3d9ae 100644 --- a/source/source_lcao/module_gint/kernel/gemm_nn_vbatch.cuh +++ b/source/source_lcao/module_gint/kernel/gemm_nn_vbatch.cuh @@ -12,28 +12,17 @@ #include "source_base/module_device/device_check.h" #include "source_base/module_device/kernel_compat.h" -// V1 K-inner shmem layout -// sA(m, k) = sA[m * slda + k] row-major in M, K-inner; slda = BLK_K + PAD -// sB(k, n) = sB[n * sldb + k] col-major in N, K-inner; sldb = BLK_K + PAD -// Both layouts make the inner loop read VK consecutive K elements per LDS, -// turning one scalar LDS-per-FMA into one 16-byte LDS-per-VK-FMAs. -// PAD comes from gemm_vec_traits::PAD (FP32: +4, FP64: +2) and is what -// makes slda/sldb 16-byte aligned for LDS.{64,128}. -// -// Phase V3 bank-conflict audit (sA inner-loop read, idx-strided lanes): -// FP64, DIM_X= 8 (8x16 thread tiles): slda=BLK_K+2 -> 8 lanes at -// stride 4 banks each side -> banks {0,4,...,28} disjoint -> 0 conflicts. -// FP32, DIM_X= 8 (8x16 thread tiles): slda=BLK_K+4 -> 8 lanes at -// stride 4 banks (4-bank vec) -> disjoint -> 0 conflicts. -// FP64, DIM_X=16 (V2 16x16 big tile): slda=BLK_K+2, 16 lanes; even -// slda forces gcd(2*slda,32) >= 2, so the LOW/HIGH bank pair lands -// on distinct banks for all 16 lanes only when 2*slda has order >=16 -// mod 32. With BLK_K=16 -> slda=18 -> 36 mod 32 = 4 -> 8-distinct -// -> 2-way conflict. Accepted in V2: still beats scalar LDS by ~VK/2, -// and removing the conflict requires a swizzled layout (Step 2). -// sB inner-loop read uses idy-strided lanes; with DIM_Y in {8,16} the -// warp covers only 2-4 distinct n_col values, broadcast factor >= 8 -// -> always conflict-free regardless of sldb. +// Shared-memory tile layout (K-inner): both operands store the contraction +// (K) axis contiguously so the inner loop can read VK consecutive K elements +// with a single 16-byte vector load instead of VK scalar loads. +// sA(m, k) = sA[m * slda + k] -- M indexes the row, K is contiguous +// sB(k, n) = sB[n * sldb + k] -- N indexes the column, K is contiguous +// slda = sldb = BLK_K + PAD +// PAD (from gemm_vec_traits: +4 for FP32, +2 for FP64) keeps the K-stride a +// whole number of 16-byte words so the vector loads stay aligned, and offsets +// the per-warp access so the strided shared-memory reads spread across banks. +// The widest FP64 tiles (DIM_X=16) still take a few bank conflicts on the sA +// read, but that is a net win over the scalar layout this replaces. #define sA(i, j) sA[(i)*slda + (j)] #define sB(i, j) sB[(j)*sldb + (i)] #define fetch(A, m, n, bound) offs_d##A[min(n * LD##A + m, bound)] @@ -68,17 +57,16 @@ static __device__ void vbatched_gemm_nn_device(int M, using vec_t = typename gemm_vec_traits::vec_t; constexpr int VK = gemm_vec_traits::VK; - // V1 contract: BLK_K must be a whole number of VK chunks so the - // vectorized FMA loop below covers it cleanly. PAD makes slda * 8/4 - // a multiple of 16 (LDS alignment) -- enforced at the kernel scope. + // BLK_K must be a whole number of VK chunks so the vectorized FMA loop + // below covers it exactly; the 16-byte alignment of the K-stride is + // enforced separately at kernel scope. static_assert(BLK_K % VK == 0, "BLK_K must be divisible by VK (16 / sizeof(T))"); - // Tile-divisibility (Phase V3 audit): every dev->shmem load loop - // assumes the BLK_* dim is an exact multiple of the corresponding - // DIM_*, and the per-thread fan-out THR_M/N is BLK_M/N / DIM_X/Y. - // A mis-spec'd new template instantiation would silently load - // garbage; these asserts surface it at compile time. + // Tile divisibility: every dev->shmem load loop assumes BLK_* is an exact + // multiple of the matching DIM_*, and the per-thread fan-out THR_M/THR_N is + // BLK_M/BLK_N divided by DIM_X/DIM_Y. A mis-specified instantiation would + // silently load garbage, so surface it at compile time. static_assert(BLK_M % DIM_X == 0, "BLK_M must be divisible by DIM_X"); static_assert(BLK_N % DIM_Y == 0, "BLK_N must be divisible by DIM_Y"); static_assert(BLK_M % DIM_XA == 0, "BLK_M must be divisible by DIM_XA"); @@ -104,10 +92,11 @@ static __device__ void vbatched_gemm_nn_device(int M, int blx = blockIdx.x; // block's m dimension int bly = blockIdx.y; // block's n dimension - // Accumulator tile (registers). Layout matches the original. + // Accumulator tile (registers). rC accumulates in T; the widening to + // double happens only at the final atomicAdd into C. T rC[THR_N][THR_M]; - // Per-VK-step shmem->reg tiles. One LDS feeds VK FMAs per (m,n). + // Per-VK-step shmem->reg tiles. One load feeds VK FMAs per (m,n). T rA[THR_M][VK]; T rB[THR_N][VK]; @@ -191,11 +180,11 @@ static __device__ void vbatched_gemm_nn_device(int M, } } -// Wide-LDS FMA: VK FMAs per shmem read. -// FP32: LDS.128 (float4) -> 4 FMAs per (m,n) per inner step -// FP64: LDS.64 (double2) -> 2 FMAs per (m,n) per inner step -// Both rely on slda/sldb being 16-byte aligned (PAD math) and on BLK_K -// being a whole number of VK chunks (static_assert above). +// Wide-load FMA: one vector load feeds VK FMAs per (m, n). +// FP32: float4 -> 4 FMAs per (m,n) per inner step +// FP64: double2 -> 2 FMAs per (m,n) per inner step +// Relies on the K-stride being 16-byte aligned and BLK_K being a whole +// number of VK chunks (static_asserts above). #pragma unroll for (k = 0; k < BLK_K; k += VK) { @@ -259,9 +248,10 @@ static __device__ void vbatched_gemm_nn_device(int M, __syncthreads(); } - // Tail: last full (BLK_K) or partial block. Scalar from the K-inner - // layout -- the partial-K block can land on an odd k count (e.g. - // bxyz=27 -> tail 11), so don't try to vectorize it. + // Tail: the leftover K columns after the BLK_K-strided main loop (here K is + // the contraction length, nw1). The remainder is generally not a multiple + // of VK, so it runs with scalar shared-memory reads instead of the vector + // load. // It's okay that m,n exceed matrix bounds as all work is in registers // or shared memory, and out-of-bounds rC[n][m] will not be saved later. kk = K - kk; @@ -347,8 +337,8 @@ static __global__ void vbatched_gemm_nn_kernel(int M, static_assert(BLK_K % gemm_vec_traits::VK == 0, "BLK_K must be divisible by VK = 16 / sizeof(T)"); - // V1 K-inner: slda is the K-axis stride for sA (M-rows of (BLK_K + PAD)), - // sldb is the K-axis stride for sB (N-cols of (BLK_K + PAD)). + // K-inner layout: slda/sldb are the K-axis stride (BLK_K + PAD) for sA + // (BLK_M rows) and sB (BLK_N columns) respectively. int shared_lda = BLK_K + PAD; int shared_ldb = BLK_K + PAD; T* shared_A = (T*)shared_mem; @@ -446,7 +436,7 @@ void vbatched_gemm_nn_impl(int m, // This is because vbatch_gemm_nn_kernel is column major, // but vatched_gemm_nn_impl is designed to be row major, - // V1 K-inner shmem footprint: + // K-inner shared-memory footprint: // sA: BLK_M rows of (BLK_K + PAD) elements // sB: BLK_N cols of (BLK_K + PAD) elements constexpr int PAD = gemm_vec_traits::PAD; diff --git a/source/source_lcao/module_gint/kernel/gemm_tn_vbatch.cuh b/source/source_lcao/module_gint/kernel/gemm_tn_vbatch.cuh index 7ed8d3e7af1..b846bbf50b2 100644 --- a/source/source_lcao/module_gint/kernel/gemm_tn_vbatch.cuh +++ b/source/source_lcao/module_gint/kernel/gemm_tn_vbatch.cuh @@ -12,15 +12,15 @@ #include "source_base/module_device/device_check.h" #include "source_base/module_device/kernel_compat.h" -// V1 K-inner shmem layout (matches gemm_nn_vbatch.cuh): -// sA(m, k) = sA[m * slda + k] row-major in M, K-inner; slda = BLK_K + PAD -// sB(k, n) = sB[n * sldb + k] col-major in N, K-inner; sldb = BLK_K + PAD -// PAD comes from gemm_vec_traits (FP32: +4, FP64: +2) and makes the -// stride 16-byte aligned + bank-conflict-free for warp-wide LDS. -// See gemm_nn_vbatch.cuh for the full Phase V3 bank-conflict audit table; -// the TN inner loop uses the same indexing pattern, so the same analysis -// applies (the only structural difference is sB's load loop, which writes -// to the same K-inner storage layout). +// Shared-memory tile layout (K-inner), identical to gemm_nn_vbatch.cuh: +// sA(m, k) = sA[m * slda + k] -- M indexes the row, K is contiguous +// sB(k, n) = sB[n * sldb + k] -- N indexes the column, K is contiguous +// slda = sldb = BLK_K + PAD +// PAD (from gemm_vec_traits: +4 for FP32, +2 for FP64) keeps the K-stride a +// whole number of 16-byte words for the vector loads and spreads the strided +// reads across banks. See gemm_nn_vbatch.cuh for the layout rationale; the TN +// inner loop uses the same access pattern, the only difference being how the +// dev->shmem load loop for sB is indexed. #define sA(i, j) sA[(i)*slda + (j)] #define sB(i, j) sB[(j)*sldb + (i)] #define fetch(A, m, n, bound) offs_d##A[min(n * LD##A + m, bound)] @@ -58,9 +58,9 @@ static __device__ void vbatched_gemm_nt_device(int M, static_assert(BLK_K % VK == 0, "BLK_K must be divisible by VK (16 / sizeof(T))"); - // Tile-divisibility (Phase V3 audit): same checks as gemm_nn_vbatch. - // sB load loop in TN traverses (BLK_K rows x BLK_N cols), so the - // divisibility constraints on DIM_XB / DIM_YB are mirrored. + // Tile divisibility: same constraints as gemm_nn_vbatch. The TN sB load + // loop traverses (BLK_K rows x BLK_N cols), so the DIM_XB / DIM_YB + // divisibility checks are mirrored accordingly. static_assert(BLK_M % DIM_X == 0, "BLK_M must be divisible by DIM_X"); static_assert(BLK_N % DIM_Y == 0, "BLK_N must be divisible by DIM_Y"); static_assert(BLK_M % DIM_XA == 0, "BLK_M must be divisible by DIM_XA"); @@ -86,10 +86,11 @@ static __device__ void vbatched_gemm_nt_device(int M, int blx = blockIdx.x; // block's m dimension int bly = blockIdx.y; // block's n dimension - // Accumulator tile (registers). + // Accumulator tile (registers). rC accumulates in T; the widening to + // double happens only at the final atomicAdd into C. T rC[THR_N][THR_M]; - // Per-VK-step shmem->reg tiles. One LDS feeds VK FMAs per (m,n). + // Per-VK-step shmem->reg tiles. One load feeds VK FMAs per (m,n). T rA[THR_M][VK]; T rB[THR_N][VK]; @@ -173,9 +174,9 @@ static __device__ void vbatched_gemm_nt_device(int M, } } -// Wide-LDS FMA: VK FMAs per shmem read. -// FP32: LDS.128 (float4) -> 4 FMAs per (m,n) per inner step -// FP64: LDS.64 (double2) -> 2 FMAs per (m,n) per inner step +// Wide-load FMA: one vector load feeds VK FMAs per (m, n). +// FP32: float4 -> 4 FMAs per (m,n) per inner step +// FP64: double2 -> 2 FMAs per (m,n) per inner step #pragma unroll for (k = 0; k < BLK_K; k += VK) { @@ -239,8 +240,10 @@ static __device__ void vbatched_gemm_nt_device(int M, __syncthreads(); } - // Tail: scalar from the K-inner layout. Partial-K blocks can be odd - // (bxyz=27 -> tail 11; bxyz=125 -> tail 13), so don't try to vectorize. + // Tail: the leftover K columns after the BLK_K-strided main loop (here K is + // the contraction length, the mesh-grid count bxyz). The remainder is + // generally not a multiple of VK, so it runs with scalar shared-memory + // reads instead of the vector load. // It's okay that m,n exceed matrix bounds as all work is in registers // or shared memory, and out-of-bounds rC[n][m] will not be saved later. kk = K - kk; @@ -326,8 +329,8 @@ static __global__ void vbatched_gemm_nt_kernel(int M, static_assert(BLK_K % gemm_vec_traits::VK == 0, "BLK_K must be divisible by VK = 16 / sizeof(T)"); - // V1 K-inner: slda = K-axis stride for sA (BLK_M rows of (BLK_K + PAD)), - // sldb = K-axis stride for sB (BLK_N cols of (BLK_K + PAD)). + // K-inner layout: slda/sldb are the K-axis stride (BLK_K + PAD) for sA + // (BLK_M rows) and sB (BLK_N columns) respectively. int shared_lda = BLK_K + PAD; int shared_ldb = BLK_K + PAD; T* shared_A = (T*)shared_mem; @@ -450,7 +453,7 @@ void vbatched_gemm_tn_impl(int m, // This is because vbatch_gemm__tn_kernel is column major, // but vatched_gemm_nt_impl is designed to be row major, - // V1 K-inner shmem footprint (matches gemm_nn_vbatch_impl): + // K-inner shared-memory footprint (matches vbatched_gemm_nn_impl): // sA: BLK_M rows of (BLK_K + PAD) elements // sB: BLK_N cols of (BLK_K + PAD) elements constexpr int PAD = gemm_vec_traits::PAD; diff --git a/source/source_lcao/module_gint/kernel/gint_helper.cuh b/source/source_lcao/module_gint/kernel/gint_helper.cuh index ec1007b1d23..1f6d431003d 100644 --- a/source/source_lcao/module_gint/kernel/gint_helper.cuh +++ b/source/source_lcao/module_gint/kernel/gint_helper.cuh @@ -44,19 +44,20 @@ inline int ceil_div(const int a, const int b) } // --------------------------------------------------------------------------- -// gemm_vec_traits -- wide-LDS primitive for the V1 K-inner inner loop. +// gemm_vec_traits -- the wide-load primitive used by the GEMM inner loop, +// which reads VK consecutive K elements per shared-memory load instead of one. // -// VK = how many T elements pack into one 16-byte LDS (4 for FP32, 2 for FP64) -// vec_t = the 16-byte CUDA vector type used for the LDS -// PAD = K-stride padding that makes (BLK_K + PAD) * sizeof(T) a multiple of -// 16 *and* keeps the warp's idx-strided shmem access bank-conflict-free -// (gcd((BLK_K+PAD) % 32, 32) == VK). +// VK = number of T elements in one 16-byte load (4 for FP32, 2 for FP64) +// vec_t = the 16-byte vector type used for that load (float4 / double2) +// PAD = padding added to the shared-memory K-stride so that (BLK_K + PAD) +// elements span a whole number of 16-byte words, keeping the +// vectorized shared-memory loads aligned and spreading the warp's +// strided reads across banks. // -// The load is issued as one *reinterpret_cast(&sA(m,k)); the -// component fan-out is done by unpack(). FP64 needs the explicit cast -- -// the compiler's auto-vectorizer is reliable for float4 but not for -// double2; per-component .x/.y/.z/.w writes guarantee the LDS.{64,128} -// SASS forms emit. +// The load is one *reinterpret_cast(&sA(m, k)); unpack() then fans the +// vector out into the per-thread registers. The explicit per-component copy is +// deliberate: nvcc reliably vectorizes float4 but not double2, and writing +// .x/.y(/.z/.w) by hand guarantees the wide load instruction is emitted. // --------------------------------------------------------------------------- template struct gemm_vec_traits; diff --git a/source/source_lcao/module_gint/kernel/phi_operator_gpu.cu b/source/source_lcao/module_gint/kernel/phi_operator_gpu.cu index 0969036dc9d..60e193e16a5 100644 --- a/source/source_lcao/module_gint/kernel/phi_operator_gpu.cu +++ b/source/source_lcao/module_gint/kernel/phi_operator_gpu.cu @@ -40,6 +40,7 @@ gemm_alpha_(BatchBigGrid::get_max_atom_pairs_num(), stream_, true) bucket_counts_.assign(nw_stride_ * nw_stride_, 0); bucket_base_.assign(nw_stride_ * nw_stride_, 0); bucket_cursor_.assign(nw_stride_ * nw_stride_, 0); + buckets_.reserve(32); } template @@ -138,7 +139,9 @@ void PhiOperatorGpu::set_bgrid_batch(std::shared_ptr bgrid_b } } } - pair_scratch_offset_.assign(pair_cache_.size(), -1); + // Sized to match pair_cache_; Pass 1 of phi_mul_phi / phi_mul_dm overwrites + // every entry before Pass 2 reads it, so no initial value is needed. + pair_scratch_offset_.resize(pair_cache_.size()); } template @@ -286,8 +289,8 @@ void PhiOperatorGpu::phi_mul_phi( // Shape-exact bucketing: group atom pairs by (nw1, nw2). K = mgrids_num_ // is already batch-wide constant, so (nw1, nw2) fully determines the GEMM // shape. Each bucket hands gemm_tn_vbatch scalar (nw1, nw2, mgrids_num_), - // so the 3x3 template ladder picks the tightest tile for every item and - // the wrapper sizes the grid exactly -- no cross-species tile waste, no + // so the tile ladder picks the tightest tile for every shape and the + // wrapper sizes the grid exactly -- no cross-shape tile waste, no // over-launched blocks. // // Algorithm: counting-sort-style two-pass over the pre-enumerated @@ -315,9 +318,8 @@ void PhiOperatorGpu::phi_mul_phi( } // Prefix sum over dense keys -> compact bucket list. - struct Bucket { int key; int off; int cnt; }; - std::vector buckets; - buckets.reserve(32); + auto& buckets = buckets_; + buckets.clear(); auto& key_to_base = bucket_base_; std::fill(key_to_base.begin(), key_to_base.end(), 0); int ap_num = 0; @@ -414,9 +416,8 @@ void PhiOperatorGpu::phi_mul_dm( } // Prefix sum over dense keys -> compact bucket list. - struct Bucket { int key; int off; int cnt; }; - std::vector buckets; - buckets.reserve(32); + auto& buckets = buckets_; + buckets.clear(); auto& key_to_base = bucket_base_; std::fill(key_to_base.begin(), key_to_base.end(), 0); int ap_num = 0; diff --git a/source/source_lcao/module_gint/kernel/phi_operator_gpu.h b/source/source_lcao/module_gint/kernel/phi_operator_gpu.h index 5631926aa3e..d93270ea7e6 100644 --- a/source/source_lcao/module_gint/kernel/phi_operator_gpu.h +++ b/source/source_lcao/module_gint/kernel/phi_operator_gpu.h @@ -4,6 +4,7 @@ #include #include +#include "source_lcao/module_gint/gint_type.h" // Vec3i, used by PairInfo #include "source_lcao/module_gint/batch_biggrid.h" #include "gint_gpu_vars.h" #include "cuda_mem_wrapper.h" @@ -29,6 +30,16 @@ struct PairInfo uint8_t is_diag; // (ia_1 == ia_2), for phi_mul_dm is_symm alpha }; +// One non-empty (nw1, nw2) bucket produced by the counting sort in +// phi_mul_phi / phi_mul_dm: a flattened shape key, the bucket's start offset in +// the flat gemm_* arrays, and how many atom pairs landed in it. +struct GemmShapeBucket +{ + int key; // nw1 * nw_stride_ + nw2 + int off; // start offset in the gemm_* host/device arrays + int cnt; // number of atom pairs in this bucket +}; + template class PhiOperatorGpu { @@ -53,10 +64,12 @@ class PhiOperatorGpu const Real* phi_d, Real* result_d) const; - // All GEMM accumulators (hr in phi_mul_phi, phi_dm in phi_mul_dm) are - // double-typed regardless of Real: when Real=float the multiplies stay in - // fp32 (cheap) but per-block reductions and device-side atomicAdd run in - // fp64 so the global reductions don't drift. + // The GEMM output buffers (hr in phi_mul_phi, phi_dm in phi_mul_dm) are + // always double, independent of Real. When Real=float the per-pair inner + // products are reduced in fp32 (cheap); the cross-pair accumulation into a + // shared hr/phi_dm element is what runs in fp64, via an atomicAdd into + // these double buffers, so summing many atom-pair contributions doesn't + // drift. void phi_mul_phi( const Real* phi_d, const Real* phi_vldr3_d, @@ -156,6 +169,10 @@ class PhiOperatorGpu mutable std::vector bucket_counts_; mutable std::vector bucket_base_; mutable std::vector bucket_cursor_; + + // Compact list of non-empty buckets for the current call. Reused (cleared, + // not reallocated) by both phi_mul_phi and phi_mul_dm. + mutable std::vector buckets_; }; } \ No newline at end of file