llama/compat: add qwen35moe vision (clip) support

Extends the compat layer with the vision side for Ollama's monolithic
qwen3.5 blobs. All changes in llama/compat/ — no new upstream patch edits.

New generic infra (reused by gemma3's existing promotion):
  - LoadOp registry (g_loadops). Any dest tensor whose name is registered
    gets its bytes produced by a closure instead of being read straight
    from disk. maybe_load_tensor consults it.
  - promote_tensor_to_f32(meta, ctx, name) now captures the source offset
    at registration time and becomes a LoadOp. Gemma3 already migrated.
  - register_concat_load(meta, dest, {srcs...}) captures the file offsets
    of N source tensors and registers a LoadOp that concatenates them.
    Assumes sources concatenate along their slowest ggml axis — which in
    C order means the dest bytes are src[0] || src[1] || ... .
  - set_tensor_shape / set_tensor_type helpers for in-place edits.

qwen35moe clip handler (handle_qwen35moe_clip):
  - Detection reuses detect_ollama_qwen35moe; additionally requires
    embedded v.* tensors so we don't fire for text-only files.
  - KV synth: clip.vision.* from qwen35moe.vision.* + sensible defaults
    (feed_forward_length=4304, image_size=768, layer_norm_epsilon=1e-6,
    is_deepstack_layers=false[27], image_mean/std=[0.5,0.5,0.5]).
  - Arch rewrite: general.architecture=clip, projector_type=qwen3vl_merger.
  - QKV merge per block (27x): captures q/k/v file offsets, registers a
    concat LoadOp, renames attn_q -> attn_qkv and widens its shape from
    [hidden, hidden] to [hidden, 3*hidden].
  - patch_embed split: source [16,16,2,3456] F16 -> two dests
    [16,16,3,1152] F32, permuting (c_out*3+c_in) packed_c back into
    separate c_in/c_out dims. Matches upstream convert_hf's
    Qwen3VLVisionModel.modify_tensors split.
  - Tensor renames (substring-matched): pos_embed -> position_embd,
    merger.norm -> post_ln, merger.linear_fc1/2 -> mm.0/mm.2,
    mlp.linear_fc1/2 -> ffn_up/ffn_down, norm1/2 -> ln1/ln2.
  - F16 -> F32 promote for v.position_embd.weight.

Ctx-pool trick for the sibling tensor:
  clip.cpp sizes its ggml_context for exactly the gguf's tensor count
  (+1). ggml_new_tensor to add v.patch_embd.weight.1 overflows. Since
  v.blk.0.attn_k.weight is orphaned after the QKV merge (clip only
  requests the merged attn_qkv), steal that slot: rename it to
  v.patch_embd.weight.1 and reshape to [16,16,3,1152] F32. Its original
  file offset is ignored; the LoadOp we register overrides the read.

Go side: adds qwen35moe to the auto-mmproj arch allowlist. ollama now
passes the monolithic blob as both --model and --mmproj for qwen3.5.

Verified end-to-end: ollama run qwen3.5:35b-a3b-q4_K_M with an image
correctly describes the image ("screenshot of a chat interface...
'open the browser, open never gonna give you up on youtube'..."). Text
inference still works on the same blob.
This commit is contained in:
jmorganca 2026-04-19 12:42:28 -07:00
parent 8fa6648650
commit db0c745308
2 changed files with 343 additions and 54 deletions

View file

@ -120,25 +120,119 @@ void add_skip_prefix(const llama_model_loader * ml, std::string prefix) {
}
// -------------------------------------------------------------------------
// F16 -> F32 tensor promotion (needed for Metal IM2COL on gemma3 conv weights)
// Load-time tensor transforms (registry consumed by maybe_load_tensor)
//
// Each registered op produces the final bytes for a single destination
// tensor by reading + transforming bytes from the source GGUF file.
// Used for F16->F32 promotion, QKV merging, and patch-embed splitting.
// -------------------------------------------------------------------------
std::mutex g_promote_mutex;
std::unordered_set<std::string> g_promote_f16_to_f32;
struct LoadOp {
// apply() reads what it needs from `src_file` and fills `dst` (dst_size
// bytes). Returns false on failure.
std::function<bool(const char * src_file, void * dst, size_t dst_size)> apply;
const char * description;
};
// Set a tensor's type + strides in a ggml_context. The companion to this is
// the `maybe_load_tensor` read hook, which converts F16 bytes from disk into
// the newly-wider F32 buffer at load time.
void promote_tensor_to_f32(ggml_context * ctx, const char * name) {
ggml_tensor * t = ggml_get_tensor(ctx, name);
if (!t) return;
t->type = GGML_TYPE_F32;
t->nb[0] = ggml_type_size(GGML_TYPE_F32);
t->nb[1] = t->nb[0] * (t->ne[0] / ggml_blck_size(GGML_TYPE_F32));
std::mutex g_loadop_mutex;
std::unordered_map<std::string, LoadOp> g_loadops;
void register_load_op(std::string dest_name, LoadOp op) {
std::lock_guard<std::mutex> lk(g_loadop_mutex);
g_loadops[std::move(dest_name)] = std::move(op);
}
// Helper: read `size` bytes at `offset` from `path` into `dst`.
bool read_at(const char * path, size_t offset, void * dst, size_t size) {
FILE * f = std::fopen(path, "rb");
if (!f) return false;
bool ok = (std::fseek(f, (long) offset, SEEK_SET) == 0
&& std::fread(dst, 1, size, f) == size);
std::fclose(f);
return ok;
}
// Capture a tensor's absolute file offset BEFORE any rename or reshape.
size_t tensor_file_offset(const gguf_context * meta, const char * name) {
const int64_t id = gguf_find_tensor(meta, name);
if (id < 0) return 0;
return gguf_get_data_offset(meta) + gguf_get_tensor_offset(meta, id);
}
// Set a tensor's type and recompute strides in a ggml_context.
void set_tensor_type(ggml_tensor * t, ggml_type type) {
t->type = type;
t->nb[0] = ggml_type_size(type);
t->nb[1] = t->nb[0] * (t->ne[0] / ggml_blck_size(type));
for (int i = 2; i < GGML_MAX_DIMS; ++i) t->nb[i] = t->nb[i - 1] * t->ne[i - 1];
}
std::lock_guard<std::mutex> lk(g_promote_mutex);
g_promote_f16_to_f32.insert(name);
// Set a tensor's shape and recompute strides in a ggml_context.
void set_tensor_shape(ggml_tensor * t, std::initializer_list<int64_t> shape) {
int i = 0;
for (auto v : shape) t->ne[i++] = v;
for (; i < GGML_MAX_DIMS; ++i) t->ne[i] = 1;
set_tensor_type(t, t->type);
}
// Promote a tensor F16 -> F32. The disk bytes stay F16; we register a
// load op that converts on read.
void promote_tensor_to_f32(gguf_context * meta, ggml_context * ctx, const char * name) {
const int64_t tid = gguf_find_tensor(meta, name);
if (tid < 0) return;
ggml_tensor * t = ggml_get_tensor(ctx, name);
if (!t || t->type != GGML_TYPE_F16) return;
const size_t src_offset = tensor_file_offset(meta, name);
const size_t n_elem = ggml_nelements(t);
const size_t src_size = n_elem * sizeof(uint16_t);
set_tensor_type(t, GGML_TYPE_F32);
register_load_op(name, LoadOp{
[src_offset, src_size, n_elem](const char * path, void * dst, size_t dst_size) {
(void) dst_size;
std::vector<uint8_t> src(src_size);
if (!read_at(path, src_offset, src.data(), src_size)) return false;
const uint16_t * sp = reinterpret_cast<const uint16_t *>(src.data());
float * dp = reinterpret_cast<float *>(dst);
for (size_t i = 0; i < n_elem; ++i) dp[i] = ggml_fp16_to_fp32(sp[i]);
return true;
},
"F16->F32 promote",
});
}
// Concatenate N source tensors into one destination tensor. Captures
// source file offsets and sizes at registration time so later renames or
// reshapes don't affect the read. Layout assumption: the source tensors
// concatenate cleanly along their slowest dim, which in C/ggml order
// means the destination's bytes are just src[0] || src[1] || ... .
void register_concat_load(const gguf_context * meta, std::string dest_name,
const std::vector<std::string> & src_names) {
std::vector<std::pair<size_t, size_t>> regions; // (offset, size)
regions.reserve(src_names.size());
for (const auto & n : src_names) {
const int64_t id = gguf_find_tensor(meta, n.c_str());
if (id < 0) return; // bail; downstream will fail loudly
regions.emplace_back(
gguf_get_data_offset(meta) + gguf_get_tensor_offset(meta, id),
gguf_get_tensor_size(meta, id));
}
register_load_op(std::move(dest_name), LoadOp{
[regions](const char * path, void * dst, size_t dst_size) {
size_t total = 0;
for (auto & [_, sz] : regions) total += sz;
if (total != dst_size) return false;
uint8_t * p = static_cast<uint8_t *>(dst);
for (auto & [off, sz] : regions) {
if (!read_at(path, off, p, sz)) return false;
p += sz;
}
return true;
},
"concat sources",
});
}
// -------------------------------------------------------------------------
@ -228,21 +322,22 @@ void handle_gemma3(const llama_model_loader * ml, gguf_context * meta, ggml_cont
// -------------------------------------------------------------------------
bool detect_ollama_qwen35moe(const gguf_context * meta, const ggml_context * ctx) {
// Strongest markers: vision KVs live in-file (upstream splits to mmproj)
// or MTP tensors are present (upstream strips them).
if (has_key(meta, "qwen35moe.vision.block_count")) return true;
if (has_key(meta, "qwen35moe.image_token_id")) return true;
if (has_key(meta, "qwen35moe.ssm.v_head_reordered")) return true;
if (has_key(meta, "qwen35moe.feed_forward_length")) return true; // upstream omits (=0 stored)
if (has_key(meta, "qwen35moe.rope.mrope_interleaved")) return true;
if (any_tensor_with_prefix(ctx, "mtp.")) return true;
if (any_tensor_with_prefix(ctx, "v.")) return true;
// Require the file to declare itself qwen35moe first.
const int64_t arch_kid = gguf_find_key(meta, "general.architecture");
if (arch_kid < 0) return false;
if (std::strcmp(gguf_get_val_str(meta, arch_kid), "qwen35moe") != 0) return false;
// Scalar-vs-array: upstream writes head_count_kv as UINT32; Ollama wrote
// it as a per-layer array. has_key alone can't tell us that, but a mismatch
// shows up as a type-mismatch crash downstream, which is worse than over-
// detecting. If any of the above markers fire we'll normalize it below.
return false;
// Then: at least one Ollama-ism. Upstream qwen35moe text files have none
// of these — the vision KVs move to mmproj, MTP tensors are dropped,
// head_count_kv is a scalar not an array, and the various extra rope /
// ssm KVs below are either absent or stored differently.
return has_key(meta, "qwen35moe.vision.block_count")
|| has_key(meta, "qwen35moe.image_token_id")
|| has_key(meta, "qwen35moe.ssm.v_head_reordered")
|| has_key(meta, "qwen35moe.feed_forward_length")
|| has_key(meta, "qwen35moe.rope.mrope_interleaved")
|| any_tensor_with_prefix(ctx, "mtp.")
|| any_tensor_with_prefix(ctx, "v.");
}
void handle_qwen35moe(const llama_model_loader * ml, gguf_context * meta, ggml_context * ctx) {
@ -357,8 +452,204 @@ void handle_gemma3_clip(gguf_context * meta, ggml_context * ctx) {
// Upstream stores patch_embd/position_embd as F32 (Gemma3VisionModel
// tensor_force_quant); Ollama stored F16. Metal's IM2COL convolution
// requires F32, so promote both at load time.
promote_tensor_to_f32(ctx, "v.patch_embd.weight");
promote_tensor_to_f32(ctx, "v.position_embd.weight");
promote_tensor_to_f32(meta, ctx, "v.patch_embd.weight");
promote_tensor_to_f32(meta, ctx, "v.position_embd.weight");
}
// -------------------------------------------------------------------------
// qwen35moe (clip side)
// -------------------------------------------------------------------------
// Substring renames. One entry handles both `.weight` and `.bias` variants.
constexpr std::pair<const char *, const char *> kQwen35moeClipRenames[] = {
{"v.pos_embed", "v.position_embd"},
{"v.patch_embed", "v.patch_embd"},
{"v.merger.norm", "v.post_ln"},
{"v.merger.linear_fc1", "mm.0"},
{"v.merger.linear_fc2", "mm.2"},
{".mlp.linear_fc1", ".ffn_up"},
{".mlp.linear_fc2", ".ffn_down"},
{".norm1", ".ln1"},
{".norm2", ".ln2"},
};
// Register a QKV merge for a single block: Ollama has separate attn_q,
// attn_k, attn_v tensors; upstream wants them concatenated along their
// slow axis. Capture source file offsets BEFORE renaming.
void register_qwen35moe_qkv_merge(gguf_context * meta, ggml_context * ctx, int block_idx) {
char qname[64], kname[64], vname[64];
std::snprintf(qname, sizeof(qname), "v.blk.%d.attn_q.weight", block_idx);
std::snprintf(kname, sizeof(kname), "v.blk.%d.attn_k.weight", block_idx);
std::snprintf(vname, sizeof(vname), "v.blk.%d.attn_v.weight", block_idx);
const ggml_tensor * q = ggml_get_tensor(ctx, qname);
if (!q) return; // not a qwen35moe vision block
// Set up the destination tensor. We rename attn_q -> attn_qkv and
// widen its slow axis from [1152, 1152] to [1152, 3456] (3 * hidden).
char qkv_w[64], qkv_b[64], qbias[64], kbias[64], vbias[64];
std::snprintf(qkv_w, sizeof(qkv_w), "v.blk.%d.attn_qkv.weight", block_idx);
std::snprintf(qkv_b, sizeof(qkv_b), "v.blk.%d.attn_qkv.bias", block_idx);
std::snprintf(qbias, sizeof(qbias), "v.blk.%d.attn_q.bias", block_idx);
std::snprintf(kbias, sizeof(kbias), "v.blk.%d.attn_k.bias", block_idx);
std::snprintf(vbias, sizeof(vbias), "v.blk.%d.attn_v.bias", block_idx);
// Capture source offsets for the concat BEFORE renaming.
register_concat_load(meta, qkv_w, {qname, kname, vname});
register_concat_load(meta, qkv_b, {qbias, kbias, vbias});
// Rename attn_q -> attn_qkv and widen shape.
rename_tensor(meta, ctx, qname, qkv_w);
if (ggml_tensor * t = ggml_get_tensor(ctx, qkv_w)) {
set_tensor_shape(t, {t->ne[0], t->ne[1] * 3});
}
// Rename attn_q.bias -> attn_qkv.bias and widen from [1152] to [3456].
rename_tensor(meta, ctx, qbias, qkv_b);
if (ggml_tensor * t = ggml_get_tensor(ctx, qkv_b)) {
set_tensor_shape(t, {t->ne[0] * 3});
}
}
// Register the patch_embed reshape + split + F16->F32.
//
// Source: one Ollama tensor `v.patch_embed.weight`, ggml shape
// [h=16, w=16, t=2, packed=3456] F16
// where `packed` is the PyTorch row-major flattening of HF's
// [out_c=1152, in_c=3, ...] dim pair, so packed_c = c_out*3 + c_in.
//
// Destination: two upstream tensors with ggml shape
// [h=16, w=16, c_in=3, c_out=1152] F32 each,
// one per temporal slice. Matches upstream's
// yield data_torch[:, :, 0, ...] # PyTorch [1152, 3, 16, 16]
// yield data_torch[:, :, 1, ...]
// which reverses to ggml ne=[16, 16, 3, 1152] per slice.
//
// For each output element (h, w, c_in, c_out):
// src_idx = h + w*W + t*W*H + (c_out*C_in + c_in)*W*H*T
// dst_idx = h + w*W + c_in*W*H + c_out*W*H*C_in
void register_qwen35moe_patch_embed_split(gguf_context * meta, ggml_context * ctx) {
const char * src_name = "v.patch_embed.weight";
const int64_t tid = gguf_find_tensor(meta, src_name);
if (tid < 0) return;
ggml_tensor * src_t = ggml_get_tensor(ctx, src_name);
if (!src_t) return;
const size_t src_offset = tensor_file_offset(meta, src_name);
const size_t src_size = ggml_nelements(src_t) * sizeof(uint16_t);
constexpr int H = 16, W = 16, T = 2, CIN = 3, COUT = 1152;
constexpr size_t HW = (size_t) H * W;
auto make_slice_op = [=](int slice_idx) {
return LoadOp{
[=](const char * path, void * dst, size_t dst_size) {
if (dst_size != (size_t) H * W * CIN * COUT * sizeof(float)) return false;
std::vector<uint8_t> src(src_size);
if (!read_at(path, src_offset, src.data(), src_size)) return false;
const uint16_t * sp = reinterpret_cast<const uint16_t *>(src.data());
float * dp = reinterpret_cast<float *>(dst);
for (int c_out = 0; c_out < COUT; ++c_out) {
for (int c_in = 0; c_in < CIN; ++c_in) {
const size_t packed = (size_t) c_out * CIN + c_in;
const uint16_t * in_base = sp + HW * (slice_idx + T * packed);
float * out_base = dp + HW * (c_in + CIN * c_out);
for (size_t i = 0; i < HW; ++i) out_base[i] = ggml_fp16_to_fp32(in_base[i]);
}
}
return true;
},
slice_idx == 0 ? "patch_embed slice 0 (permute+F16->F32)"
: "patch_embed slice 1 (permute+F16->F32)",
};
};
// Rename src -> `v.patch_embd.weight`, reshape to dest layout, register
// the slice-0 load op against its new name.
rename_tensor(meta, ctx, src_name, "v.patch_embd.weight");
ggml_tensor * dest0 = ggml_get_tensor(ctx, "v.patch_embd.weight");
if (!dest0) return;
set_tensor_shape(dest0, {16, 16, 3, 1152});
set_tensor_type (dest0, GGML_TYPE_F32);
register_load_op("v.patch_embd.weight", make_slice_op(0));
// We need a sibling tensor `v.patch_embd.weight.1` in ctx_meta so clip's
// get_tensor() can find it. ggml_new_tensor() would blow ctx_meta's
// fixed memory pool (sized exactly for the original tensor count).
// Instead, steal an unused slot: after the QKV merge, `v.blk.0.attn_k`
// is orphaned in ctx_meta — clip never looks it up because it asks for
// the merged `attn_qkv`. Rename it to our sibling and reshape.
rename_tensor(meta, ctx, "v.blk.0.attn_k.weight", "v.patch_embd.weight.1");
ggml_tensor * dest1 = ggml_get_tensor(ctx, "v.patch_embd.weight.1");
if (!dest1) return;
set_tensor_shape(dest1, {16, 16, 3, 1152});
set_tensor_type (dest1, GGML_TYPE_F32);
register_load_op("v.patch_embd.weight.1", make_slice_op(1));
}
void handle_qwen35moe_clip(gguf_context * meta, ggml_context * ctx) {
LLAMA_LOG_INFO("%s: detected Ollama-format qwen35moe GGUF used as mmproj; translating\n", __func__);
// KV synthesis: clip.vision.* from qwen35moe.vision.* (plus defaults).
copy_u32_kv(meta, "qwen35moe.vision.block_count", "clip.vision.block_count");
copy_u32_kv(meta, "qwen35moe.vision.embedding_length", "clip.vision.embedding_length");
copy_u32_kv(meta, "qwen35moe.vision.attention.head_count", "clip.vision.attention.head_count");
copy_u32_kv(meta, "qwen35moe.vision.patch_size", "clip.vision.patch_size");
copy_u32_kv(meta, "qwen35moe.vision.spatial_merge_size", "clip.vision.spatial_merge_size");
copy_u32_kv(meta, "qwen35moe.vision.num_channels", "clip.vision.num_channels");
// projection_dim is the text model's embedding_length (merger out dim).
copy_u32_kv(meta, "qwen35moe.embedding_length", "clip.vision.projection_dim");
// Ollama omitted these; defaults match reference (ref_Q3.5-35B-A3B mmproj).
if (!has_key(meta, "clip.vision.feed_forward_length"))
gguf_set_val_u32(meta, "clip.vision.feed_forward_length", 4304);
if (!has_key(meta, "clip.vision.image_size"))
gguf_set_val_u32(meta, "clip.vision.image_size", 768);
if (!has_key(meta, "clip.vision.attention.layer_norm_epsilon"))
gguf_set_val_f32(meta, "clip.vision.attention.layer_norm_epsilon", 1e-6f);
// image_mean / image_std — constants for qwen3.5 vision.
if (!has_key(meta, "clip.vision.image_mean")) {
const float v[3] = {0.5f, 0.5f, 0.5f};
gguf_set_arr_data(meta, "clip.vision.image_mean", GGUF_TYPE_FLOAT32, v, 3);
}
if (!has_key(meta, "clip.vision.image_std")) {
const float v[3] = {0.5f, 0.5f, 0.5f};
gguf_set_arr_data(meta, "clip.vision.image_std", GGUF_TYPE_FLOAT32, v, 3);
}
// is_deepstack_layers: qwen3.5 35B has no deepstack layers. Set a
// 27-element array of False matching clip.vision.block_count.
if (!has_key(meta, "clip.vision.is_deepstack_layers")) {
uint8_t bools[27] = {};
gguf_set_arr_data(meta, "clip.vision.is_deepstack_layers", GGUF_TYPE_BOOL, bools, 27);
}
if (!has_key(meta, "clip.has_vision_encoder")) gguf_set_val_bool(meta, "clip.has_vision_encoder", true);
if (!has_key(meta, "clip.use_gelu")) gguf_set_val_bool(meta, "clip.use_gelu", true);
gguf_set_val_str(meta, "clip.projector_type", "qwen3vl_merger");
gguf_set_val_str(meta, "general.architecture", "clip");
// QKV merge per block. Runs BEFORE the substring renames so we can
// reliably find attn_q / attn_k / attn_v by name.
const int64_t n_blocks_key = gguf_find_key(meta, "clip.vision.block_count");
const uint32_t n_blocks = n_blocks_key >= 0 ? gguf_get_val_u32(meta, n_blocks_key) : 27;
for (uint32_t b = 0; b < n_blocks; ++b) {
register_qwen35moe_qkv_merge(meta, ctx, (int) b);
}
// patch_embed: reshape + temporal split + F16->F32. Also BEFORE renames
// because it references `v.patch_embed.weight` by name.
register_qwen35moe_patch_embed_split(meta, ctx);
// Substring renames (last). These handle the simple pos_embed, merger.*,
// linear_fc1/2, norm1/2 conversions.
for (const auto & [from, to] : kQwen35moeClipRenames) {
rename_tensors_containing(meta, ctx, from, to);
}
// F16 -> F32 on position_embd after rename.
promote_tensor_to_f32(meta, ctx, "v.position_embd.weight");
}
} // anonymous namespace
@ -379,10 +670,16 @@ void translate_metadata(const llama_model_loader * ml,
void translate_clip_metadata(gguf_context * meta, ggml_context * ctx) {
if (!meta) return;
// Require both the gemma3 markers AND embedded vision tensors to fire.
if (detect_ollama_gemma3(meta, ctx) && any_tensor_with_prefix(ctx, "v.")) {
if (!any_tensor_with_prefix(ctx, "v.")) return; // nothing to translate
if (detect_ollama_gemma3(meta, ctx)) {
LLAMA_LOG_INFO("%s: detected Ollama-format gemma3 GGUF used as mmproj; translating\n", __func__);
handle_gemma3_clip(meta, ctx);
return;
}
if (detect_ollama_qwen35moe(meta, ctx)) {
handle_qwen35moe_clip(meta, ctx);
return;
}
}
@ -400,35 +697,27 @@ bool maybe_load_tensor(ggml_tensor * cur,
const char * source_file,
size_t file_offset,
ggml_backend_buffer_type_t buft) {
(void) file_offset; // registered ops capture their own offsets
LoadOp op;
{
std::lock_guard<std::mutex> lk(g_promote_mutex);
if (g_promote_f16_to_f32.find(ggml_get_name(cur)) == g_promote_f16_to_f32.end()) return false;
std::lock_guard<std::mutex> lk(g_loadop_mutex);
auto it = g_loadops.find(ggml_get_name(cur));
if (it == g_loadops.end()) return false;
op = it->second;
}
if (cur->type != GGML_TYPE_F32) return false;
const size_t n_elem = ggml_nelements(cur);
const size_t src_size = n_elem * sizeof(uint16_t);
const size_t dst_size = n_elem * sizeof(float);
std::vector<uint8_t> src(src_size);
FILE * f = std::fopen(source_file, "rb");
if (!f || std::fseek(f, (long) file_offset, SEEK_SET) != 0
|| std::fread(src.data(), 1, src_size, f) != src_size) {
if (f) std::fclose(f);
LLAMA_LOG_ERROR("%s: failed to read F16 bytes for '%s'\n", __func__, ggml_get_name(cur));
const size_t dst_size = ggml_nbytes(cur);
std::vector<uint8_t> dst(dst_size);
if (!op.apply(source_file, dst.data(), dst_size)) {
LLAMA_LOG_ERROR("%s: %s failed for %s\n", __func__, op.description, ggml_get_name(cur));
return false;
}
std::fclose(f);
std::vector<uint8_t> dst(dst_size);
const uint16_t * sp = reinterpret_cast<const uint16_t *>(src.data());
float * dp = reinterpret_cast<float *>(dst.data());
for (size_t i = 0; i < n_elem; ++i) dp[i] = ggml_fp16_to_fp32(sp[i]);
if (ggml_backend_buft_is_host(buft)) std::memcpy(cur->data, dst.data(), dst_size);
else ggml_backend_tensor_set(cur, dst.data(), 0, dst_size);
LLAMA_LOG_INFO("%s: promoted F16->F32 for %s (%zu elems)\n", __func__, ggml_get_name(cur), n_elem);
LLAMA_LOG_INFO("%s: %s for %s (%zu bytes)\n", __func__, op.description, ggml_get_name(cur), dst_size);
return true;
}

View file

@ -433,9 +433,9 @@ func NewLlamaServerRunner(
// and aborts model load. So gate on an explicit allowlist that mirrors
// the compat layer's clip-side coverage in llama/compat/.
compatClipArches := map[string]bool{
"gemma3": true,
"gemma3": true,
"qwen35moe": true,
// Add entries as llama/compat grows clip handlers.
// "qwen35moe": true,
}
if len(projectors) == 0 &&
len(f.Tensors().Items("v.")) > 0 &&