diff --git a/source/source_lcao/module_gint/kernel/dgemm_vbatch.cu b/source/source_lcao/module_gint/kernel/dgemm_vbatch.cu index 38946d51492..0b4df833b3b 100644 --- a/source/source_lcao/module_gint/kernel/dgemm_vbatch.cu +++ b/source/source_lcao/module_gint/kernel/dgemm_vbatch.cu @@ -3,60 +3,148 @@ #include "dgemm_vbatch.h" #include "source_base/module_device/device.h" +// Tile ladder +// ----------- +// The caller splits each batch into buckets of identical (m, n, k) and calls +// in once per bucket. The dispatchers below pick, for each bucket, the kernel +// instantiation whose (BLK_M, BLK_N) tile is the smallest rung that still +// covers the bucket's output shape, so boundary blocks don't spend most of +// their work on masked-off padding. +// +// Each thread owns a THR_M x THR_N register accumulator tile, i.e. it computes +// THR = THR_M * THR_N = (BLK_M / DIM_X) * (BLK_N / DIM_Y) +// output elements. We aim to keep THR in roughly [16, 36]: below that the inner +// FMAs don't amortize the shared-memory traffic and there's too little ILP; +// above it register pressure starts cutting occupancy. The "(in band)" / +// "(under)" notes on each case below mark where that rung lands. + template void gemm_nn_vbatch( - int max_m, int max_n, int max_k, - const int* m_d, const int* n_d, const int* k_d, + int m, int n, int k, const T* const* A_array_d, const int* lda_d, const T* const* B_array_d, const int* ldb_d, double** C_array_d, const int* ldc_d, int batchCount, cudaStream_t stream, const T* alpha) { - vbatched_gemm_nn_impl - (max_m, max_n, m_d, n_d, k_d, - A_array_d, lda_d, - B_array_d, ldb_d, - C_array_d, ldc_d, - batchCount, stream, alpha); + // 4 (nw2 bracket) x 2 (bxyz bracket) = 8 instantiations. + // + // Mapping into the impl's parameter list is: + // + // which satisfies the kernel's tile-divisibility asserts because every + // (BLK_M, BLK_N, BLK_K=16) chosen below is a multiple of the matching + // (DIM_X, DIM_Y) pair. + #define NN_DISPATCH(DX, DY, BM, BN) \ + vbatched_gemm_nn_impl( \ + m, n, k, \ + A_array_d, lda_d, B_array_d, ldb_d, \ + C_array_d, ldc_d, batchCount, stream, alpha) + + // BLK_M bracket -- smallest tile in {8,16,32,48} covering nw2. + const int blk_m_tag = (n <= 8) ? 0 + : (n <= 16) ? 1 + : (n <= 32) ? 2 + : 3; + + // BLK_N bracket -- tiles the bxyz (mesh-grid) axis. Use 32 when bxyz<=32 so + // a partial final block-row isn't mostly masked padding (e.g. bxyz=27 in a + // 64-row tile leaves ~58% of the rows idle); use 64 above that, where the + // larger tile gives better shared-memory reuse. + const int blk_n_tag = (m <= 32) ? 0 : 1; + + switch (blk_m_tag * 2 + blk_n_tag) + { + // BLK_M=8 (nw2 <=8 ). DIM=4x8 -> THR_M=2. + case 0: NN_DISPATCH( 4, 8, 8, 32); break; // THR=2*4=8 (under) + case 1: NN_DISPATCH( 4, 8, 8, 64); break; // THR=2*8=16 (in band) + // BLK_M=16 (nw2<=16). DIM=4x8 -> THR_M=4. + case 2: NN_DISPATCH( 4, 8, 16, 32); break; // THR=4*4=16 (in band) + case 3: NN_DISPATCH( 4, 8, 16, 64); break; // THR=4*8=32 (in band) + // BLK_M=32 (nw2<=32). DIM=8x8 -> THR_M=4. + case 4: NN_DISPATCH( 8, 8, 32, 32); break; // THR=4*4=16 (in band) + case 5: NN_DISPATCH( 8, 8, 32, 64); break; // THR=4*8=32 (in band) + // BLK_M=48 (nw2<=48). DIM=16x8 -> THR_M=3 (cap at 3 to keep + // register pressure room for the BLK_N=64 sibling). + case 6: NN_DISPATCH(16, 8, 48, 32); break; // THR=3*4=12 (just under) + case 7: NN_DISPATCH(16, 8, 48, 64); break; // THR=3*8=24 (in band) + } + #undef NN_DISPATCH } template void gemm_tn_vbatch( - int max_m, int max_n, int max_k, - const int* m_d, const int* n_d, const int* k_d, + int m, int n, int k, const T* const* A_array_d, const int* lda_d, const T* const* B_array_d, const int* ldb_d, double** C_array_d, const int* ldc_d, int batchCount, cudaStream_t stream, const T* alpha) { - vbatched_gemm_tn_impl - (max_m, max_n, m_d, n_d, k_d, - A_array_d, lda_d, - B_array_d, ldb_d, - C_array_d, ldc_d, - batchCount, stream, alpha); + // 4 (nw2 bracket) x 4 (nw1 bracket) = 16 instantiations. + // + // Both output axes here are the small nw axis, so we use the same + // {8,16,32,48} ladder on both. BLK_K = 32 (the bxyz axis -- large). + #define TN_DISPATCH(DX, DY, BM, BN) \ + vbatched_gemm_tn_impl( \ + m, n, k, \ + A_array_d, lda_d, B_array_d, ldb_d, \ + C_array_d, ldc_d, batchCount, stream, alpha) + + auto bracket = [](int x) { + return (x <= 8) ? 0 + : (x <= 16) ? 1 + : (x <= 32) ? 2 + : 3; + }; + const int blk_m_tag = bracket(n); // BLK_M <- nw2 + const int blk_n_tag = bracket(m); // BLK_N <- nw1 + + switch (blk_m_tag * 4 + blk_n_tag) + { + // BLK_M=8 rungs (nw2<=8). DIM_X=4, THR_M=2. + case 0: TN_DISPATCH(4, 8, 8, 8); break; // THR=2*1=2 (well under band) + case 1: TN_DISPATCH(4, 8, 8, 16); break; // THR=2*2=4 + case 2: TN_DISPATCH(4, 8, 8, 32); break; // THR=2*4=8 + case 3: TN_DISPATCH(4, 8, 8, 48); break; // THR=2*6=12 + // BLK_M=16 rungs (nw2<=16). DIM_X=4, THR_M=4. + case 4: TN_DISPATCH(4, 8, 16, 8); break; // THR=4*1=4 + case 5: TN_DISPATCH(4, 8, 16, 16); break; // THR=4*2=8 + case 6: TN_DISPATCH(4, 8, 16, 32); break; // THR=4*4=16 (in band) + case 7: TN_DISPATCH(4, 8, 16, 48); break; // THR=4*6=24 (in band) + // BLK_M=32 rungs (nw2<=32). DIM_X=8, THR_M=4. + case 8: TN_DISPATCH(8, 4, 32, 8); break; // THR=4*2=8 + case 9: TN_DISPATCH(8, 4, 32, 16); break; // THR=4*4=16 (in band) + case 10: TN_DISPATCH(8, 8, 32, 32); break; // THR=4*4=16 (in band) + case 11: TN_DISPATCH(8, 8, 32, 48); break; // THR=4*6=24 (in band) + // BLK_M=48 rungs (nw2<=48). DIM_X=8, THR_M=6. + case 12: TN_DISPATCH(8, 4, 48, 8); break; // THR=6*2=12 + case 13: TN_DISPATCH(8, 4, 48, 16); break; // THR=6*4=24 (in band) + case 14: TN_DISPATCH(8, 8, 48, 32); break; // THR=6*4=24 (in band) + case 15: TN_DISPATCH(8, 8, 48, 48); break; // THR=6*6=36 (top of band) + } + + #undef TN_DISPATCH } // Explicit instantiations template void gemm_nn_vbatch( - int, int, int, const int*, const int*, const int*, + int, int, int, const double* const*, const int*, const double* const*, const int*, double**, const int*, int, cudaStream_t, const double*); template void gemm_nn_vbatch( - int, int, int, const int*, const int*, const int*, + int, int, int, const float* const*, const int*, const float* const*, const int*, double**, const int*, int, cudaStream_t, const float*); template void gemm_tn_vbatch( - int, int, int, const int*, const int*, const int*, + int, int, int, const double* const*, const int*, const double* const*, const int*, double**, const int*, int, cudaStream_t, const double*); template void gemm_tn_vbatch( - int, int, int, const int*, const int*, const int*, + int, int, int, const float* const*, const int*, const float* const*, const int*, double**, const int*, int, cudaStream_t, const float*); diff --git a/source/source_lcao/module_gint/kernel/dgemm_vbatch.h b/source/source_lcao/module_gint/kernel/dgemm_vbatch.h index 68505c1f4e4..9d70c84348c 100644 --- a/source/source_lcao/module_gint/kernel/dgemm_vbatch.h +++ b/source/source_lcao/module_gint/kernel/dgemm_vbatch.h @@ -2,61 +2,37 @@ #include -// Template version: C(batch_id) = alpha * A(batch_id) * B(batch_id) + C(batch_id) -// As with gemm_tn_vbatch, the C accumulator is always double regardless of the -// input type T so the per-block reduction and device-side atomicAdd run in fp64. +// Shape-exact batched GEMM dispatchers. +// +// Every (A_i, B_i, C_i) in the batch has exactly the same (m, n, k); the +// caller (phi_operator_gpu.cu) guarantees this by bucketing atom pairs on +// (nw1, nw2) before calling in. The scalar m/n/k drive tile-ladder selection, +// grid sizing, and the kernel itself -- there is no per-batch-id M/N/K +// indirection. +// +// The C output is always double, independent of T. For T=float the per-item +// inner products accumulate in fp32, but the cross-item accumulation into a +// shared C element is done with a device-side fp64 atomicAdd (see the kernels' +// store loop), so summing many atom-pair contributions into the same +// hr_gint / phi_dm element does not drift. For T=double, A, B and C are all +// double. + +// C(batch) = alpha * A(batch) * B(batch) + C(batch) template void gemm_nn_vbatch( - int max_m, int max_n, int max_k, - const int* m_d, const int* n_d, const int* k_d, + int m, int n, int k, const T* const* A_array_d, const int* lda_d, const T* const* B_array_d, const int* ldb_d, double** C_array_d, const int* ldc_d, int batchCount, cudaStream_t stream, const T* alpha = nullptr); -// Template version: C(batch_id) = alpha * A(batch_id)^T * B(batch_id) + C(batch_id) -// The C accumulator is always double regardless of input type T: a fp32 GEMM -// path (T=float) feeds fp32 multiplies into fp64 accumulators (registers and -// device-side atomicAdds) to avoid catastrophic precision loss across many -// atom-pair contributions to the same hr_gint element. +// C(batch) = alpha * A(batch)^T * B(batch) + C(batch) template void gemm_tn_vbatch( - int max_m, int max_n, int max_k, - const int* m_d, const int* n_d, const int* k_d, + int m, int n, int k, const T* const* A_array_d, const int* lda_d, const T* const* B_array_d, const int* ldb_d, double** C_array_d, const int* ldc_d, int batchCount, cudaStream_t stream, const T* alpha = nullptr); - -// Legacy double-only aliases for backward compatibility -inline void dgemm_nn_vbatch( - int max_m, int max_n, int max_k, - const int* m_d, const int* n_d, const int* k_d, - const double* const* A_array_d, const int* lda_d, - const double* const* B_array_d, const int* ldb_d, - double** C_array_d, const int* ldc_d, - int batchCount, cudaStream_t stream, - const double* alpha = nullptr) -{ - gemm_nn_vbatch(max_m, max_n, max_k, - m_d, n_d, k_d, A_array_d, lda_d, B_array_d, ldb_d, - C_array_d, ldc_d, batchCount, stream, alpha); -} - -inline void dgemm_tn_vbatch( - int max_m, int max_n, int max_k, - const int* m_d, const int* n_d, const int* k_d, - const double* const* A_array_d, const int* lda_d, - const double* const* B_array_d, const int* ldb_d, - double** C_array_d, const int* ldc_d, - int batchCount, cudaStream_t stream, - const double* alpha = nullptr) -{ - // T=double path: A, B, and C are all double — the C-channel double-fix - // matches the legacy signature here. - gemm_tn_vbatch(max_m, max_n, max_k, - m_d, n_d, k_d, A_array_d, lda_d, B_array_d, ldb_d, - C_array_d, ldc_d, batchCount, stream, alpha); -} \ No newline at end of file diff --git a/source/source_lcao/module_gint/kernel/gemm_nn_vbatch.cuh b/source/source_lcao/module_gint/kernel/gemm_nn_vbatch.cuh index 39e10e4252d..0ae74f3d9ae 100644 --- a/source/source_lcao/module_gint/kernel/gemm_nn_vbatch.cuh +++ b/source/source_lcao/module_gint/kernel/gemm_nn_vbatch.cuh @@ -12,7 +12,18 @@ #include "source_base/module_device/device_check.h" #include "source_base/module_device/kernel_compat.h" -#define sA(i, j) sA[(j)*slda + (i)] +// Shared-memory tile layout (K-inner): both operands store the contraction +// (K) axis contiguously so the inner loop can read VK consecutive K elements +// with a single 16-byte vector load instead of VK scalar loads. +// sA(m, k) = sA[m * slda + k] -- M indexes the row, K is contiguous +// sB(k, n) = sB[n * sldb + k] -- N indexes the column, K is contiguous +// slda = sldb = BLK_K + PAD +// PAD (from gemm_vec_traits: +4 for FP32, +2 for FP64) keeps the K-stride a +// whole number of 16-byte words so the vector loads stay aligned, and offsets +// the per-warp access so the strided shared-memory reads spread across banks. +// The widest FP64 tiles (DIM_X=16) still take a few bank conflicts on the sA +// read, but that is a net win over the scalar layout this replaces. +#define sA(i, j) sA[(i)*slda + (j)] #define sB(i, j) sB[(j)*sldb + (i)] #define fetch(A, m, n, bound) offs_d##A[min(n * LD##A + m, bound)] @@ -43,6 +54,30 @@ static __device__ void vbatched_gemm_nn_device(int M, int sldb, T alpha) { + using vec_t = typename gemm_vec_traits::vec_t; + constexpr int VK = gemm_vec_traits::VK; + + // BLK_K must be a whole number of VK chunks so the vectorized FMA loop + // below covers it exactly; the 16-byte alignment of the K-stride is + // enforced separately at kernel scope. + static_assert(BLK_K % VK == 0, + "BLK_K must be divisible by VK (16 / sizeof(T))"); + + // Tile divisibility: every dev->shmem load loop assumes BLK_* is an exact + // multiple of the matching DIM_*, and the per-thread fan-out THR_M/THR_N is + // BLK_M/BLK_N divided by DIM_X/DIM_Y. A mis-specified instantiation would + // silently load garbage, so surface it at compile time. + static_assert(BLK_M % DIM_X == 0, "BLK_M must be divisible by DIM_X"); + static_assert(BLK_N % DIM_Y == 0, "BLK_N must be divisible by DIM_Y"); + static_assert(BLK_M % DIM_XA == 0, "BLK_M must be divisible by DIM_XA"); + static_assert(BLK_K % DIM_YA == 0, "BLK_K must be divisible by DIM_YA"); + static_assert(BLK_K % DIM_XB == 0, "BLK_K must be divisible by DIM_XB"); + static_assert(BLK_N % DIM_YB == 0, "BLK_N must be divisible by DIM_YB"); + static_assert(DIM_XA * DIM_YA == DIM_X * DIM_Y, + "A-loader thread grid must cover the whole block"); + static_assert(DIM_XB * DIM_YB == DIM_X * DIM_Y, + "B-loader thread grid must cover the whole block"); + int idx = threadIdx.x; // thread's m dimension int idy = threadIdx.y; // thread's n dimension @@ -57,13 +92,15 @@ static __device__ void vbatched_gemm_nn_device(int M, int blx = blockIdx.x; // block's m dimension int bly = blockIdx.y; // block's n dimension - // Registers for the innermost loop. rC accumulates in T; the widening to + // Accumulator tile (registers). rC accumulates in T; the widening to // double happens only at the final atomicAdd into C. T rC[THR_N][THR_M]; - T rA[THR_M]; - T rB[THR_N]; - // Registers for the dev->shmem copy + // Per-VK-step shmem->reg tiles. One load feeds VK FMAs per (m,n). + T rA[THR_M][VK]; + T rB[THR_N][VK]; + + // Registers for the dev->shmem copy (next-K-tile prefetch). T ra[BLK_K / DIM_YA][BLK_M / DIM_XA]; T rb[BLK_N / DIM_YB][BLK_K / DIM_XB]; @@ -86,7 +123,7 @@ static __device__ void vbatched_gemm_nn_device(int M, #pragma unroll for (m = 0; m < THR_M; m++) { - rC[n][m] = T(0); + rC[n][m] = 0.0; } } @@ -143,32 +180,44 @@ static __device__ void vbatched_gemm_nn_device(int M, } } -// Multiply +// Wide-load FMA: one vector load feeds VK FMAs per (m, n). +// FP32: float4 -> 4 FMAs per (m,n) per inner step +// FP64: double2 -> 2 FMAs per (m,n) per inner step +// Relies on the K-stride being 16-byte aligned and BLK_K being a whole +// number of VK chunks (static_asserts above). #pragma unroll - for (k = 0; k < BLK_K; k++) + for (k = 0; k < BLK_K; k += VK) { // Load A shmem->regs #pragma unroll for (m = 0; m < THR_M; m++) { - rA[m] = sA(m * DIM_X + idx, k); + vec_t va = *reinterpret_cast( + &sA(m * DIM_X + idx, k)); + gemm_vec_traits::unpack(va, rA[m]); } // Load B shmem->regs #pragma unroll for (n = 0; n < THR_N; n++) { - rB[n] = sB(k, n * DIM_Y + idy); + vec_t vb = *reinterpret_cast( + &sB(k, n * DIM_Y + idy)); + gemm_vec_traits::unpack(vb, rB[n]); } -// Compute +// Compute (VK fan-out per (m,n)). #pragma unroll - for (n = 0; n < THR_N; n++) + for (int kv = 0; kv < VK; kv++) { #pragma unroll - for (m = 0; m < THR_M; m++) + for (n = 0; n < THR_N; n++) { - rC[n][m] += rA[m] * rB[n]; +#pragma unroll + for (m = 0; m < THR_M; m++) + { + rC[n][m] += rA[m][kv] * rB[n][kv]; + } } } } @@ -199,36 +248,37 @@ static __device__ void vbatched_gemm_nn_device(int M, __syncthreads(); } - // Multiply last full (BLK_K) or partial block of - // columns of op(A) and rows of op(B). + // Tail: the leftover K columns after the BLK_K-strided main loop (here K is + // the contraction length, nw1). The remainder is generally not a multiple + // of VK, so it runs with scalar shared-memory reads instead of the vector + // load. // It's okay that m,n exceed matrix bounds as all work is in registers // or shared memory, and out-of-bounds rC[n][m] will not be saved later. kk = K - kk; #pragma unroll for (k = 0; k < kk; k++) { -// Load A shmem->regs + T rA_s[THR_M]; + T rB_s[THR_N]; #pragma unroll for (m = 0; m < THR_M; m++) { - rA[m] = sA(m * DIM_X + idx, k); + rA_s[m] = sA(m * DIM_X + idx, k); } -// Load B shmem->regs #pragma unroll for (n = 0; n < THR_N; n++) { - rB[n] = sB(k, n * DIM_Y + idy); + rB_s[n] = sB(k, n * DIM_Y + idy); } -// Compute #pragma unroll for (n = 0; n < THR_N; n++) { #pragma unroll for (m = 0; m < THR_M; m++) { - rC[n][m] += rA[m] * rB[n]; + rC[n][m] += rA_s[m] * rB_s[n]; } } } @@ -263,9 +313,9 @@ template -static __global__ void vbatched_gemm_nn_kernel(const int* M, - const int* N, - const int* K, +static __global__ void vbatched_gemm_nn_kernel(int M, + int N, + int K, const T* const* global_A_array, const int* global_lda, const T* const* global_B_array, @@ -274,23 +324,25 @@ static __global__ void vbatched_gemm_nn_kernel(const int* M, const int* global_ldc, const T* alpha) { - extern __shared__ __align__(sizeof(double)) unsigned char smem[]; + // 16-byte align for vec_t (double2 / float4) loads. + extern __shared__ __align__(16) unsigned char smem[]; T* shared_mem = reinterpret_cast(smem); int batchid = blockIdx.z; - int local_M = (int)M[batchid]; - int local_N = (int)N[batchid]; - int local_K = (int)K[batchid]; - - if (blockIdx.x >= (local_M + BLK_M - 1) / BLK_M) - return; - if (blockIdx.y >= (local_N + BLK_N - 1) / BLK_N) - return; - int shared_lda = BLK_M + 1; - int shared_ldb = BLK_K + 1; + constexpr int PAD = gemm_vec_traits::PAD; + static_assert(((BLK_K + PAD) * sizeof(T)) % 16 == 0, + "shmem K-stride * sizeof(T) must be 16-byte aligned for " + "LDS.{64,128}"); + static_assert(BLK_K % gemm_vec_traits::VK == 0, + "BLK_K must be divisible by VK = 16 / sizeof(T)"); + + // K-inner layout: slda/sldb are the K-axis stride (BLK_K + PAD) for sA + // (BLK_M rows) and sB (BLK_N columns) respectively. + int shared_lda = BLK_K + PAD; + int shared_ldb = BLK_K + PAD; T* shared_A = (T*)shared_mem; - T* shared_B = shared_A + shared_lda * BLK_K; + T* shared_B = shared_A + BLK_M * shared_lda; T alpha_tmp = T(1.0); if (alpha != nullptr) { @@ -307,9 +359,9 @@ static __global__ void vbatched_gemm_nn_kernel(const int* M, DIM_XB, DIM_YB, (BLK_M / DIM_X), - (BLK_N / DIM_Y)>(local_M, - local_N, - local_K, + (BLK_N / DIM_Y)>(M, + N, + K, global_A_array[batchid], (int)global_lda[batchid], global_B_array[batchid], @@ -343,12 +395,9 @@ static __global__ void vbatched_gemm_nn_kernel(const int* M, * matrix B. * @tparam DIM_YB The number of threads in the y-dimension used for loading * matrix B. - * @param max_m The maximum number of rows in the matrices. - * @param max_n The maximum number of columns in the matrices. - * @param m An array of batch sizes for the number of rows in each matrix. - * @param n An array of batch sizes for the number of columns in each matrix. - * @param k An array of batch sizes for the number of elements in each matrix - * along the K dimension. + * @param m The number of rows in each matrix (same across the batch). + * @param n The number of columns in each matrix (same across the batch). + * @param k The number of elements along the K dimension (same across the batch). * @param global_A_array An array of pointers to the input matrices A. * @param global_lda An array of leading dimensions for the input matrices A. * @param global_B_array An array of pointers to the input matrices B. @@ -358,7 +407,7 @@ static __global__ void vbatched_gemm_nn_kernel(const int* M, * @param batchCount The number of matrices in the batch. * @param stream The CUDA stream to use for the computation. * @param alpha The scalar value to multiply the matrices by (optional, default - * is nullptr). generate by copilot + * is nullptr). */ template -void vbatched_gemm_nn_impl(int max_m, - int max_n, - const int* m, - const int* n, - const int* k, +void vbatched_gemm_nn_impl(int m, + int n, + int k, const T* const* global_A_array, const int* global_lda, const T* const* global_B_array, @@ -389,17 +436,21 @@ void vbatched_gemm_nn_impl(int max_m, // This is because vbatch_gemm_nn_kernel is column major, // but vatched_gemm_nn_impl is designed to be row major, + // K-inner shared-memory footprint: + // sA: BLK_M rows of (BLK_K + PAD) elements + // sB: BLK_N cols of (BLK_K + PAD) elements + constexpr int PAD = gemm_vec_traits::PAD; size_t shared_mem_size = 0; - shared_mem_size += (BLK_M + 1) * BLK_K * sizeof(T); - shared_mem_size += (BLK_K + 1) * BLK_N * sizeof(T); + shared_mem_size += BLK_M * (BLK_K + PAD) * sizeof(T); + shared_mem_size += BLK_N * (BLK_K + PAD) * sizeof(T); dim3 dimBlock(DIM_X, DIM_Y); const int max_batch_count = 32768; for (int i = 0; i < batchCount; i += max_batch_count) { const int ibatch = min(max_batch_count, batchCount - i); - dim3 dimGrid(ceil_div(max_n, BLK_M), - ceil_div(max_m, BLK_N), + dim3 dimGrid(ceil_div(n, BLK_M), + ceil_div(m, BLK_N), ibatch); const T* alpha_tmp = nullptr; if (alpha != nullptr) @@ -418,7 +469,7 @@ void vbatched_gemm_nn_impl(int max_m, DIM_XB, DIM_YB> <<>>( - n + i, m + i, k + i, + n, m, k, global_B_array + i, global_ldb + i, global_A_array + i, global_lda + i, global_C_array + i, global_ldc + i, diff --git a/source/source_lcao/module_gint/kernel/gemm_tn_vbatch.cuh b/source/source_lcao/module_gint/kernel/gemm_tn_vbatch.cuh index ae6a4a00551..b846bbf50b2 100644 --- a/source/source_lcao/module_gint/kernel/gemm_tn_vbatch.cuh +++ b/source/source_lcao/module_gint/kernel/gemm_tn_vbatch.cuh @@ -12,7 +12,16 @@ #include "source_base/module_device/device_check.h" #include "source_base/module_device/kernel_compat.h" -#define sA(i, j) sA[(j)*slda + (i)] +// Shared-memory tile layout (K-inner), identical to gemm_nn_vbatch.cuh: +// sA(m, k) = sA[m * slda + k] -- M indexes the row, K is contiguous +// sB(k, n) = sB[n * sldb + k] -- N indexes the column, K is contiguous +// slda = sldb = BLK_K + PAD +// PAD (from gemm_vec_traits: +4 for FP32, +2 for FP64) keeps the K-stride a +// whole number of 16-byte words for the vector loads and spreads the strided +// reads across banks. See gemm_nn_vbatch.cuh for the layout rationale; the TN +// inner loop uses the same access pattern, the only difference being how the +// dev->shmem load loop for sB is indexed. +#define sA(i, j) sA[(i)*slda + (j)] #define sB(i, j) sB[(j)*sldb + (i)] #define fetch(A, m, n, bound) offs_d##A[min(n * LD##A + m, bound)] @@ -43,6 +52,26 @@ static __device__ void vbatched_gemm_nt_device(int M, int sldb, T alpha) { + using vec_t = typename gemm_vec_traits::vec_t; + constexpr int VK = gemm_vec_traits::VK; + + static_assert(BLK_K % VK == 0, + "BLK_K must be divisible by VK (16 / sizeof(T))"); + + // Tile divisibility: same constraints as gemm_nn_vbatch. The TN sB load + // loop traverses (BLK_K rows x BLK_N cols), so the DIM_XB / DIM_YB + // divisibility checks are mirrored accordingly. + static_assert(BLK_M % DIM_X == 0, "BLK_M must be divisible by DIM_X"); + static_assert(BLK_N % DIM_Y == 0, "BLK_N must be divisible by DIM_Y"); + static_assert(BLK_M % DIM_XA == 0, "BLK_M must be divisible by DIM_XA"); + static_assert(BLK_K % DIM_YA == 0, "BLK_K must be divisible by DIM_YA"); + static_assert(BLK_N % DIM_XB == 0, "BLK_N must be divisible by DIM_XB"); + static_assert(BLK_K % DIM_YB == 0, "BLK_K must be divisible by DIM_YB"); + static_assert(DIM_XA * DIM_YA == DIM_X * DIM_Y, + "A-loader thread grid must cover the whole block"); + static_assert(DIM_XB * DIM_YB == DIM_X * DIM_Y, + "B-loader thread grid must cover the whole block"); + int idx = threadIdx.x; // thread's m dimension int idy = threadIdx.y; // thread's n dimension @@ -57,13 +86,15 @@ static __device__ void vbatched_gemm_nt_device(int M, int blx = blockIdx.x; // block's m dimension int bly = blockIdx.y; // block's n dimension - // Registers for the innermost loop. rC accumulates in T; the widening to + // Accumulator tile (registers). rC accumulates in T; the widening to // double happens only at the final atomicAdd into C. T rC[THR_N][THR_M]; - T rA[THR_M]; - T rB[THR_N]; - // Registers for the dev->shmem copy + // Per-VK-step shmem->reg tiles. One load feeds VK FMAs per (m,n). + T rA[THR_M][VK]; + T rB[THR_N][VK]; + + // Registers for the dev->shmem copy (next-K-tile prefetch). T ra[BLK_K / DIM_YA][BLK_M / DIM_XA]; T rb[BLK_K / DIM_YB][BLK_N / DIM_XB]; @@ -86,7 +117,7 @@ static __device__ void vbatched_gemm_nt_device(int M, #pragma unroll for (m = 0; m < THR_M; m++) { - rC[n][m] = T(0); + rC[n][m] = 0.0; } } @@ -143,32 +174,42 @@ static __device__ void vbatched_gemm_nt_device(int M, } } -// Multiply +// Wide-load FMA: one vector load feeds VK FMAs per (m, n). +// FP32: float4 -> 4 FMAs per (m,n) per inner step +// FP64: double2 -> 2 FMAs per (m,n) per inner step #pragma unroll - for (k = 0; k < BLK_K; k++) + for (k = 0; k < BLK_K; k += VK) { // Load A shmem->regs #pragma unroll for (m = 0; m < THR_M; m++) { - rA[m] = sA(m * DIM_X + idx, k); + vec_t va = *reinterpret_cast( + &sA(m * DIM_X + idx, k)); + gemm_vec_traits::unpack(va, rA[m]); } // Load B shmem->regs #pragma unroll for (n = 0; n < THR_N; n++) { - rB[n] = sB(k, n * DIM_Y + idy); + vec_t vb = *reinterpret_cast( + &sB(k, n * DIM_Y + idy)); + gemm_vec_traits::unpack(vb, rB[n]); } -// Compute +// Compute (VK fan-out per (m,n)). #pragma unroll - for (n = 0; n < THR_N; n++) + for (int kv = 0; kv < VK; kv++) { #pragma unroll - for (m = 0; m < THR_M; m++) + for (n = 0; n < THR_N; n++) { - rC[n][m] += rA[m] * rB[n]; +#pragma unroll + for (m = 0; m < THR_M; m++) + { + rC[n][m] += rA[m][kv] * rB[n][kv]; + } } } } @@ -199,36 +240,37 @@ static __device__ void vbatched_gemm_nt_device(int M, __syncthreads(); } - // Multiply last full (BLK_K) or partial block of - // columns of op(A) and rows of op(B). + // Tail: the leftover K columns after the BLK_K-strided main loop (here K is + // the contraction length, the mesh-grid count bxyz). The remainder is + // generally not a multiple of VK, so it runs with scalar shared-memory + // reads instead of the vector load. // It's okay that m,n exceed matrix bounds as all work is in registers // or shared memory, and out-of-bounds rC[n][m] will not be saved later. kk = K - kk; #pragma unroll for (k = 0; k < kk; k++) { -// Load A shmem->regs + T rA_s[THR_M]; + T rB_s[THR_N]; #pragma unroll for (m = 0; m < THR_M; m++) { - rA[m] = sA(m * DIM_X + idx, k); + rA_s[m] = sA(m * DIM_X + idx, k); } -// Load B shmem->regs #pragma unroll for (n = 0; n < THR_N; n++) { - rB[n] = sB(k, n * DIM_Y + idy); + rB_s[n] = sB(k, n * DIM_Y + idy); } -// Compute #pragma unroll for (n = 0; n < THR_N; n++) { #pragma unroll for (m = 0; m < THR_M; m++) { - rC[n][m] += rA[m] * rB[n]; + rC[n][m] += rA_s[m] * rB_s[n]; } } } @@ -263,9 +305,9 @@ template -static __global__ void vbatched_gemm_nt_kernel(const int* M, - const int* N, - const int* K, +static __global__ void vbatched_gemm_nt_kernel(int M, + int N, + int K, const T* const* global_A_array, const int* global_lda, const T* const* global_B_array, @@ -274,23 +316,25 @@ static __global__ void vbatched_gemm_nt_kernel(const int* M, const int* global_ldc, const T* alpha) { - extern __shared__ __align__(sizeof(double)) unsigned char smem[]; + // 16-byte align for vec_t (double2 / float4) loads. + extern __shared__ __align__(16) unsigned char smem[]; T* shared_mem = reinterpret_cast(smem); int batchid = blockIdx.z; - int local_M = (int)M[batchid]; - int local_N = (int)N[batchid]; - int local_K = (int)K[batchid]; - - if (blockIdx.x >= (local_M + BLK_M - 1) / BLK_M) - return; - if (blockIdx.y >= (local_N + BLK_N - 1) / BLK_N) - return; - int shared_lda = BLK_M + 1; - int shared_ldb = BLK_K + 1; + constexpr int PAD = gemm_vec_traits::PAD; + static_assert(((BLK_K + PAD) * sizeof(T)) % 16 == 0, + "shmem K-stride * sizeof(T) must be 16-byte aligned for " + "LDS.{64,128}"); + static_assert(BLK_K % gemm_vec_traits::VK == 0, + "BLK_K must be divisible by VK = 16 / sizeof(T)"); + + // K-inner layout: slda/sldb are the K-axis stride (BLK_K + PAD) for sA + // (BLK_M rows) and sB (BLK_N columns) respectively. + int shared_lda = BLK_K + PAD; + int shared_ldb = BLK_K + PAD; T* shared_A = (T*)shared_mem; - T* shared_B = shared_A + shared_lda * BLK_K; + T* shared_B = shared_A + BLK_M * shared_lda; T alpha_tmp = T(1.0); if (alpha != nullptr) { @@ -307,9 +351,9 @@ static __global__ void vbatched_gemm_nt_kernel(const int* M, DIM_XB, DIM_YB, (BLK_M / DIM_X), - (BLK_N / DIM_Y)>(local_M, - local_N, - local_K, + (BLK_N / DIM_Y)>(M, + N, + K, global_A_array[batchid], (int)global_lda[batchid], global_B_array[batchid], @@ -343,12 +387,9 @@ static __global__ void vbatched_gemm_nt_kernel(const int* M, * matrix B. * @tparam DIM_YB The number of threads in the y-dimension used for loading * matrix B. - * @param max_m The maximum number of rows in the matrices. - * @param max_n The maximum number of columns in the matrices. - * @param m An array of batch sizes for the number of rows in each matrix. - * @param n An array of batch sizes for the number of columns in each matrix. - * @param k An array of batch sizes for the number of elements in each matrix - * along the K dimension. + * @param m The number of rows in each matrix (same across the batch). + * @param n The number of columns in each matrix (same across the batch). + * @param k The number of elements along the K dimension (same across the batch). * @param global_A_array An array of pointers to the input matrices A. * @param global_lda An array of leading dimensions for the input matrices A. * @param global_B_array An array of pointers to the input matrices B. @@ -358,7 +399,7 @@ static __global__ void vbatched_gemm_nt_kernel(const int* M, * @param batchCount The number of matrices in the batch. * @param stream The CUDA stream to use for the computation. * @param alpha The scalar value to multiply the matrices by (optional, default - * is nullptr). generate by copilot + * is nullptr). */ /* @@ -395,11 +436,9 @@ template -void vbatched_gemm_tn_impl(int max_m, - int max_n, - const int* m, - const int* n, - const int* k, +void vbatched_gemm_tn_impl(int m, + int n, + int k, const T* const* global_A_array, const int* global_lda, const T* const* global_B_array, @@ -414,17 +453,21 @@ void vbatched_gemm_tn_impl(int max_m, // This is because vbatch_gemm__tn_kernel is column major, // but vatched_gemm_nt_impl is designed to be row major, + // K-inner shared-memory footprint (matches vbatched_gemm_nn_impl): + // sA: BLK_M rows of (BLK_K + PAD) elements + // sB: BLK_N cols of (BLK_K + PAD) elements + constexpr int PAD = gemm_vec_traits::PAD; size_t shared_mem_size = 0; - shared_mem_size += (BLK_M + 1) * BLK_K * sizeof(T); - shared_mem_size += (BLK_K + 1) * BLK_N * sizeof(T); + shared_mem_size += BLK_M * (BLK_K + PAD) * sizeof(T); + shared_mem_size += BLK_N * (BLK_K + PAD) * sizeof(T); dim3 dimBlock(DIM_X, DIM_Y); const int max_batch_count = 32768; for (int i = 0; i < batchCount; i += max_batch_count) { const int ibatch = min(max_batch_count, batchCount - i); - dim3 dimGrid(ceil_div(max_n, BLK_M), - ceil_div(max_m, BLK_N), + dim3 dimGrid(ceil_div(n, BLK_M), + ceil_div(m, BLK_N), ibatch); const T* alpha_tmp = nullptr; if (alpha != nullptr) @@ -443,7 +486,7 @@ void vbatched_gemm_tn_impl(int max_m, DIM_XB, DIM_YB> <<>>( - n + i, m + i, k + i, + n, m, k, global_B_array + i, global_ldb + i, global_A_array + i, global_lda + i, global_C_array + i, global_ldc + i, diff --git a/source/source_lcao/module_gint/kernel/gint_helper.cuh b/source/source_lcao/module_gint/kernel/gint_helper.cuh index eae5953654b..1f6d431003d 100644 --- a/source/source_lcao/module_gint/kernel/gint_helper.cuh +++ b/source/source_lcao/module_gint/kernel/gint_helper.cuh @@ -43,3 +43,45 @@ inline int ceil_div(const int a, const int b) return a / b + (a % b != 0 && (a ^ b) > 0); } +// --------------------------------------------------------------------------- +// gemm_vec_traits -- the wide-load primitive used by the GEMM inner loop, +// which reads VK consecutive K elements per shared-memory load instead of one. +// +// VK = number of T elements in one 16-byte load (4 for FP32, 2 for FP64) +// vec_t = the 16-byte vector type used for that load (float4 / double2) +// PAD = padding added to the shared-memory K-stride so that (BLK_K + PAD) +// elements span a whole number of 16-byte words, keeping the +// vectorized shared-memory loads aligned and spreading the warp's +// strided reads across banks. +// +// The load is one *reinterpret_cast(&sA(m, k)); unpack() then fans the +// vector out into the per-thread registers. The explicit per-component copy is +// deliberate: nvcc reliably vectorizes float4 but not double2, and writing +// .x/.y(/.z/.w) by hand guarantees the wide load instruction is emitted. +// --------------------------------------------------------------------------- +template struct gemm_vec_traits; + +template <> struct gemm_vec_traits +{ + using vec_t = float4; + static constexpr int VK = 4; + static constexpr int PAD = 4; + __forceinline__ __device__ + static void unpack(const vec_t& v, float* d) + { + d[0] = v.x; d[1] = v.y; d[2] = v.z; d[3] = v.w; + } +}; + +template <> struct gemm_vec_traits +{ + using vec_t = double2; + static constexpr int VK = 2; + static constexpr int PAD = 2; + __forceinline__ __device__ + static void unpack(const vec_t& v, double* d) + { + d[0] = v.x; d[1] = v.y; + } +}; + diff --git a/source/source_lcao/module_gint/kernel/phi_operator_gpu.cu b/source/source_lcao/module_gint/kernel/phi_operator_gpu.cu index 7021f44e3fe..60e193e16a5 100644 --- a/source/source_lcao/module_gint/kernel/phi_operator_gpu.cu +++ b/source/source_lcao/module_gint/kernel/phi_operator_gpu.cu @@ -2,7 +2,10 @@ #include "phi_operator_kernel.cuh" #include "dgemm_vbatch.h" #include +#include +#include #include +#include #include "source_base/module_device/device_check.h" namespace ModuleGint @@ -18,11 +21,8 @@ bgrid_phi_start_(BatchBigGrid::get_max_batch_size(), stream_, true), atoms_iat_(BatchBigGrid::get_max_atoms_num(), stream_, true), atoms_bgrids_rcoords_(BatchBigGrid::get_max_atoms_num(), stream_, true), atom_phi_start_(BatchBigGrid::get_max_atoms_num(), stream_, true), -batch_mgrid_lidx_(BatchBigGrid::get_max_batch_size() +batch_mgrid_lidx_(BatchBigGrid::get_max_batch_size() * BatchBigGrid::get_bgrid_info()->get_mgrids_num(), stream_, true), -gemm_m_(BatchBigGrid::get_max_atom_pairs_num(), stream_, true), -gemm_n_(BatchBigGrid::get_max_atom_pairs_num(), stream_, true), -gemm_k_(BatchBigGrid::get_max_atom_pairs_num(), stream_, true), gemm_lda_(BatchBigGrid::get_max_atom_pairs_num(), stream_, true), gemm_ldb_(BatchBigGrid::get_max_atom_pairs_num(), stream_, true), gemm_ldc_(BatchBigGrid::get_max_atom_pairs_num(), stream_, true), @@ -32,6 +32,15 @@ gemm_C_(BatchBigGrid::get_max_atom_pairs_num(), stream_, true), gemm_alpha_(BatchBigGrid::get_max_atom_pairs_num(), stream_, true) { CHECK_CUDA(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); + // nwmax is the largest per-atom orbital count in the cell, so a stride of + // nwmax + 1 lets every valid (nw1, nw2) flatten to a distinct bucket key + // with no artificial cap. Allocate the bucketing scratch once here; the hot + // path only re-zeroes it. + nw_stride_ = gint_gpu_vars_->nwmax + 1; + bucket_counts_.assign(nw_stride_ * nw_stride_, 0); + bucket_base_.assign(nw_stride_ * nw_stride_, 0); + bucket_cursor_.assign(nw_stride_ * nw_stride_, 0); + buckets_.reserve(32); } template @@ -90,6 +99,49 @@ void PhiOperatorGpu::set_bgrid_batch(std::shared_ptr bgrid_b atom_phi_start_.copy_host_to_device_async(bgrid_batch->get_atoms_num()); batch_mgrid_lidx_.copy_host_to_device_async(bgrid_batch->get_batch_size() * mgrids_num_); CHECK_CUDA(cudaEventRecord(event_, stream_)); + + // Pre-enumerate every (ia_1, ia_2) pair so phi_mul_phi / phi_mul_dm can + // skip this O(sum atoms_i^2) walk on every call. HContainer lookups and + // per-bgrid filters stay in the hot path since they depend on the + // specific caller (hRGint vs dm, symmetric vs not). + pair_cache_.clear(); + for (int i = 0; i < bgrid_batch->get_batch_size(); ++i) + { + const auto& bgrid = bgrid_batch->get_bgrids()[i]; + const int pre_atoms = atoms_num_info_h[i].y; + const int atoms_num = bgrid->get_atoms_num(); + const int phi_len_mgrid = bgrid->get_phi_len(); + const auto& atoms = bgrid->get_atoms(); + for (int ia_1 = 0; ia_1 < atoms_num; ++ia_1) + { + const auto& atom_1 = atoms[ia_1]; + for (int ia_2 = 0; ia_2 < atoms_num; ++ia_2) + { + const auto& atom_2 = atoms[ia_2]; + PairInfo p; + p.phi_1_offset = atom_phi_start_h[pre_atoms + ia_1]; + p.phi_2_offset = atom_phi_start_h[pre_atoms + ia_2]; + p.phi_len_mgrid = phi_len_mgrid; + p.iat_1 = atom_1->get_iat(); + p.iat_2 = atom_2->get_iat(); + p.r_diff = atom_1->get_R() - atom_2->get_R(); + p.nw1 = static_cast(atom_1->get_nw()); + p.nw2 = static_cast(atom_2->get_nw()); + // The shape key (nw1 * nw_stride_ + nw2) indexes a dense + // nw_stride_^2 table in phi_mul_phi / phi_mul_dm. nw_stride_ = + // nwmax + 1 and nwmax is by construction the largest per-atom nw + // in the cell, so this can only trip on an upstream + // inconsistency rather than an undersized cap. + assert(p.nw1 < nw_stride_ && p.nw2 < nw_stride_); + p.ia_le = static_cast(ia_1 <= ia_2); + p.is_diag = static_cast(ia_1 == ia_2); + pair_cache_.push_back(p); + } + } + } + // Sized to match pair_cache_; Pass 1 of phi_mul_phi / phi_mul_dm overwrites + // every entry before Pass 2 reads it, so no initial value is needed. + pair_scratch_offset_.resize(pair_cache_.size()); } template @@ -234,57 +286,76 @@ void PhiOperatorGpu::phi_mul_phi( HContainer& hRGint, double* hr_d) const { - // ap_num means number of atom pairs + // Shape-exact bucketing: group atom pairs by (nw1, nw2). K = mgrids_num_ + // is already batch-wide constant, so (nw1, nw2) fully determines the GEMM + // shape. Each bucket hands gemm_tn_vbatch scalar (nw1, nw2, mgrids_num_), + // so the tile ladder picks the tightest tile for every shape and the + // wrapper sizes the grid exactly -- no cross-shape tile waste, no + // over-launched blocks. + // + // Algorithm: counting-sort-style two-pass over the pre-enumerated + // pair_cache_ populated in set_bgrid_batch(). + // Pass 1: HContainer lookup -> stash hr_offset, count items per shape. + // Prefix sum: build the list of non-empty buckets + their flat offsets. + // Pass 2: scatter A/B/C pointers + lda/ldb/ldc into the flat host arrays + // at each bucket's slot, then one H2D copy per array and one + // vbatch launch per bucket. (m/n/k arrays are no longer + // scattered -- the wrapper fills them on-device from the + // scalar bucket shape.) + + auto& counts = bucket_counts_; + std::fill(counts.begin(), counts.end(), 0); + + // Pass 1: filter + HContainer lookup + per-shape count. + for (size_t i = 0; i < pair_cache_.size(); ++i) + { + const auto& p = pair_cache_[i]; + if (p.iat_1 > p.iat_2) { pair_scratch_offset_[i] = -1; continue; } + const int hr = hRGint.find_matrix_offset(p.iat_1, p.iat_2, p.r_diff); + pair_scratch_offset_[i] = hr; + if (hr == -1) { continue; } + counts[p.nw1 * nw_stride_ + p.nw2]++; + } + + // Prefix sum over dense keys -> compact bucket list. + auto& buckets = buckets_; + buckets.clear(); + auto& key_to_base = bucket_base_; + std::fill(key_to_base.begin(), key_to_base.end(), 0); int ap_num = 0; - int max_m = 0; - int max_n = 0; - int max_k = mgrids_num_; - CHECK_CUDA(cudaEventSynchronize(event_)); - for (int i = 0; i < bgrid_batch_->get_batch_size(); i++) + for (int k = 0; k < nw_stride_ * nw_stride_; ++k) { - auto bgrid = bgrid_batch_->get_bgrids()[i]; - // the length of phi on a mesh grid - const int phi_len_mgrid = bgrid->get_phi_len(); - const int pre_atoms = atoms_num_info_.get_host_ptr()[i].y; - for (int ia_1 = 0; ia_1 < bgrid->get_atoms_num(); ia_1++) - { - auto atom_1 = bgrid->get_atoms()[ia_1]; - const int iat_1 = atom_1->get_iat(); - const auto& r_1 = atom_1->get_R(); - const int nw1 = atom_1->get_nw(); - const int phi_1_offset = atom_phi_start_.get_host_ptr()[pre_atoms + ia_1]; + if (counts[k] == 0) { continue; } + buckets.push_back({k, ap_num, counts[k]}); + key_to_base[k] = ap_num; + ap_num += counts[k]; + } - for (int ia_2 = 0; ia_2 < bgrid->get_atoms_num(); ia_2++) - { - auto atom_2 = bgrid->get_atoms()[ia_2]; - const int iat_2 = atom_2->get_iat(); - const auto& r_2 = atom_2->get_R(); - const int nw2 = atom_2->get_nw(); - - if(iat_1 > iat_2) - { continue; } - - int hr_offset = hRGint.find_matrix_offset(iat_1, iat_2, r_1 - r_2); - if (hr_offset == -1) - { continue; } - - const int phi_2_offset = atom_phi_start_.get_host_ptr()[pre_atoms + ia_2]; - - gemm_A_.get_host_ptr()[ap_num] = phi_d + phi_1_offset; - gemm_B_.get_host_ptr()[ap_num] = phi_vldr3_d + phi_2_offset; - gemm_C_.get_host_ptr()[ap_num] = hr_d + hr_offset; - gemm_lda_.get_host_ptr()[ap_num] = phi_len_mgrid; - gemm_ldb_.get_host_ptr()[ap_num] = phi_len_mgrid; - gemm_ldc_.get_host_ptr()[ap_num] = nw2; - gemm_m_.get_host_ptr()[ap_num] = nw1; - gemm_n_.get_host_ptr()[ap_num] = nw2; - gemm_k_.get_host_ptr()[ap_num] = bgrid->get_mgrids_num(); - ap_num++; - - max_m = std::max(max_m, nw1); - max_n = std::max(max_n, nw2); - } - } + auto* h_A = gemm_A_.get_host_ptr(); + auto* h_B = gemm_B_.get_host_ptr(); + auto* h_C = gemm_C_.get_host_ptr(); + auto* h_lda = gemm_lda_.get_host_ptr(); + auto* h_ldb = gemm_ldb_.get_host_ptr(); + auto* h_ldc = gemm_ldc_.get_host_ptr(); + + CHECK_CUDA(cudaEventSynchronize(event_)); + + // Pass 2: scatter into the flat host arrays at per-bucket cursors. + auto& cursor = bucket_cursor_; + std::fill(cursor.begin(), cursor.end(), 0); + for (size_t i = 0; i < pair_cache_.size(); ++i) + { + const int hr = pair_scratch_offset_[i]; + if (hr == -1) { continue; } + const auto& p = pair_cache_[i]; + const int key = p.nw1 * nw_stride_ + p.nw2; + const int pos = key_to_base[key] + cursor[key]++; + h_A[pos] = phi_d + p.phi_1_offset; + h_B[pos] = phi_vldr3_d + p.phi_2_offset; + h_C[pos] = hr_d + hr; + h_lda[pos] = p.phi_len_mgrid; + h_ldb[pos] = p.phi_len_mgrid; + h_ldc[pos] = p.nw2; } gemm_A_.copy_host_to_device_async(ap_num); @@ -293,26 +364,25 @@ void PhiOperatorGpu::phi_mul_phi( gemm_lda_.copy_host_to_device_async(ap_num); gemm_ldb_.copy_host_to_device_async(ap_num); gemm_ldc_.copy_host_to_device_async(ap_num); - gemm_m_.copy_host_to_device_async(ap_num); - gemm_n_.copy_host_to_device_async(ap_num); - gemm_k_.copy_host_to_device_async(ap_num); CHECK_CUDA(cudaEventRecord(event_, stream_)); - gemm_tn_vbatch(max_m, - max_n, - max_k, - gemm_m_.get_device_ptr(), - gemm_n_.get_device_ptr(), - gemm_k_.get_device_ptr(), - gemm_A_.get_device_ptr(), - gemm_lda_.get_device_ptr(), - gemm_B_.get_device_ptr(), - gemm_ldb_.get_device_ptr(), - gemm_C_.get_device_ptr(), - gemm_ldc_.get_device_ptr(), - ap_num, - stream_, - nullptr); + for (const auto& b : buckets) + { + const int nw1 = b.key / nw_stride_; + const int nw2 = b.key % nw_stride_; + gemm_tn_vbatch(nw1, + nw2, + mgrids_num_, + gemm_A_.get_device_ptr() + b.off, + gemm_lda_.get_device_ptr() + b.off, + gemm_B_.get_device_ptr() + b.off, + gemm_ldb_.get_device_ptr() + b.off, + gemm_C_.get_device_ptr() + b.off, + gemm_ldc_.get_device_ptr() + b.off, + b.cnt, + stream_, + nullptr); + } } template @@ -324,54 +394,70 @@ void PhiOperatorGpu::phi_mul_dm( double* phi_dm_d) { CHECK_CUDA(cudaMemsetAsync(phi_dm_d, 0, phi_len_ * sizeof(double), stream_)); - // ap_num means number of atom pairs + + // Shape-exact bucketing: same structure as phi_mul_phi, but NN-flavored. + // M = mgrids_num_ (batch-wide constant), N = nw2, K = nw1. + // Shape key is still (nw1, nw2); M is absent from the key since it's + // identical across every pair in the batch. is_symm selects the + // upper-triangle (ia_1 <= ia_2) subset and fills per-pair alpha. + + auto& counts = bucket_counts_; + std::fill(counts.begin(), counts.end(), 0); + + // Pass 1: filter + HContainer lookup + per-shape count. + for (size_t i = 0; i < pair_cache_.size(); ++i) + { + const auto& p = pair_cache_[i]; + if (is_symm && !p.ia_le) { pair_scratch_offset_[i] = -1; continue; } + const int dm_offset = dm.find_matrix_offset(p.iat_1, p.iat_2, p.r_diff); + pair_scratch_offset_[i] = dm_offset; + if (dm_offset == -1) { continue; } + counts[p.nw1 * nw_stride_ + p.nw2]++; + } + + // Prefix sum over dense keys -> compact bucket list. + auto& buckets = buckets_; + buckets.clear(); + auto& key_to_base = bucket_base_; + std::fill(key_to_base.begin(), key_to_base.end(), 0); int ap_num = 0; - int max_m = mgrids_num_; - int max_n = 0; - int max_k = 0; + for (int k = 0; k < nw_stride_ * nw_stride_; ++k) + { + if (counts[k] == 0) { continue; } + buckets.push_back({k, ap_num, counts[k]}); + key_to_base[k] = ap_num; + ap_num += counts[k]; + } + + auto* h_A = gemm_A_.get_host_ptr(); + auto* h_B = gemm_B_.get_host_ptr(); + auto* h_C = gemm_C_.get_host_ptr(); + auto* h_lda = gemm_lda_.get_host_ptr(); + auto* h_ldb = gemm_ldb_.get_host_ptr(); + auto* h_ldc = gemm_ldc_.get_host_ptr(); + auto* h_alpha = gemm_alpha_.get_host_ptr(); + CHECK_CUDA(cudaEventSynchronize(event_)); - for (int i = 0; i < bgrid_batch_->get_batch_size(); i++) + + // Pass 2: scatter. + auto& cursor = bucket_cursor_; + std::fill(cursor.begin(), cursor.end(), 0); + for (size_t i = 0; i < pair_cache_.size(); ++i) { - auto bgrid = bgrid_batch_->get_bgrids()[i]; - // the length of phi on a mesh grid - const int phi_len_mgrid = bgrid->get_phi_len(); - const int pre_atoms = atoms_num_info_.get_host_ptr()[i].y; - for (int ia_1 = 0; ia_1 < bgrid->get_atoms_num(); ia_1++) + const int dm_offset = pair_scratch_offset_[i]; + if (dm_offset == -1) { continue; } + const auto& p = pair_cache_[i]; + const int key = p.nw1 * nw_stride_ + p.nw2; + const int pos = key_to_base[key] + cursor[key]++; + h_A[pos] = phi_d + p.phi_1_offset; + h_B[pos] = dm_d + dm_offset; + h_C[pos] = phi_dm_d + p.phi_2_offset; + h_lda[pos] = p.phi_len_mgrid; + h_ldb[pos] = p.nw2; + h_ldc[pos] = p.phi_len_mgrid; + if (is_symm) { - auto atom_1 = bgrid->get_atoms()[ia_1]; - const int iat_1 = atom_1->get_iat(); - const auto& r_1 = atom_1->get_R(); - const int nw1 = atom_1->get_nw(); - const int phi_1_offset = atom_phi_start_.get_host_ptr()[pre_atoms + ia_1]; - int ia_2 = is_symm ? ia_1 : 0; - for (; ia_2 < bgrid->get_atoms_num(); ia_2++) - { - auto atom_2 = bgrid->get_atoms()[ia_2]; - const int iat_2 = atom_2->get_iat(); - const auto& r_2 = atom_2->get_R(); - const int nw2 = atom_2->get_nw(); - - int dm_offset = dm.find_matrix_offset(iat_1, iat_2, r_1-r_2); - if (dm_offset == -1) - { continue; } - - const int phi_dm_offset = atom_phi_start_.get_host_ptr()[pre_atoms + ia_2]; - - gemm_A_.get_host_ptr()[ap_num] = phi_d + phi_1_offset; - gemm_B_.get_host_ptr()[ap_num] = dm_d + dm_offset; - gemm_C_.get_host_ptr()[ap_num] = phi_dm_d + phi_dm_offset; - gemm_lda_.get_host_ptr()[ap_num] = phi_len_mgrid; - gemm_ldb_.get_host_ptr()[ap_num] = nw2; - gemm_ldc_.get_host_ptr()[ap_num] = phi_len_mgrid; - gemm_m_.get_host_ptr()[ap_num] = mgrids_num_; - gemm_n_.get_host_ptr()[ap_num] = nw2; - gemm_k_.get_host_ptr()[ap_num] = nw1; - gemm_alpha_.get_host_ptr()[ap_num] = ia_1 == ia_2 ? Real(1.0) : Real(2.0); - ap_num++; - - max_n = std::max(max_n, nw2); - max_k = std::max(max_k, nw1); - } + h_alpha[pos] = p.is_diag ? Real(1.0) : Real(2.0); } } @@ -381,33 +467,31 @@ void PhiOperatorGpu::phi_mul_dm( gemm_lda_.copy_host_to_device_async(ap_num); gemm_ldb_.copy_host_to_device_async(ap_num); gemm_ldc_.copy_host_to_device_async(ap_num); - gemm_m_.copy_host_to_device_async(ap_num); - gemm_n_.copy_host_to_device_async(ap_num); - gemm_k_.copy_host_to_device_async(ap_num); - if(is_symm) + if (is_symm) { - // if is_symm == false, gemm_alpha_ always equals 1.0, - // so we don't need to copy it to device + // if is_symm == false, gemm_alpha_ is always 1.0 and is skipped on device gemm_alpha_.copy_host_to_device_async(ap_num); } CHECK_CUDA(cudaEventRecord(event_, stream_)); - auto alpha_ptr = is_symm ? gemm_alpha_.get_device_ptr() : nullptr; - gemm_nn_vbatch(max_m, - max_n, - max_k, - gemm_m_.get_device_ptr(), - gemm_n_.get_device_ptr(), - gemm_k_.get_device_ptr(), - gemm_A_.get_device_ptr(), - gemm_lda_.get_device_ptr(), - gemm_B_.get_device_ptr(), - gemm_ldb_.get_device_ptr(), - gemm_C_.get_device_ptr(), - gemm_ldc_.get_device_ptr(), - ap_num, - stream_, - alpha_ptr); + for (const auto& b : buckets) + { + const int nw1 = b.key / nw_stride_; + const int nw2 = b.key % nw_stride_; + auto alpha_ptr = is_symm ? (gemm_alpha_.get_device_ptr() + b.off) : nullptr; + gemm_nn_vbatch(mgrids_num_, + nw2, + nw1, + gemm_A_.get_device_ptr() + b.off, + gemm_lda_.get_device_ptr() + b.off, + gemm_B_.get_device_ptr() + b.off, + gemm_ldb_.get_device_ptr() + b.off, + gemm_C_.get_device_ptr() + b.off, + gemm_ldc_.get_device_ptr() + b.off, + b.cnt, + stream_, + alpha_ptr); + } } template diff --git a/source/source_lcao/module_gint/kernel/phi_operator_gpu.h b/source/source_lcao/module_gint/kernel/phi_operator_gpu.h index fbdbc95352a..d93270ea7e6 100644 --- a/source/source_lcao/module_gint/kernel/phi_operator_gpu.h +++ b/source/source_lcao/module_gint/kernel/phi_operator_gpu.h @@ -1,7 +1,10 @@ #pragma once #include +#include +#include #include +#include "source_lcao/module_gint/gint_type.h" // Vec3i, used by PairInfo #include "source_lcao/module_gint/batch_biggrid.h" #include "gint_gpu_vars.h" #include "cuda_mem_wrapper.h" @@ -9,6 +12,34 @@ namespace ModuleGint { +// Per-atom-pair metadata cached in PhiOperatorGpu::set_bgrid_batch() so that +// phi_mul_phi / phi_mul_dm skip the O(bgrid * atoms^2) enumeration on every +// call. Holds only the fields both callers need; HContainer lookups still +// happen lazily in the hot path because they depend on hRGint / dm. +struct PairInfo +{ + int phi_1_offset; + int phi_2_offset; + int phi_len_mgrid; + int iat_1; + int iat_2; + Vec3i r_diff; + uint16_t nw1; + uint16_t nw2; + uint8_t ia_le; // (ia_1 <= ia_2) within the bgrid, for is_symm filter + uint8_t is_diag; // (ia_1 == ia_2), for phi_mul_dm is_symm alpha +}; + +// One non-empty (nw1, nw2) bucket produced by the counting sort in +// phi_mul_phi / phi_mul_dm: a flattened shape key, the bucket's start offset in +// the flat gemm_* arrays, and how many atom pairs landed in it. +struct GemmShapeBucket +{ + int key; // nw1 * nw_stride_ + nw2 + int off; // start offset in the gemm_* host/device arrays + int cnt; // number of atom pairs in this bucket +}; + template class PhiOperatorGpu { @@ -33,10 +64,12 @@ class PhiOperatorGpu const Real* phi_d, Real* result_d) const; - // All GEMM accumulators (hr in phi_mul_phi, phi_dm in phi_mul_dm) are - // double-typed regardless of Real: when Real=float the multiplies stay in - // fp32 (cheap) but per-block reductions and device-side atomicAdd run in - // fp64 so the global reductions don't drift. + // The GEMM output buffers (hr in phi_mul_phi, phi_dm in phi_mul_dm) are + // always double, independent of Real. When Real=float the per-pair inner + // products are reduced in fp32 (cheap); the cross-pair accumulation into a + // shared hr/phi_dm element is what runs in fp64, via an atomicAdd into + // these double buffers, so summing many atom-pair contributions doesn't + // drift. void phi_mul_phi( const Real* phi_d, const Real* phi_vldr3_d, @@ -80,6 +113,14 @@ class PhiOperatorGpu int phi_len_; + // Stride for flattening a (nw1, nw2) pair into a single dense bucket key + // (`nw1 * nw_stride_ + nw2`), so shape-exact bucketing of phi_mul_phi / + // phi_mul_dm can index a flat table instead of hashing. Set once in the + // ctor to ucell.nwmax + 1 -- nwmax is the largest per-atom orbital count in + // the cell, so there is no artificial ceiling: the table is sized to the + // actual basis (typical nwmax ~25). + int nw_stride_ = 0; + cudaStream_t stream_ = 0; cudaEvent_t event_; @@ -102,19 +143,36 @@ class PhiOperatorGpu // Mapping of the index of meshgrid in the batch of biggrids to the index of meshgrid in the local cell CudaMemWrapper batch_mgrid_lidx_; - mutable CudaMemWrapper gemm_m_; - mutable CudaMemWrapper gemm_n_; - mutable CudaMemWrapper gemm_k_; mutable CudaMemWrapper gemm_lda_; mutable CudaMemWrapper gemm_ldb_; mutable CudaMemWrapper gemm_ldc_; mutable CudaMemWrapper gemm_A_; mutable CudaMemWrapper gemm_B_; - // Single C-pointer buffer: both phi_mul_phi (output hr) and phi_mul_dm - // (output phi_dm) write into double* accumulators, so a single shared - // gemm_C_ device buffer can serve both call sites. + // C accumulator pointers are always double*: both phi_mul_phi (hr) and + // phi_mul_dm (phi_dm) write into fp64 buffers via the GEMM's fp64 atomicAdd. mutable CudaMemWrapper gemm_C_; mutable CudaMemWrapper gemm_alpha_; + + // Full (ia_1, ia_2) pair enumeration, rebuilt in set_bgrid_batch(). + // Consumed by phi_mul_phi (TN, iat_1 <= iat_2 filter) and phi_mul_dm + // (NN, optional is_symm upper-triangle filter). + std::vector pair_cache_; + + // Scratch buffer reused across phi_mul_phi / phi_mul_dm calls to cache + // per-pair HContainer offsets from Pass 1 and replay them in Pass 2 + // without a second find_matrix_offset() call. + mutable std::vector pair_scratch_offset_; + + // Dense (nw_stride_ * nw_stride_) counting-sort scratch shared by + // phi_mul_phi / phi_mul_dm. Sized once in the ctor and just re-zeroed per + // call, so the hot path never reallocates. + mutable std::vector bucket_counts_; + mutable std::vector bucket_base_; + mutable std::vector bucket_cursor_; + + // Compact list of non-empty buckets for the current call. Reused (cleared, + // not reallocated) by both phi_mul_phi and phi_mul_dm. + mutable std::vector buckets_; }; } \ No newline at end of file