llama/compat: add qwen3-vl + qwen2.5-vl handlers

Two QwenVL families that ship with monolithic vision+text Ollama GGUFs.

qwen2.5-vl (text + clip):
  * Text: arch translation qwen25vl→qwen2vl, KV prefix copy, mrope_section
    (3 elems) → rope.dimension_sections (4 elems, padded with 0).
  * Clip: tensor renames (v.merger.* → v.post_ln/mm.0/mm.2,
    v.patch_embd_{0,1} → v.patch_embd.weight{,.1}), use_silu, derive
    n_wa_pattern from fullatt_block_indexes[0]+1, image_size=560 default,
    standard CLIP image_mean/std, F32 promote of patch_embd for Metal.

qwen3-vl (text + clip):
  * Text: inject qwen3vl.rope.dimension_sections=[24,20,20,0] (Qwen3-VL-8B
    HF default — head_dim=128, sum=64) and qwen3vl.n_deepstack_layers
    derived from deepstack_visual_indexes array length.
  * Clip: per-block QKV merge (Ollama stores separate Q/K/V; upstream
    qwen3vl_merger graph reads combined attn_qkv); deepstack remap
    v.deepstack_merger.X.* → v.deepstack.{indexes[X]}.* (with
    linear_fc{1,2} → fc{1,2}); merger renames matching upstream LLaVA
    proj layout (v.post_ln + mm.0/mm.2); patch_embed split (16x16x2
    Conv3D → two 16x16 Conv2Ds, F16→F32); per-block substring renames
    (norm1/2 → ln1/2, mlp.linear_fc1/2 → ffn_up/down); F32 promote of
    position_embd.

The qwen3-vl QKV merge needed a new util helper
(register_concat_load_to_f32) because the Ollama blob mixes types within
a single block (F16 Q/K + Q8_0 V) — the existing byte-concat
register_concat_load only works when all sources share a type. The new
helper dequantizes each source to F32 via ggml_get_type_traits->to_float
and concatenates; caller sets the destination tensor type to F32.

Verified end-to-end via `ollama run`:
  * qwen2.5vl: image of NYT moon-landing front page → correct caption.
  * qwen3-vl:  same image → "New York Times front page from July 21,
    1969, headlines the Apollo 11 moon landing with 'MEN WALK ON MOON'
    and details of astronauts collecting rocks and planting a flag."

Also added qwen25vl + qwen3vl to the compatClipArches allowlist in
llm/llama_server.go so the auto-mmproj path activates for these
monolithic blobs (same mechanism as gemma3/4/qwen35moe/etc.).
This commit is contained in:
jmorganca 2026-04-19 19:41:05 -07:00
parent cd2dcaff49
commit 1ce8a6b26d
4 changed files with 387 additions and 0 deletions

View file

@ -319,4 +319,44 @@ void register_concat_load(const gguf_context * meta, std::string dest_name,
});
}
void register_concat_load_to_f32(const gguf_context * meta,
const ggml_context * ctx,
std::string dest_name,
const std::vector<std::string> & src_names) {
struct Region { size_t offset; size_t size; ggml_type type; size_t n_elem; };
std::vector<Region> regions;
regions.reserve(src_names.size());
for (const auto & n : src_names) {
const int64_t id = gguf_find_tensor(meta, n.c_str());
if (id < 0) return;
const ggml_tensor * t = ggml_get_tensor(const_cast<ggml_context *>(ctx), n.c_str());
if (!t) return;
regions.push_back({
gguf_get_data_offset(meta) + gguf_get_tensor_offset(meta, id),
gguf_get_tensor_size(meta, id),
t->type,
(size_t) ggml_nelements(t),
});
}
register_load_op(std::move(dest_name), LoadOp{
[regions](const char * path, void * dst, size_t dst_size) {
size_t total_elems = 0;
for (auto & r : regions) total_elems += r.n_elem;
if (total_elems * sizeof(float) != dst_size) return false;
float * dp = static_cast<float *>(dst);
for (auto & r : regions) {
std::vector<uint8_t> src(r.size);
if (!read_at(path, r.offset, src.data(), r.size)) return false;
const auto * tt = ggml_get_type_traits(r.type);
if (!tt || !tt->to_float) return false;
tt->to_float(src.data(), dp, (int64_t) r.n_elem);
dp += r.n_elem;
}
return true;
},
"concat sources (mixed types -> F32)",
});
}
} // namespace llama_ollama_compat::detail

