diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index e3384bbe2..8abbde44c 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -423,8 +423,8 @@ jobs: lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; - lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; - lib/ollama/include*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; + lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-mlx.tar.in ;; + lib/ollama/include*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-mlx.tar.in ;; lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;; lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;; lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;; @@ -458,12 +458,14 @@ jobs: CGO_CFLAGS CGO_CXXFLAGS GOFLAGS + APT_MIRROR=http://azure.archive.ubuntu.com/ubuntu - os: linux arch: amd64 build-args: | CGO_CFLAGS CGO_CXXFLAGS GOFLAGS + APT_MIRROR=http://azure.archive.ubuntu.com/ubuntu - os: linux arch: amd64 suffix: '-rocm' @@ -472,6 +474,7 @@ jobs: CGO_CXXFLAGS GOFLAGS FLAVOR=rocm + APT_MIRROR=http://azure.archive.ubuntu.com/ubuntu runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }} environment: release needs: setup-environment diff --git a/CMakePresets.json b/CMakePresets.json index d099d3f16..0fdbc1442 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -41,7 +41,7 @@ "inherits": [ "CUDA" ], "cacheVariables": { "CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual", - "CMAKE_CUDA_FLAGS": "-t 4", + "CMAKE_CUDA_FLAGS": "-t 2", "OLLAMA_RUNNER_DIR": "cuda_v13" } }, diff --git a/Dockerfile b/Dockerfile index 885faf250..0485b09c4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -212,8 +212,11 @@ COPY --from=cpu dist/lib/ollama /lib/ollama COPY --from=build /bin/ollama /bin/ollama FROM ubuntu:24.04 -RUN apt-get update \ +ARG APT_MIRROR=http://archive.ubuntu.com/ubuntu +RUN sed -i "s|http://archive.ubuntu.com/ubuntu|$APT_MIRROR|g" /etc/apt/sources.list.d/ubuntu.sources \ + && apt-get update \ && apt-get install -y ca-certificates libvulkan1 libopenblas0 \ + && sed -i "s|$APT_MIRROR|http://archive.ubuntu.com/ubuntu|g" /etc/apt/sources.list.d/ubuntu.sources \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* COPY --from=archive /bin /usr/bin diff --git a/cmd/cmd.go b/cmd/cmd.go index ad3bdfd7f..7448a3e7d 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -54,6 +54,7 @@ import ( "github.com/ollama/ollama/types/syncmap" "github.com/ollama/ollama/version" xcmd "github.com/ollama/ollama/x/cmd" + xcreate "github.com/ollama/ollama/x/create" xcreateclient "github.com/ollama/ollama/x/create/client" "github.com/ollama/ollama/x/imagegen" ) @@ -145,6 +146,39 @@ func isLocalhost() bool { return ip != nil && (ip.IsLoopback() || ip.IsUnspecified()) } +func resolveExperimentalLocalModelDir(ref, filename string) string { + if ref == "" || filepath.IsAbs(ref) || filename == "" { + return ref + } + + candidate := filepath.Join(filepath.Dir(filename), ref) + if xcreate.IsSafetensorsModelDir(candidate) || xcreate.IsTensorModelDir(candidate) { + return candidate + } + + return ref +} + +func resolveExperimentalDraftDir(ref, filename string) (string, error) { + if ref == "" { + return "", nil + } + if filepath.IsAbs(ref) { + if xcreate.IsSafetensorsModelDir(ref) { + return ref, nil + } + return "", fmt.Errorf("draft %s is not a supported safetensors model directory", ref) + } + if filename != "" { + candidate := filepath.Join(filepath.Dir(filename), ref) + if xcreate.IsSafetensorsModelDir(candidate) { + return candidate, nil + } + } + + return "", fmt.Errorf("DRAFT model references are not supported with --experimental yet: %s", ref) +} + func CreateHandler(cmd *cobra.Command, args []string) error { p := progress.NewProgress(os.Stderr) defer p.Stop() @@ -159,6 +193,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error { // Check for --experimental flag for safetensors model creation // This gates both safetensors LLM and imagegen model creation experimental, _ := cmd.Flags().GetBool("experimental") + draftQuantize, _ := cmd.Flags().GetString("draft-quantize") + if draftQuantize != "" && !experimental { + return errors.New("--draft-quantize requires --experimental") + } if experimental { if !isLocalhost() { return errors.New("remote safetensor model creation not yet supported") @@ -192,17 +230,22 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - // Resolve relative paths based on Modelfile location - if !filepath.IsAbs(modelDir) && filename != "" { - modelDir = filepath.Join(filepath.Dir(filename), modelDir) + modelDir = resolveExperimentalLocalModelDir(modelDir, filename) + if mfConfig.Draft != "" { + draftDir, err := resolveExperimentalDraftDir(mfConfig.Draft, filename) + if err != nil { + return err + } + mfConfig.Draft = draftDir } quantize, _ := cmd.Flags().GetString("quantize") return xcreateclient.CreateModel(xcreateclient.CreateOptions{ - ModelName: modelName, - ModelDir: modelDir, - Quantize: quantize, - Modelfile: mfConfig, + ModelName: modelName, + ModelDir: modelDir, + Quantize: quantize, + DraftQuantize: draftQuantize, + Modelfile: mfConfig, }, p) } @@ -2176,6 +2219,9 @@ func NewCLI() *cobra.Command { if experimental, _ := cmd.Flags().GetBool("experimental"); experimental { return nil } + if draftQuantize, _ := cmd.Flags().GetString("draft-quantize"); draftQuantize != "" { + return errors.New("--draft-quantize requires --experimental") + } return checkServerHeartbeat(cmd, args) }, RunE: CreateHandler, @@ -2183,6 +2229,7 @@ func NewCLI() *cobra.Command { createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")") createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)") + createCmd.Flags().String("draft-quantize", "", "Quantize draft model to this level") createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation") showCmd := &cobra.Command{ diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index d9c565630..fe2930ea6 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "os" + "path/filepath" "reflect" "strings" "testing" @@ -1524,6 +1525,87 @@ func TestCreateHandler(t *testing.T) { } } +func TestCreateHandlerDraftQuantizeRequiresExperimental(t *testing.T) { + cmd := &cobra.Command{} + cmd.Flags().Bool("experimental", false, "") + cmd.Flags().String("draft-quantize", "mxfp8", "") + cmd.SetContext(t.Context()) + + err := CreateHandler(cmd, []string{"test-model"}) + if err == nil || !strings.Contains(err.Error(), "--draft-quantize requires --experimental") { + t.Fatalf("error = %v, want draft-quantize requires experimental", err) + } +} + +func TestCreateHandlerDraftRequiresExperimental(t *testing.T) { + dir := t.TempDir() + modelfile := filepath.Join(dir, "Modelfile") + if err := os.WriteFile(modelfile, []byte("FROM base\nDRAFT ./assistant\n"), 0o644); err != nil { + t.Fatal(err) + } + + cmd := &cobra.Command{} + cmd.Flags().Bool("experimental", false, "") + cmd.Flags().String("draft-quantize", "", "") + cmd.Flags().String("file", modelfile, "") + cmd.SetContext(t.Context()) + + err := CreateHandler(cmd, []string{"test-model"}) + if err == nil || !strings.Contains(err.Error(), "DRAFT requires --experimental") { + t.Fatalf("error = %v, want DRAFT requires --experimental", err) + } +} + +func TestResolveExperimentalLocalModelDir(t *testing.T) { + dir := t.TempDir() + modelfile := filepath.Join(dir, "Modelfile") + modelDir := filepath.Join(dir, "model") + if err := os.Mkdir(modelDir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(modelDir, "config.json"), []byte(`{}`), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(modelDir, "model.safetensors"), []byte("dummy"), 0o644); err != nil { + t.Fatal(err) + } + + if got := resolveExperimentalLocalModelDir("gemma4", modelfile); got != "gemma4" { + t.Fatalf("resolveExperimentalLocalModelDir(model name) = %q, want gemma4", got) + } + if got := resolveExperimentalLocalModelDir("./model", modelfile); got != modelDir { + t.Fatalf("resolveExperimentalLocalModelDir(local dir) = %q, want %q", got, modelDir) + } +} + +func TestResolveExperimentalDraftDir(t *testing.T) { + dir := t.TempDir() + modelfile := filepath.Join(dir, "Modelfile") + draftDir := filepath.Join(dir, "assistant") + if err := os.Mkdir(draftDir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(draftDir, "config.json"), []byte(`{}`), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(draftDir, "model.safetensors"), []byte("dummy"), 0o644); err != nil { + t.Fatal(err) + } + + got, err := resolveExperimentalDraftDir("./assistant", modelfile) + if err != nil { + t.Fatal(err) + } + if got != draftDir { + t.Fatalf("resolveExperimentalDraftDir(local dir) = %q, want %q", got, draftDir) + } + + _, err = resolveExperimentalDraftDir("assistant-model", modelfile) + if err == nil || !strings.Contains(err.Error(), "DRAFT model references are not supported with --experimental yet") { + t.Fatalf("error = %v, want unsupported draft model reference", err) + } +} + func TestNewCreateRequest(t *testing.T) { tests := []struct { name string diff --git a/parser/parser.go b/parser/parser.go index f3b6dcb55..3d3ea330d 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -83,6 +83,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) req.Files[k] = v } } + case "draft": + return nil, errors.New("DRAFT requires --experimental") case "adapter": path, err := expandPath(c.Args, relativeDir) if err != nil { @@ -336,7 +338,7 @@ func (c Command) String() string { switch c.Name { case "model": fmt.Fprintf(&sb, "FROM %s", c.Args) - case "license", "template", "system", "adapter", "renderer", "parser", "requires": + case "license", "template", "system", "adapter", "renderer", "parser", "requires", "draft": fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args)) case "message": role, message, _ := strings.Cut(c.Args, ": ") @@ -362,7 +364,7 @@ const ( var ( errMissingFrom = errors.New("no FROM line") errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"") - errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", \"message\", or \"requires\"") + errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"draft\", \"renderer\", \"parser\", \"parameter\", \"message\", or \"requires\"") ) type ParserError struct { @@ -622,7 +624,7 @@ func isValidMessageRole(role string) bool { func isValidCommand(cmd string) bool { switch strings.ToLower(cmd) { - case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message", "requires": + case "from", "license", "template", "system", "adapter", "draft", "renderer", "parser", "parameter", "message", "requires": return true default: return false diff --git a/parser/parser_test.go b/parser/parser_test.go index 4dcfed0cb..7abebb44e 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -58,6 +58,32 @@ TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|> assert.Equal(t, expectedCommands, modelfile.Commands) } +func TestParseFileDraft(t *testing.T) { + modelfile, err := ParseFile(strings.NewReader(` +FROM base +DRAFT ./assistant +`)) + require.NoError(t, err) + + expectedCommands := []Command{ + {Name: "model", Args: "base"}, + {Name: "draft", Args: "./assistant"}, + } + assert.Equal(t, expectedCommands, modelfile.Commands) + assert.Contains(t, modelfile.String(), "DRAFT ./assistant") +} + +func TestCreateRequestDraftRequiresExperimental(t *testing.T) { + modelfile, err := ParseFile(strings.NewReader(` +FROM base +DRAFT ./assistant +`)) + require.NoError(t, err) + + _, err = modelfile.CreateRequest("") + require.ErrorContains(t, err, "DRAFT requires --experimental") +} + func TestParseFileTrimSpace(t *testing.T) { input := ` FROM " model 1" diff --git a/scripts/build_linux.sh b/scripts/build_linux.sh index e06ba38b5..5421d70c7 100755 --- a/scripts/build_linux.sh +++ b/scripts/build_linux.sh @@ -62,13 +62,15 @@ if echo $PLATFORM | grep "," > /dev/null ; then tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst - tar c -C ./dist/linux_amd64 --exclude rocm . | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst + tar c -C ./dist/linux_amd64 --exclude rocm --exclude 'mlx*' --exclude include . | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst + tar c -C ./dist/linux_amd64 ./lib/ollama/mlx_cuda_v13 ./lib/ollama/include | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-mlx.tar.zst elif echo $PLATFORM | grep "arm64" > /dev/null ; then tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst elif echo $PLATFORM | grep "amd64" > /dev/null ; then - tar c -C ./dist/ --exclude rocm bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst + tar c -C ./dist/ --exclude rocm --exclude 'mlx*' --exclude include bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst tar c -C ./dist/ ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst + tar c -C ./dist/ ./lib/ollama/mlx_cuda_v13 ./lib/ollama/include | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-mlx.tar.zst fi diff --git a/types/model/config.go b/types/model/config.go index 96aec8cb1..794c9758b 100644 --- a/types/model/config.go +++ b/types/model/config.go @@ -19,6 +19,7 @@ type ConfigV2 struct { ContextLen int `json:"context_length,omitempty"` EmbedLen int `json:"embedding_length,omitempty"` BaseName string `json:"base_name,omitempty"` + Draft *Draft `json:"draft,omitempty"` // required by spec Architecture string `json:"architecture"` @@ -26,6 +27,14 @@ type ConfigV2 struct { RootFS RootFS `json:"rootfs"` } +// Draft describes an auxiliary draft model stored in the same manifest. +type Draft struct { + ModelFormat string `json:"model_format,omitempty"` + Architecture string `json:"architecture,omitempty"` + TensorPrefix string `json:"tensor_prefix,omitempty"` + Config string `json:"config,omitempty"` +} + // RootFS represents the root filesystem configuration for a model. type RootFS struct { Type string `json:"type"` diff --git a/x/create/client/create.go b/x/create/client/create.go index f386813e2..d8a1a8f80 100644 --- a/x/create/client/create.go +++ b/x/create/client/create.go @@ -23,6 +23,7 @@ import ( "github.com/ollama/ollama/progress" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/x/create" + imagemanifest "github.com/ollama/ollama/x/imagegen/manifest" "github.com/ollama/ollama/x/safetensors" ) @@ -34,6 +35,7 @@ type ModelfileConfig struct { Template string System string License string + Draft string Parser string Renderer string Parameters map[string]any @@ -67,6 +69,8 @@ func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig, mfConfig.System = cmd.Args case "license": mfConfig.License = cmd.Args + case "draft": + mfConfig.Draft = cmd.Args case "parser": mfConfig.Parser = cmd.Args case "renderer": @@ -108,10 +112,12 @@ func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig, // CreateOptions holds all options for model creation. type CreateOptions struct { - ModelName string - ModelDir string - Quantize string // "int4", "int8", "nvfp4", "mxfp4", or "mxfp8" for quantization - Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile + ModelName string + ModelDir string + Quantize string // "int4", "int8", "nvfp4", "mxfp4", or "mxfp8" for quantization + DraftQuantize string // optional quantization level for draft model tensors + Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile + BaseConfig *model.ConfigV2 } // CreateModel imports a model from a local directory. @@ -121,11 +127,23 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error { // Detect model type isSafetensors := create.IsSafetensorsModelDir(opts.ModelDir) isImageGen := create.IsTensorModelDir(opts.ModelDir) + hasDraft := opts.Modelfile != nil && opts.Modelfile.Draft != "" + isBaseModelWithDraft := hasDraft && !isSafetensors && create.IsSafetensorsLLMModel(opts.ModelDir) + if opts.DraftQuantize != "" && !hasDraft { + return fmt.Errorf("--draft-quantize requires a DRAFT model") + } - if !isSafetensors && !isImageGen { + if !isSafetensors && !isImageGen && !isBaseModelWithDraft { return fmt.Errorf("%s is not a supported model directory (needs config.json + *.safetensors or model_index.json)", opts.ModelDir) } + if hasDraft && !create.IsSafetensorsModelDir(opts.Modelfile.Draft) { + return fmt.Errorf("draft %s is not a supported safetensors model directory", opts.Modelfile.Draft) + } + if hasDraft && isImageGen { + return fmt.Errorf("draft models are only supported for safetensors LLM models") + } + // Determine model type settings var modelType, spinnerKey string var capabilities []string @@ -138,6 +156,9 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error { parserName = getParserName(opts.ModelDir) rendererName = getRendererName(opts.ModelDir) capabilities = inferSafetensorsCapabilities(opts.ModelDir, resolveParserName(opts.Modelfile, parserName)) + } else if isBaseModelWithDraft { + modelType = "safetensors model" + spinnerKey = "create" } else { modelType = "image generation model" spinnerKey = "imagegen" @@ -156,13 +177,44 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error { p.Add(spinnerKey, spinner) } - // Create the model using shared callbacks + var draftLayers []create.LayerInfo var err error + if hasDraft { + draftLayers, err = create.CreateDraftSafetensorsLayers( + opts.Modelfile.Draft, + "draft.", + "draft", + opts.DraftQuantize, + newLayerCreator(), + newTensorLayerCreator(), + progressFn, + ) + if err != nil { + spinner.Stop() + return err + } + } + + if isBaseModelWithDraft { + err = createModelFromBaseWithDraft(opts, draftLayers, progressFn) + spinner.Stop() + if err != nil { + return err + } + fmt.Printf("Created safetensors model '%s'\n", opts.ModelName) + return nil + } + + // Create the model using shared callbacks if isSafetensors { + writer := newManifestWriter(opts, capabilities, parserName, rendererName) + if len(draftLayers) > 0 { + writer = appendLayersManifestWriter(writer, draftLayers) + } err = create.CreateSafetensorsModel( opts.ModelName, opts.ModelDir, opts.Quantize, newLayerCreator(), newTensorLayerCreator(), - newManifestWriter(opts, capabilities, parserName, rendererName), + writer, progressFn, newPackedTensorLayerCreator(), ) @@ -184,6 +236,68 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error { return nil } +func appendLayersManifestWriter(next create.ManifestWriter, extra []create.LayerInfo) create.ManifestWriter { + return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error { + layers = append(layers, extra...) + return next(modelName, config, layers) + } +} + +func createModelFromBaseWithDraft(opts CreateOptions, draftLayers []create.LayerInfo, progressFn func(string)) error { + progressFn(fmt.Sprintf("loading base model %s", opts.ModelDir)) + baseManifest, err := imagemanifest.LoadManifest(opts.ModelDir) + if err != nil { + return err + } + + baseConfig, err := readConfigV2(baseManifest) + if err != nil { + return err + } + opts.BaseConfig = baseConfig + + configLayer := baseManifest.GetConfigLayer("config.json") + if configLayer == nil { + return fmt.Errorf("base model %s does not contain config.json", opts.ModelDir) + } + + layers := make([]create.LayerInfo, 0, len(baseManifest.Manifest.Layers)+len(draftLayers)) + for _, layer := range baseManifest.Manifest.Layers { + layers = append(layers, create.LayerInfo{ + Digest: layer.Digest, + Size: layer.Size, + MediaType: layer.MediaType, + Name: layer.Name, + }) + } + layers = append(layers, draftLayers...) + + progressFn(fmt.Sprintf("writing manifest for %s", opts.ModelName)) + return newManifestWriter(opts, baseConfig.Capabilities, baseConfig.Parser, baseConfig.Renderer)( + opts.ModelName, + create.LayerInfo{ + Digest: configLayer.Digest, + Size: configLayer.Size, + MediaType: configLayer.MediaType, + Name: configLayer.Name, + }, + layers, + ) +} + +func readConfigV2(m *imagemanifest.ModelManifest) (*model.ConfigV2, error) { + data, err := os.ReadFile(m.BlobPath(m.Manifest.Config.Digest)) + if err != nil { + return nil, fmt.Errorf("failed to read base config: %w", err) + } + + var cfg model.ConfigV2 + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("failed to parse base config: %w", err) + } + return &cfg, nil +} + func inferSafetensorsCapabilities(modelDir, parserName string) []string { capabilities := []string{"completion"} @@ -359,14 +473,26 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re } } - // Create config blob with version requirement - configData := model.ConfigV2{ - ModelFormat: "safetensors", - FileType: strings.ToLower(strings.TrimSpace(opts.Quantize)), - Capabilities: caps, - Requires: MinOllamaVersion, - Parser: resolveParserName(opts.Modelfile, parserName), - Renderer: resolveRendererName(opts.Modelfile, rendererName), + // Create config blob with version requirement. + configData := model.ConfigV2{} + if opts.BaseConfig != nil { + configData = *opts.BaseConfig + } + configData.ModelFormat = "safetensors" + if opts.Quantize != "" || configData.FileType == "" { + configData.FileType = strings.ToLower(strings.TrimSpace(opts.Quantize)) + } + configData.Capabilities = caps + configData.Requires = MinOllamaVersion + configData.Parser = resolveParserName(opts.Modelfile, parserName) + configData.Renderer = resolveRendererName(opts.Modelfile, rendererName) + if opts.Modelfile != nil && opts.Modelfile.Draft != "" { + configData.Draft = &model.Draft{ + ModelFormat: "safetensors", + Architecture: "Gemma4AssistantForCausalLM", + TensorPrefix: "draft.", + Config: "draft/config.json", + } } configJSON, err := json.Marshal(configData) if err != nil { diff --git a/x/create/client/create_test.go b/x/create/client/create_test.go index ead7e4233..bc8f4d6e9 100644 --- a/x/create/client/create_test.go +++ b/x/create/client/create_test.go @@ -44,6 +44,7 @@ func TestModelfileConfig(t *testing.T) { func TestConfigFromModelfile(t *testing.T) { modelfile, err := parser.ParseFile(strings.NewReader(` FROM ./model +DRAFT ./assistant TEMPLATE {{ .Prompt }} PARAMETER temperature 0.7 PARAMETER stop USER: @@ -66,6 +67,10 @@ PARAMETER stop ASSISTANT: t.Fatalf("Template = %q, want %q", mfConfig.Template, "{{ .Prompt }}") } + if mfConfig.Draft != "./assistant" { + t.Fatalf("Draft = %q, want %q", mfConfig.Draft, "./assistant") + } + if got := mfConfig.Parameters["temperature"]; got != float32(0.7) { t.Fatalf("temperature = %#v, want %v", got, float32(0.7)) } @@ -153,11 +158,23 @@ func TestCreateModel_NotSafetensorsDir(t *testing.T) { } } +func TestCreateModel_DraftQuantizeRequiresDraft(t *testing.T) { + err := CreateModel(CreateOptions{ + ModelName: "test-model", + ModelDir: t.TempDir(), + DraftQuantize: "mxfp8", + }, nil) + if err == nil || !strings.Contains(err.Error(), "--draft-quantize requires a DRAFT model") { + t.Fatalf("error = %v, want draft-quantize requires DRAFT", err) + } +} + func TestCreateOptions(t *testing.T) { opts := CreateOptions{ - ModelName: "my-model", - ModelDir: "/path/to/model", - Quantize: "fp8", + ModelName: "my-model", + ModelDir: "/path/to/model", + Quantize: "fp8", + DraftQuantize: "mxfp8", Modelfile: &ModelfileConfig{ Template: "test", System: "system", @@ -179,6 +196,9 @@ func TestCreateOptions(t *testing.T) { if opts.Quantize != "fp8" { t.Errorf("Quantize = %q, want %q", opts.Quantize, "fp8") } + if opts.DraftQuantize != "mxfp8" { + t.Errorf("DraftQuantize = %q, want %q", opts.DraftQuantize, "mxfp8") + } if opts.Modelfile == nil { t.Error("Modelfile should not be nil") } @@ -286,6 +306,9 @@ func TestCreateOptions_Defaults(t *testing.T) { if opts.Quantize != "" { t.Errorf("Quantize should be empty by default, got %q", opts.Quantize) } + if opts.DraftQuantize != "" { + t.Errorf("DraftQuantize should be empty by default, got %q", opts.DraftQuantize) + } // Modelfile should default to nil if opts.Modelfile != nil { @@ -518,6 +541,48 @@ func TestNewManifestWriter_PopulatesFileTypeFromQuantize(t *testing.T) { } } +func TestNewManifestWriter_PopulatesDraftMetadata(t *testing.T) { + t.Setenv("OLLAMA_MODELS", t.TempDir()) + + opts := CreateOptions{ + ModelName: "test-draft", + ModelDir: t.TempDir(), + Modelfile: &ModelfileConfig{Draft: "/tmp/assistant"}, + } + + writer := newManifestWriter(opts, []string{"completion"}, "gemma4", "gemma4") + if err := writer(opts.ModelName, create.LayerInfo{}, nil); err != nil { + t.Fatalf("newManifestWriter() error = %v", err) + } + + name := model.ParseName(opts.ModelName) + mf, err := manifest.ParseNamedManifest(name) + if err != nil { + t.Fatalf("ParseNamedManifest() error = %v", err) + } + + configPath, err := manifest.BlobsPath(mf.Config.Digest) + if err != nil { + t.Fatalf("BlobsPath() error = %v", err) + } + + data, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("ReadFile() error = %v", err) + } + + var cfg model.ConfigV2 + if err := json.Unmarshal(data, &cfg); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if cfg.Draft == nil { + t.Fatal("Draft metadata missing") + } + if cfg.Draft.TensorPrefix != "draft." || cfg.Draft.Config != "draft/config.json" { + t.Fatalf("Draft = %#v, want draft prefix/config", cfg.Draft) + } +} + func TestSupportsThinking(t *testing.T) { tests := []struct { name string diff --git a/x/create/create.go b/x/create/create.go index 79b174487..ed07be07c 100644 --- a/x/create/create.go +++ b/x/create/create.go @@ -7,6 +7,7 @@ import ( "io" "math" "os" + "path" "path/filepath" "regexp" "slices" @@ -724,12 +725,16 @@ func detectModelOptQuantization(modelDir string) bool { } func resolveEffectiveQuantization(cfg sourceModelConfig, sourceKind sourceQuantizedKind, requested string) (string, error) { + return resolveEffectiveQuantizationForFlag(cfg, sourceKind, requested, "--quantize") +} + +func resolveEffectiveQuantizationForFlag(cfg sourceModelConfig, sourceKind sourceQuantizedKind, requested, flagName string) (string, error) { switch sourceKind { case sourceQuantizedKindNone: return requested, nil case sourceQuantizedKindPrequantized: if requested != "" { - return "", fmt.Errorf("cannot requantize already-quantized source model with --quantize %q", requested) + return "", fmt.Errorf("cannot requantize already-quantized source model with %s %q", flagName, requested) } return "", nil case sourceQuantizedKindSourceFP8: @@ -746,7 +751,7 @@ func resolveEffectiveQuantization(cfg sourceModelConfig, sourceKind sourceQuanti case "nvfp4", "mxfp4", "mxfp8": return requested, nil default: - return "", fmt.Errorf("cannot convert already-quantized fp8 source model with --quantize %q", requested) + return "", fmt.Errorf("cannot convert already-quantized fp8 source model with %s %q", flagName, requested) } } return "mxfp8", nil @@ -810,6 +815,7 @@ var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{ "Gemma4ForCausalLM": newGemma4ImportTransform, "Gemma4ForConditionalGeneration": newGemma4ImportTransform, "LagunaForCausalLM": newLagunaImportTransform, + "Gemma4AssistantForCausalLM": newGemma4ImportTransform, } func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) { @@ -1167,6 +1173,136 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La return nil } +func normalizeRequestedQuantization(flagName, quantize string) (string, error) { + q := normalizeQuantType(strings.TrimSpace(quantize)) + switch q { + case "", "int4", "int8", "nvfp4", "mxfp4", "mxfp8": + return q, nil + default: + return "", fmt.Errorf("unsupported %s %q: supported types are int4, int8, nvfp4, mxfp4, mxfp8", flagName, quantize) + } +} + +// CreateDraftSafetensorsLayers imports an assistant/draft safetensors model +// into prefixed tensor and config layers. When draftQuantize is non-empty, +// eligible draft tensors are quantized with the same per-architecture policy +// used by target safetensors imports. +func CreateDraftSafetensorsLayers(modelDir, tensorPrefix, configPrefix, draftQuantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, fn func(status string)) ([]LayerInfo, error) { + if tensorPrefix == "" { + return nil, fmt.Errorf("draft tensor prefix must not be empty") + } + if configPrefix == "" { + return nil, fmt.Errorf("draft config prefix must not be empty") + } + effectiveQuantize, err := normalizeRequestedQuantization("--draft-quantize", draftQuantize) + if err != nil { + return nil, err + } + + var importTransform tensorImportTransform = noopImportTransform{} + if effectiveQuantize != "" { + sourceConfig, err := readSourceModelConfig(modelDir) + if err != nil { + return nil, fmt.Errorf("failed to read draft config.json: %w", err) + } + sourceQuantKind, err := inspectSourceQuantization(modelDir, sourceConfig) + if err != nil { + return nil, fmt.Errorf("failed to inspect draft quantization: %w", err) + } + effectiveQuantize, err = resolveEffectiveQuantizationForFlag(sourceConfig, sourceQuantKind, effectiveQuantize, "--draft-quantize") + if err != nil { + return nil, err + } + importTransform, err = newTensorImportTransform(modelDir, sourceConfig) + if err != nil { + return nil, fmt.Errorf("failed to construct draft import transform for architecture %q: %w", sourceConfig.Architecture(), err) + } + } + + entries, err := os.ReadDir(modelDir) + if err != nil { + return nil, fmt.Errorf("failed to read draft directory: %w", err) + } + + var layers []LayerInfo + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") { + continue + } + + stPath := filepath.Join(modelDir, entry.Name()) + extractor, err := safetensors.OpenForExtraction(stPath) + if err != nil { + return nil, fmt.Errorf("failed to open draft %s: %w", stPath, err) + } + + tensorNames := extractor.ListTensors() + fn(fmt.Sprintf("importing draft %s (%d tensors%s)", entry.Name(), len(tensorNames), importQuantizationStatus(sourceQuantizedKindNone, effectiveQuantize))) + for _, tensorName := range tensorNames { + if importTransform.skipTensor(tensorName) { + continue + } + td, err := extractor.GetTensor(tensorName) + if err != nil { + extractor.Close() + return nil, fmt.Errorf("failed to get draft tensor %s: %w", tensorName, err) + } + + outTDs, err := importTransform.transformTensor(td) + if err != nil { + extractor.Close() + return nil, fmt.Errorf("failed to transform draft tensor %s: %w", tensorName, err) + } + for _, transformedTD := range outTDs { + if transformedTD == nil { + continue + } + outTD := transformedTD.WithName(tensorPrefix + transformedTD.Name) + quantizeType := "" + if effectiveQuantize != "" { + quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, effectiveQuantize) + if isEmbedTokensWeight(outTD.Name) { + quantizeType = "" + } + } + newLayers, err := createTensorLayer(outTD.SafetensorsReader(), outTD.Name, outTD.Dtype, outTD.Shape, quantizeType) + if err != nil { + extractor.Close() + return nil, fmt.Errorf("failed to create draft layer for %s: %w", tensorName, err) + } + layers = append(layers, newLayers...) + } + } + extractor.Close() + } + + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") { + continue + } + if entry.Name() == "model.safetensors.index.json" { + continue + } + + cfgPath := entry.Name() + fullPath := filepath.Join(modelDir, cfgPath) + fn(fmt.Sprintf("importing draft config %s", cfgPath)) + + f, err := os.Open(fullPath) + if err != nil { + return nil, fmt.Errorf("failed to open draft %s: %w", cfgPath, err) + } + layer, err := createLayer(f, "application/vnd.ollama.image.json", path.Join(configPrefix, cfgPath)) + f.Close() + if err != nil { + return nil, fmt.Errorf("failed to create draft config layer for %s: %w", cfgPath, err) + } + layers = append(layers, layer) + } + + return layers, nil +} + func shouldSkipSourceCompanion(name string, tensorSet map[string]struct{}, sourceTensorFiles map[string]string) bool { switch { case strings.HasSuffix(name, ".scales"): diff --git a/x/create/create_test.go b/x/create/create_test.go index 3dfd18756..35b7e4236 100644 --- a/x/create/create_test.go +++ b/x/create/create_test.go @@ -452,6 +452,137 @@ func TestCreateSafetensorsModel(t *testing.T) { } } +func TestCreateDraftSafetensorsLayersPrefixesTensorsAndConfigs(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"gemma4_assistant"}`), 0o644); err != nil { + t.Fatal(err) + } + createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{ + st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{2, 2}, make([]byte, 8)), + }) + + var tensorNames []string + createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) { + data, err := io.ReadAll(r) + if err != nil { + return LayerInfo{}, err + } + return LayerInfo{Digest: "sha256:json_" + name, Size: int64(len(data)), MediaType: mediaType, Name: name}, nil + } + createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) { + data, err := io.ReadAll(r) + if err != nil { + return nil, err + } + tensorNames = append(tensorNames, name) + tensorName, tensorShape := readSingleTensorNameAndShape(t, data) + if tensorName != name { + t.Fatalf("safetensors key = %q, want %q", tensorName, name) + } + if !slices.Equal(tensorShape, shape) { + t.Fatalf("shape = %v, want %v", tensorShape, shape) + } + if quantize != "" { + t.Fatalf("draft quantize = %q, want empty", quantize) + } + return []LayerInfo{{Digest: "sha256:tensor_" + name, Size: int64(len(data)), MediaType: "application/vnd.ollama.image.tensor", Name: name}}, nil + } + + layers, err := CreateDraftSafetensorsLayers(dir, "draft.", "draft", "", createLayer, createTensorLayer, func(string) {}) + if err != nil { + t.Fatal(err) + } + + if !slices.Contains(tensorNames, "draft.model.layers.0.self_attn.q_proj.weight") { + t.Fatalf("draft tensor was not prefixed: %v", tensorNames) + } + var hasDraftConfig bool + for _, layer := range layers { + if layer.Name == "draft/config.json" && layer.MediaType == "application/vnd.ollama.image.json" { + hasDraftConfig = true + } + } + if !hasDraftConfig { + t.Fatalf("draft/config.json layer missing: %#v", layers) + } +} + +func TestCreateDraftSafetensorsLayersQuantizesEligibleTensors(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{ + "architectures":["Gemma4AssistantForCausalLM"], + "num_hidden_layers":8 + }`), 0o644); err != nil { + t.Fatal(err) + } + createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{ + st.NewTensorDataFromBytes("model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)), + st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)), + st.NewTensorDataFromBytes("model.layers.0.input_layernorm.weight", "BF16", []int32{64}, make([]byte, 64*2)), + }) + + createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) { + data, err := io.ReadAll(r) + if err != nil { + return LayerInfo{}, err + } + return LayerInfo{Digest: "sha256:json_" + name, Size: int64(len(data)), MediaType: mediaType, Name: name}, nil + } + quantizeByName := make(map[string]string) + createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) { + data, err := io.ReadAll(r) + if err != nil { + return nil, err + } + quantizeByName[name] = quantize + return []LayerInfo{{Digest: "sha256:tensor_" + name, Size: int64(len(data)), MediaType: "application/vnd.ollama.image.tensor", Name: name}}, nil + } + + if _, err := CreateDraftSafetensorsLayers(dir, "draft.", "draft", "MXFP8", createLayer, createTensorLayer, func(string) {}); err != nil { + t.Fatal(err) + } + + if got := quantizeByName["draft.model.layers.0.self_attn.q_proj.weight"]; got != "mxfp8" { + t.Fatalf("q_proj draft quantize = %q, want mxfp8", got) + } + if got := quantizeByName["draft.model.layers.0.input_layernorm.weight"]; got != "" { + t.Fatalf("norm draft quantize = %q, want empty", got) + } + if got := quantizeByName["draft.model.embed_tokens.weight"]; got != "" { + t.Fatalf("embed_tokens draft quantize = %q, want empty", got) + } +} + +func TestCreateDraftSafetensorsLayersRejectsUnsupportedDraftQuantize(t *testing.T) { + _, err := CreateDraftSafetensorsLayers(t.TempDir(), "draft.", "draft", "bogus", nil, nil, func(string) {}) + if err == nil || !strings.Contains(err.Error(), "unsupported --draft-quantize") { + t.Fatalf("error = %v, want unsupported --draft-quantize", err) + } +} + +func readSingleTensorNameAndShape(t *testing.T, data []byte) (string, []int32) { + t.Helper() + + var headerSize uint64 + if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil { + t.Fatalf("failed to read header size: %v", err) + } + + var header map[string]struct { + Shape []int32 `json:"shape"` + } + if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil { + t.Fatalf("failed to parse header: %v", err) + } + for name, info := range header { + if name != "__metadata__" { + return name, info.Shape + } + } + t.Fatal("no tensor entry found in header") + return "", nil +} + func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) { dir := t.TempDir() diff --git a/x/mlxrunner/cache/cache.go b/x/mlxrunner/cache/cache.go index 08d90a211..d4485df50 100644 --- a/x/mlxrunner/cache/cache.go +++ b/x/mlxrunner/cache/cache.go @@ -54,6 +54,81 @@ type Attention interface { Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory } +// Viewer exposes a read-only attention history for a cache. +type Viewer interface { + View(b *batch.Batch) *nn.KVHistory +} + +type speculativeCommitter interface { + Cache + commit(n int) +} + +// Speculation is an isolated cache transaction for speculative target +// validation. Updates record generated K/V without mutating the live caches; +// Commit appends only the accepted prefix to the live caches. +type Speculation struct { + layers []speculativeCommitter +} + +// BeginSpeculation returns cache wrappers suitable for a speculative target +// forward. The returned caches must only be used for that forward. +func BeginSpeculation(caches []Cache) ([]Cache, *Speculation, bool) { + specCaches := make([]Cache, len(caches)) + layers := make([]speculativeCommitter, len(caches)) + + for i, c := range caches { + switch c := c.(type) { + case nil: + case *RotatingKVCache: + sc := newSpeculativeRotatingKVCache(c) + specCaches[i] = sc + layers[i] = sc + case *KVCache: + sc := newSpeculativeKVCache(c) + specCaches[i] = sc + layers[i] = sc + default: + return nil, nil, false + } + } + + return specCaches, &Speculation{layers: layers}, true +} + +// BeginIsolatedSpeculation returns cache wrappers that never mutate live cache +// state. It is intended for correctness instrumentation, not the hot path. +func BeginIsolatedSpeculation(caches []Cache) ([]Cache, bool) { + specCaches := make([]Cache, len(caches)) + + for i, c := range caches { + switch c := c.(type) { + case nil: + case *RotatingKVCache: + specCaches[i] = newSpeculativeRotatingKVCache(c) + case *KVCache: + specCaches[i] = newIsolatedKVCache(c) + default: + return nil, false + } + } + + return specCaches, true +} + +// Commit appends the accepted prefix from the speculative forward to the live +// caches. The target bonus token is intentionally not committed. +func (s *Speculation) Commit(n int) { + if s == nil { + return + } + for _, layer := range s.layers { + if layer != nil { + layer.commit(n) + } + } +} + type KVCache struct { keys, values *mlx.Array offset int @@ -70,6 +145,15 @@ func (c *KVCache) Update(_ *batch.Batch, keys, values *mlx.Array) *nn.KVHistory return nn.NewKVHistory(newK, newV, nil) } +// View returns the current cache contents as attention history without writing. +func (c *KVCache) View(_ *batch.Batch) *nn.KVHistory { + state := c.State() + if len(state) < 2 { + return nil + } + return nn.NewKVHistory(state[0], state[1], nil) +} + // appendKV is the raw write path shared by Update and Restore. func (c *KVCache) appendKV(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3) @@ -250,6 +334,94 @@ func (c *KVCache) Free() { func (c *KVCache) Offset() int { return c.offset } +type speculativeBase struct { + offset int +} + +func (s *speculativeBase) Free() {} +func (s *speculativeBase) Offset() int { return s.offset } +func (s *speculativeBase) Snapshot(int) Snapshot { return nil } +func (s *speculativeBase) Restore(Snapshot, int) bool { return false } +func (s *speculativeBase) Merge(parent, child Snapshot) Snapshot { return nil } +func (s *speculativeBase) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) { + return nil, snapshot +} + +type speculativeKVCache struct { + speculativeBase + target *KVCache + start int + end int +} + +func newSpeculativeKVCache(target *KVCache) *speculativeKVCache { + return &speculativeKVCache{ + speculativeBase: speculativeBase{offset: target.Offset()}, + target: target, + start: target.Offset(), + end: target.Offset(), + } +} + +func (c *speculativeKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory { + history := c.target.Update(b, keys, values) + c.offset = c.target.Offset() + c.end = c.target.Offset() + return history +} + +func (c *speculativeKVCache) State() []*mlx.Array { + return c.target.State() +} + +func (c *speculativeKVCache) commit(n int) { + target := max(c.start, c.start+n) + if target > c.end { + target = c.end + } + c.target.offset = target + c.offset = target +} + +type isolatedKVCache struct { + speculativeBase + target *KVCache + keys, values *mlx.Array +} + +func newIsolatedKVCache(target *KVCache) *isolatedKVCache { + return &isolatedKVCache{ + speculativeBase: speculativeBase{offset: target.Offset()}, + target: target, + } +} + +func (c *isolatedKVCache) Update(_ *batch.Batch, keys, values *mlx.Array) *nn.KVHistory { + c.keys = concatKV(c.keys, keys) + c.values = concatKV(c.values, values) + c.offset += keys.Dim(2) + + state := c.target.State() + if len(state) < 2 { + return nn.NewKVHistory(c.keys, c.values, nil) + } + return nn.NewKVHistory(state[0].Concatenate(2, c.keys), state[1].Concatenate(2, c.values), nil) +} + +func (c *isolatedKVCache) State() []*mlx.Array { + if c.keys == nil || c.values == nil { + return c.target.State() + } + state := c.target.State() + if len(state) < 2 { + return []*mlx.Array{c.keys, c.values} + } + return []*mlx.Array{ + state[0].Concatenate(2, c.keys), + state[1].Concatenate(2, c.values), + } +} + // RotatingKVCache implements sliding window attention with bounded memory type RotatingKVCache struct { maxSize int @@ -275,6 +447,127 @@ func (c *RotatingKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KV }) } +// View returns the current rotating cache contents in logical order for +// assistant KV sharing. +func (c *RotatingKVCache) View(_ *batch.Batch) *nn.KVHistory { + k, v := c.logicalTail(c.maxSize - 1) + if k == nil || v == nil { + return nil + } + return nn.NewKVHistory(k, v, nil) +} + +func (c *RotatingKVCache) logicalTail(keep int) (*mlx.Array, *mlx.Array) { + state := c.State() + if len(state) < 2 || keep <= 0 { + return nil, nil + } + + keys, values := state[0], state[1] + K := keys.Dim(2) + if K == 0 { + return nil, nil + } + + keep = min(keep, K) + if K > c.maxSize || c.offset < c.maxSize { + start := K - keep + return keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice()), + values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice()) + } + + oldest := c.idx % K + var logicalK, logicalV *mlx.Array + if oldest == 0 { + logicalK, logicalV = keys, values + } else { + tailK := keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(oldest, K), mlx.Slice()) + tailV := values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(oldest, K), mlx.Slice()) + headK := keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, oldest), mlx.Slice()) + headV := values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, oldest), mlx.Slice()) + logicalK = tailK.Concatenate(2, headK) + logicalV = tailV.Concatenate(2, headV) + } + + start := K - keep + return logicalK.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice()), + logicalV.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice()) +} + +type speculativeRotatingKVCache struct { + speculativeBase + target *RotatingKVCache + keys, values *mlx.Array +} + +func newSpeculativeRotatingKVCache(target *RotatingKVCache) *speculativeRotatingKVCache { + return &speculativeRotatingKVCache{ + speculativeBase: speculativeBase{offset: target.Offset()}, + target: target, + } +} + +func (c *speculativeRotatingKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory { + c.keys = concatKV(c.keys, keys) + c.values = concatKV(c.values, values) + c.offset += keys.Dim(2) + + oldK, oldV := c.target.logicalTail(c.target.maxSize - 1) + histK, histV := c.keys, c.values + if oldK != nil && oldV != nil { + histK = oldK.Concatenate(2, c.keys) + histV = oldV.Concatenate(2, c.values) + } + + return nn.NewKVHistory(histK, histV, logicalSlidingApplier{ + b: b, + K: histK.Dim(2), + window: c.target.maxSize, + dtype: keys.DType(), + }) +} + +func (c *speculativeRotatingKVCache) State() []*mlx.Array { + if c.keys == nil || c.values == nil { + return c.target.State() + } + oldK, oldV := c.target.logicalTail(c.target.maxSize - 1) + if oldK == nil || oldV == nil { + return []*mlx.Array{c.keys, c.values} + } + return []*mlx.Array{oldK.Concatenate(2, c.keys), oldV.Concatenate(2, c.values)} +} + +func (c *speculativeRotatingKVCache) commit(n int) { + if c.keys == nil || c.values == nil || n <= 0 { + return + } + n = min(n, c.keys.Dim(2)) + c.target.appendKV(prefixKV(c.keys, n), prefixKV(c.values, n)) +} + +type logicalSlidingApplier struct { + b *batch.Batch + K int + window int + dtype mlx.DType +} + +func (a logicalSlidingApplier) ApplyMask(logical nn.AttentionMask) nn.AttentionMask { + return logical.Intersect(nn.SlidingWindowMask(a.b, a.K, a.window, a.dtype)) +} + +func concatKV(prev, next *mlx.Array) *mlx.Array { + if prev == nil { + return next + } + return prev.Concatenate(2, next) +} + +func prefixKV(a *mlx.Array, n int) *mlx.Array { + return a.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, n), mlx.Slice()) +} + // appendKV is the raw write path shared by Update and Restore — // routes to concat for prefill (L > 1) and update for decode. func (c *RotatingKVCache) appendKV(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { diff --git a/x/mlxrunner/cache/rotating_attention_test.go b/x/mlxrunner/cache/rotating_attention_test.go index 38499d828..e1cef998b 100644 --- a/x/mlxrunner/cache/rotating_attention_test.go +++ b/x/mlxrunner/cache/rotating_attention_test.go @@ -114,6 +114,61 @@ func TestRotatingKVCacheDecodeParity(t *testing.T) { } } +func TestAssistantSharedHistoryL1MasksMatchNoMask(t *testing.T) { + skipIfNoMLX(t) + if !mlx.MetalIsAvailable() { + t.Skip("MLX Metal not available") + } + const H, D = 1, 4 + const window = 4 + const total = 7 + const scale = 1.0 + + q := mlx.FromValues([]float32{0.7, -0.4, 0.2, 0.9}, 1, H, 1, D) + mlx.Eval(q) + + full := NewKVCache() + sliding := NewRotatingKVCache(window) + for pos := range total { + kVals := make([]float32, H*D) + vVals := make([]float32, H*D) + for i := range kVals { + kVals[i] = 0.1*float32(pos+1) + 0.01*float32(i) + vVals[i] = -0.1*float32(pos+1) + 0.01*float32(i) + } + k := mlx.FromValues(kVals, 1, H, 1, D) + v := mlx.FromValues(vVals, 1, H, 1, D) + full.Update(newKVBatch(full.Offset(), 1), k, v) + sliding.Update(newKVBatch(sliding.Offset(), 1), k, v) + } + + b := newKVBatch(total-1, 1) + slidingHistory := sliding.View(b) + cases := []struct { + name string + h *nn.KVHistory + mask nn.AttentionMask + }{ + {name: "full", h: full.View(b), mask: nn.CausalMask()}, + {name: "sliding", h: slidingHistory, mask: nn.CausalMask().Intersect(nn.SlidingWindowMask(b, slidingHistory.K().Dim(2), window, q.DType()))}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := nn.ScaledDotProductAttention(b, q, scale, nn.WithKVHistory(tc.h), nn.WithMask(tc.mask)) + want := mlx.FastScaledDotProductAttention(q, tc.h.K(), tc.h.V(), scale, "", nil) + + mlx.Eval(got, want) + gs, ws := got.Floats(), want.Floats() + for i := range ws { + if math.Abs(float64(gs[i]-ws[i])) > 1e-5 { + t.Fatalf("index %d: got %v, want %v", i, gs[i], ws[i]) + } + } + }) + } +} + // TestRotatingKVCachePrefillParity drives an L>1 prefill into a // rotating cache and verifies SDPA output through WithKVHistory // matches a reference computed from the same K/V with the model mask diff --git a/x/mlxrunner/mlx/array_test.go b/x/mlxrunner/mlx/array_test.go index 375e674d9..7bcf35c2b 100644 --- a/x/mlxrunner/mlx/array_test.go +++ b/x/mlxrunner/mlx/array_test.go @@ -41,3 +41,36 @@ func TestFromValues(t *testing.T) { } }) } + +func TestComparisonOpsAndBernoulli(t *testing.T) { + skipIfNoMLX(t) + + a := FromValues([]float32{1, 2, 3}, 3) + b := FromValues([]float32{1, 1, 4}, 3) + eq := a.Equal(b).AsType(DTypeInt32) + gt := a.Greater(b).AsType(DTypeInt32) + le := a.LessEqual(b).AsType(DTypeInt32) + bern := Bernoulli(FromValues([]float32{1, 0}, 2)).AsType(DTypeInt32) + Eval(eq, gt, le, bern) + + for name, tc := range map[string]struct { + got []int + want []int + }{ + "equal": {eq.Ints(), []int{1, 0, 0}}, + "greater": {gt.Ints(), []int{0, 1, 0}}, + "lessEqual": {le.Ints(), []int{1, 0, 1}}, + "bernoulli": {bern.Ints(), []int{1, 0}}, + } { + t.Run(name, func(t *testing.T) { + if len(tc.got) != len(tc.want) { + t.Fatalf("got %v, want %v", tc.got, tc.want) + } + for i := range tc.want { + if tc.got[i] != tc.want[i] { + t.Fatalf("got %v, want %v", tc.got, tc.want) + } + } + }) + } +} diff --git a/x/mlxrunner/mlx/ops.go b/x/mlxrunner/mlx/ops.go index faa44898c..411b118e3 100644 --- a/x/mlxrunner/mlx/ops.go +++ b/x/mlxrunner/mlx/ops.go @@ -137,12 +137,30 @@ func (t *Array) LogsumexpAxis(axis int, keepDims bool) *Array { return out } +func (t *Array) Equal(other *Array) *Array { + out := New("EQUAL") + C.mlx_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) + return out +} + +func (t *Array) Greater(other *Array) *Array { + out := New("GREATER") + C.mlx_greater(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) + return out +} + func (t *Array) Less(other *Array) *Array { out := New("LESS") C.mlx_less(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) return out } +func (t *Array) LessEqual(other *Array) *Array { + out := New("LESS_EQUAL") + C.mlx_less_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) + return out +} + func (t *Array) MaxAxis(axis int, keepDims bool) *Array { out := New("MAX_AXIS") C.mlx_max_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx) diff --git a/x/mlxrunner/mlx/random.go b/x/mlxrunner/mlx/random.go index 82c3d6785..009bddc39 100644 --- a/x/mlxrunner/mlx/random.go +++ b/x/mlxrunner/mlx/random.go @@ -3,9 +3,24 @@ package mlx // #include "generated.h" import "C" +import "unsafe" + func (t *Array) Categorical(axis int) *Array { key := New("") out := New("") C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx) return out } + +func Bernoulli(p *Array) *Array { + dims := p.Dims() + shape := make([]C.int, len(dims)) + for i, d := range dims { + shape[i] = C.int(d) + } + + key := New("") + out := New("BERNOULLI") + C.mlx_random_bernoulli(&out.ctx, p.ctx, unsafe.SliceData(shape), C.size_t(len(shape)), key.ctx, DefaultStream().ctx) + return out +} diff --git a/x/mlxrunner/model/base/base.go b/x/mlxrunner/model/base/base.go index 8fa5a2347..4f43d74b6 100644 --- a/x/mlxrunner/model/base/base.go +++ b/x/mlxrunner/model/base/base.go @@ -27,9 +27,40 @@ type Model interface { LoadWeights(tensors map[string]*mlx.Array) error } +// DraftModel is an auxiliary model stored alongside a target model. +type DraftModel interface { + LoadWeights(tensors map[string]*mlx.Array) error +} + +// MTPDefaults holds model-provided draft-token defaults for speculative +// decoding. Environment settings in the runner may override these values. +type MTPDefaults struct { + InitialDraftTokens int + MaxDraftTokens int + Enabled bool +} + +// MTPDefaultsProvider lets a model provide MTP policy defaults from its own +// config without teaching the runner model-specific shape heuristics. +type MTPDefaultsProvider interface { + MTPDraftDefaults(sample bool) MTPDefaults +} + +// MTPDraftModel is a draft model capable of Gemma-style multi-token +// prediction from target token embeddings, target hidden states, and target KV. +type MTPDraftModel interface { + Draft(inputEmbeds *mlx.Array, position int32, caches []cache.Cache) (logits, hidden *mlx.Array) +} + +// MTPEmbeddingModel exposes the target token embedding path used by MTP drafts. +type MTPEmbeddingModel interface { + TokenEmbeddings(inputIDs *mlx.Array) *mlx.Array +} + var ( - mu sync.Mutex - registry = make(map[string]func(root *model.Root) (Model, error)) + mu sync.Mutex + registry = make(map[string]func(root *model.Root) (Model, error)) + draftRegistry = make(map[string]func(root *model.Root, target Model) (DraftModel, error)) ) // Register registers a model constructor by architecture name. @@ -44,6 +75,17 @@ func Register(arch string, fn func(root *model.Root) (Model, error)) { registry[arch] = fn } +// RegisterDraft registers a draft model constructor by architecture name. +func RegisterDraft(arch string, fn func(root *model.Root, target Model) (DraftModel, error)) { + mu.Lock() + defer mu.Unlock() + + if _, exists := draftRegistry[arch]; exists { + panic(fmt.Sprintf("draft model architecture %q already registered", arch)) + } + draftRegistry[arch] = fn +} + // New reads config.json from the manifest, detects the architecture, looks up // the registered constructor, and calls it to create the model (with config // parsed and struct created, but weights not yet loaded). @@ -78,6 +120,51 @@ func New(root *model.Root) (Model, error) { return fn(root) } +// NewDraft constructs the draft model described by the manifest config, if any. +func NewDraft(root *model.Root, target Model) (DraftModel, error) { + if root == nil || root.Draft == nil { + return nil, nil + } + + configPath := root.Draft.Config + if configPath == "" { + configPath = "draft/config.json" + } + configData, err := root.Manifest.ReadConfig(configPath) + if err != nil { + return nil, fmt.Errorf("failed to read %s: %w", configPath, err) + } + + var archConfig struct { + Architectures []string `json:"architectures"` + ModelType string `json:"model_type"` + } + if err := json.Unmarshal(configData, &archConfig); err != nil { + return nil, fmt.Errorf("failed to parse %s: %w", configPath, err) + } + + arch := root.Draft.Architecture + if arch == "" && len(archConfig.Architectures) > 0 { + arch = archConfig.Architectures[0] + } + if arch == "" { + arch = archConfig.ModelType + } + if arch == "" { + return nil, fmt.Errorf("no draft architecture found in %s", configPath) + } + slog.Info("Draft model architecture", "arch", arch) + + mu.Lock() + fn, ok := draftRegistry[arch] + mu.Unlock() + if !ok { + return nil, fmt.Errorf("unsupported draft architecture: %s", arch) + } + + return fn(root, target) +} + // Weights returns a function that loads model weights, then pins all // arrays reachable from the model struct and sweeps everything else. func Weights(m Model) func(map[string]*mlx.Array) error { diff --git a/x/mlxrunner/model/root.go b/x/mlxrunner/model/root.go index 43f6426fb..4a89ca302 100644 --- a/x/mlxrunner/model/root.go +++ b/x/mlxrunner/model/root.go @@ -10,6 +10,7 @@ import ( "strconv" "strings" + modeltypes "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/x/imagegen/manifest" ) @@ -22,6 +23,7 @@ type TensorQuantInfo struct { // Root wraps a ModelManifest with pre-scanned quantization metadata. type Root struct { Manifest *manifest.ModelManifest + Draft *modeltypes.Draft // Backwards-compatible model-level quant metadata (first tensor blob). quantType string @@ -43,6 +45,7 @@ func Open(modelName string) (*Root, error) { Manifest: m, tensorQuant: make(map[string]*TensorQuantInfo), } + root.Draft = readDraftConfig(m) for _, layer := range m.GetTensorLayers("") { blobPath := m.BlobPath(layer.Digest) @@ -68,6 +71,34 @@ func Open(modelName string) (*Root, error) { return root, nil } +func readDraftConfig(m *manifest.ModelManifest) *modeltypes.Draft { + if m == nil || m.Manifest == nil || m.Manifest.Config.Digest == "" { + return nil + } + + data, err := os.ReadFile(m.BlobPath(m.Manifest.Config.Digest)) + if err != nil { + return nil + } + + var cfg modeltypes.ConfigV2 + if err := json.Unmarshal(data, &cfg); err != nil { + return nil + } + if cfg.Draft != nil { + return cfg.Draft + } + + if m.GetConfigLayer("draft/config.json") != nil { + return &modeltypes.Draft{ + ModelFormat: "safetensors", + TensorPrefix: "draft.", + Config: "draft/config.json", + } + } + return nil +} + // Close is a no-op for now (future: release resources). func (r *Root) Close() {} diff --git a/x/mlxrunner/mtp.go b/x/mlxrunner/mtp.go new file mode 100644 index 000000000..491e2706d --- /dev/null +++ b/x/mlxrunner/mtp.go @@ -0,0 +1,959 @@ +package mlxrunner + +import ( + "context" + "fmt" + "log/slog" + "math" + "os" + "strconv" + "strings" + "time" + + "github.com/ollama/ollama/x/mlxrunner/batch" + "github.com/ollama/ollama/x/mlxrunner/cache" + "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/mlxrunner/model/base" + sampler "github.com/ollama/ollama/x/mlxrunner/sample" +) + +const ( + mtpDefaultInitialDraftTokens = 4 + mtpDefaultMaxDraftTokens = 16 +) + +type mtpDraftSchedule string + +const ( + mtpDraftScheduleHeuristic mtpDraftSchedule = "heuristic" + mtpDraftScheduleConstant mtpDraftSchedule = "constant" +) + +type mtpStats struct { + iterations int + drafted int + accepted int + mismatches int + allAccepted int + batched int + serial int + compared int + batchSerialMismatches int + maxDraft int + targetDuration time.Duration + draftDuration time.Duration + validateDuration time.Duration +} + +type mtpOptions struct { + initialDraftTokens int + maxDraftTokens int + draftSchedule mtpDraftSchedule + serialValidate bool + compareSerialValidate bool +} + +func (r *Runner) mtpDefaults(sample bool) base.MTPDefaults { + defaults := base.MTPDefaults{ + InitialDraftTokens: mtpDefaultInitialDraftTokens, + MaxDraftTokens: mtpDefaultMaxDraftTokens, + Enabled: true, + } + if p, ok := r.Model.(base.MTPDefaultsProvider); ok { + defaults = p.MTPDraftDefaults(sample) + } + if defaults.InitialDraftTokens <= 0 { + defaults.InitialDraftTokens = mtpDefaultInitialDraftTokens + } + if defaults.MaxDraftTokens <= 0 { + defaults.MaxDraftTokens = mtpDefaultMaxDraftTokens + } + return defaults +} + +func (r *Runner) loadMTPOptions(sample bool) mtpOptions { + defaults := r.mtpDefaults(sample) + + opts := mtpOptions{ + initialDraftTokens: defaults.InitialDraftTokens, + maxDraftTokens: defaults.MaxDraftTokens, + draftSchedule: mtpDraftScheduleConstant, + } + if v := positiveEnvInt("OLLAMA_MLX_MTP_MAX_DRAFT_TOKENS"); v > 0 { + opts.maxDraftTokens = v + } + if v := positiveEnvInt("OLLAMA_MLX_MTP_INITIAL_DRAFT_TOKENS"); v > 0 { + opts.initialDraftTokens = v + } + if opts.initialDraftTokens > opts.maxDraftTokens { + opts.initialDraftTokens = opts.maxDraftTokens + } + if b, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_SERIAL_VALIDATE")); err == nil { + opts.serialValidate = b + } + if b, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_COMPARE_SERIAL_VALIDATE")); err == nil { + opts.compareSerialValidate = b + } + switch schedule := strings.ToLower(strings.TrimSpace(os.Getenv("OLLAMA_MLX_MTP_DRAFT_SCHEDULE"))); schedule { + case "", string(mtpDraftScheduleConstant): + opts.draftSchedule = mtpDraftScheduleConstant + case string(mtpDraftScheduleHeuristic): + opts.draftSchedule = mtpDraftScheduleHeuristic + default: + slog.Warn("invalid MTP env setting", "key", "OLLAMA_MLX_MTP_DRAFT_SCHEDULE", "value", schedule) + } + return opts +} + +func positiveEnvInt(key string) int { + raw := os.Getenv(key) + if raw == "" { + return 0 + } + v, err := strconv.Atoi(raw) + if err != nil || v <= 0 { + slog.Warn("invalid MTP env setting", "key", key, "value", raw) + return 0 + } + return v +} + +func (r *Runner) useGreedyMTP(opts sampler.Options) bool { + if r.Draft == nil { + return false + } + if _, ok := r.Draft.(base.MTPDraftModel); !ok { + return false + } + if _, ok := r.Model.(base.MTPEmbeddingModel); !ok { + return false + } + if !r.mtpDefaults(false).Enabled { + return false + } + if opts.Logprobs || opts.TopLogprobs > 0 { + return false + } + if opts.Temperature != 0 { + return false + } + repeatPenaltyNeutral := opts.RepeatPenalty <= 0 || opts.RepeatPenalty == 1 + topPNeutral := opts.TopP <= 0 || opts.TopP >= 1 + topKNeutral := opts.TopK <= 0 + return repeatPenaltyNeutral && opts.PresencePenalty == 0 && opts.FrequencyPenalty == 0 && topPNeutral && topKNeutral && opts.MinP == 0 +} + +func (r *Runner) useSampleMTP(opts sampler.Options) bool { + if serial, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_SERIAL_VALIDATE")); err == nil && serial { + return false + } + if compare, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_COMPARE_SERIAL_VALIDATE")); err == nil && compare { + return false + } + if r.Draft == nil { + return false + } + if _, ok := r.Draft.(base.MTPDraftModel); !ok { + return false + } + if _, ok := r.Model.(base.MTPEmbeddingModel); !ok { + return false + } + if !r.mtpDefaults(true).Enabled { + return false + } + if opts.Logprobs || opts.TopLogprobs > 0 { + return false + } + return opts.Temperature != 0 +} + +func (r *Runner) runGreedyMTPDecode(ctx context.Context, request Request, session *cacheSession, caches []cache.Cache, seed []int32, position *int, started time.Time) error { + targetEmbeddings := r.Model.(base.MTPEmbeddingModel) + draft := r.Draft.(base.MTPDraftModel) + mtpOpts := r.loadMTPOptions(false) + stats := mtpStats{maxDraft: mtpOpts.initialDraftTokens} + draftLimit := mtpOpts.initialDraftTokens + slog.Info("MTP greedy decode enabled", "initial_draft_tokens", mtpOpts.initialDraftTokens, "max_draft_tokens", mtpOpts.maxDraftTokens, "draft_schedule", mtpOpts.draftSchedule, "serial_validate", mtpOpts.serialValidate, "compare_serial_validate", mtpOpts.compareSerialValidate) + + targetForward := func(token *mlx.Array) *mlx.Array { + fwd := r.Model.Forward(&batch.Batch{ + InputIDs: token, + SeqOffsets: []int32{int32(*position)}, + SeqQueryLens: []int32{int32(token.Dim(1))}, + }, caches) + *position += token.Dim(1) + return fwd + } + + hidden := targetForward(mlx.FromValues(seed, 1, len(seed))) + current := sampler.Result{Token: greedyTokenFromLogits(r.lastLogits(hidden))} + mlx.Pin(current.Arrays()...) + mlx.Sweep() + mlx.AsyncEval(current.Arrays()...) + defer func() { + mlx.Unpin(current.Arrays()...) + }() + + dec := decoder{tokenizer: r.Tokenizer} + final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1} + now := started + + generated := 0 + for generated < request.Options.NumPredict { + if err := ctx.Err(); err != nil { + return err + } + + t0 := time.Now() + hidden = targetForward(current.Token.ExpandDims(-1)) + baseLogits := r.lastLogits(hidden) + stats.targetDuration += time.Since(t0) + + if generated == 0 { + mlx.Eval(current.Arrays()...) + final.PromptEvalDuration = time.Since(now) + now = time.Now() + } + + done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final) + if err != nil { + return err + } + if !done { + generated++ + } + if done || generated >= request.Options.NumPredict { + break + } + + stats.iterations++ + maxDraft := min(draftLimit, request.Options.NumPredict-generated) + t0 = time.Now() + draftTokens := r.generateMTPDrafts(draft, targetEmbeddings, current.Token, hidden, caches, int32(*position-1), maxDraft) + draftCount := 0 + if draftTokens != nil { + draftCount = draftTokens.Dim(1) + mlx.Pin(baseLogits, draftTokens) + mlx.Eval(draftTokens) + mlx.Sweep() + } + stats.draftDuration += time.Since(t0) + stats.drafted += draftCount + var next sampler.Result + if draftCount == 0 { + next = sampler.Result{Token: greedyTokenFromLogits(baseLogits)} + } else { + var accepted int + t0 = time.Now() + next, accepted, done, err = r.acceptMTPDrafts(ctx, request, session, &dec, caches, position, baseLogits, draftTokens, &final, &generated, &stats, mtpOpts) + stats.validateDuration += time.Since(t0) + mlx.Unpin(baseLogits, draftTokens) + if err != nil { + return err + } + stats.accepted += accepted + switch { + case mtpOpts.draftSchedule == mtpDraftScheduleConstant: + case accepted == draftCount: + stats.allAccepted++ + draftLimit = min(mtpOpts.maxDraftTokens, draftLimit+2) + default: + stats.mismatches++ + draftLimit = max(1, draftLimit-1) + } + if mtpOpts.draftSchedule == mtpDraftScheduleConstant { + if accepted == draftCount { + stats.allAccepted++ + } else { + stats.mismatches++ + } + } + stats.maxDraft = max(stats.maxDraft, draftLimit) + if next.Token == nil { + mlx.Sweep() + } + if done || generated >= request.Options.NumPredict { + break + } + } + + mlx.Pin(next.Arrays()...) + old := current + current = next + mlx.Unpin(old.Arrays()...) + mlx.Sweep() + mlx.AsyncEval(current.Arrays()...) + + if generated%256 == 0 { + mlx.ClearCache() + } + } + + final.EvalCount = generated + final.EvalDuration = time.Since(now) + acceptance := 0.0 + if stats.drafted > 0 { + acceptance = float64(stats.accepted) / float64(stats.drafted) + } + avgDraft := 0.0 + avgAccepted := 0.0 + if stats.iterations > 0 { + avgDraft = float64(stats.drafted) / float64(stats.iterations) + avgAccepted = float64(stats.accepted) / float64(stats.iterations) + } + slog.Info("MTP decode stats", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "compared", stats.compared, "batch_serial_mismatches", stats.batchSerialMismatches, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", stats.maxDraft, "draft_schedule", mtpOpts.draftSchedule, "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration) + select { + case <-ctx.Done(): + return ctx.Err() + case request.Responses <- final: + return nil + } +} + +func (r *Runner) runSampleMTPDecode(ctx context.Context, request Request, session *cacheSession, caches []cache.Cache, seed []int32, position *int, started time.Time) error { + targetEmbeddings := r.Model.(base.MTPEmbeddingModel) + draft := r.Draft.(base.MTPDraftModel) + mtpOpts := r.loadMTPOptions(true) + stats := mtpStats{maxDraft: mtpOpts.initialDraftTokens} + draftLimit := mtpOpts.initialDraftTokens + slog.Info("MTP sample decode enabled", "initial_draft_tokens", mtpOpts.initialDraftTokens, "max_draft_tokens", mtpOpts.maxDraftTokens, "draft_schedule", mtpOpts.draftSchedule, "serial_validate", mtpOpts.serialValidate) + + targetForward := func(token *mlx.Array) *mlx.Array { + fwd := r.Model.Forward(&batch.Batch{ + InputIDs: token, + SeqOffsets: []int32{int32(*position)}, + SeqQueryLens: []int32{int32(token.Dim(1))}, + }, caches) + *position += token.Dim(1) + return fwd + } + + hidden := targetForward(mlx.FromValues(seed, 1, len(seed))) + current := r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden)) + mlx.Pin(current.Arrays()...) + mlx.Sweep() + mlx.AsyncEval(current.Arrays()...) + defer func() { + mlx.Unpin(current.Arrays()...) + }() + + dec := decoder{tokenizer: r.Tokenizer} + final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1} + now := started + + generated := 0 + for generated < request.Options.NumPredict { + if err := ctx.Err(); err != nil { + return err + } + + t0 := time.Now() + hidden = targetForward(mtpTokenInput(current.Token)) + baseLogits := r.lastLogits(hidden) + stats.targetDuration += time.Since(t0) + + if generated == 0 { + mlx.Eval(current.Arrays()...) + final.PromptEvalDuration = time.Since(now) + now = time.Now() + } + + done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final) + if err != nil { + return err + } + if !done { + generated++ + } + if done || generated >= request.Options.NumPredict { + break + } + + stats.iterations++ + maxDraft := min(draftLimit, request.Options.NumPredict-generated) + t0 = time.Now() + candidates := r.generateMTPDraftCandidates(draft, targetEmbeddings, current.Token, hidden, caches, int32(*position-1), maxDraft) + draftCount := 0 + if candidates != nil { + draftCount = candidates.tokens.Dim(1) + mlx.Pin(baseLogits, candidates.tokens, candidates.logits) + mlx.Sweep() + } + stats.draftDuration += time.Since(t0) + stats.drafted += draftCount + + var next sampler.Result + if draftCount == 0 { + next = r.Sampler.Sample([]int{pipelineSlot}, baseLogits) + } else { + var accepted int + t0 = time.Now() + next, accepted, done, err = r.acceptSampleMTPDrafts(ctx, request, session, &dec, caches, position, baseLogits, candidates, &final, &generated, &stats) + stats.validateDuration += time.Since(t0) + mlx.Unpin(baseLogits, candidates.tokens, candidates.logits) + if err != nil { + return err + } + stats.accepted += accepted + switch { + case mtpOpts.draftSchedule == mtpDraftScheduleConstant: + case accepted == draftCount: + stats.allAccepted++ + draftLimit = min(mtpOpts.maxDraftTokens, draftLimit+2) + default: + stats.mismatches++ + draftLimit = max(1, draftLimit-1) + } + if mtpOpts.draftSchedule == mtpDraftScheduleConstant { + if accepted == draftCount { + stats.allAccepted++ + } else { + stats.mismatches++ + } + } + stats.maxDraft = max(stats.maxDraft, draftLimit) + if next.Token == nil { + mlx.Sweep() + } + if done || generated >= request.Options.NumPredict { + break + } + } + + mlx.Pin(next.Arrays()...) + old := current + current = next + mlx.Unpin(old.Arrays()...) + mlx.Sweep() + mlx.AsyncEval(current.Arrays()...) + + if generated%256 == 0 { + mlx.ClearCache() + } + } + + final.EvalCount = generated + final.EvalDuration = time.Since(now) + acceptance := 0.0 + if stats.drafted > 0 { + acceptance = float64(stats.accepted) / float64(stats.drafted) + } + avgDraft := 0.0 + avgAccepted := 0.0 + if stats.iterations > 0 { + avgDraft = float64(stats.drafted) / float64(stats.iterations) + avgAccepted = float64(stats.accepted) / float64(stats.iterations) + } + slog.Info("MTP decode stats", "mode", "sample", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", stats.maxDraft, "draft_schedule", mtpOpts.draftSchedule, "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration) + select { + case <-ctx.Done(): + return ctx.Err() + case request.Responses <- final: + return nil + } +} + +type mtpDraftCandidates struct { + tokens *mlx.Array + // logits are the processed proposal scores used to sample tokens. + logits *mlx.Array +} + +func (r *Runner) generateMTPDrafts(draft base.MTPDraftModel, target base.MTPEmbeddingModel, token *mlx.Array, hidden *mlx.Array, caches []cache.Cache, position int32, maxDraft int) *mlx.Array { + if maxDraft <= 0 { + return nil + } + + lastToken := token.ExpandDims(-1) + lastHidden := hidden + draftTokens := make([]*mlx.Array, 0, maxDraft) + + // Gemma4 assistant MTP is trained as "single-position" drafting: + // keep the RoPE/cache position anchored at the last target-seen token + // while the proposed token and projected hidden state advance. + for range maxDraft { + tokenEmbedding := target.TokenEmbeddings(lastToken) + inputs := tokenEmbedding.Concatenate(-1, lastHidden) + logits, projected := draft.Draft(inputs, position, caches) + stepLogits := r.lastLogitsFromLogits(logits) + nextToken := greedyTokenFromLogits(stepLogits) + + lastToken = nextToken.ExpandDims(-1) + lastHidden = projected + draftTokens = append(draftTokens, lastToken) + } + if len(draftTokens) == 0 { + return nil + } + return mlx.Concatenate(draftTokens, 1) +} + +func (r *Runner) generateMTPDraftCandidates(draft base.MTPDraftModel, target base.MTPEmbeddingModel, token *mlx.Array, hidden *mlx.Array, caches []cache.Cache, position int32, maxDraft int) *mtpDraftCandidates { + if maxDraft <= 0 { + return nil + } + + lastToken := mtpTokenInput(token) + lastHidden := hidden + draftTokens := make([]*mlx.Array, 0, maxDraft) + draftLogits := make([]*mlx.Array, 0, maxDraft) + var prefix *mlx.Array + + // Gemma4 assistant MTP is trained as "single-position" drafting: + // keep the RoPE/cache position anchored at the last target-seen token + // while the proposed token and projected hidden state advance. + for range maxDraft { + tokenEmbedding := target.TokenEmbeddings(lastToken) + inputs := tokenEmbedding.Concatenate(-1, lastHidden) + logits, projected := draft.Draft(inputs, position, caches) + stepLogits := r.lastLogitsFromLogits(logits) + stepScores := r.Sampler.SpeculativeScores(pipelineSlot, stepLogits, prefix) + nextToken := stepScores.Categorical(-1).AsType(mlx.DTypeInt32) + + lastToken = mtpTokenInput(nextToken) + lastHidden = projected + draftTokens = append(draftTokens, lastToken) + draftLogits = append(draftLogits, stepScores.ExpandDims(1)) + if prefix == nil { + prefix = lastToken + } else { + prefix = prefix.Concatenate(1, lastToken) + } + } + if len(draftTokens) == 0 { + return nil + } + return &mtpDraftCandidates{ + tokens: mlx.Concatenate(draftTokens, 1), + logits: mlx.Concatenate(draftLogits, 1), + } +} + +func (r *Runner) acceptMTPDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int, stats *mtpStats, opts mtpOptions) (sampler.Result, int, bool, error) { + if opts.serialValidate { + stats.serial++ + return r.acceptMTPDraftsSerial(ctx, request, session, dec, caches, position, baseLogits, draftTokens, final, generated) + } + + specCaches, spec, ok := cache.BeginSpeculation(caches) + if ok { + stats.batched++ + return r.acceptMTPDraftsBatched(ctx, request, session, dec, caches, specCaches, spec, position, baseLogits, draftTokens, final, generated, stats, opts) + } + + stats.serial++ + return r.acceptMTPDraftsSerial(ctx, request, session, dec, caches, position, baseLogits, draftTokens, final, generated) +} + +func (r *Runner) acceptMTPDraftsBatched(ctx context.Context, request Request, session *cacheSession, dec *decoder, liveCaches []cache.Cache, caches []cache.Cache, spec *cache.Speculation, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int, stats *mtpStats, opts mtpOptions) (sampler.Result, int, bool, error) { + before := *position + draftCount := draftTokens.Dim(1) + hiddenSeq := r.Model.Forward(&batch.Batch{ + InputIDs: draftTokens, + SeqOffsets: []int32{int32(before)}, + SeqQueryLens: []int32{int32(draftCount)}, + }, caches) + + accepted := 0 + var next sampler.Result + done := false + + selectedTokens := r.mtpValidationTokens(baseLogits, hiddenSeq) + mlx.Eval(draftTokens, selectedTokens) + draftIDs := draftTokens.Ints() + selectedIDs := selectedTokens.Ints() + if len(selectedIDs) < draftCount+1 { + return sampler.Result{}, accepted, false, fmt.Errorf("mtp validation produced %d tokens for %d draft tokens", len(selectedIDs), draftCount) + } + + for i, id := range draftIDs { + if selectedIDs[i] != id { + next = sampler.Result{Token: mtpTokenAt(selectedTokens, i)} + break + } + accepted++ + if r.Tokenizer.IsEOS(int32(id)) { + done = true + break + } + } + + if opts.compareSerialValidate { + spec.Commit(0) + r.compareMTPBatchedWithSerial(ctx, liveCaches, before, baseLogits, hiddenSeq, draftIDs, selectedIDs, accepted, draftCount, stats) + } + spec.Commit(accepted) + *position = before + accepted + + for _, id := range draftIDs[:accepted] { + if *generated >= request.Options.NumPredict { + done = true + break + } + res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)} + var err error + done, err = r.emitMTPToken(ctx, request, session, dec, res, final) + if err != nil { + return sampler.Result{}, accepted, done, err + } + if !done { + (*generated)++ + } + if done { + break + } + } + + if done || *generated >= request.Options.NumPredict { + return sampler.Result{}, accepted, true, nil + } + if next.Token == nil { + next = sampler.Result{Token: mtpTokenAt(selectedTokens, draftCount)} + } + return next, accepted, false, nil +} + +func (r *Runner) acceptSampleMTPDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, candidates *mtpDraftCandidates, final *CompletionResponse, generated *int, stats *mtpStats) (sampler.Result, int, bool, error) { + specCaches, spec, ok := cache.BeginSpeculation(caches) + if !ok { + stats.serial++ + return r.Sampler.Sample([]int{pipelineSlot}, baseLogits), 0, false, nil + } + stats.batched++ + + before := *position + draftCount := candidates.tokens.Dim(1) + hiddenSeq := r.Model.Forward(&batch.Batch{ + InputIDs: candidates.tokens, + SeqOffsets: []int32{int32(before)}, + SeqQueryLens: []int32{int32(draftCount)}, + }, specCaches) + + targetScores := r.Sampler.SpeculativeScores(pipelineSlot, r.mtpValidationLogits(baseLogits, hiddenSeq), candidates.tokens) + draftScores := candidates.logits + if draftScores.NumDims() == 3 { + draftScores = draftScores.Squeeze(0) + } + acceptedMask := mtpSampleAcceptedMask(targetScores, draftScores, candidates.tokens, draftCount) + mlx.Eval(candidates.tokens, acceptedMask) + + draftIDs := candidates.tokens.Ints() + acceptedFlags := acceptedMask.Ints() + accepted := 0 + for _, ok := range acceptedFlags { + if ok == 0 { + break + } + accepted++ + } + if accepted > draftCount { + return sampler.Result{}, 0, false, fmt.Errorf("mtp sample validation accepted %d tokens for %d draft tokens", accepted, draftCount) + } + + commitIDs := make([]int32, 0, accepted+1) + done := false + for i, id := range draftIDs[:accepted] { + commitIDs = append(commitIDs, int32(id)) + if r.Tokenizer.IsEOS(int32(id)) { + done = true + accepted = i + 1 + commitIDs = commitIDs[:accepted] + break + } + } + + spec.Commit(accepted) + *position = before + accepted + + for _, id := range draftIDs[:accepted] { + if *generated >= request.Options.NumPredict { + done = true + break + } + res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)} + var err error + done, err = r.emitMTPToken(ctx, request, session, dec, res, final) + if err != nil { + return sampler.Result{}, accepted, done, err + } + if !done { + (*generated)++ + } + if done { + break + } + } + + if done || *generated >= request.Options.NumPredict { + r.Sampler.Commit(pipelineSlot, commitIDs) + return sampler.Result{}, accepted, true, nil + } + + var nextToken *mlx.Array + if accepted == draftCount { + nextToken = mtpSampleTokenAt(targetScores, draftCount) + } else { + nextToken = mtpSampleResidualToken(targetScores, draftScores, accepted) + } + mlx.Eval(nextToken) + nextID := int32(tokenID(nextToken)) + commitIDs = append(commitIDs, nextID) + r.Sampler.Commit(pipelineSlot, commitIDs) + + return sampler.Result{Token: nextToken}, accepted, false, nil +} + +func mtpSampleAcceptedMask(targetScores, draftScores, draftTokens *mlx.Array, draftCount int) *mlx.Array { + targetProbs := mlx.SoftmaxAxis(targetScores.Slice(mlx.Slice(0, draftCount), mlx.Slice()), -1, true) + draftProbs := mlx.SoftmaxAxis(draftScores, -1, true) + if draftTokens.NumDims() == 2 { + draftTokens = draftTokens.Squeeze(0) + } + indices := draftTokens.ExpandDims(-1) + p := targetProbs.TakeAlongAxis(indices, -1).Squeeze(-1) + q := draftProbs.TakeAlongAxis(indices, -1).Squeeze(-1) + acceptP := mlx.Minimum(p.Divide(q), mlx.FromValue(float32(1))) + return mlx.Bernoulli(acceptP).AsType(mlx.DTypeInt32) +} + +func mtpSampleTokenAt(scores *mlx.Array, index int) *mlx.Array { + row := scores.Slice(mlx.Slice(index, index+1), mlx.Slice()) + return mtpTokenVector(row.Categorical(-1).AsType(mlx.DTypeInt32)) +} + +func mtpSampleResidualToken(targetScores, draftScores *mlx.Array, index int) *mlx.Array { + p := mlx.SoftmaxAxis(targetScores.Slice(mlx.Slice(index, index+1), mlx.Slice()), -1, true) + q := mlx.SoftmaxAxis(draftScores.Slice(mlx.Slice(index, index+1), mlx.Slice()), -1, true) + diff := p.Subtract(q) + positive := mlx.Maximum(diff, mlx.FromValue(float32(1e-20))) + logits := mlx.Log(positive) + logits = mlx.Where(diff.LessEqual(mlx.FromValue(float32(0))), mlx.FromValue(float32(math.Inf(-1))), logits) + return mtpTokenVector(logits.Categorical(-1).AsType(mlx.DTypeInt32)) +} + +func mtpTokenInput(token *mlx.Array) *mlx.Array { + switch token.NumDims() { + case 0: + return token.Reshape(1, 1) + case 1: + return token.ExpandDims(-1) + case 2: + return token + default: + panic(fmt.Sprintf("mtp token must be rank 0, 1, or 2, got rank %d", token.NumDims())) + } +} + +func mtpTokenVector(token *mlx.Array) *mlx.Array { + switch token.NumDims() { + case 0: + return token.Reshape(1) + case 1: + return token + default: + panic(fmt.Sprintf("mtp sampled token must be rank 0 or 1, got rank %d", token.NumDims())) + } +} + +func (r *Runner) compareMTPBatchedWithSerial(ctx context.Context, caches []cache.Cache, before int, baseLogits, hiddenSeq *mlx.Array, draftIDs, selectedIDs []int, accepted, draftCount int, stats *mtpStats) { + serialCaches, ok := cache.BeginIsolatedSpeculation(caches) + if !ok { + return + } + + compareCount := accepted + 1 + if accepted == draftCount { + // Include the target bonus token when every draft was accepted. + compareCount = draftCount + 1 + } + + serialLogits := baseLogits + for i := range compareCount { + if err := ctx.Err(); err != nil { + return + } + if i >= len(selectedIDs) { + return + } + + batchedLogits := baseLogits + if i > 0 { + batchedLogits = r.targetLogitsAt(hiddenSeq, i-1) + } + + batchedToken := greedyTokenFromLogits(batchedLogits) + serialToken := greedyTokenFromLogits(serialLogits) + mlx.Eval(batchedToken, serialToken) + + batchedID := tokenID(batchedToken) + vectorizedID := selectedIDs[i] + serialID := tokenID(serialToken) + stats.compared++ + if vectorizedID != serialID { + firstMismatch := stats.batchSerialMismatches == 0 + stats.batchSerialMismatches++ + if !firstMismatch { + return + } + + draftID := -1 + if i < draftCount { + draftID = draftIDs[i] + } + batchedTop := top2FromLogits(batchedLogits) + serialTop := top2FromLogits(serialLogits) + slog.Warn("MTP batched validation differs from serial validation", + "position", before+i, + "draft", draftID, + "batched", vectorizedID, + "batched_slice", batchedID, + "serial", serialID, + "batched_slice_top1", batchedTop.firstToken, + "batched_slice_top2", batchedTop.secondToken, + "batched_slice_margin", batchedTop.margin, + "serial_top1", serialTop.firstToken, + "serial_top2", serialTop.secondToken, + "serial_margin", serialTop.margin, + ) + return + } + + if i >= draftCount || i >= accepted { + return + } + + hidden := r.Model.Forward(&batch.Batch{ + InputIDs: mlx.FromValues([]int32{int32(draftIDs[i])}, 1, 1), + SeqOffsets: []int32{int32(before + i)}, + SeqQueryLens: []int32{1}, + }, serialCaches) + serialLogits = r.lastLogits(hidden) + } +} + +type mtpTop2 struct { + firstToken int + secondToken int + margin float64 +} + +func top2FromLogits(logits *mlx.Array) mtpTop2 { + indices := logits.Negative().ArgsortAxis(-1).Slice(mlx.Slice(), mlx.Slice(0, 2)) + indices32 := indices.AsType(mlx.DTypeInt32) + values := logits.TakeAlongAxis(indices, -1).AsType(mlx.DTypeFloat32) + mlx.Eval(indices32, values) + + tokenIDs := indices32.Ints() + logitValues := values.Floats() + if len(tokenIDs) < 2 || len(logitValues) < 2 { + return mtpTop2{} + } + return mtpTop2{ + firstToken: tokenIDs[0], + secondToken: tokenIDs[1], + margin: float64(logitValues[0] - logitValues[1]), + } +} + +func (r *Runner) acceptMTPDraftsSerial(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) { + logits := baseLogits + accepted := 0 + draftIDs := draftTokens.Ints() + + for _, id := range draftIDs { + selected := greedyTokenFromLogits(logits) + mlx.Eval(selected) + selectedID := tokenID(selected) + if selectedID != id { + return sampler.Result{Token: mlx.FromValues([]int32{int32(selectedID)}, 1)}, accepted, false, nil + } + + hidden := r.Model.Forward(&batch.Batch{ + InputIDs: mlx.FromValues([]int32{int32(id)}, 1, 1), + SeqOffsets: []int32{int32(*position)}, + SeqQueryLens: []int32{1}, + }, caches) + (*position)++ + accepted++ + + res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)} + done, err := r.emitMTPToken(ctx, request, session, dec, res, final) + if err != nil { + return sampler.Result{}, accepted, done, err + } + if !done { + (*generated)++ + } + if done || *generated >= request.Options.NumPredict { + return sampler.Result{}, accepted, true, nil + } + + logits = r.lastLogits(hidden) + } + + return sampler.Result{Token: greedyTokenFromLogits(logits)}, accepted, false, nil +} + +func (r *Runner) emitMTPToken(ctx context.Context, request Request, session *cacheSession, dec *decoder, res sampler.Result, final *CompletionResponse) (bool, error) { + output := int32(tokenID(res.Token)) + session.outputs = append(session.outputs, output) + + if r.Tokenizer.IsEOS(output) { + final.DoneReason = 0 + return true, nil + } + + if resp, ok := dec.decode(res); ok { + select { + case <-ctx.Done(): + return false, ctx.Err() + case request.Responses <- resp: + } + } + return false, nil +} + +func (r *Runner) lastLogits(hidden *mlx.Array) *mlx.Array { + logits := r.Model.Unembed(hidden) + return r.lastLogitsFromLogits(logits) +} + +func (r *Runner) targetLogitsAt(hiddenSeq *mlx.Array, index int) *mlx.Array { + hidden := hiddenSeq.Slice(mlx.Slice(), mlx.Slice(index), mlx.Slice()) + return r.lastLogits(hidden) +} + +func (r *Runner) mtpValidationTokens(baseLogits, hiddenSeq *mlx.Array) *mlx.Array { + return greedyTokenFromLogits(r.mtpValidationLogits(baseLogits, hiddenSeq)) +} + +func (r *Runner) mtpValidationLogits(baseLogits, hiddenSeq *mlx.Array) *mlx.Array { + seqLogits := r.Model.Unembed(hiddenSeq) + return baseLogits.ExpandDims(1).Concatenate(1, seqLogits) +} + +func mtpTokenAt(tokens *mlx.Array, index int) *mlx.Array { + return tokens.Slice(mlx.Slice(), mlx.Slice(index)).Squeeze(0) +} + +func (r *Runner) lastLogitsFromLogits(logits *mlx.Array) *mlx.Array { + return logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) +} + +func greedyTokenFromLogits(logits *mlx.Array) *mlx.Array { + return logits.Argmax(-1, false).AsType(mlx.DTypeInt32) +} + +func tokenID(token *mlx.Array) int { + if token == nil { + return -1 + } + if token.DType() == mlx.DTypeInt32 { + ids := token.Ints() + if len(ids) > 0 { + return ids[0] + } + } + return token.Int() +} diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 414015f6a..89afcf6ab 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -146,6 +146,12 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er // Register the sampler after prefill completes. r.Sampler.Add(pipelineSlot, request.SamplerOpts, inputs) + if r.useGreedyMTP(request.SamplerOpts) { + return r.runGreedyMTPDecode(ctx, request, session, caches, tokens[processed:], &position, now) + } + if r.useSampleMTP(request.SamplerOpts) { + return r.runSampleMTPDecode(ctx, request, session, caches, tokens[processed:], &position, now) + } step := func(token *mlx.Array) sampler.Result { fwd := r.Model.Forward(&batch.Batch{ diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index 8f90c295e..ccde6ff05 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -35,6 +35,7 @@ type Request struct { type Runner struct { Model base.Model + Draft base.DraftModel Tokenizer *tokenizer.Tokenizer Requests chan Request Sampler *sample.Sampler @@ -61,12 +62,41 @@ func (r *Runner) Load(modelName string) error { return err } - // Assign weights to model (model-specific logic) - loadWeights := base.Weights(m) - if err := loadWeights(tensors); err != nil { + // Assign weights to model (model-specific logic). Target and draft weights + // must be loaded before sweeping so tensors from a combined manifest are + // not discarded before the draft model can retain them. + if err := m.LoadWeights(tensors); err != nil { return err } + r.Draft = nil + draft, err := base.NewDraft(root, m) + if err != nil { + return err + } + if draft != nil { + if err := draft.LoadWeights(tensors); err != nil { + return err + } + r.Draft = draft + } + + collected := mlx.Collect(m) + if draft != nil { + draftArrays := mlx.Collect(draft) + collected = append(collected, draftArrays...) + if root.Draft != nil { + slog.Info("Loaded draft model", "tensor_prefix", root.Draft.TensorPrefix, "config", root.Draft.Config, "arrays", len(draftArrays)) + } else { + slog.Info("Loaded draft model", "arrays", len(draftArrays)) + } + } + for _, arr := range collected { + mlx.Pin(arr) + } + mlx.Sweep() + mlx.Eval(collected...) + r.Model = m r.Tokenizer = m.Tokenizer() r.contextLength = m.MaxContextLength() diff --git a/x/mlxrunner/sample/sample.go b/x/mlxrunner/sample/sample.go index b61b68fcb..bc6baaa96 100644 --- a/x/mlxrunner/sample/sample.go +++ b/x/mlxrunner/sample/sample.go @@ -349,6 +349,159 @@ func (s *Sampler) Sample(seqIDs []int, logits *mlx.Array) Result { return res } +// SpeculativeScores applies this slot's non-sampling transforms to logits +// without mutating sampler state. Row i is scored as if draftTokens[:i] had +// already been appended to the slot history. logits must be [R,V] or [1,R,V]. +func (s *Sampler) SpeculativeScores(seqID int, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array { + slot, ok := s.byID[seqID] + if !ok { + panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: seqID %d not registered", seqID)) + } + + if logits.NumDims() == 3 { + if logits.Dim(0) != 1 { + panic("sample.Sampler.SpeculativeScores: only batch size 1 is supported") + } + logits = logits.Squeeze(0) + } + if logits.NumDims() != 2 { + panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: logits must be rank 2 or 3, got rank %d", logits.NumDims())) + } + + if draftTokens != nil && draftTokens.NumDims() == 1 { + draftTokens = draftTokens.ExpandDims(0) + } + rows := logits.Dim(0) + var hist *mlx.Array + if slot.opts.usesHistory() { + if s.history == nil { + panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: seqID %d has no history", seqID)) + } + if slot.historyLen < slot.opts.RepeatLastN { + return s.speculativeScoresSerial(slot, logits, draftTokens) + } + hist = s.speculativeHistory(slot, draftTokens, rows) + } + + return slot.speculativeScores(&slotCtx{opts: slot.opts, history: hist}, logits) +} + +// Commit appends already-selected tokens to seqID's repeat-penalty history. +// It is used after speculative sampling once the accepted continuation is +// known. Normal Sample calls continue to mutate history themselves. +func (s *Sampler) Commit(seqID int, tokens []int32) { + if len(tokens) == 0 { + return + } + slot, ok := s.byID[seqID] + if !ok { + panic(fmt.Sprintf("sample.Sampler.Commit: seqID %d not registered", seqID)) + } + if !slot.opts.usesHistory() { + return + } + if s.history == nil { + panic(fmt.Sprintf("sample.Sampler.Commit: seqID %d has no history", seqID)) + } + + row := slices.Index(s.slots, slot) + width := s.historyWidth() + take := min(len(tokens), slot.opts.RepeatLastN) + startLen := slot.historyLen + len(tokens) - take + writeTokens := tokens[len(tokens)-take:] + flatOffsets := make([]int32, take) + for i := range take { + ringPos := (startLen + i) % slot.opts.RepeatLastN + flatOffsets[i] = int32(row*width + ringPos) + } + + flatIdx := mlx.NewArrayInt32(flatOffsets, []int32{int32(take), 1}) + values := mlx.NewArrayInt32(writeTokens, []int32{int32(take), 1}) + flatHist := s.history.Reshape(s.history.Dim(0)*width, 1) + s.history.Set(flatHist.PutAlongAxis(flatIdx, values, 0).Reshape(s.history.Dim(0), width)) + slot.historyLen += len(tokens) +} + +func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array { + rows := logits.Dim(0) + draftCount := 0 + if draftTokens != nil { + draftCount = draftTokens.Dim(1) + } + row := slices.Index(s.slots, slot) + baseFill := min(slot.historyLen, slot.opts.RepeatLastN) + var base *mlx.Array + if baseFill > 0 { + base = s.history.Slice(mlx.Slice(row, row+1), mlx.Slice(0, baseFill)) + } + + scored := make([]*mlx.Array, 0, rows) + for i := range rows { + rowLogits := logits.Slice(mlx.Slice(i, i+1), mlx.Slice()) + hist := base + prefixLen := min(i, draftCount) + if prefixLen > 0 { + prefix := draftTokens.Slice(mlx.Slice(), mlx.Slice(0, prefixLen)) + if hist == nil { + hist = prefix + } else { + hist = hist.Concatenate(1, prefix) + } + if hist.Dim(1) > slot.opts.RepeatLastN { + hist = hist.Slice(mlx.Slice(), mlx.Slice(hist.Dim(1)-slot.opts.RepeatLastN, mlx.End)) + } + } + scored = append(scored, slot.speculativeScores(&slotCtx{opts: slot.opts, history: hist}, rowLogits)) + } + return mlx.Concatenate(scored, 0) +} + +func (s *Sampler) speculativeHistory(slot *slotState, draftTokens *mlx.Array, rows int) *mlx.Array { + row := slices.Index(s.slots, slot) + width := slot.opts.RepeatLastN + base := s.history.Slice(mlx.Slice(row, row+1), mlx.Slice(0, width)) + base = mlx.Tile(base, []int32{int32(rows), 1}) + next := slot.historyLen % width + draftCount := 0 + if draftTokens != nil { + draftCount = draftTokens.Dim(1) + } + if draftCount == 0 { + return base + } + + sourceIdx := make([]int32, rows*width) + writeMask := make([]bool, rows*width) + for i := range rows { + prefixLen := min(i, draftCount) + for j := range prefixLen { + pos := (next + j) % width + sourceIdx[i*width+pos] = int32(j) + writeMask[i*width+pos] = true + } + } + + draftRows := mlx.Tile(draftTokens, []int32{int32(rows), 1}) + idx := mlx.NewArrayInt32(sourceIdx, []int32{int32(rows), int32(width)}) + mask := mlx.FromValues(writeMask, rows, width) + values := draftRows.TakeAlongAxis(idx, 1) + return mlx.Where(mask, values, base) +} + +func (slot *slotState) speculativeScores(ctx *slotCtx, logits *mlx.Array) *mlx.Array { + scores := logits + // buildTransforms always appends the final selector transform + // (greedy or temperature sampling). Speculative validation needs the + // processed logits before that selector mutates the distribution. + for _, t := range slot.transforms[:len(slot.transforms)-1] { + scores = t(ctx, scores) + } + if slot.opts.Temperature > 0 { + scores = mlx.DivScalar(scores, slot.opts.Temperature) + } + return scores +} + // canBatch reports whether the call can take the uniform batched path. // All slots must share Options; when penalties are active the call must // additionally cover every registered slot in registration order with a diff --git a/x/mlxrunner/sample/sample_test.go b/x/mlxrunner/sample/sample_test.go index 3871cc6bf..b11c6b84f 100644 --- a/x/mlxrunner/sample/sample_test.go +++ b/x/mlxrunner/sample/sample_test.go @@ -176,6 +176,65 @@ func TestSampleHistoryWindow(t *testing.T) { } } +func TestSpeculativeScoresUsesDraftHistoryWithoutCommit(t *testing.T) { + skipIfNoMLX(t) + + s := New(128) + t.Cleanup(func() { + s.Free() + mlx.Sweep() + }) + + s.Add(0, Options{RepeatLastN: 2, RepeatPenalty: 10}, []int32{1, 2}) + draftTokens := mlx.NewArrayInt32([]int32{3, 4}, []int32{1, 2}) + scores := s.SpeculativeScores(0, batchLogits( + []float32{0, 9, 9, 8, 0}, // history {1,2}; token 3 wins + []float32{0, 0, 9, 9, 8}, // history {2,3}; token 4 wins + []float32{0, 0, 9, 9, 8}, // history {3,4}; token 2 wins + ), draftTokens) + tokens := scores.Argmax(-1, false).AsType(mlx.DTypeInt32) + mlx.Eval(tokens) + + if got, want := tokens.Ints(), []int{3, 4, 2}; len(got) != len(want) { + t.Fatalf("tokens = %v, want %v", got, want) + } else { + for i := range want { + if got[i] != want[i] { + t.Fatalf("tokens = %v, want %v", got, want) + } + } + } + if s.byID[0].historyLen != 2 { + t.Fatalf("historyLen = %d, want 2", s.byID[0].historyLen) + } +} + +func TestCommitBatchesRingWrites(t *testing.T) { + skipIfNoMLX(t) + + s := New(128) + t.Cleanup(func() { + s.Free() + mlx.Sweep() + }) + + s.Add(0, Options{RepeatLastN: 4, RepeatPenalty: 1.1}, []int32{10, 11, 12}) + s.Commit(0, []int32{20, 21, 22}) + s.Commit(0, []int32{30, 31, 32, 33, 34}) + mlx.Eval(s.history) + + got := s.history.Ints() + want := []int{32, 33, 34, 31} + for i := range want { + if got[i] != want[i] { + t.Fatalf("history = %v, want %v", got, want) + } + } + if s.byID[0].historyLen != 11 { + t.Fatalf("historyLen = %d, want 11", s.byID[0].historyLen) + } +} + // TestBatchSamplingPreservesPerSlotBehavior is the core equivalence test: // for every representative dispatch branch (uniform, serial on mixed opts, // serial on partial ring, subset/out-of-order), a batched Sample call must diff --git a/x/models/gemma4/assistant.go b/x/models/gemma4/assistant.go new file mode 100644 index 000000000..fef13eb84 --- /dev/null +++ b/x/models/gemma4/assistant.go @@ -0,0 +1,390 @@ +package gemma4 + +import ( + "encoding/json" + "fmt" + + "github.com/ollama/ollama/x/mlxrunner/batch" + "github.com/ollama/ollama/x/mlxrunner/cache" + "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/mlxrunner/model" + "github.com/ollama/ollama/x/mlxrunner/model/base" + "github.com/ollama/ollama/x/models/nn" +) + +var ( + _ base.DraftModel = (*AssistantModel)(nil) + _ base.MTPDraftModel = (*AssistantModel)(nil) +) + +type AssistantConfig struct { + TextConfig TextConfig `json:"text_config"` + BackboneHiddenSize int32 `json:"backbone_hidden_size"` + UseOrderedEmbeddings bool `json:"use_ordered_embeddings"` + NumCentroids int32 `json:"num_centroids"` + CentroidIntermediateTopK int32 `json:"centroid_intermediate_top_k"` +} + +type AssistantModel struct { + PreProjection nn.LinearLayer + PostProjection nn.LinearLayer + EmbedTokens nn.EmbeddingLayer + LMHead nn.LinearLayer + Centroids nn.LinearLayer + TokenOrdering *mlx.Array + Layers []*AssistantLayer + Norm *nn.RMSNorm + + NormScaled *mlx.Array + + *AssistantConfig + tensorPrefix string + + QuantGroupSize int + QuantBits int + QuantMode string + TensorQuant map[string]*model.TensorQuantInfo +} + +type AssistantLayer struct { + InputNorm *nn.RMSNorm + PostAttnNorm *nn.RMSNorm + PreFFNorm *nn.RMSNorm + PostFFNorm *nn.RMSNorm + + InputNormScaled *mlx.Array + PostAttnNormScaled *mlx.Array + PreFFNormScaled *mlx.Array + PostFFNormScaled *mlx.Array + + Attention *AssistantAttention + MLP *MLP + LayerScalar *mlx.Array + IsSliding bool +} + +type AssistantAttention struct { + QProj nn.LinearLayer + OProj nn.LinearLayer + QNorm *nn.RMSNorm + + QNormScaled *mlx.Array +} + +func parseAssistantConfig(configData []byte) (AssistantConfig, error) { + var raw struct { + TextConfig json.RawMessage `json:"text_config"` + + BackboneHiddenSize int32 `json:"backbone_hidden_size"` + UseOrderedEmbeddings bool `json:"use_ordered_embeddings"` + NumCentroids int32 `json:"num_centroids"` + CentroidIntermediateTopK int32 `json:"centroid_intermediate_top_k"` + } + if err := json.Unmarshal(configData, &raw); err != nil { + return AssistantConfig{}, fmt.Errorf("parse assistant config: %w", err) + } + if len(raw.TextConfig) == 0 { + return AssistantConfig{}, fmt.Errorf("assistant config missing text_config") + } + + text, err := parseTextConfig(raw.TextConfig) + if err != nil { + return AssistantConfig{}, err + } + if raw.NumCentroids == 0 { + raw.NumCentroids = 2048 + } + if raw.CentroidIntermediateTopK == 0 { + raw.CentroidIntermediateTopK = 32 + } + + return AssistantConfig{ + TextConfig: text, + BackboneHiddenSize: raw.BackboneHiddenSize, + UseOrderedEmbeddings: raw.UseOrderedEmbeddings, + NumCentroids: raw.NumCentroids, + CentroidIntermediateTopK: raw.CentroidIntermediateTopK, + }, nil +} + +func newAssistantModel(root *model.Root, target base.Model) (base.DraftModel, error) { + if root == nil || root.Draft == nil { + return nil, fmt.Errorf("draft metadata missing") + } + + configPath := root.Draft.Config + if configPath == "" { + configPath = "draft/config.json" + } + configData, err := root.Manifest.ReadConfig(configPath) + if err != nil { + return nil, fmt.Errorf("load draft config: %w", err) + } + + cfg, err := parseAssistantConfig(configData) + if err != nil { + return nil, err + } + + targetGemma, ok := target.(*Model) + if !ok { + return nil, fmt.Errorf("gemma4 assistant requires gemma4 target, got %T", target) + } + if cfg.BackboneHiddenSize != 0 && cfg.BackboneHiddenSize != targetGemma.HiddenSize { + return nil, fmt.Errorf("assistant backbone hidden size %d does not match target hidden size %d", cfg.BackboneHiddenSize, targetGemma.HiddenSize) + } + if cfg.TextConfig.VocabSize != targetGemma.VocabSize { + return nil, fmt.Errorf("assistant vocab size %d does not match target vocab size %d", cfg.TextConfig.VocabSize, targetGemma.VocabSize) + } + + tensorPrefix := root.Draft.TensorPrefix + if tensorPrefix == "" { + tensorPrefix = "draft." + } + + m := &AssistantModel{ + AssistantConfig: &cfg, + tensorPrefix: tensorPrefix, + Layers: make([]*AssistantLayer, cfg.TextConfig.NumHiddenLayers), + TensorQuant: root.AllTensorQuant(), + } + if qt := root.QuantType(); qt != "" { + m.QuantGroupSize, m.QuantBits, m.QuantMode = model.QuantizationParams(qt) + if gs := root.GroupSize(); gs > 0 { + m.QuantGroupSize = gs + } + } + for i := range m.Layers { + m.Layers[i] = &AssistantLayer{ + IsSliding: isLayerSliding(int32(i), &m.TextConfig), + Attention: &AssistantAttention{}, + MLP: &MLP{}, + } + } + return m, nil +} + +func (m *AssistantModel) LoadWeights(tensors map[string]*mlx.Array) error { + prefix := m.tensorPrefix + linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant) + + m.PreProjection = linears.Make(prefix + "pre_projection") + m.PostProjection = linears.Make(prefix + "post_projection") + if m.PreProjection == nil || m.PostProjection == nil { + return fmt.Errorf("missing assistant projection weights") + } + + m.EmbedTokens = model.MakeEmbeddingLayer(tensors, prefix+"model.embed_tokens", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant) + if m.EmbedTokens == nil { + return fmt.Errorf("missing assistant embedding weight") + } + m.LMHead = m.EmbedTokens.AsLinear() + + if m.UseOrderedEmbeddings { + m.Centroids = linears.Make(prefix + "masked_embedding.centroids") + m.TokenOrdering = tensors[prefix+"masked_embedding.token_ordering"] + if m.Centroids == nil || m.TokenOrdering == nil { + return fmt.Errorf("missing ordered embedding tensors: %smasked_embedding.centroids.weight and %smasked_embedding.token_ordering", prefix, prefix) + } + m.TokenOrdering = m.TokenOrdering.AsType(mlx.DTypeInt32) + } + + normWeight := tensors[prefix+"model.norm.weight"] + if normWeight == nil { + return fmt.Errorf("missing assistant final norm") + } + m.Norm = nn.NewRMSNorm(normWeight, m.TextConfig.RMSNormEps) + + for i := range m.TextConfig.NumHiddenLayers { + layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i) + layer := &AssistantLayer{ + IsSliding: isLayerSliding(i, &m.TextConfig), + Attention: &AssistantAttention{ + QProj: linears.Make(layerPrefix + ".self_attn.q_proj"), + OProj: linears.Make(layerPrefix + ".self_attn.o_proj"), + }, + MLP: &MLP{ + GateProj: linears.Make(layerPrefix + ".mlp.gate_proj"), + UpProj: linears.Make(layerPrefix + ".mlp.up_proj"), + DownProj: linears.Make(layerPrefix + ".mlp.down_proj"), + }, + LayerScalar: tensors[layerPrefix+".layer_scalar"], + } + + if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil { + layer.InputNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps) + } + if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil { + layer.PostAttnNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps) + } + if w := tensors[layerPrefix+".pre_feedforward_layernorm.weight"]; w != nil { + layer.PreFFNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps) + } + if w := tensors[layerPrefix+".post_feedforward_layernorm.weight"]; w != nil { + layer.PostFFNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps) + } + if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil { + layer.Attention.QNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps) + } + + if layer.InputNorm == nil || layer.PostAttnNorm == nil || layer.PreFFNorm == nil || layer.PostFFNorm == nil { + return fmt.Errorf("assistant layer %d: missing norm weights", i) + } + if layer.Attention.QProj == nil || layer.Attention.OProj == nil || layer.Attention.QNorm == nil { + return fmt.Errorf("assistant layer %d: missing attention weights", i) + } + if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil { + return fmt.Errorf("assistant layer %d: missing mlp weights", i) + } + + m.Layers[i] = layer + } + + m.precomputeScaledWeights() + return nil +} + +func (m *AssistantModel) precomputeScaledWeights() { + if m.Norm != nil { + m.NormScaled = m.Norm.Weight + } + for _, layer := range m.Layers { + if layer.InputNorm != nil { + layer.InputNormScaled = layer.InputNorm.Weight + } + if layer.PostAttnNorm != nil { + layer.PostAttnNormScaled = layer.PostAttnNorm.Weight + } + if layer.PreFFNorm != nil { + layer.PreFFNormScaled = layer.PreFFNorm.Weight + } + if layer.PostFFNorm != nil { + layer.PostFFNormScaled = layer.PostFFNorm.Weight + } + if layer.Attention != nil && layer.Attention.QNorm != nil { + layer.Attention.QNormScaled = layer.Attention.QNorm.Weight + } + } +} + +func (m *AssistantModel) Draft(inputsEmbeds *mlx.Array, position int32, caches []cache.Cache) (logits, hidden *mlx.Array) { + dims := inputsEmbeds.Dims() + B, L := int32(dims[0]), int32(dims[1]) + b := &batch.Batch{ + InputIDs: mlx.Zeros(mlx.DTypeInt32, int(B), int(L)), + SeqOffsets: []int32{position}, + SeqQueryLens: []int32{L}, + } + + sliding, full := m.sharedHistories(b, caches) + h := m.PreProjection.Forward(inputsEmbeds) + + positions := mlx.FromValues([]int32{position}, 1) + for _, layer := range m.Layers { + h = layer.Forward(h, b, positions, B, L, &m.TextConfig, sliding, full) + } + + hidden = mlx.RMSNormFn(h, m.NormScaled, m.TextConfig.RMSNormEps) + projected := m.PostProjection.Forward(hidden) + return m.unembed(hidden), projected +} + +func (m *AssistantModel) sharedHistories(b *batch.Batch, caches []cache.Cache) (sliding, full *nn.KVHistory) { + if len(caches) < 2 { + return nil, nil + } + if v, ok := caches[len(caches)-2].(cache.Viewer); ok { + sliding = v.View(b) + } + if v, ok := caches[len(caches)-1].(cache.Viewer); ok { + full = v.View(b) + } + return sliding, full +} + +func (m *AssistantModel) unembed(hidden *mlx.Array) *mlx.Array { + if m.UseOrderedEmbeddings { + return m.applyCentroidMasking(hidden) + } + return m.LMHead.Forward(hidden) +} + +func (m *AssistantModel) applyCentroidMasking(hidden *mlx.Array) *mlx.Array { + B, L := hidden.Dim(0), hidden.Dim(1) + vocab := int(m.TextConfig.VocabSize) + numCentroids := int(m.NumCentroids) + vocabPerCentroid := vocab / numCentroids + topK := int(m.CentroidIntermediateTopK) + + centroidLogits := m.Centroids.Forward(hidden) + topKIndices := centroidLogits.Negative().ArgpartitionAxis(topK-1, -1).Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, topK)) + ordering := m.TokenOrdering.Reshape(numCentroids, vocabPerCentroid) + selectedCanonical := ordering.TakeAxis(topKIndices, 0) + selectedFlat := selectedCanonical.Reshape(B * L * topK * vocabPerCentroid) + + embeddings := m.EmbedTokens.Forward(selectedFlat) + embeddings = embeddings.Reshape(B, L, topK*vocabPerCentroid, int(m.TextConfig.HiddenSize)) + selectedLogits := hidden.ExpandDims(2).Matmul(embeddings.Transpose(0, 1, 3, 2)).Squeeze(2) + + out := mlx.Zeros(selectedLogits.DType(), B, L, vocab) + out = mlx.AddScalar(out, -1.0e30) + return out.PutAlongAxis(selectedCanonical.Reshape(B, L, topK*vocabPerCentroid), selectedLogits, -1) +} + +func (l *AssistantLayer) Forward(x *mlx.Array, b *batch.Batch, positions *mlx.Array, B, L int32, cfg *TextConfig, sliding, full *nn.KVHistory) *mlx.Array { + normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps) + attnOut := l.Attention.Forward(normed, b, positions, B, L, l.IsSliding, cfg, sliding, full) + attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps) + h := mlx.Add(x, attnOut) + + normed = mlx.RMSNormFn(h, l.PreFFNormScaled, cfg.RMSNormEps) + mlpOut := l.MLP.Forward(normed) + mlpOut = mlx.RMSNormFn(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps) + h = mlx.Add(h, mlpOut) + + if l.LayerScalar != nil { + h = mlx.Mul(h, l.LayerScalar) + } + return h +} + +func (a *AssistantAttention) Forward(x *mlx.Array, b *batch.Batch, positions *mlx.Array, B, L int32, isSliding bool, cfg *TextConfig, sliding, full *nn.KVHistory) *mlx.Array { + headDim := cfg.HeadDim + scale := cfg.SlidingScale + ropeDims := cfg.SlidingRopeDims + ropeBase := cfg.SlidingRopeBase + history := sliding + if !isSliding { + headDim = cfg.GlobalHeadDim + scale = cfg.FullScale + ropeDims = cfg.FullRopeDims + ropeBase = cfg.FullRopeBase + history = full + } + if history == nil { + panic("gemma4 assistant missing shared target KV history") + } + + q := a.QProj.Forward(x) + q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim) + q = mlx.Transpose(q, 0, 2, 1, 3) + q = mlx.RMSNormFn(q, a.QNormScaled, cfg.RMSNormEps) + + var ropeFreqs *mlx.Array + if !isSliding { + ropeFreqs = cfg.FullRopeFreqs + } + q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, positions, ropeFreqs) + + mask := nn.CausalMask() + if isSliding && cfg.SlidingWindow > 0 { + mask = mask.Intersect(nn.SlidingWindowMask(b, history.K().Dim(2), int(cfg.SlidingWindow), q.DType())) + } + + out := nn.ScaledDotProductAttention(b, q, scale, nn.WithKVHistory(history), nn.WithMask(mask)) + out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*headDim) + if !mlx.MetalIsAvailable() { + out = mlx.Contiguous(out, false) + } + return a.OProj.Forward(out) +} diff --git a/x/models/gemma4/gemma4.go b/x/models/gemma4/gemma4.go index df1d1b271..adfc7ab55 100644 --- a/x/models/gemma4/gemma4.go +++ b/x/models/gemma4/gemma4.go @@ -18,10 +18,15 @@ import ( func init() { base.Register("Gemma4ForCausalLM", newModel) base.Register("Gemma4ForConditionalGeneration", newModel) + base.RegisterDraft("Gemma4AssistantForCausalLM", newAssistantModel) + base.RegisterDraft("gemma4_assistant", newAssistantModel) } // Compile-time interface checks. -var _ base.Model = (*Model)(nil) +var ( + _ base.Model = (*Model)(nil) + _ base.MTPDefaultsProvider = (*Model)(nil) +) // RopeParams holds per-layer-type RoPE settings. type RopeParams struct { @@ -466,6 +471,24 @@ func (m *Model) EnableCompile() bool { return true } +func (m *Model) MTPDraftDefaults(_ bool) base.MTPDefaults { + defaults := base.MTPDefaults{ + InitialDraftTokens: 4, + MaxDraftTokens: 16, + Enabled: true, + } + if m == nil || m.TextConfig == nil { + return defaults + } + switch { + case !m.EnableMoeBlock && m.HiddenSize == 5376 && m.NumHiddenLayers == 60: + defaults.InitialDraftTokens = 14 + case m.EnableMoeBlock && m.HiddenSize == 2816 && m.NumHiddenLayers == 30: + defaults.InitialDraftTokens = 8 + } + return defaults +} + func resolveWeightPrefix(tensors map[string]*mlx.Array) string { for _, prefix := range []string{"", "language_model.", "model.language_model."} { if tensors[prefix+"embed_tokens.weight"] != nil { @@ -1068,6 +1091,11 @@ func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok } +// TokenEmbeddings returns the target model's scaled token embeddings for MTP. +func (m *Model) TokenEmbeddings(inputIDs *mlx.Array) *mlx.Array { + return mlx.MulScalar(m.EmbedTokens.Forward(inputIDs), m.EmbedScale) +} + // NewCaches creates cache objects for layers that own KV state. func (m *Model) NewCaches() []cache.Cache { cacheLayers := len(m.Layers) diff --git a/x/models/gemma4/gemma4_test.go b/x/models/gemma4/gemma4_test.go index 4f674ca66..d00864e10 100644 --- a/x/models/gemma4/gemma4_test.go +++ b/x/models/gemma4/gemma4_test.go @@ -434,6 +434,54 @@ func TestLayerTypeDetection(t *testing.T) { } } +func TestMTPDraftDefaults(t *testing.T) { + tests := []struct { + name string + cfg *TextConfig + wantInitial int + wantMax int + }{ + { + name: "nil config", + wantInitial: 4, + wantMax: 16, + }, + { + name: "31b bf16", + cfg: &TextConfig{HiddenSize: 5376, NumHiddenLayers: 60}, + wantInitial: 14, + wantMax: 16, + }, + { + name: "31b quantized", + cfg: &TextConfig{HiddenSize: 5376, NumHiddenLayers: 60, QuantBits: 4}, + wantInitial: 14, + wantMax: 16, + }, + { + name: "26b-a4b moe", + cfg: &TextConfig{HiddenSize: 2816, NumHiddenLayers: 30, EnableMoeBlock: true}, + wantInitial: 8, + wantMax: 16, + }, + { + name: "generic default", + cfg: &TextConfig{HiddenSize: 2560, NumHiddenLayers: 42, HiddenSizePerLayer: 256}, + wantInitial: 4, + wantMax: 16, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := (&Model{TextConfig: tt.cfg}).MTPDraftDefaults(false) + if got.InitialDraftTokens != tt.wantInitial || got.MaxDraftTokens != tt.wantMax || !got.Enabled { + t.Fatalf("MTPDraftDefaults() = %+v, want initial=%d max=%d enabled=true", got, tt.wantInitial, tt.wantMax) + } + }) + } +} + func TestNewCachesOmitsSharedKVLayers(t *testing.T) { m := &Model{ Layers: []*DecoderLayer{ @@ -467,6 +515,55 @@ func TestNewCachesIncludesAllNonSharedLayers(t *testing.T) { } } +func TestNewCachesAssistantSharedHistoryOrdering(t *testing.T) { + cases := []struct { + name string + totalLayers int + slidingBeforeFull int + cacheLayers int + }{ + {name: "31B", totalLayers: 60, slidingBeforeFull: 5, cacheLayers: 60}, + {name: "26B-A4B", totalLayers: 30, slidingBeforeFull: 5, cacheLayers: 30}, + {name: "E4B", totalLayers: 42, slidingBeforeFull: 5, cacheLayers: 24}, + {name: "E2B", totalLayers: 35, slidingBeforeFull: 4, cacheLayers: 15}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + groupSize := tc.slidingBeforeFull + 1 + layers := make([]*DecoderLayer, tc.totalLayers) + for i := range layers { + donor := int32(-1) + if i >= tc.cacheLayers { + donor = 0 + } + layers[i] = &DecoderLayer{ + IsSliding: i%groupSize < tc.slidingBeforeFull, + KVShareDonor: donor, + } + } + + m := &Model{ + Layers: layers, + TextConfig: &TextConfig{SlidingWindow: 512}, + } + caches := m.NewCaches() + if got := len(caches); got != tc.cacheLayers { + t.Fatalf("len(NewCaches()) = %d, want %d", got, tc.cacheLayers) + } + + gotSliding := len(caches) - 2 + gotFull := len(caches) - 1 + if !m.Layers[gotSliding].IsSliding { + t.Fatalf("cache %d should be sliding attention", gotSliding) + } + if m.Layers[gotFull].IsSliding { + t.Fatalf("cache %d should be full attention", gotFull) + } + }) + } +} + func TestResolveWeightPrefix(t *testing.T) { if err := mlx.CheckInit(); err != nil { t.Skipf("MLX not available: %v", err)