diff --git a/zstd/zstdgpu/zstdgpu_shaders.h b/zstd/zstdgpu/zstdgpu_shaders.h index 83e2cd8..7aba197 100644 --- a/zstd/zstdgpu/zstdgpu_shaders.h +++ b/zstd/zstdgpu/zstdgpu_shaders.h @@ -777,6 +777,43 @@ static void zstdgpu_ParseFseHeader(ZSTDGPU_PARAM_INOUT(zstdgpu_Forward_BitBuffer outFseInfo[outFseTableIndex] = zstdgpu_CreateFseInfo(symbol, accuracyLog2); } + +// Active lanes either contain a "filler" xor a "hole" value. +// +// If a lane with a hole value can't have a filler value propagated to it from a lower lane, +// its value is unchanged (remains a hole). +// +// NOTE: ensure kzstdgpu_TgSizeX_ParseCompressedBlocks <= 32 +// so HLSL lane masks are easy to work with. +// +// Example with lower lane IDs on the left for "Wave8" where filler values are even integers (holes are odd integers): +// input = { 1, 4, 3, 3, 6, 8, 5, 5 } +// output = { 1, 4, 4, 4, 6, 8, 8, 8 } +inline uint32_t zstdgpu_WaveReplicateFillerUpwardsToHoles(uint32_t v_value, bool v_isFiller) +{ + const uint32_t s_hasFillerMask = WaveActiveBallot(v_isFiller).x; // assume <= Wave32 + const uint32_t v_selfMask = 1u << WaveGetLaneIndex(); + + uint32_t v_srcLanesMask = s_hasFillerMask & (v_selfMask - 1); + // If this lane already has a filler value, or it has no lane with a filler value to read from, make it read from itself: + if (v_isFiller || v_srcLanesMask == 0) + { + v_srcLanesMask = v_selfMask; + } + + return WaveReadLaneAt(v_value, zstdgpu_FindFirstBitHiU32(v_srcLanesMask)); +} + +inline uint32_t zstdgpu_WavePropogateFseTableIndex(uint32_t tableIndex) +{ +#if (kzstdgpu_TgSizeX_ParseCompressedBlocks - 1u) >= 32u + // Parsing compressed blocks can be divergent, so probably don't want a large thread group anyway. + #error "kzstdgpu_TgSizeX_ParseCompressedBlocks must be in [1:32], else implement WaveActiveBallot.y[zw] handling." +#endif + const bool isFiller = tableIndex < kzstdgpu_FseProbTableIndex_Repeat; + return zstdgpu_WaveReplicateFillerUpwardsToHoles(tableIndex, isFiller); +} + static void zstdgpu_ShaderEntry_ParseCompressedBlocks(ZSTDGPU_PARAM_INOUT(zstdgpu_ParseCompressedBlocks_SRT) srt, uint32_t threadId) { if (threadId >= srt.compressedBlockCount) @@ -1247,26 +1284,6 @@ static void zstdgpu_ShaderEntry_ParseCompressedBlocks(ZSTDGPU_PARAM_INOUT(zstdgp const uint32_t lastLocalIndex = WaveActiveCountBits(true) - 1u; - #define WAVE_SHUFFLE(v, and_mask, or_mask, xor_mask) WaveReadLaneAt(v, ((WaveGetLaneIndex() & (and_mask)) | (or_mask)) ^ (xor_mask)) - - #define WAVE_BROADCAST(v, group_size, group_lane) WAVE_SHUFFLE(v, ~(group_size - 1u), group_lane, 0) - - #define WAVE_PROPAGATE_STEP(p, group_size) \ - if (blockSize >= group_size /** this condition is expected to be a compile-time condition, so no real branch */) \ - { \ - /* for every group of `group_size` consecutive lanes, broadcast the value from the last lane of the "odd" sub-group of 2x smaller size) */ \ - uint32_t b = WAVE_BROADCAST(p, group_size, group_size / 2u - 1u); \ - /* for every group of `group_size` consecutive lanes */ \ - /* propagate element from the last lane of the "odd" sub-group of 2x smaller size */ \ - /* into all elements of the "even" sub-group of 2x smaller size when propagated value makes sense */\ - [flatten] if ((WaveGetLaneIndex() & (group_size / 2u))) \ - { \ - /* We propagate only non-Repeat and not-Unused values to lanes containing Repeat/Unused values*/\ - if (p >= kzstdgpu_FseProbTableIndex_Repeat && b < kzstdgpu_FseProbTableIndex_Repeat) \ - p = b; \ - } \ - } - // To propagate FSE table indices, we use a variant of "Decoupled Lookback" // 1. Each block (a group of `blockSize` threads) looks at indices of each type of FSE table // and checks for each of FSE table type if there's any FSE table "index" that is not `Unused` @@ -1325,12 +1342,7 @@ static void zstdgpu_ShaderEntry_ParseCompressedBlocks(ZSTDGPU_PARAM_INOUT(zstdgp #define LOOKBACK_STORE_EARLY_ANY_VALID(name) \ if (WaveActiveAnyTrue(indexValid##name)) \ { \ - uint32_t x = outBlockData.fseTableIndex##name; \ - WAVE_PROPAGATE_STEP(x, 2) \ - WAVE_PROPAGATE_STEP(x, 4) \ - WAVE_PROPAGATE_STEP(x, 8) \ - WAVE_PROPAGATE_STEP(x, 16) \ - WAVE_PROPAGATE_STEP(x, 32) \ + const uint32_t x = zstdgpu_WavePropogateFseTableIndex(outBlockData.fseTableIndex##name);\ const uint32_t xLast = WaveReadLaneAt(x, lastLocalIndex); \ if (WaveIsFirstLane()) \ { \ @@ -1451,15 +1463,10 @@ static void zstdgpu_ShaderEntry_ParseCompressedBlocks(ZSTDGPU_PARAM_INOUT(zstdgp // NOTE(pamartis): Because the first lane containining "non-Unused" index was set to something other than `Repeat`, // we can propagate indices across the wave (if needed of course, if the wave needs that -- contains any number of lanes with `Repeat` indices) #define PROPAGATE_ACROSS_WAVE_IF_NEEDED(name) \ - const bool needPropagateAcrossWave##name = fseTableIndexPropagated##name == kzstdgpu_FseProbTableIndex_Repeat; \ + const bool needPropagateAcrossWave##name = fseTableIndexPropagated##name == kzstdgpu_FseProbTableIndex_Repeat; \ if (WaveActiveAnyTrue(needPropagateAcrossWave##name)) \ { \ - uint32_t x = fseTableIndexPropagated##name; \ - WAVE_PROPAGATE_STEP(x, 2) \ - WAVE_PROPAGATE_STEP(x, 4) \ - WAVE_PROPAGATE_STEP(x, 8) \ - WAVE_PROPAGATE_STEP(x, 16) \ - WAVE_PROPAGATE_STEP(x, 32) \ + const uint32_t x = zstdgpu_WavePropogateFseTableIndex(fseTableIndexPropagated##name); \ if (needPropagateAcrossWave##name) \ { \ fseTableIndexPropagated##name = x; \ @@ -1478,10 +1485,6 @@ static void zstdgpu_ShaderEntry_ParseCompressedBlocks(ZSTDGPU_PARAM_INOUT(zstdgp outBlockData.fseTableIndexOffs = fseTableIndexPropagatedOffs; outBlockData.fseTableIndexMLen = fseTableIndexPropagatedMLen; - #undef WAVE_PROPAGATE_STEP - #undef WAVE_BROADCAST - #undef WAVE_SHUFFLE - #else // use static variables on CPU because this function is expected to be called in a loop for all compressed blocks static uint32_t lastHufWIndex = kzstdgpu_FseProbTableIndex_Unused; diff --git a/zstd/zstdgpu/zstdgpu_structs.h b/zstd/zstdgpu/zstdgpu_structs.h index 2f3f724..de7f544 100644 --- a/zstd/zstdgpu/zstdgpu_structs.h +++ b/zstd/zstdgpu/zstdgpu_structs.h @@ -380,7 +380,7 @@ static const uint32_t kzstdgpu_TgSizeX_PrefixSum = 64; static const uint32_t kzstdgpu_TgSizeX_PrefixSum = 32; #endif -static const uint32_t kzstdgpu_TgSizeX_ParseCompressedBlocks = 32; +#define kzstdgpu_TgSizeX_ParseCompressedBlocks 32 // #define since dxc may lack static_assert static const uint32_t kzstdgpu_TgSizeX_Memset = 64; // NOTE(pamartis): The rationale behind the below choice of TG sizes is the following