Skip to content

[LTX2] LTX2 Video VAE implementation#346

Open
prishajain1 wants to merge 1 commit intomainfrom
prisha/ltx2_vae
Open

[LTX2] LTX2 Video VAE implementation#346
prishajain1 wants to merge 1 commit intomainfrom
prisha/ltx2_vae

Conversation

@prishajain1
Copy link
Collaborator

This PR adds Video VAE component for LTX-2. This implementation ensures numerical and shapes parity with the reference PyTorch/Diffusers logic.

New files added:

  • autoencoder_kl_ltx2.py : Video VAE component for LTX2
  • test_video_vae_ltx2.py : unittests for Video VAE

@prishajain1 prishajain1 requested a review from entrpn as a code owner March 5, 2026 05:58
@github-actions
Copy link

github-actions bot commented Mar 5, 2026

@prishajain1 prishajain1 changed the title LTX2 Video VAE implementation [LTX2] LTX2 Video VAE implementation Mar 6, 2026
@@ -0,0 +1,284 @@
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a test for end to end encode/decode correctness? Like this one https://github.com/AI-Hypercomputer/maxdiffusion/blob/main/src/maxdiffusion/tests/wan_vae_test.py#L488

It doesn't need to run in the github runner, but at least it can be ran manually for validation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ran it and achieved an ssim_score of 0.9985
I will add this test in the incoming t2v pipeline PR as this test required loading vae weights.

@prishajain1 prishajain1 requested a review from entrpn March 11, 2026 18:30
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames

def blend_v(self, a: jax.Array, b: jax.Array, blend_extent: int) -> jax.Array:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These for-loops are making jax tracing time huge, gemini suggests a more jax compatible solution for these three methods:

def blend_v(self, a: jax.Array, b: jax.Array, blend_extent: int) -> jax.Array:
    blend_extent = min(a.shape[2], b.shape[2], blend_extent)
    if blend_extent <= 0:
        return b
        
    # Create broadcastable blending weights: (1, 1, blend_extent, 1, 1)
    y = jnp.arange(blend_extent, dtype=a.dtype).reshape(1, 1, -1, 1, 1)
    
    val = a[:, :, -blend_extent:, :, :] * (1.0 - y / blend_extent) + \
          b[:, :, :blend_extent, :, :] * (y / blend_extent)
          
    return b.at[:, :, :blend_extent, :, :].set(val)

def blend_h(self, a: jax.Array, b: jax.Array, blend_extent: int) -> jax.Array:
    blend_extent = min(a.shape[3], b.shape[3], blend_extent)
    if blend_extent <= 0:
        return b
        
    # Create broadcastable blending weights: (1, 1, 1, blend_extent, 1)
    x = jnp.arange(blend_extent, dtype=a.dtype).reshape(1, 1, 1, -1, 1)
    
    val = a[:, :, :, -blend_extent:, :] * (1.0 - x / blend_extent) + \
          b[:, :, :, :blend_extent, :] * (x / blend_extent)
          
    return b.at[:, :, :, :blend_extent, :].set(val)

def blend_t(self, a: jax.Array, b: jax.Array, blend_extent: int) -> jax.Array:
    blend_extent = min(a.shape[1], b.shape[1], blend_extent)
    if blend_extent <= 0:
        return b
        
    # Create broadcastable blending weights: (1, blend_extent, 1, 1, 1)
    x = jnp.arange(blend_extent, dtype=a.dtype).reshape(1, -1, 1, 1, 1)
    
    val = a[:, -blend_extent:, :, :, :] * (1.0 - x / blend_extent) + \
          b[:, :blend_extent, :, :, :] * (x / blend_extent)
          
    return b.at[:, :blend_extent, :, :, :].set(val)

self.per_channel_scale2 = None

if timestep_conditioning:
self.scale_shift_table = nnx.Param(jax.random.normal(rngs.params(), (4, in_channels)) / (in_channels**0.5))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add dtype to jax.random.normal call or it will be defaulted to float32 or float64.

# Compute mean of squared values along channel dimension.
mean_sq = jnp.mean(jnp.square(x), axis=channel_dim, keepdims=True)
rms = jnp.sqrt(mean_sq + self.eps)
return x / rms
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to use jax.lax.rsqrt:

return x * jax.lax.rsqrt(mean_sq + self.eps)

):
self.stride = _canonicalize_tuple(stride, 3, "stride")
self.group_size = (in_channels * self.stride[0] * self.stride[1] * self.stride[2]) // out_channels

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add assert self.group_size > 0?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants