Skip to content

feat: Add HiDream-O1 transformer and image generation pipeline#13749

Open
chinoll wants to merge 8 commits into
huggingface:mainfrom
chinoll:hidream-o1-transformer-model
Open

feat: Add HiDream-O1 transformer and image generation pipeline#13749
chinoll wants to merge 8 commits into
huggingface:mainfrom
chinoll:hidream-o1-transformer-model

Conversation

@chinoll
Copy link
Copy Markdown
Contributor

@chinoll chinoll commented May 14, 2026

What does this PR do?

This PR adds Diffusers support for HiDream-O1 image generation.

HiDream-O1 is a Qwen3-VL based image generation model that denoises raw RGB image patches directly.
Unlike HiDream-I1 and most image diffusion pipelines, it does not use a VAE component.

This PR adds:

  • HiDreamO1Transformer2DModel, a ModelMixin / ConfigMixin wrapper for HiDream-O1 checkpoints.
  • HiDreamO1AttnProcessor, a dedicated attention processor for the HiDream-O1 two-pass attention path.
  • HiDreamO1ImagePipeline, a text-to-image pipeline for raw RGB patch denoising.
  • Loading support for official Transformers-style HiDream-O1 checkpoints.
  • A generation script at scripts/generate_hidream_o1_image.py.
  • API documentation for the model and pipeline.
  • Tests for model loading, serialization, attention processor behavior, official implementation parity, and pipeline smoke generation.

Original implementation and checkpoints:

Notes

HiDream-O1 does not use a VAE. The pipeline prepares Qwen3-VL chat inputs, constructs O1 multimodal RoPE positions, denoises patchified RGB noise, and unpatchifies the final tensor into image space.

The transformer can also be loaded independently:

import torch
from diffusers import HiDreamO1Transformer2DModel

transformer = HiDreamO1Transformer2DModel.from_pretrained(
    "HiDream-ai/HiDream-O1-Image",
    torch_dtype=torch.bfloat16,
)

The pipeline can be loaded with:

import torch
from diffusers import HiDreamO1ImagePipeline

pipe = HiDreamO1ImagePipeline.from_pretrained(
    "HiDream-ai/HiDream-O1-Image",
    torch_dtype=torch.bfloat16,
)
pipe.to("cuda")

image = pipe(
    "A cinematic portrait of a glass astronaut standing in a neon-lit botanical garden.",
    generator=torch.Generator("cuda").manual_seed(32),
).images[0]

For the dev checkpoint:

import torch
from diffusers import HiDreamO1ImagePipeline

pipe = HiDreamO1ImagePipeline.from_pretrained(
    "HiDream-ai/HiDream-O1-Image-Dev",
    torch_dtype=torch.bfloat16,
)
pipe.to("cuda")

image = pipe(
    "A cinematic portrait of a glass astronaut standing in a neon-lit botanical garden.",
    model_type="dev",
    generator=torch.Generator("cuda").manual_seed(32),
).images[0]

Tests

PYTHONPATH=src python -m pytest -q tests/models/transformers/test_models_transformer_hidream_o1.py tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py

Result:

5 passed, 6 warnings

I also ran real image generation tests with the full and dev checkpoints in bfloat16, including multiple aspect ratios.

Before submitting

Who can review?

@yiyixuxu ,@asomoza , @sayakpaul

Generate Image

hidream_o1_processor_full_bf16_exact_2048x1024_50_seed32

@github-actions github-actions Bot added documentation Improvements or additions to documentation models tests utils pipelines size/L PR with diff > 200 LOC labels May 14, 2026
@chinoll chinoll changed the title Add HiDream-O1 transformer and image generation pipeline feat: Add HiDream-O1 transformer and image generation pipeline May 14, 2026
@sayakpaul
Copy link
Copy Markdown
Member

@claude could you do a review?

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 14, 2026

Claude finished @sayakpaul's task in 3m 44s —— View job


Review of HiDream-O1 Transformer & Pipeline PR

  • Read review guidelines (AGENTS.md, models.md, pipelines.md, modular.md)
  • Read transformer model code
  • Read pipeline code
  • Read pipeline __init__.py and registration
  • Read tests
  • Trace call path (pipeline → model) for dead code analysis
  • Post review

Thanks for the contribution! This is a solid start for integrating HiDream-O1 into diffusers, and the parity testing approach is commendable. There are several architectural issues that need to be addressed before this can be merged, mostly around aligning with diffusers conventions.


Blocking Issues

1. Attention must use dispatch_attention_fn, not direct F.scaled_dot_product_attention / flash_attn

transformer_hidream_o1.py:42-61, 97-173

The HiDreamO1AttnProcessor directly imports flash_attn at module level and calls F.scaled_dot_product_attention in its _attention method. This violates the core diffusers attention pattern documented in .ai/models.md:

Attention must follow the diffusers pattern: the processor's __call__ must use dispatch_attention_fn rather than calling F.scaled_dot_product_attention directly.

No other transformer in the repo imports flash_attn directly. The correct approach is to use dispatch_attention_fn which handles backend selection transparently (flash, sage, SDPA, etc.). The attention class should also inherit AttentionModuleMixin and declare _default_processor_cls and _available_processors.

See transformer_qwenimage.py for the canonical pattern — it also handles Qwen-style attention and uses dispatch_attention_fn with proper backend/parallel config support.

The use_flash_attn parameter propagated through the model and pipeline would be replaced by the framework's attention backend selection.

Fix this →

2. Environment variables for runtime configuration (os.environ)

transformer_hidream_o1.py:43-61 (FA_VERSION) and transformer_hidream_o1.py:68 (USE_BF16_ROPE)

No other transformer in the repo uses os.environ for configuration. FA_VERSION controls flash attention version selection and USE_BF16_ROPE changes RoPE precision. These should be removed entirely:

  • FA_VERSION becomes unnecessary once dispatch_attention_fn is used.
  • USE_BF16_ROPE should be a proper constructor parameter or removed if only one path is needed for inference.

3. Missing model class attributes

transformer_hidream_o1.py:766-767

The model is missing several standard class attributes that other transformers declare:

  • _repeated_blocks: Every transformer in the repo declares this (e.g., FluxTransformer2DModel, WanTransformer3DModel, QwenImageTransformer2DModel). Since this model wraps Qwen3VL decoder layers, it should list the appropriate block class name.
  • _skip_layerwise_casting_patterns: Needed for enable_layerwise_casting(). Should skip at minimum the patch embedding, timestep embedder, and normalization layers (e.g., ["x_embedder", "t_embedder", "norm", "patch"]).
  • _keep_in_fp32_modules: Consider listing precision-sensitive modules like the timestep embedder.

4. Pipeline uses custom _get_module_device / _get_module_dtype instead of framework utilities

pipeline_hidream_o1.py:275-284

DiffusionPipeline already provides self._execution_device (handles CPU offload correctly) and ModelMixin provides self.dtype. The custom _get_module_device and _get_module_dtype functions won't work correctly with enable_sequential_cpu_offload() or enable_model_cpu_offload(), breaking a core diffusers feature.

Replace:

device = _get_module_device(self.transformer)
dtype = _get_module_dtype(self.transformer)

with:

device = self._execution_device
dtype = self.transformer.dtype

5. Pipeline from_pretrained override is too complex and fragile

pipeline_hidream_o1.py:359-428

The custom from_pretrained with try/except fallback logic, manual kwarg filtering, and dual-mode loading (Diffusers vs Transformers checkpoint) is fragile. The standard diffusers pattern is for from_pretrained to just work via model_index.json. The Transformers-checkpoint loading should be a separate conversion script or handled in the model's own from_pretrained, not in the pipeline.

Other pipelines (Flux, Wan, QwenImage) do not override from_pretrained at all. The HiDreamO1Transformer2DModel.from_pretrained already handles both checkpoint formats — that's sufficient.

6. Custom denoising loop bypasses the scheduler

pipeline_hidream_o1.py:680-715

The pipeline inlines its own velocity/guidance math instead of letting the scheduler handle it properly:

v_cond = (x_pred_cond.float() - patches.float()) / sigma
...
model_output = -v_guided

Per .ai/pipelines.md Gotcha #3:

Check src/diffusers/schedulers/ before adding new logic. Reimplementing what the scheduler already does is a common mistake.

The velocity computation and guidance should work through the scheduler's step method. Also, the _set_timesteps helper manually sets scheduler.timesteps and scheduler.sigmas which is not standard — the scheduler should own its own state.

7. model_type parameter in __call__ is an anti-pattern

pipeline_hidream_o1.py:572

A model_type parameter that switches between hardcoded preset configurations in __call__ is not used anywhere else in diffusers. The full vs dev distinction should be handled by:

  • Different scheduler configs at construction time
  • Different default parameters in the checkpoint's config
  • Or simply documented as different num_inference_steps / guidance_scale / shift values the user passes

The user should configure these explicitly rather than through a magic string.


Non-Blocking Issues / Suggestions

8. Training-time dead code in _forward_generation

transformer_hidream_o1.py:461-482

The elif torch.is_grad_enabled(): branch creates fake pixel values and runs them through the vision encoder to keep gradients flowing during training. This is a training-time code path. Per .ai/AGENTS.md:

When porting from a research repo, delete training-time code paths entirely — only keep the inference path you are actually integrating.

9. Several output fields and dataclasses appear unused

transformer_hidream_o1.py:175-220

Three output dataclasses (HiDreamO1Transformer2DModelOutput, HiDreamO1Qwen3VLModelOutputWithPast, HiDreamO1Qwen3VLCausalLMOutputWithPast) carry fields (mid_results, cond_image_embeds, cond_deepstack_image_embeds) that are never used by the pipeline. The pipeline only reads outputs.sample which maps to x_pred. HiDreamO1ForConditionalGeneration and the lm_head on the wrapper model are also only used for the non-generation (language modeling) path and are dead code in the context of image generation.

Consider slimming down to only what the pipeline actually consumes.

10. Module-level constants duplicating configuration

pipeline_hidream_o1.py:33-82

PATCH_SIZE, T_EPS, FULL_NOISE_SCALE, DEV_FLASH_NOISE_SCALE, DEV_FLASH_NOISE_CLIP_STD, PREDEFINED_RESOLUTIONS, DEFAULT_TIMESTEPS are all hardcoded module-level constants. At minimum, PATCH_SIZE should come from the model config (self.transformer.config.patch_size).

11. Missing noise_scale_start/noise_scale_end/noise_clip_std docstrings

pipeline_hidream_o1.py:573-575

These three __call__ parameters have no docstring entries despite being important for output quality.

12. _maybe_set_scheduler_shift is too defensive

pipeline_hidream_o1.py:287-294

This function probes multiple attribute names with hasattr fallbacks. The scheduler's API should be called directly — if the scheduler doesn't support shift, that's a user error, not something to silently handle.

13. initialize_weights called in __init__

transformer_hidream_o1.py:229-234

HiDreamO1BottleneckPatchEmbed.__init__ calls self.initialize_weights() and HiDreamO1FinalLayer.__init__ calls self.apply(self._init_weights). These will be overwritten when loading pretrained weights, so they're no-ops for inference. Not harmful, but unnecessary complexity.

14. Script at scripts/generate_hidream_o1_image.py imports from the official repo

scripts/generate_hidream_o1_image.py:127-146

The script imports FlashFlowMatchEulerDiscreteScheduler, FlowUniPCMultistepScheduler, and get_rope_index_fix_point from the official HiDream-O1 repo via sys.path manipulation. Per review rules, scripts with hardcoded developer paths or imports from the reference repo are ephemeral context that shouldn't be merged. This script should either use only diffusers components or be removed.

15. Gradient checkpointing wiring

transformer_hidream_o1.py:361, 386-398

The model declares _supports_gradient_checkpointing = True but the checkpointing branch reads text_model.gradient_checkpointing (from the inner Qwen3VL language model) rather than self.gradient_checkpointing (from the ModelMixin wrapper). When a user calls model.enable_gradient_checkpointing(), it sets self.gradient_checkpointing = True on the ModelMixin model, but the inner _run_decoder_two_pass_attention checks a different attribute. This likely means gradient checkpointing silently doesn't work.

16. Tests use hardcoded /tmp/ paths

test_models_transformer_hidream_o1.py:163

repo_root = os.environ.get("HIDREAM_O1_OFFICIAL_REPO", "/tmp/HiDream-O1-Image")

Per .ai/skills/parity-testing/pitfalls.md Pitfall #10: "NEVER store converted models in /tmp/". The parity test's default path points to /tmp/.


Suggestions / Additional Info: Dead Code Analysis

Tracing the pipeline's __call___forward_transformerself.transformer()HiDreamO1Transformer2DModel.forward()HiDreamO1Qwen3VLModel.forward()_forward_generation():

The pipeline always passes vinputs, so it always hits _forward_generation. The pipeline never passes:

  • pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw → All the vision-encoder branches in _forward_generation (lines 444-515) are unreachable from the pipeline
  • precomputed_image_embeds, precomputed_deepstack_image_embeds → Dead
  • return_mid_results_layers → Always None, so mid_results tracking is dead
  • labels, logits_to_keep, past_key_values, inputs_embeds, cache_position → Only used in the LM path, dead for image generation

Methods on HiDreamO1Transformer2DModel that are never called from the pipeline:

  • get_image_features(), get_video_features() — only useful for the vision-encoder path
  • get_input_embeddings() — useful for prompt building but called on processor, not the model

HiDreamO1ForConditionalGeneration (with lm_head, GenerationMixin, loss computation, etc.) is only used as an intermediate in from_pretrained for checkpoint conversion. The lm_head on the wrapper HiDreamO1Transformer2DModel is loaded but never used by the pipeline.

Note: These findings are based on the default pipeline call path. The dead code may be useful for other contexts (e.g., reference-image generation, fine-tuning), but under the current pipeline, it represents significant dead weight. Per .ai/AGENTS.md, only the inference path being integrated should be kept.


