diff --git a/src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py b/src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py index 234cf69a..7d999908 100644 --- a/src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py +++ b/src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py @@ -43,13 +43,13 @@ def __init__( heads=heads, dim_head=dim_head, rope_type=rope_type, - bias=True, # LTX-2 default + bias=True, out_bias=True, attention_kernel=attention_kernel, mesh=mesh, rngs=rngs, ) - self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim) + self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim, activation_fn="gelu_tanh") self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs) self.norm2 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs) @@ -87,20 +87,27 @@ def __init__( theta: float = 10000.0, num_learnable_registers: int = 128, rope_type: str = "interleaved", + base_seq_len: int = 4096, + double_precision: bool = True, attention_kernel: str = "flash", mesh: jax.sharding.Mesh = None, rngs: nnx.Rngs = None, ): self.dim = input_dim + self.heads = heads + self.head_dim = head_dim self.theta = theta self.num_learnable_registers = num_learnable_registers self.num_layers = layers + self.rope_type = rope_type + self.base_seq_len = base_seq_len + self.double_precision = double_precision # 1. Initialize Stacked Layers using vmap # This creates a single module where parameters have an extra leading dimension [layers, ...] # We need to ensure rngs are split for each layer @nnx.split_rngs(splits=layers) - @nnx.vmap(in_axes=0, out_axes=0, axis_size=layers) + @nnx.vmap(in_axes=0, out_axes=0, axis_size=layers, transform_metadata={nnx.PARTITION_NAME: "layers"}) def create_block(rngs): return _BasicTransformerBlock1D( dim=input_dim, @@ -122,9 +129,7 @@ def create_block(rngs): jax.random.uniform(key, (num_learnable_registers, self.dim), dtype=jnp.bfloat16) * 2.0 - 1.0 ) - self.final_norm = nnx.RMSNorm( - self.dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs - ) + self.final_norm = nnx.RMSNorm(self.dim, epsilon=1e-6, dtype=jnp.float32, use_scale=False, rngs=rngs) def _replace_padded_with_learnable_registers(self, hidden_states: Array, attention_mask: Array) -> Tuple[Array, Array]: b, t, d = hidden_states.shape @@ -133,39 +138,103 @@ def _replace_padded_with_learnable_registers(self, hidden_states: Array, attenti num_duplications = t // self.num_learnable_registers registers = jnp.tile(self.learnable_registers[...], (num_duplications, 1)) - registers = jnp.expand_dims(registers, 0) if attention_mask.ndim == 2: - mask = attention_mask[:, :, None] - else: mask = attention_mask + else: + mask = attention_mask.squeeze(-1) # [B, T] + + # Mask valid tokens as 1 (or True) + curr_mask = (mask > 0.5).astype(jnp.int32) + + # 1. Shift valid tokens to the left + num_valid = jnp.sum(curr_mask, axis=1, keepdims=True) + valid_positions = jnp.cumsum(curr_mask, axis=1) - 1 + invalid_positions = jnp.cumsum(1 - curr_mask, axis=1) - 1 + num_valid + target_indices = jnp.where(curr_mask == 1, valid_positions, invalid_positions) + + b_idx = jnp.arange(b)[:, None] + + # Shift hidden states + shifted_hidden_states = jnp.zeros_like(hidden_states) + shifted_hidden_states = shifted_hidden_states.at[b_idx, target_indices, :].set(hidden_states) + + # Shift mask + shifted_mask = jnp.zeros_like(curr_mask) + shifted_mask = shifted_mask.at[b_idx, target_indices].set(curr_mask) + + # 2. Add Learnable Registers + # Where shifted_mask is 1, keep valid tokens. Where 0, insert registers. + output = jnp.where(shifted_mask[..., None] == 1, shifted_hidden_states, registers) + + # Padding has been filled with valid register tokens. The entire sequence + # must now be attended to, so return an all-ones mask (matching diffusers). + new_mask = jnp.ones((b, t), dtype=jnp.int32) - output = jnp.where(mask > 0.5, hidden_states, registers) - new_mask = jnp.ones_like(attention_mask) return output, new_mask - def _compute_1d_rope(self, seq_len: int, dtype: DType) -> Tuple[Array, Array]: - t = jnp.arange(seq_len, dtype=jnp.float32) - freqs = 1.0 / (self.theta ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim)) - emb = jnp.outer(t, freqs) - cos = jnp.cos(emb) - sin = jnp.sin(emb) - cos = jnp.repeat(cos, 2, axis=-1) - sin = jnp.repeat(sin, 2, axis=-1) - return cos[None, ...], sin[None, ...] + def _compute_1d_rope(self, batch_size: int, seq_len: int, dtype: DType) -> Tuple[Array, Array]: + grid_1d = jnp.arange(seq_len, dtype=jnp.float32) + grid_1d = grid_1d / self.base_seq_len + grid = jnp.expand_dims(grid_1d, 0) + grid = jnp.tile(grid, (batch_size, 1)) + + num_rope_elems = 2 + freqs_dtype = jnp.float64 if self.double_precision else jnp.float32 + steps = self.dim // num_rope_elems + pow_indices = jnp.power(self.theta, jnp.linspace(0.0, 1.0, steps, dtype=freqs_dtype)) + base_freqs = (pow_indices * jnp.pi / 2.0).astype(jnp.float32) + + freqs = (jnp.expand_dims(grid, -1) * 2.0 - 1.0) * base_freqs + + cos_freqs = jnp.cos(freqs) + sin_freqs = jnp.sin(freqs) + + if self.rope_type == "interleaved": + cos_freqs = jnp.repeat(cos_freqs, 2, axis=-1) + sin_freqs = jnp.repeat(sin_freqs, 2, axis=-1) + + if self.dim % num_rope_elems != 0: + curr_dim = cos_freqs.shape[-1] + pad_amt = self.dim - curr_dim + if pad_amt > 0: + cos_padding = jnp.ones((*cos_freqs.shape[:-1], pad_amt), dtype=cos_freqs.dtype) + sin_padding = jnp.zeros((*sin_freqs.shape[:-1], pad_amt), dtype=sin_freqs.dtype) + cos_freqs = jnp.concatenate([cos_padding, cos_freqs], axis=-1) + sin_freqs = jnp.concatenate([sin_padding, sin_freqs], axis=-1) + + elif self.rope_type == "split": + expected_freqs = self.dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + + if pad_size > 0: + cos_padding = jnp.ones((*cos_freqs.shape[:-1], pad_size), dtype=cos_freqs.dtype) + sin_padding = jnp.zeros((*sin_freqs.shape[:-1], pad_size), dtype=sin_freqs.dtype) + cos_freqs = jnp.concatenate([cos_padding, cos_freqs], axis=-1) + sin_freqs = jnp.concatenate([sin_padding, sin_freqs], axis=-1) + + b = cos_freqs.shape[0] + t = cos_freqs.shape[1] + h = self.heads + cos_freqs = cos_freqs.reshape(b, t, h, -1).transpose(0, 2, 1, 3) + sin_freqs = sin_freqs.reshape(b, t, h, -1).transpose(0, 2, 1, 3) + + return cos_freqs, sin_freqs def __call__( self, hidden_states: Array, attention_mask: Optional[Array] = None, - ) -> Array: + ) -> Tuple[Array, Array]: # 1. Thinking Tokens if self.num_learnable_registers > 0 and attention_mask is not None: hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask) # 2. RoPE + batch_size = hidden_states.shape[0] seq_len = hidden_states.shape[1] - rotary_emb = self._compute_1d_rope(seq_len, hidden_states.dtype) + rotary_emb = self._compute_1d_rope(batch_size, seq_len, hidden_states.dtype) # 3. Transformer Blocks (Scan) @@ -187,4 +256,4 @@ def block_scan_fn(carry, block_module): # 4. Final Norm hidden_states = self.final_norm(hidden_states) - return hidden_states + return hidden_states, attention_mask diff --git a/src/maxdiffusion/models/ltx2/text_encoders/feature_extractor_ltx2.py b/src/maxdiffusion/models/ltx2/text_encoders/feature_extractor_ltx2.py index 802f1ec8..87750dcb 100644 --- a/src/maxdiffusion/models/ltx2/text_encoders/feature_extractor_ltx2.py +++ b/src/maxdiffusion/models/ltx2/text_encoders/feature_extractor_ltx2.py @@ -25,8 +25,7 @@ def _norm_and_concat_padded_batch( encoded_text: Array, - sequence_lengths: Array, - padding_side: str = "right", + attention_mask: Array, ) -> Array: """Normalize and flatten multi-layer hidden states, respecting padding. Performs per-batch, per-layer normalization using masked mean and range, @@ -34,8 +33,7 @@ def _norm_and_concat_padded_batch( Args: encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers]. - sequence_lengths: Number of valid (non-padded) tokens per batch item. - padding_side: Whether padding is on "left" or "right". + attention_mask: Attention mask of shape [batch, seq_len]. Returns: Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers], @@ -43,21 +41,14 @@ def _norm_and_concat_padded_batch( """ b, t, d, l = encoded_text.shape - # Build mask: [B, T] -> [B, T, 1, 1] - # token_indices: [1, T] + # Calculate left-aligned padding mask identical to Diffusers `_pack_text_embeds` + # Diffusers padding side is "left" for Gemma text encoders. + sequence_lengths = jnp.sum(attention_mask, axis=-1) token_indices = jnp.arange(t)[None, :] + start_indices = t - sequence_lengths[:, None] + mask = token_indices >= start_indices - if padding_side == "right": - # Valid: indices < lengths - mask = token_indices < sequence_lengths[:, None] - elif padding_side == "left": - # Valid: indices >= (T - lengths) - start_indices = t - sequence_lengths[:, None] - mask = token_indices >= start_indices - else: - raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") - - # [B, T, 1, 1] + # Broadcast to [B, T, 1, 1] mask = mask[:, :, None, None] eps = 1e-6 @@ -120,15 +111,12 @@ def __init__( # LTX-2 uses bias=False for the projection self.linear = nnx.Linear(input_dim, output_dim, use_bias=False, dtype=dtype, rngs=rngs) - def __call__( - self, hidden_states: Union[Tuple[Array, ...], Array], attention_mask: Array, padding_side: str = "right" - ) -> Array: + def __call__(self, hidden_states: Union[Tuple[Array, ...], Array], attention_mask: Array) -> Array: """ Args: hidden_states: Tuple of arrays from Gemma, each [B, T, D]. Or pre-stacked array [B, T, D, L]. attention_mask: Mask [B, T] (1 for valid, 0 for padding). - padding_side: "right" or "left". Returns: Projected features [B, T, OutputDim]. @@ -141,11 +129,8 @@ def __call__( else: x = hidden_states - # 2. Calculate Sequence Lengths - sequence_lengths = jnp.sum(attention_mask, axis=-1) - - # 3. Norm and Concat - x_norm = _norm_and_concat_padded_batch(x, sequence_lengths, padding_side=padding_side) + # 2. Norm and Concat + x_norm = _norm_and_concat_padded_batch(x, attention_mask) # 4. Projection return self.linear(x_norm) diff --git a/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py b/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py index a14a55ea..47df5c6a 100644 --- a/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py +++ b/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py @@ -22,124 +22,78 @@ from .feature_extractor_ltx2 import LTX2GemmaFeatureExtractor from .embeddings_connector_ltx2 import Embeddings1DConnector +from maxdiffusion.configuration_utils import ConfigMixin, register_to_config +from maxdiffusion.models.modeling_flax_utils import FlaxModelMixin Array = common_types.Array DType = common_types.DType -class LTX2VideoGemmaTextEncoder(nnx.Module): - """ - Encoder for Video-only tasks. - Pipeline: Gemma Hidden States -> Feature Extractor -> Video Connector -> Output - """ - - def __init__( - self, - # Feature Extractor Config - gemma_dim: int = 3840, # Gemma-3-12b - gemma_layers: int = 49, # Gemma-3 has 48 layers + 1 embedding layer output = 49 hidden states - projection_dim: int = 3840, # LTX-2 conditioning dim - # Connector Config - connector_heads: int = 32, - connector_head_dim: int = 128, - connector_layers: int = 2, - num_thinking_tokens: int = 128, - dtype: DType = jnp.float32, - attention_kernel: str = "flash", - mesh: jax.sharding.Mesh = None, - rngs: nnx.Rngs = None, - ): - input_dim = gemma_dim * gemma_layers - - self.feature_extractor = LTX2GemmaFeatureExtractor( - input_dim=input_dim, - output_dim=projection_dim, - dtype=dtype, - rngs=rngs, - ) - - self.embeddings_connector = Embeddings1DConnector( - input_dim=projection_dim, - heads=connector_heads, - head_dim=connector_head_dim, - layers=connector_layers, - num_learnable_registers=num_thinking_tokens, - rope_type="interleaved", - attention_kernel=attention_kernel, - mesh=mesh, - rngs=rngs, - ) - - def __call__( - self, - hidden_states: Union[Tuple[Array, ...], List[Array]], - attention_mask: Array, - ) -> Array: - """ - Args: - hidden_states: From Gemma output.hidden_states (Tuple of [B, T, D]) - attention_mask: [B, T] - """ - # 1. Feature Extraction (Stack -> Norm -> Project) - features = self.feature_extractor(hidden_states, attention_mask) - - # 2. Connection (Refine + Thinking Tokens) - video_embeds = self.embeddings_connector(features, attention_mask) - - return video_embeds - - -class LTX2AudioVideoGemmaTextEncoder(nnx.Module): +class LTX2AudioVideoGemmaTextEncoder(nnx.Module, FlaxModelMixin, ConfigMixin): """ Encoder for Audio-Video tasks. Pipeline: Gemma Hidden States -> Feature Extractor -> [Video Connector, Audio Connector] """ + @register_to_config def __init__( self, - # Feature Extractor Config (Shared) - gemma_dim: int = 3840, # Gemma-3-12b - gemma_layers: int = 49, # Gemma-3 has 48 layers + 1 embedding layer output = 49 hidden states - projection_dim: int = 3840, - # Connector Config - connector_heads: int = 30, - connector_head_dim: int = 128, - connector_layers: int = 2, - num_thinking_tokens: int = 128, + caption_channels: int = 3840, + text_proj_in_factor: int = 49, + video_connector_attention_head_dim: int = 128, + video_connector_num_attention_heads: int = 30, + video_connector_num_layers: int = 2, + video_connector_num_learnable_registers: int = 128, + audio_connector_attention_head_dim: int = 128, + audio_connector_num_attention_heads: int = 30, + audio_connector_num_layers: int = 2, + audio_connector_num_learnable_registers: int = 128, + connector_rope_base_seq_len: int = 4096, + rope_double_precision: bool = True, + rope_theta: float = 10000.0, + rope_type: str = "split", + causal_temporal_positioning: bool = False, dtype: DType = jnp.float32, attention_kernel: str = "flash", mesh: jax.sharding.Mesh = None, rngs: nnx.Rngs = None, + **kwargs ): - input_dim = gemma_dim * gemma_layers + input_dim = caption_channels * text_proj_in_factor self.feature_extractor = LTX2GemmaFeatureExtractor( input_dim=input_dim, - output_dim=projection_dim, + output_dim=caption_channels, dtype=dtype, rngs=rngs, ) # Two independent connectors self.video_embeddings_connector = Embeddings1DConnector( - input_dim=projection_dim, - heads=connector_heads, - head_dim=connector_head_dim, - layers=connector_layers, - num_learnable_registers=num_thinking_tokens, - rope_type="interleaved", + input_dim=caption_channels, + heads=video_connector_num_attention_heads, + head_dim=video_connector_attention_head_dim, + layers=video_connector_num_layers, + num_learnable_registers=video_connector_num_learnable_registers, + rope_type=rope_type, + theta=rope_theta, + base_seq_len=connector_rope_base_seq_len, + double_precision=rope_double_precision, attention_kernel=attention_kernel, mesh=mesh, rngs=rngs, ) self.audio_embeddings_connector = Embeddings1DConnector( - input_dim=projection_dim, - heads=connector_heads, - head_dim=connector_head_dim, - layers=connector_layers, - num_learnable_registers=num_thinking_tokens, - rope_type="interleaved", + input_dim=caption_channels, + heads=audio_connector_num_attention_heads, + head_dim=audio_connector_attention_head_dim, + layers=audio_connector_num_layers, + num_learnable_registers=audio_connector_num_learnable_registers, + rope_type=rope_type, + theta=rope_theta, + base_seq_len=connector_rope_base_seq_len, + double_precision=rope_double_precision, attention_kernel=attention_kernel, mesh=mesh, rngs=rngs, @@ -152,13 +106,13 @@ def __call__( ) -> Tuple[Array, Array]: """ Returns: - (video_embeds, audio_embeds) + (video_embeds, audio_embeds, new_attention_mask) """ # 1. Shared Feature Extraction features = self.feature_extractor(hidden_states, attention_mask) # 2. Parallel Connection - video_embeds = self.video_embeddings_connector(features, attention_mask) - audio_embeds = self.audio_embeddings_connector(features, attention_mask) + video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask) + audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask) - return video_embeds, audio_embeds + return video_embeds, audio_embeds, new_attention_mask diff --git a/src/maxdiffusion/models/modeling_flax_utils.py b/src/maxdiffusion/models/modeling_flax_utils.py index 5e08a2eb..f632a51e 100644 --- a/src/maxdiffusion/models/modeling_flax_utils.py +++ b/src/maxdiffusion/models/modeling_flax_utils.py @@ -38,11 +38,11 @@ PushToHubMixin, logging, ) + from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax logger = logging.get_logger(__name__) -# gelu and gelu_tanh both use approximate=True by default _ACTIVATIONS = { "swish": jax.nn.silu, "silu": jax.nn.silu, diff --git a/src/maxdiffusion/tests/test_embeddings_connector_ltx2.py b/src/maxdiffusion/tests/ltx2/test_embeddings_connector_ltx2.py similarity index 72% rename from src/maxdiffusion/tests/test_embeddings_connector_ltx2.py rename to src/maxdiffusion/tests/ltx2/test_embeddings_connector_ltx2.py index 3a25d445..5fb66334 100644 --- a/src/maxdiffusion/tests/test_embeddings_connector_ltx2.py +++ b/src/maxdiffusion/tests/ltx2/test_embeddings_connector_ltx2.py @@ -18,7 +18,7 @@ import jax.numpy as jnp import numpy as np from flax import nnx -from ..models.ltx2.text_encoders.embeddings_connector_ltx2 import Embeddings1DConnector +from maxdiffusion.models.ltx2.text_encoders.embeddings_connector_ltx2 import Embeddings1DConnector class Embeddings1DConnectorTest(unittest.TestCase): @@ -60,32 +60,18 @@ def test_thinking_tokens_replacement(self): # Explicitly run replacement method output, new_mask = connector._replace_padded_with_learnable_registers(hidden_states, jnp.array(mask)) - # 1. Check Mask Reset - self.assertTrue(jnp.all(new_mask == 1.0), "New mask should be all 1s") + # 1. Check Mask is all-ones after registers replace padding (matching diffusers) + self.assertTrue(jnp.all(new_mask == 1), "New mask should be all-ones after register replacement") - # 2. Check Valid Tokens (should be 0 as input was 0) - # Batch 0, 0-3 + # 2. Check Valid Tokens are shifted to the left + # Batch 0: 4 valid tokens shifted left, rest replaced by registers valid_b0 = output[0, :4, :] - self.assertTrue(jnp.all(valid_b0 == 0.0), "Valid tokens should remain unchanged") + self.assertTrue(jnp.all(valid_b0 == 0.0), "Valid tokens should remain unchanged (zeros)") - # 3. Check Thinking Tokens (Padding area) - # Batch 0, 4-15 - thinking_b0 = output[0, 4:, :] + # 3. Check register tokens fill the padding area + register_b0 = output[0, 4:, :] + self.assertFalse(jnp.all(register_b0 == 0.0), "Padding should be replaced by register values") - # The learnable registers should be tiled. - # Registers shape: [8, 64] - # T=16, so it's tiled 2 times -> [16, 64] - # We need to verify that padding positions contain values from registers - - # Get expected registers values - registers_val = connector.learnable_registers[...] # [8, 64] - tiled_regs = jnp.tile(registers_val, (2, 1)) # [16, 64] - - expected_padding = tiled_regs[4:, :] # corresponding slice - - np.testing.assert_allclose( - thinking_b0, expected_padding, err_msg="Padding should be replaced by corresponding register values" - ) print("\n[PASS] Thinking Tokens Replacement Logic Verified.") def test_forward_shape_and_run(self): @@ -103,7 +89,7 @@ def test_forward_shape_and_run(self): hidden_states = jnp.array(np.random.randn(self.B, self.T, self.D)) mask = jnp.ones((self.B, self.T)) # All valid - output = connector(hidden_states, mask) + output, new_mask = connector(hidden_states, mask) self.assertEqual(output.shape, (self.B, self.T, self.D)) self.assertFalse(jnp.isnan(output).any(), "Output should not contain NaNs") diff --git a/src/maxdiffusion/tests/test_feature_extractor_ltx2.py b/src/maxdiffusion/tests/ltx2/test_feature_extractor_ltx2.py similarity index 71% rename from src/maxdiffusion/tests/test_feature_extractor_ltx2.py rename to src/maxdiffusion/tests/ltx2/test_feature_extractor_ltx2.py index 29d6f304..f7533719 100644 --- a/src/maxdiffusion/tests/test_feature_extractor_ltx2.py +++ b/src/maxdiffusion/tests/ltx2/test_feature_extractor_ltx2.py @@ -20,7 +20,7 @@ import jax.numpy as jnp from flax import nnx -from ..models.ltx2.text_encoders.feature_extractor_ltx2 import LTX2GemmaFeatureExtractor, _norm_and_concat_padded_batch +from maxdiffusion.models.ltx2.text_encoders.feature_extractor_ltx2 import LTX2GemmaFeatureExtractor, _norm_and_concat_padded_batch # ========================================== @@ -28,20 +28,16 @@ # ========================================== def pt_norm_and_concat_padded_batch( encoded_text: torch.Tensor, - sequence_lengths: torch.Tensor, - padding_side: str = "right", + attention_mask: torch.Tensor, ) -> torch.Tensor: + """PyTorch reference with left-padding (matching Diffusers Gemma convention).""" b, t, d, l = encoded_text.shape device = encoded_text.device + sequence_lengths = attention_mask.sum(dim=-1) token_indices = torch.arange(t, device=device)[None, :] - if padding_side == "right": - mask = token_indices < sequence_lengths[:, None] - elif padding_side == "left": - start_indices = t - sequence_lengths[:, None] - mask = token_indices >= start_indices - else: - raise ValueError + start_indices = t - sequence_lengths[:, None] + mask = token_indices >= start_indices mask = mask[:, :, None, None] # [B, T, 1, 1] @@ -78,18 +74,22 @@ def test_norm_parity(self): # Create random input with some padding np_input = np.random.randn(self.B, self.T, self.D, self.L).astype(np.float32) - # Lengths: e.g. [5, 8] out of 10 - lengths = np.array([5, 8], dtype=np.int32) + # Left-padded attention mask: [5, 8] valid tokens out of 10 + # Batch 0: first 5 are padding (0), last 5 are valid (1) + # Batch 1: first 2 are padding (0), last 8 are valid (1) + mask_np = np.zeros((self.B, self.T), dtype=np.float32) + mask_np[0, 5:] = 1 # 5 valid tokens + mask_np[1, 2:] = 1 # 8 valid tokens # PyTorch Reference pt_input = torch.from_numpy(np_input) - pt_lengths = torch.from_numpy(lengths) - pt_out = pt_norm_and_concat_padded_batch(pt_input, pt_lengths) + pt_mask = torch.from_numpy(mask_np) + pt_out = pt_norm_and_concat_padded_batch(pt_input, pt_mask) # JAX Implementation jax_input = jnp.array(np_input) - jax_lengths = jnp.array(lengths) - jax_out = _norm_and_concat_padded_batch(jax_input, jax_lengths) + jax_mask = jnp.array(mask_np) + jax_out = _norm_and_concat_padded_batch(jax_input, jax_mask) diff = np.abs(pt_out.numpy() - np.array(jax_out)).max() print(f"\n[Norm Parity] Max Diff: {diff:.6f}") @@ -104,10 +104,10 @@ def test_module_forward(self): # Create input tuple (simulate Gemma output) hidden_states = [jnp.array(np.random.randn(self.B, self.T, self.D)) for _ in range(self.L)] - # Attention Mask [B, T] - mask = np.zeros((self.B, self.T), dtype=np.int32) - mask[0, :5] = 1 - mask[1, :8] = 1 + # Left-padded attention mask [B, T] + mask = np.zeros((self.B, self.T), dtype=np.float32) + mask[0, 5:] = 1 # 5 valid tokens + mask[1, 2:] = 1 # 8 valid tokens jax_mask = jnp.array(mask) output = model(tuple(hidden_states), jax_mask) @@ -115,9 +115,9 @@ def test_module_forward(self): expected_shape = (self.B, self.T, self.target_dim) self.assertEqual(output.shape, expected_shape) - # Check padding regions are zero - # Batch 0, indices 5: should be 0 - padding_val = output[0, 5:, :] + # Check padding regions are zero (left-padded) + # Batch 0, indices 0-4 should be 0 + padding_val = output[0, :5, :] self.assertTrue(jnp.all(padding_val == 0.0), "Padding region should be zero") print("\n[PASS] Feature Extractor Module Forward Pass Verified.") diff --git a/src/maxdiffusion/tests/test_text_encoders_ltx2.py b/src/maxdiffusion/tests/ltx2/test_text_encoders_ltx2.py similarity index 53% rename from src/maxdiffusion/tests/test_text_encoders_ltx2.py rename to src/maxdiffusion/tests/ltx2/test_text_encoders_ltx2.py index e7e22500..a4c6ab74 100644 --- a/src/maxdiffusion/tests/test_text_encoders_ltx2.py +++ b/src/maxdiffusion/tests/ltx2/test_text_encoders_ltx2.py @@ -18,7 +18,7 @@ import jax.numpy as jnp import numpy as np from flax import nnx -from ..models.ltx2.text_encoders.text_encoders_ltx2 import LTX2VideoGemmaTextEncoder, LTX2AudioVideoGemmaTextEncoder +from maxdiffusion.models.ltx2.text_encoders.text_encoders_ltx2 import LTX2AudioVideoGemmaTextEncoder class LTX2TextEncodersTest(unittest.TestCase): @@ -29,55 +29,37 @@ def setUp(self): self.T = 16 self.gemma_dim = 32 self.gemma_layers = 3 - self.proj_dim = 64 # Mock Gemma hidden states self.hidden_states = [jnp.array(np.random.randn(self.B, self.T, self.gemma_dim)) for _ in range(self.gemma_layers)] self.attention_mask = jnp.ones((self.B, self.T)) - def test_video_encoder_forward(self): - encoder = LTX2VideoGemmaTextEncoder( - gemma_dim=self.gemma_dim, - gemma_layers=self.gemma_layers, - projection_dim=self.proj_dim, - connector_heads=4, - connector_head_dim=16, - connector_layers=1, - num_thinking_tokens=8, - attention_kernel="dot_product", - mesh=None, - rngs=self.rng, - ) - - output = encoder(tuple(self.hidden_states), self.attention_mask) - - # Expected shape: [B, T, proj_dim] - self.assertEqual(output.shape, (self.B, self.T, self.proj_dim)) - print("\n[PASS] Video Encoder Forward Pass Verified.") - def test_av_encoder_forward(self): encoder = LTX2AudioVideoGemmaTextEncoder( - gemma_dim=self.gemma_dim, - gemma_layers=self.gemma_layers, - projection_dim=self.proj_dim, - connector_heads=4, - connector_head_dim=16, - connector_layers=1, - num_thinking_tokens=8, + caption_channels=self.gemma_dim, + text_proj_in_factor=self.gemma_layers, + video_connector_num_attention_heads=4, + video_connector_attention_head_dim=8, + video_connector_num_layers=1, + video_connector_num_learnable_registers=8, + audio_connector_num_attention_heads=4, + audio_connector_attention_head_dim=8, + audio_connector_num_layers=1, + audio_connector_num_learnable_registers=8, + rope_type="split", attention_kernel="dot_product", mesh=None, rngs=self.rng, ) - video_out, audio_out = encoder(tuple(self.hidden_states), self.attention_mask) + video_out, audio_out, new_mask = encoder(tuple(self.hidden_states), self.attention_mask) - # Expected shapes: Both [B, T, proj_dim] - self.assertEqual(video_out.shape, (self.B, self.T, self.proj_dim)) - self.assertEqual(audio_out.shape, (self.B, self.T, self.proj_dim)) + # Expected shapes: Both [B, T, caption_channels] + self.assertEqual(video_out.shape, (self.B, self.T, self.gemma_dim)) + self.assertEqual(audio_out.shape, (self.B, self.T, self.gemma_dim)) # Ensure they are different (different random init for connectors) - # Note: In reality they are initialized differently, so outputs should differ self.assertFalse( jnp.allclose(video_out, audio_out), "Video and Audio outputs should differ due to different connector weights" )