Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ def __init__(
heads=heads,
dim_head=dim_head,
rope_type=rope_type,
bias=True, # LTX-2 default
bias=True,
out_bias=True,
attention_kernel=attention_kernel,
mesh=mesh,
rngs=rngs,
)
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim)
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim, activation_fn="gelu_tanh")
self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs)
self.norm2 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs)

Expand Down Expand Up @@ -87,20 +87,27 @@ def __init__(
theta: float = 10000.0,
num_learnable_registers: int = 128,
rope_type: str = "interleaved",
base_seq_len: int = 4096,
double_precision: bool = True,
attention_kernel: str = "flash",
mesh: jax.sharding.Mesh = None,
rngs: nnx.Rngs = None,
):
self.dim = input_dim
self.heads = heads
self.head_dim = head_dim
self.theta = theta
self.num_learnable_registers = num_learnable_registers
self.num_layers = layers
self.rope_type = rope_type
self.base_seq_len = base_seq_len
self.double_precision = double_precision

# 1. Initialize Stacked Layers using vmap
# This creates a single module where parameters have an extra leading dimension [layers, ...]
# We need to ensure rngs are split for each layer
@nnx.split_rngs(splits=layers)
@nnx.vmap(in_axes=0, out_axes=0, axis_size=layers)
@nnx.vmap(in_axes=0, out_axes=0, axis_size=layers, transform_metadata={nnx.PARTITION_NAME: "layers"})
def create_block(rngs):
return _BasicTransformerBlock1D(
dim=input_dim,
Expand All @@ -122,9 +129,7 @@ def create_block(rngs):
jax.random.uniform(key, (num_learnable_registers, self.dim), dtype=jnp.bfloat16) * 2.0 - 1.0
)

self.final_norm = nnx.RMSNorm(
self.dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs
)
self.final_norm = nnx.RMSNorm(self.dim, epsilon=1e-6, dtype=jnp.float32, use_scale=False, rngs=rngs)

def _replace_padded_with_learnable_registers(self, hidden_states: Array, attention_mask: Array) -> Tuple[Array, Array]:
b, t, d = hidden_states.shape
Expand All @@ -133,39 +138,103 @@ def _replace_padded_with_learnable_registers(self, hidden_states: Array, attenti

num_duplications = t // self.num_learnable_registers
registers = jnp.tile(self.learnable_registers[...], (num_duplications, 1))
registers = jnp.expand_dims(registers, 0)

if attention_mask.ndim == 2:
mask = attention_mask[:, :, None]
else:
mask = attention_mask
else:
mask = attention_mask.squeeze(-1) # [B, T]

# Mask valid tokens as 1 (or True)
curr_mask = (mask > 0.5).astype(jnp.int32)

# 1. Shift valid tokens to the left
num_valid = jnp.sum(curr_mask, axis=1, keepdims=True)
valid_positions = jnp.cumsum(curr_mask, axis=1) - 1
invalid_positions = jnp.cumsum(1 - curr_mask, axis=1) - 1 + num_valid
target_indices = jnp.where(curr_mask == 1, valid_positions, invalid_positions)

b_idx = jnp.arange(b)[:, None]

# Shift hidden states
shifted_hidden_states = jnp.zeros_like(hidden_states)
shifted_hidden_states = shifted_hidden_states.at[b_idx, target_indices, :].set(hidden_states)

# Shift mask
shifted_mask = jnp.zeros_like(curr_mask)
shifted_mask = shifted_mask.at[b_idx, target_indices].set(curr_mask)

# 2. Add Learnable Registers
# Where shifted_mask is 1, keep valid tokens. Where 0, insert registers.
output = jnp.where(shifted_mask[..., None] == 1, shifted_hidden_states, registers)

# Padding has been filled with valid register tokens. The entire sequence
# must now be attended to, so return an all-ones mask (matching diffusers).
new_mask = jnp.ones((b, t), dtype=jnp.int32)

output = jnp.where(mask > 0.5, hidden_states, registers)
new_mask = jnp.ones_like(attention_mask)
return output, new_mask

def _compute_1d_rope(self, seq_len: int, dtype: DType) -> Tuple[Array, Array]:
t = jnp.arange(seq_len, dtype=jnp.float32)
freqs = 1.0 / (self.theta ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
emb = jnp.outer(t, freqs)
cos = jnp.cos(emb)
sin = jnp.sin(emb)
cos = jnp.repeat(cos, 2, axis=-1)
sin = jnp.repeat(sin, 2, axis=-1)
return cos[None, ...], sin[None, ...]
def _compute_1d_rope(self, batch_size: int, seq_len: int, dtype: DType) -> Tuple[Array, Array]:
grid_1d = jnp.arange(seq_len, dtype=jnp.float32)
grid_1d = grid_1d / self.base_seq_len
grid = jnp.expand_dims(grid_1d, 0)
grid = jnp.tile(grid, (batch_size, 1))

num_rope_elems = 2
freqs_dtype = jnp.float64 if self.double_precision else jnp.float32
steps = self.dim // num_rope_elems
pow_indices = jnp.power(self.theta, jnp.linspace(0.0, 1.0, steps, dtype=freqs_dtype))
base_freqs = (pow_indices * jnp.pi / 2.0).astype(jnp.float32)

freqs = (jnp.expand_dims(grid, -1) * 2.0 - 1.0) * base_freqs

cos_freqs = jnp.cos(freqs)
sin_freqs = jnp.sin(freqs)

if self.rope_type == "interleaved":
cos_freqs = jnp.repeat(cos_freqs, 2, axis=-1)
sin_freqs = jnp.repeat(sin_freqs, 2, axis=-1)

if self.dim % num_rope_elems != 0:
curr_dim = cos_freqs.shape[-1]
pad_amt = self.dim - curr_dim
if pad_amt > 0:
cos_padding = jnp.ones((*cos_freqs.shape[:-1], pad_amt), dtype=cos_freqs.dtype)
sin_padding = jnp.zeros((*sin_freqs.shape[:-1], pad_amt), dtype=sin_freqs.dtype)
cos_freqs = jnp.concatenate([cos_padding, cos_freqs], axis=-1)
sin_freqs = jnp.concatenate([sin_padding, sin_freqs], axis=-1)

elif self.rope_type == "split":
expected_freqs = self.dim // 2
current_freqs = freqs.shape[-1]
pad_size = expected_freqs - current_freqs

if pad_size > 0:
cos_padding = jnp.ones((*cos_freqs.shape[:-1], pad_size), dtype=cos_freqs.dtype)
sin_padding = jnp.zeros((*sin_freqs.shape[:-1], pad_size), dtype=sin_freqs.dtype)
cos_freqs = jnp.concatenate([cos_padding, cos_freqs], axis=-1)
sin_freqs = jnp.concatenate([sin_padding, sin_freqs], axis=-1)

b = cos_freqs.shape[0]
t = cos_freqs.shape[1]
h = self.heads
cos_freqs = cos_freqs.reshape(b, t, h, -1).transpose(0, 2, 1, 3)
sin_freqs = sin_freqs.reshape(b, t, h, -1).transpose(0, 2, 1, 3)

return cos_freqs, sin_freqs

def __call__(
self,
hidden_states: Array,
attention_mask: Optional[Array] = None,
) -> Array:
) -> Tuple[Array, Array]:
# 1. Thinking Tokens
if self.num_learnable_registers > 0 and attention_mask is not None:
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)

