Skip to content

support qwen3.5(vit) inference in turbomind backend#4602

Open
irexyc wants to merge 11 commits into
InternLM:mainfrom
irexyc:qwen3.5-2
Open

support qwen3.5(vit) inference in turbomind backend#4602
irexyc wants to merge 11 commits into
InternLM:mainfrom
irexyc:qwen3.5-2

Conversation

@irexyc
Copy link
Copy Markdown
Collaborator

@irexyc irexyc commented May 20, 2026

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

Please describe the motivation of this PR and the goal you want to achieve through this PR.

Modification

Please briefly describe what modification is made in this PR.

BC-breaking (Optional)

Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
  3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

irexyc and others added 9 commits May 15, 2026 10:28
Port the Qwen3.5-VLM C++ infrastructure hooks onto the irexyc/main
turbomind tree: VisualModel-aware engine pipeline (Engine owns a
VisualModel run before the language model), InputProcessor multimodal
embedding patching, ModelRoot visual child, checked_child_cast for
abstract bases, multimodal_input value type, and bind.cpp/turbomind.cc
wiring.

Reworks RoPE: drops RopeType::kMrope in favour of an orthogonal
MropeMode (kNone/kChunked/kInterleaved), adds interleaved mrope and
non-causal attention (causal flag threaded through attention kernels,
desc and reference). mrope device buffers are borrowed from env when
the C++ vision encoder produced them, else lazily allocated for the
legacy r.inputs path.

Adapted to irexyc/main's post-a4025b9 conventions (TM_CUDA_CHECK,
scope tracing, void LlamaLinear::Forward with Ref<Tensor> output).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ion)

