@@ -43,13 +43,13 @@ def __init__(
4343 heads = heads ,
4444 dim_head = dim_head ,
4545 rope_type = rope_type ,
46- bias = True , # LTX-2 default
46+ bias = True ,
4747 out_bias = True ,
4848 attention_kernel = attention_kernel ,
4949 mesh = mesh ,
5050 rngs = rngs ,
5151 )
52- self .ff = NNXSimpleFeedForward (rngs = rngs , dim = dim , dim_out = dim )
52+ self .ff = NNXSimpleFeedForward (rngs = rngs , dim = dim , dim_out = dim , activation_fn = "gelu_tanh" )
5353 self .norm1 = nnx .RMSNorm (dim , epsilon = 1e-6 , dtype = jnp .float32 , param_dtype = jnp .float32 , use_scale = False , rngs = rngs )
5454 self .norm2 = nnx .RMSNorm (dim , epsilon = 1e-6 , dtype = jnp .float32 , param_dtype = jnp .float32 , use_scale = False , rngs = rngs )
5555
@@ -87,20 +87,27 @@ def __init__(
8787 theta : float = 10000.0 ,
8888 num_learnable_registers : int = 128 ,
8989 rope_type : str = "interleaved" ,
90+ base_seq_len : int = 4096 ,
91+ double_precision : bool = True ,
9092 attention_kernel : str = "flash" ,
9193 mesh : jax .sharding .Mesh = None ,
9294 rngs : nnx .Rngs = None ,
9395 ):
9496 self .dim = input_dim
97+ self .heads = heads
98+ self .head_dim = head_dim
9599 self .theta = theta
96100 self .num_learnable_registers = num_learnable_registers
97101 self .num_layers = layers
102+ self .rope_type = rope_type
103+ self .base_seq_len = base_seq_len
104+ self .double_precision = double_precision
98105
99106 # 1. Initialize Stacked Layers using vmap
100107 # This creates a single module where parameters have an extra leading dimension [layers, ...]
101108 # We need to ensure rngs are split for each layer
102109 @nnx .split_rngs (splits = layers )
103- @nnx .vmap (in_axes = 0 , out_axes = 0 , axis_size = layers )
110+ @nnx .vmap (in_axes = 0 , out_axes = 0 , axis_size = layers , transform_metadata = { nnx . PARTITION_NAME : "layers" } )
104111 def create_block (rngs ):
105112 return _BasicTransformerBlock1D (
106113 dim = input_dim ,
@@ -122,9 +129,7 @@ def create_block(rngs):
122129 jax .random .uniform (key , (num_learnable_registers , self .dim ), dtype = jnp .bfloat16 ) * 2.0 - 1.0
123130 )
124131
125- self .final_norm = nnx .RMSNorm (
126- self .dim , epsilon = 1e-6 , dtype = jnp .float32 , param_dtype = jnp .float32 , use_scale = False , rngs = rngs
127- )
132+ self .final_norm = nnx .RMSNorm (self .dim , epsilon = 1e-6 , dtype = jnp .float32 , use_scale = False , rngs = rngs )
128133
129134 def _replace_padded_with_learnable_registers (self , hidden_states : Array , attention_mask : Array ) -> Tuple [Array , Array ]:
130135 b , t , d = hidden_states .shape
@@ -133,39 +138,103 @@ def _replace_padded_with_learnable_registers(self, hidden_states: Array, attenti
133138
134139 num_duplications = t // self .num_learnable_registers
135140 registers = jnp .tile (self .learnable_registers [...], (num_duplications , 1 ))
136- registers = jnp .expand_dims (registers , 0 )
137141
138142 if attention_mask .ndim == 2 :
139- mask = attention_mask [:, :, None ]
140- else :
141143 mask = attention_mask
144+ else :
145+ mask = attention_mask .squeeze (- 1 ) # [B, T]
146+
147+ # Mask valid tokens as 1 (or True)
148+ curr_mask = (mask > 0.5 ).astype (jnp .int32 )
149+
150+ # 1. Shift valid tokens to the left
151+ num_valid = jnp .sum (curr_mask , axis = 1 , keepdims = True )
152+ valid_positions = jnp .cumsum (curr_mask , axis = 1 ) - 1
153+ invalid_positions = jnp .cumsum (1 - curr_mask , axis = 1 ) - 1 + num_valid
154+ target_indices = jnp .where (curr_mask == 1 , valid_positions , invalid_positions )
155+
156+ b_idx = jnp .arange (b )[:, None ]
157+
158+ # Shift hidden states
159+ shifted_hidden_states = jnp .zeros_like (hidden_states )
160+ shifted_hidden_states = shifted_hidden_states .at [b_idx , target_indices , :].set (hidden_states )
161+
162+ # Shift mask
163+ shifted_mask = jnp .zeros_like (curr_mask )
164+ shifted_mask = shifted_mask .at [b_idx , target_indices ].set (curr_mask )
165+
166+ # 2. Add Learnable Registers
167+ # Where shifted_mask is 1, keep valid tokens. Where 0, insert registers.
168+ output = jnp .where (shifted_mask [..., None ] == 1 , shifted_hidden_states , registers )
169+
170+ # Padding has been filled with valid register tokens. The entire sequence
171+ # must now be attended to, so return an all-ones mask (matching diffusers).
172+ new_mask = jnp .ones ((b , t ), dtype = jnp .int32 )
142173
143- output = jnp .where (mask > 0.5 , hidden_states , registers )
144- new_mask = jnp .ones_like (attention_mask )
145174 return output , new_mask
146175
147- def _compute_1d_rope (self , seq_len : int , dtype : DType ) -> Tuple [Array , Array ]:
148- t = jnp .arange (seq_len , dtype = jnp .float32 )
149- freqs = 1.0 / (self .theta ** (jnp .arange (0 , self .dim , 2 , dtype = jnp .float32 ) / self .dim ))
150- emb = jnp .outer (t , freqs )
151- cos = jnp .cos (emb )
152- sin = jnp .sin (emb )
153- cos = jnp .repeat (cos , 2 , axis = - 1 )
154- sin = jnp .repeat (sin , 2 , axis = - 1 )
155- return cos [None , ...], sin [None , ...]
176+ def _compute_1d_rope (self , batch_size : int , seq_len : int , dtype : DType ) -> Tuple [Array , Array ]:
177+ grid_1d = jnp .arange (seq_len , dtype = jnp .float32 )
178+ grid_1d = grid_1d / self .base_seq_len
179+ grid = jnp .expand_dims (grid_1d , 0 )
180+ grid = jnp .tile (grid , (batch_size , 1 ))
181+
182+ num_rope_elems = 2
183+ freqs_dtype = jnp .float64 if self .double_precision else jnp .float32
184+ steps = self .dim // num_rope_elems
185+ pow_indices = jnp .power (self .theta , jnp .linspace (0.0 , 1.0 , steps , dtype = freqs_dtype ))
186+ base_freqs = (pow_indices * jnp .pi / 2.0 ).astype (jnp .float32 )
187+
188+ freqs = (jnp .expand_dims (grid , - 1 ) * 2.0 - 1.0 ) * base_freqs
189+
190+ cos_freqs = jnp .cos (freqs )
191+ sin_freqs = jnp .sin (freqs )
192+
193+ if self .rope_type == "interleaved" :
194+ cos_freqs = jnp .repeat (cos_freqs , 2 , axis = - 1 )
195+ sin_freqs = jnp .repeat (sin_freqs , 2 , axis = - 1 )
196+
197+ if self .dim % num_rope_elems != 0 :
198+ curr_dim = cos_freqs .shape [- 1 ]
199+ pad_amt = self .dim - curr_dim
200+ if pad_amt > 0 :
201+ cos_padding = jnp .ones ((* cos_freqs .shape [:- 1 ], pad_amt ), dtype = cos_freqs .dtype )
202+ sin_padding = jnp .zeros ((* sin_freqs .shape [:- 1 ], pad_amt ), dtype = sin_freqs .dtype )
203+ cos_freqs = jnp .concatenate ([cos_padding , cos_freqs ], axis = - 1 )
204+ sin_freqs = jnp .concatenate ([sin_padding , sin_freqs ], axis = - 1 )
205+
206+ elif self .rope_type == "split" :
207+ expected_freqs = self .dim // 2
208+ current_freqs = freqs .shape [- 1 ]
209+ pad_size = expected_freqs - current_freqs
210+
211+ if pad_size > 0 :
212+ cos_padding = jnp .ones ((* cos_freqs .shape [:- 1 ], pad_size ), dtype = cos_freqs .dtype )
213+ sin_padding = jnp .zeros ((* sin_freqs .shape [:- 1 ], pad_size ), dtype = sin_freqs .dtype )
214+ cos_freqs = jnp .concatenate ([cos_padding , cos_freqs ], axis = - 1 )
215+ sin_freqs = jnp .concatenate ([sin_padding , sin_freqs ], axis = - 1 )
216+
217+ b = cos_freqs .shape [0 ]
218+ t = cos_freqs .shape [1 ]
219+ h = self .heads
220+ cos_freqs = cos_freqs .reshape (b , t , h , - 1 ).transpose (0 , 2 , 1 , 3 )
221+ sin_freqs = sin_freqs .reshape (b , t , h , - 1 ).transpose (0 , 2 , 1 , 3 )
222+
223+ return cos_freqs , sin_freqs
156224
157225 def __call__ (
158226 self ,
159227 hidden_states : Array ,
160228 attention_mask : Optional [Array ] = None ,
161- ) -> Array :
229+ ) -> Tuple [ Array , Array ] :
162230 # 1. Thinking Tokens
163231 if self .num_learnable_registers > 0 and attention_mask is not None :
164232 hidden_states , attention_mask = self ._replace_padded_with_learnable_registers (hidden_states , attention_mask )
165233
166234 # 2. RoPE
235+ batch_size = hidden_states .shape [0 ]
167236 seq_len = hidden_states .shape [1 ]
168- rotary_emb = self ._compute_1d_rope (seq_len , hidden_states .dtype )
237+ rotary_emb = self ._compute_1d_rope (batch_size , seq_len , hidden_states .dtype )
169238
170239 # 3. Transformer Blocks (Scan)
171240
@@ -187,4 +256,4 @@ def block_scan_fn(carry, block_module):
187256 # 4. Final Norm
188257 hidden_states = self .final_norm (hidden_states )
189258
190- return hidden_states
259+ return hidden_states , attention_mask
0 commit comments