diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index 127733e71e..2067ebaf26 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -1170,12 +1170,14 @@ static const ArBasicKind g_RayQueryCT[] = {AR_OBJECT_RAY_QUERY, AR_BASIC_UNKNOWN}; static const ArBasicKind g_LinAlgCT[] = { + AR_BASIC_LITERAL_FLOAT, AR_BASIC_FLOAT16, AR_BASIC_FLOAT32, AR_BASIC_FLOAT32_PARTIAL_PRECISION, - AR_BASIC_FLOAT16, AR_BASIC_INT32, - AR_BASIC_INT16, AR_BASIC_UINT32, - AR_BASIC_UINT16, AR_BASIC_INT8_4PACKED, - AR_BASIC_UINT8_4PACKED, AR_BASIC_NOCAST, - AR_BASIC_UNKNOWN}; + AR_BASIC_FLOAT64, AR_BASIC_LITERAL_INT, + AR_BASIC_UINT16, AR_BASIC_UINT32, + AR_BASIC_UINT64, AR_BASIC_INT16, + AR_BASIC_INT32, AR_BASIC_INT64, + AR_BASIC_UINT8_4PACKED, AR_BASIC_INT8_4PACKED, + AR_BASIC_NOCAST, AR_BASIC_UNKNOWN}; static const ArBasicKind g_AccelerationStructCT[] = { AR_OBJECT_ACCELERATION_STRUCT, AR_BASIC_UNKNOWN}; diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/builtins/matrixvectormultiplyadd/nominal.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/builtins/matrixvectormultiplyadd/nominal.hlsl index d4f0037460..edfa306148 100644 --- a/tools/clang/test/CodeGenDXIL/hlsl/linalg/builtins/matrixvectormultiplyadd/nominal.hlsl +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/builtins/matrixvectormultiplyadd/nominal.hlsl @@ -49,4 +49,18 @@ void main() { // CHECK2-SAME: i1 true, <4 x i64> %{{[0-9]+}}, i32 1, <4 x i64> %{{[0-9]+}}, i32 0) __builtin_LinAlg_MatrixVectorMultiplyAdd(result3, mat3, true, vec3, 1, result3, 0); + + // CHECK: call <8 x i32> @dx.op.linAlgMatVecMulAdd.v8i32.mC17M8N8U0S0.v8i32.v8i32(i32 -2147483622, + // CHECK-SAME: %dx.types.LinAlgMatrixC17M8N8U0S0 {{.*}}, i1 true, <8 x i32> zeroinitializer, + // CHECK-SAME: i32 1, <8 x i32> zeroinitializer, i32 0) + // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) + + // CHECK2: call void @"dx.hl.op..void (i32, <8 x i32>*, %dx.types.LinAlgMatrixC17M8N8U0S0, i1, <8 x i32>, + // CHECK2-SAME: i32, <8 x i32>, i32)"(i32 419, <8 x i32>* %result4, %dx.types.LinAlgMatrixC17M8N8U0S0 %{{[0-9]+}}, + // CHECK2-SAME: i1 true, <8 x i32> %{{[0-9]+}}, i32 1, <8 x i32> %{{[0-9]+}}, i32 0) + + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(17, 8, 8, 0, 0)]] mat4; + vector vec4 = 0; + vector result4 = 0; + __builtin_LinAlg_MatrixVectorMultiplyAdd(result4, mat4, true, vec4, 1, result4, 0); } diff --git a/utils/hct/gen_intrin_main.txt b/utils/hct/gen_intrin_main.txt index c0d1d3dcfc..37a4ee406f 100644 --- a/utils/hct/gen_intrin_main.txt +++ b/utils/hct/gen_intrin_main.txt @@ -393,24 +393,24 @@ void [[min_sm=6.10]] __builtin_LinAlg_FillMatrix(out LinAlgMatrix ret, in numeri void [[min_sm=6.10]] __builtin_LinAlg_CopyConvertMatrix(out LinAlgMatrix ret, in LinAlgMatrix source, in bool transpose); void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromDescriptor(out LinAlgMatrix ret, in ByteAddressBuffer buf, in uint offset, in uint stride, in uint layout, in uint align); void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromDescriptor(out LinAlgMatrix ret, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout, in uint align); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromMemory(out LinAlgMatrix ret, groupshared numeric[] memory, in uint offset, in uint stride, in uint layout); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromMemory(out LinAlgMatrix ret, groupshared LinAlg[] memory, in uint offset, in uint stride, in uint layout); uint [[min_sm=6.10]] __builtin_LinAlg_MatrixLength(in LinAlgMatrix matrix); uint<2> [[min_sm=6.10]] __builtin_LinAlg_MatrixGetCoordinate(in LinAlgMatrix matrix, in uint threadLocalIndex); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixGetElement(out numeric ret, in LinAlgMatrix matrix, in uint threadLocalIndex); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixSetElement(ref LinAlgMatrix ret, in LinAlgMatrix matrix, in uint threadLocalIndex, in numeric value); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixGetElement(out LinAlg ret, in LinAlgMatrix matrix, in uint threadLocalIndex); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixSetElement(ref LinAlgMatrix ret, in LinAlgMatrix matrix, in uint threadLocalIndex, in LinAlg value); void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout, in uint align); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToMemory(in LinAlgMatrix matrix, groupshared numeric[] memory, in uint offset, in uint stride, in uint layout); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToMemory(in LinAlgMatrix matrix, groupshared LinAlg[] memory, in uint offset, in uint stride, in uint layout); uint [[min_sm=6.10]] __builtin_LinAlg_MatrixQueryAccumulatorLayout(); void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiply(out LinAlgMatrix matrixC, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB); void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(ref LinAlgMatrix matrixR, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB, in LinAlgMatrix matrixC); void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulate(ref LinAlgMatrix matrixC, in LinAlgMatrix matrixLHS, in LinAlgMatrix matrixRHS); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric input, in uint inputInterp); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric input, in uint inputInterp, in numeric bias, in uint biasInterp); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out LinAlg ret, in LinAlgMatrix mat, in bool isOutputSigned, in LinAlg input, in uint inputInterp); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out LinAlg ret, in LinAlgMatrix mat, in bool isOutputSigned, in LinAlg input, in uint inputInterp, in LinAlg bias, in uint biasInterp); void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout, in uint align); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(in LinAlgMatrix matrix, groupshared numeric[] memory, in uint offset, in uint stride, in uint layout); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in numeric vecA, in numeric vecB); -void [[min_sm=6.10]] __builtin_LinAlg_Convert(out numeric ret, in numeric vec, in uint input_interp, in uint output_interp); -void [[min_sm=6.10]] __builtin_LinAlg_VectorAccumulateToDescriptor(in numeric<> vec, in RWByteAddressBuffer buf, in uint offset, in uint align); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(in LinAlgMatrix matrix, groupshared LinAlg[] memory, in uint offset, in uint stride, in uint layout); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in LinAlg vecA, in LinAlg vecB); +void [[min_sm=6.10]] __builtin_LinAlg_Convert(out LinAlg ret, in LinAlg vec, in uint input_interp, in uint output_interp); +void [[min_sm=6.10]] __builtin_LinAlg_VectorAccumulateToDescriptor(in LinAlg<> vec, in RWByteAddressBuffer buf, in uint offset, in uint align); } namespace