Add the native Qwen3.5 vision encoder for the turbomind backend:
qwen3_5vit/* (patch embed, fast pos/rotary embeds, fused embed merge,
mrope position ids, qkv preprocess, bias-gelu, ViT/block weights),
the VisualModel/VisualModelWeight base, LayerNormWeight, the
norm/layer_norm kernel, and CMake wiring.

New CUDA/C++ written against the pre-a4025b9 convention is migrated to
the consolidated error handling: sync_check_cuda_error() ->
TM_CUDA_CHECK(cudaGetLastError()), returning LlamaLinear::Forward() ->
void + Ref<Tensor> output wrapped in TM_SCOPE_CALL, TM_FUNCTION_SCOPE()
at the model entry, TM_CHECK(0) << msg -> TM_LOG_FATAL(...), and the
core/logger.h include the TM_DISPATCH_* macros now require.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add the Python turbomind loader for Qwen3.5: _Qwen3_5Model (text) and
the registered Qwen3_5Model (text + native ViT) extending it, the
VisualModelBuilder / LayerNormBuilder, and converter support for
disable_vision_encoder (maps qwen3_5 -> _qwen3_5) and the _vision
aggregate config selection.

Adapted to the 01ddf16 modeling infrastructure already on
irexyc/main: _Qwen3_5Model mirrors Qwen3_5TextModel's topology with
_uses_prefix removed and the DeltaNet builder calls de-qkv_split'd
(builder derives the split internally); _cpp_dtype dropped in favour
of the torch.dtype path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Wire Qwen3.5 VLM through the user-facing paths: enable the turbomind
backend for Qwen3_5/Qwen3_5Moe (keep InternS2-Preview gated), thread
disable_vision_encoder through archs/cli/messages/pipeline/api_server,
and add the _turbomind_native_vision path so the native ViT bypasses
the legacy Python vision preprocessor.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
MSVC odr-uses constexpr locals read from an enclosing lambda and then
rejects them as non-type template arguments, breaking the Windows wheel
build with C2672 on kernel::LayerNorm / kernel::ResidualBiasLayerNorm.
Redeclare kThreads/kVecSize inside the nested launch lambdas to match the
working pattern in rms_norm.cu (RMSNormQK).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings May 20, 2026 08:07
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds native Qwen3.5 (ViT) multimodal inference support to the TurboMind backend by introducing a C++ “visual model” sub-graph (weights + runtime), wiring multimodal inputs through the Python bindings/engine request path, and extending attention/RoPE codepaths to support the required non-causal ViT attention and updated M-RoPE behavior.

Changes:

  • Introduce VisualModel / VisualModelWeight and integrate an optional visual encoder into TurboMind engine construction and execution.
  • Implement the Qwen3.5 ViT encoder runtime + supporting CUDA kernels (pos-embed interpolation helpers, RoPE table generation, QKV preprocessing, mrope position-id scatter, fused embed merge, bias+GELU, LayerNorm).
  • Extend RoPE and attention kernels to support MropeMode (chunked/interleaved) orthogonally to base RoPE type and add non-causal prefill attention kernel variants; update Python pipeline pieces to pass multimodal payloads and to optionally disable the vision encoder.

Reviewed changes

Copilot reviewed 94 out of 94 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
src/turbomind/turbomind.cc Create optional visual runtime from weights and adjust global communicator sizing for outer DP.
src/turbomind/python/bind.cpp Add multimodal Value pybind + conversion and pass mm_inputs through ModelRequest.forward.
src/turbomind/models/visual_model.h Define VisualModel base + multimodal data/embedding carrier types.
src/turbomind/models/visual_model.cc Add translation unit for VisualModel (linker anchor).
src/turbomind/models/visual_model_weight.h Define VisualModelWeight base and make_model() factory for visual runtimes.
src/turbomind/models/qwen3_5vit/test_mrope_position_ids.cu Add standalone CUDA test for mrope position-id scatter kernel.
src/turbomind/models/qwen3_5vit/qwen3_5vit.h Declare Qwen3.5 ViT VisualModel runtime.
src/turbomind/models/qwen3_5vit/qwen3_5vit.cc Implement Qwen3.5 ViT runtime (setup/prepare/forward) incl. mrope table production and embedding injection.
src/turbomind/models/qwen3_5vit/qwen3_5vit_weight.h Add Qwen3.5 ViT visual weight tree + config.
src/turbomind/models/qwen3_5vit/qwen3_5vit_weight.cc Implement Qwen3.5 ViT weight verification/prepare and runtime factory; register module.
src/turbomind/models/qwen3_5vit/qwen3_5vit_block_weight.h Add per-block ViT weight module + config.
src/turbomind/models/qwen3_5vit/qwen3_5vit_block_weight.cc Register block weight module.
src/turbomind/models/qwen3_5vit/qkv_preprocess.h Declare fused bias+RoPE+transpose QKV preprocess kernel entry point.
src/turbomind/models/qwen3_5vit/qkv_preprocess.cu Implement QKV preprocess kernel with head-dim specializations.
src/turbomind/models/qwen3_5vit/mrope_position_ids.h Declare mrope segment descriptor and scatter API.
src/turbomind/models/qwen3_5vit/mrope_position_ids.cu Implement mrope position-id scatter kernel.
src/turbomind/models/qwen3_5vit/fused_embed_merge.h Declare fused pos-embed merge kernel.
src/turbomind/models/qwen3_5vit/fused_embed_merge.cu Implement fused pos-embed merge kernel.
src/turbomind/models/qwen3_5vit/fast_rotary_pos_emb.h Declare fast rotary table generator for vision tokens.
src/turbomind/models/qwen3_5vit/fast_rotary_pos_emb.cu Implement fast rotary table generator.
src/turbomind/models/qwen3_5vit/fast_pos_embed.h Declare bilinear gather index/weight generator for pos-embed interpolation.
src/turbomind/models/qwen3_5vit/fast_pos_embed.cu Implement bilinear gather index/weight generator.
src/turbomind/models/qwen3_5vit/bias_gelu.h Declare in-place bias+activation kernel for ViT.
src/turbomind/models/qwen3_5vit/bias_gelu.cu Implement bias+GELU kernels (pytorch-tanh + erf variants).
src/turbomind/models/model_root.h Extend ModelRoot with optional visual_model child accessor.
src/turbomind/models/llama/unified_attention_layer.cc Allow mrope to be layered over base rope and borrow mrope tensors from env when produced by vision runtime.
src/turbomind/models/llama/SequenceManager.h Store per-sequence multimodal inputs for later scheduling/compute.
src/turbomind/models/llama/llama_rope.h Replace kMrope rope type with orthogonal MropeMode (none/chunked/interleaved).
src/turbomind/models/layer_norm_weight.h Add LayerNorm weight module (gamma+beta) and config.
src/turbomind/models/layer_norm_weight.cc Implement LayerNorm weight prepare + registration.
src/turbomind/models/language_model.cc Route embedding patching through TensorMap& env (to allow multimodal injection).
src/turbomind/models/input_processor.h Update patch embedding API to accept TensorMap& env.
src/turbomind/models/input_processor.cc Patch embeddings with both input-embedding ranges and optional multimodal embeddings from env.
src/turbomind/models/CMakeLists.txt Build/link new visual model sources and layer-norm kernel; add mrope test target.
src/turbomind/models/attention_weight.h Add rope flag for interleaved mrope and a causal flag to attention config.
src/turbomind/models/attention_weight.cc Initialize MropeMode orthogonally to base rope type; plumb causal flag.
src/turbomind/kernels/norm/layer_norm.h Declare LayerNorm kernel APIs.
src/turbomind/kernels/norm/layer_norm.cu Implement LayerNorm + residual-bias LayerNorm CUDA kernels.
src/turbomind/kernels/norm/CMakeLists.txt Build layer_norm kernel library.
src/turbomind/kernels/attention/test_attention.cu Extend attention test to cover causal vs non-causal paths.
src/turbomind/kernels/attention/rotary_embedding.h Update FastRoPE to support MropeMode layered over base rope type.
src/turbomind/kernels/attention/registry.cu Include causal in attention kernel selection.
src/turbomind/kernels/attention/reference.h Add causal flag to reference attention implementation.
src/turbomind/kernels/attention/reference.cu Generalize mask generation for causal/non-causal reference.
src/turbomind/kernels/attention/mainloop_sm80.h Template masking logic on causal/non-causal behavior and handle partial K tiles.
src/turbomind/kernels/attention/mainloop_sm70.h Template masking logic on causal/non-causal behavior and handle partial K tiles.
src/turbomind/kernels/attention/kernel/attention_sm80_64.cu Register non-causal kernel variants.
src/turbomind/kernels/attention/kernel/attention_sm80_576.cu Register non-causal kernel variants.
src/turbomind/kernels/attention/kernel/attention_sm80_256.cu Register non-causal kernel variants.
src/turbomind/kernels/attention/kernel/attention_sm80_192.cu Register non-causal kernel variants.
src/turbomind/kernels/attention/kernel/attention_sm80_128.cu Register non-causal kernel variants.
src/turbomind/kernels/attention/kernel/attention_sm75_64.cu Register non-causal kernel variants.
src/turbomind/kernels/attention/kernel/attention_sm75_576.cu Register non-causal kernel variants.
src/turbomind/kernels/attention/kernel/attention_sm75_256.cu Register non-causal kernel variants.
src/turbomind/kernels/attention/kernel/attention_sm75_128.cu Register non-causal kernel variants.
src/turbomind/kernels/attention/kernel/attention_sm70_64.cu Register non-causal kernel variants.
src/turbomind/kernels/attention/kernel/attention_sm70_576.cu Register non-causal kernel variants.
src/turbomind/kernels/attention/kernel/attention_sm70_256.cu Register non-causal kernel variants.
src/turbomind/kernels/attention/kernel/attention_sm70_128.cu Register non-causal kernel variants.
src/turbomind/kernels/attention/kernel_impl.h Record causal attribute in kernel descriptor.
src/turbomind/kernels/attention/desc.h Add causal flag to descriptors and naming.
src/turbomind/kernels/attention/decoding.cu Force decoding kernels to remain causal.
src/turbomind/kernels/attention/attention.cu Populate descriptor causal flag from params.
src/turbomind/kernels/attention/attention_universal.h Template attention universal kernel on causal/non-causal and adjust tiling/masking.
src/turbomind/kernels/attention/attention_template.h Adjust tile-count computation for non-causal mode.
src/turbomind/kernels/attention/attention_params.h Add causal flag to attention params.
src/turbomind/kernels/activation.h Add GELU activation types for ViT paths.
src/turbomind/engine/request.h Add mm_inputs to Request structure (currently not serialized).
src/turbomind/engine/multimodal_input.h Introduce multimodal::Value schema for structured multimodal payloads.
src/turbomind/engine/model_request.h Add multimodal input pointer to request forwarding parameters.
src/turbomind/engine/model_request.cc Store mm_inputs into the created Request.
src/turbomind/engine/model_executor.h Allow executor to run an optional visual model before the language model.
src/turbomind/engine/model_executor.cc Execute visual prepare/forward around language model execution.
src/turbomind/engine/engine.h Extend Engine ctor to accept optional VisualModel.
src/turbomind/engine/engine.cc Invoke visual model for relevant batch ops and pass it into the executor.
src/turbomind/core/module.h Allow child attachment to accept derived instances when the declared child type is abstract.
lmdeploy/vl/model/qwen3_5.py Mark Qwen3.5 VLM as TurboMind-native vision capable.
lmdeploy/vl/model/preprocess_utils.py Ensure expanded multimodal items are sorted by offset.
lmdeploy/vl/model/builder.py Skip building Python vision model when TurboMind-native vision is available.
lmdeploy/vl/model/base.py Add _turbomind_native_vision capability flag to base vision model.
lmdeploy/turbomind/turbomind.py Add multimodal argument and pass it into _turbomind forward call.
lmdeploy/turbomind/models/utils.py Treat mrope as orthogonal to base rope type and always propagate mrope section when present.
lmdeploy/turbomind/models/qwen3_5.py Add TurboMind builders for Qwen3.5 vision sub-tree and enable interleaved mrope; implement ViT head-dim padding.
lmdeploy/turbomind/models/init.py Export Qwen3.5 vision model symbol.
lmdeploy/turbomind/converter.py Pass disable_vision_encoder only to models that support it.
lmdeploy/turbomind/builders/layer_norm.py Add Python builder for affine LayerNorm weights.
lmdeploy/turbomind/builders/_base.py Add VisualModelBuilder that attaches visual_model under ModelRoot.
lmdeploy/turbomind/builders/init.py Export new builders and config helpers.
lmdeploy/serve/processors/multimodal.py Skip Python-side VL inference path when using TurboMind-native vision.
lmdeploy/serve/openai/api_server.py Pass backend_config into get_task so disable_vision_encoder can affect routing.
lmdeploy/pipeline.py Pass backend_config into get_task so disable_vision_encoder can affect routing.
lmdeploy/messages.py Add disable_vision_encoder to TurboMind engine config.
lmdeploy/cli/serve.py Wire --disable-vision-encoder into TurboMind server args.
lmdeploy/archs.py Make get_task aware of disable_vision_encoder and update TurboMind VL arch allowlist.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +301 to +306
if (py::hasattr(obj, "__dlpack__")) {
py::capsule cap = obj.attr("__dlpack__")();
auto* dlmt = static_cast<DLManagedTensor*>(PyCapsule_GetPointer(cap.ptr(), kDlTensorCapsuleName));
auto ret = DLManagedTensorToTritonTensor(dlmt);
cap.set_name("used_dltensor");
return Value{std::move(*ret)};
// reference to IO tensors
TensorMap inputs;
TensorMap outputs;
// TODO: update serdes to support multiple nodes inference
Comment on lines +210 to +223
if (r.mm_inputs && !r.mm_inputs->is_null()) {
if ((not r.session.start_flag) or (not r.session.end_flag)) {
// only support non-interactive inference
return Request::kInvalid;
}

const auto& mm_inputs = r.mm_inputs->get<multimodal::Array>();
for (const auto& item : mm_inputs) {
const auto& doc = item.get<multimodal::Document>();
const auto& modality = doc.at("modality").get<std::string>();
const auto& ranges = doc.at("offset").get<multimodal::Array>();
const int offset = ranges[0].get<int64_t>();
const int tokens = ranges[1].get<int64_t>() - offset;

@lvhan028 lvhan028 added the enhancement New feature or request label May 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants