Skip to content

Commit 2d1794e

Browse files
committed
Refactor LTX2 text encoders: replace Video/AV classes with unified EmbeddingsProcessor; move tests to tests/ltx2/
Signed-off-by: James Huang <syhuang1201@gmail.com>
1 parent 02dbc99 commit 2d1794e

File tree

7 files changed

+197
-221
lines changed

7 files changed

+197
-221
lines changed

src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py

Lines changed: 92 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ def __init__(
4343
heads=heads,
4444
dim_head=dim_head,
4545
rope_type=rope_type,
46-
bias=True, # LTX-2 default
46+
bias=True,
4747
out_bias=True,
4848
attention_kernel=attention_kernel,
4949
mesh=mesh,
5050
rngs=rngs,
5151
)
52-
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim)
52+
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim, activation_fn="gelu_tanh")
5353
self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs)
5454
self.norm2 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs)
5555

@@ -87,20 +87,27 @@ def __init__(
8787
theta: float = 10000.0,
8888
num_learnable_registers: int = 128,
8989
rope_type: str = "interleaved",
90+
base_seq_len: int = 4096,
91+
double_precision: bool = True,
9092
attention_kernel: str = "flash",
9193
mesh: jax.sharding.Mesh = None,
9294
rngs: nnx.Rngs = None,
9395
):
9496
self.dim = input_dim
97+
self.heads = heads
98+
self.head_dim = head_dim
9599
self.theta = theta
96100
self.num_learnable_registers = num_learnable_registers
97101
self.num_layers = layers
102+
self.rope_type = rope_type
103+
self.base_seq_len = base_seq_len
104+
self.double_precision = double_precision
98105

99106
# 1. Initialize Stacked Layers using vmap
100107
# This creates a single module where parameters have an extra leading dimension [layers, ...]
101108
# We need to ensure rngs are split for each layer
102109
@nnx.split_rngs(splits=layers)
103-
@nnx.vmap(in_axes=0, out_axes=0, axis_size=layers)
110+
@nnx.vmap(in_axes=0, out_axes=0, axis_size=layers, transform_metadata={nnx.PARTITION_NAME: "layers"})
104111
def create_block(rngs):
105112
return _BasicTransformerBlock1D(
106113
dim=input_dim,
@@ -122,9 +129,7 @@ def create_block(rngs):
122129
jax.random.uniform(key, (num_learnable_registers, self.dim), dtype=jnp.bfloat16) * 2.0 - 1.0
123130
)
124131

125-
self.final_norm = nnx.RMSNorm(
126-
self.dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs
127-
)
132+
self.final_norm = nnx.RMSNorm(self.dim, epsilon=1e-6, dtype=jnp.float32, use_scale=False, rngs=rngs)
128133

129134
def _replace_padded_with_learnable_registers(self, hidden_states: Array, attention_mask: Array) -> Tuple[Array, Array]:
130135
b, t, d = hidden_states.shape
@@ -133,39 +138,103 @@ def _replace_padded_with_learnable_registers(self, hidden_states: Array, attenti
133138

134139
num_duplications = t // self.num_learnable_registers
135140
registers = jnp.tile(self.learnable_registers[...], (num_duplications, 1))
136-
registers = jnp.expand_dims(registers, 0)
137141

138142
if attention_mask.ndim == 2:
139-
mask = attention_mask[:, :, None]
140-
else:
141143
mask = attention_mask
144+
else:
145+
mask = attention_mask.squeeze(-1) # [B, T]
146+
147+
# Mask valid tokens as 1 (or True)
148+
curr_mask = (mask > 0.5).astype(jnp.int32)
149+
150+
# 1. Shift valid tokens to the left
151+
num_valid = jnp.sum(curr_mask, axis=1, keepdims=True)
152+
valid_positions = jnp.cumsum(curr_mask, axis=1) - 1
153+
invalid_positions = jnp.cumsum(1 - curr_mask, axis=1) - 1 + num_valid
154+
target_indices = jnp.where(curr_mask == 1, valid_positions, invalid_positions)
155+
156+
b_idx = jnp.arange(b)[:, None]
157+
158+
# Shift hidden states
159+
shifted_hidden_states = jnp.zeros_like(hidden_states)
160+
shifted_hidden_states = shifted_hidden_states.at[b_idx, target_indices, :].set(hidden_states)
161+
162+
# Shift mask
163+
shifted_mask = jnp.zeros_like(curr_mask)
164+
shifted_mask = shifted_mask.at[b_idx, target_indices].set(curr_mask)
165+
166+
# 2. Add Learnable Registers
167+
# Where shifted_mask is 1, keep valid tokens. Where 0, insert registers.
168+
output = jnp.where(shifted_mask[..., None] == 1, shifted_hidden_states, registers)
169+
170+
# Padding has been filled with valid register tokens. The entire sequence
171+
# must now be attended to, so return an all-ones mask (matching diffusers).
172+
new_mask = jnp.ones((b, t), dtype=jnp.int32)
142173

143-
output = jnp.where(mask > 0.5, hidden_states, registers)
144-
new_mask = jnp.ones_like(attention_mask)
145174
return output, new_mask
146175

147-
def _compute_1d_rope(self, seq_len: int, dtype: DType) -> Tuple[Array, Array]:
148-
t = jnp.arange(seq_len, dtype=jnp.float32)
149-
freqs = 1.0 / (self.theta ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
150-
emb = jnp.outer(t, freqs)
151-
cos = jnp.cos(emb)
152-
sin = jnp.sin(emb)
153-
cos = jnp.repeat(cos, 2, axis=-1)
154-
sin = jnp.repeat(sin, 2, axis=-1)
155-
return cos[None, ...], sin[None, ...]
176+
def _compute_1d_rope(self, batch_size: int, seq_len: int, dtype: DType) -> Tuple[Array, Array]:
177+
grid_1d = jnp.arange(seq_len, dtype=jnp.float32)
178+
grid_1d = grid_1d / self.base_seq_len
179+
grid = jnp.expand_dims(grid_1d, 0)
180+
grid = jnp.tile(grid, (batch_size, 1))
181+
182+
num_rope_elems = 2
183+
freqs_dtype = jnp.float64 if self.double_precision else jnp.float32
184+
steps = self.dim // num_rope_elems
185+
pow_indices = jnp.power(self.theta, jnp.linspace(0.0, 1.0, steps, dtype=freqs_dtype))
186+
base_freqs = (pow_indices * jnp.pi / 2.0).astype(jnp.float32)
187+
188+
freqs = (jnp.expand_dims(grid, -1) * 2.0 - 1.0) * base_freqs
189+
190+
cos_freqs = jnp.cos(freqs)
191+
sin_freqs = jnp.sin(freqs)
192+
193+
if self.rope_type == "interleaved":
194+
cos_freqs = jnp.repeat(cos_freqs, 2, axis=-1)
195+
sin_freqs = jnp.repeat(sin_freqs, 2, axis=-1)
196+
197+
if self.dim % num_rope_elems != 0:
198+
curr_dim = cos_freqs.shape[-1]
199+
pad_amt = self.dim - curr_dim
200+
if pad_amt > 0:
201+
cos_padding = jnp.ones((*cos_freqs.shape[:-1], pad_amt), dtype=cos_freqs.dtype)
202+
sin_padding = jnp.zeros((*sin_freqs.shape[:-1], pad_amt), dtype=sin_freqs.dtype)
203+
cos_freqs = jnp.concatenate([cos_padding, cos_freqs], axis=-1)
204+
sin_freqs = jnp.concatenate([sin_padding, sin_freqs], axis=-1)
205+
206+
elif self.rope_type == "split":
207+
expected_freqs = self.dim // 2
208+
current_freqs = freqs.shape[-1]
209+
pad_size = expected_freqs - current_freqs
210+
211+
if pad_size > 0:
212+
cos_padding = jnp.ones((*cos_freqs.shape[:-1], pad_size), dtype=cos_freqs.dtype)
213+
sin_padding = jnp.zeros((*sin_freqs.shape[:-1], pad_size), dtype=sin_freqs.dtype)
214+
cos_freqs = jnp.concatenate([cos_padding, cos_freqs], axis=-1)
215+
sin_freqs = jnp.concatenate([sin_padding, sin_freqs], axis=-1)
216+
217+
b = cos_freqs.shape[0]
218+
t = cos_freqs.shape[1]
219+
h = self.heads
220+
cos_freqs = cos_freqs.reshape(b, t, h, -1).transpose(0, 2, 1, 3)
221+
sin_freqs = sin_freqs.reshape(b, t, h, -1).transpose(0, 2, 1, 3)
222+
223+
return cos_freqs, sin_freqs
156224

157225
def __call__(
158226
self,
159227
hidden_states: Array,
160228
attention_mask: Optional[Array] = None,
161-
) -> Array:
229+
) -> Tuple[Array, Array]:
162230
# 1. Thinking Tokens
163231
if self.num_learnable_registers > 0 and attention_mask is not None:
164232
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)
165233

