diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 394d539350d6..6c456fde4d71 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -504,6 +504,8 @@ title: AuraFlow - local: api/pipelines/blip_diffusion title: BLIP-Diffusion + - local: api/pipelines/block_refinement + title: Block Refinement - local: api/pipelines/bria_3_2 title: Bria 3.2 - local: api/pipelines/bria_fibo @@ -542,6 +544,8 @@ title: DDPM - local: api/pipelines/deepfloyd_if title: DeepFloyd IF + - local: api/pipelines/dflash + title: DFlash - local: api/pipelines/diffedit title: DiffEdit - local: api/pipelines/dit @@ -562,6 +566,8 @@ title: Hunyuan-DiT - local: api/pipelines/hunyuanimage21 title: HunyuanImage2.1 + - local: api/pipelines/hybrid_token_diffusion + title: Hybrid Token Diffusion - local: api/pipelines/pix2pix title: InstructPix2Pix - local: api/pipelines/kandinsky @@ -614,6 +620,8 @@ title: Sana Sprint - local: api/pipelines/sana_video title: Sana Video + - local: api/pipelines/sdar + title: SDAR - local: api/pipelines/self_attention_guidance title: Self-Attention Guidance - local: api/pipelines/semantic_stable_diffusion @@ -659,6 +667,8 @@ title: Stable Diffusion - local: api/pipelines/stable_unclip title: Stable unCLIP + - local: api/pipelines/token_diffusion + title: Token Diffusion - local: api/pipelines/unclip title: unCLIP - local: api/pipelines/unidiffuser @@ -697,6 +707,8 @@ title: Kandinsky 5.0 Video - local: api/pipelines/latte title: Latte + - local: api/pipelines/llada2 + title: LLaDA2 - local: api/pipelines/ltx2 title: LTX-2 - local: api/pipelines/ltx_video @@ -720,6 +732,8 @@ - sections: - local: api/schedulers/overview title: Overview + - local: api/schedulers/bd3lm_token_diffusion + title: BD3LMTokenDiffusionScheduler - local: api/schedulers/block_refinement title: BlockRefinementScheduler - local: api/schedulers/cm_stochastic_iterative @@ -740,6 +754,8 @@ title: DDPMScheduler - local: api/schedulers/deis title: DEISMultistepScheduler + - local: api/schedulers/dflash_token_diffusion + title: DFlashTokenDiffusionScheduler - local: api/schedulers/multistep_dpm_solver_inverse title: DPMSolverMultistepInverse - local: api/schedulers/multistep_dpm_solver @@ -766,6 +782,8 @@ title: HeliosScheduler - local: api/schedulers/heun title: HeunDiscreteScheduler + - local: api/schedulers/hybrid_token_diffusion + title: HybridTokenDiffusionScheduler - local: api/schedulers/ipndm title: IPNDMScheduler - local: api/schedulers/stochastic_karras_ve @@ -786,8 +804,12 @@ title: ScoreSdeVeScheduler - local: api/schedulers/score_sde_vp title: ScoreSdeVpScheduler + - local: api/schedulers/sdar_token_diffusion + title: SDARTokenDiffusionScheduler - local: api/schedulers/tcd title: TCDScheduler + - local: api/schedulers/token_diffusion + title: TokenDiffusionScheduler - local: api/schedulers/unipc title: UniPCMultistepScheduler - local: api/schedulers/vq_diffusion diff --git a/docs/source/en/api/pipelines/dflash.md b/docs/source/en/api/pipelines/dflash.md new file mode 100644 index 000000000000..95847e1fdd82 --- /dev/null +++ b/docs/source/en/api/pipelines/dflash.md @@ -0,0 +1,24 @@ + + +# DFlash + +`DFlashPipeline` performs block-diffusion speculative decoding using a diffusion draft model and a target causal LM. +The draft model is conditioned on target hidden features extracted during prefill and verification steps. + +## DFlashPipeline +[[autodoc]] DFlashPipeline + - all + - __call__ + +## DFlashPipelineOutput +[[autodoc]] pipelines.DFlashPipelineOutput diff --git a/docs/source/en/api/pipelines/hybrid_token_diffusion.md b/docs/source/en/api/pipelines/hybrid_token_diffusion.md new file mode 100644 index 000000000000..56ccd61bfbc8 --- /dev/null +++ b/docs/source/en/api/pipelines/hybrid_token_diffusion.md @@ -0,0 +1,23 @@ + + +# Hybrid Token Diffusion + +`HybridTokenDiffusionPipeline` is an alias of `TokenDiffusionPipeline` for hybrid-transition schedulers. + +## HybridTokenDiffusionPipeline +[[autodoc]] HybridTokenDiffusionPipeline + - all + - __call__ + +## TokenDiffusionPipelineOutput +[[autodoc]] pipelines.TokenDiffusionPipelineOutput diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md index 3cfdfee8cc2b..6042cc44c1a4 100644 --- a/docs/source/en/api/pipelines/overview.md +++ b/docs/source/en/api/pipelines/overview.md @@ -48,11 +48,14 @@ The table below lists all the pipelines currently available in πŸ€— Diffusers an | [Dance Diffusion](dance_diffusion) | unconditional audio generation | | [DDIM](ddim) | unconditional image generation | | [DDPM](ddpm) | unconditional image generation | +| [DFlash](dflash) | text2text | +| [SDAR](sdar) | text2text | | [DeepFloyd IF](deepfloyd_if) | text2image, image2image, inpainting, super-resolution | | [DiffEdit](diffedit) | inpainting | | [DiT](dit) | text2image | | [Flux](flux) | text2image | | [Hunyuan-DiT](hunyuandit) | text2image | +| [Hybrid Token Diffusion](hybrid_token_diffusion) | text2text | | [I2VGen-XL](i2vgenxl) | image2video | | [InstructPix2Pix](pix2pix) | image editing | | [Kandinsky 2.1](kandinsky) | text2image, image2image, inpainting, interpolation | @@ -85,6 +88,7 @@ The table below lists all the pipelines currently available in πŸ€— Diffusers an | [T2I-Adapter](stable_diffusion/adapter) | text2image | | [Text2Video](text_to_video) | text2video, video2video | | [Text2Video-Zero](text_to_video_zero) | text2video | +| [Token Diffusion](token_diffusion) | text2text | | [unCLIP](unclip) | text2image, image variation | | [UniDiffuser](unidiffuser) | text2image, image2text, image variation, text variation, unconditional image generation, unconditional audio generation | | [Value-guided planning](value_guided_sampling) | value guided sampling | diff --git a/docs/source/en/api/pipelines/sdar.md b/docs/source/en/api/pipelines/sdar.md new file mode 100644 index 000000000000..bfa6aa3fb104 --- /dev/null +++ b/docs/source/en/api/pipelines/sdar.md @@ -0,0 +1,23 @@ + + +# SDAR + +`SDARPipeline` performs block diffusion decoding with iterative remasking strategies. + +## SDARPipeline +[[autodoc]] SDARPipeline + - all + - __call__ + +## SDARPipelineOutput +[[autodoc]] pipelines.SDARPipelineOutput diff --git a/docs/source/en/api/pipelines/token_diffusion.md b/docs/source/en/api/pipelines/token_diffusion.md new file mode 100644 index 000000000000..c105f53abece --- /dev/null +++ b/docs/source/en/api/pipelines/token_diffusion.md @@ -0,0 +1,24 @@ + + +# Token Diffusion + +`TokenDiffusionPipeline` provides a generic token-space diffusion sampler for discrete denoising over token IDs. It +pairs a token denoiser model with a token diffusion scheduler. + +## TokenDiffusionPipeline +[[autodoc]] TokenDiffusionPipeline + - all + - __call__ + +## TokenDiffusionPipelineOutput +[[autodoc]] pipelines.TokenDiffusionPipelineOutput diff --git a/docs/source/en/api/schedulers/dflash_token_diffusion.md b/docs/source/en/api/schedulers/dflash_token_diffusion.md new file mode 100644 index 000000000000..c98b11bc9963 --- /dev/null +++ b/docs/source/en/api/schedulers/dflash_token_diffusion.md @@ -0,0 +1,22 @@ + + +# DFlashTokenDiffusionScheduler + +`DFlashTokenDiffusionScheduler` implements the acceptance and posterior sampling logic used in DFlash-style block +diffusion speculative decoding. + +## DFlashTokenDiffusionScheduler +[[autodoc]] DFlashTokenDiffusionScheduler + +## DFlashTokenDiffusionSchedulerOutput +[[autodoc]] schedulers.scheduling_dflash_token_diffusion.DFlashTokenDiffusionSchedulerOutput diff --git a/docs/source/en/api/schedulers/hybrid_token_diffusion.md b/docs/source/en/api/schedulers/hybrid_token_diffusion.md new file mode 100644 index 000000000000..4dcdda0ea49c --- /dev/null +++ b/docs/source/en/api/schedulers/hybrid_token_diffusion.md @@ -0,0 +1,22 @@ + + +# HybridTokenDiffusionScheduler + +`HybridTokenDiffusionScheduler` defines hybrid discrete token diffusion updates with separate transitions for +masked and unmasked tokens. + +## HybridTokenDiffusionScheduler +[[autodoc]] HybridTokenDiffusionScheduler + +## HybridTokenDiffusionSchedulerOutput +[[autodoc]] schedulers.scheduling_hybrid_token_diffusion.HybridTokenDiffusionSchedulerOutput diff --git a/docs/source/en/api/schedulers/overview.md b/docs/source/en/api/schedulers/overview.md index a57e99a3e46e..5f171266b4a8 100644 --- a/docs/source/en/api/schedulers/overview.md +++ b/docs/source/en/api/schedulers/overview.md @@ -54,6 +54,25 @@ Many schedulers are implemented from the [k-diffusion](https://github.com/crowso | exponential | init with `timestep_spacing="linspace"`, `use_exponential_sigmas=True` | | beta | init with `timestep_spacing="linspace"`, `use_beta_sigmas=True` | +## Token diffusion schedulers + +These schedulers operate over categorical token IDs instead of continuous latents. They are designed for discrete +token diffusion models and expose the same `set_timesteps`/`step` interface as other schedulers. + +Differences between the discrete token schedulers: +- `TokenDiffusionScheduler`: token-level diffusion with per-token corruption (e.g. mask/uniform) and a single-step `step` to denoise logits. +- `HybridTokenDiffusionScheduler`: hybrid transitions that combine token- and block-wise updates in the same schedule. +- `DFlashTokenDiffusionScheduler`: block diffusion scheduler specialized for speculative decoding with a draft model and target acceptance. +- `SDARTokenDiffusionScheduler`: block diffusion scheduler with remasking strategies (sequential/low-confidence/entropy-bounded) per step. + +[[autodoc]] TokenDiffusionScheduler + +[[autodoc]] HybridTokenDiffusionScheduler + +[[autodoc]] DFlashTokenDiffusionScheduler + +[[autodoc]] SDARTokenDiffusionScheduler + All schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers. ## SchedulerMixin diff --git a/docs/source/en/api/schedulers/sdar_token_diffusion.md b/docs/source/en/api/schedulers/sdar_token_diffusion.md new file mode 100644 index 000000000000..7af2185f190f --- /dev/null +++ b/docs/source/en/api/schedulers/sdar_token_diffusion.md @@ -0,0 +1,21 @@ + + +# SDARTokenDiffusionScheduler + +`SDARTokenDiffusionScheduler` implements block diffusion remasking and sampling for SDAR-style decoding. + +## SDARTokenDiffusionScheduler +[[autodoc]] SDARTokenDiffusionScheduler + +## SDARTokenDiffusionSchedulerOutput +[[autodoc]] schedulers.scheduling_sdar_token_diffusion.SDARTokenDiffusionSchedulerOutput diff --git a/docs/source/en/api/schedulers/token_diffusion.md b/docs/source/en/api/schedulers/token_diffusion.md new file mode 100644 index 000000000000..fe5305c00ae5 --- /dev/null +++ b/docs/source/en/api/schedulers/token_diffusion.md @@ -0,0 +1,22 @@ + + +# TokenDiffusionScheduler + +`TokenDiffusionScheduler` defines discrete token diffusion updates over categorical token IDs and supports multiple +forward processes and alpha schedules. + +## TokenDiffusionScheduler +[[autodoc]] TokenDiffusionScheduler + +## TokenDiffusionSchedulerOutput +[[autodoc]] schedulers.scheduling_token_diffusion.TokenDiffusionSchedulerOutput diff --git a/examples/discrete_diffusion/README.md b/examples/discrete_diffusion/README.md index a3a8253b1927..9257fbcefb63 100644 --- a/examples/discrete_diffusion/README.md +++ b/examples/discrete_diffusion/README.md @@ -48,3 +48,142 @@ python examples/discrete_diffusion/sample_llada2.py \ --use_chat_template \ --add_generation_prompt ``` + +## MDLM-style absorbing diffusion + +`train_mdlm.py` trains a masked/absorbing discrete diffusion model: +- Forward process: with probability `1 - alpha(t)`, replace tokens with `mask_token_id` +- Noise schedule: log-linear `alpha(t) = 1 - (1 - eps) * t` +- Loss: weighted token reconstruction NLL over masked positions + +### Run + +```bash +accelerate launch examples/discrete_diffusion/train_mdlm.py \ + --model_name_or_path bert-base-uncased \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --output_dir mdlm-output \ + --max_train_steps 1000 \ + --lambda_conf 0.0 \ + --conf_temperature 1.0 +``` + +The script saves: +- `transformers` model + tokenizer +- `diffusers.TokenDiffusionScheduler` + +into `--output_dir` checkpoints and `--output_dir/final`. + +### Sample + +```bash +python examples/discrete_diffusion/sample_mdlm.py \ + --checkpoint_path mdlm-output/final \ + --num_samples 4 \ + --seq_len 64 \ + --num_inference_steps 128 +``` + +## Block-wise sampling + +Block-wise sampling updates the sequence in chunks, refining only the active block at a time. + +```bash +python examples/discrete_diffusion/sample_block_token_diffusion.py \ + --checkpoint_path mdlm-output/final \ + --num_samples 4 \ + --seq_len 256 \ + --block_size 32 \ + --num_inference_steps 64 \ + --top_p 0.9 +``` + +## DFlash speculative decoding + +Use a diffusion draft model with a target causal LM for block-wise speculative decoding. + +```bash +python examples/discrete_diffusion/sample_dflash.py \ + --draft_model_id z-lab/Qwen3-8B-DFlash-b16 \ + --target_model_id Qwen/Qwen3-8B \ + --prompt "How many positive whole-number divisors does 196 have?" \ + --max_new_tokens 256 \ + --use_chat_template \ + --add_generation_prompt +``` + +## SDAR block diffusion decoding + +Run SDAR-style block diffusion sampling with remasking strategies. + +```bash +python examples/discrete_diffusion/sample_sdar.py \ + --model_id JetLM/SDAR-1.7B-Chat \ + --prompt "Explain what reinforcement learning is in simple terms." \ + --max_new_tokens 256 \ + --block_length 4 \ + --num_inference_steps 4 \ + --remasking_strategy low_confidence_dynamic \ + --confidence_threshold 0.9 \ + --use_chat_template \ + --add_generation_prompt +``` + +### Fine-tune (draft model) + +```bash +accelerate launch examples/discrete_diffusion/train_dflash.py \ + --draft_model_id z-lab/Qwen3-4B-DFlash-b16 \ + --target_model_id Qwen/Qwen3-4B \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --output_dir dflash-output \ + --max_train_steps 100 \ + --logging_steps 10 +``` + +## Hybrid sampling + +Hybrid sampling uses a different transition kernel than absorbing/uniform diffusion and requires a compatible scheduler +configuration saved in the checkpoint directory. + +```bash +python examples/discrete_diffusion/sample_hybrid_token_diffusion.py \ + --checkpoint_path hybrid-output/final \ + --num_samples 4 \ + --seq_len 256 \ + --num_inference_steps 64 +``` + +### Train + +```bash +accelerate launch examples/discrete_diffusion/train_hybrid_token_diffusion.py \ + --model_name_or_path bert-base-uncased \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --output_dir hybrid-output \ + --max_train_steps 1000 \ + --lambda_conf 0.0 \ + --conf_temperature 1.0 +``` + +## UDLM-style uniform diffusion + +`train_udlm.py` trains a uniform token diffusion model: +- Forward process: with probability `1 - alpha(t)`, replace tokens with a uniform random token +- Noise schedule: configurable via `--alpha_schedule` (`log_linear`, `linear`, `cosine`, `geometric`) +- Loss: diffusion loss for uniform token diffusion + +### Run + +```bash +accelerate launch examples/discrete_diffusion/train_udlm.py \ + --model_name_or_path bert-base-uncased \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --output_dir udlm-output \ + --max_train_steps 1000 \ + --exclude_mask_from_uniform +``` diff --git a/examples/discrete_diffusion/sample_dflash.py b/examples/discrete_diffusion/sample_dflash.py new file mode 100644 index 000000000000..6ff0f293410d --- /dev/null +++ b/examples/discrete_diffusion/sample_dflash.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sample script for DFlash speculative decoding. + +Example: + python sample_dflash.py \ + --draft_model_id z-lab/Qwen3-8B-DFlash-b16 \ + --target_model_id Qwen/Qwen3-8B \ + --prompt "How many positive whole-number divisors does 196 have?" \ + --max_new_tokens 256 +""" + +import argparse + +import torch + +from diffusers import DFlashPipeline + + +def main(): + parser = argparse.ArgumentParser(description="Run DFlash speculative decoding.") + parser.add_argument( + "--draft_model_id", + type=str, + default="z-lab/Qwen3-8B-DFlash-b16", + help="Draft model ID or local path.", + ) + parser.add_argument( + "--target_model_id", + type=str, + default="Qwen/Qwen3-8B", + help="Target model ID or local path.", + ) + parser.add_argument( + "--prompt", + type=str, + default="How many positive whole-number divisors does 196 have?", + help="Prompt text to generate from.", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=2048, + help="Maximum number of new tokens to generate.", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Sampling temperature.", + ) + parser.add_argument( + "--use_chat_template", + action="store_true", + help="Use the tokenizer chat template for the prompt.", + ) + parser.add_argument( + "--add_generation_prompt", + action="store_true", + help="Add the generation prompt when using the chat template.", + ) + parser.add_argument( + "--enable_thinking", + action="store_true", + help="Enable chat-template thinking mode if supported by the tokenizer.", + ) + parser.add_argument( + "--mask_token", + type=str, + default="<|MASK|>", + help="Mask token to add if the tokenizer does not define one.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to run inference on.", + ) + parser.add_argument( + "--dtype", + type=str, + default="auto", + choices=["auto", "float32", "float16", "bfloat16"], + help="Model dtype.", + ) + + args = parser.parse_args() + + dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + torch_dtype = dtype_map.get(args.dtype) + + print(f"Loading draft model: {args.draft_model_id}") + print(f"Loading target model: {args.target_model_id}") + dtype_arg = torch_dtype if torch_dtype is not None else "auto" + pipe = DFlashPipeline.from_pretrained( + draft_model_id=args.draft_model_id, + target_model_id=args.target_model_id, + mask_token=args.mask_token, + draft_model_kwargs={ + "trust_remote_code": True, + "dtype": dtype_arg, + "device_map": args.device, + }, + target_model_kwargs={ + "dtype": dtype_arg, + "device_map": args.device, + }, + ) + + chat_kwargs = {"enable_thinking": args.enable_thinking} + + print(f"\nPrompt: {args.prompt}") + output = pipe( + prompt=args.prompt, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + use_chat_template=args.use_chat_template, + add_generation_prompt=args.add_generation_prompt, + chat_template_kwargs=chat_kwargs, + ) + + print("\nGenerated text:") + print(output.texts[0]) + print(f"\nGenerated {output.sequences.shape[1]} tokens") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/sample_hybrid_token_diffusion.py b/examples/discrete_diffusion/sample_hybrid_token_diffusion.py new file mode 100644 index 000000000000..81f35ae5b9c6 --- /dev/null +++ b/examples/discrete_diffusion/sample_hybrid_token_diffusion.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python + +import argparse +from typing import Optional + +import torch +from transformers import AutoModelForMaskedLM, AutoTokenizer + +from diffusers import HybridTokenDiffusionPipeline, HybridTokenDiffusionScheduler + + +def parse_args(): + parser = argparse.ArgumentParser(description="Sample with a hybrid-transition token diffusion scheduler.") + parser.add_argument( + "--checkpoint_path", type=str, required=True, help="Path containing a model + scheduler config." + ) + parser.add_argument("--prompt", type=str, default=None, help="Optional prompt; will be used as a fixed prefix.") + parser.add_argument("--num_samples", type=int, default=4) + parser.add_argument("--seq_len", type=int, default=64) + parser.add_argument("--num_inference_steps", type=int, default=64) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--inject_start_token", action="store_true") + return parser.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, use_fast=True) + model = AutoModelForMaskedLM.from_pretrained(args.checkpoint_path).to(device) + scheduler = HybridTokenDiffusionScheduler.from_pretrained(args.checkpoint_path) + + pipe = HybridTokenDiffusionPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer).to(device) + model.eval() + + generator: Optional[torch.Generator] = torch.Generator(device=device).manual_seed(args.seed) + + prefix_ids = None + if args.prompt is not None: + encoded = tokenizer(args.prompt, return_tensors="pt", add_special_tokens=True) + prefix_ids = encoded["input_ids"].to(device=device, dtype=torch.long) + if prefix_ids.shape[1] > args.seq_len: + raise ValueError(f"--seq_len ({args.seq_len}) must be >= prompt length ({prefix_ids.shape[1]}).") + + out = pipe( + batch_size=args.num_samples, + seq_len=args.seq_len, + num_inference_steps=args.num_inference_steps, + generator=generator, + prefix_ids=prefix_ids, + inject_start_token=args.inject_start_token, + return_text=True, + ) + + for i, t in enumerate(out.texts or []): + print(f"[{i}] {t}") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/sample_mdlm.py b/examples/discrete_diffusion/sample_mdlm.py new file mode 100644 index 000000000000..bc435e647967 --- /dev/null +++ b/examples/discrete_diffusion/sample_mdlm.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sample script for MDLM-style absorbing token diffusion text generation. + +This script demonstrates how to use the TokenDiffusionPipeline for unconditional +text generation using absorbing-state discrete diffusion. + +Example usage: + python sample_mdlm.py --model_id kuleshov-group/mdlm-owt --num_samples 4 --seq_len 64 +""" + +import argparse + +import torch +from transformers import AutoModelForMaskedLM, AutoTokenizer + +from diffusers import TokenDiffusionPipeline, TokenDiffusionScheduler + + +def main(): + parser = argparse.ArgumentParser(description="Sample from an absorbing token diffusion LM (MDLM-style).") + parser.add_argument( + "--model_id", + type=str, + default="kuleshov-group/mdlm-owt", + help="HuggingFace model ID or path to local checkpoint.", + ) + parser.add_argument("--num_samples", type=int, default=4) + parser.add_argument("--seq_len", type=int, default=64) + parser.add_argument("--num_inference_steps", type=int, default=128) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--inject_bos", action="store_true") + parser.add_argument("--trust_remote_code", action="store_true", help="Trust remote code for model/tokenizer.") + + args = parser.parse_args() + device = torch.device(args.device) + + print(f"Loading model: {args.model_id}") + tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True) + model = AutoModelForMaskedLM.from_pretrained(args.model_id, trust_remote_code=args.trust_remote_code).to(device) + model.eval() + + mask_token_id = len(tokenizer) # MDLM appends mask token after vocab + vocab_size = mask_token_id + 1 + scheduler = TokenDiffusionScheduler(vocab_size=vocab_size, mask_token_id=mask_token_id) + + pipe = TokenDiffusionPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer) + + generator = torch.Generator(device=device).manual_seed(args.seed) + + print(f"Generating {args.num_samples} samples of {args.seq_len} tokens with {args.num_inference_steps} steps") + print("-" * 50) + + output = pipe( + batch_size=args.num_samples, + seq_len=args.seq_len, + num_inference_steps=args.num_inference_steps, + generator=generator, + inject_start_token=args.inject_bos, + ) + + for i, text in enumerate(output.texts): + print(f"[{i}] {text}") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/sample_sdar.py b/examples/discrete_diffusion/sample_sdar.py new file mode 100644 index 000000000000..2d671df542bc --- /dev/null +++ b/examples/discrete_diffusion/sample_sdar.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sample script for SDAR-style block diffusion decoding. + +Example: + python sample_sdar.py \ + --model_id JetLM/SDAR-1.7B-Chat \ + --prompt "Explain what reinforcement learning is in simple terms." \ + --max_new_tokens 256 +""" + +import argparse + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from diffusers import SDARPipeline + + +def main(): + parser = argparse.ArgumentParser(description="Run SDAR block diffusion decoding.") + parser.add_argument( + "--model_id", + type=str, + default="JetLM/SDAR-1.7B-Chat", + help="Model ID or local path.", + ) + parser.add_argument( + "--prompt", + type=str, + default="Explain what reinforcement learning is in simple terms.", + help="Prompt text to generate from.", + ) + parser.add_argument("--max_new_tokens", type=int, default=256) + parser.add_argument("--block_length", type=int, default=4) + parser.add_argument("--num_inference_steps", type=int, default=4) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--top_k", type=int, default=0) + parser.add_argument("--top_p", type=float, default=1.0) + parser.add_argument( + "--remasking_strategy", + type=str, + default="low_confidence_dynamic", + choices=["low_confidence_dynamic", "low_confidence_static", "sequential", "entropy_bounded"], + ) + parser.add_argument("--confidence_threshold", type=float, default=0.9) + parser.add_argument("--entropy_threshold", type=float, default=0.35) + parser.add_argument("--mask_token_id", type=int, default=None) + parser.add_argument( + "--use_chat_template", + action="store_true", + help="Use the tokenizer chat template for the prompt.", + ) + parser.add_argument( + "--add_generation_prompt", + action="store_true", + help="Add the generation prompt when using the chat template.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to run inference on.", + ) + parser.add_argument( + "--dtype", + type=str, + default="auto", + choices=["auto", "float32", "float16", "bfloat16"], + help="Model dtype.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Model revision (branch, tag, or commit hash).", + ) + + args = parser.parse_args() + + dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16} + torch_dtype = dtype_map.get(args.dtype) + + print(f"Loading model: {args.model_id}") + model = AutoModelForCausalLM.from_pretrained( + args.model_id, + trust_remote_code=True, + torch_dtype=torch_dtype if torch_dtype is not None else "auto", + device_map=args.device, + revision=args.revision, + ) + tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True, revision=args.revision) + + if tokenizer.mask_token_id is None: + tokenizer.add_special_tokens({"mask_token": "<|MASK|>"}) + + pipe = SDARPipeline(model=model, tokenizer=tokenizer) + + print(f"\nPrompt: {args.prompt}") + output = pipe( + prompt=args.prompt, + max_new_tokens=args.max_new_tokens, + block_length=args.block_length, + num_inference_steps=args.num_inference_steps, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + remasking_strategy=args.remasking_strategy, + confidence_threshold=args.confidence_threshold, + entropy_threshold=args.entropy_threshold, + mask_token_id=args.mask_token_id, + use_chat_template=args.use_chat_template, + add_generation_prompt=args.add_generation_prompt, + ) + + print("\nGenerated text:") + print( + output.texts[0] + if output.texts is not None + else tokenizer.decode(output.sequences[0], skip_special_tokens=True) + ) + print(f"\nGenerated {output.sequences.shape[1]} tokens") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/train_dflash.py b/examples/discrete_diffusion/train_dflash.py new file mode 100644 index 000000000000..673a2173a058 --- /dev/null +++ b/examples/discrete_diffusion/train_dflash.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import os +from dataclasses import asdict, dataclass +from typing import Dict, Optional + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, get_scheduler + + +logger = get_logger(__name__) + + +@dataclass +class TrainConfig: + draft_model_id: str + target_model_id: str + dataset_name: str + dataset_config_name: Optional[str] + text_column: str + + output_dir: str + seed: int + max_train_steps: int + checkpointing_steps: int + logging_steps: int + + per_device_train_batch_size: int + gradient_accumulation_steps: int + learning_rate: float + weight_decay: float + lr_scheduler: str + lr_warmup_steps: int + + max_length: int + block_size: int + mask_token: str + + +def parse_args() -> TrainConfig: + parser = argparse.ArgumentParser(description="Fine-tune a DFlash draft model with target-conditioned blocks.") + + parser.add_argument("--draft_model_id", type=str, default="z-lab/Qwen3-4B-DFlash-b16") + parser.add_argument("--target_model_id", type=str, default="Qwen/Qwen3-4B") + parser.add_argument("--dataset_name", type=str, default="wikitext") + parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1") + parser.add_argument("--text_column", type=str, default="text") + + parser.add_argument("--output_dir", type=str, default="dflash-output") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--max_train_steps", type=int, default=1000) + parser.add_argument("--checkpointing_steps", type=int, default=500) + parser.add_argument("--logging_steps", type=int, default=50) + + parser.add_argument("--per_device_train_batch_size", type=int, default=2) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=2e-5) + parser.add_argument("--weight_decay", type=float, default=0.0) + parser.add_argument( + "--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"] + ) + parser.add_argument("--lr_warmup_steps", type=int, default=100) + + parser.add_argument("--max_length", type=int, default=512) + parser.add_argument( + "--block_size", type=int, default=0, help="Override draft block size (0 uses the model config)." + ) + parser.add_argument("--mask_token", type=str, default="<|MASK|>") + + args = parser.parse_args() + return TrainConfig(**vars(args)) + + +def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int): + texts = examples[text_column] + texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0] + return tokenizer(texts, truncation=True, padding=False, max_length=max_length) + + +def build_target_layer_ids(num_target_layers: int, num_draft_layers: int): + if num_draft_layers == 1: + return [int(num_target_layers // 2)] + start = 1 + end = int(num_target_layers) - 3 + span = end - start + return [int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(int(num_draft_layers))] + + +def extract_context_feature(hidden_states, layer_ids): + offset = 1 + selected_states = [hidden_states[layer_id + offset] for layer_id in layer_ids] + return torch.cat(selected_states, dim=-1) + + +def get_target_input_embeddings(model: torch.nn.Module) -> torch.nn.Module: + embeddings = model.get_input_embeddings() + if embeddings is None: + base = getattr(model, "model", None) + embeddings = getattr(base, "embed_tokens", None) + if embeddings is None: + raise ValueError("Target model must expose input embeddings.") + return embeddings + + +def get_target_output_embeddings(model: torch.nn.Module) -> torch.nn.Module: + embeddings = model.get_output_embeddings() + if embeddings is None: + embeddings = getattr(model, "lm_head", None) + if embeddings is None: + raise ValueError("Target model must expose output embeddings.") + return embeddings + + +def main(): + cfg = parse_args() + + project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs")) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + project_config=project_config, + ) + + if accelerator.is_main_process: + os.makedirs(cfg.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + set_seed(cfg.seed) + logger.info("Training configuration: %s", asdict(cfg)) + + tokenizer = AutoTokenizer.from_pretrained(cfg.target_model_id, use_fast=True) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.mask_token_id is None: + tokenizer.add_special_tokens({"mask_token": cfg.mask_token}) + + draft_model = AutoModel.from_pretrained(cfg.draft_model_id, trust_remote_code=True) + target_model = AutoModelForCausalLM.from_pretrained(cfg.target_model_id) + target_model.eval() + target_model.requires_grad_(False) + + mask_token_id = tokenizer.mask_token_id + if mask_token_id is None: + raise ValueError("Tokenizer must define a mask token for DFlash training.") + + input_embeddings = get_target_input_embeddings(target_model) + output_embeddings = get_target_output_embeddings(target_model) + + block_size = int(cfg.block_size) + if block_size <= 0: + block_size = getattr(draft_model, "block_size", None) or getattr( + getattr(draft_model, "config", None), "block_size", None + ) + if block_size is None: + raise ValueError("Draft model must define `block_size` or pass --block_size.") + block_size = int(block_size) + if block_size < 2: + raise ValueError("`block_size` must be at least 2 for DFlash training.") + + layer_ids = getattr(draft_model, "target_layer_ids", None) + if layer_ids is None: + cfg_draft = getattr(draft_model, "config", None) + num_target_layers = getattr(cfg_draft, "num_target_layers", None) + num_hidden_layers = getattr(cfg_draft, "num_hidden_layers", None) + if num_target_layers is None or num_hidden_layers is None: + raise ValueError("Draft model must expose `target_layer_ids` or `num_target_layers` in config.") + layer_ids = build_target_layer_ids(int(num_target_layers), int(num_hidden_layers)) + + raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name) + if "train" not in raw_datasets: + raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.") + + with accelerator.main_process_first(): + tokenized = raw_datasets["train"].map( + lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length), + batched=True, + remove_columns=raw_datasets["train"].column_names, + desc="Tokenizing", + ) + + collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt") + train_dataloader = DataLoader( + tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True + ) + + optimizer = torch.optim.AdamW(draft_model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps) + num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=cfg.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.lr_warmup_steps, + num_training_steps=cfg.max_train_steps, + ) + + draft_model, optimizer, train_dataloader, lr_scheduler, target_model = accelerator.prepare( + draft_model, optimizer, train_dataloader, lr_scheduler, target_model + ) + input_embeddings = get_target_input_embeddings(target_model) + output_embeddings = get_target_output_embeddings(target_model) + + global_step = 0 + draft_model.train() + + for epoch in range(num_train_epochs): + for batch in train_dataloader: + with accelerator.accumulate(draft_model): + input_ids = batch["input_ids"] + attention_mask = batch.get("attention_mask", torch.ones_like(input_ids)) + + valid_lengths = attention_mask.sum(dim=1) + min_valid = int(valid_lengths.min().item()) + if min_valid <= block_size: + continue + + max_start = min_valid - block_size + start = torch.randint(1, max_start + 1, (1,), device=input_ids.device).item() + + block_output_ids = torch.full( + (input_ids.shape[0], block_size), + int(mask_token_id), + device=input_ids.device, + dtype=torch.long, + ) + block_output_ids[:, 0] = input_ids[:, start] + block_targets = input_ids[:, start + 1 : start + block_size] + block_mask = attention_mask[:, start + 1 : start + block_size] + + position_ids = torch.arange(start, start + block_size, device=input_ids.device).unsqueeze(0) + position_ids = position_ids.expand(input_ids.shape[0], -1) + + with torch.no_grad(): + target_out = target_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + target_hidden = extract_context_feature(target_out.hidden_states, layer_ids) + target_hidden = target_hidden[:, :start, :] + + noise_embedding = input_embeddings(block_output_ids) + draft_hidden = draft_model( + target_hidden=target_hidden, + noise_embedding=noise_embedding, + position_ids=position_ids, + use_cache=False, + is_causal=False, + ) + if not torch.is_tensor(draft_hidden): + draft_hidden = getattr(draft_hidden, "last_hidden_state", draft_hidden[0]) + + logits = output_embeddings(draft_hidden[:, -block_size + 1 :, :]) + vocab_size = logits.shape[-1] + loss = F.cross_entropy(logits.view(-1, vocab_size), block_targets.reshape(-1), reduction="none") + loss = loss.view(block_targets.shape[0], -1) + loss = (loss * block_mask.to(loss.dtype)).sum() / block_mask.sum().clamp_min(1) + + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + global_step += 1 + + if global_step % cfg.logging_steps == 0 and accelerator.is_main_process: + logger.info("step=%d loss=%.4f lr=%.6g", global_step, loss.item(), lr_scheduler.get_last_lr()[0]) + + if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}") + os.makedirs(save_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(draft_model) + unwrapped.save_pretrained(save_dir, save_function=accelerator.save) + tokenizer.save_pretrained(save_dir) + + if global_step >= cfg.max_train_steps: + break + + if global_step >= cfg.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + final_dir = os.path.join(cfg.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(draft_model) + unwrapped.save_pretrained(final_dir, save_function=accelerator.save) + tokenizer.save_pretrained(final_dir) + + logger.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/train_hybrid_token_diffusion.py b/examples/discrete_diffusion/train_hybrid_token_diffusion.py new file mode 100644 index 000000000000..ee23b430f555 --- /dev/null +++ b/examples/discrete_diffusion/train_hybrid_token_diffusion.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python + +import argparse +import math +import os +from dataclasses import asdict, dataclass +from typing import Dict, Optional + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import ( + AutoConfig, + AutoModelForMaskedLM, + AutoTokenizer, + DataCollatorForLanguageModeling, + get_scheduler, +) + +from diffusers import HybridTokenDiffusionScheduler +from diffusers.training_utils import compute_confidence_aware_loss + + +logger = get_logger(__name__) + + +@dataclass +class TrainConfig: + model_name_or_path: str + dataset_name: str + dataset_config_name: Optional[str] + text_column: str + + output_dir: str + seed: int + max_train_steps: int + checkpointing_steps: int + logging_steps: int + + per_device_train_batch_size: int + gradient_accumulation_steps: int + learning_rate: float + weight_decay: float + lr_scheduler: str + lr_warmup_steps: int + + max_length: int + num_train_timesteps: int + t_eps: float + p_uniform: float + gamma: float + lambda_conf: float + conf_temperature: float + + +def parse_args() -> TrainConfig: + parser = argparse.ArgumentParser(description="Train a hybrid-transition token diffusion model with Accelerate.") + + parser.add_argument("--model_name_or_path", type=str, default="bert-base-uncased") + parser.add_argument("--dataset_name", type=str, default="wikitext") + parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1") + parser.add_argument("--text_column", type=str, default="text") + + parser.add_argument("--output_dir", type=str, default="hybrid-output") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--max_train_steps", type=int, default=1000) + parser.add_argument("--checkpointing_steps", type=int, default=500) + parser.add_argument("--logging_steps", type=int, default=50) + + parser.add_argument("--per_device_train_batch_size", type=int, default=8) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=5e-5) + parser.add_argument("--weight_decay", type=float, default=0.01) + parser.add_argument( + "--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"] + ) + parser.add_argument("--lr_warmup_steps", type=int, default=100) + + parser.add_argument("--max_length", type=int, default=256) + parser.add_argument("--num_train_timesteps", type=int, default=1000) + parser.add_argument("--t_eps", type=float, default=1e-4) + parser.add_argument("--p_uniform", type=float, default=0.0) + parser.add_argument("--gamma", type=float, default=1.0) + parser.add_argument( + "--lambda_conf", + type=float, + default=0.0, + help="Optional confidence-aware penalty weight (entropy on correctly predicted tokens).", + ) + parser.add_argument( + "--conf_temperature", + type=float, + default=1.0, + help="Temperature for the confidence term only; lower values sharpen the entropy penalty.", + ) + + args = parser.parse_args() + return TrainConfig(**vars(args)) + + +def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int): + texts = examples[text_column] + texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0] + return tokenizer( + texts, + truncation=True, + padding=False, + max_length=max_length, + return_special_tokens_mask=True, + ) + + +def main(): + cfg = parse_args() + + project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs")) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + project_config=project_config, + ) + + if accelerator.is_main_process: + os.makedirs(cfg.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + set_seed(cfg.seed) + logger.info("Training configuration: %s", asdict(cfg)) + + tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True) + if tokenizer.mask_token_id is None: + tokenizer.add_special_tokens({"mask_token": "[MASK]"}) + + config = AutoConfig.from_pretrained(cfg.model_name_or_path) + model = AutoModelForMaskedLM.from_pretrained(cfg.model_name_or_path, config=config) + model.resize_token_embeddings(len(tokenizer)) + + scheduler = HybridTokenDiffusionScheduler( + vocab_size=len(tokenizer), + mask_token_id=int(tokenizer.mask_token_id), + num_train_timesteps=cfg.num_train_timesteps, + t_eps=cfg.t_eps, + p_uniform=cfg.p_uniform, + gamma=cfg.gamma, + ) + + raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name) + if "train" not in raw_datasets: + raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.") + + with accelerator.main_process_first(): + tokenized = raw_datasets["train"].map( + lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length), + batched=True, + remove_columns=raw_datasets["train"].column_names, + desc="Tokenizing", + ) + + collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt") + train_dataloader = DataLoader( + tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True + ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps) + num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=cfg.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.lr_warmup_steps, + num_training_steps=cfg.max_train_steps, + ) + + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + global_step = 0 + model.train() + + for epoch in range(num_train_epochs): + for batch in train_dataloader: + with accelerator.accumulate(model): + input_ids = batch["input_ids"] + attention_mask = batch.get("attention_mask", torch.ones_like(input_ids)) + + timesteps = torch.randint( + 0, scheduler.num_train_timesteps, (input_ids.shape[0],), device=input_ids.device, dtype=torch.long + ) + + x_t = scheduler.add_noise(input_ids, noise=None, timesteps=timesteps) + logits = model(input_ids=x_t, attention_mask=attention_mask).logits + + # For this hybrid kernel, we use a simple denoising objective: predict x0 from z_t. + logits = logits.clone() + logits[..., scheduler.mask_token_id] = torch.finfo(logits.dtype).min + + labels = input_ids.clone() + labels[attention_mask.eq(0)] = -100 + per_token_weights = attention_mask.to(dtype=logits.dtype) + loss, loss_sft, loss_conf = compute_confidence_aware_loss( + logits, + labels, + lambda_conf=cfg.lambda_conf, + temperature=cfg.conf_temperature, + per_token_weights=per_token_weights, + ) + + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + global_step += 1 + + if global_step % cfg.logging_steps == 0 and accelerator.is_main_process: + logger.info( + "step=%d loss=%.4f loss_sft=%.4f loss_conf=%.4f lr=%.6g", + global_step, + loss.item(), + loss_sft.item(), + loss_conf.item(), + lr_scheduler.get_last_lr()[0], + ) + + if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}") + os.makedirs(save_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(model) + unwrapped.save_pretrained(save_dir, save_function=accelerator.save) + tokenizer.save_pretrained(save_dir) + scheduler.save_pretrained(save_dir) + + if global_step >= cfg.max_train_steps: + break + + if global_step >= cfg.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + final_dir = os.path.join(cfg.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(model) + unwrapped.save_pretrained(final_dir, save_function=accelerator.save) + tokenizer.save_pretrained(final_dir) + scheduler.save_pretrained(final_dir) + + logger.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/train_mdlm.py b/examples/discrete_diffusion/train_mdlm.py new file mode 100644 index 000000000000..59b323d35275 --- /dev/null +++ b/examples/discrete_diffusion/train_mdlm.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import os +from dataclasses import asdict, dataclass +from typing import Dict, Optional + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import ( + AutoConfig, + AutoModelForMaskedLM, + AutoTokenizer, + DataCollatorForLanguageModeling, + get_scheduler, +) + +from diffusers import TokenDiffusionScheduler +from diffusers.training_utils import compute_confidence_aware_loss + + +logger = get_logger(__name__) + + +@dataclass +class TrainConfig: + model_name_or_path: str + dataset_name: str + dataset_config_name: Optional[str] + text_column: str + + output_dir: str + seed: int + max_train_steps: int + checkpointing_steps: int + logging_steps: int + + per_device_train_batch_size: int + gradient_accumulation_steps: int + learning_rate: float + weight_decay: float + lr_scheduler: str + lr_warmup_steps: int + + max_length: int + num_train_timesteps: int + alpha_schedule: str + eps: float + sigma_min: float + sigma_max: float + min_timestep: int + lambda_conf: float + conf_temperature: float + + +def parse_args() -> TrainConfig: + parser = argparse.ArgumentParser(description="Train an absorbing token diffusion LM (MDLM-style) with Accelerate.") + + parser.add_argument("--model_name_or_path", type=str, default="bert-base-uncased") + parser.add_argument("--dataset_name", type=str, default="wikitext") + parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1") + parser.add_argument("--text_column", type=str, default="text") + + parser.add_argument("--output_dir", type=str, default="mdlm-output") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--max_train_steps", type=int, default=1000) + parser.add_argument("--checkpointing_steps", type=int, default=500) + parser.add_argument("--logging_steps", type=int, default=50) + + parser.add_argument("--per_device_train_batch_size", type=int, default=8) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=5e-5) + parser.add_argument("--weight_decay", type=float, default=0.01) + parser.add_argument( + "--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"] + ) + parser.add_argument("--lr_warmup_steps", type=int, default=100) + + parser.add_argument("--max_length", type=int, default=256) + + parser.add_argument("--num_train_timesteps", type=int, default=1000) + parser.add_argument( + "--alpha_schedule", + type=str, + default="log_linear", + choices=["log_linear", "linear", "cosine", "geometric"], + ) + parser.add_argument("--eps", type=float, default=1e-3) + parser.add_argument("--sigma_min", type=float, default=1e-4) + parser.add_argument("--sigma_max", type=float, default=20.0) + parser.add_argument("--min_timestep", type=int, default=1, help="Avoid t=0 to prevent 1/t weighting blow-ups.") + parser.add_argument( + "--lambda_conf", + type=float, + default=0.0, + help="Optional confidence-aware penalty weight (entropy on correctly predicted tokens).", + ) + parser.add_argument( + "--conf_temperature", + type=float, + default=1.0, + help="Temperature for the confidence term only; lower values sharpen the entropy penalty.", + ) + + args = parser.parse_args() + return TrainConfig(**vars(args)) + + +def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int): + texts = examples[text_column] + # drop empty lines + texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0] + return tokenizer( + texts, + truncation=True, + padding=False, + max_length=max_length, + return_special_tokens_mask=True, + ) + + +def main(): + cfg = parse_args() + + project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs")) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + project_config=project_config, + ) + + if accelerator.is_main_process: + os.makedirs(cfg.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + set_seed(cfg.seed) + logger.info("Training configuration: %s", asdict(cfg)) + + tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True) + if tokenizer.mask_token_id is None: + # MDLM-style absorbing diffusion assumes a mask token exists. + tokenizer.add_special_tokens({"mask_token": "[MASK]"}) + + config = AutoConfig.from_pretrained(cfg.model_name_or_path) + model = AutoModelForMaskedLM.from_pretrained(cfg.model_name_or_path, config=config) + model.resize_token_embeddings(len(tokenizer)) + + scheduler = TokenDiffusionScheduler( + vocab_size=len(tokenizer), + mask_token_id=int(tokenizer.mask_token_id), + num_train_timesteps=cfg.num_train_timesteps, + alpha_schedule=cfg.alpha_schedule, + eps=cfg.eps, + sigma_min=cfg.sigma_min, + sigma_max=cfg.sigma_max, + ) + + raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name) + if "train" not in raw_datasets: + raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.") + + with accelerator.main_process_first(): + tokenized = raw_datasets["train"].map( + lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length), + batched=True, + remove_columns=raw_datasets["train"].column_names, + desc="Tokenizing", + ) + + # We reuse the standard MLM collator to pad and build attention masks; we won't use its masking. + collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt") + train_dataloader = DataLoader( + tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True + ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps) + num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=cfg.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.lr_warmup_steps, + num_training_steps=cfg.max_train_steps, + ) + + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + global_step = 0 + model.train() + + for epoch in range(num_train_epochs): + for batch in train_dataloader: + with accelerator.accumulate(model): + input_ids = batch["input_ids"] + attention_mask = batch.get("attention_mask", torch.ones_like(input_ids)) + + # Sample discrete time indices (avoid timestep 0 for stability with 1/t weighting). + min_t = max(1, int(cfg.min_timestep)) + max_t = scheduler.num_train_timesteps - 1 + timesteps = torch.randint(min_t, max_t + 1, (input_ids.shape[0],), device=input_ids.device) + + # Forward process q(x_t | x_0): replace tokens with [MASK] according to alpha(t). + x_t = scheduler.add_noise(input_ids, noise=None, timesteps=timesteps) + + # Model predicts token logits for x0 reconstruction. + logits = model(input_ids=x_t, attention_mask=attention_mask).logits # [B, L, V] + + # MDLM-style constraints: + # - Do not predict the mask token as x0. + logits = logits.clone() + logits[..., scheduler.mask_token_id] = torch.finfo(logits.dtype).min + + # Only compute loss on tokens that were masked by the forward process. + mask_positions = x_t.eq(scheduler.mask_token_id) & attention_mask.to(dtype=torch.bool) + + weights = scheduler.get_mdlm_loss_weights(timesteps) + + labels = input_ids.clone() + labels[~mask_positions] = -100 + + per_token_weights = weights.to(dtype=logits.dtype).expand_as(labels) + loss, loss_sft, loss_conf = compute_confidence_aware_loss( + logits, + labels, + lambda_conf=cfg.lambda_conf, + temperature=cfg.conf_temperature, + per_token_weights=per_token_weights, + ) + + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + global_step += 1 + + if global_step % cfg.logging_steps == 0 and accelerator.is_main_process: + logger.info( + "step=%d loss=%.4f loss_sft=%.4f loss_conf=%.4f lr=%.6g", + global_step, + loss.item(), + loss_sft.item(), + loss_conf.item(), + lr_scheduler.get_last_lr()[0], + ) + + if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}") + os.makedirs(save_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(model) + unwrapped.save_pretrained(save_dir, save_function=accelerator.save) + tokenizer.save_pretrained(save_dir) + scheduler.save_pretrained(save_dir) + + if global_step >= cfg.max_train_steps: + break + + if global_step >= cfg.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + final_dir = os.path.join(cfg.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(model) + unwrapped.save_pretrained(final_dir, save_function=accelerator.save) + tokenizer.save_pretrained(final_dir) + scheduler.save_pretrained(final_dir) + + logger.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/examples/discrete_diffusion/train_udlm.py b/examples/discrete_diffusion/train_udlm.py new file mode 100644 index 000000000000..8c61790defc2 --- /dev/null +++ b/examples/discrete_diffusion/train_udlm.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import os +from dataclasses import asdict, dataclass +from typing import Dict, Optional + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import ( + AutoConfig, + AutoModelForMaskedLM, + AutoTokenizer, + DataCollatorForLanguageModeling, + get_scheduler, +) + +from diffusers import TokenDiffusionScheduler + + +logger = get_logger(__name__) + + +@dataclass +class TrainConfig: + model_name_or_path: str + dataset_name: str + dataset_config_name: Optional[str] + text_column: str + + output_dir: str + seed: int + max_train_steps: int + checkpointing_steps: int + logging_steps: int + + per_device_train_batch_size: int + gradient_accumulation_steps: int + learning_rate: float + weight_decay: float + lr_scheduler: str + lr_warmup_steps: int + + max_length: int + num_train_timesteps: int + alpha_schedule: str + eps: float + sigma_min: float + sigma_max: float + min_timestep: int + exclude_mask_from_uniform: bool + + +def parse_args() -> TrainConfig: + parser = argparse.ArgumentParser(description="Train a uniform token diffusion LM (UDLM-style) with Accelerate.") + + parser.add_argument("--model_name_or_path", type=str, default="bert-base-uncased") + parser.add_argument("--dataset_name", type=str, default="wikitext") + parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1") + parser.add_argument("--text_column", type=str, default="text") + + parser.add_argument("--output_dir", type=str, default="udlm-output") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--max_train_steps", type=int, default=1000) + parser.add_argument("--checkpointing_steps", type=int, default=500) + parser.add_argument("--logging_steps", type=int, default=50) + + parser.add_argument("--per_device_train_batch_size", type=int, default=8) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=5e-5) + parser.add_argument("--weight_decay", type=float, default=0.01) + parser.add_argument( + "--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"] + ) + parser.add_argument("--lr_warmup_steps", type=int, default=100) + + parser.add_argument("--max_length", type=int, default=256) + parser.add_argument("--num_train_timesteps", type=int, default=1000) + parser.add_argument( + "--alpha_schedule", + type=str, + default="log_linear", + choices=["log_linear", "linear", "cosine", "geometric"], + help="Alpha schedule used for the uniform forward process and the continuous-time UDLM objective.", + ) + parser.add_argument("--eps", type=float, default=1e-3) + parser.add_argument("--sigma_min", type=float, default=1e-4) + parser.add_argument("--sigma_max", type=float, default=20.0) + parser.add_argument("--min_timestep", type=int, default=1) + parser.add_argument( + "--exclude_mask_from_uniform", action="store_true", help="Exclude mask token from uniform draws." + ) + + args = parser.parse_args() + return TrainConfig(**vars(args)) + + +def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int): + texts = examples[text_column] + texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0] + return tokenizer( + texts, + truncation=True, + padding=False, + max_length=max_length, + return_special_tokens_mask=True, + ) + + +def udlm_diffusion_loss( + logits: torch.Tensor, + x0: torch.LongTensor, + x_t: torch.LongTensor, + *, + alpha_t: torch.Tensor, + dalpha_t: torch.Tensor, +): + """ + UDLM diffusion loss (continuous-time form). + + Args: + logits: [B, L, V] + x0: [B, L] + x_t: [B, L] + alpha_t: [B, 1] alpha(t) for the uniform forward process. + dalpha_t: [B, 1] time derivative alpha'(t) with respect to continuous time t in [0, 1]. + Returns: + loss_per_token: [B, L] + """ + log_x_theta = torch.log_softmax(logits, dim=-1) + B, L, V = log_x_theta.shape + + alpha = alpha_t.to(dtype=log_x_theta.dtype).view(B, 1, 1) + alpha_prime = dalpha_t.to(dtype=log_x_theta.dtype).view(B, 1, 1) + + x0_one_hot = F.one_hot(x0, V).to(dtype=log_x_theta.dtype) + xt_one_hot = F.one_hot(x_t, V).to(dtype=log_x_theta.dtype) + + x_bar = V * alpha * x0_one_hot + 1.0 - alpha + x_bar_theta = V * alpha * log_x_theta.exp() + 1.0 - alpha + + coeff = alpha_prime / (V * alpha.clamp_min(torch.finfo(alpha.dtype).eps)) + + x_bar_zt = (x_bar * xt_one_hot).sum(dim=-1, keepdim=True) # (B, L, 1) + x_bar_theta_zt = (x_bar_theta * xt_one_hot).sum(dim=-1, keepdim=True) # (B, L, 1) + + term1 = (V / x_bar_zt) - (V / x_bar_theta_zt) + + term2 = ((x_bar / x_bar_zt) * (x_bar_theta_zt.log() - x_bar_theta.log() + x_bar.log() - x_bar_zt.log())).sum( + dim=-1, keepdim=True + ) + + diffusion_loss = (coeff * (term1 - term2)).squeeze(-1) # (B, L) + return diffusion_loss + + +def main(): + cfg = parse_args() + + project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs")) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + project_config=project_config, + ) + + if accelerator.is_main_process: + os.makedirs(cfg.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + set_seed(cfg.seed) + logger.info("Training configuration: %s", asdict(cfg)) + + tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True) + if tokenizer.mask_token_id is None: + tokenizer.add_special_tokens({"mask_token": "[MASK]"}) + + config = AutoConfig.from_pretrained(cfg.model_name_or_path) + model = AutoModelForMaskedLM.from_pretrained(cfg.model_name_or_path, config=config) + model.resize_token_embeddings(len(tokenizer)) + + scheduler = TokenDiffusionScheduler( + vocab_size=len(tokenizer), + mask_token_id=int(tokenizer.mask_token_id), + num_train_timesteps=cfg.num_train_timesteps, + alpha_schedule=cfg.alpha_schedule, + eps=cfg.eps, + sigma_min=cfg.sigma_min, + sigma_max=cfg.sigma_max, + forward_process="uniform", + exclude_mask_from_uniform=cfg.exclude_mask_from_uniform, + ) + + raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name) + if "train" not in raw_datasets: + raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.") + + with accelerator.main_process_first(): + tokenized = raw_datasets["train"].map( + lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length), + batched=True, + remove_columns=raw_datasets["train"].column_names, + desc="Tokenizing", + ) + + collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt") + train_dataloader = DataLoader( + tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True + ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps) + num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=cfg.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.lr_warmup_steps, + num_training_steps=cfg.max_train_steps, + ) + + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + global_step = 0 + model.train() + + for epoch in range(num_train_epochs): + for batch in train_dataloader: + with accelerator.accumulate(model): + input_ids = batch["input_ids"] + attention_mask = batch.get("attention_mask", torch.ones_like(input_ids)) + + min_t = max(1, int(cfg.min_timestep)) + max_t = scheduler.num_train_timesteps - 1 + timesteps = torch.randint(min_t, max_t + 1, (input_ids.shape[0],), device=input_ids.device) + + x_t = scheduler.add_noise(input_ids, noise=None, timesteps=timesteps) + logits = model(input_ids=x_t, attention_mask=attention_mask).logits + + if scheduler.exclude_mask_from_uniform: + logits = logits.clone() + logits[..., scheduler.mask_token_id] = torch.finfo(logits.dtype).min + + alpha_t = scheduler.get_alpha(timesteps) + dalpha_t = scheduler.get_alpha_prime(timesteps) + loss_per_token = udlm_diffusion_loss(logits, input_ids, x_t, alpha_t=alpha_t, dalpha_t=dalpha_t) + loss = (loss_per_token * attention_mask.to(loss_per_token.dtype)).sum() + loss = loss / attention_mask.sum().clamp_min(1) + + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if accelerator.sync_gradients: + global_step += 1 + + if global_step % cfg.logging_steps == 0 and accelerator.is_main_process: + logger.info("step=%d loss=%.4f lr=%.6g", global_step, loss.item(), lr_scheduler.get_last_lr()[0]) + + if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}") + os.makedirs(save_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(model) + unwrapped.save_pretrained(save_dir, save_function=accelerator.save) + tokenizer.save_pretrained(save_dir) + scheduler.save_pretrained(save_dir) + + if global_step >= cfg.max_train_steps: + break + + if global_step >= cfg.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + final_dir = os.path.join(cfg.output_dir, "final") + os.makedirs(final_dir, exist_ok=True) + unwrapped = accelerator.unwrap_model(model) + unwrapped.save_pretrained(final_dir, save_function=accelerator.save) + tokenizer.save_pretrained(final_dir) + scheduler.save_pretrained(final_dir) + + logger.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py index a65767d084b6..4c8951a4400e 100644 --- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py @@ -28,7 +28,6 @@ import torch.nn.functional as F import torch.utils.checkpoint import transformers -import wandb from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed @@ -43,6 +42,7 @@ from transformers import AutoTokenizer, PretrainedConfig import diffusers +import wandb from diffusers import ( AutoencoderKL, DDPMScheduler, diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py index 756b20bb8d26..a7af773e2fb7 100644 --- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py @@ -29,7 +29,6 @@ import torch.nn.functional as F import torch.utils.checkpoint import transformers -import wandb from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed @@ -45,6 +44,7 @@ from transformers import AutoTokenizer, PretrainedConfig import diffusers +import wandb from diffusers import ( AutoencoderKL, DDPMScheduler, diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py index 5a1b26f88604..f8850f591aec 100644 --- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py +++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py @@ -29,7 +29,6 @@ import torch.nn.functional as F import torch.utils.checkpoint import transformers -import wandb from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed @@ -45,6 +44,7 @@ from transformers import AutoTokenizer, PretrainedConfig import diffusers +import wandb from diffusers import ( AutoencoderKL, DDPMScheduler, diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py index f1bfaa2fb551..1fc89191b212 100644 --- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py +++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py @@ -28,7 +28,6 @@ import torch.nn.functional as F import torch.utils.checkpoint import transformers -import wandb import webdataset as wds from accelerate import Accelerator from accelerate.logging import get_logger @@ -43,6 +42,7 @@ from transformers import AutoTokenizer, PretrainedConfig import diffusers +import wandb from diffusers import ( AutoencoderKL, DDPMScheduler, diff --git a/examples/research_projects/unified_latents/README.md b/examples/research_projects/unified_latents/README.md new file mode 100644 index 000000000000..97ef95ca033e --- /dev/null +++ b/examples/research_projects/unified_latents/README.md @@ -0,0 +1,189 @@ +# Unified Latents (UL) Training (Diffusers Research Scripts) + +This folder contains a Diffusers-based implementation of Unified Latents from `2602.17270`. + +Current scripts: +- `train_ul_stage1.py`: joint UL stage-1 training (`encoder + prior + decoder`) +- `train_ul_stage2_base.py`: UL stage-2 base model training on frozen stage-1 encoder latents + +## What Is Implemented + +- Stage 1 (UL latent learning): + - deterministic ResNet-like encoder (`AutoencoderULEncoder`) with Section 5.1-style stage depths `[2,2,2,3]` + - latent prior denoiser (DiT-style) + - decoder denoiser (UViT-style approximation) +- Stage 2 (UL base model): + - two-stage ViT-like latent denoiser approximation + - trained with stage-1 encoder frozen + +Notes on architecture fidelity: +- Section 5.1 defaults are used where practical. +- Decoder is an approximation (concat-conditioned conv+attention) rather than the paper's exact dedicated UViT implementation. + +## Requirements + +Run from repo root with Diffusers import path available: + +```bash +cd /path/to/diffusers +export PYTHONPATH=src +``` + +Use `accelerate launch ...` for training. + +## Dataset Format + +Both stages expect `torchvision.datasets.ImageFolder` layout: + +```text +data_root/ + class_a/ + img1.png + img2.png + class_b/ + img3.png +``` + +## Stage 1 Training + +Stage-1 objective implementation follows UL Algorithm 1: +- prior term uses `(-d lambda_z/dt) * exp(lambda_z) / 2 * ||z_clean - z_hat||^2` +- decoder term uses `(-d lambda_x/dt) * exp(lambda_x) / 2 * w(lambda_x) * ||x - x_hat||^2` with `w(lambda)=sigmoid(lambda-b)` +- prior terminal KL `KL[q(z1|x)||N(0,I)]` is always included (paper Algorithm 1) +- `||.||^2` uses true squared sums and losses are reported in bits-per-pixel via division by `num_pixels * ln(2)` + +Command: + +```bash +accelerate launch examples/research_projects/unified_latents/train_ul_stage1.py \ + --train_data_dir /path/to/imagefolder \ + --output_dir ul-stage1 \ + --resolution 256 \ + --train_batch_size 8 \ + --max_train_steps 10000 +``` + +Recommended useful options: +- `--report_to tensorboard|wandb` +- `--tracker_project_name unified-latents-stage1` +- `--checkpoints_total_limit 3` +- `--resume_from_checkpoint latest` +- `--num_workers 0` (useful in restricted environments) + +Stage 1 outputs: +- `ul-stage1/checkpoint-*/`: + - accelerate state files + - `encoder/`, `prior/`, `decoder/` (Diffusers `save_pretrained` format) +- `ul-stage1/final/`: + - `encoder/`, `prior/`, `decoder/` (Diffusers format) + - `encoder.pt`, `prior.pt`, `decoder.pt` (raw state_dict) + +## Stage 2 Training + +Train the base model using the frozen stage-1 encoder: + +Stage-2 objective uses paper-style weighted ELBO on latents: +- diffusion target is the clean encoder mean latent `z_clean` +- training target is clean encoder latent (`z_clean`) for lower variance +- base model uses v-prediction parameterization and computes weighted ELBO in x-space +- loss uses `(-d lambda_z/dt) * exp(lambda_z) / 2 * w(lambda_z) * ||z_clean - z_hat||^2` with `w(lambda)=sigmoid(lambda-b)` +- sampling stops at `logsnr_0` and passes the resulting noisy latent `z0` to the decoder conditioning path + +```bash +accelerate launch examples/research_projects/unified_latents/train_ul_stage2_base.py \ + --train_data_dir /path/to/imagefolder \ + --stage1_encoder_path /path/to/ul-stage1/final/encoder \ + --output_dir ul-stage2-base \ + --resolution 256 \ + --train_batch_size 8 \ + --max_train_steps 10000 +``` + +`--stage1_encoder_path` supports: +- Diffusers encoder directory (recommended), e.g. `.../final/encoder` +- raw checkpoint file, e.g. `.../final/encoder.pt` + +Recommended useful options: +- `--report_to tensorboard|wandb` +- `--tracker_project_name unified-latents-stage2` +- `--checkpoints_total_limit 3` +- `--resume_from_checkpoint latest` +- `--num_workers 0` (useful in restricted environments) + +Stage 2 outputs: +- `ul-stage2-base/checkpoint-*/`: + - accelerate state files + - `base_model/` (Diffusers `save_pretrained` format) +- `ul-stage2-base/final/`: + - `base_model/` (Diffusers format) + - `base_model.pt` (raw state_dict) + +## Resume Examples + +Stage 1 resume: + +```bash +accelerate launch examples/research_projects/unified_latents/train_ul_stage1.py \ + --train_data_dir /path/to/imagefolder \ + --output_dir ul-stage1 \ + --resume_from_checkpoint latest +``` + +Stage 2 resume: + +```bash +accelerate launch examples/research_projects/unified_latents/train_ul_stage2_base.py \ + --train_data_dir /path/to/imagefolder \ + --stage1_encoder_path /path/to/ul-stage1/final/encoder \ + --output_dir ul-stage2-base \ + --resume_from_checkpoint latest +``` + +## Quick Smoke-Test Settings + +For a fast sanity run on a tiny dataset: + +- `--resolution 64` +- `--train_batch_size 2` +- `--max_train_steps 1` +- `--save_steps 1` +- `--num_workers 0` +- `--mixed_precision no` + +These settings validate training loop, checkpointing, and serialization paths. + + +## ImageNet-512 + Weights & Biases (wandb) + +Use either a Hub dataset (`--dataset_name`) or local ImageNet-512 imagefolder (`--train_data_dir`). + +Stage 1 example (local ImageNet-512): + +```bash +accelerate launch examples/research_projects/unified_latents/train_ul_stage1.py \ + --train_data_dir /path/to/imagenet512_imagefolder \ + --output_dir ul-stage1-imagenet512 \ + --resolution 512 \ + --train_batch_size 8 \ + --max_train_steps 100000 \ + --report_to wandb \ + --tracker_project_name ul-imagenet512-stage1 +``` + +Stage 2 example: + +```bash +accelerate launch examples/research_projects/unified_latents/train_ul_stage2_base.py \ + --train_data_dir /path/to/imagenet512_imagefolder \ + --stage1_encoder_path /path/to/ul-stage1-imagenet512/final/encoder \ + --output_dir ul-stage2-imagenet512 \ + --resolution 512 \ + --train_batch_size 8 \ + --max_train_steps 100000 \ + --report_to wandb \ + --tracker_project_name ul-imagenet512-stage2 +``` + +Hub dataset variant: +- replace `--train_data_dir ...` with `--dataset_name ` +- optionally set `--dataset_config_name ...` and `--image_column ...` diff --git a/examples/research_projects/unified_latents/eval_ul.py b/examples/research_projects/unified_latents/eval_ul.py new file mode 100644 index 000000000000..98003d9d0170 --- /dev/null +++ b/examples/research_projects/unified_latents/eval_ul.py @@ -0,0 +1,412 @@ +#!/usr/bin/env python +# coding=utf-8 + +import argparse +import json +from pathlib import Path + +import torch +import torch.nn.functional as F +from torchvision import transforms +from torchvision.datasets import ImageFolder +from torchvision.models import Inception_V3_Weights, inception_v3 +from torchvision.utils import make_grid, save_image +from ul_models import ULTwoStageBaseModel + +from diffusers import UNet2DModel +from diffusers.models.autoencoders import AutoencoderULEncoder +from diffusers.training_utils import ul_alpha_sigma_from_logsnr, ul_logsnr_schedule + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Evaluate Unified Latents stage-1 reconstruction and stage-2 realism." + ) + parser.add_argument("--train_data_dir", type=str, required=True, help="ImageFolder root with real images.") + parser.add_argument("--stage1_encoder_path", type=str, required=True, help="Path to stage-1 encoder directory.") + parser.add_argument("--stage1_decoder_path", type=str, required=True, help="Path to stage-1 decoder directory.") + parser.add_argument("--stage2_base_path", type=str, required=True, help="Path to stage-2 base model directory.") + parser.add_argument("--output_dir", type=str, default="ul-eval", help="Directory to store evaluation outputs.") + parser.add_argument("--resolution", type=int, default=256) + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument("--num_workers", type=int, default=0) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=str, default="cuda") + + parser.add_argument("--num_recon_samples", type=int, default=256) + parser.add_argument("--num_gen_samples", type=int, default=1024) + parser.add_argument("--num_sampling_steps", type=int, default=30) + parser.add_argument("--num_train_timesteps", type=int, default=1000) + parser.add_argument("--latent_channels", type=int, default=4) + parser.add_argument("--latent_downsample_factor", type=int, default=8) + + parser.add_argument("--base_schedule", type=str, default="linear", choices=["linear", "cosine"]) + parser.add_argument("--decoder_schedule", type=str, default="linear", choices=["linear", "cosine"]) + parser.add_argument("--lambda_z_min", type=float, default=-10.0) + parser.add_argument("--lambda_z_max", type=float, default=5.0) + parser.add_argument("--lambda_x_min", type=float, default=-10.0) + parser.add_argument("--lambda_x_max", type=float, default=10.0) + + parser.add_argument( + "--recon_use_noisy_z0", action="store_true", help="Use noisy z0 for recon instead of deterministic z0." + ) + parser.add_argument("--decoder_prediction_type", type=str, default="x0", choices=["epsilon", "x0"]) + parser.add_argument("--save_recon_grid", type=str, default="stage1_recon_grid.png") + parser.add_argument("--save_stage2_grid", type=str, default="stage2_samples_grid.png") + parser.add_argument("--grid_items", type=int, default=8) + + parser.add_argument("--kid_subset_size", type=int, default=256) + parser.add_argument("--kid_subsets", type=int, default=10) + + parser.add_argument("--pass_psnr", type=float, default=18.0) + parser.add_argument("--pass_mae", type=float, default=0.10) + parser.add_argument("--pass_kid", type=float, default=0.02) + + return parser.parse_args() + + +def _build_loader(data_dir: str, resolution: int, batch_size: int, num_workers: int): + image_transform = transforms.Compose( + [ + transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(resolution), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + dataset = ImageFolder(root=data_dir, transform=image_transform) + loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + drop_last=False, + ) + return dataset, loader + + +def _decode_image_diffusion( + decoder: UNet2DModel, + z0: torch.Tensor, + *, + resolution: int, + num_sampling_steps: int, + num_train_timesteps: int, + schedule: str, + lambda_min: float, + lambda_max: float, + decoder_prediction_type: str, +) -> torch.Tensor: + bsz = z0.shape[0] + z0_up = F.interpolate(z0, size=(resolution, resolution), mode="bilinear", align_corners=False) + + x_t = torch.randn(bsz, 3, resolution, resolution, device=z0.device, dtype=z0.dtype) + t_grid = torch.linspace(1.0, 0.0, num_sampling_steps + 1, device=z0.device, dtype=z0.dtype) + for i in range(num_sampling_steps): + t_cur = t_grid[i].repeat(bsz) + t_nxt = t_grid[i + 1].repeat(bsz) + + lambda_cur = ul_logsnr_schedule(t_cur, schedule_type=schedule, lambda_min=lambda_min, lambda_max=lambda_max) + lambda_nxt = ul_logsnr_schedule(t_nxt, schedule_type=schedule, lambda_min=lambda_min, lambda_max=lambda_max) + alpha_cur, sigma_cur = ul_alpha_sigma_from_logsnr(lambda_cur) + alpha_nxt, sigma_nxt = ul_alpha_sigma_from_logsnr(lambda_nxt) + + timestep_idx = (t_cur * (num_train_timesteps - 1)).long().clamp(0, num_train_timesteps - 1) + decoder_pred = decoder(torch.cat([x_t, z0_up], dim=1), timestep_idx).sample + if decoder_prediction_type == "epsilon": + eps_hat = decoder_pred + x0_hat = (x_t - sigma_cur[:, None, None, None] * eps_hat) / alpha_cur[:, None, None, None].clamp_min(1e-5) + else: + x0_hat = decoder_pred + eps_hat = (x_t - alpha_cur[:, None, None, None] * x0_hat) / sigma_cur[:, None, None, None].clamp_min(1e-5) + x_t = alpha_nxt[:, None, None, None] * x0_hat + sigma_nxt[:, None, None, None] * eps_hat + + return x_t.clamp(-1, 1) + + +def _sample_stage2_latent( + base_model: ULTwoStageBaseModel, + *, + batch_size: int, + latent_channels: int, + latent_size: int, + num_sampling_steps: int, + num_train_timesteps: int, + schedule: str, + lambda_min: float, + lambda_max: float, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + z_t = torch.randn(batch_size, latent_channels, latent_size, latent_size, device=device, dtype=dtype) + t_grid = torch.linspace(1.0, 0.0, num_sampling_steps + 1, device=device, dtype=dtype) + for i in range(num_sampling_steps): + t_cur = t_grid[i].repeat(batch_size) + t_nxt = t_grid[i + 1].repeat(batch_size) + + lambda_cur = ul_logsnr_schedule(t_cur, schedule_type=schedule, lambda_min=lambda_min, lambda_max=lambda_max) + lambda_nxt = ul_logsnr_schedule(t_nxt, schedule_type=schedule, lambda_min=lambda_min, lambda_max=lambda_max) + alpha_cur, sigma_cur = ul_alpha_sigma_from_logsnr(lambda_cur) + alpha_nxt, sigma_nxt = ul_alpha_sigma_from_logsnr(lambda_nxt) + + timestep_idx = (t_cur * (num_train_timesteps - 1)).long().clamp(0, num_train_timesteps - 1) + dummy_labels = torch.zeros((batch_size,), device=device, dtype=torch.long) + v_pred = base_model(z_t, timestep_idx, dummy_labels) + + z0_hat = alpha_cur[:, None, None, None] * z_t - sigma_cur[:, None, None, None] * v_pred + eps_hat = sigma_cur[:, None, None, None] * z_t + alpha_cur[:, None, None, None] * v_pred + z_t = alpha_nxt[:, None, None, None] * z0_hat + sigma_nxt[:, None, None, None] * eps_hat + + # Paper-ground-truth: hand off noisy latent at logsnr_0. + return z_t + + +def _collect_inception_features( + images_m11: torch.Tensor, + inception: torch.nn.Module, + *, + device: torch.device, +) -> torch.Tensor: + x = (images_m11 + 1.0) / 2.0 + x = x.clamp(0.0, 1.0) + x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False) + mean = torch.tensor([0.485, 0.456, 0.406], device=device, dtype=x.dtype)[None, :, None, None] + std = torch.tensor([0.229, 0.224, 0.225], device=device, dtype=x.dtype)[None, :, None, None] + x = (x - mean) / std + with torch.no_grad(): + feats = inception(x) + return feats.float().cpu() + + +def _kid_mmd2_unbiased(x: torch.Tensor, y: torch.Tensor) -> float: + # Polynomial kernel used in KID. + # x: [n, d], y: [m, d] + n = x.shape[0] + m = y.shape[0] + d = x.shape[1] + k_xx = ((x @ x.T) / d + 1.0) ** 3 + k_yy = ((y @ y.T) / d + 1.0) ** 3 + k_xy = ((x @ y.T) / d + 1.0) ** 3 + + sum_xx = (k_xx.sum() - torch.diagonal(k_xx).sum()) / (n * (n - 1)) + sum_yy = (k_yy.sum() - torch.diagonal(k_yy).sum()) / (m * (m - 1)) + sum_xy = k_xy.mean() + return (sum_xx + sum_yy - 2.0 * sum_xy).item() + + +def _compute_kid( + real_feats: torch.Tensor, + fake_feats: torch.Tensor, + *, + subset_size: int, + subsets: int, + seed: int, +) -> tuple[float, float]: + g = torch.Generator(device="cpu") + g.manual_seed(seed) + + n = min(real_feats.shape[0], fake_feats.shape[0]) + subset_size = min(subset_size, n) + if subset_size < 2: + raise ValueError("Need at least 2 samples to compute KID.") + + values = [] + for _ in range(subsets): + idx_r = torch.randperm(real_feats.shape[0], generator=g)[:subset_size] + idx_f = torch.randperm(fake_feats.shape[0], generator=g)[:subset_size] + values.append(_kid_mmd2_unbiased(real_feats[idx_r], fake_feats[idx_f])) + + vals = torch.tensor(values, dtype=torch.float32) + return vals.mean().item(), vals.std(unbiased=False).item() + + +def main(): + args = parse_args() + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + torch.manual_seed(args.seed) + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + + dataset, loader = _build_loader(args.train_data_dir, args.resolution, args.batch_size, args.num_workers) + + encoder = AutoencoderULEncoder.from_pretrained(args.stage1_encoder_path).to(device).eval() + decoder = UNet2DModel.from_pretrained(args.stage1_decoder_path).to(device).eval() + base_model = ULTwoStageBaseModel.from_pretrained(args.stage2_base_path).to(device).eval() + + latent_size = int(getattr(base_model.config, "latent_size", args.resolution // args.latent_downsample_factor)) + + # Stage-1 reconstruction metrics. + recon_count = 0 + recon_mae = [] + recon_mse = [] + vis_pairs = [] + + for x, _ in loader: + if recon_count >= args.num_recon_samples: + break + x = x.to(device) + if recon_count + x.shape[0] > args.num_recon_samples: + x = x[: args.num_recon_samples - recon_count] + + with torch.no_grad(): + z_clean = encoder.encode(x).latent + lambda0 = torch.full((x.shape[0],), args.lambda_z_max, device=device, dtype=x.dtype) + alpha0, sigma0 = ul_alpha_sigma_from_logsnr(lambda0) + if args.recon_use_noisy_z0: + eps0 = torch.randn_like(z_clean) + else: + eps0 = torch.zeros_like(z_clean) + z0 = alpha0[:, None, None, None] * z_clean + sigma0[:, None, None, None] * eps0 + + recon = _decode_image_diffusion( + decoder, + z0, + resolution=args.resolution, + num_sampling_steps=args.num_sampling_steps, + num_train_timesteps=args.num_train_timesteps, + schedule=args.decoder_schedule, + lambda_min=args.lambda_x_min, + lambda_max=args.lambda_x_max, + decoder_prediction_type=args.decoder_prediction_type, + ) + + diff = recon - x + mse = diff.square().mean(dim=(1, 2, 3)) + mae = diff.abs().mean(dim=(1, 2, 3)) + recon_mse.append(mse.cpu()) + recon_mae.append(mae.cpu()) + + if len(vis_pairs) < args.grid_items: + take = min(args.grid_items - len(vis_pairs), x.shape[0]) + for i in range(take): + vis_pairs.append(torch.cat([x[i], recon[i]], dim=2).detach().cpu()) + + recon_count += x.shape[0] + + recon_mse = torch.cat(recon_mse) + recon_mae = torch.cat(recon_mae) + recon_psnr = -10.0 * torch.log10(recon_mse.clamp_min(1e-12)) + + if vis_pairs: + recon_grid = make_grid([(p + 1.0) / 2.0 for p in vis_pairs], nrow=2) + save_image(recon_grid, out_dir / args.save_recon_grid) + + # Stage-2 generation and realism metrics. + gen_images = [] + gen_count = 0 + while gen_count < args.num_gen_samples: + bsz = min(args.batch_size, args.num_gen_samples - gen_count) + with torch.no_grad(): + z0_sampled = _sample_stage2_latent( + base_model, + batch_size=bsz, + latent_channels=args.latent_channels, + latent_size=latent_size, + num_sampling_steps=args.num_sampling_steps, + num_train_timesteps=args.num_train_timesteps, + schedule=args.base_schedule, + lambda_min=args.lambda_z_min, + lambda_max=args.lambda_z_max, + device=device, + dtype=torch.float32, + ) + x_gen = _decode_image_diffusion( + decoder, + z0_sampled, + resolution=args.resolution, + num_sampling_steps=args.num_sampling_steps, + num_train_timesteps=args.num_train_timesteps, + schedule=args.decoder_schedule, + lambda_min=args.lambda_x_min, + lambda_max=args.lambda_x_max, + decoder_prediction_type=args.decoder_prediction_type, + ) + gen_images.append(x_gen.cpu()) + gen_count += bsz + + gen_images = torch.cat(gen_images, dim=0)[: args.num_gen_samples] + + stage2_vis = make_grid(((gen_images[: args.grid_items] + 1.0) / 2.0), nrow=2) + save_image(stage2_vis, out_dir / args.save_stage2_grid) + + # Real images for comparison features. + real_images = [] + real_count = 0 + for x, _ in loader: + if real_count >= args.num_gen_samples: + break + if real_count + x.shape[0] > args.num_gen_samples: + x = x[: args.num_gen_samples - real_count] + real_images.append(x) + real_count += x.shape[0] + real_images = torch.cat(real_images, dim=0) + + inception = inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1, aux_logits=True, transform_input=False) + inception.aux_logits = False + inception.AuxLogits = None + inception.fc = torch.nn.Identity() + inception = inception.to(device).eval() + + real_feats = [] + fake_feats = [] + with torch.no_grad(): + for i in range(0, real_images.shape[0], args.batch_size): + real_feats.append( + _collect_inception_features(real_images[i : i + args.batch_size].to(device), inception, device=device) + ) + fake_feats.append( + _collect_inception_features(gen_images[i : i + args.batch_size].to(device), inception, device=device) + ) + real_feats = torch.cat(real_feats, dim=0) + fake_feats = torch.cat(fake_feats, dim=0) + + kid_mean, kid_std = _compute_kid( + real_feats, + fake_feats, + subset_size=args.kid_subset_size, + subsets=args.kid_subsets, + seed=args.seed, + ) + + metrics = { + "dataset_size": len(dataset), + "num_recon_samples": int(recon_mse.numel()), + "num_gen_samples": int(gen_images.shape[0]), + "stage1": { + "mae_mean": float(recon_mae.mean().item()), + "mse_mean": float(recon_mse.mean().item()), + "psnr_mean_db": float(recon_psnr.mean().item()), + "psnr_median_db": float(recon_psnr.median().item()), + "recon_uses_noisy_z0": bool(args.recon_use_noisy_z0), + }, + "stage2": { + "kid_mean": float(kid_mean), + "kid_std": float(kid_std), + }, + "passes": { + "stage1_psnr": bool(recon_psnr.mean().item() >= args.pass_psnr), + "stage1_mae": bool(recon_mae.mean().item() <= args.pass_mae), + "stage2_kid": bool(kid_mean <= args.pass_kid), + }, + "thresholds": { + "pass_psnr": float(args.pass_psnr), + "pass_mae": float(args.pass_mae), + "pass_kid": float(args.pass_kid), + }, + "artifacts": { + "recon_grid": str((out_dir / args.save_recon_grid).resolve()), + "stage2_grid": str((out_dir / args.save_stage2_grid).resolve()), + }, + } + metrics["passes"]["all"] = all(metrics["passes"].values()) + + metrics_path = out_dir / "metrics.json" + metrics_path.write_text(json.dumps(metrics, indent=2)) + print(json.dumps(metrics, indent=2)) + print(f"Saved metrics to {metrics_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/research_projects/unified_latents/sample_ul_stage2.py b/examples/research_projects/unified_latents/sample_ul_stage2.py new file mode 100644 index 000000000000..34d4cb8e7613 --- /dev/null +++ b/examples/research_projects/unified_latents/sample_ul_stage2.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python +# coding=utf-8 + +import argparse +from pathlib import Path + +import torch +import torch.nn.functional as F +from torchvision.utils import save_image +from ul_models import ULTwoStageBaseModel + +from diffusers import UNet2DModel +from diffusers.training_utils import ( + ul_alpha_sigma_from_logsnr, + ul_logsnr_schedule, +) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Sample UL Stage-2 latents and decode to images.") + parser.add_argument("--stage2_base_path", type=str, required=True, help="Path to stage2 base_model directory.") + parser.add_argument("--stage1_decoder_path", type=str, required=True, help="Path to stage1 decoder directory.") + parser.add_argument("--output_path", type=str, required=True, help="Output image path.") + parser.add_argument("--resolution", type=int, default=256) + parser.add_argument("--latent_channels", type=int, default=4) + parser.add_argument("--latent_downsample_factor", type=int, default=8) + parser.add_argument("--num_train_timesteps", type=int, default=1000) + parser.add_argument("--num_sampling_steps", type=int, default=50) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--base_schedule", type=str, default="linear", choices=["linear", "cosine"]) + parser.add_argument("--decoder_schedule", type=str, default="linear", choices=["linear", "cosine"]) + parser.add_argument("--lambda_z_min", type=float, default=-10.0) + parser.add_argument("--lambda_z_max", type=float, default=5.0) + parser.add_argument("--lambda_x_min", type=float, default=-10.0) + parser.add_argument("--lambda_x_max", type=float, default=10.0) + parser.add_argument("--decoder_prediction_type", type=str, default="x0", choices=["epsilon", "x0"]) + return parser.parse_args() + + +@torch.no_grad() +def main(): + args = parse_args() + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + torch.manual_seed(args.seed) + latent_size = args.resolution // args.latent_downsample_factor + + base_model = ULTwoStageBaseModel.from_pretrained(args.stage2_base_path).to(device).eval() + decoder = UNet2DModel.from_pretrained(args.stage1_decoder_path).to(device).eval() + + # Stage-2 latent sampling (DDIM-like deterministic update with v-pred model). + z_t = torch.randn(args.batch_size, args.latent_channels, latent_size, latent_size, device=device) + t_grid = torch.linspace(1.0, 0.0, args.num_sampling_steps + 1, device=device) + + for i in range(args.num_sampling_steps): + t_cur = t_grid[i].repeat(args.batch_size) + t_nxt = t_grid[i + 1].repeat(args.batch_size) + + lambda_cur = ul_logsnr_schedule( + t_cur, + schedule_type=args.base_schedule, + lambda_min=args.lambda_z_min, + lambda_max=args.lambda_z_max, + ) + lambda_nxt = ul_logsnr_schedule( + t_nxt, + schedule_type=args.base_schedule, + lambda_min=args.lambda_z_min, + lambda_max=args.lambda_z_max, + ) + alpha_cur, sigma_cur = ul_alpha_sigma_from_logsnr(lambda_cur) + alpha_nxt, sigma_nxt = ul_alpha_sigma_from_logsnr(lambda_nxt) + + timestep_idx = (t_cur * (args.num_train_timesteps - 1)).long().clamp(0, args.num_train_timesteps - 1) + dummy_labels = torch.zeros((args.batch_size,), device=device, dtype=torch.long) + v_pred = base_model(z_t, timestep_idx, dummy_labels) + + z0_hat = alpha_cur[:, None, None, None] * z_t - sigma_cur[:, None, None, None] * v_pred + eps_hat = sigma_cur[:, None, None, None] * z_t + alpha_cur[:, None, None, None] * v_pred + + z_t = alpha_nxt[:, None, None, None] * z0_hat + sigma_nxt[:, None, None, None] * eps_hat + + # Paper Sec. 3.3: hand off the noisy latent at logsnr_0 to the decoder. + z_handoff = z_t + + # Decoder sampling conditioned on final stage-2 latent. + z0_up = F.interpolate(z_handoff, size=(args.resolution, args.resolution), mode="bilinear", align_corners=False) + + x_t = torch.randn(args.batch_size, 3, args.resolution, args.resolution, device=device) + x_grid = torch.linspace(1.0, 0.0, args.num_sampling_steps + 1, device=device) + + for i in range(args.num_sampling_steps): + t_cur = x_grid[i].repeat(args.batch_size) + t_nxt = x_grid[i + 1].repeat(args.batch_size) + + lambda_cur = ul_logsnr_schedule( + t_cur, + schedule_type=args.decoder_schedule, + lambda_min=args.lambda_x_min, + lambda_max=args.lambda_x_max, + ) + lambda_nxt = ul_logsnr_schedule( + t_nxt, + schedule_type=args.decoder_schedule, + lambda_min=args.lambda_x_min, + lambda_max=args.lambda_x_max, + ) + alpha_cur, sigma_cur = ul_alpha_sigma_from_logsnr(lambda_cur) + alpha_nxt, sigma_nxt = ul_alpha_sigma_from_logsnr(lambda_nxt) + + timestep_idx = (t_cur * (args.num_train_timesteps - 1)).long().clamp(0, args.num_train_timesteps - 1) + decoder_input = torch.cat([x_t, z0_up], dim=1) + decoder_pred = decoder(decoder_input, timestep_idx).sample + if args.decoder_prediction_type == "epsilon": + eps_hat = decoder_pred + x0_hat = (x_t - sigma_cur[:, None, None, None] * eps_hat) / alpha_cur[:, None, None, None].clamp_min(1e-5) + else: + x0_hat = decoder_pred + eps_hat = (x_t - alpha_cur[:, None, None, None] * x0_hat) / sigma_cur[:, None, None, None].clamp_min(1e-5) + x_t = alpha_nxt[:, None, None, None] * x0_hat + sigma_nxt[:, None, None, None] * eps_hat + + x_final = x_t.clamp(-1, 1) + output_path = Path(args.output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + save_image((x_final + 1.0) / 2.0, output_path, nrow=2) + print(f"Saved samples to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/research_projects/unified_latents/train_ul_stage1.py b/examples/research_projects/unified_latents/train_ul_stage1.py new file mode 100644 index 000000000000..db6f445c032c --- /dev/null +++ b/examples/research_projects/unified_latents/train_ul_stage1.py @@ -0,0 +1,588 @@ +#!/usr/bin/env python +# coding=utf-8 + +import argparse +import logging +import math +import os +import shutil +from pathlib import Path + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from torch import nn +from torch.utils.data import DataLoader, IterableDataset +from torchvision import transforms +from torchvision.utils import save_image +from tqdm.auto import tqdm + +from diffusers.models.autoencoders import AutoencoderULEncoder +from diffusers.models.transformers.dit_transformer_2d import DiTTransformer2DModel +from diffusers.models.unets.unet_2d import UNet2DModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + ul_alpha_sigma_from_logsnr, + ul_decoder_loss_weight, + ul_dlogsnr_dt, + ul_elbo_prefactor, + ul_logsnr_schedule, + ul_sample_t, + ul_terminal_gaussian_kl, +) + + +logger = get_logger(__name__, log_level="INFO") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train Unified Latents Stage-1 (UL) with Diffusers.") + + parser.add_argument("--dataset_name", type=str, default=None, help="Dataset name on the Hub.") + parser.add_argument("--dataset_config_name", type=str, default=None, help="Dataset config name.") + parser.add_argument( + "--train_data_dir", type=str, default=None, help="Local dataset directory (imagefolder style)." + ) + parser.add_argument("--cache_dir", type=str, default=None, help="Hugging Face datasets cache directory.") + parser.add_argument("--image_column", type=str, default="image", help="Column containing images.") + parser.add_argument("--streaming", action="store_true", help="Stream dataset from the Hub/local files.") + parser.add_argument("--output_dir", type=str, default="ul-stage1", help="Where to save checkpoints and samples.") + parser.add_argument("--logging_dir", type=str, default="logs") + parser.add_argument("--report_to", type=str, default="tensorboard") + parser.add_argument("--tracker_project_name", type=str, default="unified-latents-stage1") + parser.add_argument("--resolution", type=int, default=256) + parser.add_argument("--center_crop", action="store_true") + parser.add_argument("--random_flip", action="store_true") + parser.add_argument("--num_workers", type=int, default=4) + + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument("--max_train_steps", type=int, default=10000) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=1e-4) + parser.add_argument("--adam_beta1", type=float, default=0.9) + parser.add_argument("--adam_beta2", type=float, default=0.999) + parser.add_argument("--adam_weight_decay", type=float, default=1e-2) + parser.add_argument("--adam_epsilon", type=float, default=1e-8) + parser.add_argument("--max_grad_norm", type=float, default=1.0) + parser.add_argument("--lr_scheduler", type=str, default="cosine") + parser.add_argument("--lr_warmup_steps", type=int, default=500) + + parser.add_argument("--num_train_timesteps", type=int, default=1000) + parser.add_argument("--lambda_z_min", type=float, default=-10.0) + parser.add_argument("--lambda_z_max", type=float, default=5.0) + parser.add_argument("--lambda_x_min", type=float, default=-10.0) + parser.add_argument("--lambda_x_max", type=float, default=10.0) + parser.add_argument("--prior_schedule", type=str, default="linear", choices=["linear", "cosine"]) + parser.add_argument("--decoder_schedule", type=str, default="linear", choices=["linear", "cosine"]) + parser.add_argument("--decoder_loss_factor", type=float, default=1.6) + parser.add_argument("--decoder_sigmoid_bias", type=float, default=0.0) + parser.add_argument( + "--decoder_prediction_type", + type=str, + default="x0", + choices=["epsilon", "x0"], + help="Decoder target parameterization.", + ) + parser.add_argument("--latent_channels", type=int, default=4) + parser.add_argument( + "--latent_downsample_factor", + type=int, + default=8, + help="Downsample factor of encoder output. Current encoder config is 8x; paper target is 16x with 2x2 patching.", + ) + + # Prior (paper Sec. 5.1: single-level ViT, 8 blocks, 1024 channels) + parser.add_argument("--prior_num_layers", type=int, default=8) + parser.add_argument("--prior_num_heads", type=int, default=16) + parser.add_argument("--prior_head_dim", type=int, default=64) # 16 * 64 = 1024 + parser.add_argument("--prior_patch_size", type=int, default=1) + + # Decoder approximation (paper Sec. 5.1: UViT conv channels [128, 256, 512], dropout 0.1) + parser.add_argument("--decoder_block_out_channels", type=str, default="128,256,512") + parser.add_argument("--decoder_layers_per_block", type=int, default=2) + parser.add_argument("--decoder_dropout", type=float, default=0.1) + + parser.add_argument("--save_steps", type=int, default=500) + parser.add_argument("--checkpoints_total_limit", type=int, default=None) + parser.add_argument("--resume_from_checkpoint", type=str, default=None) + parser.add_argument("--sample_steps", type=int, default=250) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"]) + + return parser.parse_args() + + +def build_transforms(args): + transform_list = [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + ] + if args.center_crop: + transform_list.append(transforms.CenterCrop(args.resolution)) + else: + transform_list.append(transforms.RandomCrop(args.resolution)) + + if args.random_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + transform_list.extend( + [ + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + return transforms.Compose(transform_list) + + +def parse_channels(channels_str: str) -> tuple[int, ...]: + values = [int(x.strip()) for x in channels_str.split(",") if x.strip()] + if not values: + raise ValueError("`decoder_block_out_channels` must contain at least one integer.") + return tuple(values) + + +class HFStreamingImageDataset(IterableDataset): + def __init__(self, hf_iterable, image_column: str, image_transform): + super().__init__() + self.hf_iterable = hf_iterable + self.image_column = image_column + self.image_transform = image_transform + + def __iter__(self): + for example in self.hf_iterable: + image = example[self.image_column].convert("RGB") + yield {"pixel_values": self.image_transform(image)} + + +def get_train_dataset_and_collate(args, image_transform): + if args.dataset_name is not None: + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + data_dir=args.train_data_dir, + streaming=args.streaming, + ) + else: + if args.train_data_dir is None: + raise ValueError("Provide either `--dataset_name` or `--train_data_dir`.") + data_files = {"train": os.path.join(args.train_data_dir, "**")} + dataset = load_dataset( + "imagefolder", data_files=data_files, cache_dir=args.cache_dir, streaming=args.streaming + ) + + train_dataset = dataset["train"] + + if args.streaming: + image_column = args.image_column + train_dataset = HFStreamingImageDataset( + train_dataset, image_column=image_column, image_transform=image_transform + ) + else: + column_names = train_dataset.column_names + image_column = args.image_column if args.image_column in column_names else column_names[0] + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["pixel_values"] = [image_transform(image) for image in images] + return examples + + train_dataset = train_dataset.with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + return {"pixel_values": pixel_values.contiguous().float()} + + return train_dataset, collate_fn + + +def _save_pretrained_components( + accelerator: Accelerator, encoder: nn.Module, prior: nn.Module, decoder: nn.Module, output_dir +): + output_dir = Path(output_dir) + accelerator.unwrap_model(encoder).save_pretrained(output_dir / "encoder") + accelerator.unwrap_model(prior).save_pretrained(output_dir / "prior") + accelerator.unwrap_model(decoder).save_pretrained(output_dir / "decoder") + + +def main(): + args = parse_args() + if args.seed is not None: + set_seed(args.seed) + + output_dir = Path(args.output_dir) + logging_dir = output_dir / args.logging_dir + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_dir)) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state) + + if accelerator.is_main_process: + output_dir.mkdir(parents=True, exist_ok=True) + (output_dir / "samples").mkdir(exist_ok=True) + logger.info("Starting UL Stage-1 training") + logger.info(f"Output dir: {output_dir}") + + image_transform = build_transforms(args) + train_dataset, collate_fn = get_train_dataset_and_collate(args, image_transform) + train_dataloader = DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=not args.streaming, + collate_fn=collate_fn, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True, + ) + + latent_size = args.resolution // args.latent_downsample_factor + if latent_size < 1: + raise ValueError("`resolution // latent_downsample_factor` must be >= 1.") + + encoder = AutoencoderULEncoder(in_channels=3, latent_channels=args.latent_channels) + + prior = DiTTransformer2DModel( + num_attention_heads=args.prior_num_heads, + attention_head_dim=args.prior_head_dim, + in_channels=args.latent_channels, + out_channels=args.latent_channels, + num_layers=args.prior_num_layers, + sample_size=latent_size, + patch_size=args.prior_patch_size, + num_embeds_ada_norm=args.num_train_timesteps, + ) + + decoder_channels = parse_channels(args.decoder_block_out_channels) + decoder = UNet2DModel( + sample_size=args.resolution, + in_channels=3 + args.latent_channels, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), + block_out_channels=decoder_channels, + layers_per_block=args.decoder_layers_per_block, + dropout=args.decoder_dropout, + attention_head_dim=8, + add_attention=True, + num_train_timesteps=args.num_train_timesteps, + ) + + params = list(encoder.parameters()) + list(prior.parameters()) + list(decoder.parameters()) + optimizer = torch.optim.AdamW( + params, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + has_dataloader_length = True + try: + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + except TypeError: + has_dataloader_length = False + num_update_steps_per_epoch = None + if args.max_train_steps is None or args.max_train_steps <= 0: + if not has_dataloader_length: + raise ValueError("For streaming datasets, set `--max_train_steps` to a positive value.") + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + + lr_scheduler = get_scheduler( + name=args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + encoder, prior, decoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + encoder, prior, decoder, optimizer, train_dataloader, lr_scheduler + ) + trainable_params = list(encoder.parameters()) + list(prior.parameters()) + list(decoder.parameters()) + + # Recompute schedule values after `accelerator.prepare`. + if has_dataloader_length: + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + else: + args.num_train_epochs = max(args.num_train_epochs, 1) + + if accelerator.is_main_process: + accelerator.init_trackers(args.tracker_project_name, config=vars(args)) + + logger.info("***** Running training *****") + if has_dataloader_length: + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + else: + logger.info(" Num examples = streaming (unknown)") + logger.info(" Num batches each epoch = streaming (unknown)") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info( + " Total train batch size (w. parallel, distributed & accumulation) = " + f"{args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps}" + ) + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + global_step = 0 + first_epoch = 0 + + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + dirs = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else [] + dirs = [d for d in dirs if d.startswith("checkpoint-")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") + else: + logger.info(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch if has_dataloader_length else 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=global_step, + desc="Steps", + disable=not accelerator.is_local_main_process, + ) + + train_loss = 0.0 + fixed_eval_batch_cpu = None + for epoch in range(first_epoch, args.num_train_epochs): + encoder.train() + prior.train() + decoder.train() + + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(encoder, prior, decoder): + x = batch["pixel_values"].to(accelerator.device) + bsz = x.shape[0] + if fixed_eval_batch_cpu is None: + # Keep a fixed mini-batch for deterministic reconstruction snapshots. + fixed_eval_batch_cpu = x[: min(4, bsz)].detach().float().cpu() + + # Encode deterministic latent z_clean + z_clean = encoder.encode(x).latent + + # Prior branch: z_t -> z_clean with unweighted ELBO-style denoising loss + t_z = ul_sample_t(bsz, x.device) + lambda_z = ul_logsnr_schedule( + t_z, + schedule_type=args.prior_schedule, + lambda_min=args.lambda_z_min, + lambda_max=args.lambda_z_max, + ) + dlogsnr_dt_z = ul_dlogsnr_dt( + t_z, + schedule_type=args.prior_schedule, + lambda_min=args.lambda_z_min, + lambda_max=args.lambda_z_max, + ) + prefactor_z = ul_elbo_prefactor(lambda_z, dlogsnr_dt_z) + alpha_z, sigma_z = ul_alpha_sigma_from_logsnr(lambda_z) + eps_z = torch.randn_like(z_clean) + z_t = alpha_z[:, None, None, None] * z_clean + sigma_z[:, None, None, None] * eps_z + t_z_idx = (t_z * (args.num_train_timesteps - 1)).long().clamp(0, args.num_train_timesteps - 1) + dummy_labels = torch.zeros((bsz,), device=x.device, dtype=torch.long) + z_pred = prior(z_t, timestep=t_z_idx, class_labels=dummy_labels).sample + prior_per_sample = F.mse_loss(z_pred.float(), z_clean.float(), reduction="none").sum(dim=(1, 2, 3)) + prior_loss = (prefactor_z * prior_per_sample).mean() + prior_loss = prior_loss + ul_terminal_gaussian_kl(z_clean, logsnr_terminal=args.lambda_z_min).mean() + + # Decoder branch: x_t conditioned on fixed-noise z0 + lambda0 = torch.full((bsz,), args.lambda_z_max, device=x.device, dtype=x.dtype) + alpha0, sigma0 = ul_alpha_sigma_from_logsnr(lambda0) + eps_0 = torch.randn_like(z_clean) + z0 = alpha0[:, None, None, None] * z_clean + sigma0[:, None, None, None] * eps_0 + + t_x = ul_sample_t(bsz, x.device) + lambda_x = ul_logsnr_schedule( + t_x, + schedule_type=args.decoder_schedule, + lambda_min=args.lambda_x_min, + lambda_max=args.lambda_x_max, + ) + dlogsnr_dt_x = ul_dlogsnr_dt( + t_x, + schedule_type=args.decoder_schedule, + lambda_min=args.lambda_x_min, + lambda_max=args.lambda_x_max, + ) + prefactor_x = ul_elbo_prefactor(lambda_x, dlogsnr_dt_x) + alpha_x, sigma_x = ul_alpha_sigma_from_logsnr(lambda_x) + eps_x = torch.randn_like(x) + x_t = alpha_x[:, None, None, None] * x + sigma_x[:, None, None, None] * eps_x + t_x_idx = (t_x * (args.num_train_timesteps - 1)).long().clamp(0, args.num_train_timesteps - 1) + + z0_up = F.interpolate(z0, size=x.shape[-2:], mode="bilinear", align_corners=False) + decoder_input = torch.cat([x_t, z0_up], dim=1) + decoder_pred = decoder(decoder_input, t_x_idx).sample + if args.decoder_prediction_type == "epsilon": + decoder_target = eps_x + decoder_prefactor = 0.5 * (-dlogsnr_dt_x) + else: + decoder_target = x + decoder_prefactor = prefactor_x + + decoder_per_sample = F.mse_loss(decoder_pred.float(), decoder_target.float(), reduction="none").sum( + dim=(1, 2, 3) + ) + decoder_w = ul_decoder_loss_weight( + lambda_x, + bias=args.decoder_sigmoid_bias, + loss_factor=args.decoder_loss_factor, + invert=False, + ) + decoder_loss = (decoder_prefactor * decoder_w * decoder_per_sample).mean() + + num_pixels = x.shape[-2] * x.shape[-1] + bpp_denom = num_pixels * math.log(2.0) + prior_bpp = prior_loss / bpp_denom + decoder_bpp = decoder_loss / bpp_denom + loss = prior_bpp + decoder_bpp + accelerator.backward(loss) + + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + avg_loss = accelerator.gather(loss.detach().repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + if accelerator.sync_gradients: + global_step += 1 + progress_bar.update(1) + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if accelerator.is_main_process and global_step % 50 == 0: + logger.info( + f"step={global_step} loss_bpp={loss.item():.4f} prior_bpp={prior_bpp.item():.4f} " + f"decoder_bpp={decoder_bpp.item():.4f} lr={lr_scheduler.get_last_lr()[0]:.2e}" + ) + + if accelerator.is_main_process and global_step % args.sample_steps == 0: + with torch.no_grad(): + encoder.eval() + decoder.eval() + + eval_x = fixed_eval_batch_cpu.to(accelerator.device) + eval_z_clean = encoder.encode(eval_x).latent + eval_z_up = F.interpolate( + eval_z_clean, size=eval_x.shape[-2:], mode="bilinear", align_corners=False + ) + + # Deterministic decode from encoder latents at t=0 with zero x-noise input. + eval_x_t = torch.zeros_like(eval_x) + eval_t_idx = torch.zeros((eval_x.shape[0],), device=eval_x.device, dtype=torch.long) + eval_decoder_input = torch.cat([eval_x_t, eval_z_up], dim=1) + eval_decoder_out = decoder(eval_decoder_input, eval_t_idx).sample + if args.decoder_prediction_type == "epsilon": + eval_lambda0 = ul_logsnr_schedule( + torch.zeros((eval_x.shape[0],), device=eval_x.device), + schedule_type=args.decoder_schedule, + lambda_min=args.lambda_x_min, + lambda_max=args.lambda_x_max, + ) + eval_alpha0, eval_sigma0 = ul_alpha_sigma_from_logsnr(eval_lambda0) + eval_recon = ( + (eval_x_t - eval_sigma0[:, None, None, None] * eval_decoder_out) + / eval_alpha0[:, None, None, None].clamp_min(1e-5) + ).clamp(-1, 1) + else: + eval_recon = eval_decoder_out.clamp(-1, 1) + + # Save side-by-side: left=input, right=reconstruction. + side_by_side = torch.cat([eval_x.clamp(-1, 1), eval_recon], dim=3).detach().float().cpu() + save_image( + (side_by_side + 1.0) / 2.0, + output_dir / "samples" / f"step_{global_step:07d}.png", + nrow=2, + ) + + encoder.train() + decoder.train() + + if accelerator.is_main_process and global_step % args.save_steps == 0: + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint-")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[:num_to_remove] + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)}" + ) + for removing_checkpoint in removing_checkpoints: + shutil.rmtree(os.path.join(args.output_dir, removing_checkpoint)) + + save_dir = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_dir) + _save_pretrained_components( + accelerator=accelerator, + encoder=encoder, + prior=prior, + decoder=decoder, + output_dir=save_dir, + ) + logger.info(f"Saved state to {save_dir}") + + logs = { + "step_loss_bpp": loss.detach().item(), + "prior_loss_bpp": prior_bpp.detach().item(), + "decoder_loss_bpp": decoder_bpp.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + final_dir = output_dir / "final" + final_dir.mkdir(parents=True, exist_ok=True) + torch.save(accelerator.unwrap_model(encoder).state_dict(), final_dir / "encoder.pt") + torch.save(accelerator.unwrap_model(prior).state_dict(), final_dir / "prior.pt") + torch.save(accelerator.unwrap_model(decoder).state_dict(), final_dir / "decoder.pt") + _save_pretrained_components( + accelerator=accelerator, + encoder=encoder, + prior=prior, + decoder=decoder, + output_dir=final_dir, + ) + logger.info(f"Training finished. Saved final checkpoints to {final_dir}") + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/examples/research_projects/unified_latents/train_ul_stage2_base.py b/examples/research_projects/unified_latents/train_ul_stage2_base.py new file mode 100644 index 000000000000..373a0bb5ffa7 --- /dev/null +++ b/examples/research_projects/unified_latents/train_ul_stage2_base.py @@ -0,0 +1,462 @@ +#!/usr/bin/env python +# coding=utf-8 + +import argparse +import logging +import math +import os +import shutil +from pathlib import Path + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from torch import nn +from torch.utils.data import DataLoader, IterableDataset +from torchvision import transforms +from tqdm.auto import tqdm +from ul_models import ULTwoStageBaseModel + +from diffusers.models.autoencoders import AutoencoderULEncoder +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + ul_alpha_sigma_from_logsnr, + ul_decoder_loss_weight, + ul_dlogsnr_dt, + ul_elbo_prefactor, + ul_logsnr_schedule, + ul_sample_t, +) + + +logger = get_logger(__name__, log_level="INFO") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train Unified Latents Stage-2 base model with Diffusers.") + + parser.add_argument("--dataset_name", type=str, default=None, help="Dataset name on the Hub.") + parser.add_argument("--dataset_config_name", type=str, default=None, help="Dataset config name.") + parser.add_argument( + "--train_data_dir", type=str, default=None, help="Local dataset directory (imagefolder style)." + ) + parser.add_argument("--cache_dir", type=str, default=None, help="Hugging Face datasets cache directory.") + parser.add_argument("--image_column", type=str, default="image", help="Column containing images.") + parser.add_argument("--streaming", action="store_true", help="Stream dataset from the Hub/local files.") + parser.add_argument( + "--stage1_encoder_path", + type=str, + required=True, + help="Path to stage-1 encoder checkpoint. Supports a `.pt` file or a `save_pretrained` directory.", + ) + + parser.add_argument("--output_dir", type=str, default="ul-stage2-base") + parser.add_argument("--logging_dir", type=str, default="logs") + parser.add_argument("--report_to", type=str, default="tensorboard") + parser.add_argument("--tracker_project_name", type=str, default="unified-latents-stage2") + + parser.add_argument("--resolution", type=int, default=256) + parser.add_argument("--center_crop", action="store_true") + parser.add_argument("--random_flip", action="store_true") + parser.add_argument("--num_workers", type=int, default=4) + + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument("--max_train_steps", type=int, default=10000) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + + parser.add_argument("--learning_rate", type=float, default=1e-4) + parser.add_argument("--adam_beta1", type=float, default=0.9) + parser.add_argument("--adam_beta2", type=float, default=0.999) + parser.add_argument("--adam_weight_decay", type=float, default=1e-2) + parser.add_argument("--adam_epsilon", type=float, default=1e-8) + parser.add_argument("--max_grad_norm", type=float, default=1.0) + parser.add_argument("--lr_scheduler", type=str, default="cosine") + parser.add_argument("--lr_warmup_steps", type=int, default=500) + + parser.add_argument("--num_train_timesteps", type=int, default=1000) + parser.add_argument("--base_schedule", type=str, default="linear", choices=["linear", "cosine"]) + parser.add_argument("--lambda_z_min", type=float, default=-10.0) + parser.add_argument("--lambda_z_max", type=float, default=5.0) + + parser.add_argument("--loss_factor", type=float, default=1.0) + parser.add_argument("--sigmoid_bias", type=float, default=0.0) + + parser.add_argument("--latent_channels", type=int, default=4) + parser.add_argument("--latent_downsample_factor", type=int, default=8) + parser.add_argument("--base_patch_size", type=int, default=1) + parser.add_argument("--base_stage_a_layers", type=int, default=6) + parser.add_argument("--base_stage_b_layers", type=int, default=16) + parser.add_argument("--base_stage_a_heads", type=int, default=8) + parser.add_argument("--base_stage_a_head_dim", type=int, default=64) + parser.add_argument("--base_stage_b_heads", type=int, default=16) + parser.add_argument("--base_stage_b_head_dim", type=int, default=64) + + parser.add_argument("--save_steps", type=int, default=500) + parser.add_argument("--checkpoints_total_limit", type=int, default=None) + parser.add_argument("--resume_from_checkpoint", type=str, default=None) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"]) + + return parser.parse_args() + + +def build_transforms(args): + transform_list = [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + ] + if args.center_crop: + transform_list.append(transforms.CenterCrop(args.resolution)) + else: + transform_list.append(transforms.RandomCrop(args.resolution)) + + if args.random_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + transform_list.extend( + [ + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + return transforms.Compose(transform_list) + + +class HFStreamingImageDataset(IterableDataset): + def __init__(self, hf_iterable, image_column: str, image_transform): + super().__init__() + self.hf_iterable = hf_iterable + self.image_column = image_column + self.image_transform = image_transform + + def __iter__(self): + for example in self.hf_iterable: + image = example[self.image_column].convert("RGB") + yield {"pixel_values": self.image_transform(image)} + + +def get_train_dataset_and_collate(args, image_transform): + if args.dataset_name is not None: + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + data_dir=args.train_data_dir, + streaming=args.streaming, + ) + else: + if args.train_data_dir is None: + raise ValueError("Provide either `--dataset_name` or `--train_data_dir`.") + data_files = {"train": os.path.join(args.train_data_dir, "**")} + dataset = load_dataset( + "imagefolder", data_files=data_files, cache_dir=args.cache_dir, streaming=args.streaming + ) + + train_dataset = dataset["train"] + + if args.streaming: + image_column = args.image_column + train_dataset = HFStreamingImageDataset( + train_dataset, image_column=image_column, image_transform=image_transform + ) + else: + column_names = train_dataset.column_names + image_column = args.image_column if args.image_column in column_names else column_names[0] + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["pixel_values"] = [image_transform(image) for image in images] + return examples + + train_dataset = train_dataset.with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + return {"pixel_values": pixel_values.contiguous().float()} + + return train_dataset, collate_fn + + +def _save_pretrained_base_model(accelerator: Accelerator, base_model: nn.Module, output_dir): + output_dir = Path(output_dir) + accelerator.unwrap_model(base_model).save_pretrained(output_dir / "base_model") + + +def main(): + args = parse_args() + if args.seed is not None: + set_seed(args.seed) + + output_dir = Path(args.output_dir) + logging_dir = output_dir / args.logging_dir + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_dir)) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state) + + if accelerator.is_main_process: + output_dir.mkdir(parents=True, exist_ok=True) + logger.info("Starting UL Stage-2 base training") + logger.info(f"Output dir: {output_dir}") + + image_transform = build_transforms(args) + train_dataset, collate_fn = get_train_dataset_and_collate(args, image_transform) + train_dataloader = DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=not args.streaming, + collate_fn=collate_fn, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True, + ) + + latent_size = args.resolution // args.latent_downsample_factor + if latent_size < 1: + raise ValueError("`resolution // latent_downsample_factor` must be >= 1.") + + encoder_path = Path(args.stage1_encoder_path) + if encoder_path.is_dir(): + encoder = AutoencoderULEncoder.from_pretrained(str(encoder_path)) + else: + encoder = AutoencoderULEncoder(in_channels=3, latent_channels=args.latent_channels) + state_dict = torch.load(args.stage1_encoder_path, map_location="cpu") + missing, unexpected = encoder.load_state_dict(state_dict, strict=False) + if len(missing) > 0: + logger.warning(f"Missing keys when loading stage-1 encoder: {missing}") + if len(unexpected) > 0: + logger.warning(f"Unexpected keys when loading stage-1 encoder: {unexpected}") + encoder.requires_grad_(False) + encoder.eval() + + base_model = ULTwoStageBaseModel( + latent_channels=args.latent_channels, + latent_size=latent_size, + num_train_timesteps=args.num_train_timesteps, + stage_a_layers=args.base_stage_a_layers, + stage_b_layers=args.base_stage_b_layers, + stage_a_heads=args.base_stage_a_heads, + stage_a_head_dim=args.base_stage_a_head_dim, + stage_b_heads=args.base_stage_b_heads, + stage_b_head_dim=args.base_stage_b_head_dim, + patch_size=args.base_patch_size, + ) + + optimizer = torch.optim.AdamW( + base_model.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + has_dataloader_length = True + try: + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + except TypeError: + has_dataloader_length = False + num_update_steps_per_epoch = None + if args.max_train_steps is None or args.max_train_steps <= 0: + if not has_dataloader_length: + raise ValueError("For streaming datasets, set `--max_train_steps` to a positive value.") + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + + lr_scheduler = get_scheduler( + name=args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + encoder, base_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + encoder, base_model, optimizer, train_dataloader, lr_scheduler + ) + + if has_dataloader_length: + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + else: + args.num_train_epochs = max(args.num_train_epochs, 1) + + if accelerator.is_main_process: + accelerator.init_trackers(args.tracker_project_name, config=vars(args)) + + logger.info("***** Running stage-2 base training *****") + if has_dataloader_length: + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + else: + logger.info(" Num examples = streaming (unknown)") + logger.info(" Num batches each epoch = streaming (unknown)") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info( + " Total train batch size (w. parallel, distributed & accumulation) = " + f"{args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps}" + ) + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + + global_step = 0 + first_epoch = 0 + + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + dirs = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else [] + dirs = [d for d in dirs if d.startswith("checkpoint-")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") + else: + logger.info(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch if has_dataloader_length else 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=global_step, + desc="Steps", + disable=not accelerator.is_local_main_process, + ) + + train_loss = 0.0 + + for epoch in range(first_epoch, args.num_train_epochs): + base_model.train() + + for _, batch in enumerate(train_dataloader): + with accelerator.accumulate(base_model): + x = batch["pixel_values"].to(accelerator.device) + bsz = x.shape[0] + + with torch.no_grad(): + z_clean = encoder.encode(x).latent + + # Stage-2 uses clean encoder means as training targets. + z_target = z_clean + t = ul_sample_t(bsz, x.device) + lambda_t = ul_logsnr_schedule( + t, + schedule_type=args.base_schedule, + lambda_min=args.lambda_z_min, + lambda_max=args.lambda_z_max, + ) + dlogsnr_dt = ul_dlogsnr_dt( + t, + schedule_type=args.base_schedule, + lambda_min=args.lambda_z_min, + lambda_max=args.lambda_z_max, + ) + prefactor = ul_elbo_prefactor(lambda_t, dlogsnr_dt) + alpha_t, sigma_t = ul_alpha_sigma_from_logsnr(lambda_t) + eps = torch.randn_like(z_target) + z_t = alpha_t[:, None, None, None] * z_target + sigma_t[:, None, None, None] * eps + + timesteps = (t * (args.num_train_timesteps - 1)).long().clamp(0, args.num_train_timesteps - 1) + dummy_labels = torch.zeros((bsz,), device=x.device, dtype=torch.long) + + # Preferred stage-2 parameterization: predict velocity v. + v_pred = base_model(z_t, timesteps, dummy_labels) + z_target_hat = alpha_t[:, None, None, None] * z_t - sigma_t[:, None, None, None] * v_pred + + per_sample = F.mse_loss(z_target_hat.float(), z_target.float(), reduction="none").sum(dim=(1, 2, 3)) + weights = ul_decoder_loss_weight( + lambda_t, + bias=args.sigmoid_bias, + loss_factor=args.loss_factor, + invert=False, + ) + loss_raw = (prefactor * weights * per_sample).mean() + num_pixels = x.shape[-2] * x.shape[-1] + bpp_denom = num_pixels * math.log(2.0) + loss = loss_raw / bpp_denom + + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(base_model.parameters(), args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + avg_loss = accelerator.gather(loss.detach().repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + if accelerator.sync_gradients: + global_step += 1 + progress_bar.update(1) + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if accelerator.is_main_process and global_step % args.save_steps == 0: + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint-")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[:num_to_remove] + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)}" + ) + for removing_checkpoint in removing_checkpoints: + shutil.rmtree(os.path.join(args.output_dir, removing_checkpoint)) + + save_dir = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_dir) + _save_pretrained_base_model( + accelerator=accelerator, + base_model=base_model, + output_dir=save_dir, + ) + logger.info(f"Saved state to {save_dir}") + + logs = { + "step_loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + final_dir = output_dir / "final" + final_dir.mkdir(parents=True, exist_ok=True) + torch.save(accelerator.unwrap_model(base_model).state_dict(), final_dir / "base_model.pt") + _save_pretrained_base_model( + accelerator=accelerator, + base_model=base_model, + output_dir=final_dir, + ) + logger.info(f"Training finished. Saved final checkpoint to {final_dir / 'base_model.pt'}") + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/examples/research_projects/unified_latents/ul_models.py b/examples/research_projects/unified_latents/ul_models.py new file mode 100644 index 000000000000..0c01d00ffa97 --- /dev/null +++ b/examples/research_projects/unified_latents/ul_models.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# coding=utf-8 + +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.transformers.dit_transformer_2d import DiTTransformer2DModel + + +class ULTwoStageBaseModel(ModelMixin, ConfigMixin): + """ + Approximation of UL Sec. 5.1 stage-2 base model: + - Stage A: ViT-like latent denoiser, width ~512, 6 layers + - Stage B: ViT-like latent denoiser, width ~1024, 16 layers + + Both stages are DiT blocks and predict residual denoised latents. + """ + + @register_to_config + def __init__( + self, + latent_channels: int, + latent_size: int, + num_train_timesteps: int, + stage_a_layers: int = 6, + stage_b_layers: int = 16, + stage_a_heads: int = 8, + stage_a_head_dim: int = 64, # 512 width + stage_b_heads: int = 16, + stage_b_head_dim: int = 64, # 1024 width + patch_size: int = 1, + ): + super().__init__() + self.stage_a = DiTTransformer2DModel( + num_attention_heads=stage_a_heads, + attention_head_dim=stage_a_head_dim, + in_channels=latent_channels, + out_channels=latent_channels, + num_layers=stage_a_layers, + sample_size=latent_size, + patch_size=patch_size, + num_embeds_ada_norm=num_train_timesteps, + dropout=0.1, + ) + self.stage_b = DiTTransformer2DModel( + num_attention_heads=stage_b_heads, + attention_head_dim=stage_b_head_dim, + in_channels=latent_channels, + out_channels=latent_channels, + num_layers=stage_b_layers, + sample_size=latent_size, + patch_size=patch_size, + num_embeds_ada_norm=num_train_timesteps, + dropout=0.1, + ) + + def forward(self, z_t: torch.Tensor, timesteps: torch.LongTensor, class_labels: torch.LongTensor) -> torch.Tensor: + h = self.stage_a(z_t, timestep=timesteps, class_labels=class_labels).sample + out = self.stage_b(h, timestep=timesteps, class_labels=class_labels).sample + return out diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 7d966452d1a2..5eb1fa40dd06 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -340,10 +340,21 @@ "StableDiffusionMixin", ] ) + _import_structure["pipelines"].extend( + [ + "BD3LMPipeline", + "BD3LMPipelineOutput", + "HybridTokenDiffusionPipeline", + "TokenDiffusionPipeline", + "TokenDiffusionPipelineOutput", + ] + ) _import_structure["quantizers"] = ["DiffusersQuantizer"] _import_structure["schedulers"].extend( [ "AmusedScheduler", + "BD3LMTokenDiffusionScheduler", + "BD3LMTokenDiffusionSchedulerOutput", "BlockRefinementScheduler", "BlockRefinementSchedulerOutput", "CMStochasticIterativeScheduler", @@ -356,6 +367,8 @@ "DDPMScheduler", "DDPMWuerstchenScheduler", "DEISMultistepScheduler", + "DFlashTokenDiffusionScheduler", + "DFlashTokenDiffusionSchedulerOutput", "DPMSolverMultistepInverseScheduler", "DPMSolverMultistepScheduler", "DPMSolverSinglestepScheduler", @@ -369,6 +382,8 @@ "HeliosDMDScheduler", "HeliosScheduler", "HeunDiscreteScheduler", + "HybridTokenDiffusionScheduler", + "HybridTokenDiffusionSchedulerOutput", "IPNDMScheduler", "KarrasVeScheduler", "KDPM2AncestralDiscreteScheduler", @@ -381,7 +396,10 @@ "SchedulerMixin", "SCMScheduler", "ScoreSdeVeScheduler", + "SDARTokenDiffusionScheduler", + "SDARTokenDiffusionSchedulerOutput", "TCDScheduler", + "TokenDiffusionScheduler", "UnCLIPScheduler", "UniPCMultistepScheduler", "VQDiffusionScheduler", @@ -511,6 +529,8 @@ "CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline", "CycleDiffusionPipeline", + "DFlashPipeline", + "DFlashPipelineOutput", "EasyAnimateControlPipeline", "EasyAnimateInpaintPipeline", "EasyAnimatePipeline", @@ -631,6 +651,8 @@ "SanaSprintPipeline", "SanaVideoPipeline", "SanaVideoPipeline", + "SDARPipeline", + "SDARPipelineOutput", "SemanticStableDiffusionPipeline", "ShapEImg2ImgPipeline", "ShapEPipeline", @@ -1107,6 +1129,8 @@ AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image, + BD3LMPipeline, + BD3LMPipelineOutput, BlipDiffusionControlNetPipeline, BlipDiffusionPipeline, CLIPImageProjection, @@ -1116,6 +1140,7 @@ DDPMPipeline, DiffusionPipeline, DiTPipeline, + HybridTokenDiffusionPipeline, ImagePipelineOutput, KarrasVePipeline, LDMPipeline, @@ -1124,10 +1149,14 @@ RePaintPipeline, ScoreSdeVePipeline, StableDiffusionMixin, + TokenDiffusionPipeline, + TokenDiffusionPipelineOutput, ) from .quantizers import DiffusersQuantizer from .schedulers import ( AmusedScheduler, + BD3LMTokenDiffusionScheduler, + BD3LMTokenDiffusionSchedulerOutput, BlockRefinementScheduler, BlockRefinementSchedulerOutput, CMStochasticIterativeScheduler, @@ -1140,6 +1169,8 @@ DDPMScheduler, DDPMWuerstchenScheduler, DEISMultistepScheduler, + DFlashTokenDiffusionScheduler, + DFlashTokenDiffusionSchedulerOutput, DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, @@ -1153,6 +1184,8 @@ HeliosDMDScheduler, HeliosScheduler, HeunDiscreteScheduler, + HybridTokenDiffusionScheduler, + HybridTokenDiffusionSchedulerOutput, IPNDMScheduler, KarrasVeScheduler, KDPM2AncestralDiscreteScheduler, @@ -1165,7 +1198,10 @@ SchedulerMixin, SCMScheduler, ScoreSdeVeScheduler, + SDARTokenDiffusionScheduler, + SDARTokenDiffusionSchedulerOutput, TCDScheduler, + TokenDiffusionScheduler, UnCLIPScheduler, UniPCMultistepScheduler, VQDiffusionScheduler, @@ -1274,6 +1310,8 @@ CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline, CycleDiffusionPipeline, + DFlashPipeline, + DFlashPipelineOutput, EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, EasyAnimatePipeline, @@ -1393,6 +1431,8 @@ SanaSprintImg2ImgPipeline, SanaSprintPipeline, SanaVideoPipeline, + SDARPipeline, + SDARPipelineOutput, SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3dafb56fdd65..2c9da4570cdf 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -47,11 +47,14 @@ "AutoPipelineForInpainting", "AutoPipelineForText2Image", ] + _import_structure["bd3lm"] = ["BD3LMPipeline", "BD3LMPipelineOutput"] _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] _import_structure["ddim"] = ["DDIMPipeline"] _import_structure["ddpm"] = ["DDPMPipeline"] _import_structure["dit"] = ["DiTPipeline"] + _import_structure["hybrid_token_diffusion"] = ["HybridTokenDiffusionPipeline"] + _import_structure["token_diffusion"] = ["TokenDiffusionPipeline", "TokenDiffusionPipelineOutput"] _import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"]) _import_structure["pipeline_utils"] = [ "AudioPipelineOutput", @@ -416,6 +419,9 @@ "Kandinsky5T2IPipeline", "Kandinsky5I2IPipeline", ] + _import_structure["dflash"] = ["DFlashPipeline", "DFlashPipelineOutput"] + _import_structure["sdar"] = ["SDARPipeline", "SDARPipelineOutput"] + _import_structure["llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"] _import_structure["z_image"] = [ "ZImageControlNetInpaintPipeline", "ZImageControlNetPipeline", @@ -543,12 +549,14 @@ AutoPipelineForInpainting, AutoPipelineForText2Image, ) + from .bd3lm import BD3LMPipeline, BD3LMPipelineOutput from .consistency_models import ConsistencyModelPipeline from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline from .ddpm import DDPMPipeline from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline from .dit import DiTPipeline + from .hybrid_token_diffusion import HybridTokenDiffusionPipeline from .latent_diffusion import LDMSuperResolutionPipeline from .pipeline_utils import ( AudioPipelineOutput, @@ -556,6 +564,7 @@ ImagePipelineOutput, StableDiffusionMixin, ) + from .token_diffusion import TokenDiffusionPipeline, TokenDiffusionPipelineOutput try: if not (is_torch_available() and is_librosa_available()): @@ -651,6 +660,7 @@ VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, ) + from .dflash import DFlashPipeline, DFlashPipelineOutput from .easyanimate import ( EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, @@ -792,6 +802,7 @@ SanaSprintPipeline, ) from .sana_video import SanaImageToVideoPipeline, SanaVideoPipeline + from .sdar import SDARPipeline, SDARPipelineOutput from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_audio import StableAudioPipeline, StableAudioProjectionModel diff --git a/src/diffusers/pipelines/bd3lm/__init__.py b/src/diffusers/pipelines/bd3lm/__init__.py new file mode 100644 index 000000000000..18f81b69d58d --- /dev/null +++ b/src/diffusers/pipelines/bd3lm/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_bd3lm"] = ["BD3LMPipeline", "BD3LMPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_bd3lm import BD3LMPipeline, BD3LMPipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/bd3lm/pipeline_bd3lm.py b/src/diffusers/pipelines/bd3lm/pipeline_bd3lm.py new file mode 100644 index 000000000000..0459d26a0ac4 --- /dev/null +++ b/src/diffusers/pipelines/bd3lm/pipeline_bd3lm.py @@ -0,0 +1,382 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +import torch +from tqdm.auto import tqdm + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...schedulers import BD3LMTokenDiffusionScheduler +from ...utils import BaseOutput, logging, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline, DiscreteDiffusionPipelineMixin + + +logger = logging.get_logger(__name__) + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from transformers import AutoModelForMaskedLM, GPT2TokenizerFast + >>> from diffusers import BD3LMPipeline, BD3LMTokenDiffusionScheduler + + >>> model_id = "kuleshov-group/bd3lm-owt-block_size4" + >>> model = AutoModelForMaskedLM.from_pretrained(model_id, trust_remote_code=True, dtype=torch.bfloat16).cuda() + >>> tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + >>> scheduler = BD3LMTokenDiffusionScheduler( + ... block_size=model.config.block_size, + ... mask_token_id=model.config.vocab_size - 1, + ... ) + + >>> pipe = BD3LMPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer) + >>> output = pipe(gen_length=64, num_inference_steps=64, nucleus_p=0.9) + >>> print(output.texts[0]) + ``` +""" + + +@dataclass +class BD3LMPipelineOutput(BaseOutput): + sequences: torch.LongTensor + texts: list[str] | None = None + + +class BD3LMPipeline(DiffusionPipeline, DiscreteDiffusionPipelineMixin): + r""" + Pipeline for BD3LM (Block Discrete Denoising Diffusion Language Model) text generation via semi-autoregressive + block diffusion. + + BD3LM generates text by autoregressively appending masked blocks and denoising each block via discrete DDPM + updates. At each stride, a new block of mask tokens is appended and iteratively denoised using the model's + predicted token probabilities. + + The model is expected to accept `(input_ids, timesteps, sample_mode)` and return logits of shape `[batch, + block_length, vocab_size]` when in sample mode. + """ + + model: Any + scheduler: BD3LMTokenDiffusionScheduler + tokenizer: Any + + _callback_tensor_inputs = ["x_accum"] + + def __init__( + self, + model: Any, + scheduler: BD3LMTokenDiffusionScheduler, + tokenizer: Any | None = None, + ): + super().__init__() + self.register_modules(model=model, scheduler=scheduler, tokenizer=tokenizer) + + # Resolve mask token ID: model.config.mask_index > vocab_size - 1 (BD3LM convention) > tokenizer + self.mask_token_id: int | None = None + if hasattr(self.model, "config"): + self.mask_token_id = getattr(self.model.config, "mask_index", None) + if self.mask_token_id is None: + vocab_size = getattr(self.model.config, "vocab_size", None) + if vocab_size is not None: + self.mask_token_id = vocab_size - 1 + if self.mask_token_id is None and self.tokenizer is not None: + self.mask_token_id = getattr(self.tokenizer, "mask_token_id", None) + + self.eos_token_id = getattr(self.tokenizer, "eos_token_id", None) if self.tokenizer is not None else None + + @property + def num_timesteps(self): + return self._num_timesteps + + def check_inputs( + self, + prompt: str | list[str] | None, + input_ids: torch.LongTensor | None, + gen_length: int, + block_length: int, + num_inference_steps: int, + nucleus_p: float, + output_type: str, + callback_on_step_end: Callable | PipelineCallback | MultiPipelineCallbacks | None, + callback_on_step_end_tensor_inputs: list[str] | None, + ): + if input_ids is not None: + if input_ids.ndim not in (1, 2): + raise ValueError(f"`input_ids` must be 1D or 2D, got shape {tuple(input_ids.shape)}.") + if input_ids.dtype != torch.long: + raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.") + if prompt is not None and input_ids is None and self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + + if gen_length <= 0: + raise ValueError(f"`gen_length` must be > 0, got {gen_length}.") + if block_length <= 0: + raise ValueError(f"`block_length` must be > 0, got {block_length}.") + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + if not (0.0 < nucleus_p <= 1.0): + raise ValueError(f"`nucleus_p` must be in (0, 1], got {nucleus_p}.") + if output_type not in {"seq", "text"}: + raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.") + + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + input_ids: torch.LongTensor | None = None, + gen_length: int = 256, + block_length: int | None = None, + num_inference_steps: int = 64, + nucleus_p: float = 1.0, + mask_token_id: int | None = None, + eos_token_id: int | None = None, + eos_early_stop: bool = True, + generator: torch.Generator | None = None, + output_type: str = "text", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int, dict], None] + | PipelineCallback + | MultiPipelineCallbacks + | None = None, + callback_on_step_end_tensor_inputs: list[str] | None = None, + ) -> BD3LMPipelineOutput | tuple[torch.LongTensor, list[str] | None]: + """ + Generate text with BD3LM semi-autoregressive block diffusion. + + Args: + prompt (`str` or `List[str]`, *optional*): + Prompt text. When provided, it is tokenized and used as a prefix for generation. + input_ids (`torch.LongTensor`, *optional*): + Pre-tokenized input IDs. Takes precedence over `prompt`. + gen_length (`int`, defaults to `256`): + Total number of new tokens to generate. + block_length (`int`, *optional*): + Block size for diffusion. If not provided, reads `model.config.block_size`. + num_inference_steps (`int`, defaults to `64`): + Number of DDPM denoising steps per block. + nucleus_p (`float`, defaults to `1.0`): + Nucleus sampling probability threshold. Set to `1.0` to disable nucleus filtering. + mask_token_id (`int`, *optional*): + Mask token ID. Resolved from model config or scheduler config if not provided. + eos_token_id (`int`, *optional*): + EOS token ID for early stopping. + eos_early_stop (`bool`, defaults to `True`): + Whether to stop generation when EOS is produced. + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + output_type (`str`, defaults to `"text"`): + Output format. `"text"` decodes sequences into strings. `"seq"` returns raw token IDs. + return_dict (`bool`, defaults to `True`): + Whether to return a [`BD3LMPipelineOutput`] instead of a tuple. + callback_on_step_end (`Callable` or `PipelineCallback`, *optional*): + Callback executed after each denoising step. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + Tensor keys to pass to the callback. Allowed keys: `x_accum`. + + Examples: + """ + # 1. Resolve defaults and check inputs early + if block_length is None: + block_length = getattr(getattr(self.model, "config", None), "block_size", None) + if block_length is None: + raise ValueError("`block_length` must be provided or available as `model.config.block_size`.") + + if mask_token_id is None: + mask_token_id = self.mask_token_id + if mask_token_id is None: + mask_token_id = self.scheduler.config.mask_token_id + if mask_token_id is None: + raise ValueError("`mask_token_id` must be provided (or available via model config or scheduler config).") + + if eos_token_id is None: + eos_token_id = self.eos_token_id + + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["x_accum"] + + self.check_inputs( + prompt=prompt, + input_ids=input_ids, + gen_length=gen_length, + block_length=block_length, + num_inference_steps=num_inference_steps, + nucleus_p=nucleus_p, + output_type=output_type, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + device = self._execution_device + + # 2. Prepare prompt IDs (prefix) + if prompt is not None or input_ids is not None: + prompt_ids = self._prepare_input_ids( + prompt=prompt, + messages=None, + input_ids=input_ids, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + if prompt_ids.ndim == 1: + prompt_ids = prompt_ids.unsqueeze(0) + prompt_ids = prompt_ids.to(device=device) + batch_size = prompt_ids.shape[0] + else: + batch_size = 1 + bos_id = self._resolve_start_token_id() + if bos_id is None: + raise ValueError( + "No prompt provided and no BOS token found on the tokenizer. " + "Provide `prompt`, `input_ids`, or a tokenizer with a BOS token." + ) + prompt_ids = torch.tensor([[bos_id]], device=device, dtype=torch.long) + + prompt_length = prompt_ids.shape[1] + + # 3. Set up scheduler timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + num_strides = (gen_length + block_length - 1) // block_length + self._num_timesteps = num_inference_steps * num_strides + + # 4. Semi-autoregressive block diffusion loop + finished = torch.zeros((batch_size,), device=device, dtype=torch.bool) + global_step = 0 + x_accum: torch.LongTensor = prompt_ids + + # Initialize KV cache for cross-block context + if hasattr(self.model, "reset_kv_cache"): + self.model.reset_kv_cache(eval_batch_size=batch_size) + + block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy() + block_progress_bar_config["position"] = 0 + block_progress_bar_config["desc"] = "Blocks" + + for stride_num in tqdm(range(num_strides), **block_progress_bar_config): + # Append a new masked block + masked_block = torch.full((batch_size, block_length), mask_token_id, device=device, dtype=torch.long) + x_accum = torch.cat([x_accum, masked_block], dim=1) + + # DDPM denoising steps within this block + p_x0_cache = None + + self.set_progress_bar_config(position=1, leave=False, desc=f"Block {stride_num} Denoising") + progress_bar = self.progress_bar(total=num_inference_steps) + + for step_idx in range(num_inference_steps): + # Check if all mask tokens in current block are resolved + if self.scheduler.check_should_stop(x_accum[:, -block_length:], mask_token_id): + progress_bar.update(num_inference_steps - step_idx) + break + + t = self.scheduler.timesteps[step_idx] + + # Get model predictions only when p_x0 cache is invalidated + if p_x0_cache is None: + sigma_t = self.scheduler.compute_sigma(t, batch_size) + model_input = x_accum[:, -block_length:] + model_output = self.model( + input_ids=model_input, + timesteps=sigma_t.to(model_input.device), + sample_mode=True, + ) + logits = model_output.logits if hasattr(model_output, "logits") else model_output + else: + # Reuse cached p_x0 distribution (convert back to log-probs for scheduler) + logits = p_x0_cache.log() + + # Scheduler step: DDPM update with subs parameterization + scheduler_output = self.scheduler.step( + model_output=logits, + timestep=t, + sample=x_accum, + mask_token_id=mask_token_id, + nucleus_p=nucleus_p, + generator=generator, + return_dict=True, + ) + + x_accum = scheduler_output.prev_sample + p_x0_cache = scheduler_output.p_x0_cache + + # Callback + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, global_step, step_idx, callback_kwargs) + x_accum = callback_outputs.pop("x_accum", x_accum) + + global_step += 1 + progress_bar.update(1) + + progress_bar.close() + + # Store denoised block's KVs into cache for cross-block context + if hasattr(self.model, "reset_kv_cache"): + denoised_block = x_accum[:, -block_length:] + sigma_store = self.scheduler.compute_sigma(self.scheduler.timesteps[0], batch_size) + self.model( + input_ids=denoised_block, + timesteps=sigma_store.to(denoised_block.device), + sample_mode=True, + store_kv=True, + ) + + # EOS early stopping (delegated to scheduler) + if eos_early_stop and eos_token_id is not None: + finished = self.scheduler.check_eos_finished(x_accum, prompt_length, eos_token_id, finished) + if finished.all(): + break + + # 5. Post-process output + total_generated = x_accum.shape[1] - prompt_length + trim_length = min(total_generated, gen_length) + sequences = x_accum[:, prompt_length : prompt_length + trim_length] + + if eos_token_id is not None and batch_size == 1: + eos_positions = (sequences[0] == eos_token_id).nonzero(as_tuple=True)[0] + if len(eos_positions) > 0: + sequences = sequences[:, : int(eos_positions[0].item()) + 1] + + texts = None + if output_type == "text" and self.tokenizer is not None: + texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + + if not return_dict: + return sequences.to(device=device), texts + return BD3LMPipelineOutput(sequences=sequences.to(device=device), texts=texts) + + +__all__ = ["BD3LMPipeline", "BD3LMPipelineOutput"] diff --git a/src/diffusers/pipelines/dflash/__init__.py b/src/diffusers/pipelines/dflash/__init__.py new file mode 100644 index 000000000000..c5d0f5fae4cd --- /dev/null +++ b/src/diffusers/pipelines/dflash/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_dflash"] = ["DFlashPipeline", "DFlashPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_dflash import DFlashPipeline, DFlashPipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/dflash/pipeline_dflash.py b/src/diffusers/pipelines/dflash/pipeline_dflash.py new file mode 100644 index 000000000000..544f6b8f62ee --- /dev/null +++ b/src/diffusers/pipelines/dflash/pipeline_dflash.py @@ -0,0 +1,495 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +import torch +from tqdm.auto import tqdm +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, DynamicCache + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...schedulers import DFlashTokenDiffusionScheduler +from ...utils import BaseOutput, logging, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline, DiscreteDiffusionPipelineMixin + + +logger = logging.get_logger(__name__) + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import DFlashPipeline + + >>> draft_id = "z-lab/Qwen3-8B-DFlash-b16" + >>> target_id = "Qwen/Qwen3-8B" + >>> pipe = DFlashPipeline.from_pretrained( + ... draft_model_id=draft_id, + ... target_model_id=target_id, + ... draft_model_kwargs={"trust_remote_code": True, "dtype": torch.bfloat16}, + ... target_model_kwargs={"dtype": torch.bfloat16}, + ... ) + >>> out = pipe(prompt="How many positive whole-number divisors does 196 have?") + >>> print(out.texts[0]) + ``` +""" + + +@dataclass +class DFlashPipelineOutput(BaseOutput): + sequences: torch.LongTensor + texts: list[str] | None = None + + +def _build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> list[int]: + if num_draft_layers == 1: + return [int(num_target_layers // 2)] + start = 1 + end = int(num_target_layers) - 3 + span = end - start + return [int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(int(num_draft_layers))] + + +def _extract_context_feature(hidden_states: list[torch.Tensor], layer_ids: list[int]) -> torch.Tensor: + offset = 1 + selected_states = [hidden_states[layer_id + offset] for layer_id in layer_ids] + return torch.cat(selected_states, dim=-1) + + +class DFlashPipeline(DiffusionPipeline, DiscreteDiffusionPipelineMixin): + r""" + Block diffusion pipeline for speculative decoding with a DFlash draft model and a target causal LM. + """ + + draft_model: Any + target_model: Any + tokenizer: Any + scheduler: DFlashTokenDiffusionScheduler + _callback_tensor_inputs = ["block_output_ids", "draft_logits", "accepted_length", "next_token", "output_ids"] + + def __init__( + self, + draft_model: torch.nn.Module, + target_model: torch.nn.Module, + tokenizer: Any | None = None, + scheduler: DFlashTokenDiffusionScheduler | None = None, + ): + super().__init__() + if scheduler is None: + scheduler = DFlashTokenDiffusionScheduler() + self.register_modules( + draft_model=draft_model, target_model=target_model, tokenizer=tokenizer, scheduler=scheduler + ) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str | None = None, + *, + draft_model_id: str | None = None, + target_model_id: str | None = None, + tokenizer_id: str | None = None, + mask_token: str | None = "<|MASK|>", + scheduler: DFlashTokenDiffusionScheduler | None = None, + draft_model_kwargs: dict[str, object] | None = None, + target_model_kwargs: dict[str, object] | None = None, + tokenizer_kwargs: dict[str, object] | None = None, + **pipeline_kwargs, + ) -> "DFlashPipeline": + if draft_model_id is None and target_model_id is None and pretrained_model_name_or_path is not None: + return super().from_pretrained(pretrained_model_name_or_path, **pipeline_kwargs) + + if draft_model_id is None: + if pretrained_model_name_or_path is None: + raise ValueError("Provide `draft_model_id` or `pretrained_model_name_or_path`.") + draft_model_id = str(pretrained_model_name_or_path) + if target_model_id is None: + raise ValueError("`target_model_id` must be provided when loading draft/target models separately.") + + draft_model_kwargs = dict(draft_model_kwargs or {}) + draft_model_kwargs.setdefault("trust_remote_code", True) + target_model_kwargs = dict(target_model_kwargs or {}) + tokenizer_kwargs = dict(tokenizer_kwargs or {}) + + draft = AutoModel.from_pretrained(draft_model_id, **draft_model_kwargs) + target = AutoModelForCausalLM.from_pretrained(target_model_id, **target_model_kwargs) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id or target_model_id, **tokenizer_kwargs) + + if mask_token is not None and tokenizer.mask_token_id is None: + tokenizer.add_special_tokens({"mask_token": mask_token}) + + return cls( + draft_model=draft, + target_model=target, + tokenizer=tokenizer, + scheduler=scheduler, + **pipeline_kwargs, + ) + + def check_inputs( + self, + prompt: str | list[str] | None, + messages: list[dict[str, str]] | None, + input_ids: torch.LongTensor | None, + max_new_tokens: int, + output_type: str, + callback_on_step_end: Callable | PipelineCallback | MultiPipelineCallbacks | None, + callback_on_step_end_tensor_inputs: list[str] | None, + ): + # Input source validation + if prompt is None and messages is None and input_ids is None: + raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.") + if prompt is not None and messages is not None: + raise ValueError("Provide either `prompt` or `messages`, not both.") + if input_ids is not None: + if input_ids.ndim not in (1, 2): + raise ValueError(f"`input_ids` must be 1D or 2D, got shape {tuple(input_ids.shape)}.") + if input_ids.dtype != torch.long: + raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.") + if prompt is not None and input_ids is None and self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + if messages is not None and input_ids is None and self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + + # Generation parameter validation + if max_new_tokens <= 0: + raise ValueError(f"`max_new_tokens` must be > 0, got {max_new_tokens}.") + if output_type not in {"seq", "text"}: + raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.") + + # Callback validation + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + def prepare_latents( + self, + max_length: int, + block_size: int, + mask_token_id: int, + device: torch.device, + ) -> torch.LongTensor: + return torch.full( + (1, max_length + int(block_size)), + int(mask_token_id), + dtype=torch.long, + device=device, + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + messages: list[dict[str, str]] | None = None, + input_ids: torch.LongTensor | None = None, + max_new_tokens: int = 2048, + temperature: float = 0.0, + stop_token_ids: list[int] | None = None, + mask_token_id: int | None = None, + use_chat_template: bool = True, + add_generation_prompt: bool = True, + chat_template_kwargs: dict[str, object] | None = None, + output_type: str = "text", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int, dict], None] + | PipelineCallback + | MultiPipelineCallbacks + | None = None, + callback_on_step_end_tensor_inputs: list[str] | None = None, + ) -> DFlashPipelineOutput | tuple[torch.LongTensor, list[str] | None]: + """ + Generate text using block-diffusion speculative decoding. + + Args: + prompt (`str` or `list[str]`, *optional*): + Prompt text. When `use_chat_template` is `True` (default) and a tokenizer with a chat template is + available, the prompt is wrapped in a chat message before tokenization. + messages (`list[dict[str, str]]`, *optional*): + Chat messages to encode. Takes precedence over `prompt` when provided. + input_ids (`torch.LongTensor`, *optional*): + Pre-tokenized input IDs. Takes precedence over `prompt` and `messages`. + max_new_tokens (`int`): + Maximum number of new tokens to generate. + temperature (`float`): + Sampling temperature. + stop_token_ids (`list[int]`, *optional*): + Token IDs that signal generation should stop. + mask_token_id (`int`, *optional*): + Mask token ID for the draft model. + use_chat_template (`bool`, defaults to `True`): + Whether to wrap the prompt in a chat template. + add_generation_prompt (`bool`, defaults to `True`): + Whether to add the generation prompt when using chat templates. + chat_template_kwargs (`dict[str, object]`, *optional*): + Additional keyword arguments for the chat template. + output_type (`str`, defaults to `"text"`): + Output format. `"text"` decodes sequences into strings (requires a tokenizer). `"seq"` returns raw + token ID sequences only. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`DFlashPipelineOutput`] instead of a tuple. + callback_on_step_end (`Callable` or `PipelineCallback`, *optional*): + Callback executed after each speculative decoding step. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + Tensor keys to pass to the callback. + + Examples: + """ + # 1. Check inputs early + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["block_output_ids"] + + self.check_inputs( + prompt=prompt, + messages=messages, + input_ids=input_ids, + max_new_tokens=max_new_tokens, + output_type=output_type, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + # 2. Prepare input IDs from prompt/messages/input_ids + input_ids = self._prepare_input_ids( + prompt=prompt, + messages=messages, + input_ids=input_ids, + use_chat_template=use_chat_template, + add_generation_prompt=add_generation_prompt, + chat_template_kwargs=chat_template_kwargs, + ) + + if mask_token_id is None: + mask_token_id = getattr(getattr(self, "tokenizer", None), "mask_token_id", None) + if mask_token_id is None: + # DFlash models store mask_token_id in config.dflash_config + dflash_config = getattr(getattr(self.draft_model, "config", None), "dflash_config", None) + if dflash_config is not None: + mask_token_id = dflash_config.get("mask_token_id", None) + if mask_token_id is None: + raise ValueError("`mask_token_id` must be provided (or available on the tokenizer/model config).") + if input_ids.shape[0] != 1: + raise ValueError("DFlashPipeline currently supports batch_size=1 input_ids.") + + target_params = list(self.target_model.parameters()) if hasattr(self.target_model, "parameters") else [] + device = target_params[0].device if len(target_params) > 0 else torch.device("cpu") + input_ids = input_ids.to(device=device) + draft_params = list(self.draft_model.parameters()) if hasattr(self.draft_model, "parameters") else [] + draft_device = draft_params[0].device if len(draft_params) > 0 else device + if draft_device != device: + logger.warning( + "Draft model is on %s while target model is on %s. For best performance, place both on the same device.", + draft_device, + device, + ) + + if stop_token_ids is None: + eos_token_id = getattr(getattr(self, "tokenizer", None), "eos_token_id", None) + stop_token_ids = [int(eos_token_id)] if eos_token_id is not None else None + if stop_token_ids is not None: + stop_token_ids = [int(token_id) for token_id in stop_token_ids] + + # 3. Setup models and scheduler + self.draft_model.eval() + self.target_model.eval() + self.scheduler.set_timesteps(1, device=device) + + block_size = self._get_block_size() + target_layer_ids = self._get_target_layer_ids() + input_embeddings = self._get_target_input_embeddings() + output_embeddings = self._get_target_output_embeddings() + + num_input_tokens = input_ids.shape[1] + max_length = num_input_tokens + int(max_new_tokens) + + output_ids = self.prepare_latents(max_length, block_size, int(mask_token_id), device) + position_ids = torch.arange(output_ids.shape[1], device=device).unsqueeze(0) + + past_key_values_target = DynamicCache() + past_key_values_draft = DynamicCache() + + # 4. Prefill step + output = self._target_forward( + input_ids=input_ids, + position_ids=position_ids[:, :num_input_tokens], + past_key_values=past_key_values_target, + output_hidden_states=True, + logits_to_keep=1, + ) + output_ids[:, :num_input_tokens] = input_ids + output_ids[:, num_input_tokens : num_input_tokens + 1] = self.scheduler.sample( + output.logits[:, -1:], temperature=temperature + ) + target_hidden = _extract_context_feature(output.hidden_states, target_layer_ids) + + start = num_input_tokens + global_step = 0 + num_blocks = (max_length - num_input_tokens + block_size - 1) // block_size + + # 5. Block-wise speculative decoding loop + block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy() + block_progress_bar_config["position"] = 0 + block_progress_bar_config["desc"] = "Blocks" + block_iter = tqdm(range(num_blocks), **block_progress_bar_config) + + for _block_idx in block_iter: + if start >= max_length: + break + + block_output_ids = output_ids[:, start : start + int(block_size)].clone() + block_position_ids = position_ids[:, start : start + int(block_size)] + noise_embedding = input_embeddings(block_output_ids) + draft_hidden = self.draft_model( + target_hidden=target_hidden, + noise_embedding=noise_embedding, + position_ids=position_ids[:, past_key_values_draft.get_seq_length() : start + int(block_size)], + past_key_values=past_key_values_draft, + use_cache=True, + is_causal=False, + ) + if not torch.is_tensor(draft_hidden): + draft_hidden = getattr(draft_hidden, "last_hidden_state", draft_hidden[0]) + draft_logits = output_embeddings(draft_hidden[:, -int(block_size) + 1 :, :]) + past_key_values_draft.crop(start) + block_output_ids[:, 1:] = self.scheduler.sample(draft_logits, temperature=temperature) + + output = self._target_forward( + input_ids=block_output_ids, + position_ids=block_position_ids, + past_key_values=past_key_values_target, + output_hidden_states=True, + logits_to_keep=None, + ) + step_output = self.scheduler.step( + block_output_ids, output.logits, temperature=temperature, return_dict=True + ) + accepted_length = step_output.accepted_length + next_token = step_output.next_token + acceptance_length = int(step_output.accepted_length[0].item()) + output_ids[:, start : start + acceptance_length + 1] = block_output_ids[:, : acceptance_length + 1] + output_ids[:, start + acceptance_length + 1] = step_output.next_token + start += acceptance_length + 1 + past_key_values_target.crop(start) + target_hidden = _extract_context_feature(output.hidden_states, target_layer_ids)[ + :, : acceptance_length + 1, : + ] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, global_step, 0, callback_kwargs) + output_ids = callback_outputs.pop("output_ids", output_ids) + global_step += 1 + + if self.scheduler.check_should_stop(output_ids, stop_token_ids, num_input_tokens): + break + + # 6. Post-process output + output_ids = output_ids[:, :max_length] + output_ids = output_ids[:, output_ids[0] != int(mask_token_id)] + if stop_token_ids is not None: + stop_tensor = torch.tensor(stop_token_ids, device=device, dtype=torch.long) + stop_positions = torch.isin(output_ids[0, num_input_tokens:], stop_tensor).nonzero(as_tuple=True)[0] + if stop_positions.numel() > 0: + output_ids = output_ids[:, : num_input_tokens + int(stop_positions[0].item()) + 1] + + prompt_len = input_ids.shape[1] + sequences = output_ids[:, prompt_len:] + + texts = None + if output_type == "text" and getattr(self, "tokenizer", None) is not None: + texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + + if not return_dict: + return sequences, texts + return DFlashPipelineOutput(sequences=sequences, texts=texts) + + def _get_block_size(self) -> int: + block_size = getattr(self.draft_model, "block_size", None) + if block_size is None: + block_size = getattr(getattr(self.draft_model, "config", None), "block_size", None) + if block_size is None: + raise ValueError("`draft_model` must define `block_size` on the module or its config.") + return int(block_size) + + def _get_target_layer_ids(self) -> list[int]: + layer_ids = getattr(self.draft_model, "target_layer_ids", None) + if layer_ids is not None: + return list(layer_ids) + cfg = getattr(self.draft_model, "config", None) + num_target_layers = getattr(cfg, "num_target_layers", None) + num_hidden_layers = getattr(cfg, "num_hidden_layers", None) + if num_target_layers is None or num_hidden_layers is None: + raise ValueError("`draft_model` must define `target_layer_ids` or expose `num_target_layers` in config.") + return _build_target_layer_ids(int(num_target_layers), int(num_hidden_layers)) + + def _get_target_input_embeddings(self) -> torch.nn.Module: + embeddings = self.target_model.get_input_embeddings() + if embeddings is None: + base_model = getattr(self.target_model, "model", None) + embeddings = getattr(base_model, "embed_tokens", None) + if embeddings is None: + raise ValueError("`target_model` must provide input embeddings for DFlash decoding.") + return embeddings + + def _get_target_output_embeddings(self) -> torch.nn.Module: + embeddings = self.target_model.get_output_embeddings() + if embeddings is None: + embeddings = getattr(self.target_model, "lm_head", None) + if embeddings is None: + raise ValueError("`target_model` must provide output embeddings for DFlash decoding.") + return embeddings + + def _target_forward( + self, + *, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + past_key_values: DynamicCache, + output_hidden_states: bool, + logits_to_keep: int | None, + ): + kwargs = { + "input_ids": input_ids, + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": True, + "output_hidden_states": output_hidden_states, + } + if logits_to_keep is not None: + try: + return self.target_model(**kwargs, logits_to_keep=logits_to_keep) + except TypeError: + pass + return self.target_model(**kwargs) + + +__all__ = ["DFlashPipeline", "DFlashPipelineOutput"] diff --git a/src/diffusers/pipelines/hybrid_token_diffusion/__init__.py b/src/diffusers/pipelines/hybrid_token_diffusion/__init__.py new file mode 100644 index 000000000000..b3a1792f2a64 --- /dev/null +++ b/src/diffusers/pipelines/hybrid_token_diffusion/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pipeline_hybrid_token_diffusion import HybridTokenDiffusionPipeline, HybridTokenDiffusionPipelineOutput + + +__all__ = ["HybridTokenDiffusionPipeline", "HybridTokenDiffusionPipelineOutput"] diff --git a/src/diffusers/pipelines/hybrid_token_diffusion/pipeline_hybrid_token_diffusion.py b/src/diffusers/pipelines/hybrid_token_diffusion/pipeline_hybrid_token_diffusion.py new file mode 100644 index 000000000000..e2869eabec2b --- /dev/null +++ b/src/diffusers/pipelines/hybrid_token_diffusion/pipeline_hybrid_token_diffusion.py @@ -0,0 +1,242 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable + +import torch + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...utils import BaseOutput, logging +from ..token_diffusion.pipeline_token_diffusion import TokenDiffusionPipeline + + +logger = logging.get_logger(__name__) + + +@dataclass +class HybridTokenDiffusionPipelineOutput(BaseOutput): + """ + Output class for hybrid token diffusion pipelines. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Sampled token IDs. + texts (`list[str]`, *optional*): + Decoded texts if a tokenizer was provided and `output_type="text"`. + """ + + sequences: torch.LongTensor + texts: list[str] | None = None + + +class HybridTokenDiffusionPipeline(TokenDiffusionPipeline): + """ + Pipeline for hybrid-transition discrete token diffusion sampling. + + This pipeline extends `TokenDiffusionPipeline` with conventions aligned to LLaDA2-style pipelines: `output_type` + parameter, input validation via `check_inputs`, and progress bar support. + """ + + def check_inputs( + self, + batch_size: int, + seq_len: int, + num_inference_steps: int, + output_type: str, + callback_on_step_end: Callable | PipelineCallback | MultiPipelineCallbacks | None, + callback_on_step_end_tensor_inputs: list[str] | None, + infill_mask: torch.BoolTensor | None, + prefix_ids: torch.LongTensor | None, + ): + # Generation parameter validation + if batch_size <= 0: + raise ValueError(f"`batch_size` must be > 0, got {batch_size}.") + if seq_len <= 0: + raise ValueError(f"`seq_len` must be > 0, got {seq_len}.") + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + if output_type not in {"seq", "text"}: + raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.") + + # Callback validation + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # Conditioning validation + if infill_mask is not None and infill_mask.shape != (batch_size, seq_len): + raise ValueError(f"`infill_mask` must have shape {(batch_size, seq_len)}, got {tuple(infill_mask.shape)}.") + if prefix_ids is not None: + p = prefix_ids + if p.ndim == 1: + p = p.unsqueeze(0) + if p.ndim == 2 and p.shape[1] > seq_len: + raise ValueError(f"`prefix_ids` length {p.shape[1]} must be <= seq_len={seq_len}.") + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + seq_len: int = 64, + num_inference_steps: int = 128, + generator: torch.Generator | None = None, + prefix_ids: torch.LongTensor | None = None, + infill_mask: torch.BoolTensor | None = None, + inject_start_token: bool = False, + output_type: str = "text", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int, dict], None] + | PipelineCallback + | MultiPipelineCallbacks + | None = None, + callback_on_step_end_tensor_inputs: list[str] | None = None, + **model_kwargs, + ) -> HybridTokenDiffusionPipelineOutput | tuple[torch.LongTensor, list[str] | None]: + """ + Generate token sequences via hybrid-transition discrete diffusion. + + Args: + batch_size (`int`, defaults to `1`): + Number of sequences to generate. + seq_len (`int`, defaults to `64`): + Sequence length in tokens. + num_inference_steps (`int`, defaults to `128`): + Number of reverse diffusion steps. + generator (`torch.Generator`, *optional*): + Optional torch generator for determinism. + prefix_ids (`torch.LongTensor`, *optional*): + Optional prefix token IDs to keep fixed at the start of each sequence. Shape `[P]` or `[batch_size, + P]`. + infill_mask (`torch.BoolTensor`, *optional*): + Optional boolean mask of shape `[batch_size, seq_len]` indicating which positions are editable (`True`) + vs fixed (`False`). Fixed positions are clamped to the initial values on every step. + inject_start_token (`bool`, defaults to `False`): + If True, inject `bos_token_id` (or `cls_token_id`) into position 0 (if available). + output_type (`str`, defaults to `"text"`): + Output format. `"text"` decodes sequences into strings (requires a tokenizer). `"seq"` returns raw + token ID sequences only. + return_dict (`bool`, defaults to `True`): + Whether to return a [`HybridTokenDiffusionPipelineOutput`] instead of a tuple. + callback_on_step_end (`Callable` or `PipelineCallback`, *optional*): + A function called after each denoising step with signature `callback_on_step_end(self, step: int, + timestep: int, callback_kwargs: Dict)`. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + List of tensor keys to include in `callback_kwargs`. + model_kwargs: + Forward kwargs passed to `model(...)` (e.g. attention mask overrides). + """ + # 1. Check inputs early + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["input_ids"] + + self.check_inputs( + batch_size=batch_size, + seq_len=seq_len, + num_inference_steps=num_inference_steps, + output_type=output_type, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + infill_mask=infill_mask, + prefix_ids=prefix_ids, + ) + + # 2. Set up device and scheduler + device = self._execution_device + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + # 3. Prepare latents + input_ids = self.prepare_latents(batch_size, seq_len, generator=generator, device=device) + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + + fixed_mask = None + fixed_values = None + if infill_mask is not None: + fixed_mask = (~infill_mask.to(device=device)).to(dtype=torch.bool) + fixed_values = input_ids.clone() + + if prefix_ids is not None: + prefix_ids = self._normalize_prefix_ids(prefix_ids, batch_size=batch_size, device=device) + prefix_len = prefix_ids.shape[1] + + input_ids[:, :prefix_len] = prefix_ids + if fixed_mask is None: + fixed_mask = torch.zeros((batch_size, seq_len), device=device, dtype=torch.bool) + fixed_values = input_ids.clone() + fixed_mask[:, :prefix_len] = True + fixed_values[:, :prefix_len] = prefix_ids + + start_token_id = self._resolve_start_token_id() + if inject_start_token and start_token_id is not None: + input_ids[:, 0] = start_token_id + if fixed_mask is not None: + fixed_mask[:, 0] = True + fixed_values[:, 0] = start_token_id + + # 4. Denoising loop with progress bar + progress_bar = self.progress_bar(total=len(timesteps)) + for step_idx, t in enumerate(timesteps): + timestep = t.expand(batch_size) + out = self.model(input_ids=input_ids, timesteps=timestep, return_dict=True, **model_kwargs) + logits = getattr(out, "logits", None) + if logits is None: + # Fall back to tuple-style returns. + logits = out[0] + + input_ids = self.scheduler.step(logits, t, input_ids, generator=generator, return_dict=True).prev_sample + + if fixed_mask is not None: + input_ids = torch.where(fixed_mask, fixed_values, input_ids) + + if inject_start_token and start_token_id is not None: + input_ids[:, 0] = start_token_id + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, step_idx, t, callback_kwargs) + input_ids = callback_outputs.pop("input_ids", input_ids) + + progress_bar.update(1) + progress_bar.close() + + # 5. Post-process output + texts = None + if output_type == "text" and getattr(self, "tokenizer", None) is not None: + texts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True) + + if not return_dict: + return (input_ids, texts) + return HybridTokenDiffusionPipelineOutput(sequences=input_ids, texts=texts) + + +__all__ = ["HybridTokenDiffusionPipeline", "HybridTokenDiffusionPipelineOutput"] diff --git a/src/diffusers/pipelines/llada2/pipeline_llada2.py b/src/diffusers/pipelines/llada2/pipeline_llada2.py index d4b037ada151..7655cb92ab8e 100644 --- a/src/diffusers/pipelines/llada2/pipeline_llada2.py +++ b/src/diffusers/pipelines/llada2/pipeline_llada2.py @@ -23,7 +23,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...schedulers import BlockRefinementScheduler from ...utils import BaseOutput, logging, replace_example_docstring -from ..pipeline_utils import DiffusionPipeline +from ..pipeline_utils import DiffusionPipeline, DiscreteDiffusionPipelineMixin logger = logging.get_logger(__name__) @@ -56,7 +56,7 @@ class LLaDA2PipelineOutput(BaseOutput): texts: list[str] | None = None -class LLaDA2Pipeline(DiffusionPipeline): +class LLaDA2Pipeline(DiffusionPipeline, DiscreteDiffusionPipelineMixin): r""" Pipeline for LLaDA2-style discrete diffusion text generation via block-wise iterative refinement. @@ -88,65 +88,6 @@ def __init__( def num_timesteps(self): return self._num_timesteps - # --- Prompt encoding --- - - def _prepare_input_ids( - self, - *, - prompt: str | list[str] | None, - messages: list[dict[str, str]] | None, - input_ids: torch.LongTensor | None, - use_chat_template: bool, - add_generation_prompt: bool, - chat_template_kwargs: dict[str, Any] | None, - ) -> torch.LongTensor: - """Convert prompt/messages/input_ids to a [batch, seq] LongTensor.""" - if input_ids is not None: - if input_ids.ndim == 1: - input_ids = input_ids.unsqueeze(0) - if input_ids.ndim != 2: - raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.") - if input_ids.dtype != torch.long: - raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.") - return input_ids - - if self.tokenizer is None: - raise ValueError("Tokenizer is required when `input_ids` is not provided.") - - if messages is not None and prompt is not None: - raise ValueError("Provide either `prompt` or `messages`, not both.") - if messages is None and prompt is None: - raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.") - - chat_template_kwargs = chat_template_kwargs or {} - - if messages is not None: - encoded = self.tokenizer.apply_chat_template( - messages, - add_generation_prompt=add_generation_prompt, - tokenize=True, - return_tensors="pt", - return_dict=True, - **chat_template_kwargs, - ) - return encoded["input_ids"] - - if use_chat_template and getattr(self.tokenizer, "chat_template", None): - if isinstance(prompt, list): - raise ValueError("`prompt` must be a string when `use_chat_template=True`.") - encoded = self.tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], - add_generation_prompt=add_generation_prompt, - tokenize=True, - return_tensors="pt", - return_dict=True, - **chat_template_kwargs, - ) - return encoded["input_ids"] - - encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list)) - return encoded["input_ids"] - def check_inputs( self, prompt: str | list[str] | None, diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d675f1de04a7..ba854532a9e3 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -2383,3 +2383,100 @@ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): else: self.vae.unfuse_qkv_projections() self.fusing_vae = False + + +class DiscreteDiffusionPipelineMixin: + """Shared utilities for discrete (token) diffusion pipelines. + + Provides common helper methods for pipelines that operate on discrete token sequences, including prompt encoding, + prefix handling, and start token resolution. + """ + + def _resolve_start_token_id(self) -> "int | None": + """Resolve BOS or CLS token ID from self.tokenizer.""" + tok = getattr(self, "tokenizer", None) + if tok is None: + return None + for attr in ("bos_token_id", "cls_token_id"): + token_id = getattr(tok, attr, None) + if token_id is not None: + return int(token_id) + return None + + def _normalize_prefix_ids( + self, prefix_ids: "torch.LongTensor", batch_size: int, device: "torch.device" + ) -> "torch.LongTensor": + """Validate shape/dtype and broadcast prefix token IDs.""" + if prefix_ids.ndim == 1: + prefix_ids = prefix_ids.unsqueeze(0) + if prefix_ids.ndim != 2: + raise ValueError( + f"`prefix_ids` must have shape [prefix_len] or [batch, prefix_len], got {prefix_ids.shape}." + ) + if prefix_ids.shape[0] not in (1, batch_size): + raise ValueError( + f"`prefix_ids` batch dim must be 1 or batch_size={batch_size}, got {prefix_ids.shape[0]}." + ) + if prefix_ids.dtype != torch.long: + raise ValueError(f"`prefix_ids` must be int64 token IDs, got dtype={prefix_ids.dtype}.") + prefix_ids = prefix_ids.to(device=device) + if prefix_ids.shape[0] == 1 and batch_size > 1: + prefix_ids = prefix_ids.expand(batch_size, -1) + return prefix_ids + + def _prepare_input_ids( + self, + *, + prompt: "str | list[str] | None", + messages: "list[dict[str, str]] | None", + input_ids: "torch.LongTensor | None", + use_chat_template: bool, + add_generation_prompt: bool, + chat_template_kwargs: "dict[str, object] | None", + ) -> "torch.LongTensor": + """Convert prompt/messages/input_ids to a [batch, seq] LongTensor.""" + if input_ids is not None: + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + if input_ids.ndim != 2: + raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.") + if input_ids.dtype != torch.long: + raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.") + return input_ids + + if self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + + if messages is not None and prompt is not None: + raise ValueError("Provide either `prompt` or `messages`, not both.") + if messages is None and prompt is None: + raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.") + + chat_template_kwargs = chat_template_kwargs or {} + + if messages is not None: + encoded = self.tokenizer.apply_chat_template( + messages, + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + **chat_template_kwargs, + ) + return encoded["input_ids"] + + if use_chat_template and getattr(self.tokenizer, "chat_template", None): + if isinstance(prompt, list): + raise ValueError("`prompt` must be a string when `use_chat_template=True`.") + encoded = self.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_tensors="pt", + return_dict=True, + **chat_template_kwargs, + ) + return encoded["input_ids"] + + encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list)) + return encoded["input_ids"] diff --git a/src/diffusers/pipelines/sdar/__init__.py b/src/diffusers/pipelines/sdar/__init__.py new file mode 100644 index 000000000000..13f8e30c9962 --- /dev/null +++ b/src/diffusers/pipelines/sdar/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_sdar"] = ["SDARPipeline", "SDARPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_sdar import SDARPipeline, SDARPipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/sdar/pipeline_sdar.py b/src/diffusers/pipelines/sdar/pipeline_sdar.py new file mode 100644 index 000000000000..b63b0620420b --- /dev/null +++ b/src/diffusers/pipelines/sdar/pipeline_sdar.py @@ -0,0 +1,479 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +import torch +from tqdm.auto import tqdm +from transformers import DynamicCache + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...schedulers import SDARTokenDiffusionScheduler +from ...utils import BaseOutput, logging, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline, DiscreteDiffusionPipelineMixin + + +logger = logging.get_logger(__name__) + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from transformers import AutoModelForCausalLM, AutoTokenizer + >>> from diffusers import SDARPipeline + + >>> model_id = "JetLM/SDAR-1.7B-Chat" + >>> model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, dtype=torch.bfloat16) + >>> tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + >>> tokenizer.add_special_tokens({"mask_token": "<|MASK|>"}) + + >>> pipe = SDARPipeline(model=model, tokenizer=tokenizer) + >>> out = pipe(prompt="Explain what reinforcement learning is in simple terms.") + >>> print(out.texts[0]) + ``` +""" + + +@dataclass +class SDARPipelineOutput(BaseOutput): + sequences: torch.LongTensor + texts: list[str] | None = None + + +class SDARPipeline(DiffusionPipeline, DiscreteDiffusionPipelineMixin): + r""" + Block diffusion pipeline for SDAR-style token generation. + + This pipeline generates text by processing blocks of tokens in a semi-autoregressive fashion. Each block is + iteratively denoised using a masked diffusion process, where tokens are progressively revealed based on model + confidence. + + The model is expected to accept an attention mask and `position_ids`, and to return logits of shape `[batch, seq, + vocab_size]`. + """ + + model: Any + scheduler: SDARTokenDiffusionScheduler + tokenizer: Any + + _callback_tensor_inputs = ["cur_x", "logits", "sampled_tokens", "sampled_probs", "transfer_index"] + + def __init__( + self, + model: Any, + scheduler: SDARTokenDiffusionScheduler | None = None, + tokenizer: Any | None = None, + ): + super().__init__() + if scheduler is None: + scheduler = SDARTokenDiffusionScheduler() + self.register_modules(model=model, tokenizer=tokenizer, scheduler=scheduler) + self._store_kv_supported: bool | None = None + + @property + def num_timesteps(self): + return self._num_timesteps + + def check_inputs( + self, + prompt: str | list[str] | None, + messages: list[dict[str, str]] | None, + input_ids: torch.LongTensor | None, + block_length: int, + num_inference_steps: int, + mask_token_id: int | None, + output_type: str, + callback_on_step_end: Callable | PipelineCallback | MultiPipelineCallbacks | None, + callback_on_step_end_tensor_inputs: list[str] | None, + ): + # Input source validation + if prompt is None and messages is None and input_ids is None: + raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.") + if prompt is not None and messages is not None: + raise ValueError("Provide either `prompt` or `messages`, not both.") + if input_ids is not None: + if input_ids.ndim not in (1, 2): + raise ValueError(f"`input_ids` must be 1D or 2D, got shape {tuple(input_ids.shape)}.") + if input_ids.dtype != torch.long: + raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.") + if input_ids.ndim == 2 and input_ids.shape[0] != 1: + raise ValueError("SDARPipeline currently supports batch_size=1 input_ids.") + if prompt is not None and input_ids is None and self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + if messages is not None and input_ids is None and self.tokenizer is None: + raise ValueError("Tokenizer is required when `input_ids` is not provided.") + + # Generation parameter validation + if block_length <= 0: + raise ValueError(f"`block_length` must be > 0, got {block_length}.") + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + if mask_token_id is None: + raise ValueError("`mask_token_id` must be provided (or available on the tokenizer).") + if output_type not in {"seq", "text"}: + raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.") + + # Callback validation + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + def prepare_latents( + self, + total_length: int, + mask_token_id: int, + device: torch.device, + ) -> torch.LongTensor: + return torch.full( + (1, total_length), + int(mask_token_id), + dtype=torch.long, + device=device, + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + messages: list[dict[str, str]] | None = None, + input_ids: torch.LongTensor | None = None, + max_new_tokens: int = 256, + block_length: int = 4, + num_inference_steps: int = 4, + temperature: float = 1.0, + top_k: int = 0, + top_p: float = 1.0, + remasking_strategy: str = "low_confidence_dynamic", + confidence_threshold: float = 0.9, + entropy_threshold: float = 0.35, + stop_token_ids: list[int] | None = None, + mask_token_id: int | None = None, + use_chat_template: bool = True, + add_generation_prompt: bool = True, + chat_template_kwargs: dict[str, object] | None = None, + generator: torch.Generator | None = None, + output_type: str = "text", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int, dict], None] + | PipelineCallback + | MultiPipelineCallbacks + | None = None, + callback_on_step_end_tensor_inputs: list[str] | None = None, + ) -> SDARPipelineOutput | tuple[torch.LongTensor, list[str] | None]: + """ + Generate text using SDAR-style block diffusion decoding. + + Args: + prompt (`str` or `List[str]`, *optional*): + Prompt text. When `use_chat_template` is `True` (default) and a tokenizer with a chat template is + available, the prompt is wrapped in a chat message before tokenization. + messages (`List[Dict[str, str]]`, *optional*): + Chat messages to encode (e.g. `[{"role": "user", "content": "Hello"}]`). Takes precedence over `prompt` + when provided. Requires a tokenizer with `apply_chat_template`. + input_ids (`torch.LongTensor`, *optional*): + Pre-tokenized input IDs. Takes precedence over `prompt` and `messages`. + max_new_tokens (`int`): + Number of tokens to generate. + block_length (`int`): + Block size for denoising. + num_inference_steps (`int`): + Number of denoising steps per block. + temperature (`float`): + Sampling temperature. + top_k (`int`): + Top-k sampling cutoff. + top_p (`float`): + Nucleus sampling cutoff. + remasking_strategy (`str`): + Strategy for selecting which tokens to commit (`sequential`, `low_confidence_static`, + `low_confidence_dynamic`, `entropy_bounded`). + confidence_threshold (`float`): + Confidence threshold for dynamic remasking. + entropy_threshold (`float`): + Entropy threshold for entropy-bounded remasking. + stop_token_ids (`list[int]`, *optional*): + Token IDs that signal generation should stop. + mask_token_id (`int`, *optional*): + Mask token ID to use for the template. + use_chat_template (`bool`, defaults to `True`): + Whether to wrap the prompt in a chat template. + add_generation_prompt (`bool`, defaults to `True`): + Whether to add the generation prompt when using chat templates. + chat_template_kwargs (`dict`, *optional*): + Extra kwargs for `apply_chat_template`. + generator (`torch.Generator`, *optional*): + RNG for sampling. + output_type (`str`, defaults to `"text"`): + Output format. `"text"` decodes sequences into strings (requires a tokenizer). `"seq"` returns raw + token ID sequences only. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`SDARPipelineOutput`] instead of a tuple. + callback_on_step_end (`Callable` or `PipelineCallback`, *optional*): + Callback executed after each denoising step. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + Tensor keys to pass to the callback. + + Examples: + """ + # 1. Check inputs early + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["cur_x"] + + # Resolve block_length from model if not explicitly overridden by the user + model_block_length = getattr(self.model, "block_length", None) + if model_block_length is None: + model_block_length = getattr(getattr(self.model, "config", None), "block_length", None) + if model_block_length is not None: + block_length = int(model_block_length) + + if mask_token_id is None: + mask_token_id = getattr(getattr(self, "tokenizer", None), "mask_token_id", None) + + self.check_inputs( + prompt=prompt, + messages=messages, + input_ids=input_ids, + block_length=block_length, + num_inference_steps=num_inference_steps, + mask_token_id=mask_token_id, + output_type=output_type, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + # 2. Prepare input IDs from prompt/messages/input_ids + input_ids = self._prepare_input_ids( + prompt=prompt, + messages=messages, + input_ids=input_ids, + use_chat_template=use_chat_template, + add_generation_prompt=add_generation_prompt, + chat_template_kwargs=chat_template_kwargs, + ) + + device = self._execution_device + input_ids = input_ids.to(device=device) + + if stop_token_ids is None: + eos_token_id = getattr(getattr(self, "tokenizer", None), "eos_token_id", None) + stop_token_ids = [int(eos_token_id)] if eos_token_id is not None else None + if stop_token_ids is not None: + stop_token_ids = [int(token_id) for token_id in stop_token_ids] + + self.model.eval() + self.scheduler.set_timesteps(int(num_inference_steps), device=device) + + prompt_length = input_ids.shape[1] + num_blocks = (prompt_length + int(max_new_tokens) + int(block_length) - 1) // int(block_length) + total_length = int(num_blocks) * int(block_length) + + # 3. Build 2D attention mask β€” the model handles backend-specific conversion internally. + attn_mask = self._build_block_attention_mask_2d( + num_blocks=num_blocks, + block_length=block_length, + total_length=total_length, + device=device, + ) + + x = self.prepare_latents(total_length, int(mask_token_id), device) + x[:, :prompt_length] = input_ids + + position_ids = torch.arange(total_length, device=device).unsqueeze(0) + past_key_values = DynamicCache() + + prefill_blocks = prompt_length // int(block_length) + prefill_length = int(prefill_blocks) * int(block_length) + + self._num_timesteps = num_inference_steps * max(num_blocks - prefill_blocks, 0) + + if prefill_length > 0: + cur_x = x[:, :prefill_length] + cur_position_ids = position_ids[:, :prefill_length] + cur_attn_mask = attn_mask[:prefill_length, :prefill_length].unsqueeze(0) + self._model_forward_logits( + input_ids=cur_x, + attention_mask=cur_attn_mask, + position_ids=cur_position_ids, + past_key_values=past_key_values, + store_kv=True, + ) + + num_transfer_tokens = self.scheduler.get_num_transfer_tokens(int(block_length), int(num_inference_steps)).to( + device=device + ) + + global_step = 0 + + # 4. Block-wise generation loop + block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy() + block_progress_bar_config["position"] = 0 + block_progress_bar_config["desc"] = "Blocks" + for block_idx in tqdm(range(prefill_blocks, int(num_blocks)), **block_progress_bar_config): + start = int(block_idx) * int(block_length) + end = start + int(block_length) + cur_x = x[:, start:end].clone() + cur_position_ids = position_ids[:, start:end] + cur_attn_mask = attn_mask[start:end, :end].unsqueeze(0) + + self.set_progress_bar_config(position=1, leave=False, desc=f"Block {block_idx} Inference Steps") + progress_bar = self.progress_bar(total=num_inference_steps) + + for step in range(int(num_inference_steps) + 1): + mask_index = cur_x == int(mask_token_id) + if mask_index.sum() == 0: + self._model_forward_logits( + input_ids=cur_x, + attention_mask=cur_attn_mask, + position_ids=cur_position_ids, + past_key_values=past_key_values, + store_kv=True, + ) + break + + logits = self._model_forward_logits( + input_ids=cur_x, + attention_mask=cur_attn_mask, + position_ids=cur_position_ids, + past_key_values=past_key_values, + store_kv=False, + ) + + step_output = self.scheduler.step( + logits, + step, + cur_x, + mask_token_id=int(mask_token_id), + num_transfer_tokens=num_transfer_tokens, + remasking_strategy=remasking_strategy, + confidence_threshold=confidence_threshold, + entropy_threshold=entropy_threshold, + temperature=temperature, + top_k=top_k, + top_p=top_p, + generator=generator, + return_dict=True, + ) + cur_x = step_output.prev_sample + transfer_index = step_output.transfer_index + sampled_tokens = step_output.sampled_tokens + sampled_probs = step_output.sampled_probs + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, global_step, step, callback_kwargs) + cur_x = callback_outputs.pop("cur_x", cur_x) + + global_step += 1 + progress_bar.update(1) + + progress_bar.close() + x[:, start:end] = cur_x + + if self.scheduler.check_should_stop(x, prompt_length, stop_token_ids): + break + + # 5. Post-process output + output_ids = x[:, : prompt_length + int(max_new_tokens)] + if stop_token_ids is not None: + stop_tensor = torch.tensor(stop_token_ids, device=device, dtype=torch.long) + stop_positions = torch.isin(output_ids[0, prompt_length:], stop_tensor).nonzero(as_tuple=True)[0] + if stop_positions.numel() > 0: + output_ids = output_ids[:, : prompt_length + int(stop_positions[0].item()) + 1] + + if output_ids.shape[0] == 1: + output_ids = output_ids[:, output_ids[0] != int(mask_token_id)] + + sequences = output_ids[:, prompt_length:] + texts = None + if output_type == "text" and self.tokenizer is not None: + texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + + if not return_dict: + return sequences, texts + return SDARPipelineOutput(sequences=sequences, texts=texts) + + def _model_forward_logits( + self, + *, + input_ids: torch.LongTensor, + attention_mask: torch.Tensor, + position_ids: torch.LongTensor, + past_key_values: DynamicCache, + store_kv: bool, + ) -> torch.Tensor: + """Run the model forward pass and return logits. + + Passes a 2D attention mask and lets the model handle internal conversion. + """ + kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": True, + } + if self._store_kv_supported is False: + output = self.model(**kwargs) + return output.logits if hasattr(output, "logits") else output[0] + if self._store_kv_supported is True: + kwargs["store_kv"] = store_kv + output = self.model(**kwargs) + return output.logits if hasattr(output, "logits") else output[0] + try: + kwargs["store_kv"] = store_kv + output = self.model(**kwargs) + self._store_kv_supported = True + return output.logits if hasattr(output, "logits") else output[0] + except TypeError: + output = self.model(**kwargs) + self._store_kv_supported = False + return output.logits if hasattr(output, "logits") else output[0] + + def _build_block_attention_mask_2d( + self, + *, + num_blocks: int, + block_length: int, + total_length: int, + device: torch.device, + ) -> torch.Tensor: + """Build a 2D block-causal attention mask of shape `(total_length, total_length)`. + + Each position can attend to all positions in the same or earlier blocks. + """ + block_mask = torch.tril(torch.ones(num_blocks, num_blocks, device=device, dtype=torch.long)) + attn = block_mask.repeat_interleave(block_length, dim=0).repeat_interleave(block_length, dim=1) + return attn[:total_length, :total_length] + + +__all__ = ["SDARPipeline", "SDARPipelineOutput"] diff --git a/src/diffusers/pipelines/token_diffusion/__init__.py b/src/diffusers/pipelines/token_diffusion/__init__.py new file mode 100644 index 000000000000..3fc40f86ee12 --- /dev/null +++ b/src/diffusers/pipelines/token_diffusion/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pipeline_token_diffusion import TokenDiffusionPipeline, TokenDiffusionPipelineOutput + + +__all__ = ["TokenDiffusionPipeline", "TokenDiffusionPipelineOutput"] diff --git a/src/diffusers/pipelines/token_diffusion/pipeline_token_diffusion.py b/src/diffusers/pipelines/token_diffusion/pipeline_token_diffusion.py new file mode 100644 index 000000000000..e93762b8c4dd --- /dev/null +++ b/src/diffusers/pipelines/token_diffusion/pipeline_token_diffusion.py @@ -0,0 +1,265 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +import torch + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...utils import BaseOutput, logging +from ..pipeline_utils import DiffusionPipeline, DiscreteDiffusionPipelineMixin + + +logger = logging.get_logger(__name__) + + +@dataclass +class TokenDiffusionPipelineOutput(BaseOutput): + """ + Output class for token diffusion pipelines. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Sampled token IDs. + texts (`list[str]`, *optional*): + Decoded texts if a tokenizer was provided and `output_type="text"`. + """ + + sequences: torch.LongTensor + texts: list[str] | None = None + + +class TokenDiffusionPipeline(DiffusionPipeline, DiscreteDiffusionPipelineMixin): + """ + Generic token diffusion sampling pipeline. + + This pipeline is intended as a minimal, diffusers-native wrapper around: + - a token denoiser model (e.g. `transformers.AutoModelForMaskedLM`-like, returning logits over vocab), and + - a discrete token scheduler (e.g. `TokenDiffusionScheduler`) that implements `set_timesteps()` and `step()`. + + The pipeline supports multiple forward processes via the scheduler configuration (e.g. absorbing/mask, uniform). + Conditioning (prefix/infill) is intentionally out of scope for the first version. + """ + + model: Any + tokenizer: Any + scheduler: Any + + _callback_tensor_inputs = ["input_ids", "logits"] + + def __init__( + self, + model: Any, + scheduler: Any, + tokenizer: Any | None = None, + ): + super().__init__() + self.register_modules(model=model, scheduler=scheduler, tokenizer=tokenizer) + + @property + def num_timesteps(self): + return self._num_timesteps + + def prepare_latents( + self, + batch_size: int, + seq_len: int, + generator: torch.Generator | None = None, + device: torch.device | None = None, + ) -> torch.LongTensor: + shape = torch.Size((batch_size, seq_len)) + return self.scheduler.sample_prior(shape, device=device, generator=generator) + + def check_inputs( + self, + batch_size: int, + seq_len: int, + num_inference_steps: int, + output_type: str, + callback_on_step_end: Callable | PipelineCallback | MultiPipelineCallbacks | None, + callback_on_step_end_tensor_inputs: list[str] | None, + infill_mask: torch.BoolTensor | None, + prefix_ids: torch.LongTensor | None, + ): + # Generation parameter validation + if batch_size <= 0: + raise ValueError(f"`batch_size` must be > 0, got {batch_size}.") + if seq_len <= 0: + raise ValueError(f"`seq_len` must be > 0, got {seq_len}.") + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + if output_type not in {"seq", "text"}: + raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.") + + # Callback validation + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # Mask / prefix validation + if infill_mask is not None and infill_mask.shape != (batch_size, seq_len): + raise ValueError(f"`infill_mask` must have shape {(batch_size, seq_len)}, got {tuple(infill_mask.shape)}.") + if prefix_ids is not None: + p = prefix_ids + if p.ndim == 1: + p = p.unsqueeze(0) + if p.ndim == 2 and p.shape[1] > seq_len: + raise ValueError(f"`prefix_ids` length {p.shape[1]} must be <= seq_len={seq_len}.") + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + seq_len: int = 64, + num_inference_steps: int = 128, + generator: torch.Generator | None = None, + prefix_ids: torch.LongTensor | None = None, + infill_mask: torch.BoolTensor | None = None, + inject_start_token: bool = False, + output_type: str = "text", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int, dict], None] + | PipelineCallback + | MultiPipelineCallbacks + | None = None, + callback_on_step_end_tensor_inputs: list[str] | None = None, + **model_kwargs, + ) -> TokenDiffusionPipelineOutput | tuple[torch.LongTensor, list[str] | None]: + """ + Args: + batch_size: Number of sequences to generate. + seq_len: Sequence length in tokens. + num_inference_steps: Number of reverse diffusion steps. + generator: Optional torch generator for determinism. + prefix_ids: Optional prefix token IDs to keep fixed at the start of each sequence. Shape `[P]` or + `[batch_size, P]`. + infill_mask: + Optional boolean mask of shape `[batch_size, seq_len]` indicating which positions are editable (`True`) + vs fixed (`False`). Fixed positions are clamped to the initial values on every step. + inject_start_token: If True, inject `bos_token_id` (or `cls_token_id`) into position 0 (if available). + output_type (`str`, defaults to `"text"`): + Output format. `"text"` decodes sequences into strings (requires a tokenizer). `"seq"` returns raw + token ID sequences only. + return_dict: If True, returns a `TokenDiffusionPipelineOutput`. + callback_on_step_end: A function called after each denoising step with signature + `callback_on_step_end(self, step: int, timestep: int, callback_kwargs: dict)`. + callback_on_step_end_tensor_inputs: List of tensor keys to include in `callback_kwargs`. + model_kwargs: Forward kwargs passed to `model(...)` (e.g. attention mask overrides). + """ + # 1. Check inputs early + if callback_on_step_end is not None and isinstance( + callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks) + ): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if callback_on_step_end_tensor_inputs is None: + callback_on_step_end_tensor_inputs = ["input_ids"] + + self.check_inputs( + batch_size=batch_size, + seq_len=seq_len, + num_inference_steps=num_inference_steps, + output_type=output_type, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + infill_mask=infill_mask, + prefix_ids=prefix_ids, + ) + + # 2. Prepare timesteps + device = self._execution_device + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + # 3. Prepare latents + input_ids = self.prepare_latents(batch_size, seq_len, generator=generator, device=device) + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + + # 4. Build fixed masks for prefix / infill conditioning + fixed_mask = None + fixed_values = None + if infill_mask is not None: + fixed_mask = (~infill_mask.to(device=device)).to(dtype=torch.bool) + fixed_values = input_ids.clone() + + if prefix_ids is not None: + prefix_ids = self._normalize_prefix_ids(prefix_ids, batch_size=batch_size, device=device) + prefix_len = prefix_ids.shape[1] + + input_ids[:, :prefix_len] = prefix_ids + if fixed_mask is None: + fixed_mask = torch.zeros((batch_size, seq_len), device=device, dtype=torch.bool) + fixed_values = input_ids.clone() + fixed_mask[:, :prefix_len] = True + fixed_values[:, :prefix_len] = prefix_ids + + start_token_id = self._resolve_start_token_id() + if inject_start_token and start_token_id is not None: + input_ids[:, 0] = start_token_id + if fixed_mask is not None: + fixed_mask[:, 0] = True + fixed_values[:, 0] = start_token_id + + # 5. Denoising loop + progress_bar = self.progress_bar(total=num_inference_steps) + for step_idx, t in enumerate(timesteps): + timestep = t.expand(batch_size) + out = self.model(input_ids=input_ids, timesteps=timestep, return_dict=True, **model_kwargs) + logits = getattr(out, "logits", None) + if logits is None: + # Fall back to tuple-style returns. + logits = out[0] + + input_ids = self.scheduler.step(logits, t, input_ids, generator=generator, return_dict=True).prev_sample + + # Enforce fixed masks (prefix / infill conditioning) + if fixed_mask is not None: + input_ids = self.scheduler.enforce_fixed_masks(input_ids, fixed_mask, fixed_values) + + if inject_start_token and start_token_id is not None: + input_ids[:, 0] = start_token_id + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, step_idx, t, callback_kwargs) + input_ids = callback_outputs.pop("input_ids", input_ids) + + progress_bar.update(1) + progress_bar.close() + + # 6. Post-process output + texts = None + if output_type == "text" and getattr(self, "tokenizer", None) is not None: + texts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True) + + if not return_dict: + return (input_ids, texts) + return TokenDiffusionPipelineOutput(sequences=input_ids, texts=texts) + + +__all__ = ["TokenDiffusionPipeline", "TokenDiffusionPipelineOutput"] diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index b1f75bed7dc5..e0954aec37f5 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -40,6 +40,10 @@ else: _import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"] _import_structure["scheduling_amused"] = ["AmusedScheduler"] + _import_structure["scheduling_bd3lm_token_diffusion"] = [ + "BD3LMTokenDiffusionScheduler", + "BD3LMTokenDiffusionSchedulerOutput", + ] _import_structure["scheduling_block_refinement"] = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"] _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] @@ -51,6 +55,10 @@ _import_structure["scheduling_ddpm_parallel"] = ["DDPMParallelScheduler"] _import_structure["scheduling_ddpm_wuerstchen"] = ["DDPMWuerstchenScheduler"] _import_structure["scheduling_deis_multistep"] = ["DEISMultistepScheduler"] + _import_structure["scheduling_dflash_token_diffusion"] = [ + "DFlashTokenDiffusionScheduler", + "DFlashTokenDiffusionSchedulerOutput", + ] _import_structure["scheduling_dpm_cogvideox"] = ["CogVideoXDPMScheduler"] _import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"] _import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"] @@ -65,6 +73,10 @@ _import_structure["scheduling_helios"] = ["HeliosScheduler"] _import_structure["scheduling_helios_dmd"] = ["HeliosDMDScheduler"] _import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"] + _import_structure["scheduling_hybrid_token_diffusion"] = [ + "HybridTokenDiffusionScheduler", + "HybridTokenDiffusionSchedulerOutput", + ] _import_structure["scheduling_ipndm"] = ["IPNDMScheduler"] _import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"] _import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"] @@ -74,8 +86,13 @@ _import_structure["scheduling_repaint"] = ["RePaintScheduler"] _import_structure["scheduling_sasolver"] = ["SASolverScheduler"] _import_structure["scheduling_scm"] = ["SCMScheduler"] + _import_structure["scheduling_sdar_token_diffusion"] = [ + "SDARTokenDiffusionScheduler", + "SDARTokenDiffusionSchedulerOutput", + ] _import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"] _import_structure["scheduling_tcd"] = ["TCDScheduler"] + _import_structure["scheduling_token_diffusion"] = ["TokenDiffusionScheduler"] _import_structure["scheduling_unclip"] = ["UnCLIPScheduler"] _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"] _import_structure["scheduling_utils"] = ["AysSchedules", "KarrasDiffusionSchedulers", "SchedulerMixin"] @@ -146,6 +163,7 @@ else: from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler from .scheduling_amused import AmusedScheduler + from .scheduling_bd3lm_token_diffusion import BD3LMTokenDiffusionScheduler, BD3LMTokenDiffusionSchedulerOutput from .scheduling_block_refinement import BlockRefinementScheduler, BlockRefinementSchedulerOutput from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler @@ -157,6 +175,10 @@ from .scheduling_ddpm_parallel import DDPMParallelScheduler from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler from .scheduling_deis_multistep import DEISMultistepScheduler + from .scheduling_dflash_token_diffusion import ( + DFlashTokenDiffusionScheduler, + DFlashTokenDiffusionSchedulerOutput, + ) from .scheduling_dpm_cogvideox import CogVideoXDPMScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler @@ -180,6 +202,7 @@ from .scheduling_repaint import RePaintScheduler from .scheduling_sasolver import SASolverScheduler from .scheduling_scm import SCMScheduler + from .scheduling_sdar_token_diffusion import SDARTokenDiffusionScheduler, SDARTokenDiffusionSchedulerOutput from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_tcd import TCDScheduler from .scheduling_unclip import UnCLIPScheduler diff --git a/src/diffusers/schedulers/scheduling_bd3lm_token_diffusion.py b/src/diffusers/schedulers/scheduling_bd3lm_token_diffusion.py new file mode 100644 index 000000000000..f484939ea725 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_bd3lm_token_diffusion.py @@ -0,0 +1,428 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from dataclasses import dataclass + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class BD3LMTokenDiffusionSchedulerOutput(BaseOutput): + """ + Output class for BD3LM token diffusion scheduling. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Updated token sequence after the current denoising step. + p_x0_cache (`torch.Tensor` of shape `(batch_size, block_size, vocab_size)` or `None`): + Cached clean-token probability distribution. When `None`, the model should be called again at the next + step; when not `None`, the cached distribution can be reused. + """ + + prev_sample: torch.LongTensor + p_x0_cache: torch.Tensor | None + + +class BD3LMTokenDiffusionScheduler(SchedulerMixin, ConfigMixin): + """ + Scheduler for Block Discrete Denoising Diffusion Language Models (BD3LM). + + Implements the DDPM-style caching update from BD3LM, which iteratively denoises masked token sequences block by + block. At each step the scheduler computes posterior transition probabilities q(x_s | x_t, x_0) and samples new + tokens for currently masked positions while preserving already-unmasked tokens. + + Supports multiple noise schedules: loglinear, cosine, square, square_root, and log. + """ + + order = 1 + + @register_to_config + def __init__( + self, + block_size: int = 1024, + num_inference_steps: int = 1024, + noise_type: str = "loglinear", + nucleus_p: float = 1.0, + mask_token_id: int = 32000, + ): + self.num_inference_steps = num_inference_steps + self.timesteps: torch.Tensor | None = None + self._dt: float | None = None + + # ------------------------------------------------------------------ + # Timestep management + # ------------------------------------------------------------------ + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Create linearly spaced timesteps from 1 to 0 (exclusive). + + Args: + num_inference_steps (`int`): + Number of denoising steps. + device (`str` or `torch.device`, *optional*): + Device for the timestep tensor. + """ + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + self.num_inference_steps = num_inference_steps + self.timesteps = torch.linspace(1.0, 0.0, num_inference_steps, device=device) + self._dt = 1.0 / num_inference_steps + + # ------------------------------------------------------------------ + # Noise schedule utilities + # ------------------------------------------------------------------ + + def _compute_move_chance(self, t: torch.Tensor) -> torch.Tensor: + """ + Compute the probability that a token has been masked (move chance) at continuous time *t*. + + The move chance depends on the configured ``noise_type``: + - **loglinear**: ``move_chance = t`` + - **cosine**: ``move_chance = 1 - (1 - eps) * cos(t * pi / 2)`` + - **square**: ``move_chance = t ** 2`` + - **square_root**: ``move_chance = t ** 0.5`` + - **log**: ``move_chance = log(1 + t) / log(2)`` + + Args: + t (`torch.Tensor`): + Continuous timestep values in [0, 1]. + + Returns: + `torch.Tensor`: Move chance at each timestep value, same shape as *t*. + """ + noise_type = self.config.noise_type + eps = 1e-3 + if noise_type == "loglinear": + return t + elif noise_type == "cosine": + return 1.0 - (1.0 - eps) * torch.cos(t * math.pi / 2.0) + elif noise_type == "square": + return torch.clamp(t**2, min=eps) + elif noise_type == "square_root": + return torch.clamp(t**0.5, min=eps) + elif noise_type == "log": + return torch.log1p(t) / math.log(2.0) + else: + raise ValueError( + f"Unknown noise_type '{noise_type}'. Must be one of: loglinear, cosine, square, square_root, log." + ) + + # ------------------------------------------------------------------ + # Nucleus (top-p) filtering + # ------------------------------------------------------------------ + + @staticmethod + def _nucleus_filtering(probs: torch.Tensor, nucleus_p: float) -> torch.Tensor: + """ + Apply nucleus (top-p) filtering to a probability distribution. + + Tokens outside the top-p cumulative probability mass are zeroed out and the distribution is renormalised. + + Args: + probs (`torch.Tensor` of shape `(*, vocab_size)`): + Token probability distributions (already softmaxed). + nucleus_p (`float`): + Cumulative probability threshold. Use 1.0 to disable filtering. + + Returns: + `torch.Tensor`: Filtered and renormalised probability distributions. + """ + if nucleus_p >= 1.0: + return probs + sorted_probs, sorted_indices = probs.sort(dim=-1, descending=True) + cumulative_probs = sorted_probs.cumsum(dim=-1) + nucleus_mask = cumulative_probs <= nucleus_p + # Always keep at least the top-1 token + nucleus_mask[..., 0] = True + sorted_probs = sorted_probs * nucleus_mask + # Scatter back to original order + filtered = torch.zeros_like(probs) + filtered.scatter_(-1, sorted_indices, sorted_probs) + filtered = filtered / filtered.sum(dim=-1, keepdim=True) + return filtered + + # ------------------------------------------------------------------ + # Core step + # ------------------------------------------------------------------ + + def step( + self, + model_output: torch.Tensor, + timestep: float | torch.Tensor, + sample: torch.LongTensor, + *, + mask_token_id: int | None = None, + nucleus_p: float | None = None, + generator: torch.Generator | None = None, + return_dict: bool = True, + ) -> BD3LMTokenDiffusionSchedulerOutput | tuple[torch.LongTensor, torch.Tensor | None]: + """ + Perform a single DDPM caching denoising step. + + The method implements the BD3LM reverse-process update: given predicted clean-token logits from the model, it + computes the posterior q(x_s | x_t, x_0), samples new tokens for masked positions, and copies through tokens + that are already unmasked. + + Args: + model_output (`torch.Tensor` of shape `(batch_size, seq_len, vocab_size)`): + Raw logits from the model. Softmax and nucleus filtering are applied internally. + timestep (`float` or `torch.Tensor`): + Current continuous timestep *t* (in [0, 1], starting at 1 and decreasing toward 0). + sample (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Current noisy token sequence. Masked positions contain ``mask_token_id``. + mask_token_id (`int`, *optional*): + Token ID used for masked positions. Defaults to the value from the scheduler config. + nucleus_p (`float`, *optional*): + Nucleus sampling threshold. Defaults to the value from the scheduler config. + generator (`torch.Generator`, *optional*): + Random number generator for reproducible sampling. + return_dict (`bool`): + Whether to return a [`BD3LMTokenDiffusionSchedulerOutput`] or a plain tuple. + + Returns: + [`BD3LMTokenDiffusionSchedulerOutput`] or `tuple`: + The denoised sample and the p_x0 cache (``None`` when the sample changed, meaning the cache is + invalidated and the model must be called again at the next step). + """ + if mask_token_id is None: + mask_token_id = self.config.mask_token_id + if nucleus_p is None: + nucleus_p = self.config.nucleus_p + + block_size = self.config.block_size + dt = self._dt if self._dt is not None else 1.0 / self.num_inference_steps + + # Ensure timestep is a tensor + if not isinstance(timestep, torch.Tensor): + t = torch.tensor([timestep], device=sample.device, dtype=torch.float64) + else: + t = timestep.to(dtype=torch.float64) + if t.dim() == 0: + t = t.unsqueeze(0) + + # ------------------------------------------------------------------ + # Compute move chances at t and s = t - dt + # ------------------------------------------------------------------ + move_chance_t = self._compute_move_chance(t).to(dtype=torch.float64) + move_chance_s = self._compute_move_chance(t - dt).to(dtype=torch.float64) + + # Expand to (batch, 1) for broadcasting against (batch, seq_len) + if move_chance_t.dim() == 1: + move_chance_t = move_chance_t.unsqueeze(-1) + move_chance_s = move_chance_s.unsqueeze(-1) + + # mask_prob: probability that a token stays masked at s given it was masked at t + mask_prob = move_chance_s / move_chance_t # (batch, 1) + + # ------------------------------------------------------------------ + # Apply subs parameterization and convert to p(x_0) + # ------------------------------------------------------------------ + logits = model_output[:, -block_size:].to(dtype=torch.float64) + + # Subs parameterization: mask token gets -inf, then log_softmax normalizes. + # For unmasked positions, the distribution is forced to be the identity. + logits[..., mask_token_id] = -1e9 + logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) + + x_current_block = sample[:, -block_size:] + unmasked = x_current_block != mask_token_id + logits[unmasked] = -1e9 + logits[unmasked, x_current_block[unmasked]] = 0.0 + + # Convert log-probs to probs and apply nucleus filtering + p_x0 = logits.exp() + p_x0 = self._nucleus_filtering(p_x0, nucleus_p) + + # ------------------------------------------------------------------ + # Compute posterior q(x_s | x_t, x_0) and sample + # ------------------------------------------------------------------ + # For non-mask tokens: q_xs = p_x0 * (1 - mask_prob), q_xs[mask] = mask_prob + q_xs = p_x0 * (1.0 - mask_prob.unsqueeze(-1)) + q_xs[..., mask_token_id] = mask_prob.squeeze(-1) + + # Gumbel-argmax categorical sampling + gumbel_noise = -(torch.rand_like(q_xs, generator=generator) + 1e-10).log() + gumbel_noise = (1e-10 + gumbel_noise).clamp(min=1e-30) + x_block = (q_xs / gumbel_noise).argmax(dim=-1) + + # ------------------------------------------------------------------ + # Copy flag: preserve tokens that are already unmasked + # ------------------------------------------------------------------ + x_current_block = sample[:, -block_size:] + is_masked = (x_current_block == mask_token_id).to(dtype=x_block.dtype) + x_block = (1 - is_masked) * x_current_block + is_masked * x_block + + # Assemble full sequence + if sample.shape[-1] > block_size: + prev_sample = torch.cat([sample[:, :-block_size], x_block], dim=-1) + else: + prev_sample = x_block + + # ------------------------------------------------------------------ + # Determine p_x0 cache validity + # ------------------------------------------------------------------ + # If any token changed, invalidate the cache so the model is called again. + if not torch.equal(prev_sample, sample): + p_x0_cache = None + else: + p_x0_cache = p_x0 + + if not return_dict: + return prev_sample, p_x0_cache + return BD3LMTokenDiffusionSchedulerOutput( + prev_sample=prev_sample, + p_x0_cache=p_x0_cache, + ) + + # ------------------------------------------------------------------ + # Forward (noising) process + # ------------------------------------------------------------------ + + def add_noise( + self, + original_samples: torch.LongTensor, + timesteps: torch.Tensor, + mask_token_id: int | None = None, + generator: torch.Generator | None = None, + ) -> torch.LongTensor: + """ + Apply the forward noising process: randomly mask tokens with probability determined by the noise schedule. + + Args: + original_samples (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Clean token IDs. + timesteps (`torch.Tensor` of shape `(batch_size,)` or `(batch_size, seq_len)`): + Continuous timestep values in [0, 1] controlling the amount of noise. + mask_token_id (`int`, *optional*): + Token ID to use for masked positions. Defaults to the value from the scheduler config. + generator (`torch.Generator`, *optional*): + Random number generator for reproducibility. + + Returns: + `torch.LongTensor`: Noisy token sequence with the same shape as *original_samples*. + """ + if mask_token_id is None: + mask_token_id = self.config.mask_token_id + + move_chance = self._compute_move_chance(timesteps) + # Expand move_chance to match sample dimensions for broadcasting + if move_chance.dim() == 1: + move_chance = move_chance.unsqueeze(-1) # (batch, 1) + + # Sample uniform noise and mask tokens where noise < move_chance + uniform_noise = torch.rand( + original_samples.shape, + device=original_samples.device, + dtype=move_chance.dtype, + generator=generator, + ) + mask = uniform_noise < move_chance + noisy_samples = torch.where(mask, mask_token_id, original_samples) + return noisy_samples + + # ------------------------------------------------------------------ + # Sigma computation (for model input) + # ------------------------------------------------------------------ + + def compute_sigma(self, t: float | torch.Tensor, batch_size: int = 1) -> torch.Tensor: + """ + Compute the sigma value (noise level) for a given timestep. + + Sigma is derived from the noise schedule's move chance: ``sigma = -log(1 - move_chance)``, clamped at + ``sigma_max = -log(eps)`` where ``eps = 1e-3``. + + Args: + t (`float` or `torch.Tensor`): + Continuous timestep value in [0, 1]. + batch_size (`int`): + Batch size for expanding the result. + + Returns: + `torch.Tensor`: Sigma values of shape ``(batch_size,)`` in float32. + """ + if not isinstance(t, torch.Tensor): + t = torch.tensor([t], dtype=torch.float64) + t = t.to(dtype=torch.float64) + if t.dim() == 0: + t = t.unsqueeze(0) + t = t.expand(batch_size) + + eps = 1e-3 + sigma_max = -torch.log(torch.tensor(eps, device=t.device, dtype=torch.float64)) + move_chance = self._compute_move_chance(t) + sigma = torch.min(-torch.log(1.0 - move_chance), sigma_max) + return sigma.float() + + # ------------------------------------------------------------------ + # Stopping criteria + # ------------------------------------------------------------------ + + @staticmethod + def check_eos_finished( + sequences: torch.LongTensor, + prompt_length: int, + eos_token_id: int, + finished: torch.BoolTensor, + ) -> torch.BoolTensor: + """ + Update per-batch finished flags when EOS tokens appear in the generated portion. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Full accumulated sequence including prompt. + prompt_length (`int`): + Number of prompt tokens at the start. + eos_token_id (`int`): + EOS token ID. + finished (`torch.BoolTensor` of shape `(batch_size,)`): + Current per-batch finished flags. + + Returns: + `torch.BoolTensor`: Updated finished flags. + """ + batch_size = sequences.shape[0] + for b in range(batch_size): + if finished[b]: + continue + generated = sequences[b, prompt_length:] + if (generated == eos_token_id).any(): + finished[b] = True + return finished + + @staticmethod + def check_should_stop(sequences: torch.LongTensor, mask_token_id: int) -> bool: + """ + Check whether all mask tokens have been resolved. + + Args: + sequences (`torch.LongTensor`): + Current token sequences. + mask_token_id (`int`): + Token ID used for masked positions. + + Returns: + `bool`: `True` if no mask tokens remain in *sequences*. + """ + return (sequences == mask_token_id).sum().item() == 0 + + +__all__ = ["BD3LMTokenDiffusionScheduler", "BD3LMTokenDiffusionSchedulerOutput"] diff --git a/src/diffusers/schedulers/scheduling_dflash_token_diffusion.py b/src/diffusers/schedulers/scheduling_dflash_token_diffusion.py new file mode 100644 index 000000000000..1190cc5fe123 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_dflash_token_diffusion.py @@ -0,0 +1,187 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class DFlashTokenDiffusionSchedulerOutput(BaseOutput): + """ + Output class for DFlash-style speculative token scheduling. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch_size, block_size)`): + The proposed block tokens from the draft model. + accepted_length (`torch.LongTensor` of shape `(batch_size,)`): + Number of consecutive accepted tokens from the block. + next_token (`torch.LongTensor` of shape `(batch_size,)`): + Next token sampled from the target posterior at the first rejection. + posterior (`torch.LongTensor` of shape `(batch_size, block_size)`): + Sampled tokens from the target posterior used for acceptance checks. + """ + + prev_sample: torch.LongTensor + accepted_length: torch.LongTensor + next_token: torch.LongTensor + posterior: torch.LongTensor + + +class DFlashTokenDiffusionScheduler(SchedulerMixin, ConfigMixin): + """ + Scheduler for DFlash-style block diffusion speculative decoding. + + This scheduler samples target posteriors and computes acceptance lengths for draft blocks. + """ + + order = 1 + + @register_to_config + def __init__(self): + self.num_inference_steps = 1 + self.timesteps = torch.tensor([0], dtype=torch.long) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + self.num_inference_steps = int(num_inference_steps) + self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long) + + def sample(self, logits: torch.Tensor, temperature: float = 0.0) -> torch.LongTensor: + if temperature < 1e-5: + return torch.argmax(logits, dim=-1) + bsz, seq_len, vocab_size = logits.shape + flat = logits.view(-1, vocab_size) / float(temperature) + probs = torch.softmax(flat, dim=-1) + return torch.multinomial(probs, num_samples=1).view(bsz, seq_len) + + def step( + self, + draft_tokens: torch.LongTensor, + target_logits: torch.Tensor, + *, + temperature: float = 0.0, + return_dict: bool = True, + ) -> ( + DFlashTokenDiffusionSchedulerOutput + | tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor] + ): + posterior = self.sample(target_logits, temperature=temperature) + if draft_tokens.shape[1] > 1: + matches = draft_tokens[:, 1:] == posterior[:, :-1] + accepted_length = matches.int().cumprod(dim=1).sum(dim=1) + else: + accepted_length = torch.zeros((draft_tokens.shape[0],), device=draft_tokens.device, dtype=torch.long) + + next_token = posterior.gather(1, accepted_length.unsqueeze(1)).squeeze(1) + + if not return_dict: + return draft_tokens, accepted_length, next_token, posterior + return DFlashTokenDiffusionSchedulerOutput( + prev_sample=draft_tokens, + accepted_length=accepted_length, + next_token=next_token, + posterior=posterior, + ) + + @staticmethod + def check_should_stop( + output_ids: torch.LongTensor, + stop_token_ids: list[int] | None, + num_input_tokens: int, + ) -> bool: + """ + Check whether any stop token has been generated in the output sequence. + + Args: + output_ids (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Current output token IDs including prompt and generated tokens. + stop_token_ids (`list[int]` or `None`): + Token IDs that signal generation should stop. + num_input_tokens (`int`): + Number of prompt tokens at the start of the sequence. + + Returns: + `bool`: `True` if generation should stop, `False` otherwise. + """ + if stop_token_ids is None: + return False + stop_tensor = torch.tensor(stop_token_ids, device=output_ids.device, dtype=torch.long) + return torch.isin(output_ids[:, num_input_tokens:], stop_tensor).any().item() + + def add_noise( + self, + original_samples: torch.LongTensor, + attention_mask: torch.LongTensor, + *, + prompt_length: int, + block_size: int, + mask_token_id: int, + generator: torch.Generator | None = None, + ) -> tuple[torch.LongTensor, torch.BoolTensor]: + """ + Apply the forward (noising) process for DFlash-style block diffusion training. + + For each block after the prompt, a random fraction of valid (non-padding) tokens are replaced with + `mask_token_id`. + + Args: + original_samples (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Clean token IDs. + attention_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Padding mask (1 for valid, 0 for padding). + prompt_length (`int`): + Number of leading prompt tokens to keep unmasked. + block_size (`int`): + Block size for masking. + mask_token_id (`int`): + Token ID to use for masked positions. + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + + Returns: + `tuple[torch.LongTensor, torch.BoolTensor]`: + `(noisy, masked)` -- the noisy sequence and the boolean mask indicating which positions were masked. + """ + batch_size, seq_len = original_samples.shape + device = original_samples.device + + noisy = original_samples.clone() + masked = torch.zeros_like(original_samples, dtype=torch.bool) + + valid = attention_mask.to(dtype=torch.bool) + for block_start in range(prompt_length, seq_len, block_size): + block_end = min(seq_len, block_start + block_size) + seg_len = block_end - block_start + if seg_len <= 0: + continue + + p_mask = torch.rand((batch_size, 1), device=device, generator=generator) + seg = torch.rand((batch_size, seg_len), device=device, generator=generator) < p_mask + seg = seg & valid[:, block_start:block_end] + + masked[:, block_start:block_end] = seg + + noisy = torch.where(masked, torch.full_like(noisy, mask_token_id), noisy) + return noisy, masked + + +__all__ = ["DFlashTokenDiffusionScheduler", "DFlashTokenDiffusionSchedulerOutput"] diff --git a/src/diffusers/schedulers/scheduling_hybrid_token_diffusion.py b/src/diffusers/schedulers/scheduling_hybrid_token_diffusion.py new file mode 100644 index 000000000000..b738895e3221 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_hybrid_token_diffusion.py @@ -0,0 +1,267 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from dataclasses import dataclass + +import numpy as np +import torch +import torch.nn.functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_token_diffusion import _gumbel_argmax +from .scheduling_utils import SchedulerMixin + + +@dataclass +class HybridTokenDiffusionSchedulerOutput(BaseOutput): + prev_sample: torch.LongTensor + + +class HybridTokenDiffusionScheduler(SchedulerMixin, ConfigMixin): + """ + Hybrid-transition discrete token diffusion scheduler. + + This scheduler defines a forward transition kernel that mixes: + - keeping the current token (scaled by alpha(t)) + - moving toward a mixture distribution over tokens (beta_pi(t)) + + The scheduler exposes: + - `add_noise(...)` for forward corruption + - `step(...)` for reverse updates using the model's predicted token distribution + """ + + order = 1 + + @register_to_config + def __init__( + self, + vocab_size: int, + mask_token_id: int, + num_train_timesteps: int = 1000, + t_eps: float = 1e-4, + p_uniform: float = 0.0, + clip_noise: float = 20.0, + gamma: float = 1.0, + ): + if vocab_size <= 0: + raise ValueError(f"`vocab_size` must be > 0, got {vocab_size}.") + if not (0 <= mask_token_id < vocab_size): + raise ValueError(f"`mask_token_id` must be in [0, vocab_size), got {mask_token_id}.") + if num_train_timesteps <= 1: + raise ValueError(f"`num_train_timesteps` must be > 1, got {num_train_timesteps}.") + if not (0.0 < t_eps < 0.5): + raise ValueError(f"`t_eps` must be in (0, 0.5), got {t_eps}.") + if gamma <= 0: + raise ValueError(f"`gamma` must be > 0, got {gamma}.") + + self.vocab_size = int(vocab_size) + self.mask_token_id = int(mask_token_id) + self.num_train_timesteps = int(num_train_timesteps) + self.t_eps = float(t_eps) + + p_uniform = max(math.exp(-float(clip_noise)), float(p_uniform)) + log_B = float(gamma) * math.log(2.0) + math.log(p_uniform) - math.log(1.0 - p_uniform) + log_B = float(np.clip(log_B, -float(clip_noise), float(clip_noise))) + self.log_B = float(log_B) + self.log_gamma = float(math.log(float(gamma))) + + self.num_inference_steps = None + self.timesteps = None + self._timesteps_with_end = None + + mask = torch.zeros(self.vocab_size, dtype=torch.float32) + mask[self.mask_token_id] = 1.0 + self.mask = mask + + unif = (1.0 - mask) / max(self.vocab_size - 1, 1) + self.unif = unif + + def sample_prior( + self, + shape: torch.Size, + device: torch.device, + generator: torch.Generator | None = None, + ) -> torch.LongTensor: + """ + Sample from the prior distribution at t=1. + + At t=1, the stationary distribution concentrates on the mask token, so this returns a tensor filled with + `mask_token_id`. + + Args: + shape (`torch.Size`): + Desired output shape, e.g. `(batch_size, seq_len)`. + device (`torch.device`): + Device for the output tensor. + generator (`torch.Generator`, *optional*): + Optional generator for determinism (unused for the absorbing prior). + + Returns: + `torch.LongTensor` of shape `shape` with `mask_token_id`. + """ + return torch.full(shape, self.mask_token_id, device=device, dtype=torch.long) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + self.num_inference_steps = int(num_inference_steps) + + t0 = 1.0 - float(self.t_eps) + t1 = float(self.t_eps) + timesteps = torch.linspace(t0, t1, self.num_inference_steps + 1, dtype=torch.float32, device=device) + self._timesteps_with_end = timesteps + self.timesteps = timesteps[:-1] + + def scale_model_input(self, sample: torch.Tensor, timestep: int | torch.Tensor | None = None) -> torch.Tensor: + return sample + + def _to_continuous_t(self, timesteps: torch.Tensor, device: torch.device) -> torch.Tensor: + if timesteps.dtype in (torch.float16, torch.float32, torch.float64, torch.bfloat16): + t = timesteps.to(device=device, dtype=torch.float32) + return t.clamp(float(self.t_eps), 1.0 - float(self.t_eps)) + + if timesteps.dtype not in (torch.int32, torch.int64): + raise ValueError(f"`timesteps` must be float or int, got dtype={timesteps.dtype}.") + + t = timesteps.to(device=device, dtype=torch.float32) / float(self.num_train_timesteps - 1) + t = (1.0 - 2.0 * float(self.t_eps)) * t + float(self.t_eps) + return t.clamp(float(self.t_eps), 1.0 - float(self.t_eps)) + + def _get_alpha_betapi(self, t: torch.Tensor, eps: float = 1e-6) -> tuple[torch.Tensor, torch.Tensor]: + t = t.view(-1, 1) + t1m = 1.0 - t + + gamma = float(math.exp(self.log_gamma)) + B = float(math.exp(self.log_B)) + c_t = (t.pow(gamma / 2.0) * t1m.pow(gamma / 2.0) * B).to(dtype=torch.float32) + C_t = (1.0 + c_t).clamp_min(eps) + + alpha_t = t1m / C_t + beta_pi = ( + t * self.mask.to(device=t.device, dtype=torch.float32) + + c_t * self.unif.to(device=t.device, dtype=torch.float32) + ) / C_t + return alpha_t, beta_pi + + def _probs_at_t(self, probs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + alpha_t, beta_pi = self._get_alpha_betapi(t) + alpha_t = alpha_t.to(dtype=probs.dtype) + beta_pi = beta_pi.to(dtype=probs.dtype) + + out = probs.mul(alpha_t.unsqueeze(1)) + out[..., : beta_pi.shape[-1]].add_(beta_pi.unsqueeze(1)) + return out + + def _sample_categorical(self, probs: torch.Tensor, generator: torch.Generator | None) -> torch.LongTensor: + bsz, seqlen, vocab = probs.shape + flat = probs.view(-1, vocab).clamp_min(torch.finfo(probs.dtype).tiny) + flat = flat / flat.sum(dim=-1, keepdim=True).clamp_min(torch.finfo(probs.dtype).eps) + sample = torch.multinomial(flat, num_samples=1, generator=generator).view(bsz, seqlen) + return sample.to(dtype=torch.long) + + def add_noise( + self, + original_samples: torch.LongTensor, + noise: torch.Tensor | None, + timesteps: torch.Tensor, + ) -> torch.LongTensor: + del noise + if original_samples.dtype != torch.long: + raise ValueError(f"`original_samples` must be int64 token IDs, got dtype={original_samples.dtype}.") + + device = original_samples.device + t = self._to_continuous_t(timesteps.to(device=device), device=device) + onehot = F.one_hot(original_samples, num_classes=self.vocab_size).to(dtype=torch.float32) + probs = self._probs_at_t(onehot, t) + return self._sample_categorical(probs, generator=None) + + def _index_for_timestep(self, timestep: float | torch.Tensor) -> int: + if self.timesteps is None: + raise ValueError("Call `set_timesteps(...)` before calling `step()`.") + + if isinstance(timestep, torch.Tensor): + t = float(timestep.detach().cpu().item()) + else: + t = float(timestep) + + idx = int(torch.argmin(torch.abs(self.timesteps.detach().cpu() - torch.tensor(t))).item()) + return idx + + def step( + self, + model_output: torch.Tensor, + timestep: float | torch.Tensor, + sample: torch.LongTensor, + generator: torch.Generator | None = None, + return_dict: bool = True, + ) -> HybridTokenDiffusionSchedulerOutput | tuple[torch.LongTensor]: + if sample.dtype != torch.long: + raise ValueError(f"`sample` must be int64 token IDs, got dtype={sample.dtype}.") + if model_output.ndim != 3 or model_output.shape[-1] != self.vocab_size: + raise ValueError( + f"`model_output` must have shape [batch, seq_len, vocab_size={self.vocab_size}], got {tuple(model_output.shape)}." + ) + if model_output.shape[0] != sample.shape[0] or model_output.shape[1] != sample.shape[1]: + raise ValueError( + f"`model_output` batch/seq dims {tuple(model_output.shape[:2])} must match `sample` {tuple(sample.shape)}." + ) + + if self._timesteps_with_end is None: + raise ValueError("Call `set_timesteps(...)` before calling `step()`.") + + device = sample.device + batch_size, seq_len = sample.shape + + step_index = self._index_for_timestep(timestep) + t_val = self._timesteps_with_end[step_index].to(device=device) + s_val = self._timesteps_with_end[step_index + 1].to(device=device) + + t = t_val * torch.ones(batch_size, device=device, dtype=torch.float32) + s = s_val * torch.ones(batch_size, device=device, dtype=torch.float32) + + logits = model_output.to(dtype=torch.float32) + logits = logits.clone() + logits[..., self.mask_token_id] = torch.finfo(logits.dtype).min + probs = logits.softmax(dim=-1) + + q_s = self._probs_at_t(probs, s) + q_t = self._probs_at_t(probs, t) + q_zt = q_t.gather(-1, sample.unsqueeze(-1)).clamp_min(torch.finfo(torch.float32).eps) + + alpha_t, beta_pi_t = self._get_alpha_betapi(t) + alpha_s, beta_pi_s = self._get_alpha_betapi(s) + + alpha_ts = (alpha_t / alpha_s).clamp_min(torch.finfo(torch.float32).eps) + beta_pi_ts = beta_pi_t - (alpha_t / alpha_s) * beta_pi_s + + vz_t = F.one_hot(sample, num_classes=self.vocab_size).to(dtype=torch.float32) + beta_pi_ts_at_zt = beta_pi_ts.unsqueeze(1).expand_as(vz_t).gather(-1, sample.unsqueeze(-1)) + q_ts = alpha_ts.view(batch_size, 1, 1) * vz_t + beta_pi_ts_at_zt + + q_st = q_ts * q_s / q_zt + q_st = q_st.clamp_min(torch.finfo(torch.float32).tiny) + q_st = q_st / q_st.sum(dim=-1, keepdim=True).clamp_min(torch.finfo(torch.float32).eps) + + x_prev = _gumbel_argmax(torch.log(q_st), generator=generator).to(dtype=torch.long) + + if not return_dict: + return (x_prev,) + return HybridTokenDiffusionSchedulerOutput(prev_sample=x_prev) + + +__all__ = ["HybridTokenDiffusionScheduler", "HybridTokenDiffusionSchedulerOutput"] diff --git a/src/diffusers/schedulers/scheduling_sdar_token_diffusion.py b/src/diffusers/schedulers/scheduling_sdar_token_diffusion.py new file mode 100644 index 000000000000..e52eb95e04c3 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_sdar_token_diffusion.py @@ -0,0 +1,323 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass + +import torch +import torch.nn.functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class SDARTokenDiffusionSchedulerOutput(BaseOutput): + """ + Output class for SDAR-style block diffusion scheduling. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch_size, block_length)`): + Updated block tokens after the current denoising step. + transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`): + Boolean mask indicating which tokens were updated. + sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`): + Sampled token IDs from the model logits. + sampled_probs (`torch.Tensor` of shape `(batch_size, block_length)`): + Probabilities of the sampled tokens. + """ + + prev_sample: torch.LongTensor + transfer_index: torch.BoolTensor + sampled_tokens: torch.LongTensor + sampled_probs: torch.Tensor + + +class SDARTokenDiffusionScheduler(SchedulerMixin, ConfigMixin): + """ + Scheduler for SDAR-style block diffusion decoding. + """ + + order = 1 + + @register_to_config + def __init__( + self, + block_length: int = 4, + num_inference_steps: int = 4, + remasking_strategy: str = "low_confidence_dynamic", + confidence_threshold: float = 0.9, + entropy_threshold: float = 0.35, + temperature: float = 1.0, + top_k: int = 0, + top_p: float = 1.0, + ): + self.num_inference_steps = num_inference_steps + self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, dtype=torch.long) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + self.num_inference_steps = int(num_inference_steps) + self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long) + + def get_num_transfer_tokens(self, block_length: int, num_inference_steps: int) -> torch.LongTensor: + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + base = int(block_length) // int(num_inference_steps) + remainder = int(block_length) % int(num_inference_steps) + num_transfer_tokens = torch.zeros(int(num_inference_steps), dtype=torch.long) + num_transfer_tokens += base + if remainder > 0: + num_transfer_tokens[:remainder] += 1 + return num_transfer_tokens + + def _top_k_logits(self, logits: torch.Tensor, k: int) -> torch.Tensor: + if k <= 0: + return logits + values, _ = torch.topk(logits, k) + min_values = values[..., -1, None] + return torch.where(logits < min_values, torch.full_like(logits, float("-inf")), logits) + + def _top_p_logits(self, logits: torch.Tensor, p: float) -> torch.Tensor: + if p >= 1.0: + return logits + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + sorted_mask = cumulative_probs > p + sorted_mask[..., 1:] = sorted_mask[..., :-1].clone() + sorted_mask[..., 0] = False + mask_indices = torch.scatter(torch.zeros_like(logits, dtype=torch.bool), -1, sorted_indices, sorted_mask) + return logits.masked_fill(mask_indices, float("-inf")) + + def sample( + self, + logits: torch.Tensor, + *, + temperature: float | None = None, + top_k: int | None = None, + top_p: float | None = None, + generator: torch.Generator | None = None, + ) -> tuple[torch.LongTensor, torch.Tensor]: + if temperature is None: + temperature = float(self.config.temperature) + if top_k is None: + top_k = int(self.config.top_k) + if top_p is None: + top_p = float(self.config.top_p) + + orig_shape = logits.shape[:-1] + vocab_size = logits.shape[-1] + flat = logits.view(-1, vocab_size) + + if temperature < 1e-5: + probs = F.softmax(flat, dim=-1) + tokens = torch.argmax(flat, dim=-1, keepdim=True) + token_probs = torch.gather(probs, -1, tokens) + return tokens.view(*orig_shape), token_probs.view(*orig_shape) + + flat = flat / float(temperature) + flat = self._top_k_logits(flat, int(top_k)) + flat = self._top_p_logits(flat, float(top_p)) + probs = F.softmax(flat, dim=-1) + tokens = torch.multinomial(probs, num_samples=1, generator=generator) + token_probs = torch.gather(probs, -1, tokens) + return tokens.view(*orig_shape), token_probs.view(*orig_shape) + + def check_should_stop( + self, + sequences: torch.LongTensor, + prompt_length: int, + stop_token_ids: list[int] | None = None, + ) -> bool: + """ + Check whether generation should stop based on stop token IDs. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Current full sequence including prompt. + prompt_length (`int`): + Number of prompt tokens at the start of the sequence. + stop_token_ids (`list[int]`, *optional*): + Token IDs that signal generation should stop. + + Returns: + `bool`: `True` if any stop token is found in the generated portion. + """ + if stop_token_ids is None or len(stop_token_ids) == 0: + return False + stop_tensor = torch.tensor(stop_token_ids, device=sequences.device, dtype=torch.long) + return torch.isin(sequences[:, prompt_length:], stop_tensor).any().item() + + def add_noise( + self, + original_samples: torch.LongTensor, + attention_mask: torch.LongTensor, + *, + prompt_length: int, + block_length: int, + mask_token_id: int, + generator: torch.Generator | None = None, + ) -> tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]: + """ + Apply the forward (noising) process for semi-autoregressive block masking. + + For each block after the prompt, a random fraction of valid (non-padding) tokens are replaced with + `mask_token_id`. Two complementary views are returned: `noisy` and `noisy_rev`, where the masked positions in + one are the unmasked positions in the other. + + Args: + original_samples (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Clean token IDs. + attention_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Padding mask (1 for valid, 0 for padding). + prompt_length (`int`): + Number of leading prompt tokens to keep unmasked. + block_length (`int`): + Block size for masking. + mask_token_id (`int`): + Token ID to use for masked positions. + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + + Returns: + `tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]`: + `(noisy, noisy_rev, masked, masked_rev)` β€” the two complementary noisy sequences and their + corresponding boolean masks. + """ + batch_size, seq_len = original_samples.shape + device = original_samples.device + + noisy = original_samples.clone() + noisy_rev = original_samples.clone() + masked = torch.zeros_like(original_samples, dtype=torch.bool) + masked_rev = torch.zeros_like(original_samples, dtype=torch.bool) + + valid = attention_mask.to(dtype=torch.bool) + for block_start in range(prompt_length, seq_len, block_length): + block_end = min(seq_len, block_start + block_length) + seg_len = block_end - block_start + if seg_len <= 0: + continue + + p_mask = torch.rand((batch_size, 1), device=device, generator=generator) + seg = torch.rand((batch_size, seg_len), device=device, generator=generator) < p_mask + seg = seg & valid[:, block_start:block_end] + seg_rev = (~seg) & valid[:, block_start:block_end] + + masked[:, block_start:block_end] = seg + masked_rev[:, block_start:block_end] = seg_rev + + noisy = torch.where(masked, torch.full_like(noisy, mask_token_id), noisy) + noisy_rev = torch.where(masked_rev, torch.full_like(noisy_rev, mask_token_id), noisy_rev) + return noisy, noisy_rev, masked, masked_rev + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.LongTensor, + *, + mask_token_id: int, + num_transfer_tokens: torch.LongTensor, + remasking_strategy: str | None = None, + confidence_threshold: float | None = None, + entropy_threshold: float | None = None, + temperature: float | None = None, + top_k: int | None = None, + top_p: float | None = None, + generator: torch.Generator | None = None, + return_dict: bool = True, + ) -> SDARTokenDiffusionSchedulerOutput | tuple[torch.LongTensor, torch.BoolTensor, torch.LongTensor, torch.Tensor]: + if remasking_strategy is None: + remasking_strategy = str(self.config.remasking_strategy) + if confidence_threshold is None: + confidence_threshold = float(self.config.confidence_threshold) + if entropy_threshold is None: + entropy_threshold = float(self.config.entropy_threshold) + + sampled_tokens, sampled_probs = self.sample( + model_output, temperature=temperature, top_k=top_k, top_p=top_p, generator=generator + ) + mask_index = sample == int(mask_token_id) + transfer_index = torch.zeros_like(mask_index) + + if isinstance(timestep, torch.Tensor): + step_index = int(timestep.item()) + else: + step_index = int(timestep) + + if step_index >= int(num_transfer_tokens.numel()): + step_index = int(num_transfer_tokens.numel()) - 1 + step_transfer = int(num_transfer_tokens[step_index].item()) + + if remasking_strategy == "sequential": + for j in range(sample.shape[0]): + if not mask_index[j].any(): + continue + num_masked = int(mask_index[j].sum().item()) + k = min(step_transfer, num_masked) + first_mask_index = mask_index[j].nonzero(as_tuple=True)[0].min().item() + transfer_index[j, first_mask_index : first_mask_index + k] = True + + elif remasking_strategy in {"low_confidence_static", "low_confidence_dynamic"}: + confidence = torch.where(mask_index, sampled_probs, torch.full_like(sampled_probs, float("-inf"))) + for j in range(confidence.shape[0]): + if not mask_index[j].any(): + continue + num_masked = int(mask_index[j].sum().item()) + k = min(step_transfer, num_masked) + if remasking_strategy == "low_confidence_dynamic": + high_conf_mask = confidence[j] > confidence_threshold + if int(high_conf_mask.sum().item()) >= k: + transfer_index[j] = high_conf_mask + continue + _, idx = torch.topk(confidence[j], k) + transfer_index[j, idx] = True + + elif remasking_strategy == "entropy_bounded": + eps = 1e-12 + entropies = -(sampled_probs.clamp_min(eps) * sampled_probs.clamp_min(eps).log()).sum(dim=-1) + entropies = torch.where(mask_index, entropies, torch.full_like(sampled_probs, float("inf"))) + ent_sorted, order = torch.sort(entropies, dim=1, descending=False) + cumsum = torch.cumsum(ent_sorted, dim=1) + for j in range(sampled_probs.shape[0]): + if not mask_index[j].any(): + continue + threshold_tensor = torch.tensor(entropy_threshold, device=sampled_probs.device) + k = int(torch.searchsorted(cumsum[j], threshold_tensor, right=False).item()) + num_masked = int(mask_index[j].sum().item()) + k = max(1, min(k, num_masked)) + selected_token_indices = order[j, :k] + transfer_index[j, selected_token_indices] = True + + else: + raise ValueError(f"Unknown remasking strategy: {remasking_strategy}") + + prev_sample = sample.clone() + prev_sample[transfer_index] = sampled_tokens[transfer_index] + + if not return_dict: + return prev_sample, transfer_index, sampled_tokens, sampled_probs + return SDARTokenDiffusionSchedulerOutput( + prev_sample=prev_sample, + transfer_index=transfer_index, + sampled_tokens=sampled_tokens, + sampled_probs=sampled_probs, + ) + + +__all__ = ["SDARTokenDiffusionScheduler", "SDARTokenDiffusionSchedulerOutput"] diff --git a/src/diffusers/schedulers/scheduling_token_diffusion.py b/src/diffusers/schedulers/scheduling_token_diffusion.py new file mode 100644 index 000000000000..b00a1a743b97 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_token_diffusion.py @@ -0,0 +1,567 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np +import torch +import torch.nn.functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class TokenDiffusionSchedulerOutput(BaseOutput): + """ + Output class for discrete token schedulers. + + Args: + prev_sample (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Sample at the previous timestep. This should be fed into the model at the next denoising iteration. + """ + + prev_sample: torch.LongTensor + + +def _gumbel_argmax(logits: torch.Tensor, generator: torch.Generator | None = None) -> torch.LongTensor: + """ + Sample from a categorical distribution defined by (unnormalized) logits via Gumbel-max. + + Args: + logits: Tensor of shape `(..., vocab_size)`. + generator: Optional torch generator for determinism. + + Returns: + `torch.LongTensor` of shape `logits.shape[:-1]` with sampled indices. + """ + # Gumbel(0,1) noise: -log(-log(U)) + uniform = torch.rand(logits.shape, device=logits.device, dtype=logits.dtype, generator=generator).clamp_(1e-30, 1) + gumbel = -torch.log(-torch.log(uniform)) + return (logits + gumbel).argmax(dim=-1) + + +class TokenDiffusionScheduler(SchedulerMixin, ConfigMixin): + """ + Discrete diffusion scheduler over token IDs (categorical states). + + This scheduler is designed for *token-space* diffusion (e.g. masked/absorbing diffusion language models) and + follows the diffusers scheduler API where possible: `set_timesteps()` for inference and `step()` for reverse + updates. + + Currently implemented: + - Forward process: + - `absorbing`: with probability `1 - alpha(t)` replace token with `mask_token_id`. + - `uniform`: with probability `1 - alpha(t)` replace token with a uniform random token. + - Noise schedule: selectable `alpha(t)` families with `t in [0, 1]`. + + Notes: + - `step()` expects the model to return logits over vocabulary for `x0` reconstruction. + - The mask token is treated as an *absorbing state* and is never sampled as an `x0` prediction. + """ + + order = 1 + + @register_to_config + def __init__( + self, + vocab_size: int, + mask_token_id: int, + num_train_timesteps: int = 1000, + alpha_schedule: str = "log_linear", + eps: float = 1e-3, + sigma_min: float = 1e-4, + sigma_max: float = 20.0, + forward_process: str = "absorbing", + exclude_mask_from_uniform: bool = True, + ): + if vocab_size <= 0: + raise ValueError(f"`vocab_size` must be > 0, got {vocab_size}.") + if num_train_timesteps <= 1: + raise ValueError(f"`num_train_timesteps` must be > 1, got {num_train_timesteps}.") + if not (0.0 < eps < 1.0): + raise ValueError(f"`eps` must be in (0, 1), got {eps}.") + if not (0 <= mask_token_id < vocab_size): + raise ValueError(f"`mask_token_id` must be in [0, vocab_size), got {mask_token_id}.") + alpha_schedule = str(alpha_schedule).lower() + if alpha_schedule not in {"log_linear", "linear", "cosine", "geometric"}: + raise ValueError( + "`alpha_schedule` must be one of {'log_linear','linear','cosine','geometric'}, got" + f" {alpha_schedule!r}." + ) + if sigma_min <= 0 or sigma_max <= 0: + raise ValueError( + f"`sigma_min` and `sigma_max` must be > 0, got sigma_min={sigma_min}, sigma_max={sigma_max}." + ) + if sigma_max <= sigma_min: + raise ValueError(f"`sigma_max` must be > `sigma_min`, got sigma_min={sigma_min}, sigma_max={sigma_max}.") + if forward_process not in {"absorbing", "uniform"}: + raise ValueError(f"`forward_process` must be one of {{'absorbing','uniform'}}, got {forward_process!r}.") + + self.vocab_size = int(vocab_size) + self.mask_token_id = int(mask_token_id) + self.num_train_timesteps = int(num_train_timesteps) + self.alpha_schedule = alpha_schedule + self.eps = float(eps) + self.sigma_min = float(sigma_min) + self.sigma_max = float(sigma_max) + self.forward_process = str(forward_process) + self.exclude_mask_from_uniform = bool(exclude_mask_from_uniform) + + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + self.alphas = None + self._step_index_map = None + + def _effective_vocab_size(self) -> int: + if self.forward_process == "uniform" and self.exclude_mask_from_uniform: + return self.vocab_size - 1 + return self.vocab_size + + def _sample_uniform_tokens( + self, shape: torch.Size, device: torch.device, dtype: torch.dtype, generator: torch.Generator | None = None + ) -> torch.LongTensor: + """ + Sample uniform token IDs, optionally excluding `mask_token_id` (by shifting indices around it). + """ + if self.forward_process != "uniform": + raise ValueError("Uniform token sampling is only valid for `forward_process='uniform'`.") + + if not self.exclude_mask_from_uniform: + return torch.randint(0, self.vocab_size, shape, device=device, dtype=dtype, generator=generator) + + # Sample in [0, vocab_size-1) and shift around mask_token_id. + v_eff = self.vocab_size - 1 + draw = torch.randint(0, v_eff, shape, device=device, dtype=dtype, generator=generator) + return torch.where(draw >= self.mask_token_id, draw + 1, draw) + + def sample_prior( + self, + shape: torch.Size, + device: torch.device, + generator: torch.Generator | None = None, + ) -> torch.LongTensor: + """ + Sample from the prior distribution of the forward process at t=1. + + For `forward_process="absorbing"`, returns a tensor filled with `mask_token_id`. For + `forward_process="uniform"`, returns uniform random token IDs (optionally excluding `mask_token_id`). + + Args: + shape (`torch.Size`): + Desired output shape, e.g. `(batch_size, seq_len)`. + device (`torch.device`): + Device for the output tensor. + generator (`torch.Generator`, *optional*): + Optional generator for determinism (only used for the uniform process). + + Returns: + `torch.LongTensor` of shape `shape` with sampled prior token IDs. + """ + if self.forward_process == "uniform": + return self._sample_uniform_tokens(shape, device=device, dtype=torch.long, generator=generator) + return torch.full(shape, self.mask_token_id, device=device, dtype=torch.long) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Timesteps are stored in descending order, so `timesteps[0]` is the noisiest step. Alpha values are pre-computed + for each timestep so that `step()` can look them up by index instead of recomputing them on every call. + """ + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.") + self.num_inference_steps = int(num_inference_steps) + + # Standard diffusers behavior: map inference steps onto training step indices. + timesteps = torch.linspace( + self.num_train_timesteps - 1, 0, self.num_inference_steps, dtype=torch.float32 + ).round() + self.timesteps = timesteps.to(dtype=torch.long, device=device) + + # Pre-compute alpha(t) for every inference timestep. + t_continuous = timesteps / float(self.num_train_timesteps - 1) + t_continuous = t_continuous.clamp_(0.0, 1.0) + self.alphas = self._alpha_t(t_continuous).to(dtype=torch.float32, device=device) + + # Build a map from timestep value β†’ index for O(1) lookup in step(). + self._step_index_map = {int(self.timesteps[i].item()): i for i in range(len(self.timesteps))} + + def scale_model_input(self, sample: torch.Tensor, timestep: int | None = None) -> torch.Tensor: + return sample + + def _t_from_timestep(self, timestep: int | torch.Tensor, device: torch.device) -> torch.Tensor: + """ + Convert an integer training timestep index into continuous time `t in [0, 1]`. + """ + if isinstance(timestep, torch.Tensor): + t_idx = timestep.to(device=device, dtype=torch.float32) + else: + t_idx = torch.tensor(float(timestep), device=device, dtype=torch.float32) + denom = float(self.num_train_timesteps - 1) + return (t_idx / denom).clamp_(0.0, 1.0) + + def _alpha_t(self, t: torch.Tensor) -> torch.Tensor: + """ + Compute alpha(t) for the configured schedule. + + The returned tensor is expected to be in (0, 1] and monotone decreasing in `t`. + """ + if self.alpha_schedule == "log_linear": + # alpha(t) = 1 - (1 - eps) * t + return 1.0 - (1.0 - self.eps) * t + + if self.alpha_schedule == "linear": + # alpha(t) = (1 - 2*eps) * (1 - t) + eps + return (1.0 - 2.0 * self.eps) * (1.0 - t) + self.eps + + if self.alpha_schedule == "cosine": + # alpha_base(t) = 1 - cos(pi/2 * (1 - t)) + # alpha(t) = (1 - 2*eps) * alpha_base(t) + eps + base = 1.0 - torch.cos(torch.pi / 2.0 * (1.0 - t)) + return (1.0 - 2.0 * self.eps) * base + self.eps + + if self.alpha_schedule == "geometric": + # total_noise(t) = sigma_min^(1-t) * sigma_max^t + # alpha(t) = exp(-total_noise(t)) + sigma_min = torch.as_tensor(self.sigma_min, device=t.device, dtype=t.dtype) + sigma_max = torch.as_tensor(self.sigma_max, device=t.device, dtype=t.dtype) + total_noise = (sigma_min ** (1.0 - t)) * (sigma_max**t) + return (-total_noise).exp() + + raise ValueError(f"Unsupported alpha schedule: {self.alpha_schedule!r}") + + def _alpha_prime_t(self, t: torch.Tensor) -> torch.Tensor: + """ + Compute d/dt alpha(t) for the configured schedule. + """ + if self.alpha_schedule == "log_linear": + return -(1.0 - self.eps) * torch.ones_like(t) + + if self.alpha_schedule == "linear": + return -(1.0 - 2.0 * self.eps) * torch.ones_like(t) + + if self.alpha_schedule == "cosine": + base_prime = -(torch.pi / 2.0) * torch.sin(torch.pi / 2.0 * (1.0 - t)) + return (1.0 - 2.0 * self.eps) * base_prime + + if self.alpha_schedule == "geometric": + sigma_min = torch.as_tensor(self.sigma_min, device=t.device, dtype=t.dtype) + sigma_max = torch.as_tensor(self.sigma_max, device=t.device, dtype=t.dtype) + total_noise = (sigma_min ** (1.0 - t)) * (sigma_max**t) + alpha = (-total_noise).exp() + rate = total_noise * (sigma_max.log() - sigma_min.log()) + return -alpha * rate + + raise ValueError(f"Unsupported alpha schedule: {self.alpha_schedule!r}") + + def get_mdlm_loss_weights(self, timesteps: torch.LongTensor) -> torch.Tensor: + """ + Return per-example positive loss weights for masked-token reconstruction objectives. + + The weight corresponds to `-alpha'(t) / (1 - alpha(t))`, which is positive for monotone decreasing alpha(t). + + Args: + timesteps (`torch.LongTensor` of shape `(batch_size,)`): + Training timestep indices in `[0, num_train_timesteps-1]`. + + Returns: + `torch.FloatTensor` of shape `(batch_size, 1)`: + Positive weights to multiply token-level cross-entropy by. + """ + if timesteps.dtype not in (torch.int32, torch.int64): + raise ValueError(f"`timesteps` must be an integer tensor, got dtype={timesteps.dtype}.") + device = timesteps.device + t = self._t_from_timestep(timesteps.to(device), device=device) + t = t.to(dtype=torch.float32) + alpha = self._alpha_t(t).to(dtype=torch.float32) + dalpha = self._alpha_prime_t(t).to(dtype=torch.float32) + denom = (1.0 - alpha).clamp_min(torch.finfo(torch.float32).eps) + w = (-dalpha / denom).clamp_min(torch.finfo(torch.float32).tiny) + return w.view(-1, 1) + + def get_alpha(self, timesteps: torch.LongTensor) -> torch.Tensor: + """ + Return per-example alpha(t) values for the configured schedule. + + Args: + timesteps (`torch.LongTensor` of shape `(batch_size,)`): + Training timestep indices in `[0, num_train_timesteps-1]`. + + Returns: + `torch.FloatTensor` of shape `(batch_size, 1)`: + Alpha values in `(0, 1]` for each example. + """ + if timesteps.dtype not in (torch.int32, torch.int64): + raise ValueError(f"`timesteps` must be an integer tensor, got dtype={timesteps.dtype}.") + device = timesteps.device + t = self._t_from_timestep(timesteps.to(device), device=device).to(dtype=torch.float32) + alpha = self._alpha_t(t).to(dtype=torch.float32) + return alpha.view(-1, 1) + + def get_alpha_prime(self, timesteps: torch.LongTensor) -> torch.Tensor: + """ + Return per-example time derivative alpha'(t) for the configured schedule. + + Args: + timesteps (`torch.LongTensor` of shape `(batch_size,)`): + Training timestep indices in `[0, num_train_timesteps-1]`. + + Returns: + `torch.FloatTensor` of shape `(batch_size, 1)`: + Alpha derivatives with respect to continuous time `t in [0, 1]`. + """ + if timesteps.dtype not in (torch.int32, torch.int64): + raise ValueError(f"`timesteps` must be an integer tensor, got dtype={timesteps.dtype}.") + device = timesteps.device + t = self._t_from_timestep(timesteps.to(device), device=device).to(dtype=torch.float32) + dalpha = self._alpha_prime_t(t).to(dtype=torch.float32) + return dalpha.view(-1, 1) + + def add_noise( + self, + original_samples: torch.LongTensor, + noise: torch.Tensor | None, + timesteps: torch.LongTensor, + block_mask: torch.BoolTensor | None = None, + ) -> torch.LongTensor: + """ + Apply the forward process q(x_t | x_0). + + The `noise` argument is accepted for API compatibility but is not used for the absorbing kernel. + + Args: + original_samples (`torch.LongTensor`): + Original token IDs of shape `(batch_size, seq_len)`. + noise (`torch.Tensor`, *optional*): + Accepted for API compatibility; unused. + timesteps (`torch.LongTensor`): + Per-example timestep indices of shape `(batch_size,)`. + block_mask (`torch.BoolTensor`, *optional*): + Boolean mask of shape `(batch_size, seq_len)`. When provided, only positions where `block_mask` is + `True` are noised; other positions are kept unchanged. + """ + del noise + + if original_samples.dtype != torch.long: + raise ValueError(f"`original_samples` must be int64 token IDs, got dtype={original_samples.dtype}.") + if timesteps.dtype not in (torch.int32, torch.int64): + raise ValueError(f"`timesteps` must be an integer tensor, got dtype={timesteps.dtype}.") + if block_mask is not None: + if block_mask.dtype != torch.bool: + raise ValueError(f"`block_mask` must be boolean, got dtype={block_mask.dtype}.") + if block_mask.shape != original_samples.shape: + raise ValueError( + f"`block_mask` must have shape {tuple(original_samples.shape)}, got {tuple(block_mask.shape)}." + ) + + batch_size, seq_len = original_samples.shape + device = original_samples.device + + # Convert per-example timesteps into alpha(t) in [eps, 1]. + t = self._t_from_timestep(timesteps.to(device), device=device).view(batch_size, 1) + alpha = self._alpha_t(t).to(dtype=torch.float32) + + p_replace = (1.0 - alpha).expand(batch_size, seq_len) + rand = torch.rand((batch_size, seq_len), device=device, dtype=torch.float32) + replace_positions = rand < p_replace + + if self.forward_process == "absorbing": + replacement = torch.full_like(original_samples, self.mask_token_id) + elif self.forward_process == "uniform": + replacement = self._sample_uniform_tokens( + original_samples.shape, device=device, dtype=original_samples.dtype, generator=None + ) + else: + raise ValueError(f"Unsupported forward process: {self.forward_process!r}") + + noised = torch.where(replace_positions, replacement, original_samples) + if block_mask is not None: + noised = torch.where(block_mask.to(device=device), noised, original_samples) + return noised + + def enforce_fixed_masks( + self, + sample: torch.LongTensor, + fixed_mask: torch.BoolTensor, + fixed_values: torch.LongTensor, + ) -> torch.LongTensor: + """ + Re-apply fixed token values at positions indicated by `fixed_mask`. + + This is used by the pipeline to enforce prefix / infill conditioning after each scheduler step. + + Args: + sample (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Current token IDs after a scheduler step. + fixed_mask (`torch.BoolTensor` of shape `(batch_size, seq_len)`): + Boolean mask where `True` indicates a position whose value must be restored. + fixed_values (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Token IDs to restore at the fixed positions. + + Returns: + `torch.LongTensor`: Token IDs with fixed positions restored. + """ + return torch.where(fixed_mask, fixed_values, sample) + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.LongTensor, + generator: torch.Generator | None = None, + return_dict: bool = True, + block_mask: torch.BoolTensor | None = None, + ) -> TokenDiffusionSchedulerOutput | tuple[torch.LongTensor]: + """ + Reverse diffusion step for the configured forward process. + + For `forward_process="absorbing"`, the update mirrors the common absorbing posterior: + - Keep all unmasked positions fixed. + - For masked positions, with probability p_denoise replace mask by a sample from p_theta(x0 | x_t, t). + + For `forward_process="uniform"`, this implements the discrete posterior used by UDLM-style uniform token + diffusion. + """ + if sample.dtype != torch.long: + raise ValueError(f"`sample` must be int64 token IDs, got dtype={sample.dtype}.") + if model_output.ndim != 3 or model_output.shape[-1] != self.vocab_size: + raise ValueError( + f"`model_output` must have shape [batch, seq_len, vocab_size={self.vocab_size}], got {tuple(model_output.shape)}." + ) + if model_output.shape[0] != sample.shape[0] or model_output.shape[1] != sample.shape[1]: + raise ValueError( + f"`model_output` batch/seq dims {tuple(model_output.shape[:2])} must match `sample` {tuple(sample.shape)}." + ) + if block_mask is not None: + if block_mask.dtype != torch.bool: + raise ValueError(f"`block_mask` must be boolean, got dtype={block_mask.dtype}.") + if block_mask.shape != sample.shape: + raise ValueError(f"`block_mask` must have shape {tuple(sample.shape)}, got {tuple(block_mask.shape)}.") + + device = sample.device + batch_size, seq_len = sample.shape + + # Figure out the previous timestep in the configured inference schedule. + if self.num_inference_steps is None: + raise ValueError("Call `set_timesteps(num_inference_steps, ...)` before calling `step()`.") + + if isinstance(timestep, torch.Tensor): + timestep_int = int(timestep.item()) + else: + timestep_int = int(timestep) + + # Look up the step index and pre-computed alpha values. + if self._step_index_map is not None and timestep_int in self._step_index_map: + step_index = self._step_index_map[timestep_int] + else: + current_indices = (self.timesteps == timestep_int).nonzero(as_tuple=False) + if current_indices.numel() == 0: + raise ValueError(f"`timestep` ({timestep_int}) must be one of `self.timesteps`.") + step_index = int(current_indices[0].item()) + + is_noise_removal_step = step_index + 1 >= len(self.timesteps) + + # Use pre-computed alphas when available, otherwise fall back to computing on the fly. + if self.alphas is not None: + alpha_t = self.alphas[step_index].to(device=device) + if is_noise_removal_step: + alpha_prev = torch.tensor(1.0, device=device, dtype=torch.float32) + else: + alpha_prev = self.alphas[step_index + 1].to(device=device) + else: + t = self._t_from_timestep(timestep_int, device=device) + alpha_t = self._alpha_t(t).to(dtype=torch.float32) + if is_noise_removal_step: + alpha_prev = torch.tensor(1.0, device=device, dtype=torch.float32) + else: + prev_timestep_int = int(self.timesteps[step_index + 1].item()) + t_prev = self._t_from_timestep(prev_timestep_int, device=device) + alpha_prev = self._alpha_t(t_prev).to(dtype=torch.float32) + + if self.forward_process == "uniform": + # Convert logits to probabilities for x0; optionally forbid mask token. + logits = model_output.to(dtype=torch.float32) + if self.exclude_mask_from_uniform: + logits = logits.clone() + logits[..., self.mask_token_id] = torch.finfo(logits.dtype).min + p_x0 = logits.softmax(dim=-1) + + V = self.vocab_size + x = sample + xt_one_hot = F.one_hot(x, V).to(dtype=p_x0.dtype) + + alpha_ts = (alpha_t / alpha_prev).clamp_min(torch.finfo(torch.float32).eps) + + if self.exclude_mask_from_uniform: + limiting = torch.full((V,), 1.0 / float(V - 1), device=device, dtype=p_x0.dtype) + limiting[self.mask_token_id] = 0.0 + else: + limiting = torch.full((V,), 1.0 / float(V), device=device, dtype=p_x0.dtype) + limiting = limiting.view(1, 1, -1) + + alpha_t3 = alpha_t.view(1, 1, 1) + alpha_s3 = alpha_prev.view(1, 1, 1) + alpha_ts3 = alpha_ts.view(1, 1, 1) + + numerator = ( + (alpha_t3 * V * p_x0 * xt_one_hot) + + ((alpha_ts3 - alpha_t3) * xt_one_hot) + + ((alpha_s3 - alpha_t3) * p_x0) + + ((1.0 - alpha_ts3) * (1.0 - alpha_s3) * limiting) + ) + denom = (alpha_t3 * V * p_x0.gather(-1, x.unsqueeze(-1)) + (1.0 - alpha_t3)).clamp_min( + torch.finfo(torch.float32).eps + ) + + q_xs = numerator / denom + q_xs = q_xs.clamp_min(torch.finfo(torch.float32).tiny) + q_xs = q_xs / q_xs.sum(dim=-1, keepdim=True).clamp_min(torch.finfo(torch.float32).eps) + + x_prev = _gumbel_argmax(torch.log(q_xs), generator=generator).to(dtype=torch.long) + + elif self.forward_process == "absorbing": + # p_denoise = (alpha_prev - alpha_t) / (1 - alpha_t) + denom = (1.0 - alpha_t).clamp_min(torch.finfo(torch.float32).eps) + p_denoise = ((alpha_prev - alpha_t) / denom).clamp(0.0, 1.0) + + # Sample x0 predictions (never sample the mask token). + logits = model_output.to(dtype=torch.float32) + logits[..., self.mask_token_id] = torch.finfo(logits.dtype).min + sampled_x0 = _gumbel_argmax(logits, generator=generator).to(dtype=torch.long) + + # Only masked positions can change. + is_masked = sample == self.mask_token_id + + # Bernoulli draw for whether to denoise at this step (only matters on masked positions). + rand = torch.rand((batch_size, seq_len), device=device, dtype=torch.float32, generator=generator) + should_denoise = rand < float(p_denoise.item()) + + x_prev = torch.where(is_masked & should_denoise, sampled_x0, sample) + + else: + raise ValueError(f"Unsupported forward process for `step()`: {self.forward_process!r}") + + if block_mask is not None: + x_prev = torch.where(block_mask.to(device=device), x_prev, sample) + + if not return_dict: + return (x_prev,) + return TokenDiffusionSchedulerOutput(prev_sample=x_prev) + + +__all__ = ["TokenDiffusionScheduler", "TokenDiffusionSchedulerOutput"] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index fa37388fe75a..d3274c6e9f49 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2233,6 +2233,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class BD3LMPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class BD3LMPipelineOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class BlipDiffusionControlNetPipeline(metaclass=DummyObject): _backends = ["torch"] @@ -2368,6 +2398,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HybridTokenDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ImagePipelineOutput(metaclass=DummyObject): _backends = ["torch"] @@ -2488,6 +2533,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class TokenDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class TokenDiffusionPipelineOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DiffusersQuantizer(metaclass=DummyObject): _backends = ["torch"] @@ -2518,6 +2593,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class BD3LMTokenDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class BD3LMTokenDiffusionSchedulerOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class BlockRefinementScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -2698,6 +2803,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DFlashTokenDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DFlashTokenDiffusionSchedulerOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DPMSolverMultistepInverseScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -2893,6 +3028,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HybridTokenDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class HybridTokenDiffusionSchedulerOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class IPNDMScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -3073,6 +3238,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class SDARTokenDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class SDARTokenDiffusionSchedulerOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class TCDScheduler(metaclass=DummyObject): _backends = ["torch"] @@ -3088,6 +3283,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class TokenDiffusionScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class UnCLIPScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 1e4d14566160..aa1f4eb11624 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1157,6 +1157,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class DFlashPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class DFlashPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class EasyAnimateControlPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -2942,6 +2972,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class SDARPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class SDARPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class SemanticStableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/bd3lm/__init__.py b/tests/pipelines/bd3lm/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/bd3lm/test_bd3lm.py b/tests/pipelines/bd3lm/test_bd3lm.py new file mode 100644 index 000000000000..d31c91a11f5e --- /dev/null +++ b/tests/pipelines/bd3lm/test_bd3lm.py @@ -0,0 +1,226 @@ +import unittest + +import torch + +from diffusers import BD3LMPipeline, BD3LMTokenDiffusionScheduler + + +class _DummyConfig: + def __init__(self, block_size, vocab_size, mask_index): + self.block_size = block_size + self.vocab_size = vocab_size + self.mask_index = mask_index + + +class _DummyBD3LMModel(torch.nn.Module): + """Minimal model that satisfies BD3LMPipeline's interface.""" + + def __init__(self, vocab_size=32, block_size=4): + super().__init__() + self.config = _DummyConfig( + block_size=block_size, + vocab_size=vocab_size, + mask_index=vocab_size - 1, + ) + self.backbone = torch.nn.Linear(1, vocab_size, bias=False) + self.register_buffer("_device_anchor", torch.empty(0)) + + @property + def dtype(self): + return torch.float32 + + @property + def device(self): + return self._device_anchor.device + + def reset_kv_cache(self, eval_batch_size): + pass + + def forward(self, input_ids, timesteps, sample_mode=False, store_kv=False): + batch_size, seq_len = input_ids.shape + logits = torch.zeros( + (batch_size, seq_len, self.config.vocab_size), + device=input_ids.device, + dtype=torch.float32, + ) + # Make logits vary by position so denoising is deterministic. + positions = torch.arange(seq_len, device=input_ids.device, dtype=torch.float32).view(1, seq_len, 1) + token_ids = (torch.arange(seq_len, device=input_ids.device) % (self.config.vocab_size - 2)).view(1, seq_len, 1) + logits.scatter_( + 2, + token_ids.expand(batch_size, -1, -1), + 1.0 + positions.expand(batch_size, -1, -1) * 0.1, + ) + return logits + + +def _make_pipeline(tokenizer=None, vocab_size=32, block_size=4): + model = _DummyBD3LMModel(vocab_size=vocab_size, block_size=block_size) + scheduler = BD3LMTokenDiffusionScheduler( + block_size=block_size, + mask_token_id=vocab_size - 1, + ) + return BD3LMPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer) + + +class BD3LMPipelineTest(unittest.TestCase): + def test_pipeline_runs(self): + """Basic end-to-end generation with input_ids.""" + pipe = _make_pipeline().to("cpu") + + input_ids = torch.tensor([[5, 6, 7, 8], [1, 2, 3, 4]], dtype=torch.long) + out = pipe( + input_ids=input_ids, + gen_length=8, + num_inference_steps=8, + nucleus_p=1.0, + output_type="seq", + ) + + self.assertEqual(out.sequences.shape[0], 2) + self.assertGreater(out.sequences.shape[1], 0) + self.assertLessEqual(out.sequences.shape[1], 8) + + def test_output_type_seq(self): + """output_type='seq' returns sequences but no texts.""" + pipe = _make_pipeline().to("cpu") + + out = pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + gen_length=8, + num_inference_steps=4, + nucleus_p=1.0, + output_type="seq", + ) + + self.assertIsNotNone(out.sequences) + self.assertIsNone(out.texts) + + def test_output_type_text_with_tokenizer(self): + """output_type='text' with a tokenizer should return decoded texts.""" + tok = type( + "Tok", + (), + { + "eos_token_id": None, + "mask_token_id": 31, + "batch_decode": lambda self, seqs, **kw: [f"decoded_{len(s)}" for s in seqs], + }, + )() + pipe = _make_pipeline(tokenizer=tok).to("cpu") + + out = pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + gen_length=8, + num_inference_steps=4, + nucleus_p=1.0, + output_type="text", + ) + + self.assertIsNotNone(out.sequences) + self.assertIsNotNone(out.texts) + self.assertEqual(len(out.texts), 1) + self.assertTrue(out.texts[0].startswith("decoded_")) + + def test_output_type_text_without_tokenizer(self): + """output_type='text' without a tokenizer should return texts=None.""" + pipe = _make_pipeline(tokenizer=None).to("cpu") + + out = pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + gen_length=8, + num_inference_steps=4, + nucleus_p=1.0, + output_type="text", + ) + + self.assertIsNotNone(out.sequences) + self.assertIsNone(out.texts) + + def test_output_type_invalid_raises(self): + """Invalid output_type should raise ValueError.""" + pipe = _make_pipeline().to("cpu") + + with self.assertRaises(ValueError): + pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + gen_length=8, + num_inference_steps=4, + nucleus_p=1.0, + output_type="invalid", + ) + + def test_return_dict_false(self): + """return_dict=False should return a plain tuple.""" + pipe = _make_pipeline().to("cpu") + + result = pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + gen_length=8, + num_inference_steps=4, + nucleus_p=1.0, + output_type="seq", + return_dict=False, + ) + + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + sequences, texts = result + self.assertIsInstance(sequences, torch.Tensor) + self.assertIsNone(texts) + + def test_check_inputs_bad_gen_length(self): + """gen_length <= 0 should raise ValueError.""" + pipe = _make_pipeline().to("cpu") + + with self.assertRaises(ValueError): + pipe( + input_ids=torch.tensor([[1, 2]], dtype=torch.long), + gen_length=0, + num_inference_steps=4, + nucleus_p=1.0, + output_type="seq", + ) + + def test_check_inputs_bad_num_inference_steps(self): + """num_inference_steps <= 0 should raise ValueError.""" + pipe = _make_pipeline().to("cpu") + + with self.assertRaises(ValueError): + pipe( + input_ids=torch.tensor([[1, 2]], dtype=torch.long), + gen_length=8, + num_inference_steps=0, + nucleus_p=1.0, + output_type="seq", + ) + + def test_check_inputs_bad_nucleus_p(self): + """nucleus_p out of (0, 1] should raise ValueError.""" + pipe = _make_pipeline().to("cpu") + + with self.assertRaises(ValueError): + pipe( + input_ids=torch.tensor([[1, 2]], dtype=torch.long), + gen_length=8, + num_inference_steps=4, + nucleus_p=0.0, + output_type="seq", + ) + + def test_check_inputs_bad_output_type(self): + """output_type not in {'seq', 'text'} should raise ValueError.""" + pipe = _make_pipeline().to("cpu") + + with self.assertRaises(ValueError): + pipe( + input_ids=torch.tensor([[1, 2]], dtype=torch.long), + gen_length=8, + num_inference_steps=4, + nucleus_p=1.0, + output_type="bad", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pipelines/dflash/__init__.py b/tests/pipelines/dflash/__init__.py new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/tests/pipelines/dflash/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/pipelines/dflash/test_dflash.py b/tests/pipelines/dflash/test_dflash.py new file mode 100644 index 000000000000..dcb3249ca68c --- /dev/null +++ b/tests/pipelines/dflash/test_dflash.py @@ -0,0 +1,393 @@ +import unittest + +import torch + +from diffusers import DFlashTokenDiffusionScheduler +from diffusers.pipelines.dflash.pipeline_dflash import DFlashPipeline + + +class _DummyModelOutput: + def __init__(self, logits, hidden_states=None): + self.logits = logits + self.hidden_states = hidden_states + + +class _DummyConfig: + def __init__(self, block_size, num_target_layers, num_hidden_layers): + self.block_size = block_size + self.num_target_layers = num_target_layers + self.num_hidden_layers = num_hidden_layers + + +class _DummyTargetModel(torch.nn.Module): + """Minimal target (causal LM) model that returns logits and hidden_states.""" + + def __init__(self, vocab_size: int, hidden_dim: int, num_layers: int): + super().__init__() + self.vocab_size = vocab_size + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.embed = torch.nn.Embedding(vocab_size, hidden_dim) + self.lm_head = torch.nn.Linear(hidden_dim, vocab_size, bias=False) + + def get_input_embeddings(self): + return self.embed + + def get_output_embeddings(self): + return self.lm_head + + def forward( + self, + input_ids, + position_ids=None, + past_key_values=None, + use_cache=False, + output_hidden_states=False, + logits_to_keep=None, + **kwargs, + ): + bsz, seq_len = input_ids.shape + h = self.embed(input_ids) + # Create hidden_states list: one entry per layer + 1 for the embedding layer + hidden_states = [h] * (self.num_layers + 1) if output_hidden_states else None + logits = self.lm_head(h) + # Make token 0 the most likely so acceptance is deterministic + logits[:, :, 0] = 10.0 + return _DummyModelOutput(logits=logits, hidden_states=hidden_states) + + def parameters(self): + return super().parameters() + + +class _DummyDraftModel(torch.nn.Module): + """Minimal draft model that returns hidden states of the expected shape.""" + + def __init__(self, hidden_dim: int, num_target_layers: int, block_size: int): + super().__init__() + self.block_size = block_size + self.config = _DummyConfig( + block_size=block_size, + num_target_layers=num_target_layers, + num_hidden_layers=1, + ) + # The draft model receives concatenated hidden states from num_target_layers target layers, + # each of dim hidden_dim, and produces a hidden state of dim hidden_dim. + self.proj = torch.nn.Linear(hidden_dim * num_target_layers, hidden_dim, bias=False) + self.register_buffer("_device_anchor", torch.empty(0)) + + @property + def device(self): + return self._device_anchor.device + + def forward( + self, + target_hidden, + noise_embedding, + position_ids=None, + past_key_values=None, + use_cache=False, + is_causal=False, + **kwargs, + ): + # Return a tensor with shape (batch, seq_len, hidden_dim) + bsz = noise_embedding.shape[0] + seq_len = position_ids.shape[1] if position_ids is not None else noise_embedding.shape[1] + h = torch.zeros(bsz, seq_len, self.proj.out_features, device=noise_embedding.device) + return h + + +def _make_pipeline(tokenizer=None, vocab_size=32, hidden_dim=16, num_target_layers=4, block_size=4): + target = _DummyTargetModel(vocab_size=vocab_size, hidden_dim=hidden_dim, num_layers=num_target_layers) + draft = _DummyDraftModel(hidden_dim=hidden_dim, num_target_layers=1, block_size=block_size) + # Set target_layer_ids directly so we skip the config-based computation. + draft.target_layer_ids = [1] + scheduler = DFlashTokenDiffusionScheduler() + return DFlashPipeline(draft_model=draft, target_model=target, tokenizer=tokenizer, scheduler=scheduler) + + +class DFlashPipelineTest(unittest.TestCase): + # ------------------------------------------------------------------ + # Pipeline runs + # ------------------------------------------------------------------ + def test_pipeline_runs_with_input_ids(self): + pipe = _make_pipeline() + input_ids = torch.tensor([[5, 6, 7, 8]], dtype=torch.long) + + out = pipe( + input_ids=input_ids, + max_new_tokens=8, + temperature=0.0, + mask_token_id=31, + stop_token_ids=None, + output_type="seq", + ) + + self.assertIsNotNone(out.sequences) + self.assertEqual(out.sequences.ndim, 2) + self.assertEqual(out.sequences.shape[0], 1) + # Generated tokens should not be longer than max_new_tokens + self.assertLessEqual(out.sequences.shape[1], 8) + + # ------------------------------------------------------------------ + # output_type="seq" + # ------------------------------------------------------------------ + def test_output_type_seq(self): + pipe = _make_pipeline() + input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + + out = pipe( + input_ids=input_ids, + max_new_tokens=8, + temperature=0.0, + mask_token_id=31, + output_type="seq", + ) + + self.assertIsNotNone(out.sequences) + self.assertIsNone(out.texts) + + # ------------------------------------------------------------------ + # output_type="text" with mock tokenizer + # ------------------------------------------------------------------ + def test_output_type_text_with_tokenizer(self): + tok = type( + "Tok", + (), + { + "eos_token_id": None, + "mask_token_id": 31, + "batch_decode": lambda self, seqs, **kw: [f"decoded_{len(s)}" for s in seqs], + }, + )() + pipe = _make_pipeline(tokenizer=tok) + + out = pipe( + input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), + max_new_tokens=8, + temperature=0.0, + output_type="text", + ) + + self.assertIsNotNone(out.sequences) + self.assertIsNotNone(out.texts) + self.assertEqual(len(out.texts), 1) + self.assertTrue(out.texts[0].startswith("decoded_")) + + def test_output_type_text_without_tokenizer(self): + """output_type='text' without a tokenizer should return texts=None.""" + pipe = _make_pipeline(tokenizer=None) + + out = pipe( + input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), + max_new_tokens=8, + temperature=0.0, + mask_token_id=31, + output_type="text", + ) + + self.assertIsNotNone(out.sequences) + self.assertIsNone(out.texts) + + # ------------------------------------------------------------------ + # output_type invalid + # ------------------------------------------------------------------ + def test_output_type_invalid_raises(self): + pipe = _make_pipeline() + + with self.assertRaises(ValueError): + pipe( + input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), + max_new_tokens=8, + mask_token_id=31, + output_type="invalid", + ) + + # ------------------------------------------------------------------ + # return_dict=False + # ------------------------------------------------------------------ + def test_pipeline_return_tuple(self): + pipe = _make_pipeline() + input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + + result = pipe( + input_ids=input_ids, + max_new_tokens=8, + temperature=0.0, + mask_token_id=31, + output_type="seq", + return_dict=False, + ) + + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + sequences, texts = result + self.assertIsNotNone(sequences) + self.assertIsNone(texts) + + # ------------------------------------------------------------------ + # check_inputs validation + # ------------------------------------------------------------------ + def test_check_inputs_no_inputs_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=None, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_both_prompt_and_messages_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt="hello", + messages=[{"role": "user", "content": "hi"}], + input_ids=None, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_invalid_input_ids_ndim_raises(self): + pipe = _make_pipeline() + bad_ids = torch.zeros(2, 3, 4, dtype=torch.long) + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=bad_ids, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_invalid_input_ids_dtype_raises(self): + pipe = _make_pipeline() + bad_ids = torch.zeros(1, 4, dtype=torch.float32) + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=bad_ids, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_invalid_max_new_tokens_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=torch.tensor([[1, 2]], dtype=torch.long), + max_new_tokens=0, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_invalid_output_type_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=torch.tensor([[1, 2]], dtype=torch.long), + max_new_tokens=16, + output_type="bad", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_prompt_without_tokenizer_raises(self): + pipe = _make_pipeline(tokenizer=None) + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt="hello", + messages=None, + input_ids=None, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_messages_without_tokenizer_raises(self): + pipe = _make_pipeline(tokenizer=None) + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=[{"role": "user", "content": "hi"}], + input_ids=None, + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_valid_input_ids_passes(self): + pipe = _make_pipeline() + # Should not raise. + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=torch.tensor([[1, 2, 3]], dtype=torch.long), + max_new_tokens=16, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + # ------------------------------------------------------------------ + # _prepare_input_ids + # ------------------------------------------------------------------ + def test_prepare_input_ids_from_tensor(self): + pipe = _make_pipeline() + ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + result = pipe._prepare_input_ids( + prompt=None, + messages=None, + input_ids=ids, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + self.assertTrue(torch.equal(result, ids)) + + def test_prepare_input_ids_from_1d_tensor(self): + pipe = _make_pipeline() + ids = torch.tensor([1, 2, 3], dtype=torch.long) + result = pipe._prepare_input_ids( + prompt=None, + messages=None, + input_ids=ids, + use_chat_template=False, + add_generation_prompt=False, + chat_template_kwargs=None, + ) + self.assertEqual(result.shape, (1, 3)) + + # ------------------------------------------------------------------ + # prepare_latents + # ------------------------------------------------------------------ + def test_prepare_latents(self): + pipe = _make_pipeline() + mask_token_id = 99 + latents = pipe.prepare_latents( + max_length=10, block_size=4, mask_token_id=mask_token_id, device=torch.device("cpu") + ) + self.assertEqual(latents.shape, (1, 14)) # 10 + 4 + self.assertTrue((latents == mask_token_id).all().item()) + self.assertEqual(latents.dtype, torch.long) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pipelines/sdar/__init__.py b/tests/pipelines/sdar/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/sdar/test_sdar.py b/tests/pipelines/sdar/test_sdar.py new file mode 100644 index 000000000000..830e1ec0da54 --- /dev/null +++ b/tests/pipelines/sdar/test_sdar.py @@ -0,0 +1,261 @@ +import unittest + +import torch + +from diffusers import SDARPipeline, SDARTokenDiffusionScheduler + + +class _DummyModelOutput: + def __init__(self, logits): + self.logits = logits + + +class _DummyCausalLM(torch.nn.Module): + """Minimal causal LM that returns deterministic logits given input_ids.""" + + def __init__(self, vocab_size: int): + super().__init__() + self.vocab_size = int(vocab_size) + self.register_buffer("_device_anchor", torch.empty(0)) + + @property + def dtype(self): + return torch.float32 + + @property + def device(self): + return self._device_anchor.device + + def forward(self, input_ids, attention_mask=None, position_ids=None, **kwargs): + batch_size, seq_len = input_ids.shape + logits = torch.zeros((batch_size, seq_len, self.vocab_size), device=input_ids.device, dtype=torch.float32) + # Make confidence vary with token position so top-k commits are deterministic. + positions = torch.arange(seq_len, device=input_ids.device, dtype=torch.float32).view(1, seq_len, 1) + token_ids = (torch.arange(seq_len, device=input_ids.device) % (self.vocab_size - 2)).view(1, seq_len, 1) + logits.scatter_(2, token_ids.expand(batch_size, -1, -1), 1.0 + positions.expand(batch_size, -1, -1) * 0.1) + return _DummyModelOutput(logits=logits) + + +def _make_pipeline(tokenizer=None): + model = _DummyCausalLM(vocab_size=32) + scheduler = SDARTokenDiffusionScheduler() + return SDARPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer) + + +class SDARPipelineTest(unittest.TestCase): + # ------------------------------------------------------------------ + # Basic pipeline run + # ------------------------------------------------------------------ + def test_pipeline_runs_with_input_ids(self): + pipe = _make_pipeline().to("cpu") + + input_ids = torch.tensor([[5, 6, 7, 8]], dtype=torch.long) + out = pipe( + input_ids=input_ids, + use_chat_template=False, + max_new_tokens=16, + block_length=4, + num_inference_steps=4, + temperature=0.0, + mask_token_id=31, + output_type="seq", + ) + + self.assertIsNotNone(out.sequences) + self.assertEqual(out.sequences.ndim, 2) + # Generated tokens only (prompt stripped) + self.assertGreater(out.sequences.shape[1], 0) + + # ------------------------------------------------------------------ + # output_type="seq" β†’ texts is None + # ------------------------------------------------------------------ + def test_output_type_seq(self): + pipe = _make_pipeline().to("cpu") + + out = pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + use_chat_template=False, + max_new_tokens=16, + block_length=4, + num_inference_steps=4, + temperature=0.0, + mask_token_id=31, + output_type="seq", + ) + + self.assertIsNotNone(out.sequences) + self.assertIsNone(out.texts) + + # ------------------------------------------------------------------ + # output_type="text" with dummy tokenizer + # ------------------------------------------------------------------ + def test_output_type_text_with_tokenizer(self): + tok = type( + "Tok", + (), + { + "eos_token_id": None, + "mask_token_id": 31, + "batch_decode": lambda self, seqs, **kw: [f"decoded_{len(s)}" for s in seqs], + }, + )() + pipe = _make_pipeline(tokenizer=tok).to("cpu") + + out = pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + use_chat_template=False, + max_new_tokens=16, + block_length=4, + num_inference_steps=4, + temperature=0.0, + output_type="text", + ) + + self.assertIsNotNone(out.sequences) + self.assertIsNotNone(out.texts) + self.assertEqual(len(out.texts), 1) + self.assertTrue(out.texts[0].startswith("decoded_")) + + # ------------------------------------------------------------------ + # output_type="text" without tokenizer β†’ texts is None + # ------------------------------------------------------------------ + def test_output_type_text_without_tokenizer(self): + pipe = _make_pipeline(tokenizer=None).to("cpu") + + out = pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + use_chat_template=False, + max_new_tokens=16, + block_length=4, + num_inference_steps=4, + temperature=0.0, + mask_token_id=31, + output_type="text", + ) + + self.assertIsNotNone(out.sequences) + self.assertIsNone(out.texts) + + # ------------------------------------------------------------------ + # Invalid output_type raises ValueError + # ------------------------------------------------------------------ + def test_output_type_invalid_raises(self): + pipe = _make_pipeline().to("cpu") + + with self.assertRaises(ValueError): + pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + use_chat_template=False, + max_new_tokens=16, + block_length=4, + num_inference_steps=4, + mask_token_id=31, + output_type="invalid", + ) + + # ------------------------------------------------------------------ + # check_inputs validation + # ------------------------------------------------------------------ + def test_check_inputs_no_source_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=None, + block_length=4, + num_inference_steps=4, + mask_token_id=31, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_prompt_and_messages_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt="hello", + messages=[{"role": "user", "content": "hi"}], + input_ids=None, + block_length=4, + num_inference_steps=4, + mask_token_id=31, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_bad_input_ids_ndim_raises(self): + pipe = _make_pipeline() + bad_ids = torch.zeros(2, 3, 4, dtype=torch.long) + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=bad_ids, + block_length=4, + num_inference_steps=4, + mask_token_id=31, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_bad_block_length_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=torch.tensor([[1, 2]], dtype=torch.long), + block_length=0, + num_inference_steps=4, + mask_token_id=31, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + def test_check_inputs_no_mask_token_raises(self): + pipe = _make_pipeline() + with self.assertRaises(ValueError): + pipe.check_inputs( + prompt=None, + messages=None, + input_ids=torch.tensor([[1, 2]], dtype=torch.long), + block_length=4, + num_inference_steps=4, + mask_token_id=None, + output_type="seq", + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=None, + ) + + # ------------------------------------------------------------------ + # return_dict=False returns tuple + # ------------------------------------------------------------------ + def test_return_dict_false(self): + pipe = _make_pipeline().to("cpu") + + result = pipe( + input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long), + use_chat_template=False, + max_new_tokens=16, + block_length=4, + num_inference_steps=4, + temperature=0.0, + mask_token_id=31, + output_type="seq", + return_dict=False, + ) + + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + sequences, texts = result + self.assertIsNotNone(sequences) + self.assertIsNone(texts) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pipelines/test_pipeline_hybrid_token_diffusion.py b/tests/pipelines/test_pipeline_hybrid_token_diffusion.py new file mode 100644 index 000000000000..25c2f7a9ae82 --- /dev/null +++ b/tests/pipelines/test_pipeline_hybrid_token_diffusion.py @@ -0,0 +1,56 @@ +import unittest + +import torch + +from diffusers import HybridTokenDiffusionPipeline, HybridTokenDiffusionScheduler + + +class _DummyTokenizer: + cls_token_id = 1 + bos_token_id = None + + def batch_decode(self, sequences, skip_special_tokens=True): + return [" ".join(map(str, row)) for row in sequences.tolist()] + + +class _DummyModelOutput: + def __init__(self, logits): + self.logits = logits + + +class _DummyMLM(torch.nn.Module): + def __init__(self, vocab_size: int): + super().__init__() + self.vocab_size = vocab_size + self.register_buffer("_device_anchor", torch.empty(0)) + + @property + def dtype(self): + return torch.float32 + + @property + def device(self): + return self._device_anchor.device + + def forward(self, input_ids=None, attention_mask=None, **kwargs): + batch_size, seq_len = input_ids.shape + logits = torch.zeros((batch_size, seq_len, self.vocab_size), device=input_ids.device, dtype=torch.float32) + return _DummyModelOutput(logits=logits) + + +class HybridTokenDiffusionPipelineTest(unittest.TestCase): + def test_pipeline_runs(self): + vocab_size = 32 + scheduler = HybridTokenDiffusionScheduler(vocab_size=vocab_size, mask_token_id=vocab_size - 1) + model = _DummyMLM(vocab_size=vocab_size) + tokenizer = _DummyTokenizer() + pipe = HybridTokenDiffusionPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer).to("cpu") + + gen = torch.Generator().manual_seed(0) + out = pipe(batch_size=2, seq_len=8, num_inference_steps=2, generator=gen, inject_start_token=True) + self.assertEqual(out.sequences.shape, (2, 8)) + self.assertEqual(len(out.texts), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pipelines/test_pipeline_token_diffusion.py b/tests/pipelines/test_pipeline_token_diffusion.py new file mode 100644 index 000000000000..57c26d13a6a2 --- /dev/null +++ b/tests/pipelines/test_pipeline_token_diffusion.py @@ -0,0 +1,129 @@ +import unittest + +import torch + +from diffusers import TokenDiffusionPipeline, TokenDiffusionScheduler + + +class _DummyTokenizer: + bos_token_id = None + cls_token_id = 1 + + def batch_decode(self, sequences, skip_special_tokens=True): + # Deterministic, cheap β€œdecode”: join token ids as strings. + out = [] + for row in sequences.tolist(): + out.append(" ".join(str(i) for i in row)) + return out + + +class _DummyModelOutput: + def __init__(self, logits): + self.logits = logits + + +class _DummyMLM(torch.nn.Module): + def __init__(self, vocab_size: int): + super().__init__() + self.vocab_size = vocab_size + self.register_buffer("_device_anchor", torch.empty(0)) + + @property + def dtype(self): + return torch.float32 + + @property + def device(self): + return self._device_anchor.device + + def forward(self, input_ids=None, attention_mask=None, **kwargs): + batch_size, seq_len = input_ids.shape + logits = torch.zeros((batch_size, seq_len, self.vocab_size), device=input_ids.device, dtype=torch.float32) + return _DummyModelOutput(logits=logits) + + +class TokenDiffusionPipelineTest(unittest.TestCase): + def test_absorbing_pipeline_runs(self): + vocab_size = 32 + scheduler = TokenDiffusionScheduler( + vocab_size=vocab_size, mask_token_id=vocab_size - 1, forward_process="absorbing" + ) + model = _DummyMLM(vocab_size=vocab_size) + tokenizer = _DummyTokenizer() + + pipe = TokenDiffusionPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer) + pipe = pipe.to("cpu") + + out = pipe(batch_size=2, seq_len=8, num_inference_steps=2, inject_start_token=True) + self.assertEqual(out.sequences.shape, (2, 8)) + self.assertEqual(len(out.texts), 2) + + def test_uniform_pipeline_runs(self): + vocab_size = 32 + scheduler = TokenDiffusionScheduler( + vocab_size=vocab_size, + mask_token_id=vocab_size - 1, + forward_process="uniform", + exclude_mask_from_uniform=True, + ) + model = _DummyMLM(vocab_size=vocab_size) + tokenizer = _DummyTokenizer() + + pipe = TokenDiffusionPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer) + pipe = pipe.to("cpu") + + gen = torch.Generator().manual_seed(0) + out = pipe(batch_size=2, seq_len=8, num_inference_steps=2, generator=gen, inject_start_token=True) + self.assertEqual(out.sequences.shape, (2, 8)) + self.assertFalse((out.sequences == scheduler.mask_token_id).any().item()) + + def test_prefix_ids_are_fixed(self): + vocab_size = 32 + scheduler = TokenDiffusionScheduler( + vocab_size=vocab_size, mask_token_id=vocab_size - 1, forward_process="absorbing" + ) + model = _DummyMLM(vocab_size=vocab_size) + tokenizer = _DummyTokenizer() + + pipe = TokenDiffusionPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer).to("cpu") + prefix = torch.tensor([5, 6, 7], dtype=torch.long) + out = pipe(batch_size=2, seq_len=8, num_inference_steps=2, prefix_ids=prefix, return_text=False) + + self.assertTrue((out.sequences[:, :3] == prefix.view(1, -1)).all().item()) + + def test_infill_mask_freezes_positions(self): + vocab_size = 32 + scheduler = TokenDiffusionScheduler( + vocab_size=vocab_size, + mask_token_id=vocab_size - 1, + forward_process="uniform", + exclude_mask_from_uniform=True, + ) + model = _DummyMLM(vocab_size=vocab_size) + tokenizer = _DummyTokenizer() + + pipe = TokenDiffusionPipeline(model=model, scheduler=scheduler, tokenizer=tokenizer).to("cpu") + + # Only positions 2..7 are editable, first two positions are fixed to the initial values. + infill_mask = torch.ones((2, 8), dtype=torch.bool) + infill_mask[:, :2] = False + gen = torch.Generator().manual_seed(0) + out = pipe( + batch_size=2, seq_len=8, num_inference_steps=2, generator=gen, infill_mask=infill_mask, return_text=False + ) + + # Fixed positions should be unchanged from the initial latents (for uniform, these are random but clamped). + # Since the model predicts uniform logits and the scheduler would otherwise resample, this checks clamping works. + out2 = pipe( + batch_size=2, + seq_len=8, + num_inference_steps=2, + generator=torch.Generator().manual_seed(0), + infill_mask=infill_mask, + return_text=False, + ) + self.assertTrue((out.sequences[:, :2] == out2.sequences[:, :2]).all().item()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/schedulers/test_scheduler_bd3lm_token_diffusion.py b/tests/schedulers/test_scheduler_bd3lm_token_diffusion.py new file mode 100644 index 000000000000..a13f063753b8 --- /dev/null +++ b/tests/schedulers/test_scheduler_bd3lm_token_diffusion.py @@ -0,0 +1,338 @@ +import math +import tempfile +import unittest + +import torch + +from diffusers import BD3LMTokenDiffusionScheduler + + +class BD3LMTokenDiffusionSchedulerTest(unittest.TestCase): + def get_scheduler(self, **kwargs): + config = { + "block_size": 8, + "num_inference_steps": 8, + "noise_type": "loglinear", + "nucleus_p": 1.0, + "mask_token_id": 31, + } + config.update(kwargs) + return BD3LMTokenDiffusionScheduler(**config) + + # ------------------------------------------------------------------ + # Timestep management + # ------------------------------------------------------------------ + + def test_set_timesteps(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(8) + self.assertEqual(scheduler.num_inference_steps, 8) + self.assertEqual(len(scheduler.timesteps), 8) + # Should go from 1.0 down to near 0.0 + self.assertAlmostEqual(scheduler.timesteps[0].item(), 1.0, places=4) + self.assertAlmostEqual(scheduler.timesteps[-1].item(), 0.0, places=4) + + def test_set_timesteps_invalid(self): + scheduler = self.get_scheduler() + with self.assertRaises(ValueError): + scheduler.set_timesteps(0) + + # ------------------------------------------------------------------ + # Noise schedule: _compute_move_chance + # ------------------------------------------------------------------ + + def test_compute_move_chance_loglinear(self): + scheduler = self.get_scheduler(noise_type="loglinear") + t = torch.tensor([0.0, 0.5, 1.0]) + mc = scheduler._compute_move_chance(t) + self.assertTrue(torch.allclose(mc, t)) + + def test_compute_move_chance_cosine(self): + scheduler = self.get_scheduler(noise_type="cosine") + t = torch.tensor([0.0, 0.5, 1.0]) + mc = scheduler._compute_move_chance(t) + eps = 1e-3 + expected_0 = 1.0 - (1.0 - eps) * math.cos(0.0) + expected_half = 1.0 - (1.0 - eps) * math.cos(0.5 * math.pi / 2.0) + expected_1 = 1.0 - (1.0 - eps) * math.cos(math.pi / 2.0) + self.assertAlmostEqual(mc[0].item(), expected_0, places=5) + self.assertAlmostEqual(mc[1].item(), expected_half, places=5) + self.assertAlmostEqual(mc[2].item(), expected_1, places=5) + + def test_compute_move_chance_square(self): + scheduler = self.get_scheduler(noise_type="square") + t = torch.tensor([0.0, 0.5, 1.0]) + mc = scheduler._compute_move_chance(t) + eps = 1e-3 + self.assertAlmostEqual(mc[0].item(), eps, places=5) + self.assertAlmostEqual(mc[1].item(), 0.25, places=5) + self.assertAlmostEqual(mc[2].item(), 1.0, places=5) + + def test_compute_move_chance_square_root(self): + scheduler = self.get_scheduler(noise_type="square_root") + t = torch.tensor([0.0, 0.25, 1.0]) + mc = scheduler._compute_move_chance(t) + eps = 1e-3 + self.assertAlmostEqual(mc[0].item(), eps, places=5) + self.assertAlmostEqual(mc[1].item(), 0.5, places=5) + self.assertAlmostEqual(mc[2].item(), 1.0, places=5) + + def test_compute_move_chance_log(self): + scheduler = self.get_scheduler(noise_type="log") + t = torch.tensor([0.0, 1.0]) + mc = scheduler._compute_move_chance(t) + self.assertAlmostEqual(mc[0].item(), 0.0, places=5) + self.assertAlmostEqual(mc[1].item(), 1.0, places=5) + + # ------------------------------------------------------------------ + # Sigma computation + # ------------------------------------------------------------------ + + def test_compute_sigma(self): + scheduler = self.get_scheduler(noise_type="loglinear") + sigma = scheduler.compute_sigma(0.5, batch_size=2) + self.assertEqual(sigma.shape, (2,)) + # sigma = -log(1 - move_chance) = -log(1 - 0.5) = log(2) + expected = math.log(2.0) + self.assertAlmostEqual(sigma[0].item(), expected, places=4) + self.assertAlmostEqual(sigma[1].item(), expected, places=4) + + def test_compute_sigma_clamps_at_max(self): + scheduler = self.get_scheduler(noise_type="loglinear") + # At t=1.0, move_chance=1.0, so -log(0) -> inf, should be clamped. + sigma = scheduler.compute_sigma(1.0, batch_size=1) + eps = 1e-3 + sigma_max = -math.log(eps) + self.assertAlmostEqual(sigma[0].item(), sigma_max, places=3) + + # ------------------------------------------------------------------ + # Config save/load + # ------------------------------------------------------------------ + + def test_save_load_config_round_trip(self): + scheduler = self.get_scheduler(block_size=16, noise_type="cosine", nucleus_p=0.9, mask_token_id=99) + with tempfile.TemporaryDirectory() as tmpdir: + scheduler.save_config(tmpdir) + loaded = BD3LMTokenDiffusionScheduler.from_pretrained(tmpdir) + + self.assertEqual(loaded.config.block_size, 16) + self.assertEqual(loaded.config.noise_type, "cosine") + self.assertAlmostEqual(loaded.config.nucleus_p, 0.9) + self.assertEqual(loaded.config.mask_token_id, 99) + + def test_from_config(self): + scheduler = self.get_scheduler(block_size=16, noise_type="square") + new_scheduler = BD3LMTokenDiffusionScheduler.from_config(scheduler.config) + self.assertEqual(new_scheduler.config.block_size, 16) + self.assertEqual(new_scheduler.config.noise_type, "square") + + # ------------------------------------------------------------------ + # step() + # ------------------------------------------------------------------ + + def test_step_commits_tokens(self): + """Running enough steps should commit masked tokens to non-mask values.""" + scheduler = self.get_scheduler(block_size=4) + scheduler.set_timesteps(8) + + batch_size, block_size, vocab_size = 1, 4, 32 + mask_id = 31 + + sample = torch.full((batch_size, block_size), mask_id, dtype=torch.long) + + # Create logits with strong preference for non-mask tokens. + logits = torch.zeros(batch_size, block_size, vocab_size) + for i in range(block_size): + logits[0, i, i] = 10.0 + + # Run all denoising steps. + for step_idx in range(scheduler.num_inference_steps): + t = scheduler.timesteps[step_idx] + out = scheduler.step( + model_output=logits, + timestep=t, + sample=sample, + mask_token_id=mask_id, + return_dict=True, + ) + sample = out.prev_sample + + # After all steps, no mask tokens should remain. + self.assertFalse((sample == mask_id).any().item()) + + def test_step_preserves_unmasked_tokens(self): + """Already-unmasked positions must be preserved (copy flag).""" + scheduler = self.get_scheduler(block_size=4) + scheduler.set_timesteps(4) + + batch_size, block_size, vocab_size = 1, 4, 32 + mask_id = 31 + + # Positions 0,1 are already unmasked; positions 2,3 are masked. + sample = torch.tensor([[5, 10, mask_id, mask_id]], dtype=torch.long) + logits = torch.zeros(batch_size, block_size, vocab_size) + for i in range(block_size): + logits[0, i, i % (vocab_size - 2)] = 10.0 + + out = scheduler.step( + model_output=logits, + timestep=scheduler.timesteps[0], + sample=sample, + mask_token_id=mask_id, + return_dict=True, + ) + + # Unmasked positions must be unchanged. + self.assertEqual(out.prev_sample[0, 0].item(), 5) + self.assertEqual(out.prev_sample[0, 1].item(), 10) + + def test_step_return_tuple(self): + """return_dict=False should return a plain tuple.""" + scheduler = self.get_scheduler(block_size=4) + scheduler.set_timesteps(4) + + vocab_size = 32 + sample = torch.full((1, 4), 31, dtype=torch.long) + logits = torch.randn(1, 4, vocab_size) + + result = scheduler.step( + model_output=logits, + timestep=scheduler.timesteps[0], + sample=sample, + mask_token_id=31, + return_dict=False, + ) + + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + + def test_step_batched(self): + """step works with batch_size > 1.""" + scheduler = self.get_scheduler(block_size=4) + scheduler.set_timesteps(4) + + batch_size, vocab_size = 3, 32 + mask_id = 31 + sample = torch.full((batch_size, 4), mask_id, dtype=torch.long) + logits = torch.randn(batch_size, 4, vocab_size) + + out = scheduler.step( + model_output=logits, + timestep=scheduler.timesteps[0], + sample=sample, + mask_token_id=mask_id, + return_dict=True, + ) + + self.assertEqual(out.prev_sample.shape, (batch_size, 4)) + + # ------------------------------------------------------------------ + # Nucleus filtering + # ------------------------------------------------------------------ + + def test_nucleus_filtering_passthrough(self): + """nucleus_p=1.0 should not alter the distribution.""" + probs = torch.softmax(torch.randn(1, 4, 32), dim=-1) + filtered = BD3LMTokenDiffusionScheduler._nucleus_filtering(probs, nucleus_p=1.0) + self.assertTrue(torch.allclose(probs, filtered)) + + def test_nucleus_filtering_truncates(self): + """nucleus_p < 1.0 should zero out low-probability tokens.""" + probs = torch.zeros(1, 1, 4) + probs[0, 0] = torch.tensor([0.5, 0.3, 0.15, 0.05]) + filtered = BD3LMTokenDiffusionScheduler._nucleus_filtering(probs, nucleus_p=0.8) + # Token with prob 0.05 should be zeroed out. + self.assertAlmostEqual(filtered[0, 0, 3].item(), 0.0, places=5) + # Filtered probs should still sum to ~1. + self.assertAlmostEqual(filtered.sum().item(), 1.0, places=4) + + def test_nucleus_filtering_keeps_top1(self): + """Nucleus filtering always keeps at least the top-1 token.""" + probs = torch.zeros(1, 1, 4) + probs[0, 0] = torch.tensor([0.1, 0.1, 0.1, 0.7]) + filtered = BD3LMTokenDiffusionScheduler._nucleus_filtering(probs, nucleus_p=0.01) + # Top-1 (index 3) must be kept. + self.assertGreater(filtered[0, 0, 3].item(), 0.0) + + # ------------------------------------------------------------------ + # Stopping criteria + # ------------------------------------------------------------------ + + def test_check_should_stop_all_unmasked(self): + mask_id = 31 + sequences = torch.tensor([[1, 2, 3, 4]], dtype=torch.long) + self.assertTrue(BD3LMTokenDiffusionScheduler.check_should_stop(sequences, mask_id)) + + def test_check_should_stop_has_masks(self): + mask_id = 31 + sequences = torch.tensor([[1, 31, 3, 4]], dtype=torch.long) + self.assertFalse(BD3LMTokenDiffusionScheduler.check_should_stop(sequences, mask_id)) + + def test_check_eos_finished(self): + eos_id = 2 + prompt_length = 2 + sequences = torch.tensor([[10, 11, 5, eos_id, 7, 8]], dtype=torch.long) + finished = torch.tensor([False]) + + finished = BD3LMTokenDiffusionScheduler.check_eos_finished(sequences, prompt_length, eos_id, finished) + self.assertTrue(finished[0].item()) + + def test_check_eos_finished_no_eos(self): + eos_id = 2 + prompt_length = 2 + sequences = torch.tensor([[10, 11, 5, 6, 7, 8]], dtype=torch.long) + finished = torch.tensor([False]) + + finished = BD3LMTokenDiffusionScheduler.check_eos_finished(sequences, prompt_length, eos_id, finished) + self.assertFalse(finished[0].item()) + + def test_check_eos_finished_already_finished(self): + eos_id = 2 + sequences = torch.tensor([[10, 11, 5, 6]], dtype=torch.long) + finished = torch.tensor([True]) + + finished = BD3LMTokenDiffusionScheduler.check_eos_finished(sequences, 2, eos_id, finished) + self.assertTrue(finished[0].item()) + + # ------------------------------------------------------------------ + # add_noise + # ------------------------------------------------------------------ + + def test_add_noise(self): + scheduler = self.get_scheduler() + original = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long) + mask_id = 31 + + gen = torch.Generator().manual_seed(42) + # At t=1.0 (loglinear), all tokens should be masked. + t_full = torch.tensor([1.0]) + noisy = scheduler.add_noise(original, t_full, mask_token_id=mask_id, generator=gen) + self.assertTrue((noisy == mask_id).all().item()) + + def test_add_noise_zero(self): + scheduler = self.get_scheduler() + original = torch.tensor([[1, 2, 3, 4]], dtype=torch.long) + mask_id = 31 + + # At t=0.0 (loglinear), no tokens should be masked. + t_zero = torch.tensor([0.0]) + noisy = scheduler.add_noise(original, t_zero, mask_token_id=mask_id) + self.assertTrue(torch.equal(noisy, original)) + + def test_add_noise_partial(self): + scheduler = self.get_scheduler() + original = torch.arange(100).unsqueeze(0).long() + mask_id = 999 + + gen = torch.Generator().manual_seed(0) + t_half = torch.tensor([0.5]) + noisy = scheduler.add_noise(original, t_half, mask_token_id=mask_id, generator=gen) + + num_masked = (noisy == mask_id).sum().item() + # With 100 tokens and move_chance=0.5, we expect roughly 50 masked. + self.assertGreater(num_masked, 20) + self.assertLess(num_masked, 80) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/schedulers/test_scheduler_dflash_token_diffusion.py b/tests/schedulers/test_scheduler_dflash_token_diffusion.py new file mode 100644 index 000000000000..9cbb52b28193 --- /dev/null +++ b/tests/schedulers/test_scheduler_dflash_token_diffusion.py @@ -0,0 +1,310 @@ +import tempfile +import unittest + +import torch + +from diffusers import DFlashTokenDiffusionScheduler + + +class DFlashTokenDiffusionSchedulerTest(unittest.TestCase): + def get_scheduler(self): + return DFlashTokenDiffusionScheduler() + + # ------------------------------------------------------------------ + # set_timesteps + # ------------------------------------------------------------------ + def test_set_timesteps(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(4) + self.assertEqual(scheduler.num_inference_steps, 4) + self.assertEqual(len(scheduler.timesteps), 4) + self.assertEqual(scheduler.timesteps[0].item(), 3) + self.assertEqual(scheduler.timesteps[-1].item(), 0) + + def test_set_timesteps_single(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(1) + self.assertEqual(scheduler.num_inference_steps, 1) + self.assertEqual(len(scheduler.timesteps), 1) + self.assertEqual(scheduler.timesteps[0].item(), 0) + + def test_set_timesteps_invalid(self): + scheduler = self.get_scheduler() + with self.assertRaises(ValueError): + scheduler.set_timesteps(0) + with self.assertRaises(ValueError): + scheduler.set_timesteps(-1) + + # ------------------------------------------------------------------ + # Config round-trip + # ------------------------------------------------------------------ + def test_save_load_config_round_trip(self): + scheduler = self.get_scheduler() + with tempfile.TemporaryDirectory() as tmpdir: + scheduler.save_config(tmpdir) + loaded = DFlashTokenDiffusionScheduler.from_pretrained(tmpdir) + # The scheduler has no user-configurable params, but it should survive the round-trip. + self.assertIsInstance(loaded, DFlashTokenDiffusionScheduler) + self.assertEqual(loaded.order, 1) + + def test_from_config(self): + scheduler = self.get_scheduler() + new_scheduler = DFlashTokenDiffusionScheduler.from_config(scheduler.config) + self.assertIsInstance(new_scheduler, DFlashTokenDiffusionScheduler) + self.assertEqual(new_scheduler.order, 1) + + # ------------------------------------------------------------------ + # sample() – greedy + # ------------------------------------------------------------------ + def test_sample_greedy(self): + scheduler = self.get_scheduler() + logits = torch.tensor([[[1.0, 5.0, 2.0], [3.0, 1.0, 4.0]]]) # (1, 2, 3) + tokens = scheduler.sample(logits, temperature=0.0) + self.assertEqual(tokens.shape, (1, 2)) + self.assertEqual(tokens[0, 0].item(), 1) # argmax of [1,5,2] + self.assertEqual(tokens[0, 1].item(), 2) # argmax of [3,1,4] + + def test_sample_greedy_batched(self): + scheduler = self.get_scheduler() + logits = torch.tensor( + [ + [[10.0, 0.0], [0.0, 10.0]], + [[0.0, 10.0], [10.0, 0.0]], + ] + ) # (2, 2, 2) + tokens = scheduler.sample(logits, temperature=0.0) + self.assertEqual(tokens.shape, (2, 2)) + self.assertEqual(tokens[0, 0].item(), 0) + self.assertEqual(tokens[0, 1].item(), 1) + self.assertEqual(tokens[1, 0].item(), 1) + self.assertEqual(tokens[1, 1].item(), 0) + + # ------------------------------------------------------------------ + # sample() – multinomial + # ------------------------------------------------------------------ + def test_sample_multinomial(self): + scheduler = self.get_scheduler() + # One token has overwhelming probability; multinomial should pick it. + logits = torch.tensor([[[0.0, 100.0, -100.0]]]) # (1, 1, 3) + tokens = scheduler.sample(logits, temperature=1.0) + self.assertEqual(tokens.shape, (1, 1)) + self.assertEqual(tokens[0, 0].item(), 1) + + # ------------------------------------------------------------------ + # step() – return dict + # ------------------------------------------------------------------ + def test_step_all_accepted(self): + """All draft tokens match the posterior => accepted_length == block_size - 1.""" + scheduler = self.get_scheduler() + batch_size, block_size, vocab_size = 1, 4, 8 + + # Draft tokens: [0, 3, 3, 3] + draft_tokens = torch.tensor([[0, 3, 3, 3]], dtype=torch.long) + # Target logits: make argmax = [3, 3, 3, X] so posterior[:, :-1] matches draft[:, 1:] + logits = torch.zeros(batch_size, block_size, vocab_size) + logits[:, 0, 3] = 10.0 + logits[:, 1, 3] = 10.0 + logits[:, 2, 3] = 10.0 + logits[:, 3, 5] = 10.0 # last posterior token (next_token candidate) + + out = scheduler.step(draft_tokens, logits, temperature=0.0, return_dict=True) + + self.assertEqual(out.prev_sample.shape, (1, 4)) + self.assertEqual(out.accepted_length.shape, (1,)) + self.assertEqual(out.accepted_length[0].item(), 3) # all 3 comparisons match + self.assertEqual(out.next_token.shape, (1,)) + self.assertEqual(out.next_token[0].item(), 5) + self.assertEqual(out.posterior.shape, (1, 4)) + + def test_step_none_accepted(self): + """First draft token already mismatches => accepted_length == 0.""" + scheduler = self.get_scheduler() + batch_size, block_size, vocab_size = 1, 4, 8 + + draft_tokens = torch.tensor([[0, 1, 2, 3]], dtype=torch.long) + logits = torch.zeros(batch_size, block_size, vocab_size) + logits[:, 0, 5] = 10.0 # posterior[0] = 5, but draft[1] = 1 => mismatch + logits[:, 1, 2] = 10.0 + logits[:, 2, 3] = 10.0 + logits[:, 3, 4] = 10.0 + + out = scheduler.step(draft_tokens, logits, temperature=0.0, return_dict=True) + + self.assertEqual(out.accepted_length[0].item(), 0) + self.assertEqual(out.next_token[0].item(), 5) # posterior at index 0 + + def test_step_partial_accepted(self): + """First two match, third does not => accepted_length == 2.""" + scheduler = self.get_scheduler() + batch_size, block_size, vocab_size = 1, 5, 8 + + # draft: [0, 3, 4, 7, 2] + draft_tokens = torch.tensor([[0, 3, 4, 7, 2]], dtype=torch.long) + logits = torch.zeros(batch_size, block_size, vocab_size) + logits[:, 0, 3] = 10.0 # match draft[1]=3 + logits[:, 1, 4] = 10.0 # match draft[2]=4 + logits[:, 2, 0] = 10.0 # mismatch draft[3]=7 + logits[:, 3, 2] = 10.0 + logits[:, 4, 6] = 10.0 + + out = scheduler.step(draft_tokens, logits, temperature=0.0, return_dict=True) + + self.assertEqual(out.accepted_length[0].item(), 2) + self.assertEqual(out.next_token[0].item(), 0) # posterior at index 2 + + def test_step_single_token_block(self): + """Block with a single token => accepted_length == 0.""" + scheduler = self.get_scheduler() + draft_tokens = torch.tensor([[5]], dtype=torch.long) + logits = torch.zeros(1, 1, 8) + logits[:, 0, 3] = 10.0 + + out = scheduler.step(draft_tokens, logits, temperature=0.0, return_dict=True) + self.assertEqual(out.accepted_length[0].item(), 0) + self.assertEqual(out.next_token[0].item(), 3) + + # ------------------------------------------------------------------ + # step() – return tuple + # ------------------------------------------------------------------ + def test_step_return_tuple(self): + scheduler = self.get_scheduler() + draft_tokens = torch.tensor([[0, 1, 2]], dtype=torch.long) + logits = torch.randn(1, 3, 8) + + result = scheduler.step(draft_tokens, logits, temperature=0.0, return_dict=False) + + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 4) + prev_sample, accepted_length, next_token, posterior = result + self.assertEqual(prev_sample.shape, (1, 3)) + self.assertEqual(accepted_length.shape, (1,)) + self.assertEqual(next_token.shape, (1,)) + self.assertEqual(posterior.shape, (1, 3)) + + # ------------------------------------------------------------------ + # step() – batched + # ------------------------------------------------------------------ + def test_step_batched(self): + scheduler = self.get_scheduler() + batch_size, block_size, vocab_size = 3, 4, 16 + draft_tokens = torch.randint(0, vocab_size, (batch_size, block_size)) + logits = torch.randn(batch_size, block_size, vocab_size) + + out = scheduler.step(draft_tokens, logits, temperature=0.0, return_dict=True) + + self.assertEqual(out.prev_sample.shape, (batch_size, block_size)) + self.assertEqual(out.accepted_length.shape, (batch_size,)) + self.assertEqual(out.next_token.shape, (batch_size,)) + self.assertEqual(out.posterior.shape, (batch_size, block_size)) + + # ------------------------------------------------------------------ + # check_should_stop() + # ------------------------------------------------------------------ + def test_check_should_stop_no_stop_tokens(self): + output_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long) + self.assertFalse(DFlashTokenDiffusionScheduler.check_should_stop(output_ids, None, 2)) + + def test_check_should_stop_found(self): + # Stop token 99 is in the generated portion (after num_input_tokens=2). + output_ids = torch.tensor([[1, 2, 3, 99, 5]], dtype=torch.long) + self.assertTrue(DFlashTokenDiffusionScheduler.check_should_stop(output_ids, [99], 2)) + + def test_check_should_stop_only_in_prompt(self): + # Stop token 1 is only in the prompt portion => should NOT stop. + output_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long) + self.assertFalse(DFlashTokenDiffusionScheduler.check_should_stop(output_ids, [1], 2)) + + def test_check_should_stop_multiple_stop_tokens(self): + output_ids = torch.tensor([[10, 20, 30, 40, 50]], dtype=torch.long) + self.assertTrue(DFlashTokenDiffusionScheduler.check_should_stop(output_ids, [40, 99], 2)) + self.assertFalse(DFlashTokenDiffusionScheduler.check_should_stop(output_ids, [99, 100], 2)) + + # ------------------------------------------------------------------ + # add_noise() + # ------------------------------------------------------------------ + def test_add_noise_prompt_preserved(self): + scheduler = self.get_scheduler() + original = torch.tensor([[10, 11, 12, 13, 14, 15, 16, 17]], dtype=torch.long) + attention_mask = torch.ones_like(original) + mask_token_id = 99 + prompt_length = 3 + + gen = torch.Generator().manual_seed(42) + noisy, masked = scheduler.add_noise( + original, + attention_mask, + prompt_length=prompt_length, + block_size=4, + mask_token_id=mask_token_id, + generator=gen, + ) + + # Prompt positions should never be masked. + self.assertFalse(masked[0, :prompt_length].any().item()) + # Prompt tokens should be unchanged. + self.assertTrue(torch.equal(noisy[0, :prompt_length], original[0, :prompt_length])) + + def test_add_noise_masked_positions(self): + scheduler = self.get_scheduler() + original = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long) + attention_mask = torch.ones_like(original) + mask_token_id = 99 + + gen = torch.Generator().manual_seed(0) + noisy, masked = scheduler.add_noise( + original, + attention_mask, + prompt_length=2, + block_size=3, + mask_token_id=mask_token_id, + generator=gen, + ) + + # Where masked is True, noisy should equal mask_token_id. + self.assertTrue((noisy[masked] == mask_token_id).all().item()) + # Where masked is False, noisy should equal original. + self.assertTrue(torch.equal(noisy[~masked], original[~masked])) + + def test_add_noise_respects_attention_mask(self): + scheduler = self.get_scheduler() + original = torch.tensor([[1, 2, 3, 4, 0, 0]], dtype=torch.long) + attention_mask = torch.tensor([[1, 1, 1, 1, 0, 0]], dtype=torch.long) + mask_token_id = 99 + + gen = torch.Generator().manual_seed(42) + noisy, masked = scheduler.add_noise( + original, + attention_mask, + prompt_length=1, + block_size=3, + mask_token_id=mask_token_id, + generator=gen, + ) + + # Padding positions (attention_mask=0) should never be masked. + self.assertFalse(masked[0, 4].item()) + self.assertFalse(masked[0, 5].item()) + + def test_add_noise_output_shapes(self): + scheduler = self.get_scheduler() + batch_size, seq_len = 2, 10 + original = torch.randint(0, 50, (batch_size, seq_len)) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long) + mask_token_id = 99 + + noisy, masked = scheduler.add_noise( + original, + attention_mask, + prompt_length=2, + block_size=4, + mask_token_id=mask_token_id, + ) + + self.assertEqual(noisy.shape, (batch_size, seq_len)) + self.assertEqual(masked.shape, (batch_size, seq_len)) + self.assertEqual(noisy.dtype, torch.long) + self.assertEqual(masked.dtype, torch.bool) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/schedulers/test_scheduler_hybrid_token_diffusion.py b/tests/schedulers/test_scheduler_hybrid_token_diffusion.py new file mode 100644 index 000000000000..a3146698f22b --- /dev/null +++ b/tests/schedulers/test_scheduler_hybrid_token_diffusion.py @@ -0,0 +1,29 @@ +import unittest + +import torch + +from diffusers import HybridTokenDiffusionScheduler + + +class HybridTokenDiffusionSchedulerTest(unittest.TestCase): + def test_add_noise_and_step_shapes(self): + vocab_size = 32 + scheduler = HybridTokenDiffusionScheduler(vocab_size=vocab_size, mask_token_id=vocab_size - 1) + scheduler.set_timesteps(4, device="cpu") + + batch_size, seq_len = 2, 8 + x0 = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (batch_size,), dtype=torch.long) + x_t = scheduler.add_noise(x0, noise=None, timesteps=timesteps) + self.assertEqual(x_t.shape, x0.shape) + self.assertEqual(x_t.dtype, torch.long) + + logits = torch.zeros((batch_size, seq_len, vocab_size), dtype=torch.float32) + gen = torch.Generator().manual_seed(0) + out = scheduler.step(logits, scheduler.timesteps[0], x_t, generator=gen, return_dict=True) + self.assertEqual(out.prev_sample.shape, x0.shape) + self.assertTrue(((out.prev_sample >= 0) & (out.prev_sample < vocab_size)).all().item()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/schedulers/test_scheduler_sdar_token_diffusion.py b/tests/schedulers/test_scheduler_sdar_token_diffusion.py new file mode 100644 index 000000000000..a51fe13514cb --- /dev/null +++ b/tests/schedulers/test_scheduler_sdar_token_diffusion.py @@ -0,0 +1,317 @@ +import tempfile +import unittest + +import torch + +from diffusers import SDARTokenDiffusionScheduler + + +class SDARTokenDiffusionSchedulerTest(unittest.TestCase): + def get_scheduler(self, **kwargs): + config = { + "block_length": 32, + "num_inference_steps": 8, + "remasking_strategy": "low_confidence_dynamic", + "confidence_threshold": 0.9, + "entropy_threshold": 0.35, + "temperature": 1.0, + "top_k": 0, + "top_p": 1.0, + } + config.update(kwargs) + return SDARTokenDiffusionScheduler(**config) + + # ------------------------------------------------------------------ + # set_timesteps + # ------------------------------------------------------------------ + def test_set_timesteps(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(8) + self.assertEqual(scheduler.num_inference_steps, 8) + self.assertEqual(len(scheduler.timesteps), 8) + self.assertEqual(scheduler.timesteps[0].item(), 7) + self.assertEqual(scheduler.timesteps[-1].item(), 0) + + def test_set_timesteps_invalid(self): + scheduler = self.get_scheduler() + with self.assertRaises(ValueError): + scheduler.set_timesteps(0) + + # ------------------------------------------------------------------ + # get_num_transfer_tokens + # ------------------------------------------------------------------ + def test_get_num_transfer_tokens_even(self): + scheduler = self.get_scheduler() + schedule = scheduler.get_num_transfer_tokens(block_length=32, num_inference_steps=8) + self.assertEqual(schedule.sum().item(), 32) + self.assertEqual(len(schedule), 8) + self.assertTrue((schedule == 4).all().item()) + + def test_get_num_transfer_tokens_remainder(self): + scheduler = self.get_scheduler() + schedule = scheduler.get_num_transfer_tokens(block_length=10, num_inference_steps=3) + self.assertEqual(schedule.sum().item(), 10) + self.assertEqual(len(schedule), 3) + # First `remainder` entries get +1 + self.assertEqual(schedule[0].item(), 4) + self.assertEqual(schedule[1].item(), 3) + self.assertEqual(schedule[2].item(), 3) + + # ------------------------------------------------------------------ + # save / load config round trip + # ------------------------------------------------------------------ + def test_save_load_config_round_trip(self): + scheduler = self.get_scheduler( + block_length=64, + remasking_strategy="sequential", + confidence_threshold=0.8, + entropy_threshold=0.5, + ) + with tempfile.TemporaryDirectory() as tmpdir: + scheduler.save_config(tmpdir) + loaded = SDARTokenDiffusionScheduler.from_pretrained(tmpdir) + + self.assertEqual(loaded.config.block_length, 64) + self.assertEqual(loaded.config.remasking_strategy, "sequential") + self.assertEqual(loaded.config.confidence_threshold, 0.8) + self.assertEqual(loaded.config.entropy_threshold, 0.5) + + # ------------------------------------------------------------------ + # from_config + # ------------------------------------------------------------------ + def test_from_config(self): + scheduler = self.get_scheduler(block_length=16, remasking_strategy="entropy_bounded") + new_scheduler = SDARTokenDiffusionScheduler.from_config(scheduler.config) + self.assertEqual(new_scheduler.config.block_length, 16) + self.assertEqual(new_scheduler.config.remasking_strategy, "entropy_bounded") + + # ------------------------------------------------------------------ + # step – remasking strategies + # ------------------------------------------------------------------ + def _make_step_inputs(self, batch_size=1, block_length=8, vocab_size=32, mask_id=31, num_steps=2): + sample = torch.full((batch_size, block_length), mask_id, dtype=torch.long) + logits = torch.zeros(batch_size, block_length, vocab_size) + for i in range(block_length): + logits[:, i, i % (vocab_size - 1)] = 10.0 - i # decreasing confidence + scheduler = self.get_scheduler(block_length=block_length, num_inference_steps=num_steps) + scheduler.set_timesteps(num_steps) + num_transfer_tokens = scheduler.get_num_transfer_tokens(block_length, num_steps) + return scheduler, logits, sample, num_transfer_tokens, mask_id + + def test_step_sequential(self): + scheduler, logits, sample, ntt, mask_id = self._make_step_inputs() + out = scheduler.step( + model_output=logits, + timestep=0, + sample=sample, + mask_token_id=mask_id, + num_transfer_tokens=ntt, + remasking_strategy="sequential", + temperature=0.0, + return_dict=True, + ) + # With 8 tokens and 2 steps, first step commits 4 tokens sequentially from the first mask + committed = out.transfer_index[0].sum().item() + self.assertEqual(committed, 4) + # Sequential: first 4 positions should be committed + self.assertTrue(out.transfer_index[0, :4].all().item()) + self.assertFalse(out.transfer_index[0, 4:].any().item()) + + def test_step_low_confidence_static(self): + scheduler, logits, sample, ntt, mask_id = self._make_step_inputs() + out = scheduler.step( + model_output=logits, + timestep=0, + sample=sample, + mask_token_id=mask_id, + num_transfer_tokens=ntt, + remasking_strategy="low_confidence_static", + temperature=0.0, + return_dict=True, + ) + committed = out.transfer_index[0].sum().item() + self.assertEqual(committed, 4) + + def test_step_low_confidence_dynamic(self): + scheduler, logits, sample, ntt, mask_id = self._make_step_inputs() + out = scheduler.step( + model_output=logits, + timestep=0, + sample=sample, + mask_token_id=mask_id, + num_transfer_tokens=ntt, + remasking_strategy="low_confidence_dynamic", + confidence_threshold=0.9, + temperature=0.0, + return_dict=True, + ) + # Should commit at least step_transfer tokens + committed = out.transfer_index[0].sum().item() + self.assertGreaterEqual(committed, 4) + + def test_step_entropy_bounded(self): + scheduler, logits, sample, ntt, mask_id = self._make_step_inputs() + out = scheduler.step( + model_output=logits, + timestep=0, + sample=sample, + mask_token_id=mask_id, + num_transfer_tokens=ntt, + remasking_strategy="entropy_bounded", + entropy_threshold=0.35, + temperature=0.0, + return_dict=True, + ) + committed = out.transfer_index[0].sum().item() + self.assertGreater(committed, 0) + + def test_step_unknown_strategy_raises(self): + scheduler, logits, sample, ntt, mask_id = self._make_step_inputs() + with self.assertRaises(ValueError): + scheduler.step( + model_output=logits, + timestep=0, + sample=sample, + mask_token_id=mask_id, + num_transfer_tokens=ntt, + remasking_strategy="nonexistent", + temperature=0.0, + ) + + # ------------------------------------------------------------------ + # step – output shapes + # ------------------------------------------------------------------ + def test_step_output_shapes(self): + batch_size, block_length, vocab_size = 2, 8, 32 + scheduler, logits, sample, ntt, mask_id = self._make_step_inputs( + batch_size=batch_size, block_length=block_length, vocab_size=vocab_size + ) + out = scheduler.step( + model_output=logits, + timestep=0, + sample=sample, + mask_token_id=mask_id, + num_transfer_tokens=ntt, + temperature=0.0, + return_dict=True, + ) + self.assertEqual(out.prev_sample.shape, (batch_size, block_length)) + self.assertEqual(out.transfer_index.shape, (batch_size, block_length)) + self.assertEqual(out.sampled_tokens.shape, (batch_size, block_length)) + self.assertEqual(out.sampled_probs.shape, (batch_size, block_length)) + + # ------------------------------------------------------------------ + # step – return_dict=False + # ------------------------------------------------------------------ + def test_step_return_tuple(self): + scheduler, logits, sample, ntt, mask_id = self._make_step_inputs() + result = scheduler.step( + model_output=logits, + timestep=0, + sample=sample, + mask_token_id=mask_id, + num_transfer_tokens=ntt, + temperature=0.0, + return_dict=False, + ) + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 4) + + # ------------------------------------------------------------------ + # step – batched + # ------------------------------------------------------------------ + def test_step_batched(self): + batch_size = 3 + scheduler, logits, sample, ntt, mask_id = self._make_step_inputs(batch_size=batch_size) + out = scheduler.step( + model_output=logits, + timestep=0, + sample=sample, + mask_token_id=mask_id, + num_transfer_tokens=ntt, + temperature=0.0, + return_dict=True, + ) + self.assertEqual(out.prev_sample.shape, (batch_size, 8)) + self.assertEqual(out.transfer_index.shape, (batch_size, 8)) + + # ------------------------------------------------------------------ + # sample – greedy and multinomial + # ------------------------------------------------------------------ + def test_sample_greedy(self): + scheduler = self.get_scheduler() + logits = torch.tensor([[[1.0, 5.0, 2.0]]]) # (1, 1, 3) + tokens, probs = scheduler.sample(logits, temperature=0.0) + self.assertEqual(tokens.item(), 1) + self.assertEqual(tokens.shape, (1, 1)) + self.assertEqual(probs.shape, (1, 1)) + + def test_sample_multinomial(self): + scheduler = self.get_scheduler() + logits = torch.tensor([[[0.0, 100.0, -100.0]]]) + gen = torch.Generator().manual_seed(42) + tokens, probs = scheduler.sample(logits, temperature=1.0, generator=gen) + self.assertEqual(tokens.item(), 1) + + # ------------------------------------------------------------------ + # check_should_stop + # ------------------------------------------------------------------ + def test_check_should_stop_with_stop_tokens(self): + scheduler = self.get_scheduler() + sequences = torch.tensor([[1, 2, 3, 99, 5]], dtype=torch.long) + self.assertTrue(scheduler.check_should_stop(sequences, prompt_length=2, stop_token_ids=[99])) + + def test_check_should_stop_without_stop_tokens(self): + scheduler = self.get_scheduler() + sequences = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long) + self.assertFalse(scheduler.check_should_stop(sequences, prompt_length=2, stop_token_ids=None)) + + def test_check_should_stop_no_match(self): + scheduler = self.get_scheduler() + sequences = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long) + self.assertFalse(scheduler.check_should_stop(sequences, prompt_length=2, stop_token_ids=[99])) + + def test_check_should_stop_in_prompt_only(self): + scheduler = self.get_scheduler() + # Stop token present only in the prompt region β€” should NOT trigger stop + sequences = torch.tensor([[99, 2, 3, 4, 5]], dtype=torch.long) + self.assertFalse(scheduler.check_should_stop(sequences, prompt_length=2, stop_token_ids=[99])) + + # ------------------------------------------------------------------ + # add_noise + # ------------------------------------------------------------------ + def test_add_noise(self): + scheduler = self.get_scheduler(block_length=4) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + mask_token_id = 99 + + gen = torch.Generator().manual_seed(42) + noisy, noisy_rev, masked, masked_rev = scheduler.add_noise( + input_ids, + attention_mask, + prompt_length=2, + block_length=4, + mask_token_id=mask_token_id, + generator=gen, + ) + + # Prompt positions should never be masked + self.assertFalse(masked[0, 0].item()) + self.assertFalse(masked[0, 1].item()) + self.assertFalse(masked_rev[0, 0].item()) + self.assertFalse(masked_rev[0, 1].item()) + + # Noisy should have mask_token_id where masked is True + self.assertTrue((noisy[masked] == mask_token_id).all().item()) + self.assertTrue((noisy_rev[masked_rev] == mask_token_id).all().item()) + + # masked and masked_rev should be complementary within valid non-prompt positions + non_prompt = torch.zeros_like(masked) + non_prompt[0, 2:] = True + combined = masked | masked_rev + self.assertTrue((combined[0, 2:] == non_prompt[0, 2:]).all().item()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/schedulers/test_scheduler_token_diffusion.py b/tests/schedulers/test_scheduler_token_diffusion.py new file mode 100644 index 000000000000..66738dceedc4 --- /dev/null +++ b/tests/schedulers/test_scheduler_token_diffusion.py @@ -0,0 +1,239 @@ +import unittest + +import torch + +from diffusers import TokenDiffusionScheduler + + +class TokenDiffusionSchedulerTest(unittest.TestCase): + def get_scheduler(self, **kwargs): + config = { + "vocab_size": 128, + "mask_token_id": 127, + "num_train_timesteps": 100, + "eps": 1e-3, + } + config.update(kwargs) + return TokenDiffusionScheduler(**config) + + def test_set_timesteps(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(10) + self.assertEqual(len(scheduler.timesteps), 10) + self.assertTrue((scheduler.timesteps[:-1] >= scheduler.timesteps[1:]).all().item()) + + def test_alpha_schedule_monotone_and_bounded(self): + # alpha(t) should be in (0, 1] and non-increasing in t for supported schedules. + schedules = ["log_linear", "linear", "cosine", "geometric"] + t = torch.linspace(0, 1, 33, dtype=torch.float32) + + for name in schedules: + scheduler = self.get_scheduler(alpha_schedule=name) + alpha = scheduler._alpha_t(t) + self.assertTrue(((alpha > 0) & (alpha <= 1)).all().item()) + # monotone non-increasing: alpha[i] >= alpha[i+1] + self.assertTrue((alpha[:-1] >= alpha[1:]).all().item()) + + def test_mdlm_weights_match_log_linear_1_over_t(self): + scheduler = self.get_scheduler(alpha_schedule="log_linear", eps=1e-3, num_train_timesteps=1000) + timesteps = torch.tensor([1, 10, 100, 999], dtype=torch.long) + w = scheduler.get_mdlm_loss_weights(timesteps).squeeze(-1) + t_cont = timesteps.to(dtype=torch.float32) / float(scheduler.num_train_timesteps - 1) + expected = 1.0 / t_cont + self.assertTrue(torch.allclose(w, expected, rtol=5e-5, atol=1e-5)) + + def test_add_noise_absorbing_keeps_shape_dtype(self): + scheduler = self.get_scheduler() + batch_size, seq_len = 4, 16 + x0 = torch.randint(0, scheduler.vocab_size, (batch_size, seq_len), dtype=torch.long) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (batch_size,), dtype=torch.long) + + xt = scheduler.add_noise(x0, noise=None, timesteps=timesteps) + self.assertEqual(xt.shape, x0.shape) + self.assertEqual(xt.dtype, torch.long) + + # xt values must be valid token ids + self.assertTrue(((xt >= 0) & (xt < scheduler.vocab_size)).all().item()) + + def test_step_preserves_unmasked_tokens(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(5) + + batch_size, seq_len = 2, 12 + x_t = torch.randint(0, scheduler.vocab_size, (batch_size, seq_len), dtype=torch.long) + x_t[:, :3] = scheduler.mask_token_id # ensure some masked positions + + # Model predicts uniform logits; step should never change already unmasked positions + logits = torch.zeros((batch_size, seq_len, scheduler.vocab_size), dtype=torch.float32) + out = scheduler.step(logits, scheduler.timesteps[0], x_t, return_dict=True) + x_prev = out.prev_sample + + self.assertTrue((x_prev[:, 3:] == x_t[:, 3:]).all().item()) + + def test_step_never_samples_mask_token(self): + scheduler = self.get_scheduler() + # Use a single inference step so the scheduler denoises to t=0 in one go (p_denoise = 1). + scheduler.set_timesteps(1) + + batch_size, seq_len = 2, 12 + x_t = torch.full((batch_size, seq_len), scheduler.mask_token_id, dtype=torch.long) + logits = torch.zeros((batch_size, seq_len, scheduler.vocab_size), dtype=torch.float32) + + gen = torch.Generator().manual_seed(0) + x_prev = scheduler.step(logits, scheduler.timesteps[0], x_t, generator=gen, return_dict=True).prev_sample + + # Mask token is forbidden as an x0 prediction, and the scheduler performs a final noise-removal step. + self.assertTrue((x_prev != scheduler.mask_token_id).all().item()) + + def test_uniform_add_noise_excludes_mask_if_configured(self): + scheduler = self.get_scheduler(forward_process="uniform", exclude_mask_from_uniform=True) + batch_size, seq_len = 8, 64 + x0 = torch.randint(0, scheduler.vocab_size, (batch_size, seq_len), dtype=torch.long) + # Make sure some originals are mask token too (uniform should still sample non-mask replacements). + x0[:, :5] = scheduler.mask_token_id + + # Use the noisiest time (highest replace probability). + timesteps = torch.full((batch_size,), scheduler.num_train_timesteps - 1, dtype=torch.long) + xt = scheduler.add_noise(x0, noise=None, timesteps=timesteps) + + # Mask token should be rare-to-absent under uniform corruption when excluded. + self.assertFalse((xt == scheduler.mask_token_id).any().item()) + + def test_uniform_step_runs_and_returns_valid_ids(self): + scheduler = self.get_scheduler(forward_process="uniform", exclude_mask_from_uniform=True) + scheduler.set_timesteps(2) + + batch_size, seq_len = 2, 16 + x_t = torch.randint(0, scheduler.vocab_size, (batch_size, seq_len), dtype=torch.long) + logits = torch.zeros((batch_size, seq_len, scheduler.vocab_size), dtype=torch.float32) + + gen = torch.Generator().manual_seed(0) + x_prev = scheduler.step(logits, scheduler.timesteps[0], x_t, generator=gen, return_dict=True).prev_sample + + self.assertEqual(x_prev.shape, x_t.shape) + self.assertTrue(((x_prev >= 0) & (x_prev < scheduler.vocab_size)).all().item()) + # With exclusion, mask token should not appear. + self.assertFalse((x_prev == scheduler.mask_token_id).any().item()) + + def test_alpha_helpers_shapes(self): + scheduler = self.get_scheduler(num_train_timesteps=10) + timesteps = torch.tensor([0, 1, 9], dtype=torch.long) + + alpha = scheduler.get_alpha(timesteps) + dalpha = scheduler.get_alpha_prime(timesteps) + + self.assertEqual(alpha.shape, (3, 1)) + self.assertEqual(dalpha.shape, (3, 1)) + + def test_set_timesteps_precomputes_alphas(self): + for schedule in ["log_linear", "linear", "cosine", "geometric"]: + scheduler = self.get_scheduler(alpha_schedule=schedule) + scheduler.set_timesteps(10) + + self.assertIsNotNone(scheduler.alphas) + self.assertEqual(len(scheduler.alphas), 10) + + # Verify pre-computed alphas match on-the-fly computation. + for i, ts in enumerate(scheduler.timesteps): + t = scheduler._t_from_timestep(int(ts.item()), device=torch.device("cpu")) + expected = scheduler._alpha_t(t).to(dtype=torch.float32) + self.assertTrue( + torch.allclose(scheduler.alphas[i].cpu(), expected, atol=1e-6), + f"Alpha mismatch at step {i} for schedule {schedule}", + ) + + def test_set_timesteps_builds_step_index_map(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(5) + + self.assertIsNotNone(scheduler._step_index_map) + self.assertEqual(len(scheduler._step_index_map), 5) + + for i, ts in enumerate(scheduler.timesteps): + self.assertEqual(scheduler._step_index_map[int(ts.item())], i) + + def test_step_respects_block_mask(self): + scheduler = self.get_scheduler() + scheduler.set_timesteps(1) + + batch_size, seq_len = 2, 8 + x = torch.full((batch_size, seq_len), scheduler.mask_token_id, dtype=torch.long) + block_mask = torch.zeros_like(x, dtype=torch.bool) + block_mask[:, :4] = True + + logits = torch.zeros((batch_size, seq_len, scheduler.config.vocab_size), dtype=torch.float32) + gen = torch.Generator().manual_seed(0) + out = scheduler.step(logits, scheduler.timesteps[0], x, generator=gen, return_dict=True, block_mask=block_mask) + + # Block positions should be denoised (non-mask) after the final noise-removal step. + self.assertTrue((out.prev_sample[:, :4] != scheduler.mask_token_id).all().item()) + # Outside the block, tokens should remain unchanged (still mask). + self.assertTrue((out.prev_sample[:, 4:] == scheduler.mask_token_id).all().item()) + + def test_add_noise_respects_block_mask(self): + scheduler = self.get_scheduler() + batch_size, seq_len = 2, 16 + x0 = torch.randint(0, scheduler.config.vocab_size - 1, (batch_size, seq_len), dtype=torch.long) + block_mask = torch.zeros_like(x0, dtype=torch.bool) + block_mask[:, :4] = True + + # Use high noise timestep so almost all block positions get noised. + timesteps = torch.full((batch_size,), scheduler.num_train_timesteps - 1, dtype=torch.long) + xt = scheduler.add_noise(x0, noise=None, timesteps=timesteps, block_mask=block_mask) + + # Outside the block, tokens must be unchanged. + self.assertTrue(torch.equal(xt[:, 4:], x0[:, 4:])) + + def test_step_without_block_mask_unchanged(self): + """Passing block_mask=None should produce the same result as before.""" + scheduler = self.get_scheduler() + scheduler.set_timesteps(3) + + batch_size, seq_len = 2, 8 + x_t = torch.full((batch_size, seq_len), scheduler.mask_token_id, dtype=torch.long) + logits = torch.randn((batch_size, seq_len, scheduler.config.vocab_size), dtype=torch.float32) + + gen1 = torch.Generator().manual_seed(42) + out1 = scheduler.step(logits, scheduler.timesteps[0], x_t, generator=gen1, return_dict=True) + + gen2 = torch.Generator().manual_seed(42) + out2 = scheduler.step(logits, scheduler.timesteps[0], x_t, generator=gen2, return_dict=True, block_mask=None) + + self.assertTrue(torch.equal(out1.prev_sample, out2.prev_sample)) + + def test_step_uses_precomputed_alphas_consistent_with_recompute(self): + """Verify step() produces identical results whether using pre-computed or recomputed alphas.""" + for process in ["absorbing", "uniform"]: + scheduler = self.get_scheduler(forward_process=process) + scheduler.set_timesteps(3) + + batch_size, seq_len = 2, 8 + if process == "absorbing": + x_t = torch.full((batch_size, seq_len), scheduler.mask_token_id, dtype=torch.long) + else: + x_t = torch.randint(0, scheduler.vocab_size, (batch_size, seq_len), dtype=torch.long) + logits = torch.randn((batch_size, seq_len, scheduler.vocab_size), dtype=torch.float32) + + gen1 = torch.Generator().manual_seed(42) + out1 = scheduler.step(logits, scheduler.timesteps[0], x_t, generator=gen1, return_dict=True) + + # Temporarily clear pre-computed alphas to force recomputation. + saved_alphas = scheduler.alphas + saved_map = scheduler._step_index_map + scheduler.alphas = None + scheduler._step_index_map = None + + gen2 = torch.Generator().manual_seed(42) + out2 = scheduler.step(logits, scheduler.timesteps[0], x_t, generator=gen2, return_dict=True) + + scheduler.alphas = saved_alphas + scheduler._step_index_map = saved_map + + self.assertTrue( + torch.equal(out1.prev_sample, out2.prev_sample), + f"Mismatch for forward_process={process}", + ) + + +if __name__ == "__main__": + unittest.main()