diff --git a/tools/clang/unittests/HLSLExec/LinAlgTests.cpp b/tools/clang/unittests/HLSLExec/LinAlgTests.cpp index 7b8aeeeba7..b3c4b761ea 100644 --- a/tools/clang/unittests/HLSLExec/LinAlgTests.cpp +++ b/tools/clang/unittests/HLSLExec/LinAlgTests.cpp @@ -359,6 +359,9 @@ class DxilConf_SM610_LinAlg { // Convert TEST_METHOD(Convert); + // Vector Accumulate + TEST_METHOD(VectorAccumulateDescriptor_Thread_F16); + private: CComPtr D3DDevice; dxc::SpecificDllLoader DxcSupport; @@ -1693,7 +1696,7 @@ static const char ConvertShader[] = R"( static void runConvert(ID3D12Device *Device, dxc::SpecificDllLoader &DxcSupport, bool Verbose) { - std::string Args = "-HV 202x"; + std::string Args = "-HV 202x -enable-16bit-types"; MatrixDim NumElements = 4; size_t BufferSize = elementSize(ComponentType::F32) * NumElements; @@ -1718,4 +1721,44 @@ void DxilConf_SM610_LinAlg::Convert() { runConvert(D3DDevice, DxcSupport, VerboseLogging); } +static const char VectorAccumulateDescriptorShader[] = R"( + RWByteAddressBuffer Output : register(u0); + + [numthreads(1, 1, 1)] + void main() { + vector InVec = {1.0, 2.0, 3.0, 4.0}; + __builtin_LinAlg_VectorAccumulateToDescriptor(InVec, Output, 0, 64); + } +)"; + +static void runVectorAccumulateDescriptor(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + bool Verbose) { + std::string Args = "-HV 202x -enable-16bit-types"; + MatrixDim NumElements = 4; + size_t BufferSize = elementSize(ComponentType::F16) * NumElements; + + compileShader(DxcSupport, VectorAccumulateDescriptorShader, "cs_6_10", Args, + Verbose); + + auto Expected = makeExpectedVec(ComponentType::F16, NumElements, 1.0); + + auto Op = createComputeOp(VectorAccumulateDescriptorShader, "cs_6_10", + "UAV(u0)", Args.c_str()); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootView(Op.get(), 0, "Output"); + + auto Result = runShaderOp(Device, DxcSupport, std::move(Op)); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(ComponentType::F16, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::VectorAccumulateDescriptor_Thread_F16() { + runVectorAccumulateDescriptor(D3DDevice, DxcSupport, VerboseLogging); +} + } // namespace LinAlg