# 2. RoPE
batch_size = hidden_states.shape[0]
seq_len = hidden_states.shape[1]
rotary_emb = self._compute_1d_rope(seq_len, hidden_states.dtype)
rotary_emb = self._compute_1d_rope(batch_size, seq_len, hidden_states.dtype)

# 3. Transformer Blocks (Scan)

Expand All @@ -187,4 +256,4 @@ def block_scan_fn(carry, block_module):
# 4. Final Norm
hidden_states = self.final_norm(hidden_states)

return hidden_states
return hidden_states, attention_mask
Original file line number Diff line number Diff line change
Expand Up @@ -25,39 +25,30 @@

def _norm_and_concat_padded_batch(
encoded_text: Array,
sequence_lengths: Array,
padding_side: str = "right",
attention_mask: Array,
) -> Array:
"""Normalize and flatten multi-layer hidden states, respecting padding.
Performs per-batch, per-layer normalization using masked mean and range,
then concatenates across the layer dimension.

Args:
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
sequence_lengths: Number of valid (non-padded) tokens per batch item.
padding_side: Whether padding is on "left" or "right".
attention_mask: Attention mask of shape [batch, seq_len].

Returns:
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
with padded positions zeroed out.
"""
b, t, d, l = encoded_text.shape

# Build mask: [B, T] -> [B, T, 1, 1]
# token_indices: [1, T]
# Calculate left-aligned padding mask identical to Diffusers `_pack_text_embeds`
# Diffusers padding side is "left" for Gemma text encoders.
sequence_lengths = jnp.sum(attention_mask, axis=-1)
token_indices = jnp.arange(t)[None, :]
start_indices = t - sequence_lengths[:, None]
mask = token_indices >= start_indices

if padding_side == "right":
# Valid: indices < lengths
mask = token_indices < sequence_lengths[:, None]
elif padding_side == "left":
# Valid: indices >= (T - lengths)
start_indices = t - sequence_lengths[:, None]
mask = token_indices >= start_indices
else:
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")

# [B, T, 1, 1]
# Broadcast to [B, T, 1, 1]
mask = mask[:, :, None, None]

eps = 1e-6
Expand Down Expand Up @@ -120,15 +111,12 @@ def __init__(
# LTX-2 uses bias=False for the projection
self.linear = nnx.Linear(input_dim, output_dim, use_bias=False, dtype=dtype, rngs=rngs)

def __call__(
self, hidden_states: Union[Tuple[Array, ...], Array], attention_mask: Array, padding_side: str = "right"
) -> Array:
def __call__(self, hidden_states: Union[Tuple[Array, ...], Array], attention_mask: Array) -> Array:
"""
Args:
hidden_states: Tuple of arrays from Gemma, each [B, T, D].
Or pre-stacked array [B, T, D, L].
attention_mask: Mask [B, T] (1 for valid, 0 for padding).
padding_side: "right" or "left".

Returns:
Projected features [B, T, OutputDim].
Expand All @@ -141,11 +129,8 @@ def __call__(
else:
x = hidden_states

# 2. Calculate Sequence Lengths
sequence_lengths = jnp.sum(attention_mask, axis=-1)

# 3. Norm and Concat
x_norm = _norm_and_concat_padded_batch(x, sequence_lengths, padding_side=padding_side)
# 2. Norm and Concat
x_norm = _norm_and_concat_padded_batch(x, attention_mask)

# 4. Projection
return self.linear(x_norm)
Loading
Loading