166234
# 2. RoPE
235+
batch_size = hidden_states.shape[0]
167236
seq_len = hidden_states.shape[1]
168-
rotary_emb = self._compute_1d_rope(seq_len, hidden_states.dtype)
237+
rotary_emb = self._compute_1d_rope(batch_size, seq_len, hidden_states.dtype)
169238

170239
# 3. Transformer Blocks (Scan)
171240

@@ -187,4 +256,4 @@ def block_scan_fn(carry, block_module):
187256
# 4. Final Norm
188257
hidden_states = self.final_norm(hidden_states)
189258

190-
return hidden_states
259+
return hidden_states, attention_mask

src/maxdiffusion/models/ltx2/text_encoders/feature_extractor_ltx2.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,39 +25,30 @@
2525

2626
def _norm_and_concat_padded_batch(
2727
encoded_text: Array,
28-
sequence_lengths: Array,
29-
padding_side: str = "right",
28+
attention_mask: Array,
3029
) -> Array:
3130
"""Normalize and flatten multi-layer hidden states, respecting padding.
3231
Performs per-batch, per-layer normalization using masked mean and range,
3332
then concatenates across the layer dimension.
3433
3534
Args:
3635
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
37-
sequence_lengths: Number of valid (non-padded) tokens per batch item.
38-
padding_side: Whether padding is on "left" or "right".
36+
attention_mask: Attention mask of shape [batch, seq_len].
3937
4038
Returns:
4139
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
4240
with padded positions zeroed out.
4341
"""
4442
b, t, d, l = encoded_text.shape
4543

46-
# Build mask: [B, T] -> [B, T, 1, 1]
47-
# token_indices: [1, T]
44+
# Calculate left-aligned padding mask identical to Diffusers `_pack_text_embeds`
45+
# Diffusers padding side is "left" for Gemma text encoders.
46+
sequence_lengths = jnp.sum(attention_mask, axis=-1)
4847
token_indices = jnp.arange(t)[None, :]
48+
start_indices = t - sequence_lengths[:, None]
49+
mask = token_indices >= start_indices
4950

50-
if padding_side == "right":
51-
# Valid: indices < lengths
52-
mask = token_indices < sequence_lengths[:, None]
53-
elif padding_side == "left":
54-
# Valid: indices >= (T - lengths)
55-
start_indices = t - sequence_lengths[:, None]
56-
mask = token_indices >= start_indices
57-
else:
58-
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
59-
60-
# [B, T, 1, 1]
51+
# Broadcast to [B, T, 1, 1]
6152
mask = mask[:, :, None, None]
6253

6354
eps = 1e-6
@@ -120,15 +111,12 @@ def __init__(
120111
# LTX-2 uses bias=False for the projection
121112
self.linear = nnx.Linear(input_dim, output_dim, use_bias=False, dtype=dtype, rngs=rngs)
122113

123-
def __call__(
124-
self, hidden_states: Union[Tuple[Array, ...], Array], attention_mask: Array, padding_side: str = "right"
125-
) -> Array:
114+
def __call__(self, hidden_states: Union[Tuple[Array, ...], Array], attention_mask: Array) -> Array:
126115
"""
127116
Args:
128117
hidden_states: Tuple of arrays from Gemma, each [B, T, D].
129118
Or pre-stacked array [B, T, D, L].
130119
attention_mask: Mask [B, T] (1 for valid, 0 for padding).
131-
padding_side: "right" or "left".
132120
133121
Returns:
134122
Projected features [B, T, OutputDim].
@@ -141,11 +129,8 @@ def __call__(
141129
else:
142130
x = hidden_states
143131

144-
# 2. Calculate Sequence Lengths
145-
sequence_lengths = jnp.sum(attention_mask, axis=-1)
146-
147-
# 3. Norm and Concat
148-
x_norm = _norm_and_concat_padded_batch(x, sequence_lengths, padding_side=padding_side)
132+
# 2. Norm and Concat
133+
x_norm = _norm_and_concat_padded_batch(x, attention_mask)
149134

150135
# 4. Projection
151136
return self.linear(x_norm)

0 commit comments

Comments
 (0)