Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions tools/clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1170,12 +1170,14 @@ static const ArBasicKind g_RayQueryCT[] = {AR_OBJECT_RAY_QUERY,
AR_BASIC_UNKNOWN};

static const ArBasicKind g_LinAlgCT[] = {
AR_BASIC_LITERAL_FLOAT, AR_BASIC_FLOAT16,
AR_BASIC_FLOAT32, AR_BASIC_FLOAT32_PARTIAL_PRECISION,
AR_BASIC_FLOAT16, AR_BASIC_INT32,
AR_BASIC_INT16, AR_BASIC_UINT32,
AR_BASIC_UINT16, AR_BASIC_INT8_4PACKED,
AR_BASIC_UINT8_4PACKED, AR_BASIC_NOCAST,
AR_BASIC_UNKNOWN};
AR_BASIC_FLOAT64, AR_BASIC_LITERAL_INT,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't think we want/need to keep AR_BASIC_LITERAL_INT, or AR_BASIC_LITERAL_FLOAT here.

Also, AR_BASIC_FLOAT32_PARTIAL_PRECISION is half when in min-precision mode that maps to ordinary float. I don't think it's necessary to support that, but it probably doesn't hurt.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix will be up in a couple minutes

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay so removing AR_BASIC_LITERAL_INT caused some test failures. Specifically, in SetElement

dxc/DirectXShaderCompiler/tools/clang/test/CodeGenDXIL/hlsl/linalg/builtins/matrixsetelement/nominal.hlsl:21:3: note: candidate function not viable: no known conversion from 'literal int' to 'unsigned int' for 4th argument
  __builtin_LinAlg_MatrixSetElement(mat2, mat1, 1, 5);

I think this would be fine in real use cases because of the header implementation but given that is a slightly larger change/longer discussion I'm going to take your offer to punt :)

AR_BASIC_UINT16, AR_BASIC_UINT32,
AR_BASIC_UINT64, AR_BASIC_INT16,
AR_BASIC_INT32, AR_BASIC_INT64,
AR_BASIC_UINT8_4PACKED, AR_BASIC_INT8_4PACKED,
AR_BASIC_NOCAST, AR_BASIC_UNKNOWN};
Comment thread
V-FEXrt marked this conversation as resolved.

