Conversation
| @@ -0,0 +1,284 @@ | |||
| """ | |||
There was a problem hiding this comment.
can we add a test for end to end encode/decode correctness? Like this one https://github.com/AI-Hypercomputer/maxdiffusion/blob/main/src/maxdiffusion/tests/wan_vae_test.py#L488
It doesn't need to run in the github runner, but at least it can be ran manually for validation.
There was a problem hiding this comment.
Ran it and achieved an ssim_score of 0.9985
I will add this test in the incoming t2v pipeline PR as this test required loading vae weights.
| self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width | ||
| self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames | ||
|
|
||
| def blend_v(self, a: jax.Array, b: jax.Array, blend_extent: int) -> jax.Array: |
There was a problem hiding this comment.
These for-loops are making jax tracing time huge, gemini suggests a more jax compatible solution for these three methods:
def blend_v(self, a: jax.Array, b: jax.Array, blend_extent: int) -> jax.Array:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
if blend_extent <= 0:
return b
# Create broadcastable blending weights: (1, 1, blend_extent, 1, 1)
y = jnp.arange(blend_extent, dtype=a.dtype).reshape(1, 1, -1, 1, 1)
val = a[:, :, -blend_extent:, :, :] * (1.0 - y / blend_extent) + \
b[:, :, :blend_extent, :, :] * (y / blend_extent)
return b.at[:, :, :blend_extent, :, :].set(val)
def blend_h(self, a: jax.Array, b: jax.Array, blend_extent: int) -> jax.Array:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
if blend_extent <= 0:
return b
# Create broadcastable blending weights: (1, 1, 1, blend_extent, 1)
x = jnp.arange(blend_extent, dtype=a.dtype).reshape(1, 1, 1, -1, 1)
val = a[:, :, :, -blend_extent:, :] * (1.0 - x / blend_extent) + \
b[:, :, :, :blend_extent, :] * (x / blend_extent)
return b.at[:, :, :, :blend_extent, :].set(val)
def blend_t(self, a: jax.Array, b: jax.Array, blend_extent: int) -> jax.Array:
blend_extent = min(a.shape[1], b.shape[1], blend_extent)
if blend_extent <= 0:
return b
# Create broadcastable blending weights: (1, blend_extent, 1, 1, 1)
x = jnp.arange(blend_extent, dtype=a.dtype).reshape(1, -1, 1, 1, 1)
val = a[:, -blend_extent:, :, :, :] * (1.0 - x / blend_extent) + \
b[:, :blend_extent, :, :, :] * (x / blend_extent)
return b.at[:, :blend_extent, :, :, :].set(val)
| self.per_channel_scale2 = None | ||
|
|
||
| if timestep_conditioning: | ||
| self.scale_shift_table = nnx.Param(jax.random.normal(rngs.params(), (4, in_channels)) / (in_channels**0.5)) |
There was a problem hiding this comment.
please add dtype to jax.random.normal call or it will be defaulted to float32 or float64.
| # Compute mean of squared values along channel dimension. | ||
| mean_sq = jnp.mean(jnp.square(x), axis=channel_dim, keepdims=True) | ||
| rms = jnp.sqrt(mean_sq + self.eps) | ||
| return x / rms |
There was a problem hiding this comment.
better to use jax.lax.rsqrt:
return x * jax.lax.rsqrt(mean_sq + self.eps)
| ): | ||
| self.stride = _canonicalize_tuple(stride, 3, "stride") | ||
| self.group_size = (in_channels * self.stride[0] * self.stride[1] * self.stride[2]) // out_channels | ||
|
|
There was a problem hiding this comment.
nit: add assert self.group_size > 0?
d1d4bf6 to
8971606
Compare
This PR adds Video VAE component for LTX-2. This implementation ensures numerical and shapes parity with the reference PyTorch/Diffusers logic.
New files added:
autoencoder_kl_ltx2.py: Video VAE component for LTX2test_video_vae_ltx2.py: unittests for Video VAE