mirror of
https://github.com/ollama/ollama.git
synced 2026-05-13 06:21:28 +00:00
mlx: Gemma4 MTP speculative decoding (#15980)
This change adds support for MTP (multi-token prediction) speculative decoding for the gemma4 model family. It includes: * support for importing safetensors based gemma4 draft models with `ollama create` * a new DRAFT command in the Modelfile for specifying draft models * a --quantize-draft flag for the ollama create command to quantize the draft model * cache support for speculation * changes to the rotating cache to be able to handle MTP correctly * sampling support for draft model token prediction --------- Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
This commit is contained in:
parent
4017af96cd
commit
15e6076d79
28 changed files with 2928 additions and 42 deletions
|
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
imagemanifest "github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/safetensors"
|
||||
)
|
||||
|
||||
|
|
@ -34,6 +35,7 @@ type ModelfileConfig struct {
|
|||
Template string
|
||||
System string
|
||||
License string
|
||||
Draft string
|
||||
Parser string
|
||||
Renderer string
|
||||
Parameters map[string]any
|
||||
|
|
@ -67,6 +69,8 @@ func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig,
|
|||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
case "draft":
|
||||
mfConfig.Draft = cmd.Args
|
||||
case "parser":
|
||||
mfConfig.Parser = cmd.Args
|
||||
case "renderer":
|
||||
|
|
@ -108,10 +112,12 @@ func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig,
|
|||
|
||||
// CreateOptions holds all options for model creation.
|
||||
type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "int4", "int8", "nvfp4", "mxfp4", or "mxfp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "int4", "int8", "nvfp4", "mxfp4", or "mxfp8" for quantization
|
||||
DraftQuantize string // optional quantization level for draft model tensors
|
||||
Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
|
||||
BaseConfig *model.ConfigV2
|
||||
}
|
||||
|
||||
// CreateModel imports a model from a local directory.
|
||||
|
|
@ -121,11 +127,23 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
|||
// Detect model type
|
||||
isSafetensors := create.IsSafetensorsModelDir(opts.ModelDir)
|
||||
isImageGen := create.IsTensorModelDir(opts.ModelDir)
|
||||
hasDraft := opts.Modelfile != nil && opts.Modelfile.Draft != ""
|
||||
isBaseModelWithDraft := hasDraft && !isSafetensors && create.IsSafetensorsLLMModel(opts.ModelDir)
|
||||
if opts.DraftQuantize != "" && !hasDraft {
|
||||
return fmt.Errorf("--draft-quantize requires a DRAFT model")
|
||||
}
|
||||
|
||||
if !isSafetensors && !isImageGen {
|
||||
if !isSafetensors && !isImageGen && !isBaseModelWithDraft {
|
||||
return fmt.Errorf("%s is not a supported model directory (needs config.json + *.safetensors or model_index.json)", opts.ModelDir)
|
||||
}
|
||||
|
||||
if hasDraft && !create.IsSafetensorsModelDir(opts.Modelfile.Draft) {
|
||||
return fmt.Errorf("draft %s is not a supported safetensors model directory", opts.Modelfile.Draft)
|
||||
}
|
||||
if hasDraft && isImageGen {
|
||||
return fmt.Errorf("draft models are only supported for safetensors LLM models")
|
||||
}
|
||||
|
||||
// Determine model type settings
|
||||
var modelType, spinnerKey string
|
||||
var capabilities []string
|
||||
|
|
@ -138,6 +156,9 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
|||
parserName = getParserName(opts.ModelDir)
|
||||
rendererName = getRendererName(opts.ModelDir)
|
||||
capabilities = inferSafetensorsCapabilities(opts.ModelDir, resolveParserName(opts.Modelfile, parserName))
|
||||
} else if isBaseModelWithDraft {
|
||||
modelType = "safetensors model"
|
||||
spinnerKey = "create"
|
||||
} else {
|
||||
modelType = "image generation model"
|
||||
spinnerKey = "imagegen"
|
||||
|
|
@ -156,13 +177,44 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
|||
p.Add(spinnerKey, spinner)
|
||||
}
|
||||
|
||||
// Create the model using shared callbacks
|
||||
var draftLayers []create.LayerInfo
|
||||
var err error
|
||||
if hasDraft {
|
||||
draftLayers, err = create.CreateDraftSafetensorsLayers(
|
||||
opts.Modelfile.Draft,
|
||||
"draft.",
|
||||
"draft",
|
||||
opts.DraftQuantize,
|
||||
newLayerCreator(),
|
||||
newTensorLayerCreator(),
|
||||
progressFn,
|
||||
)
|
||||
if err != nil {
|
||||
spinner.Stop()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if isBaseModelWithDraft {
|
||||
err = createModelFromBaseWithDraft(opts, draftLayers, progressFn)
|
||||
spinner.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("Created safetensors model '%s'\n", opts.ModelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create the model using shared callbacks
|
||||
if isSafetensors {
|
||||
writer := newManifestWriter(opts, capabilities, parserName, rendererName)
|
||||
if len(draftLayers) > 0 {
|
||||
writer = appendLayersManifestWriter(writer, draftLayers)
|
||||
}
|
||||
err = create.CreateSafetensorsModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities, parserName, rendererName),
|
||||
writer,
|
||||
progressFn,
|
||||
newPackedTensorLayerCreator(),
|
||||
)
|
||||
|
|
@ -184,6 +236,68 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func appendLayersManifestWriter(next create.ManifestWriter, extra []create.LayerInfo) create.ManifestWriter {
|
||||
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
|
||||
layers = append(layers, extra...)
|
||||
return next(modelName, config, layers)
|
||||
}
|
||||
}
|
||||
|
||||
func createModelFromBaseWithDraft(opts CreateOptions, draftLayers []create.LayerInfo, progressFn func(string)) error {
|
||||
progressFn(fmt.Sprintf("loading base model %s", opts.ModelDir))
|
||||
baseManifest, err := imagemanifest.LoadManifest(opts.ModelDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
baseConfig, err := readConfigV2(baseManifest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.BaseConfig = baseConfig
|
||||
|
||||
configLayer := baseManifest.GetConfigLayer("config.json")
|
||||
if configLayer == nil {
|
||||
return fmt.Errorf("base model %s does not contain config.json", opts.ModelDir)
|
||||
}
|
||||
|
||||
layers := make([]create.LayerInfo, 0, len(baseManifest.Manifest.Layers)+len(draftLayers))
|
||||
for _, layer := range baseManifest.Manifest.Layers {
|
||||
layers = append(layers, create.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: layer.Name,
|
||||
})
|
||||
}
|
||||
layers = append(layers, draftLayers...)
|
||||
|
||||
progressFn(fmt.Sprintf("writing manifest for %s", opts.ModelName))
|
||||
return newManifestWriter(opts, baseConfig.Capabilities, baseConfig.Parser, baseConfig.Renderer)(
|
||||
opts.ModelName,
|
||||
create.LayerInfo{
|
||||
Digest: configLayer.Digest,
|
||||
Size: configLayer.Size,
|
||||
MediaType: configLayer.MediaType,
|
||||
Name: configLayer.Name,
|
||||
},
|
||||
layers,
|
||||
)
|
||||
}
|
||||
|
||||
func readConfigV2(m *imagemanifest.ModelManifest) (*model.ConfigV2, error) {
|
||||
data, err := os.ReadFile(m.BlobPath(m.Manifest.Config.Digest))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read base config: %w", err)
|
||||
}
|
||||
|
||||
var cfg model.ConfigV2
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse base config: %w", err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func inferSafetensorsCapabilities(modelDir, parserName string) []string {
|
||||
capabilities := []string{"completion"}
|
||||
|
||||
|
|
@ -359,14 +473,26 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
|
|||
}
|
||||
}
|
||||
|
||||
// Create config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
FileType: strings.ToLower(strings.TrimSpace(opts.Quantize)),
|
||||
Capabilities: caps,
|
||||
Requires: MinOllamaVersion,
|
||||
Parser: resolveParserName(opts.Modelfile, parserName),
|
||||
Renderer: resolveRendererName(opts.Modelfile, rendererName),
|
||||
// Create config blob with version requirement.
|
||||
configData := model.ConfigV2{}
|
||||
if opts.BaseConfig != nil {
|
||||
configData = *opts.BaseConfig
|
||||
}
|
||||
configData.ModelFormat = "safetensors"
|
||||
if opts.Quantize != "" || configData.FileType == "" {
|
||||
configData.FileType = strings.ToLower(strings.TrimSpace(opts.Quantize))
|
||||
}
|
||||
configData.Capabilities = caps
|
||||
configData.Requires = MinOllamaVersion
|
||||
configData.Parser = resolveParserName(opts.Modelfile, parserName)
|
||||
configData.Renderer = resolveRendererName(opts.Modelfile, rendererName)
|
||||
if opts.Modelfile != nil && opts.Modelfile.Draft != "" {
|
||||
configData.Draft = &model.Draft{
|
||||
ModelFormat: "safetensors",
|
||||
Architecture: "Gemma4AssistantForCausalLM",
|
||||
TensorPrefix: "draft.",
|
||||
Config: "draft/config.json",
|
||||
}
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ func TestModelfileConfig(t *testing.T) {
|
|||
func TestConfigFromModelfile(t *testing.T) {
|
||||
modelfile, err := parser.ParseFile(strings.NewReader(`
|
||||
FROM ./model
|
||||
DRAFT ./assistant
|
||||
TEMPLATE {{ .Prompt }}
|
||||
PARAMETER temperature 0.7
|
||||
PARAMETER stop USER:
|
||||
|
|
@ -66,6 +67,10 @@ PARAMETER stop ASSISTANT:
|
|||
t.Fatalf("Template = %q, want %q", mfConfig.Template, "{{ .Prompt }}")
|
||||
}
|
||||
|
||||
if mfConfig.Draft != "./assistant" {
|
||||
t.Fatalf("Draft = %q, want %q", mfConfig.Draft, "./assistant")
|
||||
}
|
||||
|
||||
if got := mfConfig.Parameters["temperature"]; got != float32(0.7) {
|
||||
t.Fatalf("temperature = %#v, want %v", got, float32(0.7))
|
||||
}
|
||||
|
|
@ -153,11 +158,23 @@ func TestCreateModel_NotSafetensorsDir(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCreateModel_DraftQuantizeRequiresDraft(t *testing.T) {
|
||||
err := CreateModel(CreateOptions{
|
||||
ModelName: "test-model",
|
||||
ModelDir: t.TempDir(),
|
||||
DraftQuantize: "mxfp8",
|
||||
}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "--draft-quantize requires a DRAFT model") {
|
||||
t.Fatalf("error = %v, want draft-quantize requires DRAFT", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions(t *testing.T) {
|
||||
opts := CreateOptions{
|
||||
ModelName: "my-model",
|
||||
ModelDir: "/path/to/model",
|
||||
Quantize: "fp8",
|
||||
ModelName: "my-model",
|
||||
ModelDir: "/path/to/model",
|
||||
Quantize: "fp8",
|
||||
DraftQuantize: "mxfp8",
|
||||
Modelfile: &ModelfileConfig{
|
||||
Template: "test",
|
||||
System: "system",
|
||||
|
|
@ -179,6 +196,9 @@ func TestCreateOptions(t *testing.T) {
|
|||
if opts.Quantize != "fp8" {
|
||||
t.Errorf("Quantize = %q, want %q", opts.Quantize, "fp8")
|
||||
}
|
||||
if opts.DraftQuantize != "mxfp8" {
|
||||
t.Errorf("DraftQuantize = %q, want %q", opts.DraftQuantize, "mxfp8")
|
||||
}
|
||||
if opts.Modelfile == nil {
|
||||
t.Error("Modelfile should not be nil")
|
||||
}
|
||||
|
|
@ -286,6 +306,9 @@ func TestCreateOptions_Defaults(t *testing.T) {
|
|||
if opts.Quantize != "" {
|
||||
t.Errorf("Quantize should be empty by default, got %q", opts.Quantize)
|
||||
}
|
||||
if opts.DraftQuantize != "" {
|
||||
t.Errorf("DraftQuantize should be empty by default, got %q", opts.DraftQuantize)
|
||||
}
|
||||
|
||||
// Modelfile should default to nil
|
||||
if opts.Modelfile != nil {
|
||||
|
|
@ -518,6 +541,48 @@ func TestNewManifestWriter_PopulatesFileTypeFromQuantize(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNewManifestWriter_PopulatesDraftMetadata(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
opts := CreateOptions{
|
||||
ModelName: "test-draft",
|
||||
ModelDir: t.TempDir(),
|
||||
Modelfile: &ModelfileConfig{Draft: "/tmp/assistant"},
|
||||
}
|
||||
|
||||
writer := newManifestWriter(opts, []string{"completion"}, "gemma4", "gemma4")
|
||||
if err := writer(opts.ModelName, create.LayerInfo{}, nil); err != nil {
|
||||
t.Fatalf("newManifestWriter() error = %v", err)
|
||||
}
|
||||
|
||||
name := model.ParseName(opts.ModelName)
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseNamedManifest() error = %v", err)
|
||||
}
|
||||
|
||||
configPath, err := manifest.BlobsPath(mf.Config.Digest)
|
||||
if err != nil {
|
||||
t.Fatalf("BlobsPath() error = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile() error = %v", err)
|
||||
}
|
||||
|
||||
var cfg model.ConfigV2
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if cfg.Draft == nil {
|
||||
t.Fatal("Draft metadata missing")
|
||||
}
|
||||
if cfg.Draft.TensorPrefix != "draft." || cfg.Draft.Config != "draft/config.json" {
|
||||
t.Fatalf("Draft = %#v, want draft prefix/config", cfg.Draft)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsThinking(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
|
|
@ -724,12 +725,16 @@ func detectModelOptQuantization(modelDir string) bool {
|
|||
}
|
||||
|
||||
func resolveEffectiveQuantization(cfg sourceModelConfig, sourceKind sourceQuantizedKind, requested string) (string, error) {
|
||||
return resolveEffectiveQuantizationForFlag(cfg, sourceKind, requested, "--quantize")
|
||||
}
|
||||
|
||||
func resolveEffectiveQuantizationForFlag(cfg sourceModelConfig, sourceKind sourceQuantizedKind, requested, flagName string) (string, error) {
|
||||
switch sourceKind {
|
||||
case sourceQuantizedKindNone:
|
||||
return requested, nil
|
||||
case sourceQuantizedKindPrequantized:
|
||||
if requested != "" {
|
||||
return "", fmt.Errorf("cannot requantize already-quantized source model with --quantize %q", requested)
|
||||
return "", fmt.Errorf("cannot requantize already-quantized source model with %s %q", flagName, requested)
|
||||
}
|
||||
return "", nil
|
||||
case sourceQuantizedKindSourceFP8:
|
||||
|
|
@ -746,7 +751,7 @@ func resolveEffectiveQuantization(cfg sourceModelConfig, sourceKind sourceQuanti
|
|||
case "nvfp4", "mxfp4", "mxfp8":
|
||||
return requested, nil
|
||||
default:
|
||||
return "", fmt.Errorf("cannot convert already-quantized fp8 source model with --quantize %q", requested)
|
||||
return "", fmt.Errorf("cannot convert already-quantized fp8 source model with %s %q", flagName, requested)
|
||||
}
|
||||
}
|
||||
return "mxfp8", nil
|
||||
|
|
@ -810,6 +815,7 @@ var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{
|
|||
"Gemma4ForCausalLM": newGemma4ImportTransform,
|
||||
"Gemma4ForConditionalGeneration": newGemma4ImportTransform,
|
||||
"LagunaForCausalLM": newLagunaImportTransform,
|
||||
"Gemma4AssistantForCausalLM": newGemma4ImportTransform,
|
||||
}
|
||||
|
||||
func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) {
|
||||
|
|
@ -1167,6 +1173,136 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
|||
return nil
|
||||
}
|
||||
|
||||
func normalizeRequestedQuantization(flagName, quantize string) (string, error) {
|
||||
q := normalizeQuantType(strings.TrimSpace(quantize))
|
||||
switch q {
|
||||
case "", "int4", "int8", "nvfp4", "mxfp4", "mxfp8":
|
||||
return q, nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported %s %q: supported types are int4, int8, nvfp4, mxfp4, mxfp8", flagName, quantize)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDraftSafetensorsLayers imports an assistant/draft safetensors model
|
||||
// into prefixed tensor and config layers. When draftQuantize is non-empty,
|
||||
// eligible draft tensors are quantized with the same per-architecture policy
|
||||
// used by target safetensors imports.
|
||||
func CreateDraftSafetensorsLayers(modelDir, tensorPrefix, configPrefix, draftQuantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, fn func(status string)) ([]LayerInfo, error) {
|
||||
if tensorPrefix == "" {
|
||||
return nil, fmt.Errorf("draft tensor prefix must not be empty")
|
||||
}
|
||||
if configPrefix == "" {
|
||||
return nil, fmt.Errorf("draft config prefix must not be empty")
|
||||
}
|
||||
effectiveQuantize, err := normalizeRequestedQuantization("--draft-quantize", draftQuantize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var importTransform tensorImportTransform = noopImportTransform{}
|
||||
if effectiveQuantize != "" {
|
||||
sourceConfig, err := readSourceModelConfig(modelDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read draft config.json: %w", err)
|
||||
}
|
||||
sourceQuantKind, err := inspectSourceQuantization(modelDir, sourceConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to inspect draft quantization: %w", err)
|
||||
}
|
||||
effectiveQuantize, err = resolveEffectiveQuantizationForFlag(sourceConfig, sourceQuantKind, effectiveQuantize, "--draft-quantize")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
importTransform, err = newTensorImportTransform(modelDir, sourceConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to construct draft import transform for architecture %q: %w", sourceConfig.Architecture(), err)
|
||||
}
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(modelDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read draft directory: %w", err)
|
||||
}
|
||||
|
||||
var layers []LayerInfo
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
continue
|
||||
}
|
||||
|
||||
stPath := filepath.Join(modelDir, entry.Name())
|
||||
extractor, err := safetensors.OpenForExtraction(stPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open draft %s: %w", stPath, err)
|
||||
}
|
||||
|
||||
tensorNames := extractor.ListTensors()
|
||||
fn(fmt.Sprintf("importing draft %s (%d tensors%s)", entry.Name(), len(tensorNames), importQuantizationStatus(sourceQuantizedKindNone, effectiveQuantize)))
|
||||
for _, tensorName := range tensorNames {
|
||||
if importTransform.skipTensor(tensorName) {
|
||||
continue
|
||||
}
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return nil, fmt.Errorf("failed to get draft tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
outTDs, err := importTransform.transformTensor(td)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return nil, fmt.Errorf("failed to transform draft tensor %s: %w", tensorName, err)
|
||||
}
|
||||
for _, transformedTD := range outTDs {
|
||||
if transformedTD == nil {
|
||||
continue
|
||||
}
|
||||
outTD := transformedTD.WithName(tensorPrefix + transformedTD.Name)
|
||||
quantizeType := ""
|
||||
if effectiveQuantize != "" {
|
||||
quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, effectiveQuantize)
|
||||
if isEmbedTokensWeight(outTD.Name) {
|
||||
quantizeType = ""
|
||||
}
|
||||
}
|
||||
newLayers, err := createTensorLayer(outTD.SafetensorsReader(), outTD.Name, outTD.Dtype, outTD.Shape, quantizeType)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return nil, fmt.Errorf("failed to create draft layer for %s: %w", tensorName, err)
|
||||
}
|
||||
layers = append(layers, newLayers...)
|
||||
}
|
||||
}
|
||||
extractor.Close()
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
if entry.Name() == "model.safetensors.index.json" {
|
||||
continue
|
||||
}
|
||||
|
||||
cfgPath := entry.Name()
|
||||
fullPath := filepath.Join(modelDir, cfgPath)
|
||||
fn(fmt.Sprintf("importing draft config %s", cfgPath))
|
||||
|
||||
f, err := os.Open(fullPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open draft %s: %w", cfgPath, err)
|
||||
}
|
||||
layer, err := createLayer(f, "application/vnd.ollama.image.json", path.Join(configPrefix, cfgPath))
|
||||
f.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create draft config layer for %s: %w", cfgPath, err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func shouldSkipSourceCompanion(name string, tensorSet map[string]struct{}, sourceTensorFiles map[string]string) bool {
|
||||
switch {
|
||||
case strings.HasSuffix(name, ".scales"):
|
||||
|
|
|
|||
|
|
@ -452,6 +452,137 @@ func TestCreateSafetensorsModel(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCreateDraftSafetensorsLayersPrefixesTensorsAndConfigs(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"gemma4_assistant"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{2, 2}, make([]byte, 8)),
|
||||
})
|
||||
|
||||
var tensorNames []string
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
return LayerInfo{Digest: "sha256:json_" + name, Size: int64(len(data)), MediaType: mediaType, Name: name}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tensorNames = append(tensorNames, name)
|
||||
tensorName, tensorShape := readSingleTensorNameAndShape(t, data)
|
||||
if tensorName != name {
|
||||
t.Fatalf("safetensors key = %q, want %q", tensorName, name)
|
||||
}
|
||||
if !slices.Equal(tensorShape, shape) {
|
||||
t.Fatalf("shape = %v, want %v", tensorShape, shape)
|
||||
}
|
||||
if quantize != "" {
|
||||
t.Fatalf("draft quantize = %q, want empty", quantize)
|
||||
}
|
||||
return []LayerInfo{{Digest: "sha256:tensor_" + name, Size: int64(len(data)), MediaType: "application/vnd.ollama.image.tensor", Name: name}}, nil
|
||||
}
|
||||
|
||||
layers, err := CreateDraftSafetensorsLayers(dir, "draft.", "draft", "", createLayer, createTensorLayer, func(string) {})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Contains(tensorNames, "draft.model.layers.0.self_attn.q_proj.weight") {
|
||||
t.Fatalf("draft tensor was not prefixed: %v", tensorNames)
|
||||
}
|
||||
var hasDraftConfig bool
|
||||
for _, layer := range layers {
|
||||
if layer.Name == "draft/config.json" && layer.MediaType == "application/vnd.ollama.image.json" {
|
||||
hasDraftConfig = true
|
||||
}
|
||||
}
|
||||
if !hasDraftConfig {
|
||||
t.Fatalf("draft/config.json layer missing: %#v", layers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateDraftSafetensorsLayersQuantizesEligibleTensors(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{
|
||||
"architectures":["Gemma4AssistantForCausalLM"],
|
||||
"num_hidden_layers":8
|
||||
}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.input_layernorm.weight", "BF16", []int32{64}, make([]byte, 64*2)),
|
||||
})
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
return LayerInfo{Digest: "sha256:json_" + name, Size: int64(len(data)), MediaType: mediaType, Name: name}, nil
|
||||
}
|
||||
quantizeByName := make(map[string]string)
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
quantizeByName[name] = quantize
|
||||
return []LayerInfo{{Digest: "sha256:tensor_" + name, Size: int64(len(data)), MediaType: "application/vnd.ollama.image.tensor", Name: name}}, nil
|
||||
}
|
||||
|
||||
if _, err := CreateDraftSafetensorsLayers(dir, "draft.", "draft", "MXFP8", createLayer, createTensorLayer, func(string) {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got := quantizeByName["draft.model.layers.0.self_attn.q_proj.weight"]; got != "mxfp8" {
|
||||
t.Fatalf("q_proj draft quantize = %q, want mxfp8", got)
|
||||
}
|
||||
if got := quantizeByName["draft.model.layers.0.input_layernorm.weight"]; got != "" {
|
||||
t.Fatalf("norm draft quantize = %q, want empty", got)
|
||||
}
|
||||
if got := quantizeByName["draft.model.embed_tokens.weight"]; got != "" {
|
||||
t.Fatalf("embed_tokens draft quantize = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateDraftSafetensorsLayersRejectsUnsupportedDraftQuantize(t *testing.T) {
|
||||
_, err := CreateDraftSafetensorsLayers(t.TempDir(), "draft.", "draft", "bogus", nil, nil, func(string) {})
|
||||
if err == nil || !strings.Contains(err.Error(), "unsupported --draft-quantize") {
|
||||
t.Fatalf("error = %v, want unsupported --draft-quantize", err)
|
||||
}
|
||||
}
|
||||
|
||||
func readSingleTensorNameAndShape(t *testing.T, data []byte) (string, []int32) {
|
||||
t.Helper()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
|
||||
t.Fatalf("failed to read header size: %v", err)
|
||||
}
|
||||
|
||||
var header map[string]struct {
|
||||
Shape []int32 `json:"shape"`
|
||||
}
|
||||
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
|
||||
t.Fatalf("failed to parse header: %v", err)
|
||||
}
|
||||
for name, info := range header {
|
||||
if name != "__metadata__" {
|
||||
return name, info.Shape
|
||||
}
|
||||
}
|
||||
t.Fatal("no tensor entry found in header")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue