mlx: Gemma4 MTP speculative decoding (#15980)

This change adds support for MTP (multi-token prediction) speculative decoding for the
gemma4 model family.

It includes:
  * support for importing safetensors based gemma4 draft models with `ollama create`
  * a new DRAFT command in the Modelfile for specifying draft models
  * a --quantize-draft flag for the ollama create command to quantize the draft model
  * cache support for speculation
  * changes to the rotating cache to be able to handle MTP correctly
  * sampling support for draft model token prediction

---------

Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
This commit is contained in:
Patrick Devine 2026-05-05 08:55:04 -07:00 committed by GitHub
parent 4017af96cd
commit 15e6076d79
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 2928 additions and 42 deletions

View file

@ -23,6 +23,7 @@ import (
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create"
imagemanifest "github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/safetensors"
)
@ -34,6 +35,7 @@ type ModelfileConfig struct {
Template string
System string
License string
Draft string
Parser string
Renderer string
Parameters map[string]any
@ -67,6 +69,8 @@ func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig,
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
case "draft":
mfConfig.Draft = cmd.Args
case "parser":
mfConfig.Parser = cmd.Args
case "renderer":
@ -108,10 +112,12 @@ func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig,
// CreateOptions holds all options for model creation.
type CreateOptions struct {
ModelName string
ModelDir string
Quantize string // "int4", "int8", "nvfp4", "mxfp4", or "mxfp8" for quantization
Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
ModelName string
ModelDir string
Quantize string // "int4", "int8", "nvfp4", "mxfp4", or "mxfp8" for quantization
DraftQuantize string // optional quantization level for draft model tensors
Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
BaseConfig *model.ConfigV2
}
// CreateModel imports a model from a local directory.
@ -121,11 +127,23 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
// Detect model type
isSafetensors := create.IsSafetensorsModelDir(opts.ModelDir)
isImageGen := create.IsTensorModelDir(opts.ModelDir)
hasDraft := opts.Modelfile != nil && opts.Modelfile.Draft != ""
isBaseModelWithDraft := hasDraft && !isSafetensors && create.IsSafetensorsLLMModel(opts.ModelDir)
if opts.DraftQuantize != "" && !hasDraft {
return fmt.Errorf("--draft-quantize requires a DRAFT model")
}
if !isSafetensors && !isImageGen {
if !isSafetensors && !isImageGen && !isBaseModelWithDraft {
return fmt.Errorf("%s is not a supported model directory (needs config.json + *.safetensors or model_index.json)", opts.ModelDir)
}
if hasDraft && !create.IsSafetensorsModelDir(opts.Modelfile.Draft) {
return fmt.Errorf("draft %s is not a supported safetensors model directory", opts.Modelfile.Draft)
}
if hasDraft && isImageGen {
return fmt.Errorf("draft models are only supported for safetensors LLM models")
}
// Determine model type settings
var modelType, spinnerKey string
var capabilities []string
@ -138,6 +156,9 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
parserName = getParserName(opts.ModelDir)
rendererName = getRendererName(opts.ModelDir)
capabilities = inferSafetensorsCapabilities(opts.ModelDir, resolveParserName(opts.Modelfile, parserName))
} else if isBaseModelWithDraft {
modelType = "safetensors model"
spinnerKey = "create"
} else {
modelType = "image generation model"
spinnerKey = "imagegen"
@ -156,13 +177,44 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
p.Add(spinnerKey, spinner)
}
// Create the model using shared callbacks
var draftLayers []create.LayerInfo
var err error
if hasDraft {
draftLayers, err = create.CreateDraftSafetensorsLayers(
opts.Modelfile.Draft,
"draft.",
"draft",
opts.DraftQuantize,
newLayerCreator(),
newTensorLayerCreator(),
progressFn,
)
if err != nil {
spinner.Stop()
return err
}
}
if isBaseModelWithDraft {
err = createModelFromBaseWithDraft(opts, draftLayers, progressFn)
spinner.Stop()
if err != nil {
return err
}
fmt.Printf("Created safetensors model '%s'\n", opts.ModelName)
return nil
}
// Create the model using shared callbacks
if isSafetensors {
writer := newManifestWriter(opts, capabilities, parserName, rendererName)
if len(draftLayers) > 0 {
writer = appendLayersManifestWriter(writer, draftLayers)
}
err = create.CreateSafetensorsModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities, parserName, rendererName),
writer,
progressFn,
newPackedTensorLayerCreator(),
)
@ -184,6 +236,68 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
return nil
}
func appendLayersManifestWriter(next create.ManifestWriter, extra []create.LayerInfo) create.ManifestWriter {
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
layers = append(layers, extra...)
return next(modelName, config, layers)
}
}
func createModelFromBaseWithDraft(opts CreateOptions, draftLayers []create.LayerInfo, progressFn func(string)) error {
progressFn(fmt.Sprintf("loading base model %s", opts.ModelDir))
baseManifest, err := imagemanifest.LoadManifest(opts.ModelDir)
if err != nil {
return err
}
baseConfig, err := readConfigV2(baseManifest)
if err != nil {
return err
}
opts.BaseConfig = baseConfig
configLayer := baseManifest.GetConfigLayer("config.json")
if configLayer == nil {
return fmt.Errorf("base model %s does not contain config.json", opts.ModelDir)
}
layers := make([]create.LayerInfo, 0, len(baseManifest.Manifest.Layers)+len(draftLayers))
for _, layer := range baseManifest.Manifest.Layers {
layers = append(layers, create.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: layer.Name,
})
}
layers = append(layers, draftLayers...)
progressFn(fmt.Sprintf("writing manifest for %s", opts.ModelName))
return newManifestWriter(opts, baseConfig.Capabilities, baseConfig.Parser, baseConfig.Renderer)(
opts.ModelName,
create.LayerInfo{
Digest: configLayer.Digest,
Size: configLayer.Size,
MediaType: configLayer.MediaType,
Name: configLayer.Name,
},
layers,
)
}
func readConfigV2(m *imagemanifest.ModelManifest) (*model.ConfigV2, error) {
data, err := os.ReadFile(m.BlobPath(m.Manifest.Config.Digest))
if err != nil {
return nil, fmt.Errorf("failed to read base config: %w", err)
}
var cfg model.ConfigV2
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse base config: %w", err)
}
return &cfg, nil
}
func inferSafetensorsCapabilities(modelDir, parserName string) []string {
capabilities := []string{"completion"}
@ -359,14 +473,26 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
}
}
// Create config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
FileType: strings.ToLower(strings.TrimSpace(opts.Quantize)),
Capabilities: caps,
Requires: MinOllamaVersion,
Parser: resolveParserName(opts.Modelfile, parserName),
Renderer: resolveRendererName(opts.Modelfile, rendererName),
// Create config blob with version requirement.
configData := model.ConfigV2{}
if opts.BaseConfig != nil {
configData = *opts.BaseConfig
}
configData.ModelFormat = "safetensors"
if opts.Quantize != "" || configData.FileType == "" {
configData.FileType = strings.ToLower(strings.TrimSpace(opts.Quantize))
}
configData.Capabilities = caps
configData.Requires = MinOllamaVersion
configData.Parser = resolveParserName(opts.Modelfile, parserName)
configData.Renderer = resolveRendererName(opts.Modelfile, rendererName)
if opts.Modelfile != nil && opts.Modelfile.Draft != "" {
configData.Draft = &model.Draft{
ModelFormat: "safetensors",
Architecture: "Gemma4AssistantForCausalLM",
TensorPrefix: "draft.",
Config: "draft/config.json",
}
}
configJSON, err := json.Marshal(configData)
if err != nil {

View file

@ -44,6 +44,7 @@ func TestModelfileConfig(t *testing.T) {
func TestConfigFromModelfile(t *testing.T) {
modelfile, err := parser.ParseFile(strings.NewReader(`
FROM ./model
DRAFT ./assistant
TEMPLATE {{ .Prompt }}
PARAMETER temperature 0.7
PARAMETER stop USER:
@ -66,6 +67,10 @@ PARAMETER stop ASSISTANT:
t.Fatalf("Template = %q, want %q", mfConfig.Template, "{{ .Prompt }}")
}
if mfConfig.Draft != "./assistant" {
t.Fatalf("Draft = %q, want %q", mfConfig.Draft, "./assistant")
}
if got := mfConfig.Parameters["temperature"]; got != float32(0.7) {
t.Fatalf("temperature = %#v, want %v", got, float32(0.7))
}
@ -153,11 +158,23 @@ func TestCreateModel_NotSafetensorsDir(t *testing.T) {
}
}
func TestCreateModel_DraftQuantizeRequiresDraft(t *testing.T) {
err := CreateModel(CreateOptions{
ModelName: "test-model",
ModelDir: t.TempDir(),
DraftQuantize: "mxfp8",
}, nil)
if err == nil || !strings.Contains(err.Error(), "--draft-quantize requires a DRAFT model") {
t.Fatalf("error = %v, want draft-quantize requires DRAFT", err)
}
}
func TestCreateOptions(t *testing.T) {
opts := CreateOptions{
ModelName: "my-model",
ModelDir: "/path/to/model",
Quantize: "fp8",
ModelName: "my-model",
ModelDir: "/path/to/model",
Quantize: "fp8",
DraftQuantize: "mxfp8",
Modelfile: &ModelfileConfig{
Template: "test",
System: "system",
@ -179,6 +196,9 @@ func TestCreateOptions(t *testing.T) {
if opts.Quantize != "fp8" {
t.Errorf("Quantize = %q, want %q", opts.Quantize, "fp8")
}
if opts.DraftQuantize != "mxfp8" {
t.Errorf("DraftQuantize = %q, want %q", opts.DraftQuantize, "mxfp8")
}
if opts.Modelfile == nil {
t.Error("Modelfile should not be nil")
}
@ -286,6 +306,9 @@ func TestCreateOptions_Defaults(t *testing.T) {
if opts.Quantize != "" {
t.Errorf("Quantize should be empty by default, got %q", opts.Quantize)
}
if opts.DraftQuantize != "" {
t.Errorf("DraftQuantize should be empty by default, got %q", opts.DraftQuantize)
}
// Modelfile should default to nil
if opts.Modelfile != nil {
@ -518,6 +541,48 @@ func TestNewManifestWriter_PopulatesFileTypeFromQuantize(t *testing.T) {
}
}
func TestNewManifestWriter_PopulatesDraftMetadata(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
opts := CreateOptions{
ModelName: "test-draft",
ModelDir: t.TempDir(),
Modelfile: &ModelfileConfig{Draft: "/tmp/assistant"},
}
writer := newManifestWriter(opts, []string{"completion"}, "gemma4", "gemma4")
if err := writer(opts.ModelName, create.LayerInfo{}, nil); err != nil {
t.Fatalf("newManifestWriter() error = %v", err)
}
name := model.ParseName(opts.ModelName)
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
t.Fatalf("ParseNamedManifest() error = %v", err)
}
configPath, err := manifest.BlobsPath(mf.Config.Digest)
if err != nil {
t.Fatalf("BlobsPath() error = %v", err)
}
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("ReadFile() error = %v", err)
}
var cfg model.ConfigV2
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if cfg.Draft == nil {
t.Fatal("Draft metadata missing")
}
if cfg.Draft.TensorPrefix != "draft." || cfg.Draft.Config != "draft/config.json" {
t.Fatalf("Draft = %#v, want draft prefix/config", cfg.Draft)
}
}
func TestSupportsThinking(t *testing.T) {
tests := []struct {
name string

View file

@ -7,6 +7,7 @@ import (
"io"
"math"
"os"
"path"
"path/filepath"
"regexp"
"slices"
@ -724,12 +725,16 @@ func detectModelOptQuantization(modelDir string) bool {
}
func resolveEffectiveQuantization(cfg sourceModelConfig, sourceKind sourceQuantizedKind, requested string) (string, error) {
return resolveEffectiveQuantizationForFlag(cfg, sourceKind, requested, "--quantize")
}
func resolveEffectiveQuantizationForFlag(cfg sourceModelConfig, sourceKind sourceQuantizedKind, requested, flagName string) (string, error) {
switch sourceKind {
case sourceQuantizedKindNone:
return requested, nil
case sourceQuantizedKindPrequantized:
if requested != "" {
return "", fmt.Errorf("cannot requantize already-quantized source model with --quantize %q", requested)
return "", fmt.Errorf("cannot requantize already-quantized source model with %s %q", flagName, requested)
}
return "", nil
case sourceQuantizedKindSourceFP8:
@ -746,7 +751,7 @@ func resolveEffectiveQuantization(cfg sourceModelConfig, sourceKind sourceQuanti
case "nvfp4", "mxfp4", "mxfp8":
return requested, nil
default:
return "", fmt.Errorf("cannot convert already-quantized fp8 source model with --quantize %q", requested)
return "", fmt.Errorf("cannot convert already-quantized fp8 source model with %s %q", flagName, requested)
}
}
return "mxfp8", nil
@ -810,6 +815,7 @@ var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{
"Gemma4ForCausalLM": newGemma4ImportTransform,
"Gemma4ForConditionalGeneration": newGemma4ImportTransform,
"LagunaForCausalLM": newLagunaImportTransform,
"Gemma4AssistantForCausalLM": newGemma4ImportTransform,
}
func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) {
@ -1167,6 +1173,136 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
return nil
}
func normalizeRequestedQuantization(flagName, quantize string) (string, error) {
q := normalizeQuantType(strings.TrimSpace(quantize))
switch q {
case "", "int4", "int8", "nvfp4", "mxfp4", "mxfp8":
return q, nil
default:
return "", fmt.Errorf("unsupported %s %q: supported types are int4, int8, nvfp4, mxfp4, mxfp8", flagName, quantize)
}
}
// CreateDraftSafetensorsLayers imports an assistant/draft safetensors model
// into prefixed tensor and config layers. When draftQuantize is non-empty,
// eligible draft tensors are quantized with the same per-architecture policy
// used by target safetensors imports.
func CreateDraftSafetensorsLayers(modelDir, tensorPrefix, configPrefix, draftQuantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, fn func(status string)) ([]LayerInfo, error) {
if tensorPrefix == "" {
return nil, fmt.Errorf("draft tensor prefix must not be empty")
}
if configPrefix == "" {
return nil, fmt.Errorf("draft config prefix must not be empty")
}
effectiveQuantize, err := normalizeRequestedQuantization("--draft-quantize", draftQuantize)
if err != nil {
return nil, err
}
var importTransform tensorImportTransform = noopImportTransform{}
if effectiveQuantize != "" {
sourceConfig, err := readSourceModelConfig(modelDir)
if err != nil {
return nil, fmt.Errorf("failed to read draft config.json: %w", err)
}
sourceQuantKind, err := inspectSourceQuantization(modelDir, sourceConfig)
if err != nil {
return nil, fmt.Errorf("failed to inspect draft quantization: %w", err)
}
effectiveQuantize, err = resolveEffectiveQuantizationForFlag(sourceConfig, sourceQuantKind, effectiveQuantize, "--draft-quantize")
if err != nil {
return nil, err
}
importTransform, err = newTensorImportTransform(modelDir, sourceConfig)
if err != nil {
return nil, fmt.Errorf("failed to construct draft import transform for architecture %q: %w", sourceConfig.Architecture(), err)
}
}
entries, err := os.ReadDir(modelDir)
if err != nil {
return nil, fmt.Errorf("failed to read draft directory: %w", err)
}
var layers []LayerInfo
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
continue
}
stPath := filepath.Join(modelDir, entry.Name())
extractor, err := safetensors.OpenForExtraction(stPath)
if err != nil {
return nil, fmt.Errorf("failed to open draft %s: %w", stPath, err)
}
tensorNames := extractor.ListTensors()
fn(fmt.Sprintf("importing draft %s (%d tensors%s)", entry.Name(), len(tensorNames), importQuantizationStatus(sourceQuantizedKindNone, effectiveQuantize)))
for _, tensorName := range tensorNames {
if importTransform.skipTensor(tensorName) {
continue
}
td, err := extractor.GetTensor(tensorName)
if err != nil {
extractor.Close()
return nil, fmt.Errorf("failed to get draft tensor %s: %w", tensorName, err)
}
outTDs, err := importTransform.transformTensor(td)
if err != nil {
extractor.Close()
return nil, fmt.Errorf("failed to transform draft tensor %s: %w", tensorName, err)
}
for _, transformedTD := range outTDs {
if transformedTD == nil {
continue
}
outTD := transformedTD.WithName(tensorPrefix + transformedTD.Name)
quantizeType := ""
if effectiveQuantize != "" {
quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, effectiveQuantize)
if isEmbedTokensWeight(outTD.Name) {
quantizeType = ""
}
}
newLayers, err := createTensorLayer(outTD.SafetensorsReader(), outTD.Name, outTD.Dtype, outTD.Shape, quantizeType)
if err != nil {
extractor.Close()
return nil, fmt.Errorf("failed to create draft layer for %s: %w", tensorName, err)
}
layers = append(layers, newLayers...)
}
}
extractor.Close()
}
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
continue
}
if entry.Name() == "model.safetensors.index.json" {
continue
}
cfgPath := entry.Name()
fullPath := filepath.Join(modelDir, cfgPath)
fn(fmt.Sprintf("importing draft config %s", cfgPath))
f, err := os.Open(fullPath)
if err != nil {
return nil, fmt.Errorf("failed to open draft %s: %w", cfgPath, err)
}
layer, err := createLayer(f, "application/vnd.ollama.image.json", path.Join(configPrefix, cfgPath))
f.Close()
if err != nil {
return nil, fmt.Errorf("failed to create draft config layer for %s: %w", cfgPath, err)
}
layers = append(layers, layer)
}
return layers, nil
}
func shouldSkipSourceCompanion(name string, tensorSet map[string]struct{}, sourceTensorFiles map[string]string) bool {
switch {
case strings.HasSuffix(name, ".scales"):

View file

@ -452,6 +452,137 @@ func TestCreateSafetensorsModel(t *testing.T) {
}
}
func TestCreateDraftSafetensorsLayersPrefixesTensorsAndConfigs(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"gemma4_assistant"}`), 0o644); err != nil {
t.Fatal(err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{2, 2}, make([]byte, 8)),
})
var tensorNames []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
return LayerInfo{Digest: "sha256:json_" + name, Size: int64(len(data)), MediaType: mediaType, Name: name}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
tensorNames = append(tensorNames, name)
tensorName, tensorShape := readSingleTensorNameAndShape(t, data)
if tensorName != name {
t.Fatalf("safetensors key = %q, want %q", tensorName, name)
}
if !slices.Equal(tensorShape, shape) {
t.Fatalf("shape = %v, want %v", tensorShape, shape)
}
if quantize != "" {
t.Fatalf("draft quantize = %q, want empty", quantize)
}
return []LayerInfo{{Digest: "sha256:tensor_" + name, Size: int64(len(data)), MediaType: "application/vnd.ollama.image.tensor", Name: name}}, nil
}
layers, err := CreateDraftSafetensorsLayers(dir, "draft.", "draft", "", createLayer, createTensorLayer, func(string) {})
if err != nil {
t.Fatal(err)
}
if !slices.Contains(tensorNames, "draft.model.layers.0.self_attn.q_proj.weight") {
t.Fatalf("draft tensor was not prefixed: %v", tensorNames)
}
var hasDraftConfig bool
for _, layer := range layers {
if layer.Name == "draft/config.json" && layer.MediaType == "application/vnd.ollama.image.json" {
hasDraftConfig = true
}
}
if !hasDraftConfig {
t.Fatalf("draft/config.json layer missing: %#v", layers)
}
}
func TestCreateDraftSafetensorsLayersQuantizesEligibleTensors(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{
"architectures":["Gemma4AssistantForCausalLM"],
"num_hidden_layers":8
}`), 0o644); err != nil {
t.Fatal(err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.layers.0.input_layernorm.weight", "BF16", []int32{64}, make([]byte, 64*2)),
})
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
return LayerInfo{Digest: "sha256:json_" + name, Size: int64(len(data)), MediaType: mediaType, Name: name}, nil
}
quantizeByName := make(map[string]string)
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
quantizeByName[name] = quantize
return []LayerInfo{{Digest: "sha256:tensor_" + name, Size: int64(len(data)), MediaType: "application/vnd.ollama.image.tensor", Name: name}}, nil
}
if _, err := CreateDraftSafetensorsLayers(dir, "draft.", "draft", "MXFP8", createLayer, createTensorLayer, func(string) {}); err != nil {
t.Fatal(err)
}
if got := quantizeByName["draft.model.layers.0.self_attn.q_proj.weight"]; got != "mxfp8" {
t.Fatalf("q_proj draft quantize = %q, want mxfp8", got)
}
if got := quantizeByName["draft.model.layers.0.input_layernorm.weight"]; got != "" {
t.Fatalf("norm draft quantize = %q, want empty", got)
}
if got := quantizeByName["draft.model.embed_tokens.weight"]; got != "" {
t.Fatalf("embed_tokens draft quantize = %q, want empty", got)
}
}
func TestCreateDraftSafetensorsLayersRejectsUnsupportedDraftQuantize(t *testing.T) {
_, err := CreateDraftSafetensorsLayers(t.TempDir(), "draft.", "draft", "bogus", nil, nil, func(string) {})
if err == nil || !strings.Contains(err.Error(), "unsupported --draft-quantize") {
t.Fatalf("error = %v, want unsupported --draft-quantize", err)
}
}
func readSingleTensorNameAndShape(t *testing.T, data []byte) (string, []int32) {
t.Helper()
var headerSize uint64
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
t.Fatalf("failed to read header size: %v", err)
}
var header map[string]struct {
Shape []int32 `json:"shape"`
}
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
t.Fatalf("failed to parse header: %v", err)
}
for name, info := range header {
if name != "__metadata__" {
return name, info.Shape
}
}
t.Fatal("no tensor entry found in header")
return "", nil
}
func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) {
dir := t.TempDir()