mlx: Gemma4 MTP speculative decoding (#15980)

This change adds support for MTP (multi-token prediction) speculative decoding for the
gemma4 model family.

It includes:
  * support for importing safetensors based gemma4 draft models with `ollama create`
  * a new DRAFT command in the Modelfile for specifying draft models
  * a --quantize-draft flag for the ollama create command to quantize the draft model
  * cache support for speculation
  * changes to the rotating cache to be able to handle MTP correctly
  * sampling support for draft model token prediction

---------

Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
This commit is contained in:
Patrick Devine 2026-05-05 08:55:04 -07:00 committed by GitHub
parent 4017af96cd
commit 15e6076d79
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 2928 additions and 42 deletions

View file

@ -423,8 +423,8 @@ jobs:
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/include*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-mlx.tar.in ;;
lib/ollama/include*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-mlx.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
@ -458,12 +458,14 @@ jobs:
CGO_CFLAGS
CGO_CXXFLAGS
GOFLAGS
APT_MIRROR=http://azure.archive.ubuntu.com/ubuntu
- os: linux
arch: amd64
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
GOFLAGS
APT_MIRROR=http://azure.archive.ubuntu.com/ubuntu
- os: linux
arch: amd64
suffix: '-rocm'
@ -472,6 +474,7 @@ jobs:
CGO_CXXFLAGS
GOFLAGS
FLAVOR=rocm
APT_MIRROR=http://azure.archive.ubuntu.com/ubuntu
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
needs: setup-environment

View file

@ -41,7 +41,7 @@
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 4",
"CMAKE_CUDA_FLAGS": "-t 2",
"OLLAMA_RUNNER_DIR": "cuda_v13"
}
},

View file

@ -212,8 +212,11 @@ COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04
RUN apt-get update \
ARG APT_MIRROR=http://archive.ubuntu.com/ubuntu
RUN sed -i "s|http://archive.ubuntu.com/ubuntu|$APT_MIRROR|g" /etc/apt/sources.list.d/ubuntu.sources \
&& apt-get update \
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
&& sed -i "s|$APT_MIRROR|http://archive.ubuntu.com/ubuntu|g" /etc/apt/sources.list.d/ubuntu.sources \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive /bin /usr/bin

View file

@ -54,6 +54,7 @@ import (
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
xcreate "github.com/ollama/ollama/x/create"
xcreateclient "github.com/ollama/ollama/x/create/client"
"github.com/ollama/ollama/x/imagegen"
)
@ -145,6 +146,39 @@ func isLocalhost() bool {
return ip != nil && (ip.IsLoopback() || ip.IsUnspecified())
}
func resolveExperimentalLocalModelDir(ref, filename string) string {
if ref == "" || filepath.IsAbs(ref) || filename == "" {
return ref
}
candidate := filepath.Join(filepath.Dir(filename), ref)
if xcreate.IsSafetensorsModelDir(candidate) || xcreate.IsTensorModelDir(candidate) {
return candidate
}
return ref
}
func resolveExperimentalDraftDir(ref, filename string) (string, error) {
if ref == "" {
return "", nil
}
if filepath.IsAbs(ref) {
if xcreate.IsSafetensorsModelDir(ref) {
return ref, nil
}
return "", fmt.Errorf("draft %s is not a supported safetensors model directory", ref)
}
if filename != "" {
candidate := filepath.Join(filepath.Dir(filename), ref)
if xcreate.IsSafetensorsModelDir(candidate) {
return candidate, nil
}
}
return "", fmt.Errorf("DRAFT model references are not supported with --experimental yet: %s", ref)
}
func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
@ -159,6 +193,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
// Check for --experimental flag for safetensors model creation
// This gates both safetensors LLM and imagegen model creation
experimental, _ := cmd.Flags().GetBool("experimental")
draftQuantize, _ := cmd.Flags().GetString("draft-quantize")
if draftQuantize != "" && !experimental {
return errors.New("--draft-quantize requires --experimental")
}
if experimental {
if !isLocalhost() {
return errors.New("remote safetensor model creation not yet supported")
@ -192,17 +230,22 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
// Resolve relative paths based on Modelfile location
if !filepath.IsAbs(modelDir) && filename != "" {
modelDir = filepath.Join(filepath.Dir(filename), modelDir)
modelDir = resolveExperimentalLocalModelDir(modelDir, filename)
if mfConfig.Draft != "" {
draftDir, err := resolveExperimentalDraftDir(mfConfig.Draft, filename)
if err != nil {
return err
}
mfConfig.Draft = draftDir
}
quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
ModelDir: modelDir,
Quantize: quantize,
Modelfile: mfConfig,
ModelName: modelName,
ModelDir: modelDir,
Quantize: quantize,
DraftQuantize: draftQuantize,
Modelfile: mfConfig,
}, p)
}
@ -2176,6 +2219,9 @@ func NewCLI() *cobra.Command {
if experimental, _ := cmd.Flags().GetBool("experimental"); experimental {
return nil
}
if draftQuantize, _ := cmd.Flags().GetString("draft-quantize"); draftQuantize != "" {
return errors.New("--draft-quantize requires --experimental")
}
return checkServerHeartbeat(cmd, args)
},
RunE: CreateHandler,
@ -2183,6 +2229,7 @@ func NewCLI() *cobra.Command {
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
createCmd.Flags().String("draft-quantize", "", "Quantize draft model to this level")
createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation")
showCmd := &cobra.Command{

View file

@ -9,6 +9,7 @@ import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
@ -1524,6 +1525,87 @@ func TestCreateHandler(t *testing.T) {
}
}
func TestCreateHandlerDraftQuantizeRequiresExperimental(t *testing.T) {
cmd := &cobra.Command{}
cmd.Flags().Bool("experimental", false, "")
cmd.Flags().String("draft-quantize", "mxfp8", "")
cmd.SetContext(t.Context())
err := CreateHandler(cmd, []string{"test-model"})
if err == nil || !strings.Contains(err.Error(), "--draft-quantize requires --experimental") {
t.Fatalf("error = %v, want draft-quantize requires experimental", err)
}
}
func TestCreateHandlerDraftRequiresExperimental(t *testing.T) {
dir := t.TempDir()
modelfile := filepath.Join(dir, "Modelfile")
if err := os.WriteFile(modelfile, []byte("FROM base\nDRAFT ./assistant\n"), 0o644); err != nil {
t.Fatal(err)
}
cmd := &cobra.Command{}
cmd.Flags().Bool("experimental", false, "")
cmd.Flags().String("draft-quantize", "", "")
cmd.Flags().String("file", modelfile, "")
cmd.SetContext(t.Context())
err := CreateHandler(cmd, []string{"test-model"})
if err == nil || !strings.Contains(err.Error(), "DRAFT requires --experimental") {
t.Fatalf("error = %v, want DRAFT requires --experimental", err)
}
}
func TestResolveExperimentalLocalModelDir(t *testing.T) {
dir := t.TempDir()
modelfile := filepath.Join(dir, "Modelfile")
modelDir := filepath.Join(dir, "model")
if err := os.Mkdir(modelDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(modelDir, "config.json"), []byte(`{}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(modelDir, "model.safetensors"), []byte("dummy"), 0o644); err != nil {
t.Fatal(err)
}
if got := resolveExperimentalLocalModelDir("gemma4", modelfile); got != "gemma4" {
t.Fatalf("resolveExperimentalLocalModelDir(model name) = %q, want gemma4", got)
}
if got := resolveExperimentalLocalModelDir("./model", modelfile); got != modelDir {
t.Fatalf("resolveExperimentalLocalModelDir(local dir) = %q, want %q", got, modelDir)
}
}
func TestResolveExperimentalDraftDir(t *testing.T) {
dir := t.TempDir()
modelfile := filepath.Join(dir, "Modelfile")
draftDir := filepath.Join(dir, "assistant")
if err := os.Mkdir(draftDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(draftDir, "config.json"), []byte(`{}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(draftDir, "model.safetensors"), []byte("dummy"), 0o644); err != nil {
t.Fatal(err)
}
got, err := resolveExperimentalDraftDir("./assistant", modelfile)
if err != nil {
t.Fatal(err)
}
if got != draftDir {
t.Fatalf("resolveExperimentalDraftDir(local dir) = %q, want %q", got, draftDir)
}
_, err = resolveExperimentalDraftDir("assistant-model", modelfile)
if err == nil || !strings.Contains(err.Error(), "DRAFT model references are not supported with --experimental yet") {
t.Fatalf("error = %v, want unsupported draft model reference", err)
}
}
func TestNewCreateRequest(t *testing.T) {
tests := []struct {
name string

View file

@ -83,6 +83,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
req.Files[k] = v
}
}
case "draft":
return nil, errors.New("DRAFT requires --experimental")
case "adapter":
path, err := expandPath(c.Args, relativeDir)
if err != nil {
@ -336,7 +338,7 @@ func (c Command) String() string {
switch c.Name {
case "model":
fmt.Fprintf(&sb, "FROM %s", c.Args)
case "license", "template", "system", "adapter", "renderer", "parser", "requires":
case "license", "template", "system", "adapter", "renderer", "parser", "requires", "draft":
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
case "message":
role, message, _ := strings.Cut(c.Args, ": ")
@ -362,7 +364,7 @@ const (
var (
errMissingFrom = errors.New("no FROM line")
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", \"message\", or \"requires\"")
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"draft\", \"renderer\", \"parser\", \"parameter\", \"message\", or \"requires\"")
)
type ParserError struct {
@ -622,7 +624,7 @@ func isValidMessageRole(role string) bool {
func isValidCommand(cmd string) bool {
switch strings.ToLower(cmd) {
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message", "requires":
case "from", "license", "template", "system", "adapter", "draft", "renderer", "parser", "parameter", "message", "requires":
return true
default:
return false

View file

@ -58,6 +58,32 @@ TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
assert.Equal(t, expectedCommands, modelfile.Commands)
}
func TestParseFileDraft(t *testing.T) {
modelfile, err := ParseFile(strings.NewReader(`
FROM base
DRAFT ./assistant
`))
require.NoError(t, err)
expectedCommands := []Command{
{Name: "model", Args: "base"},
{Name: "draft", Args: "./assistant"},
}
assert.Equal(t, expectedCommands, modelfile.Commands)
assert.Contains(t, modelfile.String(), "DRAFT ./assistant")
}
func TestCreateRequestDraftRequiresExperimental(t *testing.T) {
modelfile, err := ParseFile(strings.NewReader(`
FROM base
DRAFT ./assistant
`))
require.NoError(t, err)
_, err = modelfile.CreateRequest("")
require.ErrorContains(t, err, "DRAFT requires --experimental")
}
func TestParseFileTrimSpace(t *testing.T) {
input := `
FROM " model 1"

View file

@ -62,13 +62,15 @@ if echo $PLATFORM | grep "," > /dev/null ; then
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
tar c -C ./dist/linux_amd64 --exclude rocm . | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
tar c -C ./dist/linux_amd64 --exclude rocm --exclude 'mlx*' --exclude include . | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
tar c -C ./dist/linux_amd64 ./lib/ollama/mlx_cuda_v13 ./lib/ollama/include | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-mlx.tar.zst
elif echo $PLATFORM | grep "arm64" > /dev/null ; then
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
elif echo $PLATFORM | grep "amd64" > /dev/null ; then
tar c -C ./dist/ --exclude rocm bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
tar c -C ./dist/ --exclude rocm --exclude 'mlx*' --exclude include bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
tar c -C ./dist/ ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
tar c -C ./dist/ ./lib/ollama/mlx_cuda_v13 ./lib/ollama/include | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-mlx.tar.zst
fi

View file

@ -19,6 +19,7 @@ type ConfigV2 struct {
ContextLen int `json:"context_length,omitempty"`
EmbedLen int `json:"embedding_length,omitempty"`
BaseName string `json:"base_name,omitempty"`
Draft *Draft `json:"draft,omitempty"`
// required by spec
Architecture string `json:"architecture"`
@ -26,6 +27,14 @@ type ConfigV2 struct {
RootFS RootFS `json:"rootfs"`
}
// Draft describes an auxiliary draft model stored in the same manifest.
type Draft struct {
ModelFormat string `json:"model_format,omitempty"`
Architecture string `json:"architecture,omitempty"`
TensorPrefix string `json:"tensor_prefix,omitempty"`
Config string `json:"config,omitempty"`
}
// RootFS represents the root filesystem configuration for a model.
type RootFS struct {
Type string `json:"type"`

View file

@ -23,6 +23,7 @@ import (
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create"
imagemanifest "github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/safetensors"
)
@ -34,6 +35,7 @@ type ModelfileConfig struct {
Template string
System string
License string
Draft string
Parser string
Renderer string
Parameters map[string]any
@ -67,6 +69,8 @@ func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig,
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
case "draft":
mfConfig.Draft = cmd.Args
case "parser":
mfConfig.Parser = cmd.Args
case "renderer":
@ -108,10 +112,12 @@ func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig,
// CreateOptions holds all options for model creation.
type CreateOptions struct {
ModelName string
ModelDir string
Quantize string // "int4", "int8", "nvfp4", "mxfp4", or "mxfp8" for quantization
Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
ModelName string
ModelDir string
Quantize string // "int4", "int8", "nvfp4", "mxfp4", or "mxfp8" for quantization
DraftQuantize string // optional quantization level for draft model tensors
Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
BaseConfig *model.ConfigV2
}
// CreateModel imports a model from a local directory.
@ -121,11 +127,23 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
// Detect model type
isSafetensors := create.IsSafetensorsModelDir(opts.ModelDir)
isImageGen := create.IsTensorModelDir(opts.ModelDir)
hasDraft := opts.Modelfile != nil && opts.Modelfile.Draft != ""
isBaseModelWithDraft := hasDraft && !isSafetensors && create.IsSafetensorsLLMModel(opts.ModelDir)
if opts.DraftQuantize != "" && !hasDraft {
return fmt.Errorf("--draft-quantize requires a DRAFT model")
}
if !isSafetensors && !isImageGen {
if !isSafetensors && !isImageGen && !isBaseModelWithDraft {
return fmt.Errorf("%s is not a supported model directory (needs config.json + *.safetensors or model_index.json)", opts.ModelDir)
}
if hasDraft && !create.IsSafetensorsModelDir(opts.Modelfile.Draft) {
return fmt.Errorf("draft %s is not a supported safetensors model directory", opts.Modelfile.Draft)
}
if hasDraft && isImageGen {
return fmt.Errorf("draft models are only supported for safetensors LLM models")
}
// Determine model type settings
var modelType, spinnerKey string
var capabilities []string
@ -138,6 +156,9 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
parserName = getParserName(opts.ModelDir)
rendererName = getRendererName(opts.ModelDir)
capabilities = inferSafetensorsCapabilities(opts.ModelDir, resolveParserName(opts.Modelfile, parserName))
} else if isBaseModelWithDraft {
modelType = "safetensors model"
spinnerKey = "create"
} else {
modelType = "image generation model"
spinnerKey = "imagegen"
@ -156,13 +177,44 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
p.Add(spinnerKey, spinner)
}
// Create the model using shared callbacks
var draftLayers []create.LayerInfo
var err error
if hasDraft {
draftLayers, err = create.CreateDraftSafetensorsLayers(
opts.Modelfile.Draft,
"draft.",
"draft",
opts.DraftQuantize,
newLayerCreator(),
newTensorLayerCreator(),
progressFn,
)
if err != nil {
spinner.Stop()
return err
}
}
if isBaseModelWithDraft {
err = createModelFromBaseWithDraft(opts, draftLayers, progressFn)
spinner.Stop()
if err != nil {
return err
}
fmt.Printf("Created safetensors model '%s'\n", opts.ModelName)
return nil
}
// Create the model using shared callbacks
if isSafetensors {
writer := newManifestWriter(opts, capabilities, parserName, rendererName)
if len(draftLayers) > 0 {
writer = appendLayersManifestWriter(writer, draftLayers)
}
err = create.CreateSafetensorsModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities, parserName, rendererName),
writer,
progressFn,
newPackedTensorLayerCreator(),
)
@ -184,6 +236,68 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
return nil
}
func appendLayersManifestWriter(next create.ManifestWriter, extra []create.LayerInfo) create.ManifestWriter {
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
layers = append(layers, extra...)
return next(modelName, config, layers)
}
}
func createModelFromBaseWithDraft(opts CreateOptions, draftLayers []create.LayerInfo, progressFn func(string)) error {
progressFn(fmt.Sprintf("loading base model %s", opts.ModelDir))
baseManifest, err := imagemanifest.LoadManifest(opts.ModelDir)
if err != nil {
return err
}
baseConfig, err := readConfigV2(baseManifest)
if err != nil {
return err
}
opts.BaseConfig = baseConfig
configLayer := baseManifest.GetConfigLayer("config.json")
if configLayer == nil {
return fmt.Errorf("base model %s does not contain config.json", opts.ModelDir)
}
layers := make([]create.LayerInfo, 0, len(baseManifest.Manifest.Layers)+len(draftLayers))
for _, layer := range baseManifest.Manifest.Layers {
layers = append(layers, create.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: layer.Name,
})
}
layers = append(layers, draftLayers...)
progressFn(fmt.Sprintf("writing manifest for %s", opts.ModelName))
return newManifestWriter(opts, baseConfig.Capabilities, baseConfig.Parser, baseConfig.Renderer)(
opts.ModelName,
create.LayerInfo{
Digest: configLayer.Digest,
Size: configLayer.Size,
MediaType: configLayer.MediaType,
Name: configLayer.Name,
},
layers,
)
}
func readConfigV2(m *imagemanifest.ModelManifest) (*model.ConfigV2, error) {
data, err := os.ReadFile(m.BlobPath(m.Manifest.Config.Digest))
if err != nil {
return nil, fmt.Errorf("failed to read base config: %w", err)
}
var cfg model.ConfigV2
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse base config: %w", err)
}
return &cfg, nil
}
func inferSafetensorsCapabilities(modelDir, parserName string) []string {
capabilities := []string{"completion"}
@ -359,14 +473,26 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
}
}
// Create config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
FileType: strings.ToLower(strings.TrimSpace(opts.Quantize)),
Capabilities: caps,
Requires: MinOllamaVersion,
Parser: resolveParserName(opts.Modelfile, parserName),
Renderer: resolveRendererName(opts.Modelfile, rendererName),
// Create config blob with version requirement.
configData := model.ConfigV2{}
if opts.BaseConfig != nil {
configData = *opts.BaseConfig
}
configData.ModelFormat = "safetensors"
if opts.Quantize != "" || configData.FileType == "" {
configData.FileType = strings.ToLower(strings.TrimSpace(opts.Quantize))
}
configData.Capabilities = caps
configData.Requires = MinOllamaVersion
configData.Parser = resolveParserName(opts.Modelfile, parserName)
configData.Renderer = resolveRendererName(opts.Modelfile, rendererName)
if opts.Modelfile != nil && opts.Modelfile.Draft != "" {
configData.Draft = &model.Draft{
ModelFormat: "safetensors",
Architecture: "Gemma4AssistantForCausalLM",
TensorPrefix: "draft.",
Config: "draft/config.json",
}
}
configJSON, err := json.Marshal(configData)
if err != nil {

View file

@ -44,6 +44,7 @@ func TestModelfileConfig(t *testing.T) {
func TestConfigFromModelfile(t *testing.T) {
modelfile, err := parser.ParseFile(strings.NewReader(`
FROM ./model
DRAFT ./assistant
TEMPLATE {{ .Prompt }}
PARAMETER temperature 0.7
PARAMETER stop USER:
@ -66,6 +67,10 @@ PARAMETER stop ASSISTANT:
t.Fatalf("Template = %q, want %q", mfConfig.Template, "{{ .Prompt }}")
}
if mfConfig.Draft != "./assistant" {
t.Fatalf("Draft = %q, want %q", mfConfig.Draft, "./assistant")
}
if got := mfConfig.Parameters["temperature"]; got != float32(0.7) {
t.Fatalf("temperature = %#v, want %v", got, float32(0.7))
}
@ -153,11 +158,23 @@ func TestCreateModel_NotSafetensorsDir(t *testing.T) {
}
}
func TestCreateModel_DraftQuantizeRequiresDraft(t *testing.T) {
err := CreateModel(CreateOptions{
ModelName: "test-model",
ModelDir: t.TempDir(),
DraftQuantize: "mxfp8",
}, nil)
if err == nil || !strings.Contains(err.Error(), "--draft-quantize requires a DRAFT model") {
t.Fatalf("error = %v, want draft-quantize requires DRAFT", err)
}
}
func TestCreateOptions(t *testing.T) {
opts := CreateOptions{
ModelName: "my-model",
ModelDir: "/path/to/model",
Quantize: "fp8",
ModelName: "my-model",
ModelDir: "/path/to/model",
Quantize: "fp8",
DraftQuantize: "mxfp8",
Modelfile: &ModelfileConfig{
Template: "test",
System: "system",
@ -179,6 +196,9 @@ func TestCreateOptions(t *testing.T) {
if opts.Quantize != "fp8" {
t.Errorf("Quantize = %q, want %q", opts.Quantize, "fp8")
}
if opts.DraftQuantize != "mxfp8" {
t.Errorf("DraftQuantize = %q, want %q", opts.DraftQuantize, "mxfp8")
}
if opts.Modelfile == nil {
t.Error("Modelfile should not be nil")
}
@ -286,6 +306,9 @@ func TestCreateOptions_Defaults(t *testing.T) {
if opts.Quantize != "" {
t.Errorf("Quantize should be empty by default, got %q", opts.Quantize)
}
if opts.DraftQuantize != "" {
t.Errorf("DraftQuantize should be empty by default, got %q", opts.DraftQuantize)
}
// Modelfile should default to nil
if opts.Modelfile != nil {
@ -518,6 +541,48 @@ func TestNewManifestWriter_PopulatesFileTypeFromQuantize(t *testing.T) {
}
}
func TestNewManifestWriter_PopulatesDraftMetadata(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
opts := CreateOptions{
ModelName: "test-draft",
ModelDir: t.TempDir(),
Modelfile: &ModelfileConfig{Draft: "/tmp/assistant"},
}
writer := newManifestWriter(opts, []string{"completion"}, "gemma4", "gemma4")
if err := writer(opts.ModelName, create.LayerInfo{}, nil); err != nil {
t.Fatalf("newManifestWriter() error = %v", err)
}
name := model.ParseName(opts.ModelName)
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
t.Fatalf("ParseNamedManifest() error = %v", err)
}
configPath, err := manifest.BlobsPath(mf.Config.Digest)
if err != nil {
t.Fatalf("BlobsPath() error = %v", err)
}
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("ReadFile() error = %v", err)
}
var cfg model.ConfigV2
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if cfg.Draft == nil {
t.Fatal("Draft metadata missing")
}
if cfg.Draft.TensorPrefix != "draft." || cfg.Draft.Config != "draft/config.json" {
t.Fatalf("Draft = %#v, want draft prefix/config", cfg.Draft)
}
}
func TestSupportsThinking(t *testing.T) {
tests := []struct {
name string

View file

@ -7,6 +7,7 @@ import (
"io"
"math"
"os"
"path"
"path/filepath"
"regexp"
"slices"
@ -724,12 +725,16 @@ func detectModelOptQuantization(modelDir string) bool {
}
func resolveEffectiveQuantization(cfg sourceModelConfig, sourceKind sourceQuantizedKind, requested string) (string, error) {
return resolveEffectiveQuantizationForFlag(cfg, sourceKind, requested, "--quantize")
}
func resolveEffectiveQuantizationForFlag(cfg sourceModelConfig, sourceKind sourceQuantizedKind, requested, flagName string) (string, error) {
switch sourceKind {
case sourceQuantizedKindNone:
return requested, nil
case sourceQuantizedKindPrequantized:
if requested != "" {
return "", fmt.Errorf("cannot requantize already-quantized source model with --quantize %q", requested)
return "", fmt.Errorf("cannot requantize already-quantized source model with %s %q", flagName, requested)
}
return "", nil
case sourceQuantizedKindSourceFP8:
@ -746,7 +751,7 @@ func resolveEffectiveQuantization(cfg sourceModelConfig, sourceKind sourceQuanti
case "nvfp4", "mxfp4", "mxfp8":
return requested, nil
default:
return "", fmt.Errorf("cannot convert already-quantized fp8 source model with --quantize %q", requested)
return "", fmt.Errorf("cannot convert already-quantized fp8 source model with %s %q", flagName, requested)
}
}
return "mxfp8", nil
@ -810,6 +815,7 @@ var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{
"Gemma4ForCausalLM": newGemma4ImportTransform,
"Gemma4ForConditionalGeneration": newGemma4ImportTransform,
"LagunaForCausalLM": newLagunaImportTransform,
"Gemma4AssistantForCausalLM": newGemma4ImportTransform,
}
func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) {
@ -1167,6 +1173,136 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
return nil
}
func normalizeRequestedQuantization(flagName, quantize string) (string, error) {
q := normalizeQuantType(strings.TrimSpace(quantize))
switch q {
case "", "int4", "int8", "nvfp4", "mxfp4", "mxfp8":
return q, nil
default:
return "", fmt.Errorf("unsupported %s %q: supported types are int4, int8, nvfp4, mxfp4, mxfp8", flagName, quantize)
}
}
// CreateDraftSafetensorsLayers imports an assistant/draft safetensors model
// into prefixed tensor and config layers. When draftQuantize is non-empty,
// eligible draft tensors are quantized with the same per-architecture policy
// used by target safetensors imports.
func CreateDraftSafetensorsLayers(modelDir, tensorPrefix, configPrefix, draftQuantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, fn func(status string)) ([]LayerInfo, error) {
if tensorPrefix == "" {
return nil, fmt.Errorf("draft tensor prefix must not be empty")
}
if configPrefix == "" {
return nil, fmt.Errorf("draft config prefix must not be empty")
}
effectiveQuantize, err := normalizeRequestedQuantization("--draft-quantize", draftQuantize)
if err != nil {
return nil, err
}
var importTransform tensorImportTransform = noopImportTransform{}
if effectiveQuantize != "" {
sourceConfig, err := readSourceModelConfig(modelDir)
if err != nil {
return nil, fmt.Errorf("failed to read draft config.json: %w", err)
}
sourceQuantKind, err := inspectSourceQuantization(modelDir, sourceConfig)
if err != nil {
return nil, fmt.Errorf("failed to inspect draft quantization: %w", err)
}
effectiveQuantize, err = resolveEffectiveQuantizationForFlag(sourceConfig, sourceQuantKind, effectiveQuantize, "--draft-quantize")
if err != nil {
return nil, err
}
importTransform, err = newTensorImportTransform(modelDir, sourceConfig)
if err != nil {
return nil, fmt.Errorf("failed to construct draft import transform for architecture %q: %w", sourceConfig.Architecture(), err)
}
}
entries, err := os.ReadDir(modelDir)
if err != nil {
return nil, fmt.Errorf("failed to read draft directory: %w", err)
}
var layers []LayerInfo
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
continue
}
stPath := filepath.Join(modelDir, entry.Name())
extractor, err := safetensors.OpenForExtraction(stPath)
if err != nil {
return nil, fmt.Errorf("failed to open draft %s: %w", stPath, err)
}
tensorNames := extractor.ListTensors()
fn(fmt.Sprintf("importing draft %s (%d tensors%s)", entry.Name(), len(tensorNames), importQuantizationStatus(sourceQuantizedKindNone, effectiveQuantize)))
for _, tensorName := range tensorNames {
if importTransform.skipTensor(tensorName) {
continue
}
td, err := extractor.GetTensor(tensorName)
if err != nil {
extractor.Close()
return nil, fmt.Errorf("failed to get draft tensor %s: %w", tensorName, err)
}
outTDs, err := importTransform.transformTensor(td)
if err != nil {
extractor.Close()
return nil, fmt.Errorf("failed to transform draft tensor %s: %w", tensorName, err)
}
for _, transformedTD := range outTDs {
if transformedTD == nil {
continue
}
outTD := transformedTD.WithName(tensorPrefix + transformedTD.Name)
quantizeType := ""
if effectiveQuantize != "" {
quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, effectiveQuantize)
if isEmbedTokensWeight(outTD.Name) {
quantizeType = ""
}
}
newLayers, err := createTensorLayer(outTD.SafetensorsReader(), outTD.Name, outTD.Dtype, outTD.Shape, quantizeType)
if err != nil {
extractor.Close()
return nil, fmt.Errorf("failed to create draft layer for %s: %w", tensorName, err)
}
layers = append(layers, newLayers...)
}
}
extractor.Close()
}
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
continue
}
if entry.Name() == "model.safetensors.index.json" {
continue
}
cfgPath := entry.Name()
fullPath := filepath.Join(modelDir, cfgPath)
fn(fmt.Sprintf("importing draft config %s", cfgPath))
f, err := os.Open(fullPath)
if err != nil {
return nil, fmt.Errorf("failed to open draft %s: %w", cfgPath, err)
}
layer, err := createLayer(f, "application/vnd.ollama.image.json", path.Join(configPrefix, cfgPath))
f.Close()
if err != nil {
return nil, fmt.Errorf("failed to create draft config layer for %s: %w", cfgPath, err)
}
layers = append(layers, layer)
}
return layers, nil
}
func shouldSkipSourceCompanion(name string, tensorSet map[string]struct{}, sourceTensorFiles map[string]string) bool {
switch {
case strings.HasSuffix(name, ".scales"):

View file

@ -452,6 +452,137 @@ func TestCreateSafetensorsModel(t *testing.T) {
}
}
func TestCreateDraftSafetensorsLayersPrefixesTensorsAndConfigs(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"gemma4_assistant"}`), 0o644); err != nil {
t.Fatal(err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{2, 2}, make([]byte, 8)),
})
var tensorNames []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
return LayerInfo{Digest: "sha256:json_" + name, Size: int64(len(data)), MediaType: mediaType, Name: name}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
tensorNames = append(tensorNames, name)
tensorName, tensorShape := readSingleTensorNameAndShape(t, data)
if tensorName != name {
t.Fatalf("safetensors key = %q, want %q", tensorName, name)
}
if !slices.Equal(tensorShape, shape) {
t.Fatalf("shape = %v, want %v", tensorShape, shape)
}
if quantize != "" {
t.Fatalf("draft quantize = %q, want empty", quantize)
}
return []LayerInfo{{Digest: "sha256:tensor_" + name, Size: int64(len(data)), MediaType: "application/vnd.ollama.image.tensor", Name: name}}, nil
}
layers, err := CreateDraftSafetensorsLayers(dir, "draft.", "draft", "", createLayer, createTensorLayer, func(string) {})
if err != nil {
t.Fatal(err)
}
if !slices.Contains(tensorNames, "draft.model.layers.0.self_attn.q_proj.weight") {
t.Fatalf("draft tensor was not prefixed: %v", tensorNames)
}
var hasDraftConfig bool
for _, layer := range layers {
if layer.Name == "draft/config.json" && layer.MediaType == "application/vnd.ollama.image.json" {
hasDraftConfig = true
}
}
if !hasDraftConfig {
t.Fatalf("draft/config.json layer missing: %#v", layers)
}
}
func TestCreateDraftSafetensorsLayersQuantizesEligibleTensors(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{
"architectures":["Gemma4AssistantForCausalLM"],
"num_hidden_layers":8
}`), 0o644); err != nil {
t.Fatal(err)
}
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
st.NewTensorDataFromBytes("model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
st.NewTensorDataFromBytes("model.layers.0.input_layernorm.weight", "BF16", []int32{64}, make([]byte, 64*2)),
})
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
return LayerInfo{Digest: "sha256:json_" + name, Size: int64(len(data)), MediaType: mediaType, Name: name}, nil
}
quantizeByName := make(map[string]string)
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
quantizeByName[name] = quantize
return []LayerInfo{{Digest: "sha256:tensor_" + name, Size: int64(len(data)), MediaType: "application/vnd.ollama.image.tensor", Name: name}}, nil
}
if _, err := CreateDraftSafetensorsLayers(dir, "draft.", "draft", "MXFP8", createLayer, createTensorLayer, func(string) {}); err != nil {
t.Fatal(err)
}
if got := quantizeByName["draft.model.layers.0.self_attn.q_proj.weight"]; got != "mxfp8" {
t.Fatalf("q_proj draft quantize = %q, want mxfp8", got)
}
if got := quantizeByName["draft.model.layers.0.input_layernorm.weight"]; got != "" {
t.Fatalf("norm draft quantize = %q, want empty", got)
}
if got := quantizeByName["draft.model.embed_tokens.weight"]; got != "" {
t.Fatalf("embed_tokens draft quantize = %q, want empty", got)
}
}
func TestCreateDraftSafetensorsLayersRejectsUnsupportedDraftQuantize(t *testing.T) {
_, err := CreateDraftSafetensorsLayers(t.TempDir(), "draft.", "draft", "bogus", nil, nil, func(string) {})
if err == nil || !strings.Contains(err.Error(), "unsupported --draft-quantize") {
t.Fatalf("error = %v, want unsupported --draft-quantize", err)
}
}
func readSingleTensorNameAndShape(t *testing.T, data []byte) (string, []int32) {
t.Helper()
var headerSize uint64
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
t.Fatalf("failed to read header size: %v", err)
}
var header map[string]struct {
Shape []int32 `json:"shape"`
}
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
t.Fatalf("failed to parse header: %v", err)
}
for name, info := range header {
if name != "__metadata__" {
return name, info.Shape
}
}
t.Fatal("no tensor entry found in header")
return "", nil
}
func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) {
dir := t.TempDir()

View file

@ -54,6 +54,81 @@ type Attention interface {
Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory
}
// Viewer exposes a read-only attention history for a cache.
type Viewer interface {
View(b *batch.Batch) *nn.KVHistory
}
type speculativeCommitter interface {
Cache
commit(n int)
}
// Speculation is an isolated cache transaction for speculative target
// validation. Updates record generated K/V without mutating the live caches;
// Commit appends only the accepted prefix to the live caches.
type Speculation struct {
layers []speculativeCommitter
}
// BeginSpeculation returns cache wrappers suitable for a speculative target
// forward. The returned caches must only be used for that forward.
func BeginSpeculation(caches []Cache) ([]Cache, *Speculation, bool) {
specCaches := make([]Cache, len(caches))
layers := make([]speculativeCommitter, len(caches))
for i, c := range caches {
switch c := c.(type) {
case nil:
case *RotatingKVCache:
sc := newSpeculativeRotatingKVCache(c)
specCaches[i] = sc
layers[i] = sc
case *KVCache:
sc := newSpeculativeKVCache(c)
specCaches[i] = sc
layers[i] = sc
default:
return nil, nil, false
}
}
return specCaches, &Speculation{layers: layers}, true
}
// BeginIsolatedSpeculation returns cache wrappers that never mutate live cache
// state. It is intended for correctness instrumentation, not the hot path.
func BeginIsolatedSpeculation(caches []Cache) ([]Cache, bool) {
specCaches := make([]Cache, len(caches))
for i, c := range caches {
switch c := c.(type) {
case nil:
case *RotatingKVCache:
specCaches[i] = newSpeculativeRotatingKVCache(c)
case *KVCache:
specCaches[i] = newIsolatedKVCache(c)
default:
return nil, false
}
}
return specCaches, true
}
// Commit appends the accepted prefix from the speculative forward to the live
// caches. The target bonus token is intentionally not committed.
func (s *Speculation) Commit(n int) {
if s == nil {
return
}
for _, layer := range s.layers {
if layer != nil {
layer.commit(n)
}
}
}
type KVCache struct {
keys, values *mlx.Array
offset int
@ -70,6 +145,15 @@ func (c *KVCache) Update(_ *batch.Batch, keys, values *mlx.Array) *nn.KVHistory
return nn.NewKVHistory(newK, newV, nil)
}
// View returns the current cache contents as attention history without writing.
func (c *KVCache) View(_ *batch.Batch) *nn.KVHistory {
state := c.State()
if len(state) < 2 {
return nil
}
return nn.NewKVHistory(state[0], state[1], nil)
}
// appendKV is the raw write path shared by Update and Restore.
func (c *KVCache) appendKV(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
@ -250,6 +334,94 @@ func (c *KVCache) Free() {
func (c *KVCache) Offset() int { return c.offset }
type speculativeBase struct {
offset int
}
func (s *speculativeBase) Free() {}
func (s *speculativeBase) Offset() int { return s.offset }
func (s *speculativeBase) Snapshot(int) Snapshot { return nil }
func (s *speculativeBase) Restore(Snapshot, int) bool { return false }
func (s *speculativeBase) Merge(parent, child Snapshot) Snapshot { return nil }
func (s *speculativeBase) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
return nil, snapshot
}
type speculativeKVCache struct {
speculativeBase
target *KVCache
start int
end int
}
func newSpeculativeKVCache(target *KVCache) *speculativeKVCache {
return &speculativeKVCache{
speculativeBase: speculativeBase{offset: target.Offset()},
target: target,
start: target.Offset(),
end: target.Offset(),
}
}
func (c *speculativeKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
history := c.target.Update(b, keys, values)
c.offset = c.target.Offset()
c.end = c.target.Offset()
return history
}
func (c *speculativeKVCache) State() []*mlx.Array {
return c.target.State()
}
func (c *speculativeKVCache) commit(n int) {
target := max(c.start, c.start+n)
if target > c.end {
target = c.end
}
c.target.offset = target
c.offset = target
}
type isolatedKVCache struct {
speculativeBase
target *KVCache
keys, values *mlx.Array
}
func newIsolatedKVCache(target *KVCache) *isolatedKVCache {
return &isolatedKVCache{
speculativeBase: speculativeBase{offset: target.Offset()},
target: target,
}
}
func (c *isolatedKVCache) Update(_ *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
c.keys = concatKV(c.keys, keys)
c.values = concatKV(c.values, values)
c.offset += keys.Dim(2)
state := c.target.State()
if len(state) < 2 {
return nn.NewKVHistory(c.keys, c.values, nil)
}
return nn.NewKVHistory(state[0].Concatenate(2, c.keys), state[1].Concatenate(2, c.values), nil)
}
func (c *isolatedKVCache) State() []*mlx.Array {
if c.keys == nil || c.values == nil {
return c.target.State()
}
state := c.target.State()
if len(state) < 2 {
return []*mlx.Array{c.keys, c.values}
}
return []*mlx.Array{
state[0].Concatenate(2, c.keys),
state[1].Concatenate(2, c.values),
}
}
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
maxSize int
@ -275,6 +447,127 @@ func (c *RotatingKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KV
})
}
// View returns the current rotating cache contents in logical order for
// assistant KV sharing.
func (c *RotatingKVCache) View(_ *batch.Batch) *nn.KVHistory {
k, v := c.logicalTail(c.maxSize - 1)
if k == nil || v == nil {
return nil
}
return nn.NewKVHistory(k, v, nil)
}
func (c *RotatingKVCache) logicalTail(keep int) (*mlx.Array, *mlx.Array) {
state := c.State()
if len(state) < 2 || keep <= 0 {
return nil, nil
}
keys, values := state[0], state[1]
K := keys.Dim(2)
if K == 0 {
return nil, nil
}
keep = min(keep, K)
if K > c.maxSize || c.offset < c.maxSize {
start := K - keep
return keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice()),
values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice())
}
oldest := c.idx % K
var logicalK, logicalV *mlx.Array
if oldest == 0 {
logicalK, logicalV = keys, values
} else {
tailK := keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(oldest, K), mlx.Slice())
tailV := values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(oldest, K), mlx.Slice())
headK := keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, oldest), mlx.Slice())
headV := values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, oldest), mlx.Slice())
logicalK = tailK.Concatenate(2, headK)
logicalV = tailV.Concatenate(2, headV)
}
start := K - keep
return logicalK.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice()),
logicalV.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice())
}
type speculativeRotatingKVCache struct {
speculativeBase
target *RotatingKVCache
keys, values *mlx.Array
}
func newSpeculativeRotatingKVCache(target *RotatingKVCache) *speculativeRotatingKVCache {
return &speculativeRotatingKVCache{
speculativeBase: speculativeBase{offset: target.Offset()},
target: target,
}
}
func (c *speculativeRotatingKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
c.keys = concatKV(c.keys, keys)
c.values = concatKV(c.values, values)
c.offset += keys.Dim(2)
oldK, oldV := c.target.logicalTail(c.target.maxSize - 1)
histK, histV := c.keys, c.values
if oldK != nil && oldV != nil {
histK = oldK.Concatenate(2, c.keys)
histV = oldV.Concatenate(2, c.values)
}
return nn.NewKVHistory(histK, histV, logicalSlidingApplier{
b: b,
K: histK.Dim(2),
window: c.target.maxSize,
dtype: keys.DType(),
})
}
func (c *speculativeRotatingKVCache) State() []*mlx.Array {
if c.keys == nil || c.values == nil {
return c.target.State()
}
oldK, oldV := c.target.logicalTail(c.target.maxSize - 1)
if oldK == nil || oldV == nil {
return []*mlx.Array{c.keys, c.values}
}
return []*mlx.Array{oldK.Concatenate(2, c.keys), oldV.Concatenate(2, c.values)}
}
func (c *speculativeRotatingKVCache) commit(n int) {
if c.keys == nil || c.values == nil || n <= 0 {
return
}
n = min(n, c.keys.Dim(2))
c.target.appendKV(prefixKV(c.keys, n), prefixKV(c.values, n))
}
type logicalSlidingApplier struct {
b *batch.Batch
K int
window int
dtype mlx.DType
}
func (a logicalSlidingApplier) ApplyMask(logical nn.AttentionMask) nn.AttentionMask {
return logical.Intersect(nn.SlidingWindowMask(a.b, a.K, a.window, a.dtype))
}
func concatKV(prev, next *mlx.Array) *mlx.Array {
if prev == nil {
return next
}
return prev.Concatenate(2, next)
}
func prefixKV(a *mlx.Array, n int) *mlx.Array {
return a.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, n), mlx.Slice())
}
// appendKV is the raw write path shared by Update and Restore —
// routes to concat for prefill (L > 1) and update for decode.
func (c *RotatingKVCache) appendKV(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {

View file

@ -114,6 +114,61 @@ func TestRotatingKVCacheDecodeParity(t *testing.T) {
}
}
func TestAssistantSharedHistoryL1MasksMatchNoMask(t *testing.T) {
skipIfNoMLX(t)
if !mlx.MetalIsAvailable() {
t.Skip("MLX Metal not available")
}
const H, D = 1, 4
const window = 4
const total = 7
const scale = 1.0
q := mlx.FromValues([]float32{0.7, -0.4, 0.2, 0.9}, 1, H, 1, D)
mlx.Eval(q)
full := NewKVCache()
sliding := NewRotatingKVCache(window)
for pos := range total {
kVals := make([]float32, H*D)
vVals := make([]float32, H*D)
for i := range kVals {
kVals[i] = 0.1*float32(pos+1) + 0.01*float32(i)
vVals[i] = -0.1*float32(pos+1) + 0.01*float32(i)
}
k := mlx.FromValues(kVals, 1, H, 1, D)
v := mlx.FromValues(vVals, 1, H, 1, D)
full.Update(newKVBatch(full.Offset(), 1), k, v)
sliding.Update(newKVBatch(sliding.Offset(), 1), k, v)
}
b := newKVBatch(total-1, 1)
slidingHistory := sliding.View(b)
cases := []struct {
name string
h *nn.KVHistory
mask nn.AttentionMask
}{
{name: "full", h: full.View(b), mask: nn.CausalMask()},
{name: "sliding", h: slidingHistory, mask: nn.CausalMask().Intersect(nn.SlidingWindowMask(b, slidingHistory.K().Dim(2), window, q.DType()))},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := nn.ScaledDotProductAttention(b, q, scale, nn.WithKVHistory(tc.h), nn.WithMask(tc.mask))
want := mlx.FastScaledDotProductAttention(q, tc.h.K(), tc.h.V(), scale, "", nil)
mlx.Eval(got, want)
gs, ws := got.Floats(), want.Floats()
for i := range ws {
if math.Abs(float64(gs[i]-ws[i])) > 1e-5 {
t.Fatalf("index %d: got %v, want %v", i, gs[i], ws[i])
}
}
})
}
}
// TestRotatingKVCachePrefillParity drives an L>1 prefill into a
// rotating cache and verifies SDPA output through WithKVHistory
// matches a reference computed from the same K/V with the model mask

View file

@ -41,3 +41,36 @@ func TestFromValues(t *testing.T) {
}
})
}
func TestComparisonOpsAndBernoulli(t *testing.T) {
skipIfNoMLX(t)
a := FromValues([]float32{1, 2, 3}, 3)
b := FromValues([]float32{1, 1, 4}, 3)
eq := a.Equal(b).AsType(DTypeInt32)
gt := a.Greater(b).AsType(DTypeInt32)
le := a.LessEqual(b).AsType(DTypeInt32)
bern := Bernoulli(FromValues([]float32{1, 0}, 2)).AsType(DTypeInt32)
Eval(eq, gt, le, bern)
for name, tc := range map[string]struct {
got []int
want []int
}{
"equal": {eq.Ints(), []int{1, 0, 0}},
"greater": {gt.Ints(), []int{0, 1, 0}},
"lessEqual": {le.Ints(), []int{1, 0, 1}},
"bernoulli": {bern.Ints(), []int{1, 0}},
} {
t.Run(name, func(t *testing.T) {
if len(tc.got) != len(tc.want) {
t.Fatalf("got %v, want %v", tc.got, tc.want)
}
for i := range tc.want {
if tc.got[i] != tc.want[i] {
t.Fatalf("got %v, want %v", tc.got, tc.want)
}
}
})
}
}

View file

@ -137,12 +137,30 @@ func (t *Array) LogsumexpAxis(axis int, keepDims bool) *Array {
return out
}
func (t *Array) Equal(other *Array) *Array {
out := New("EQUAL")
C.mlx_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Greater(other *Array) *Array {
out := New("GREATER")
C.mlx_greater(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Less(other *Array) *Array {
out := New("LESS")
C.mlx_less(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) LessEqual(other *Array) *Array {
out := New("LESS_EQUAL")
C.mlx_less_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) MaxAxis(axis int, keepDims bool) *Array {
out := New("MAX_AXIS")
C.mlx_max_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)

View file

@ -3,9 +3,24 @@ package mlx
// #include "generated.h"
import "C"
import "unsafe"
func (t *Array) Categorical(axis int) *Array {
key := New("")
out := New("")
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
return out
}
func Bernoulli(p *Array) *Array {
dims := p.Dims()
shape := make([]C.int, len(dims))
for i, d := range dims {
shape[i] = C.int(d)
}
key := New("")
out := New("BERNOULLI")
C.mlx_random_bernoulli(&out.ctx, p.ctx, unsafe.SliceData(shape), C.size_t(len(shape)), key.ctx, DefaultStream().ctx)
return out
}

View file

@ -27,9 +27,40 @@ type Model interface {
LoadWeights(tensors map[string]*mlx.Array) error
}
// DraftModel is an auxiliary model stored alongside a target model.
type DraftModel interface {
LoadWeights(tensors map[string]*mlx.Array) error
}
// MTPDefaults holds model-provided draft-token defaults for speculative
// decoding. Environment settings in the runner may override these values.
type MTPDefaults struct {
InitialDraftTokens int
MaxDraftTokens int
Enabled bool
}
// MTPDefaultsProvider lets a model provide MTP policy defaults from its own
// config without teaching the runner model-specific shape heuristics.
type MTPDefaultsProvider interface {
MTPDraftDefaults(sample bool) MTPDefaults
}
// MTPDraftModel is a draft model capable of Gemma-style multi-token
// prediction from target token embeddings, target hidden states, and target KV.
type MTPDraftModel interface {
Draft(inputEmbeds *mlx.Array, position int32, caches []cache.Cache) (logits, hidden *mlx.Array)
}
// MTPEmbeddingModel exposes the target token embedding path used by MTP drafts.
type MTPEmbeddingModel interface {
TokenEmbeddings(inputIDs *mlx.Array) *mlx.Array
}
var (
mu sync.Mutex
registry = make(map[string]func(root *model.Root) (Model, error))
mu sync.Mutex
registry = make(map[string]func(root *model.Root) (Model, error))
draftRegistry = make(map[string]func(root *model.Root, target Model) (DraftModel, error))
)
// Register registers a model constructor by architecture name.
@ -44,6 +75,17 @@ func Register(arch string, fn func(root *model.Root) (Model, error)) {
registry[arch] = fn
}
// RegisterDraft registers a draft model constructor by architecture name.
func RegisterDraft(arch string, fn func(root *model.Root, target Model) (DraftModel, error)) {
mu.Lock()
defer mu.Unlock()
if _, exists := draftRegistry[arch]; exists {
panic(fmt.Sprintf("draft model architecture %q already registered", arch))
}
draftRegistry[arch] = fn
}
// New reads config.json from the manifest, detects the architecture, looks up
// the registered constructor, and calls it to create the model (with config
// parsed and struct created, but weights not yet loaded).
@ -78,6 +120,51 @@ func New(root *model.Root) (Model, error) {
return fn(root)
}
// NewDraft constructs the draft model described by the manifest config, if any.
func NewDraft(root *model.Root, target Model) (DraftModel, error) {
if root == nil || root.Draft == nil {
return nil, nil
}
configPath := root.Draft.Config
if configPath == "" {
configPath = "draft/config.json"
}
configData, err := root.Manifest.ReadConfig(configPath)
if err != nil {
return nil, fmt.Errorf("failed to read %s: %w", configPath, err)
}
var archConfig struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(configData, &archConfig); err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", configPath, err)
}
arch := root.Draft.Architecture
if arch == "" && len(archConfig.Architectures) > 0 {
arch = archConfig.Architectures[0]
}
if arch == "" {
arch = archConfig.ModelType
}
if arch == "" {
return nil, fmt.Errorf("no draft architecture found in %s", configPath)
}
slog.Info("Draft model architecture", "arch", arch)
mu.Lock()
fn, ok := draftRegistry[arch]
mu.Unlock()
if !ok {
return nil, fmt.Errorf("unsupported draft architecture: %s", arch)
}
return fn(root, target)
}
// Weights returns a function that loads model weights, then pins all
// arrays reachable from the model struct and sweeps everything else.
func Weights(m Model) func(map[string]*mlx.Array) error {

View file

@ -10,6 +10,7 @@ import (
"strconv"
"strings"
modeltypes "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen/manifest"
)
@ -22,6 +23,7 @@ type TensorQuantInfo struct {
// Root wraps a ModelManifest with pre-scanned quantization metadata.
type Root struct {
Manifest *manifest.ModelManifest
Draft *modeltypes.Draft
// Backwards-compatible model-level quant metadata (first tensor blob).
quantType string
@ -43,6 +45,7 @@ func Open(modelName string) (*Root, error) {
Manifest: m,
tensorQuant: make(map[string]*TensorQuantInfo),
}
root.Draft = readDraftConfig(m)
for _, layer := range m.GetTensorLayers("") {
blobPath := m.BlobPath(layer.Digest)
@ -68,6 +71,34 @@ func Open(modelName string) (*Root, error) {
return root, nil
}
func readDraftConfig(m *manifest.ModelManifest) *modeltypes.Draft {
if m == nil || m.Manifest == nil || m.Manifest.Config.Digest == "" {
return nil
}
data, err := os.ReadFile(m.BlobPath(m.Manifest.Config.Digest))
if err != nil {
return nil
}
var cfg modeltypes.ConfigV2
if err := json.Unmarshal(data, &cfg); err != nil {
return nil
}
if cfg.Draft != nil {
return cfg.Draft
}
if m.GetConfigLayer("draft/config.json") != nil {
return &modeltypes.Draft{
ModelFormat: "safetensors",
TensorPrefix: "draft.",
Config: "draft/config.json",
}
}
return nil
}
// Close is a no-op for now (future: release resources).
func (r *Root) Close() {}

959
x/mlxrunner/mtp.go Normal file
View file

@ -0,0 +1,959 @@
package mlxrunner
import (
"context"
"fmt"
"log/slog"
"math"
"os"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
)
const (
mtpDefaultInitialDraftTokens = 4
mtpDefaultMaxDraftTokens = 16
)
type mtpDraftSchedule string
const (
mtpDraftScheduleHeuristic mtpDraftSchedule = "heuristic"
mtpDraftScheduleConstant mtpDraftSchedule = "constant"
)
type mtpStats struct {
iterations int
drafted int
accepted int
mismatches int
allAccepted int
batched int
serial int
compared int
batchSerialMismatches int
maxDraft int
targetDuration time.Duration
draftDuration time.Duration
validateDuration time.Duration
}
type mtpOptions struct {
initialDraftTokens int
maxDraftTokens int
draftSchedule mtpDraftSchedule
serialValidate bool
compareSerialValidate bool
}
func (r *Runner) mtpDefaults(sample bool) base.MTPDefaults {
defaults := base.MTPDefaults{
InitialDraftTokens: mtpDefaultInitialDraftTokens,
MaxDraftTokens: mtpDefaultMaxDraftTokens,
Enabled: true,
}
if p, ok := r.Model.(base.MTPDefaultsProvider); ok {
defaults = p.MTPDraftDefaults(sample)
}
if defaults.InitialDraftTokens <= 0 {
defaults.InitialDraftTokens = mtpDefaultInitialDraftTokens
}
if defaults.MaxDraftTokens <= 0 {
defaults.MaxDraftTokens = mtpDefaultMaxDraftTokens
}
return defaults
}
func (r *Runner) loadMTPOptions(sample bool) mtpOptions {
defaults := r.mtpDefaults(sample)
opts := mtpOptions{
initialDraftTokens: defaults.InitialDraftTokens,
maxDraftTokens: defaults.MaxDraftTokens,
draftSchedule: mtpDraftScheduleConstant,
}
if v := positiveEnvInt("OLLAMA_MLX_MTP_MAX_DRAFT_TOKENS"); v > 0 {
opts.maxDraftTokens = v
}
if v := positiveEnvInt("OLLAMA_MLX_MTP_INITIAL_DRAFT_TOKENS"); v > 0 {
opts.initialDraftTokens = v
}
if opts.initialDraftTokens > opts.maxDraftTokens {
opts.initialDraftTokens = opts.maxDraftTokens
}
if b, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_SERIAL_VALIDATE")); err == nil {
opts.serialValidate = b
}
if b, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_COMPARE_SERIAL_VALIDATE")); err == nil {
opts.compareSerialValidate = b
}
switch schedule := strings.ToLower(strings.TrimSpace(os.Getenv("OLLAMA_MLX_MTP_DRAFT_SCHEDULE"))); schedule {
case "", string(mtpDraftScheduleConstant):
opts.draftSchedule = mtpDraftScheduleConstant
case string(mtpDraftScheduleHeuristic):
opts.draftSchedule = mtpDraftScheduleHeuristic
default:
slog.Warn("invalid MTP env setting", "key", "OLLAMA_MLX_MTP_DRAFT_SCHEDULE", "value", schedule)
}
return opts
}
func positiveEnvInt(key string) int {
raw := os.Getenv(key)
if raw == "" {
return 0
}
v, err := strconv.Atoi(raw)
if err != nil || v <= 0 {
slog.Warn("invalid MTP env setting", "key", key, "value", raw)
return 0
}
return v
}
func (r *Runner) useGreedyMTP(opts sampler.Options) bool {
if r.Draft == nil {
return false
}
if _, ok := r.Draft.(base.MTPDraftModel); !ok {
return false
}
if _, ok := r.Model.(base.MTPEmbeddingModel); !ok {
return false
}
if !r.mtpDefaults(false).Enabled {
return false
}
if opts.Logprobs || opts.TopLogprobs > 0 {
return false
}
if opts.Temperature != 0 {
return false
}
repeatPenaltyNeutral := opts.RepeatPenalty <= 0 || opts.RepeatPenalty == 1
topPNeutral := opts.TopP <= 0 || opts.TopP >= 1
topKNeutral := opts.TopK <= 0
return repeatPenaltyNeutral && opts.PresencePenalty == 0 && opts.FrequencyPenalty == 0 && topPNeutral && topKNeutral && opts.MinP == 0
}
func (r *Runner) useSampleMTP(opts sampler.Options) bool {
if serial, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_SERIAL_VALIDATE")); err == nil && serial {
return false
}
if compare, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_COMPARE_SERIAL_VALIDATE")); err == nil && compare {
return false
}
if r.Draft == nil {
return false
}
if _, ok := r.Draft.(base.MTPDraftModel); !ok {
return false
}
if _, ok := r.Model.(base.MTPEmbeddingModel); !ok {
return false
}
if !r.mtpDefaults(true).Enabled {
return false
}
if opts.Logprobs || opts.TopLogprobs > 0 {
return false
}
return opts.Temperature != 0
}
func (r *Runner) runGreedyMTPDecode(ctx context.Context, request Request, session *cacheSession, caches []cache.Cache, seed []int32, position *int, started time.Time) error {
targetEmbeddings := r.Model.(base.MTPEmbeddingModel)
draft := r.Draft.(base.MTPDraftModel)
mtpOpts := r.loadMTPOptions(false)
stats := mtpStats{maxDraft: mtpOpts.initialDraftTokens}
draftLimit := mtpOpts.initialDraftTokens
slog.Info("MTP greedy decode enabled", "initial_draft_tokens", mtpOpts.initialDraftTokens, "max_draft_tokens", mtpOpts.maxDraftTokens, "draft_schedule", mtpOpts.draftSchedule, "serial_validate", mtpOpts.serialValidate, "compare_serial_validate", mtpOpts.compareSerialValidate)
targetForward := func(token *mlx.Array) *mlx.Array {
fwd := r.Model.Forward(&batch.Batch{
InputIDs: token,
SeqOffsets: []int32{int32(*position)},
SeqQueryLens: []int32{int32(token.Dim(1))},
}, caches)
*position += token.Dim(1)
return fwd
}
hidden := targetForward(mlx.FromValues(seed, 1, len(seed)))
current := sampler.Result{Token: greedyTokenFromLogits(r.lastLogits(hidden))}
mlx.Pin(current.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
defer func() {
mlx.Unpin(current.Arrays()...)
}()
dec := decoder{tokenizer: r.Tokenizer}
final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1}
now := started
generated := 0
for generated < request.Options.NumPredict {
if err := ctx.Err(); err != nil {
return err
}
t0 := time.Now()
hidden = targetForward(current.Token.ExpandDims(-1))
baseLogits := r.lastLogits(hidden)
stats.targetDuration += time.Since(t0)
if generated == 0 {
mlx.Eval(current.Arrays()...)
final.PromptEvalDuration = time.Since(now)
now = time.Now()
}
done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final)
if err != nil {
return err
}
if !done {
generated++
}
if done || generated >= request.Options.NumPredict {
break
}
stats.iterations++
maxDraft := min(draftLimit, request.Options.NumPredict-generated)
t0 = time.Now()
draftTokens := r.generateMTPDrafts(draft, targetEmbeddings, current.Token, hidden, caches, int32(*position-1), maxDraft)
draftCount := 0
if draftTokens != nil {
draftCount = draftTokens.Dim(1)
mlx.Pin(baseLogits, draftTokens)
mlx.Eval(draftTokens)
mlx.Sweep()
}
stats.draftDuration += time.Since(t0)
stats.drafted += draftCount
var next sampler.Result
if draftCount == 0 {
next = sampler.Result{Token: greedyTokenFromLogits(baseLogits)}
} else {
var accepted int
t0 = time.Now()
next, accepted, done, err = r.acceptMTPDrafts(ctx, request, session, &dec, caches, position, baseLogits, draftTokens, &final, &generated, &stats, mtpOpts)
stats.validateDuration += time.Since(t0)
mlx.Unpin(baseLogits, draftTokens)
if err != nil {
return err
}
stats.accepted += accepted
switch {
case mtpOpts.draftSchedule == mtpDraftScheduleConstant:
case accepted == draftCount:
stats.allAccepted++
draftLimit = min(mtpOpts.maxDraftTokens, draftLimit+2)
default:
stats.mismatches++
draftLimit = max(1, draftLimit-1)
}
if mtpOpts.draftSchedule == mtpDraftScheduleConstant {
if accepted == draftCount {
stats.allAccepted++
} else {
stats.mismatches++
}
}
stats.maxDraft = max(stats.maxDraft, draftLimit)
if next.Token == nil {
mlx.Sweep()
}
if done || generated >= request.Options.NumPredict {
break
}
}
mlx.Pin(next.Arrays()...)
old := current
current = next
mlx.Unpin(old.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
if generated%256 == 0 {
mlx.ClearCache()
}
}
final.EvalCount = generated
final.EvalDuration = time.Since(now)
acceptance := 0.0
if stats.drafted > 0 {
acceptance = float64(stats.accepted) / float64(stats.drafted)
}
avgDraft := 0.0
avgAccepted := 0.0
if stats.iterations > 0 {
avgDraft = float64(stats.drafted) / float64(stats.iterations)
avgAccepted = float64(stats.accepted) / float64(stats.iterations)
}
slog.Info("MTP decode stats", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "compared", stats.compared, "batch_serial_mismatches", stats.batchSerialMismatches, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", stats.maxDraft, "draft_schedule", mtpOpts.draftSchedule, "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration)
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- final:
return nil
}
}
func (r *Runner) runSampleMTPDecode(ctx context.Context, request Request, session *cacheSession, caches []cache.Cache, seed []int32, position *int, started time.Time) error {
targetEmbeddings := r.Model.(base.MTPEmbeddingModel)
draft := r.Draft.(base.MTPDraftModel)
mtpOpts := r.loadMTPOptions(true)
stats := mtpStats{maxDraft: mtpOpts.initialDraftTokens}
draftLimit := mtpOpts.initialDraftTokens
slog.Info("MTP sample decode enabled", "initial_draft_tokens", mtpOpts.initialDraftTokens, "max_draft_tokens", mtpOpts.maxDraftTokens, "draft_schedule", mtpOpts.draftSchedule, "serial_validate", mtpOpts.serialValidate)
targetForward := func(token *mlx.Array) *mlx.Array {
fwd := r.Model.Forward(&batch.Batch{
InputIDs: token,
SeqOffsets: []int32{int32(*position)},
SeqQueryLens: []int32{int32(token.Dim(1))},
}, caches)
*position += token.Dim(1)
return fwd
}
hidden := targetForward(mlx.FromValues(seed, 1, len(seed)))
current := r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden))
mlx.Pin(current.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
defer func() {
mlx.Unpin(current.Arrays()...)
}()
dec := decoder{tokenizer: r.Tokenizer}
final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1}
now := started
generated := 0
for generated < request.Options.NumPredict {
if err := ctx.Err(); err != nil {
return err
}
t0 := time.Now()
hidden = targetForward(mtpTokenInput(current.Token))
baseLogits := r.lastLogits(hidden)
stats.targetDuration += time.Since(t0)
if generated == 0 {
mlx.Eval(current.Arrays()...)
final.PromptEvalDuration = time.Since(now)
now = time.Now()
}
done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final)
if err != nil {
return err
}
if !done {
generated++
}
if done || generated >= request.Options.NumPredict {
break
}
stats.iterations++
maxDraft := min(draftLimit, request.Options.NumPredict-generated)
t0 = time.Now()
candidates := r.generateMTPDraftCandidates(draft, targetEmbeddings, current.Token, hidden, caches, int32(*position-1), maxDraft)
draftCount := 0
if candidates != nil {
draftCount = candidates.tokens.Dim(1)
mlx.Pin(baseLogits, candidates.tokens, candidates.logits)
mlx.Sweep()
}
stats.draftDuration += time.Since(t0)
stats.drafted += draftCount
var next sampler.Result
if draftCount == 0 {
next = r.Sampler.Sample([]int{pipelineSlot}, baseLogits)
} else {
var accepted int
t0 = time.Now()
next, accepted, done, err = r.acceptSampleMTPDrafts(ctx, request, session, &dec, caches, position, baseLogits, candidates, &final, &generated, &stats)
stats.validateDuration += time.Since(t0)
mlx.Unpin(baseLogits, candidates.tokens, candidates.logits)
if err != nil {
return err
}
stats.accepted += accepted
switch {
case mtpOpts.draftSchedule == mtpDraftScheduleConstant:
case accepted == draftCount:
stats.allAccepted++
draftLimit = min(mtpOpts.maxDraftTokens, draftLimit+2)
default:
stats.mismatches++
draftLimit = max(1, draftLimit-1)
}
if mtpOpts.draftSchedule == mtpDraftScheduleConstant {
if accepted == draftCount {
stats.allAccepted++
} else {
stats.mismatches++
}
}
stats.maxDraft = max(stats.maxDraft, draftLimit)
if next.Token == nil {
mlx.Sweep()
}
if done || generated >= request.Options.NumPredict {
break
}
}
mlx.Pin(next.Arrays()...)
old := current
current = next
mlx.Unpin(old.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
if generated%256 == 0 {
mlx.ClearCache()
}
}
final.EvalCount = generated
final.EvalDuration = time.Since(now)
acceptance := 0.0
if stats.drafted > 0 {
acceptance = float64(stats.accepted) / float64(stats.drafted)
}
avgDraft := 0.0
avgAccepted := 0.0
if stats.iterations > 0 {
avgDraft = float64(stats.drafted) / float64(stats.iterations)
avgAccepted = float64(stats.accepted) / float64(stats.iterations)
}
slog.Info("MTP decode stats", "mode", "sample", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", stats.maxDraft, "draft_schedule", mtpOpts.draftSchedule, "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration)
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- final:
return nil
}
}
type mtpDraftCandidates struct {
tokens *mlx.Array
// logits are the processed proposal scores used to sample tokens.
logits *mlx.Array
}
func (r *Runner) generateMTPDrafts(draft base.MTPDraftModel, target base.MTPEmbeddingModel, token *mlx.Array, hidden *mlx.Array, caches []cache.Cache, position int32, maxDraft int) *mlx.Array {
if maxDraft <= 0 {
return nil
}
lastToken := token.ExpandDims(-1)
lastHidden := hidden
draftTokens := make([]*mlx.Array, 0, maxDraft)
// Gemma4 assistant MTP is trained as "single-position" drafting:
// keep the RoPE/cache position anchored at the last target-seen token
// while the proposed token and projected hidden state advance.
for range maxDraft {
tokenEmbedding := target.TokenEmbeddings(lastToken)
inputs := tokenEmbedding.Concatenate(-1, lastHidden)
logits, projected := draft.Draft(inputs, position, caches)
stepLogits := r.lastLogitsFromLogits(logits)
nextToken := greedyTokenFromLogits(stepLogits)
lastToken = nextToken.ExpandDims(-1)
lastHidden = projected
draftTokens = append(draftTokens, lastToken)
}
if len(draftTokens) == 0 {
return nil
}
return mlx.Concatenate(draftTokens, 1)
}
func (r *Runner) generateMTPDraftCandidates(draft base.MTPDraftModel, target base.MTPEmbeddingModel, token *mlx.Array, hidden *mlx.Array, caches []cache.Cache, position int32, maxDraft int) *mtpDraftCandidates {
if maxDraft <= 0 {
return nil
}
lastToken := mtpTokenInput(token)
lastHidden := hidden
draftTokens := make([]*mlx.Array, 0, maxDraft)
draftLogits := make([]*mlx.Array, 0, maxDraft)
var prefix *mlx.Array
// Gemma4 assistant MTP is trained as "single-position" drafting:
// keep the RoPE/cache position anchored at the last target-seen token
// while the proposed token and projected hidden state advance.
for range maxDraft {
tokenEmbedding := target.TokenEmbeddings(lastToken)
inputs := tokenEmbedding.Concatenate(-1, lastHidden)
logits, projected := draft.Draft(inputs, position, caches)
stepLogits := r.lastLogitsFromLogits(logits)
stepScores := r.Sampler.SpeculativeScores(pipelineSlot, stepLogits, prefix)
nextToken := stepScores.Categorical(-1).AsType(mlx.DTypeInt32)
lastToken = mtpTokenInput(nextToken)
lastHidden = projected
draftTokens = append(draftTokens, lastToken)
draftLogits = append(draftLogits, stepScores.ExpandDims(1))
if prefix == nil {
prefix = lastToken
} else {
prefix = prefix.Concatenate(1, lastToken)
}
}
if len(draftTokens) == 0 {
return nil
}
return &mtpDraftCandidates{
tokens: mlx.Concatenate(draftTokens, 1),
logits: mlx.Concatenate(draftLogits, 1),
}
}
func (r *Runner) acceptMTPDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int, stats *mtpStats, opts mtpOptions) (sampler.Result, int, bool, error) {
if opts.serialValidate {
stats.serial++
return r.acceptMTPDraftsSerial(ctx, request, session, dec, caches, position, baseLogits, draftTokens, final, generated)
}
specCaches, spec, ok := cache.BeginSpeculation(caches)
if ok {
stats.batched++
return r.acceptMTPDraftsBatched(ctx, request, session, dec, caches, specCaches, spec, position, baseLogits, draftTokens, final, generated, stats, opts)
}
stats.serial++
return r.acceptMTPDraftsSerial(ctx, request, session, dec, caches, position, baseLogits, draftTokens, final, generated)
}
func (r *Runner) acceptMTPDraftsBatched(ctx context.Context, request Request, session *cacheSession, dec *decoder, liveCaches []cache.Cache, caches []cache.Cache, spec *cache.Speculation, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int, stats *mtpStats, opts mtpOptions) (sampler.Result, int, bool, error) {
before := *position
draftCount := draftTokens.Dim(1)
hiddenSeq := r.Model.Forward(&batch.Batch{
InputIDs: draftTokens,
SeqOffsets: []int32{int32(before)},
SeqQueryLens: []int32{int32(draftCount)},
}, caches)
accepted := 0
var next sampler.Result
done := false
selectedTokens := r.mtpValidationTokens(baseLogits, hiddenSeq)
mlx.Eval(draftTokens, selectedTokens)
draftIDs := draftTokens.Ints()
selectedIDs := selectedTokens.Ints()
if len(selectedIDs) < draftCount+1 {
return sampler.Result{}, accepted, false, fmt.Errorf("mtp validation produced %d tokens for %d draft tokens", len(selectedIDs), draftCount)
}
for i, id := range draftIDs {
if selectedIDs[i] != id {
next = sampler.Result{Token: mtpTokenAt(selectedTokens, i)}
break
}
accepted++
if r.Tokenizer.IsEOS(int32(id)) {
done = true
break
}
}
if opts.compareSerialValidate {
spec.Commit(0)
r.compareMTPBatchedWithSerial(ctx, liveCaches, before, baseLogits, hiddenSeq, draftIDs, selectedIDs, accepted, draftCount, stats)
}
spec.Commit(accepted)
*position = before + accepted
for _, id := range draftIDs[:accepted] {
if *generated >= request.Options.NumPredict {
done = true
break
}
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
var err error
done, err = r.emitMTPToken(ctx, request, session, dec, res, final)
if err != nil {
return sampler.Result{}, accepted, done, err
}
if !done {
(*generated)++
}
if done {
break
}
}
if done || *generated >= request.Options.NumPredict {
return sampler.Result{}, accepted, true, nil
}
if next.Token == nil {
next = sampler.Result{Token: mtpTokenAt(selectedTokens, draftCount)}
}
return next, accepted, false, nil
}
func (r *Runner) acceptSampleMTPDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, candidates *mtpDraftCandidates, final *CompletionResponse, generated *int, stats *mtpStats) (sampler.Result, int, bool, error) {
specCaches, spec, ok := cache.BeginSpeculation(caches)
if !ok {
stats.serial++
return r.Sampler.Sample([]int{pipelineSlot}, baseLogits), 0, false, nil
}
stats.batched++
before := *position
draftCount := candidates.tokens.Dim(1)
hiddenSeq := r.Model.Forward(&batch.Batch{
InputIDs: candidates.tokens,
SeqOffsets: []int32{int32(before)},
SeqQueryLens: []int32{int32(draftCount)},
}, specCaches)
targetScores := r.Sampler.SpeculativeScores(pipelineSlot, r.mtpValidationLogits(baseLogits, hiddenSeq), candidates.tokens)
draftScores := candidates.logits
if draftScores.NumDims() == 3 {
draftScores = draftScores.Squeeze(0)
}
acceptedMask := mtpSampleAcceptedMask(targetScores, draftScores, candidates.tokens, draftCount)
mlx.Eval(candidates.tokens, acceptedMask)
draftIDs := candidates.tokens.Ints()
acceptedFlags := acceptedMask.Ints()
accepted := 0
for _, ok := range acceptedFlags {
if ok == 0 {
break
}
accepted++
}
if accepted > draftCount {
return sampler.Result{}, 0, false, fmt.Errorf("mtp sample validation accepted %d tokens for %d draft tokens", accepted, draftCount)
}
commitIDs := make([]int32, 0, accepted+1)
done := false
for i, id := range draftIDs[:accepted] {
commitIDs = append(commitIDs, int32(id))
if r.Tokenizer.IsEOS(int32(id)) {
done = true
accepted = i + 1
commitIDs = commitIDs[:accepted]
break
}
}
spec.Commit(accepted)
*position = before + accepted
for _, id := range draftIDs[:accepted] {
if *generated >= request.Options.NumPredict {
done = true
break
}
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
var err error
done, err = r.emitMTPToken(ctx, request, session, dec, res, final)
if err != nil {
return sampler.Result{}, accepted, done, err
}
if !done {
(*generated)++
}
if done {
break
}
}
if done || *generated >= request.Options.NumPredict {
r.Sampler.Commit(pipelineSlot, commitIDs)
return sampler.Result{}, accepted, true, nil
}
var nextToken *mlx.Array
if accepted == draftCount {
nextToken = mtpSampleTokenAt(targetScores, draftCount)
} else {
nextToken = mtpSampleResidualToken(targetScores, draftScores, accepted)
}
mlx.Eval(nextToken)
nextID := int32(tokenID(nextToken))
commitIDs = append(commitIDs, nextID)
r.Sampler.Commit(pipelineSlot, commitIDs)
return sampler.Result{Token: nextToken}, accepted, false, nil
}
func mtpSampleAcceptedMask(targetScores, draftScores, draftTokens *mlx.Array, draftCount int) *mlx.Array {
targetProbs := mlx.SoftmaxAxis(targetScores.Slice(mlx.Slice(0, draftCount), mlx.Slice()), -1, true)
draftProbs := mlx.SoftmaxAxis(draftScores, -1, true)
if draftTokens.NumDims() == 2 {
draftTokens = draftTokens.Squeeze(0)
}
indices := draftTokens.ExpandDims(-1)
p := targetProbs.TakeAlongAxis(indices, -1).Squeeze(-1)
q := draftProbs.TakeAlongAxis(indices, -1).Squeeze(-1)
acceptP := mlx.Minimum(p.Divide(q), mlx.FromValue(float32(1)))
return mlx.Bernoulli(acceptP).AsType(mlx.DTypeInt32)
}
func mtpSampleTokenAt(scores *mlx.Array, index int) *mlx.Array {
row := scores.Slice(mlx.Slice(index, index+1), mlx.Slice())
return mtpTokenVector(row.Categorical(-1).AsType(mlx.DTypeInt32))
}
func mtpSampleResidualToken(targetScores, draftScores *mlx.Array, index int) *mlx.Array {
p := mlx.SoftmaxAxis(targetScores.Slice(mlx.Slice(index, index+1), mlx.Slice()), -1, true)
q := mlx.SoftmaxAxis(draftScores.Slice(mlx.Slice(index, index+1), mlx.Slice()), -1, true)
diff := p.Subtract(q)
positive := mlx.Maximum(diff, mlx.FromValue(float32(1e-20)))
logits := mlx.Log(positive)
logits = mlx.Where(diff.LessEqual(mlx.FromValue(float32(0))), mlx.FromValue(float32(math.Inf(-1))), logits)
return mtpTokenVector(logits.Categorical(-1).AsType(mlx.DTypeInt32))
}
func mtpTokenInput(token *mlx.Array) *mlx.Array {
switch token.NumDims() {
case 0:
return token.Reshape(1, 1)
case 1:
return token.ExpandDims(-1)
case 2:
return token
default:
panic(fmt.Sprintf("mtp token must be rank 0, 1, or 2, got rank %d", token.NumDims()))
}
}
func mtpTokenVector(token *mlx.Array) *mlx.Array {
switch token.NumDims() {
case 0:
return token.Reshape(1)
case 1:
return token
default:
panic(fmt.Sprintf("mtp sampled token must be rank 0 or 1, got rank %d", token.NumDims()))
}
}
func (r *Runner) compareMTPBatchedWithSerial(ctx context.Context, caches []cache.Cache, before int, baseLogits, hiddenSeq *mlx.Array, draftIDs, selectedIDs []int, accepted, draftCount int, stats *mtpStats) {
serialCaches, ok := cache.BeginIsolatedSpeculation(caches)
if !ok {
return
}
compareCount := accepted + 1
if accepted == draftCount {
// Include the target bonus token when every draft was accepted.
compareCount = draftCount + 1
}
serialLogits := baseLogits
for i := range compareCount {
if err := ctx.Err(); err != nil {
return
}
if i >= len(selectedIDs) {
return
}
batchedLogits := baseLogits
if i > 0 {
batchedLogits = r.targetLogitsAt(hiddenSeq, i-1)
}
batchedToken := greedyTokenFromLogits(batchedLogits)
serialToken := greedyTokenFromLogits(serialLogits)
mlx.Eval(batchedToken, serialToken)
batchedID := tokenID(batchedToken)
vectorizedID := selectedIDs[i]
serialID := tokenID(serialToken)
stats.compared++
if vectorizedID != serialID {
firstMismatch := stats.batchSerialMismatches == 0
stats.batchSerialMismatches++
if !firstMismatch {
return
}
draftID := -1
if i < draftCount {
draftID = draftIDs[i]
}
batchedTop := top2FromLogits(batchedLogits)
serialTop := top2FromLogits(serialLogits)
slog.Warn("MTP batched validation differs from serial validation",
"position", before+i,
"draft", draftID,
"batched", vectorizedID,
"batched_slice", batchedID,
"serial", serialID,
"batched_slice_top1", batchedTop.firstToken,
"batched_slice_top2", batchedTop.secondToken,
"batched_slice_margin", batchedTop.margin,
"serial_top1", serialTop.firstToken,
"serial_top2", serialTop.secondToken,
"serial_margin", serialTop.margin,
)
return
}
if i >= draftCount || i >= accepted {
return
}
hidden := r.Model.Forward(&batch.Batch{
InputIDs: mlx.FromValues([]int32{int32(draftIDs[i])}, 1, 1),
SeqOffsets: []int32{int32(before + i)},
SeqQueryLens: []int32{1},
}, serialCaches)
serialLogits = r.lastLogits(hidden)
}
}
type mtpTop2 struct {
firstToken int
secondToken int
margin float64
}
func top2FromLogits(logits *mlx.Array) mtpTop2 {
indices := logits.Negative().ArgsortAxis(-1).Slice(mlx.Slice(), mlx.Slice(0, 2))
indices32 := indices.AsType(mlx.DTypeInt32)
values := logits.TakeAlongAxis(indices, -1).AsType(mlx.DTypeFloat32)
mlx.Eval(indices32, values)
tokenIDs := indices32.Ints()
logitValues := values.Floats()
if len(tokenIDs) < 2 || len(logitValues) < 2 {
return mtpTop2{}
}
return mtpTop2{
firstToken: tokenIDs[0],
secondToken: tokenIDs[1],
margin: float64(logitValues[0] - logitValues[1]),
}
}
func (r *Runner) acceptMTPDraftsSerial(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) {
logits := baseLogits
accepted := 0
draftIDs := draftTokens.Ints()
for _, id := range draftIDs {
selected := greedyTokenFromLogits(logits)
mlx.Eval(selected)
selectedID := tokenID(selected)
if selectedID != id {
return sampler.Result{Token: mlx.FromValues([]int32{int32(selectedID)}, 1)}, accepted, false, nil
}
hidden := r.Model.Forward(&batch.Batch{
InputIDs: mlx.FromValues([]int32{int32(id)}, 1, 1),
SeqOffsets: []int32{int32(*position)},
SeqQueryLens: []int32{1},
}, caches)
(*position)++
accepted++
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
done, err := r.emitMTPToken(ctx, request, session, dec, res, final)
if err != nil {
return sampler.Result{}, accepted, done, err
}
if !done {
(*generated)++
}
if done || *generated >= request.Options.NumPredict {
return sampler.Result{}, accepted, true, nil
}
logits = r.lastLogits(hidden)
}
return sampler.Result{Token: greedyTokenFromLogits(logits)}, accepted, false, nil
}
func (r *Runner) emitMTPToken(ctx context.Context, request Request, session *cacheSession, dec *decoder, res sampler.Result, final *CompletionResponse) (bool, error) {
output := int32(tokenID(res.Token))
session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) {
final.DoneReason = 0
return true, nil
}
if resp, ok := dec.decode(res); ok {
select {
case <-ctx.Done():
return false, ctx.Err()
case request.Responses <- resp:
}
}
return false, nil
}
func (r *Runner) lastLogits(hidden *mlx.Array) *mlx.Array {
logits := r.Model.Unembed(hidden)
return r.lastLogitsFromLogits(logits)
}
func (r *Runner) targetLogitsAt(hiddenSeq *mlx.Array, index int) *mlx.Array {
hidden := hiddenSeq.Slice(mlx.Slice(), mlx.Slice(index), mlx.Slice())
return r.lastLogits(hidden)
}
func (r *Runner) mtpValidationTokens(baseLogits, hiddenSeq *mlx.Array) *mlx.Array {
return greedyTokenFromLogits(r.mtpValidationLogits(baseLogits, hiddenSeq))
}
func (r *Runner) mtpValidationLogits(baseLogits, hiddenSeq *mlx.Array) *mlx.Array {
seqLogits := r.Model.Unembed(hiddenSeq)
return baseLogits.ExpandDims(1).Concatenate(1, seqLogits)
}
func mtpTokenAt(tokens *mlx.Array, index int) *mlx.Array {
return tokens.Slice(mlx.Slice(), mlx.Slice(index)).Squeeze(0)
}
func (r *Runner) lastLogitsFromLogits(logits *mlx.Array) *mlx.Array {
return logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
}
func greedyTokenFromLogits(logits *mlx.Array) *mlx.Array {
return logits.Argmax(-1, false).AsType(mlx.DTypeInt32)
}
func tokenID(token *mlx.Array) int {
if token == nil {
return -1
}
if token.DType() == mlx.DTypeInt32 {
ids := token.Ints()
if len(ids) > 0 {
return ids[0]
}
}
return token.Int()
}

View file

@ -146,6 +146,12 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er
// Register the sampler after prefill completes.
r.Sampler.Add(pipelineSlot, request.SamplerOpts, inputs)
if r.useGreedyMTP(request.SamplerOpts) {
return r.runGreedyMTPDecode(ctx, request, session, caches, tokens[processed:], &position, now)
}
if r.useSampleMTP(request.SamplerOpts) {
return r.runSampleMTPDecode(ctx, request, session, caches, tokens[processed:], &position, now)
}
step := func(token *mlx.Array) sampler.Result {
fwd := r.Model.Forward(&batch.Batch{

View file

@ -35,6 +35,7 @@ type Request struct {
type Runner struct {
Model base.Model
Draft base.DraftModel
Tokenizer *tokenizer.Tokenizer
Requests chan Request
Sampler *sample.Sampler
@ -61,12 +62,41 @@ func (r *Runner) Load(modelName string) error {
return err
}
// Assign weights to model (model-specific logic)
loadWeights := base.Weights(m)
if err := loadWeights(tensors); err != nil {
// Assign weights to model (model-specific logic). Target and draft weights
// must be loaded before sweeping so tensors from a combined manifest are
// not discarded before the draft model can retain them.
if err := m.LoadWeights(tensors); err != nil {
return err
}
r.Draft = nil
draft, err := base.NewDraft(root, m)
if err != nil {
return err
}
if draft != nil {
if err := draft.LoadWeights(tensors); err != nil {
return err
}
r.Draft = draft
}
collected := mlx.Collect(m)
if draft != nil {
draftArrays := mlx.Collect(draft)
collected = append(collected, draftArrays...)
if root.Draft != nil {
slog.Info("Loaded draft model", "tensor_prefix", root.Draft.TensorPrefix, "config", root.Draft.Config, "arrays", len(draftArrays))
} else {
slog.Info("Loaded draft model", "arrays", len(draftArrays))
}
}
for _, arr := range collected {
mlx.Pin(arr)
}
mlx.Sweep()
mlx.Eval(collected...)
r.Model = m
r.Tokenizer = m.Tokenizer()
r.contextLength = m.MaxContextLength()

View file

@ -349,6 +349,159 @@ func (s *Sampler) Sample(seqIDs []int, logits *mlx.Array) Result {
return res
}
// SpeculativeScores applies this slot's non-sampling transforms to logits
// without mutating sampler state. Row i is scored as if draftTokens[:i] had
// already been appended to the slot history. logits must be [R,V] or [1,R,V].
func (s *Sampler) SpeculativeScores(seqID int, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
slot, ok := s.byID[seqID]
if !ok {
panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: seqID %d not registered", seqID))
}
if logits.NumDims() == 3 {
if logits.Dim(0) != 1 {
panic("sample.Sampler.SpeculativeScores: only batch size 1 is supported")
}
logits = logits.Squeeze(0)
}
if logits.NumDims() != 2 {
panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: logits must be rank 2 or 3, got rank %d", logits.NumDims()))
}
if draftTokens != nil && draftTokens.NumDims() == 1 {
draftTokens = draftTokens.ExpandDims(0)
}
rows := logits.Dim(0)
var hist *mlx.Array
if slot.opts.usesHistory() {
if s.history == nil {
panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: seqID %d has no history", seqID))
}
if slot.historyLen < slot.opts.RepeatLastN {
return s.speculativeScoresSerial(slot, logits, draftTokens)
}
hist = s.speculativeHistory(slot, draftTokens, rows)
}
return slot.speculativeScores(&slotCtx{opts: slot.opts, history: hist}, logits)
}
// Commit appends already-selected tokens to seqID's repeat-penalty history.
// It is used after speculative sampling once the accepted continuation is
// known. Normal Sample calls continue to mutate history themselves.
func (s *Sampler) Commit(seqID int, tokens []int32) {
if len(tokens) == 0 {
return
}
slot, ok := s.byID[seqID]
if !ok {
panic(fmt.Sprintf("sample.Sampler.Commit: seqID %d not registered", seqID))
}
if !slot.opts.usesHistory() {
return
}
if s.history == nil {
panic(fmt.Sprintf("sample.Sampler.Commit: seqID %d has no history", seqID))
}
row := slices.Index(s.slots, slot)
width := s.historyWidth()
take := min(len(tokens), slot.opts.RepeatLastN)
startLen := slot.historyLen + len(tokens) - take
writeTokens := tokens[len(tokens)-take:]
flatOffsets := make([]int32, take)
for i := range take {
ringPos := (startLen + i) % slot.opts.RepeatLastN
flatOffsets[i] = int32(row*width + ringPos)
}
flatIdx := mlx.NewArrayInt32(flatOffsets, []int32{int32(take), 1})
values := mlx.NewArrayInt32(writeTokens, []int32{int32(take), 1})
flatHist := s.history.Reshape(s.history.Dim(0)*width, 1)
s.history.Set(flatHist.PutAlongAxis(flatIdx, values, 0).Reshape(s.history.Dim(0), width))
slot.historyLen += len(tokens)
}
func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
rows := logits.Dim(0)
draftCount := 0
if draftTokens != nil {
draftCount = draftTokens.Dim(1)
}
row := slices.Index(s.slots, slot)
baseFill := min(slot.historyLen, slot.opts.RepeatLastN)
var base *mlx.Array
if baseFill > 0 {
base = s.history.Slice(mlx.Slice(row, row+1), mlx.Slice(0, baseFill))
}
scored := make([]*mlx.Array, 0, rows)
for i := range rows {
rowLogits := logits.Slice(mlx.Slice(i, i+1), mlx.Slice())
hist := base
prefixLen := min(i, draftCount)
if prefixLen > 0 {
prefix := draftTokens.Slice(mlx.Slice(), mlx.Slice(0, prefixLen))
if hist == nil {
hist = prefix
} else {
hist = hist.Concatenate(1, prefix)
}
if hist.Dim(1) > slot.opts.RepeatLastN {
hist = hist.Slice(mlx.Slice(), mlx.Slice(hist.Dim(1)-slot.opts.RepeatLastN, mlx.End))
}
}
scored = append(scored, slot.speculativeScores(&slotCtx{opts: slot.opts, history: hist}, rowLogits))
}
return mlx.Concatenate(scored, 0)
}
func (s *Sampler) speculativeHistory(slot *slotState, draftTokens *mlx.Array, rows int) *mlx.Array {
row := slices.Index(s.slots, slot)
width := slot.opts.RepeatLastN
base := s.history.Slice(mlx.Slice(row, row+1), mlx.Slice(0, width))
base = mlx.Tile(base, []int32{int32(rows), 1})
next := slot.historyLen % width
draftCount := 0
if draftTokens != nil {
draftCount = draftTokens.Dim(1)
}
if draftCount == 0 {
return base
}
sourceIdx := make([]int32, rows*width)
writeMask := make([]bool, rows*width)
for i := range rows {
prefixLen := min(i, draftCount)
for j := range prefixLen {
pos := (next + j) % width
sourceIdx[i*width+pos] = int32(j)
writeMask[i*width+pos] = true
}
}
draftRows := mlx.Tile(draftTokens, []int32{int32(rows), 1})
idx := mlx.NewArrayInt32(sourceIdx, []int32{int32(rows), int32(width)})
mask := mlx.FromValues(writeMask, rows, width)
values := draftRows.TakeAlongAxis(idx, 1)
return mlx.Where(mask, values, base)
}
func (slot *slotState) speculativeScores(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
scores := logits
// buildTransforms always appends the final selector transform
// (greedy or temperature sampling). Speculative validation needs the
// processed logits before that selector mutates the distribution.
for _, t := range slot.transforms[:len(slot.transforms)-1] {
scores = t(ctx, scores)
}
if slot.opts.Temperature > 0 {
scores = mlx.DivScalar(scores, slot.opts.Temperature)
}
return scores
}
// canBatch reports whether the call can take the uniform batched path.
// All slots must share Options; when penalties are active the call must
// additionally cover every registered slot in registration order with a

View file

@ -176,6 +176,65 @@ func TestSampleHistoryWindow(t *testing.T) {
}
}
func TestSpeculativeScoresUsesDraftHistoryWithoutCommit(t *testing.T) {
skipIfNoMLX(t)
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
s.Add(0, Options{RepeatLastN: 2, RepeatPenalty: 10}, []int32{1, 2})
draftTokens := mlx.NewArrayInt32([]int32{3, 4}, []int32{1, 2})
scores := s.SpeculativeScores(0, batchLogits(
[]float32{0, 9, 9, 8, 0}, // history {1,2}; token 3 wins
[]float32{0, 0, 9, 9, 8}, // history {2,3}; token 4 wins
[]float32{0, 0, 9, 9, 8}, // history {3,4}; token 2 wins
), draftTokens)
tokens := scores.Argmax(-1, false).AsType(mlx.DTypeInt32)
mlx.Eval(tokens)
if got, want := tokens.Ints(), []int{3, 4, 2}; len(got) != len(want) {
t.Fatalf("tokens = %v, want %v", got, want)
} else {
for i := range want {
if got[i] != want[i] {
t.Fatalf("tokens = %v, want %v", got, want)
}
}
}
if s.byID[0].historyLen != 2 {
t.Fatalf("historyLen = %d, want 2", s.byID[0].historyLen)
}
}
func TestCommitBatchesRingWrites(t *testing.T) {
skipIfNoMLX(t)
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
s.Add(0, Options{RepeatLastN: 4, RepeatPenalty: 1.1}, []int32{10, 11, 12})
s.Commit(0, []int32{20, 21, 22})
s.Commit(0, []int32{30, 31, 32, 33, 34})
mlx.Eval(s.history)
got := s.history.Ints()
want := []int{32, 33, 34, 31}
for i := range want {
if got[i] != want[i] {
t.Fatalf("history = %v, want %v", got, want)
}
}
if s.byID[0].historyLen != 11 {
t.Fatalf("historyLen = %d, want 11", s.byID[0].historyLen)
}
}
// TestBatchSamplingPreservesPerSlotBehavior is the core equivalence test:
// for every representative dispatch branch (uniform, serial on mixed opts,
// serial on partial ring, subset/out-of-order), a batched Sample call must

View file

@ -0,0 +1,390 @@
package gemma4
import (
"encoding/json"
"fmt"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
)
var (
_ base.DraftModel = (*AssistantModel)(nil)
_ base.MTPDraftModel = (*AssistantModel)(nil)
)
type AssistantConfig struct {
TextConfig TextConfig `json:"text_config"`
BackboneHiddenSize int32 `json:"backbone_hidden_size"`
UseOrderedEmbeddings bool `json:"use_ordered_embeddings"`
NumCentroids int32 `json:"num_centroids"`
CentroidIntermediateTopK int32 `json:"centroid_intermediate_top_k"`
}
type AssistantModel struct {
PreProjection nn.LinearLayer
PostProjection nn.LinearLayer
EmbedTokens nn.EmbeddingLayer
LMHead nn.LinearLayer
Centroids nn.LinearLayer
TokenOrdering *mlx.Array
Layers []*AssistantLayer
Norm *nn.RMSNorm
NormScaled *mlx.Array
*AssistantConfig
tensorPrefix string
QuantGroupSize int
QuantBits int
QuantMode string
TensorQuant map[string]*model.TensorQuantInfo
}
type AssistantLayer struct {
InputNorm *nn.RMSNorm
PostAttnNorm *nn.RMSNorm
PreFFNorm *nn.RMSNorm
PostFFNorm *nn.RMSNorm
InputNormScaled *mlx.Array
PostAttnNormScaled *mlx.Array
PreFFNormScaled *mlx.Array
PostFFNormScaled *mlx.Array
Attention *AssistantAttention
MLP *MLP
LayerScalar *mlx.Array
IsSliding bool
}
type AssistantAttention struct {
QProj nn.LinearLayer
OProj nn.LinearLayer
QNorm *nn.RMSNorm
QNormScaled *mlx.Array
}
func parseAssistantConfig(configData []byte) (AssistantConfig, error) {
var raw struct {
TextConfig json.RawMessage `json:"text_config"`
BackboneHiddenSize int32 `json:"backbone_hidden_size"`
UseOrderedEmbeddings bool `json:"use_ordered_embeddings"`
NumCentroids int32 `json:"num_centroids"`
CentroidIntermediateTopK int32 `json:"centroid_intermediate_top_k"`
}
if err := json.Unmarshal(configData, &raw); err != nil {
return AssistantConfig{}, fmt.Errorf("parse assistant config: %w", err)
}
if len(raw.TextConfig) == 0 {
return AssistantConfig{}, fmt.Errorf("assistant config missing text_config")
}
text, err := parseTextConfig(raw.TextConfig)
if err != nil {
return AssistantConfig{}, err
}
if raw.NumCentroids == 0 {
raw.NumCentroids = 2048
}
if raw.CentroidIntermediateTopK == 0 {
raw.CentroidIntermediateTopK = 32
}
return AssistantConfig{
TextConfig: text,
BackboneHiddenSize: raw.BackboneHiddenSize,
UseOrderedEmbeddings: raw.UseOrderedEmbeddings,
NumCentroids: raw.NumCentroids,
CentroidIntermediateTopK: raw.CentroidIntermediateTopK,
}, nil
}
func newAssistantModel(root *model.Root, target base.Model) (base.DraftModel, error) {
if root == nil || root.Draft == nil {
return nil, fmt.Errorf("draft metadata missing")
}
configPath := root.Draft.Config
if configPath == "" {
configPath = "draft/config.json"
}
configData, err := root.Manifest.ReadConfig(configPath)
if err != nil {
return nil, fmt.Errorf("load draft config: %w", err)
}
cfg, err := parseAssistantConfig(configData)
if err != nil {
return nil, err
}
targetGemma, ok := target.(*Model)
if !ok {
return nil, fmt.Errorf("gemma4 assistant requires gemma4 target, got %T", target)
}
if cfg.BackboneHiddenSize != 0 && cfg.BackboneHiddenSize != targetGemma.HiddenSize {
return nil, fmt.Errorf("assistant backbone hidden size %d does not match target hidden size %d", cfg.BackboneHiddenSize, targetGemma.HiddenSize)
}
if cfg.TextConfig.VocabSize != targetGemma.VocabSize {
return nil, fmt.Errorf("assistant vocab size %d does not match target vocab size %d", cfg.TextConfig.VocabSize, targetGemma.VocabSize)
}
tensorPrefix := root.Draft.TensorPrefix
if tensorPrefix == "" {
tensorPrefix = "draft."
}
m := &AssistantModel{
AssistantConfig: &cfg,
tensorPrefix: tensorPrefix,
Layers: make([]*AssistantLayer, cfg.TextConfig.NumHiddenLayers),
TensorQuant: root.AllTensorQuant(),
}
if qt := root.QuantType(); qt != "" {
m.QuantGroupSize, m.QuantBits, m.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
m.QuantGroupSize = gs
}
}
for i := range m.Layers {
m.Layers[i] = &AssistantLayer{
IsSliding: isLayerSliding(int32(i), &m.TextConfig),
Attention: &AssistantAttention{},
MLP: &MLP{},
}
}
return m, nil
}
func (m *AssistantModel) LoadWeights(tensors map[string]*mlx.Array) error {
prefix := m.tensorPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
m.PreProjection = linears.Make(prefix + "pre_projection")
m.PostProjection = linears.Make(prefix + "post_projection")
if m.PreProjection == nil || m.PostProjection == nil {
return fmt.Errorf("missing assistant projection weights")
}
m.EmbedTokens = model.MakeEmbeddingLayer(tensors, prefix+"model.embed_tokens", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
if m.EmbedTokens == nil {
return fmt.Errorf("missing assistant embedding weight")
}
m.LMHead = m.EmbedTokens.AsLinear()
if m.UseOrderedEmbeddings {
m.Centroids = linears.Make(prefix + "masked_embedding.centroids")
m.TokenOrdering = tensors[prefix+"masked_embedding.token_ordering"]
if m.Centroids == nil || m.TokenOrdering == nil {
return fmt.Errorf("missing ordered embedding tensors: %smasked_embedding.centroids.weight and %smasked_embedding.token_ordering", prefix, prefix)
}
m.TokenOrdering = m.TokenOrdering.AsType(mlx.DTypeInt32)
}
normWeight := tensors[prefix+"model.norm.weight"]
if normWeight == nil {
return fmt.Errorf("missing assistant final norm")
}
m.Norm = nn.NewRMSNorm(normWeight, m.TextConfig.RMSNormEps)
for i := range m.TextConfig.NumHiddenLayers {
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
layer := &AssistantLayer{
IsSliding: isLayerSliding(i, &m.TextConfig),
Attention: &AssistantAttention{
QProj: linears.Make(layerPrefix + ".self_attn.q_proj"),
OProj: linears.Make(layerPrefix + ".self_attn.o_proj"),
},
MLP: &MLP{
GateProj: linears.Make(layerPrefix + ".mlp.gate_proj"),
UpProj: linears.Make(layerPrefix + ".mlp.up_proj"),
DownProj: linears.Make(layerPrefix + ".mlp.down_proj"),
},
LayerScalar: tensors[layerPrefix+".layer_scalar"],
}
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
layer.InputNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps)
}
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
layer.PostAttnNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps)
}
if w := tensors[layerPrefix+".pre_feedforward_layernorm.weight"]; w != nil {
layer.PreFFNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps)
}
if w := tensors[layerPrefix+".post_feedforward_layernorm.weight"]; w != nil {
layer.PostFFNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps)
}
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
layer.Attention.QNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps)
}
if layer.InputNorm == nil || layer.PostAttnNorm == nil || layer.PreFFNorm == nil || layer.PostFFNorm == nil {
return fmt.Errorf("assistant layer %d: missing norm weights", i)
}
if layer.Attention.QProj == nil || layer.Attention.OProj == nil || layer.Attention.QNorm == nil {
return fmt.Errorf("assistant layer %d: missing attention weights", i)
}
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
return fmt.Errorf("assistant layer %d: missing mlp weights", i)
}
m.Layers[i] = layer
}
m.precomputeScaledWeights()
return nil
}
func (m *AssistantModel) precomputeScaledWeights() {
if m.Norm != nil {
m.NormScaled = m.Norm.Weight
}
for _, layer := range m.Layers {
if layer.InputNorm != nil {
layer.InputNormScaled = layer.InputNorm.Weight
}
if layer.PostAttnNorm != nil {
layer.PostAttnNormScaled = layer.PostAttnNorm.Weight
}
if layer.PreFFNorm != nil {
layer.PreFFNormScaled = layer.PreFFNorm.Weight
}
if layer.PostFFNorm != nil {
layer.PostFFNormScaled = layer.PostFFNorm.Weight
}
if layer.Attention != nil && layer.Attention.QNorm != nil {
layer.Attention.QNormScaled = layer.Attention.QNorm.Weight
}
}
}
func (m *AssistantModel) Draft(inputsEmbeds *mlx.Array, position int32, caches []cache.Cache) (logits, hidden *mlx.Array) {
dims := inputsEmbeds.Dims()
B, L := int32(dims[0]), int32(dims[1])
b := &batch.Batch{
InputIDs: mlx.Zeros(mlx.DTypeInt32, int(B), int(L)),
SeqOffsets: []int32{position},
SeqQueryLens: []int32{L},
}
sliding, full := m.sharedHistories(b, caches)
h := m.PreProjection.Forward(inputsEmbeds)
positions := mlx.FromValues([]int32{position}, 1)
for _, layer := range m.Layers {
h = layer.Forward(h, b, positions, B, L, &m.TextConfig, sliding, full)
}
hidden = mlx.RMSNormFn(h, m.NormScaled, m.TextConfig.RMSNormEps)
projected := m.PostProjection.Forward(hidden)
return m.unembed(hidden), projected
}
func (m *AssistantModel) sharedHistories(b *batch.Batch, caches []cache.Cache) (sliding, full *nn.KVHistory) {
if len(caches) < 2 {
return nil, nil
}
if v, ok := caches[len(caches)-2].(cache.Viewer); ok {
sliding = v.View(b)
}
if v, ok := caches[len(caches)-1].(cache.Viewer); ok {
full = v.View(b)
}
return sliding, full
}
func (m *AssistantModel) unembed(hidden *mlx.Array) *mlx.Array {
if m.UseOrderedEmbeddings {
return m.applyCentroidMasking(hidden)
}
return m.LMHead.Forward(hidden)
}
func (m *AssistantModel) applyCentroidMasking(hidden *mlx.Array) *mlx.Array {
B, L := hidden.Dim(0), hidden.Dim(1)
vocab := int(m.TextConfig.VocabSize)
numCentroids := int(m.NumCentroids)
vocabPerCentroid := vocab / numCentroids
topK := int(m.CentroidIntermediateTopK)
centroidLogits := m.Centroids.Forward(hidden)
topKIndices := centroidLogits.Negative().ArgpartitionAxis(topK-1, -1).Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, topK))
ordering := m.TokenOrdering.Reshape(numCentroids, vocabPerCentroid)
selectedCanonical := ordering.TakeAxis(topKIndices, 0)
selectedFlat := selectedCanonical.Reshape(B * L * topK * vocabPerCentroid)
embeddings := m.EmbedTokens.Forward(selectedFlat)
embeddings = embeddings.Reshape(B, L, topK*vocabPerCentroid, int(m.TextConfig.HiddenSize))
selectedLogits := hidden.ExpandDims(2).Matmul(embeddings.Transpose(0, 1, 3, 2)).Squeeze(2)
out := mlx.Zeros(selectedLogits.DType(), B, L, vocab)
out = mlx.AddScalar(out, -1.0e30)
return out.PutAlongAxis(selectedCanonical.Reshape(B, L, topK*vocabPerCentroid), selectedLogits, -1)
}
func (l *AssistantLayer) Forward(x *mlx.Array, b *batch.Batch, positions *mlx.Array, B, L int32, cfg *TextConfig, sliding, full *nn.KVHistory) *mlx.Array {
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
attnOut := l.Attention.Forward(normed, b, positions, B, L, l.IsSliding, cfg, sliding, full)
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
normed = mlx.RMSNormFn(h, l.PreFFNormScaled, cfg.RMSNormEps)
mlpOut := l.MLP.Forward(normed)
mlpOut = mlx.RMSNormFn(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
h = mlx.Add(h, mlpOut)
if l.LayerScalar != nil {
h = mlx.Mul(h, l.LayerScalar)
}
return h
}
func (a *AssistantAttention) Forward(x *mlx.Array, b *batch.Batch, positions *mlx.Array, B, L int32, isSliding bool, cfg *TextConfig, sliding, full *nn.KVHistory) *mlx.Array {
headDim := cfg.HeadDim
scale := cfg.SlidingScale
ropeDims := cfg.SlidingRopeDims
ropeBase := cfg.SlidingRopeBase
history := sliding
if !isSliding {
headDim = cfg.GlobalHeadDim
scale = cfg.FullScale
ropeDims = cfg.FullRopeDims
ropeBase = cfg.FullRopeBase
history = full
}
if history == nil {
panic("gemma4 assistant missing shared target KV history")
}
q := a.QProj.Forward(x)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim)
q = mlx.Transpose(q, 0, 2, 1, 3)
q = mlx.RMSNormFn(q, a.QNormScaled, cfg.RMSNormEps)
var ropeFreqs *mlx.Array
if !isSliding {
ropeFreqs = cfg.FullRopeFreqs
}
q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, positions, ropeFreqs)
mask := nn.CausalMask()
if isSliding && cfg.SlidingWindow > 0 {
mask = mask.Intersect(nn.SlidingWindowMask(b, history.K().Dim(2), int(cfg.SlidingWindow), q.DType()))
}
out := nn.ScaledDotProductAttention(b, q, scale, nn.WithKVHistory(history), nn.WithMask(mask))
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*headDim)
if !mlx.MetalIsAvailable() {
out = mlx.Contiguous(out, false)
}
return a.OProj.Forward(out)
}

View file

@ -18,10 +18,15 @@ import (
func init() {
base.Register("Gemma4ForCausalLM", newModel)
base.Register("Gemma4ForConditionalGeneration", newModel)
base.RegisterDraft("Gemma4AssistantForCausalLM", newAssistantModel)
base.RegisterDraft("gemma4_assistant", newAssistantModel)
}
// Compile-time interface checks.
var _ base.Model = (*Model)(nil)
var (
_ base.Model = (*Model)(nil)
_ base.MTPDefaultsProvider = (*Model)(nil)
)
// RopeParams holds per-layer-type RoPE settings.
type RopeParams struct {
@ -466,6 +471,24 @@ func (m *Model) EnableCompile() bool {
return true
}
func (m *Model) MTPDraftDefaults(_ bool) base.MTPDefaults {
defaults := base.MTPDefaults{
InitialDraftTokens: 4,
MaxDraftTokens: 16,
Enabled: true,
}
if m == nil || m.TextConfig == nil {
return defaults
}
switch {
case !m.EnableMoeBlock && m.HiddenSize == 5376 && m.NumHiddenLayers == 60:
defaults.InitialDraftTokens = 14
case m.EnableMoeBlock && m.HiddenSize == 2816 && m.NumHiddenLayers == 30:
defaults.InitialDraftTokens = 8
}
return defaults
}
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
for _, prefix := range []string{"", "language_model.", "model.language_model."} {
if tensors[prefix+"embed_tokens.weight"] != nil {
@ -1068,6 +1091,11 @@ func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}
// TokenEmbeddings returns the target model's scaled token embeddings for MTP.
func (m *Model) TokenEmbeddings(inputIDs *mlx.Array) *mlx.Array {
return mlx.MulScalar(m.EmbedTokens.Forward(inputIDs), m.EmbedScale)
}
// NewCaches creates cache objects for layers that own KV state.
func (m *Model) NewCaches() []cache.Cache {
cacheLayers := len(m.Layers)

View file

@ -434,6 +434,54 @@ func TestLayerTypeDetection(t *testing.T) {
}
}
func TestMTPDraftDefaults(t *testing.T) {
tests := []struct {
name string
cfg *TextConfig
wantInitial int
wantMax int
}{
{
name: "nil config",
wantInitial: 4,
wantMax: 16,
},
{
name: "31b bf16",
cfg: &TextConfig{HiddenSize: 5376, NumHiddenLayers: 60},
wantInitial: 14,
wantMax: 16,
},
{
name: "31b quantized",
cfg: &TextConfig{HiddenSize: 5376, NumHiddenLayers: 60, QuantBits: 4},
wantInitial: 14,
wantMax: 16,
},
{
name: "26b-a4b moe",
cfg: &TextConfig{HiddenSize: 2816, NumHiddenLayers: 30, EnableMoeBlock: true},
wantInitial: 8,
wantMax: 16,
},
{
name: "generic default",
cfg: &TextConfig{HiddenSize: 2560, NumHiddenLayers: 42, HiddenSizePerLayer: 256},
wantInitial: 4,
wantMax: 16,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := (&Model{TextConfig: tt.cfg}).MTPDraftDefaults(false)
if got.InitialDraftTokens != tt.wantInitial || got.MaxDraftTokens != tt.wantMax || !got.Enabled {
t.Fatalf("MTPDraftDefaults() = %+v, want initial=%d max=%d enabled=true", got, tt.wantInitial, tt.wantMax)
}
})
}
}
func TestNewCachesOmitsSharedKVLayers(t *testing.T) {
m := &Model{
Layers: []*DecoderLayer{
@ -467,6 +515,55 @@ func TestNewCachesIncludesAllNonSharedLayers(t *testing.T) {
}
}
func TestNewCachesAssistantSharedHistoryOrdering(t *testing.T) {
cases := []struct {
name string
totalLayers int
slidingBeforeFull int
cacheLayers int
}{
{name: "31B", totalLayers: 60, slidingBeforeFull: 5, cacheLayers: 60},
{name: "26B-A4B", totalLayers: 30, slidingBeforeFull: 5, cacheLayers: 30},
{name: "E4B", totalLayers: 42, slidingBeforeFull: 5, cacheLayers: 24},
{name: "E2B", totalLayers: 35, slidingBeforeFull: 4, cacheLayers: 15},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
groupSize := tc.slidingBeforeFull + 1
layers := make([]*DecoderLayer, tc.totalLayers)
for i := range layers {
donor := int32(-1)
if i >= tc.cacheLayers {
donor = 0
}
layers[i] = &DecoderLayer{
IsSliding: i%groupSize < tc.slidingBeforeFull,
KVShareDonor: donor,
}
}
m := &Model{
Layers: layers,
TextConfig: &TextConfig{SlidingWindow: 512},
}
caches := m.NewCaches()
if got := len(caches); got != tc.cacheLayers {
t.Fatalf("len(NewCaches()) = %d, want %d", got, tc.cacheLayers)
}
gotSliding := len(caches) - 2
gotFull := len(caches) - 1
if !m.Layers[gotSliding].IsSliding {
t.Fatalf("cache %d should be sliding attention", gotSliding)
}
if m.Layers[gotFull].IsSliding {
t.Fatalf("cache %d should be full attention", gotFull)
}
})
}
}
func TestResolveWeightPrefix(t *testing.T) {
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)