static const ArBasicKind g_AccelerationStructCT[] = {
AR_OBJECT_ACCELERATION_STRUCT, AR_BASIC_UNKNOWN};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,18 @@ void main() {
// CHECK2-SAME: i1 true, <4 x i64> %{{[0-9]+}}, i32 1, <4 x i64> %{{[0-9]+}}, i32 0)

__builtin_LinAlg_MatrixVectorMultiplyAdd(result3, mat3, true, vec3, 1, result3, 0);

// CHECK: call <8 x i32> @dx.op.linAlgMatVecMulAdd.v8i32.mC17M8N8U0S0.v8i32.v8i32(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC17M8N8U0S0 {{.*}}, i1 true, <8 x i32> zeroinitializer,
// CHECK-SAME: i32 1, <8 x i32> zeroinitializer, i32 0)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)

// CHECK2: call void @"dx.hl.op..void (i32, <8 x i32>*, %dx.types.LinAlgMatrixC17M8N8U0S0, i1, <8 x i32>,
// CHECK2-SAME: i32, <8 x i32>, i32)"(i32 419, <8 x i32>* %result4, %dx.types.LinAlgMatrixC17M8N8U0S0 %{{[0-9]+}},
// CHECK2-SAME: i1 true, <8 x i32> %{{[0-9]+}}, i32 1, <8 x i32> %{{[0-9]+}}, i32 0)

__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(17, 8, 8, 0, 0)]] mat4;
vector<int8_t4_packed, 8> vec4 = 0;
vector<int8_t4_packed, 8> result4 = 0;
__builtin_LinAlg_MatrixVectorMultiplyAdd(result4, mat4, true, vec4, 1, result4, 0);
}
20 changes: 10 additions & 10 deletions utils/hct/gen_intrin_main.txt
Original file line number Diff line number Diff line change
Expand Up @@ -393,24 +393,24 @@ void [[min_sm=6.10]] __builtin_LinAlg_FillMatrix(out LinAlgMatrix ret, in numeri
void [[min_sm=6.10]] __builtin_LinAlg_CopyConvertMatrix(out LinAlgMatrix ret, in LinAlgMatrix source, in bool transpose);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromDescriptor(out LinAlgMatrix ret, in ByteAddressBuffer buf, in uint offset, in uint stride, in uint layout, in uint align);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromDescriptor(out LinAlgMatrix ret, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout, in uint align);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromMemory(out LinAlgMatrix ret, groupshared numeric[] memory, in uint offset, in uint stride, in uint layout);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromMemory(out LinAlgMatrix ret, groupshared LinAlg[] memory, in uint offset, in uint stride, in uint layout);
uint [[min_sm=6.10]] __builtin_LinAlg_MatrixLength(in LinAlgMatrix matrix);
uint<2> [[min_sm=6.10]] __builtin_LinAlg_MatrixGetCoordinate(in LinAlgMatrix matrix, in uint threadLocalIndex);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixGetElement(out numeric ret, in LinAlgMatrix matrix, in uint threadLocalIndex);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixSetElement(ref LinAlgMatrix ret, in LinAlgMatrix matrix, in uint threadLocalIndex, in numeric value);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixGetElement(out LinAlg ret, in LinAlgMatrix matrix, in uint threadLocalIndex);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixSetElement(ref LinAlgMatrix ret, in LinAlgMatrix matrix, in uint threadLocalIndex, in LinAlg value);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout, in uint align);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToMemory(in LinAlgMatrix matrix, groupshared numeric[] memory, in uint offset, in uint stride, in uint layout);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToMemory(in LinAlgMatrix matrix, groupshared LinAlg[] memory, in uint offset, in uint stride, in uint layout);
uint [[min_sm=6.10]] __builtin_LinAlg_MatrixQueryAccumulatorLayout();
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiply(out LinAlgMatrix matrixC, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(ref LinAlgMatrix matrixR, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB, in LinAlgMatrix matrixC);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulate(ref LinAlgMatrix matrixC, in LinAlgMatrix matrixLHS, in LinAlgMatrix matrixRHS);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric<c> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<c2> input, in uint inputInterp);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric<c> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<c2> input, in uint inputInterp, in numeric<c> bias, in uint biasInterp);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out LinAlg<c> ret, in LinAlgMatrix mat, in bool isOutputSigned, in LinAlg<c2> input, in uint inputInterp);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out LinAlg<c> ret, in LinAlgMatrix mat, in bool isOutputSigned, in LinAlg<c2> input, in uint inputInterp, in LinAlg<c> bias, in uint biasInterp);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout, in uint align);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(in LinAlgMatrix matrix, groupshared numeric[] memory, in uint offset, in uint stride, in uint layout);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in numeric<c> vecA, in numeric<c2> vecB);
void [[min_sm=6.10]] __builtin_LinAlg_Convert(out numeric<c> ret, in numeric<c2> vec, in uint input_interp, in uint output_interp);
void [[min_sm=6.10]] __builtin_LinAlg_VectorAccumulateToDescriptor(in numeric<> vec, in RWByteAddressBuffer buf, in uint offset, in uint align);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(in LinAlgMatrix matrix, groupshared LinAlg[] memory, in uint offset, in uint stride, in uint layout);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in LinAlg<c> vecA, in LinAlg<c2> vecB);
void [[min_sm=6.10]] __builtin_LinAlg_Convert(out LinAlg<c> ret, in LinAlg<c2> vec, in uint input_interp, in uint output_interp);
void [[min_sm=6.10]] __builtin_LinAlg_VectorAccumulateToDescriptor(in LinAlg<> vec, in RWByteAddressBuffer buf, in uint offset, in uint align);

} namespace

Expand Down
Loading