View file

@ -117,4 +117,14 @@ void promote_tensor_to_f32(gguf_context * meta, ggml_context * ctx, const char *
void register_concat_load(const gguf_context * meta, std::string dest_name,
const std::vector<std::string> & src_names);
// Mixed-type variant of register_concat_load: dequantizes each source to
// F32 via its ggml_type_traits.to_float and concatenates the F32 arrays.
// Use when sources differ in quantization (e.g. F16 q/k + Q8_0 v in some
// Ollama vision blobs). Caller must set the destination tensor's type to
// GGML_TYPE_F32 so dst_size matches the F32 concat size.
void register_concat_load_to_f32(const gguf_context * meta,
const ggml_context * ctx,
std::string dest_name,
const std::vector<std::string> & src_names);
} // namespace llama_ollama_compat::detail

View file

@ -717,6 +717,101 @@ void handle_mistral3(const llama_model_loader * ml, gguf_context * meta, ggml_co
add_skip_prefix(ml, "mm.");
}
// =========================================================================
// qwen25vl (text side — Qwen2.5-VL, arch translation to qwen2vl)
// =========================================================================
//
// Ollama publishes Qwen2.5-VL with general.architecture=qwen25vl, but
// upstream loads both Qwen2-VL and Qwen2.5-VL under arch=qwen2vl
// (which reads the rope.dimension_sections KV for M-RoPE).
//
// Translation:
// * arch_name: qwen25vl → qwen2vl (loader uses arch as KV prefix)
// * KV prefix: qwen25vl.* → qwen2vl.*
// * rope.mrope_section (3 elements) → rope.dimension_sections (4, padded with 0)
// * Hide vision+projector tensors from the text loader.
bool detect_ollama_qwen25vl(const gguf_context * meta) {
const int64_t arch_kid = gguf_find_key(meta, "general.architecture");
if (arch_kid < 0) return false;
return std::strcmp(gguf_get_val_str(meta, arch_kid), "qwen25vl") == 0;
}
void handle_qwen25vl(const llama_model_loader * ml, gguf_context * meta,
ggml_context * ctx, std::string & arch_name) {
(void) ctx;
if (!detect_ollama_qwen25vl(meta)) return;
LLAMA_LOG_INFO("%s: detected Ollama-format qwen25vl GGUF; translating to qwen2vl\n", __func__);
// Switch architecture so the loader reads qwen2vl.* keys (and uses the
// qwen2vl model build path, which handles M-RoPE).
arch_name = "qwen2vl";
gguf_set_val_str(meta, "general.architecture", "qwen2vl");
// Mirror the qwen25vl.* KVs under qwen2vl.* (rename_kv_prefix copies;
// the original qwen25vl.* keys remain but are unread).
rename_kv_prefix(meta, "qwen25vl.", "qwen2vl.");
// Translate mrope_section (3 elems) → dimension_sections (4 elems, padded).
const int64_t kid = gguf_find_key(meta, "qwen2vl.rope.mrope_section");
if (kid >= 0 && gguf_get_arr_n(meta, kid) >= 3) {
const auto * src = static_cast<const int32_t *>(gguf_get_arr_data(meta, kid));
const int32_t padded[4] = { src[0], src[1], src[2], 0 };
gguf_set_arr_data(meta, "qwen2vl.rope.dimension_sections",
GGUF_TYPE_INT32, padded, 4);
}
add_skip_prefix(ml, "v.");
add_skip_prefix(ml, "mm.");
}
// =========================================================================
// qwen3vl (text side — Qwen3-VL)
// =========================================================================
//
// Ollama publishes Qwen3-VL with general.architecture=qwen3vl (matches
// upstream). Two missing KVs that the upstream qwen3vl loader requires:
//
// * qwen3vl.rope.dimension_sections — M-RoPE section sizes. Derived from
// the HF config (rope_scaling.mrope_section). Hardcoded here as
// [24, 20, 20, 0] which matches Qwen3-VL-8B (head_dim=128, sum=64).
// If new Qwen3-VL variants ship with a different mrope, derive from
// the head_dim or read from a published KV.
//
// * qwen3vl.n_deepstack_layers — count of deepstack adapters. Length of
// qwen3vl.vision.deepstack_visual_indexes (3 for Qwen3-VL-8B).
bool detect_ollama_qwen3vl(const gguf_context * meta, const ggml_context * ctx) {
(void) ctx;
const int64_t arch_kid = gguf_find_key(meta, "general.architecture");
if (arch_kid < 0) return false;
if (std::strcmp(gguf_get_val_str(meta, arch_kid), "qwen3vl") != 0) return false;
// Marker: upstream-converted qwen3vl always has rope.dimension_sections;
// Ollama's blob doesn't.
return !has_key(meta, "qwen3vl.rope.dimension_sections");
}
void handle_qwen3vl(const llama_model_loader * ml, gguf_context * meta, ggml_context * ctx) {
(void) ctx;
if (!detect_ollama_qwen3vl(meta, ctx)) return;
LLAMA_LOG_INFO("%s: detected Ollama-format qwen3vl GGUF; applying compatibility fixes\n", __func__);
// Inject required M-RoPE sections (Qwen3-VL-8B default).
const int32_t mrope[4] = { 24, 20, 20, 0 };
gguf_set_arr_data(meta, "qwen3vl.rope.dimension_sections",
GGUF_TYPE_INT32, mrope, 4);
// Derive n_deepstack_layers from the deepstack indexes array length.
const int64_t ds_kid = gguf_find_key(meta, "qwen3vl.vision.deepstack_visual_indexes");
const uint32_t n_ds = (ds_kid >= 0) ? (uint32_t) gguf_get_arr_n(meta, ds_kid) : 0;
inject_u32_if_missing(meta, "qwen3vl.n_deepstack_layers", n_ds);
add_skip_prefix(ml, "v.");
add_skip_prefix(ml, "mm.");
}
// =========================================================================
// gemma3 (clip side)
// =========================================================================
@ -1410,6 +1505,234 @@ void handle_mistral3_clip(gguf_context * meta, ggml_context * ctx) {
promote_tensor_to_f32(meta, ctx, "v.patch_embd.weight");
}
// =========================================================================
// qwen25vl (clip side — Qwen2.5-VL vision tower + merger)
// =========================================================================
//
// Ollama qwen25vl has a vision tower with mostly upstream-compatible
// tensor names. Five tensor renames + KV translation:
//
// v.merger.ln_q.weight → v.post_ln.weight (post-tower norm)
// v.merger.mlp.0.{weight,bias} → mm.0.{weight,bias} (LLaVA proj 0)
// v.merger.mlp.2.{weight,bias} → mm.2.{weight,bias} (LLaVA proj 2)
// v.patch_embd_0.weight → v.patch_embd.weight (slice 0)
// v.patch_embd_1.weight → v.patch_embd.weight.1 (slice 1)
//
// The KV side maps qwen25vl.vision.* → clip.vision.*, sets the projector
// type and use_silu, derives n_wa_pattern from fullatt_block_indexes[0]+1
// (per upstream's qwen2.5vl converter), and supplies image_size=560 and
// projection_dim (= text embedding_length, qwen25vl.embedding_length).
void handle_qwen25vl_clip(gguf_context * meta, ggml_context * ctx) {
LLAMA_LOG_INFO("%s: detected Ollama-format qwen25vl GGUF used as mmproj; translating\n", __func__);
copy_u32_kv(meta, "qwen25vl.vision.attention.head_count", "clip.vision.attention.head_count");
copy_f32_kv(meta, "qwen25vl.vision.attention.layer_norm_epsilon", "clip.vision.attention.layer_norm_epsilon");
copy_u32_kv(meta, "qwen25vl.vision.block_count", "clip.vision.block_count");
copy_u32_kv(meta, "qwen25vl.vision.embedding_length", "clip.vision.embedding_length");
copy_u32_kv(meta, "qwen25vl.vision.num_channels", "clip.vision.num_channels");
copy_u32_kv(meta, "qwen25vl.vision.patch_size", "clip.vision.patch_size");
copy_u32_kv(meta, "qwen25vl.vision.spatial_merge_size", "clip.vision.spatial_merge_size");
copy_u32_kv(meta, "qwen25vl.vision.window_size", "clip.vision.window_size");
copy_u32_kv(meta, "qwen25vl.embedding_length", "clip.vision.projection_dim");
// Derive feed_forward_length from the actual ffn_up shape if missing.
if (!has_key(meta, "clip.vision.feed_forward_length")) {
if (ggml_tensor * t = ggml_get_tensor(ctx, "v.blk.0.ffn_up.weight")) {
gguf_set_val_u32(meta, "clip.vision.feed_forward_length", (uint32_t) t->ne[1]);
}
}
// Derive n_wa_pattern from fullatt_block_indexes[0]+1 (upstream convention).
{
const int64_t kid = gguf_find_key(meta, "qwen25vl.vision.fullatt_block_indexes");
if (kid >= 0 && gguf_get_arr_n(meta, kid) >= 1) {
const auto * arr = static_cast<const int32_t *>(gguf_get_arr_data(meta, kid));
gguf_set_val_u32(meta, "clip.vision.n_wa_pattern", (uint32_t)(arr[0] + 1));
}
}
// Default image_size = 560 (Qwen2VLVisionModel default, no image_size in HF config).
inject_u32_if_missing(meta, "clip.vision.image_size", 560);
// Standard preprocessor mean/std for Qwen2.5-VL (CLIP convention).
static const float kMean[3] = {0.48145466f, 0.4578275f, 0.40821073f};
static const float kStd [3] = {0.26862954f, 0.26130258f, 0.27577711f};
inject_f32_arr_if_missing(meta, "clip.vision.image_mean", kMean, 3);
inject_f32_arr_if_missing(meta, "clip.vision.image_std", kStd, 3);
inject_bool_if_missing(meta, "clip.has_vision_encoder", true);
inject_bool_if_missing(meta, "clip.use_silu", true);
gguf_set_val_str(meta, "clip.projector_type", "qwen2.5vl_merger");
gguf_set_val_str(meta, "general.architecture", "clip");
// Tensor renames.
rename_tensor(meta, ctx, "v.merger.ln_q.weight", "v.post_ln.weight");
rename_tensor(meta, ctx, "v.merger.mlp.0.weight", "mm.0.weight");
rename_tensor(meta, ctx, "v.merger.mlp.0.bias", "mm.0.bias");
rename_tensor(meta, ctx, "v.merger.mlp.2.weight", "mm.2.weight");
rename_tensor(meta, ctx, "v.merger.mlp.2.bias", "mm.2.bias");
rename_tensor(meta, ctx, "v.patch_embd_0.weight", "v.patch_embd.weight");
rename_tensor(meta, ctx, "v.patch_embd_1.weight", "v.patch_embd.weight.1");
// Metal IM2COL needs F32 patch_embd (same issue as gemma3 / glmocr).
promote_tensor_to_f32(meta, ctx, "v.patch_embd.weight");
promote_tensor_to_f32(meta, ctx, "v.patch_embd.weight.1");
}
// =========================================================================
// qwen3vl (clip side — Qwen3-VL vision tower + deepstack adapters)
// =========================================================================
//
// Ollama qwen3vl monolithic GGUF embeds the vision tower (27 blocks),
// deepstack merger adapters (3 of them, indexed 0/1/2), and the merger
// MLP. Compared to upstream's qwen3vl_merger expectations:
//
// * Per-block leaf renames: norm1→ln1, norm2→ln2, mlp.linear_fc1→ffn_up,
// mlp.linear_fc2→ffn_down.
// * Merger renames: v.merger.norm→v.post_ln, v.merger.linear_fc1→mm.0,
// v.merger.linear_fc2→mm.2 (LLaVA proj).
// * Deepstack remap: v.deepstack_merger.X.* → v.deepstack.{indexes[X]}.*
// where indexes is qwen3vl.vision.deepstack_visual_indexes (e.g.
// [8, 16, 24] for Qwen3-VL-8B). The leaf names also rename:
// linear_fc1→fc1, linear_fc2→fc2.
// * Per-block QKV merge: upstream's qwen3vl graph reads a single
// attn_qkv tensor (shape [hidden, 3*hidden]); Ollama stores separate
// Q/K/V. Same merge as qwen35moe — reuse that helper.
// * Patch embed: split the merged Conv3D weight [W,H,T,OUT*IN] into two
// Conv2D weights [W,H,IN,OUT], one per temporal slice. Same logic and
// donor (orphaned attn_k from QKV merge) as qwen35moe; reuse that helper.
void handle_qwen3vl_clip(gguf_context * meta, ggml_context * ctx) {
LLAMA_LOG_INFO("%s: detected Ollama-format qwen3vl GGUF used as mmproj; translating\n", __func__);
copy_u32_kv(meta, "qwen3vl.vision.attention.head_count", "clip.vision.attention.head_count");
copy_f32_kv(meta, "qwen3vl.vision.attention.layer_norm_epsilon", "clip.vision.attention.layer_norm_epsilon");
copy_u32_kv(meta, "qwen3vl.vision.block_count", "clip.vision.block_count");
copy_u32_kv(meta, "qwen3vl.vision.embedding_length", "clip.vision.embedding_length");
copy_u32_kv(meta, "qwen3vl.vision.num_channels", "clip.vision.num_channels");
copy_u32_kv(meta, "qwen3vl.vision.patch_size", "clip.vision.patch_size");
copy_u32_kv(meta, "qwen3vl.vision.spatial_merge_size", "clip.vision.spatial_merge_size");
copy_u32_kv(meta, "qwen3vl.embedding_length", "clip.vision.projection_dim");
// Derive feed_forward_length from ffn_up / mlp.linear_fc1 shape.
if (!has_key(meta, "clip.vision.feed_forward_length")) {
if (ggml_tensor * t = ggml_get_tensor(ctx, "v.blk.0.mlp.linear_fc1.weight")) {
gguf_set_val_u32(meta, "clip.vision.feed_forward_length", (uint32_t) t->ne[1]);
}
}
// image_size = sqrt(num_position_embeddings) * patch_size. v.pos_embed
// shape is [n_embd, num_positions], so num_positions = ne[1].
if (!has_key(meta, "clip.vision.image_size")) {
ggml_tensor * pe = ggml_get_tensor(ctx, "v.pos_embed.weight");
const int64_t patch_kid = gguf_find_key(meta, "qwen3vl.vision.patch_size");
if (pe && patch_kid >= 0) {
const uint32_t patch = gguf_get_val_u32(meta, patch_kid);
const uint32_t side = (uint32_t) std::sqrt((double) pe->ne[1]);
gguf_set_val_u32(meta, "clip.vision.image_size", side * patch);
}
}
// Image mean/std (Qwen3-VL uses [0.5, 0.5, 0.5] for both, per HF config).
static const float kHalfHalfHalf[3] = {0.5f, 0.5f, 0.5f};
inject_f32_arr_if_missing(meta, "clip.vision.image_mean", kHalfHalfHalf, 3);
inject_f32_arr_if_missing(meta, "clip.vision.image_std", kHalfHalfHalf, 3);
inject_bool_if_missing(meta, "clip.has_vision_encoder", true);
inject_bool_if_missing(meta, "clip.use_gelu", true);
gguf_set_val_str(meta, "clip.projector_type", "qwen3vl_merger");
gguf_set_val_str(meta, "general.architecture", "clip");
// Per-block QKV merge: upstream's qwen3vl_merger graph reads a single
// `v.blk.X.attn_qkv.weight` (shape [hidden, 3*hidden]) — Ollama stores
// separate Q/K/V. Unlike qwen35moe (where Q/K/V are uniformly F16), the
// qwen3vl Ollama blob can mix F16 (Q/K) with Q8_0 (V), so a raw byte
// concat fails. Dequantize all three to F32 and concat in F32 instead.
// After the merge, attn_k/attn_v become orphaned in the clip ctx, which
// the patch_embed split then reclaims for `v.patch_embd.weight.1`.
const int64_t n_blocks_key = gguf_find_key(meta, "clip.vision.block_count");
const uint32_t n_blocks = n_blocks_key >= 0 ? gguf_get_val_u32(meta, n_blocks_key) : 27;
for (uint32_t b = 0; b < n_blocks; ++b) {
char q[64], k[64], v[64], qb[64], kb[64], vb[64], qkv_w[64], qkv_b[64];
std::snprintf(q, sizeof(q), "v.blk.%u.attn_q.weight", b);
std::snprintf(k, sizeof(k), "v.blk.%u.attn_k.weight", b);
std::snprintf(v, sizeof(v), "v.blk.%u.attn_v.weight", b);
std::snprintf(qb, sizeof(qb), "v.blk.%u.attn_q.bias", b);
std::snprintf(kb, sizeof(kb), "v.blk.%u.attn_k.bias", b);
std::snprintf(vb, sizeof(vb), "v.blk.%u.attn_v.bias", b);
std::snprintf(qkv_w, sizeof(qkv_w), "v.blk.%u.attn_qkv.weight", b);
std::snprintf(qkv_b, sizeof(qkv_b), "v.blk.%u.attn_qkv.bias", b);
if (!ggml_get_tensor(ctx, q)) continue;
register_concat_load_to_f32(meta, ctx, qkv_w, {q, k, v});
register_concat_load_to_f32(meta, ctx, qkv_b, {qb, kb, vb});
rename_tensor(meta, ctx, q, qkv_w);
if (ggml_tensor * t = ggml_get_tensor(ctx, qkv_w)) {
set_tensor_shape(t, {t->ne[0], t->ne[1] * 3});
set_tensor_type (t, GGML_TYPE_F32);
}
rename_tensor(meta, ctx, qb, qkv_b);
if (ggml_tensor * t = ggml_get_tensor(ctx, qkv_b)) {
set_tensor_shape(t, {t->ne[0] * 3});
set_tensor_type (t, GGML_TYPE_F32);
}
}
// Patch embed split runs BEFORE per-block substring renames so it can
// find the source by name `v.patch_embed.weight`. Same shape as
// qwen35moe (16x16 patches, 2 temporal slices, 3 in_ch, 1152 out_ch).
register_qwen35moe_patch_embed_split(meta, ctx);
// Top-level renames (full names) — must run before substring per-block
// renames so .linear_fc1 substring matches only inside .mlp.linear_fc1.
rename_tensor(meta, ctx, "v.merger.norm.weight", "v.post_ln.weight");
rename_tensor(meta, ctx, "v.merger.norm.bias", "v.post_ln.bias");
rename_tensor(meta, ctx, "v.merger.linear_fc1.weight", "mm.0.weight");
rename_tensor(meta, ctx, "v.merger.linear_fc1.bias", "mm.0.bias");
rename_tensor(meta, ctx, "v.merger.linear_fc2.weight", "mm.2.weight");
rename_tensor(meta, ctx, "v.merger.linear_fc2.bias", "mm.2.bias");
rename_tensor(meta, ctx, "v.patch_embed.bias", "v.patch_embd.bias");
rename_tensor(meta, ctx, "v.pos_embed.weight", "v.position_embd.weight");
// Deepstack remap: v.deepstack_merger.X.{norm,linear_fc1,linear_fc2}.{weight,bias}
// → v.deepstack.{deepstack_visual_indexes[X]}.{norm,fc1,fc2}.{weight,bias}.
// Upstream stores deepstack tensors at the absolute clip layer index
// (e.g. v.deepstack.8.* for the adapter that fires after layer 8).
{
const int64_t ds_kid = gguf_find_key(meta, "qwen3vl.vision.deepstack_visual_indexes");
if (ds_kid >= 0) {
const size_t n = gguf_get_arr_n(meta, ds_kid);
const auto * idx = static_cast<const int32_t *>(gguf_get_arr_data(meta, ds_kid));
for (size_t i = 0; i < n; ++i) {
char from[GGML_MAX_NAME], to[GGML_MAX_NAME];
auto rn = [&](const char * leaf_from, const char * leaf_to, const char * suffix) {
std::snprintf(from, sizeof(from), "v.deepstack_merger.%zu.%s.%s", i, leaf_from, suffix);
std::snprintf(to, sizeof(to), "v.deepstack.%d.%s.%s", idx[i], leaf_to, suffix);
rename_tensor(meta, ctx, from, to);
};
rn("norm", "norm", "weight");
rn("norm", "norm", "bias");
rn("linear_fc1", "fc1", "weight");
rn("linear_fc1", "fc1", "bias");
rn("linear_fc2", "fc2", "weight");
rn("linear_fc2", "fc2", "bias");
}
}
}
// Per-block substring renames (safe — these substrings now only appear
// in v.blk.X.* paths after the top-level/deepstack renames above).
rename_tensors_containing(meta, ctx, ".norm1", ".ln1");
rename_tensors_containing(meta, ctx, ".norm2", ".ln2");
rename_tensors_containing(meta, ctx, ".mlp.linear_fc1", ".ffn_up");
rename_tensors_containing(meta, ctx, ".mlp.linear_fc2", ".ffn_down");
// Position embed should be F32 (precision matters for resize_position_embeddings).
promote_tensor_to_f32(meta, ctx, "v.position_embd.weight");
}
} // anonymous namespace
// =========================================================================
@ -1437,6 +1760,10 @@ void translate_metadata(const llama_model_loader * ml,
if (arch_name == "gptoss") handle_gptoss (ml, meta, ctx, arch_name);
if (arch_name == "lfm2") handle_lfm2 (ml, meta, ctx);
if (arch_name == "mistral3") handle_mistral3 (ml, meta, ctx);
// qwen25vl must run before any qwen2vl-targeted handler — it switches
// arch_name to "qwen2vl" so the loader uses qwen2vl.* keys.
if (arch_name == "qwen25vl") handle_qwen25vl (ml, meta, ctx, arch_name);
if (arch_name == "qwen3vl") handle_qwen3vl (ml, meta, ctx);
if (arch_name == "deepseekocr") handle_deepseekocr (ml, meta, ctx, arch_name);
if (arch_name == "nemotron_h_moe") handle_nemotron_h_moe(ml, meta, ctx);
if (arch_name == "llama4") handle_llama4 (ml, meta, ctx);
@ -1477,6 +1804,14 @@ void translate_clip_metadata(gguf_context * meta, ggml_context * ctx) {
handle_glmocr_clip(meta, ctx);
return;
}
if (detect_ollama_qwen25vl(meta)) {
handle_qwen25vl_clip(meta, ctx);
return;
}
if (detect_ollama_qwen3vl(meta, ctx)) {
handle_qwen3vl_clip(meta, ctx);
return;
}
}
bool should_skip_tensor(const llama_model_loader * ml, const char * tensor_name) {

View file

@ -436,6 +436,8 @@ func NewLlamaServerRunner(
"gemma3": true,
"gemma4": true,
"qwen35moe": true,
"qwen25vl": true,
"qwen3vl": true,
"mistral3": true,
"deepseekocr": true,
"glmocr": true,