@chinoll
Copy link
Copy Markdown
Contributor Author

chinoll commented May 14, 2026

Thanks for the review. I pushed follow-up changes up to f9e374a59 on chinoll:hidream-o1-transformer-model.

Summary by review item:

  1. Attention backend: addressed. HiDreamO1AttnProcessor now uses dispatch_attention_fn, and the attention module follows the diffusers processor pattern with AttentionModuleMixin, _default_processor_cls, and _available_processors.
  2. Runtime env vars: addressed. The flash-attn / RoPE environment-variable switches were removed.
  3. Model class attributes: addressed for the required attributes. Added _repeated_blocks, _skip_layerwise_casting_patterns, and _no_split_modules. I did not add _keep_in_fp32_modules yet because it was phrased as optional; I can add it if maintainers prefer a specific list.
  4. Pipeline device/dtype helpers: addressed. The pipeline now uses self._execution_device and self.transformer.dtype.
  5. Pipeline from_pretrained: addressed. The custom pipeline override was removed; model loading compatibility stays in HiDreamO1Transformer2DModel.from_pretrained.
  6. Scheduler / denoising: addressed. The pipeline no longer mutates scheduler.timesteps / scheduler.sigmas directly, and now routes custom schedules through scheduler.set_timesteps(...). The default scheduler uses prediction_type="sample", so the model's x0 prediction is passed to scheduler.step() directly instead of converting to a hand-written velocity prediction.
  7. model_type: addressed. The magic model_type preset switch was removed from the pipeline call; full/dev settings are passed explicitly by users or by the helper script.
  8. Training-time branch: intentionally kept for now. The public HiDream-O1 checkpoints are Qwen3-VL-style checkpoints, not a pruned image-only checkpoint, and keeping this code path helps preserve compatibility with the official module structure while loading the official weights directly.
  9. Output fields / LM path / lm_head: intentionally kept for now. The official checkpoints include the CausalLM head and related Qwen3-VL weights. Removing lm_head / CausalLM-compatible structures would require key filtering or a conversion-only loading path, and would break direct loading from the official checkpoint format. The text-to-image pipeline does not use these fields, but they are retained to keep checkpoint compatibility and avoid silently dropping official weights.
  10. Module constants: partially addressed. Removed obsolete scheduler constants from the pipeline. PATCH_SIZE and the official resolution buckets remain because HiDream-O1 is a raw-RGB patch model with fixed official generation buckets; I can move PATCH_SIZE reads to self.transformer.config.patch_size in a follow-up if preferred.
  11. Missing docstrings: addressed. Added docs for noise_scale_start, noise_scale_end, and noise_clip_std.
  12. Scheduler shift helper: addressed in f9e374a59. The helper is now explicit: it supports schedulers with flow_shift config or set_shift(), and raises a ValueError otherwise instead of silently probing/falling through.
  13. Init-time weight initialization: addressed. Removed the extra initialize_weights() / _init_weights() calls from the O1 patch and final layers.
  14. Generation script official-repo dependency: addressed. The script now uses diffusers components and no longer imports schedulers or helpers from the official repository.
  15. Gradient checkpointing: addressed. The two-pass decoder path now uses the model's _gradient_checkpointing_func. I also verified locally that HiDreamO1Transformer2DModel.enable_gradient_checkpointing() propagates to the inner Qwen3-VL language model used by this path.
  16. /tmp in tests: addressed. The parity test now requires HIDREAM_O1_OFFICIAL_REPO explicitly and no longer defaults to /tmp/HiDream-O1-Image.

Validation run locally:

  • python -m py_compile ...
  • PYTHONPATH=src python -m pytest -q tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py tests/models/transformers/test_models_transformer_hidream_o1.py -> 5 passed, 1 skipped
  • HIDREAM_O1_OFFICIAL_REPO=/tmp/HiDream-O1-Image PYTHONPATH=src python -m pytest -q tests/models/transformers/test_models_transformer_hidream_o1.py tests/pipelines/hidream_o1/test_pipeline_hidream_o1.py -> 5 passed

I also smoke-tested Dev bf16 generation on a CUDA machine at 1024x1024 with both the official Dev timestep schedule and the default 28-step schedule; both generated images successfully.

@chinoll
Copy link
Copy Markdown
Contributor Author

chinoll commented May 14, 2026

@claude review

Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for open the PR
i have two feedbacks on high-level:

  1. can you host the transformer in transformers library using remote code?
  2. can you host the pipeline in modular diffusers only? documentation here https://huggingface.co/docs/diffusers/main/en/modular_diffusers/overview we also have pretty good doc for AI agents on this so they should have pretty good idea how to turn it into modular one

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation models pipelines size/L PR with diff > 200 LOC tests utils

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants