mirror of
https://github.com/ollama/ollama.git
synced 2026-05-13 06:21:28 +00:00
mlx: Gemma4 MTP speculative decoding (#15980)
This change adds support for MTP (multi-token prediction) speculative decoding for the gemma4 model family. It includes: * support for importing safetensors based gemma4 draft models with `ollama create` * a new DRAFT command in the Modelfile for specifying draft models * a --quantize-draft flag for the ollama create command to quantize the draft model * cache support for speculation * changes to the rotating cache to be able to handle MTP correctly * sampling support for draft model token prediction --------- Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
This commit is contained in:
parent
4017af96cd
commit
15e6076d79
28 changed files with 2928 additions and 42 deletions
7
.github/workflows/release.yaml
vendored
7
.github/workflows/release.yaml
vendored
|
|
@ -423,8 +423,8 @@ jobs:
|
|||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/include*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-mlx.tar.in ;;
|
||||
lib/ollama/include*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-mlx.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
|
|
@ -458,12 +458,14 @@ jobs:
|
|||
CGO_CFLAGS
|
||||
CGO_CXXFLAGS
|
||||
GOFLAGS
|
||||
APT_MIRROR=http://azure.archive.ubuntu.com/ubuntu
|
||||
- os: linux
|
||||
arch: amd64
|
||||
build-args: |
|
||||
CGO_CFLAGS
|
||||
CGO_CXXFLAGS
|
||||
GOFLAGS
|
||||
APT_MIRROR=http://azure.archive.ubuntu.com/ubuntu
|
||||
- os: linux
|
||||
arch: amd64
|
||||
suffix: '-rocm'
|
||||
|
|
@ -472,6 +474,7 @@ jobs:
|
|||
CGO_CXXFLAGS
|
||||
GOFLAGS
|
||||
FLAVOR=rocm
|
||||
APT_MIRROR=http://azure.archive.ubuntu.com/ubuntu
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
needs: setup-environment
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@
|
|||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
|
||||
"CMAKE_CUDA_FLAGS": "-t 4",
|
||||
"CMAKE_CUDA_FLAGS": "-t 2",
|
||||
"OLLAMA_RUNNER_DIR": "cuda_v13"
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -212,8 +212,11 @@ COPY --from=cpu dist/lib/ollama /lib/ollama
|
|||
COPY --from=build /bin/ollama /bin/ollama
|
||||
|
||||
FROM ubuntu:24.04
|
||||
RUN apt-get update \
|
||||
ARG APT_MIRROR=http://archive.ubuntu.com/ubuntu
|
||||
RUN sed -i "s|http://archive.ubuntu.com/ubuntu|$APT_MIRROR|g" /etc/apt/sources.list.d/ubuntu.sources \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y ca-certificates libvulkan1 libopenblas0 \
|
||||
&& sed -i "s|$APT_MIRROR|http://archive.ubuntu.com/ubuntu|g" /etc/apt/sources.list.d/ubuntu.sources \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
COPY --from=archive /bin /usr/bin
|
||||
|
|
|
|||
61
cmd/cmd.go
61
cmd/cmd.go
|
|
@ -54,6 +54,7 @@ import (
|
|||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
xcmd "github.com/ollama/ollama/x/cmd"
|
||||
xcreate "github.com/ollama/ollama/x/create"
|
||||
xcreateclient "github.com/ollama/ollama/x/create/client"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
|
@ -145,6 +146,39 @@ func isLocalhost() bool {
|
|||
return ip != nil && (ip.IsLoopback() || ip.IsUnspecified())
|
||||
}
|
||||
|
||||
func resolveExperimentalLocalModelDir(ref, filename string) string {
|
||||
if ref == "" || filepath.IsAbs(ref) || filename == "" {
|
||||
return ref
|
||||
}
|
||||
|
||||
candidate := filepath.Join(filepath.Dir(filename), ref)
|
||||
if xcreate.IsSafetensorsModelDir(candidate) || xcreate.IsTensorModelDir(candidate) {
|
||||
return candidate
|
||||
}
|
||||
|
||||
return ref
|
||||
}
|
||||
|
||||
func resolveExperimentalDraftDir(ref, filename string) (string, error) {
|
||||
if ref == "" {
|
||||
return "", nil
|
||||
}
|
||||
if filepath.IsAbs(ref) {
|
||||
if xcreate.IsSafetensorsModelDir(ref) {
|
||||
return ref, nil
|
||||
}
|
||||
return "", fmt.Errorf("draft %s is not a supported safetensors model directory", ref)
|
||||
}
|
||||
if filename != "" {
|
||||
candidate := filepath.Join(filepath.Dir(filename), ref)
|
||||
if xcreate.IsSafetensorsModelDir(candidate) {
|
||||
return candidate, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("DRAFT model references are not supported with --experimental yet: %s", ref)
|
||||
}
|
||||
|
||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
|
@ -159,6 +193,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||
// Check for --experimental flag for safetensors model creation
|
||||
// This gates both safetensors LLM and imagegen model creation
|
||||
experimental, _ := cmd.Flags().GetBool("experimental")
|
||||
draftQuantize, _ := cmd.Flags().GetString("draft-quantize")
|
||||
if draftQuantize != "" && !experimental {
|
||||
return errors.New("--draft-quantize requires --experimental")
|
||||
}
|
||||
if experimental {
|
||||
if !isLocalhost() {
|
||||
return errors.New("remote safetensor model creation not yet supported")
|
||||
|
|
@ -192,17 +230,22 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// Resolve relative paths based on Modelfile location
|
||||
if !filepath.IsAbs(modelDir) && filename != "" {
|
||||
modelDir = filepath.Join(filepath.Dir(filename), modelDir)
|
||||
modelDir = resolveExperimentalLocalModelDir(modelDir, filename)
|
||||
if mfConfig.Draft != "" {
|
||||
draftDir, err := resolveExperimentalDraftDir(mfConfig.Draft, filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mfConfig.Draft = draftDir
|
||||
}
|
||||
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: modelName,
|
||||
ModelDir: modelDir,
|
||||
Quantize: quantize,
|
||||
Modelfile: mfConfig,
|
||||
ModelName: modelName,
|
||||
ModelDir: modelDir,
|
||||
Quantize: quantize,
|
||||
DraftQuantize: draftQuantize,
|
||||
Modelfile: mfConfig,
|
||||
}, p)
|
||||
}
|
||||
|
||||
|
|
@ -2176,6 +2219,9 @@ func NewCLI() *cobra.Command {
|
|||
if experimental, _ := cmd.Flags().GetBool("experimental"); experimental {
|
||||
return nil
|
||||
}
|
||||
if draftQuantize, _ := cmd.Flags().GetString("draft-quantize"); draftQuantize != "" {
|
||||
return errors.New("--draft-quantize requires --experimental")
|
||||
}
|
||||
return checkServerHeartbeat(cmd, args)
|
||||
},
|
||||
RunE: CreateHandler,
|
||||
|
|
@ -2183,6 +2229,7 @@ func NewCLI() *cobra.Command {
|
|||
|
||||
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
|
||||
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
|
||||
createCmd.Flags().String("draft-quantize", "", "Quantize draft model to this level")
|
||||
createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation")
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
|
@ -1524,6 +1525,87 @@ func TestCreateHandler(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCreateHandlerDraftQuantizeRequiresExperimental(t *testing.T) {
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().Bool("experimental", false, "")
|
||||
cmd.Flags().String("draft-quantize", "mxfp8", "")
|
||||
cmd.SetContext(t.Context())
|
||||
|
||||
err := CreateHandler(cmd, []string{"test-model"})
|
||||
if err == nil || !strings.Contains(err.Error(), "--draft-quantize requires --experimental") {
|
||||
t.Fatalf("error = %v, want draft-quantize requires experimental", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateHandlerDraftRequiresExperimental(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
modelfile := filepath.Join(dir, "Modelfile")
|
||||
if err := os.WriteFile(modelfile, []byte("FROM base\nDRAFT ./assistant\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().Bool("experimental", false, "")
|
||||
cmd.Flags().String("draft-quantize", "", "")
|
||||
cmd.Flags().String("file", modelfile, "")
|
||||
cmd.SetContext(t.Context())
|
||||
|
||||
err := CreateHandler(cmd, []string{"test-model"})
|
||||
if err == nil || !strings.Contains(err.Error(), "DRAFT requires --experimental") {
|
||||
t.Fatalf("error = %v, want DRAFT requires --experimental", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveExperimentalLocalModelDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
modelfile := filepath.Join(dir, "Modelfile")
|
||||
modelDir := filepath.Join(dir, "model")
|
||||
if err := os.Mkdir(modelDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(modelDir, "config.json"), []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(modelDir, "model.safetensors"), []byte("dummy"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got := resolveExperimentalLocalModelDir("gemma4", modelfile); got != "gemma4" {
|
||||
t.Fatalf("resolveExperimentalLocalModelDir(model name) = %q, want gemma4", got)
|
||||
}
|
||||
if got := resolveExperimentalLocalModelDir("./model", modelfile); got != modelDir {
|
||||
t.Fatalf("resolveExperimentalLocalModelDir(local dir) = %q, want %q", got, modelDir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveExperimentalDraftDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
modelfile := filepath.Join(dir, "Modelfile")
|
||||
draftDir := filepath.Join(dir, "assistant")
|
||||
if err := os.Mkdir(draftDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(draftDir, "config.json"), []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(draftDir, "model.safetensors"), []byte("dummy"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got, err := resolveExperimentalDraftDir("./assistant", modelfile)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != draftDir {
|
||||
t.Fatalf("resolveExperimentalDraftDir(local dir) = %q, want %q", got, draftDir)
|
||||
}
|
||||
|
||||
_, err = resolveExperimentalDraftDir("assistant-model", modelfile)
|
||||
if err == nil || !strings.Contains(err.Error(), "DRAFT model references are not supported with --experimental yet") {
|
||||
t.Fatalf("error = %v, want unsupported draft model reference", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCreateRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
|
|||
|
|
@ -83,6 +83,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
|||
req.Files[k] = v
|
||||
}
|
||||
}
|
||||
case "draft":
|
||||
return nil, errors.New("DRAFT requires --experimental")
|
||||
case "adapter":
|
||||
path, err := expandPath(c.Args, relativeDir)
|
||||
if err != nil {
|
||||
|
|
@ -336,7 +338,7 @@ func (c Command) String() string {
|
|||
switch c.Name {
|
||||
case "model":
|
||||
fmt.Fprintf(&sb, "FROM %s", c.Args)
|
||||
case "license", "template", "system", "adapter", "renderer", "parser", "requires":
|
||||
case "license", "template", "system", "adapter", "renderer", "parser", "requires", "draft":
|
||||
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
|
||||
case "message":
|
||||
role, message, _ := strings.Cut(c.Args, ": ")
|
||||
|
|
@ -362,7 +364,7 @@ const (
|
|||
var (
|
||||
errMissingFrom = errors.New("no FROM line")
|
||||
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", \"message\", or \"requires\"")
|
||||
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"draft\", \"renderer\", \"parser\", \"parameter\", \"message\", or \"requires\"")
|
||||
)
|
||||
|
||||
type ParserError struct {
|
||||
|
|
@ -622,7 +624,7 @@ func isValidMessageRole(role string) bool {
|
|||
|
||||
func isValidCommand(cmd string) bool {
|
||||
switch strings.ToLower(cmd) {
|
||||
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message", "requires":
|
||||
case "from", "license", "template", "system", "adapter", "draft", "renderer", "parser", "parameter", "message", "requires":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
|
|
|||
|
|
@ -58,6 +58,32 @@ TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
|||
assert.Equal(t, expectedCommands, modelfile.Commands)
|
||||
}
|
||||
|
||||
func TestParseFileDraft(t *testing.T) {
|
||||
modelfile, err := ParseFile(strings.NewReader(`
|
||||
FROM base
|
||||
DRAFT ./assistant
|
||||
`))
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedCommands := []Command{
|
||||
{Name: "model", Args: "base"},
|
||||
{Name: "draft", Args: "./assistant"},
|
||||
}
|
||||
assert.Equal(t, expectedCommands, modelfile.Commands)
|
||||
assert.Contains(t, modelfile.String(), "DRAFT ./assistant")
|
||||
}
|
||||
|
||||
func TestCreateRequestDraftRequiresExperimental(t *testing.T) {
|
||||
modelfile, err := ParseFile(strings.NewReader(`
|
||||
FROM base
|
||||
DRAFT ./assistant
|
||||
`))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = modelfile.CreateRequest("")
|
||||
require.ErrorContains(t, err, "DRAFT requires --experimental")
|
||||
}
|
||||
|
||||
func TestParseFileTrimSpace(t *testing.T) {
|
||||
input := `
|
||||
FROM " model 1"
|
||||
|
|
|
|||
|
|
@ -62,13 +62,15 @@ if echo $PLATFORM | grep "," > /dev/null ; then
|
|||
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
|
||||
tar c -C ./dist/linux_amd64 --exclude rocm . | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
|
||||
tar c -C ./dist/linux_amd64 --exclude rocm --exclude 'mlx*' --exclude include . | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
|
||||
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
|
||||
tar c -C ./dist/linux_amd64 ./lib/ollama/mlx_cuda_v13 ./lib/ollama/include | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-mlx.tar.zst
|
||||
elif echo $PLATFORM | grep "arm64" > /dev/null ; then
|
||||
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
|
||||
elif echo $PLATFORM | grep "amd64" > /dev/null ; then
|
||||
tar c -C ./dist/ --exclude rocm bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
|
||||
tar c -C ./dist/ --exclude rocm --exclude 'mlx*' --exclude include bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
|
||||
tar c -C ./dist/ ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
|
||||
tar c -C ./dist/ ./lib/ollama/mlx_cuda_v13 ./lib/ollama/include | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-mlx.tar.zst
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ type ConfigV2 struct {
|
|||
ContextLen int `json:"context_length,omitempty"`
|
||||
EmbedLen int `json:"embedding_length,omitempty"`
|
||||
BaseName string `json:"base_name,omitempty"`
|
||||
Draft *Draft `json:"draft,omitempty"`
|
||||
|
||||
// required by spec
|
||||
Architecture string `json:"architecture"`
|
||||
|
|
@ -26,6 +27,14 @@ type ConfigV2 struct {
|
|||
RootFS RootFS `json:"rootfs"`
|
||||
}
|
||||
|
||||
// Draft describes an auxiliary draft model stored in the same manifest.
|
||||
type Draft struct {
|
||||
ModelFormat string `json:"model_format,omitempty"`
|
||||
Architecture string `json:"architecture,omitempty"`
|
||||
TensorPrefix string `json:"tensor_prefix,omitempty"`
|
||||
Config string `json:"config,omitempty"`
|
||||
}
|
||||
|
||||
// RootFS represents the root filesystem configuration for a model.
|
||||
type RootFS struct {
|
||||
Type string `json:"type"`
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
imagemanifest "github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/safetensors"
|
||||
)
|
||||
|
||||
|
|
@ -34,6 +35,7 @@ type ModelfileConfig struct {
|
|||
Template string
|
||||
System string
|
||||
License string
|
||||
Draft string
|
||||
Parser string
|
||||
Renderer string
|
||||
Parameters map[string]any
|
||||
|
|
@ -67,6 +69,8 @@ func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig,
|
|||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
case "draft":
|
||||
mfConfig.Draft = cmd.Args
|
||||
case "parser":
|
||||
mfConfig.Parser = cmd.Args
|
||||
case "renderer":
|
||||
|
|
@ -108,10 +112,12 @@ func ConfigFromModelfile(modelfile *parser.Modelfile) (string, *ModelfileConfig,
|
|||
|
||||
// CreateOptions holds all options for model creation.
|
||||
type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "int4", "int8", "nvfp4", "mxfp4", or "mxfp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "int4", "int8", "nvfp4", "mxfp4", or "mxfp8" for quantization
|
||||
DraftQuantize string // optional quantization level for draft model tensors
|
||||
Modelfile *ModelfileConfig // template/system/license/parser/renderer/parameters from Modelfile
|
||||
BaseConfig *model.ConfigV2
|
||||
}
|
||||
|
||||
// CreateModel imports a model from a local directory.
|
||||
|
|
@ -121,11 +127,23 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
|||
// Detect model type
|
||||
isSafetensors := create.IsSafetensorsModelDir(opts.ModelDir)
|
||||
isImageGen := create.IsTensorModelDir(opts.ModelDir)
|
||||
hasDraft := opts.Modelfile != nil && opts.Modelfile.Draft != ""
|
||||
isBaseModelWithDraft := hasDraft && !isSafetensors && create.IsSafetensorsLLMModel(opts.ModelDir)
|
||||
if opts.DraftQuantize != "" && !hasDraft {
|
||||
return fmt.Errorf("--draft-quantize requires a DRAFT model")
|
||||
}
|
||||
|
||||
if !isSafetensors && !isImageGen {
|
||||
if !isSafetensors && !isImageGen && !isBaseModelWithDraft {
|
||||
return fmt.Errorf("%s is not a supported model directory (needs config.json + *.safetensors or model_index.json)", opts.ModelDir)
|
||||
}
|
||||
|
||||
if hasDraft && !create.IsSafetensorsModelDir(opts.Modelfile.Draft) {
|
||||
return fmt.Errorf("draft %s is not a supported safetensors model directory", opts.Modelfile.Draft)
|
||||
}
|
||||
if hasDraft && isImageGen {
|
||||
return fmt.Errorf("draft models are only supported for safetensors LLM models")
|
||||
}
|
||||
|
||||
// Determine model type settings
|
||||
var modelType, spinnerKey string
|
||||
var capabilities []string
|
||||
|
|
@ -138,6 +156,9 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
|||
parserName = getParserName(opts.ModelDir)
|
||||
rendererName = getRendererName(opts.ModelDir)
|
||||
capabilities = inferSafetensorsCapabilities(opts.ModelDir, resolveParserName(opts.Modelfile, parserName))
|
||||
} else if isBaseModelWithDraft {
|
||||
modelType = "safetensors model"
|
||||
spinnerKey = "create"
|
||||
} else {
|
||||
modelType = "image generation model"
|
||||
spinnerKey = "imagegen"
|
||||
|
|
@ -156,13 +177,44 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
|||
p.Add(spinnerKey, spinner)
|
||||
}
|
||||
|
||||
// Create the model using shared callbacks
|
||||
var draftLayers []create.LayerInfo
|
||||
var err error
|
||||
if hasDraft {
|
||||
draftLayers, err = create.CreateDraftSafetensorsLayers(
|
||||
opts.Modelfile.Draft,
|
||||
"draft.",
|
||||
"draft",
|
||||
opts.DraftQuantize,
|
||||
newLayerCreator(),
|
||||
newTensorLayerCreator(),
|
||||
progressFn,
|
||||
)
|
||||
if err != nil {
|
||||
spinner.Stop()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if isBaseModelWithDraft {
|
||||
err = createModelFromBaseWithDraft(opts, draftLayers, progressFn)
|
||||
spinner.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("Created safetensors model '%s'\n", opts.ModelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create the model using shared callbacks
|
||||
if isSafetensors {
|
||||
writer := newManifestWriter(opts, capabilities, parserName, rendererName)
|
||||
if len(draftLayers) > 0 {
|
||||
writer = appendLayersManifestWriter(writer, draftLayers)
|
||||
}
|
||||
err = create.CreateSafetensorsModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities, parserName, rendererName),
|
||||
writer,
|
||||
progressFn,
|
||||
newPackedTensorLayerCreator(),
|
||||
)
|
||||
|
|
@ -184,6 +236,68 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func appendLayersManifestWriter(next create.ManifestWriter, extra []create.LayerInfo) create.ManifestWriter {
|
||||
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
|
||||
layers = append(layers, extra...)
|
||||
return next(modelName, config, layers)
|
||||
}
|
||||
}
|
||||
|
||||
func createModelFromBaseWithDraft(opts CreateOptions, draftLayers []create.LayerInfo, progressFn func(string)) error {
|
||||
progressFn(fmt.Sprintf("loading base model %s", opts.ModelDir))
|
||||
baseManifest, err := imagemanifest.LoadManifest(opts.ModelDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
baseConfig, err := readConfigV2(baseManifest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.BaseConfig = baseConfig
|
||||
|
||||
configLayer := baseManifest.GetConfigLayer("config.json")
|
||||
if configLayer == nil {
|
||||
return fmt.Errorf("base model %s does not contain config.json", opts.ModelDir)
|
||||
}
|
||||
|
||||
layers := make([]create.LayerInfo, 0, len(baseManifest.Manifest.Layers)+len(draftLayers))
|
||||
for _, layer := range baseManifest.Manifest.Layers {
|
||||
layers = append(layers, create.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: layer.Name,
|
||||
})
|
||||
}
|
||||
layers = append(layers, draftLayers...)
|
||||
|
||||
progressFn(fmt.Sprintf("writing manifest for %s", opts.ModelName))
|
||||
return newManifestWriter(opts, baseConfig.Capabilities, baseConfig.Parser, baseConfig.Renderer)(
|
||||
opts.ModelName,
|
||||
create.LayerInfo{
|
||||
Digest: configLayer.Digest,
|
||||
Size: configLayer.Size,
|
||||
MediaType: configLayer.MediaType,
|
||||
Name: configLayer.Name,
|
||||
},
|
||||
layers,
|
||||
)
|
||||
}
|
||||
|
||||
func readConfigV2(m *imagemanifest.ModelManifest) (*model.ConfigV2, error) {
|
||||
data, err := os.ReadFile(m.BlobPath(m.Manifest.Config.Digest))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read base config: %w", err)
|
||||
}
|
||||
|
||||
var cfg model.ConfigV2
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse base config: %w", err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func inferSafetensorsCapabilities(modelDir, parserName string) []string {
|
||||
capabilities := []string{"completion"}
|
||||
|
||||
|
|
@ -359,14 +473,26 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
|
|||
}
|
||||
}
|
||||
|
||||
// Create config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
FileType: strings.ToLower(strings.TrimSpace(opts.Quantize)),
|
||||
Capabilities: caps,
|
||||
Requires: MinOllamaVersion,
|
||||
Parser: resolveParserName(opts.Modelfile, parserName),
|
||||
Renderer: resolveRendererName(opts.Modelfile, rendererName),
|
||||
// Create config blob with version requirement.
|
||||
configData := model.ConfigV2{}
|
||||
if opts.BaseConfig != nil {
|
||||
configData = *opts.BaseConfig
|
||||
}
|
||||
configData.ModelFormat = "safetensors"
|
||||
if opts.Quantize != "" || configData.FileType == "" {
|
||||
configData.FileType = strings.ToLower(strings.TrimSpace(opts.Quantize))
|
||||
}
|
||||
configData.Capabilities = caps
|
||||
configData.Requires = MinOllamaVersion
|
||||
configData.Parser = resolveParserName(opts.Modelfile, parserName)
|
||||
configData.Renderer = resolveRendererName(opts.Modelfile, rendererName)
|
||||
if opts.Modelfile != nil && opts.Modelfile.Draft != "" {
|
||||
configData.Draft = &model.Draft{
|
||||
ModelFormat: "safetensors",
|
||||
Architecture: "Gemma4AssistantForCausalLM",
|
||||
TensorPrefix: "draft.",
|
||||
Config: "draft/config.json",
|
||||
}
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ func TestModelfileConfig(t *testing.T) {
|
|||
func TestConfigFromModelfile(t *testing.T) {
|
||||
modelfile, err := parser.ParseFile(strings.NewReader(`
|
||||
FROM ./model
|
||||
DRAFT ./assistant
|
||||
TEMPLATE {{ .Prompt }}
|
||||
PARAMETER temperature 0.7
|
||||
PARAMETER stop USER:
|
||||
|
|
@ -66,6 +67,10 @@ PARAMETER stop ASSISTANT:
|
|||
t.Fatalf("Template = %q, want %q", mfConfig.Template, "{{ .Prompt }}")
|
||||
}
|
||||
|
||||
if mfConfig.Draft != "./assistant" {
|
||||
t.Fatalf("Draft = %q, want %q", mfConfig.Draft, "./assistant")
|
||||
}
|
||||
|
||||
if got := mfConfig.Parameters["temperature"]; got != float32(0.7) {
|
||||
t.Fatalf("temperature = %#v, want %v", got, float32(0.7))
|
||||
}
|
||||
|
|
@ -153,11 +158,23 @@ func TestCreateModel_NotSafetensorsDir(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCreateModel_DraftQuantizeRequiresDraft(t *testing.T) {
|
||||
err := CreateModel(CreateOptions{
|
||||
ModelName: "test-model",
|
||||
ModelDir: t.TempDir(),
|
||||
DraftQuantize: "mxfp8",
|
||||
}, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "--draft-quantize requires a DRAFT model") {
|
||||
t.Fatalf("error = %v, want draft-quantize requires DRAFT", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions(t *testing.T) {
|
||||
opts := CreateOptions{
|
||||
ModelName: "my-model",
|
||||
ModelDir: "/path/to/model",
|
||||
Quantize: "fp8",
|
||||
ModelName: "my-model",
|
||||
ModelDir: "/path/to/model",
|
||||
Quantize: "fp8",
|
||||
DraftQuantize: "mxfp8",
|
||||
Modelfile: &ModelfileConfig{
|
||||
Template: "test",
|
||||
System: "system",
|
||||
|
|
@ -179,6 +196,9 @@ func TestCreateOptions(t *testing.T) {
|
|||
if opts.Quantize != "fp8" {
|
||||
t.Errorf("Quantize = %q, want %q", opts.Quantize, "fp8")
|
||||
}
|
||||
if opts.DraftQuantize != "mxfp8" {
|
||||
t.Errorf("DraftQuantize = %q, want %q", opts.DraftQuantize, "mxfp8")
|
||||
}
|
||||
if opts.Modelfile == nil {
|
||||
t.Error("Modelfile should not be nil")
|
||||
}
|
||||
|
|
@ -286,6 +306,9 @@ func TestCreateOptions_Defaults(t *testing.T) {
|
|||
if opts.Quantize != "" {
|
||||
t.Errorf("Quantize should be empty by default, got %q", opts.Quantize)
|
||||
}
|
||||
if opts.DraftQuantize != "" {
|
||||
t.Errorf("DraftQuantize should be empty by default, got %q", opts.DraftQuantize)
|
||||
}
|
||||
|
||||
// Modelfile should default to nil
|
||||
if opts.Modelfile != nil {
|
||||
|
|
@ -518,6 +541,48 @@ func TestNewManifestWriter_PopulatesFileTypeFromQuantize(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNewManifestWriter_PopulatesDraftMetadata(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
opts := CreateOptions{
|
||||
ModelName: "test-draft",
|
||||
ModelDir: t.TempDir(),
|
||||
Modelfile: &ModelfileConfig{Draft: "/tmp/assistant"},
|
||||
}
|
||||
|
||||
writer := newManifestWriter(opts, []string{"completion"}, "gemma4", "gemma4")
|
||||
if err := writer(opts.ModelName, create.LayerInfo{}, nil); err != nil {
|
||||
t.Fatalf("newManifestWriter() error = %v", err)
|
||||
}
|
||||
|
||||
name := model.ParseName(opts.ModelName)
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseNamedManifest() error = %v", err)
|
||||
}
|
||||
|
||||
configPath, err := manifest.BlobsPath(mf.Config.Digest)
|
||||
if err != nil {
|
||||
t.Fatalf("BlobsPath() error = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile() error = %v", err)
|
||||
}
|
||||
|
||||
var cfg model.ConfigV2
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if cfg.Draft == nil {
|
||||
t.Fatal("Draft metadata missing")
|
||||
}
|
||||
if cfg.Draft.TensorPrefix != "draft." || cfg.Draft.Config != "draft/config.json" {
|
||||
t.Fatalf("Draft = %#v, want draft prefix/config", cfg.Draft)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsThinking(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
|
|
@ -724,12 +725,16 @@ func detectModelOptQuantization(modelDir string) bool {
|
|||
}
|
||||
|
||||
func resolveEffectiveQuantization(cfg sourceModelConfig, sourceKind sourceQuantizedKind, requested string) (string, error) {
|
||||
return resolveEffectiveQuantizationForFlag(cfg, sourceKind, requested, "--quantize")
|
||||
}
|
||||
|
||||
func resolveEffectiveQuantizationForFlag(cfg sourceModelConfig, sourceKind sourceQuantizedKind, requested, flagName string) (string, error) {
|
||||
switch sourceKind {
|
||||
case sourceQuantizedKindNone:
|
||||
return requested, nil
|
||||
case sourceQuantizedKindPrequantized:
|
||||
if requested != "" {
|
||||
return "", fmt.Errorf("cannot requantize already-quantized source model with --quantize %q", requested)
|
||||
return "", fmt.Errorf("cannot requantize already-quantized source model with %s %q", flagName, requested)
|
||||
}
|
||||
return "", nil
|
||||
case sourceQuantizedKindSourceFP8:
|
||||
|
|
@ -746,7 +751,7 @@ func resolveEffectiveQuantization(cfg sourceModelConfig, sourceKind sourceQuanti
|
|||
case "nvfp4", "mxfp4", "mxfp8":
|
||||
return requested, nil
|
||||
default:
|
||||
return "", fmt.Errorf("cannot convert already-quantized fp8 source model with --quantize %q", requested)
|
||||
return "", fmt.Errorf("cannot convert already-quantized fp8 source model with %s %q", flagName, requested)
|
||||
}
|
||||
}
|
||||
return "mxfp8", nil
|
||||
|
|
@ -810,6 +815,7 @@ var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{
|
|||
"Gemma4ForCausalLM": newGemma4ImportTransform,
|
||||
"Gemma4ForConditionalGeneration": newGemma4ImportTransform,
|
||||
"LagunaForCausalLM": newLagunaImportTransform,
|
||||
"Gemma4AssistantForCausalLM": newGemma4ImportTransform,
|
||||
}
|
||||
|
||||
func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) {
|
||||
|
|
@ -1167,6 +1173,136 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
|||
return nil
|
||||
}
|
||||
|
||||
func normalizeRequestedQuantization(flagName, quantize string) (string, error) {
|
||||
q := normalizeQuantType(strings.TrimSpace(quantize))
|
||||
switch q {
|
||||
case "", "int4", "int8", "nvfp4", "mxfp4", "mxfp8":
|
||||
return q, nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported %s %q: supported types are int4, int8, nvfp4, mxfp4, mxfp8", flagName, quantize)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDraftSafetensorsLayers imports an assistant/draft safetensors model
|
||||
// into prefixed tensor and config layers. When draftQuantize is non-empty,
|
||||
// eligible draft tensors are quantized with the same per-architecture policy
|
||||
// used by target safetensors imports.
|
||||
func CreateDraftSafetensorsLayers(modelDir, tensorPrefix, configPrefix, draftQuantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, fn func(status string)) ([]LayerInfo, error) {
|
||||
if tensorPrefix == "" {
|
||||
return nil, fmt.Errorf("draft tensor prefix must not be empty")
|
||||
}
|
||||
if configPrefix == "" {
|
||||
return nil, fmt.Errorf("draft config prefix must not be empty")
|
||||
}
|
||||
effectiveQuantize, err := normalizeRequestedQuantization("--draft-quantize", draftQuantize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var importTransform tensorImportTransform = noopImportTransform{}
|
||||
if effectiveQuantize != "" {
|
||||
sourceConfig, err := readSourceModelConfig(modelDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read draft config.json: %w", err)
|
||||
}
|
||||
sourceQuantKind, err := inspectSourceQuantization(modelDir, sourceConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to inspect draft quantization: %w", err)
|
||||
}
|
||||
effectiveQuantize, err = resolveEffectiveQuantizationForFlag(sourceConfig, sourceQuantKind, effectiveQuantize, "--draft-quantize")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
importTransform, err = newTensorImportTransform(modelDir, sourceConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to construct draft import transform for architecture %q: %w", sourceConfig.Architecture(), err)
|
||||
}
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(modelDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read draft directory: %w", err)
|
||||
}
|
||||
|
||||
var layers []LayerInfo
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
continue
|
||||
}
|
||||
|
||||
stPath := filepath.Join(modelDir, entry.Name())
|
||||
extractor, err := safetensors.OpenForExtraction(stPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open draft %s: %w", stPath, err)
|
||||
}
|
||||
|
||||
tensorNames := extractor.ListTensors()
|
||||
fn(fmt.Sprintf("importing draft %s (%d tensors%s)", entry.Name(), len(tensorNames), importQuantizationStatus(sourceQuantizedKindNone, effectiveQuantize)))
|
||||
for _, tensorName := range tensorNames {
|
||||
if importTransform.skipTensor(tensorName) {
|
||||
continue
|
||||
}
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return nil, fmt.Errorf("failed to get draft tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
outTDs, err := importTransform.transformTensor(td)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return nil, fmt.Errorf("failed to transform draft tensor %s: %w", tensorName, err)
|
||||
}
|
||||
for _, transformedTD := range outTDs {
|
||||
if transformedTD == nil {
|
||||
continue
|
||||
}
|
||||
outTD := transformedTD.WithName(tensorPrefix + transformedTD.Name)
|
||||
quantizeType := ""
|
||||
if effectiveQuantize != "" {
|
||||
quantizeType = importTransform.quantizationType(outTD.Name, outTD.Shape, effectiveQuantize)
|
||||
if isEmbedTokensWeight(outTD.Name) {
|
||||
quantizeType = ""
|
||||
}
|
||||
}
|
||||
newLayers, err := createTensorLayer(outTD.SafetensorsReader(), outTD.Name, outTD.Dtype, outTD.Shape, quantizeType)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return nil, fmt.Errorf("failed to create draft layer for %s: %w", tensorName, err)
|
||||
}
|
||||
layers = append(layers, newLayers...)
|
||||
}
|
||||
}
|
||||
extractor.Close()
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
if entry.Name() == "model.safetensors.index.json" {
|
||||
continue
|
||||
}
|
||||
|
||||
cfgPath := entry.Name()
|
||||
fullPath := filepath.Join(modelDir, cfgPath)
|
||||
fn(fmt.Sprintf("importing draft config %s", cfgPath))
|
||||
|
||||
f, err := os.Open(fullPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open draft %s: %w", cfgPath, err)
|
||||
}
|
||||
layer, err := createLayer(f, "application/vnd.ollama.image.json", path.Join(configPrefix, cfgPath))
|
||||
f.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create draft config layer for %s: %w", cfgPath, err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func shouldSkipSourceCompanion(name string, tensorSet map[string]struct{}, sourceTensorFiles map[string]string) bool {
|
||||
switch {
|
||||
case strings.HasSuffix(name, ".scales"):
|
||||
|
|
|
|||
|
|
@ -452,6 +452,137 @@ func TestCreateSafetensorsModel(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCreateDraftSafetensorsLayersPrefixesTensorsAndConfigs(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type":"gemma4_assistant"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{2, 2}, make([]byte, 8)),
|
||||
})
|
||||
|
||||
var tensorNames []string
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
return LayerInfo{Digest: "sha256:json_" + name, Size: int64(len(data)), MediaType: mediaType, Name: name}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tensorNames = append(tensorNames, name)
|
||||
tensorName, tensorShape := readSingleTensorNameAndShape(t, data)
|
||||
if tensorName != name {
|
||||
t.Fatalf("safetensors key = %q, want %q", tensorName, name)
|
||||
}
|
||||
if !slices.Equal(tensorShape, shape) {
|
||||
t.Fatalf("shape = %v, want %v", tensorShape, shape)
|
||||
}
|
||||
if quantize != "" {
|
||||
t.Fatalf("draft quantize = %q, want empty", quantize)
|
||||
}
|
||||
return []LayerInfo{{Digest: "sha256:tensor_" + name, Size: int64(len(data)), MediaType: "application/vnd.ollama.image.tensor", Name: name}}, nil
|
||||
}
|
||||
|
||||
layers, err := CreateDraftSafetensorsLayers(dir, "draft.", "draft", "", createLayer, createTensorLayer, func(string) {})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Contains(tensorNames, "draft.model.layers.0.self_attn.q_proj.weight") {
|
||||
t.Fatalf("draft tensor was not prefixed: %v", tensorNames)
|
||||
}
|
||||
var hasDraftConfig bool
|
||||
for _, layer := range layers {
|
||||
if layer.Name == "draft/config.json" && layer.MediaType == "application/vnd.ollama.image.json" {
|
||||
hasDraftConfig = true
|
||||
}
|
||||
}
|
||||
if !hasDraftConfig {
|
||||
t.Fatalf("draft/config.json layer missing: %#v", layers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateDraftSafetensorsLayersQuantizesEligibleTensors(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{
|
||||
"architectures":["Gemma4AssistantForCausalLM"],
|
||||
"num_hidden_layers":8
|
||||
}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
createTestSafetensors(t, filepath.Join(dir, "model.safetensors"), []*st.TensorData{
|
||||
st.NewTensorDataFromBytes("model.embed_tokens.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.self_attn.q_proj.weight", "BF16", []int32{64, 64}, make([]byte, 64*64*2)),
|
||||
st.NewTensorDataFromBytes("model.layers.0.input_layernorm.weight", "BF16", []int32{64}, make([]byte, 64*2)),
|
||||
})
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
return LayerInfo{Digest: "sha256:json_" + name, Size: int64(len(data)), MediaType: mediaType, Name: name}, nil
|
||||
}
|
||||
quantizeByName := make(map[string]string)
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
quantizeByName[name] = quantize
|
||||
return []LayerInfo{{Digest: "sha256:tensor_" + name, Size: int64(len(data)), MediaType: "application/vnd.ollama.image.tensor", Name: name}}, nil
|
||||
}
|
||||
|
||||
if _, err := CreateDraftSafetensorsLayers(dir, "draft.", "draft", "MXFP8", createLayer, createTensorLayer, func(string) {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got := quantizeByName["draft.model.layers.0.self_attn.q_proj.weight"]; got != "mxfp8" {
|
||||
t.Fatalf("q_proj draft quantize = %q, want mxfp8", got)
|
||||
}
|
||||
if got := quantizeByName["draft.model.layers.0.input_layernorm.weight"]; got != "" {
|
||||
t.Fatalf("norm draft quantize = %q, want empty", got)
|
||||
}
|
||||
if got := quantizeByName["draft.model.embed_tokens.weight"]; got != "" {
|
||||
t.Fatalf("embed_tokens draft quantize = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateDraftSafetensorsLayersRejectsUnsupportedDraftQuantize(t *testing.T) {
|
||||
_, err := CreateDraftSafetensorsLayers(t.TempDir(), "draft.", "draft", "bogus", nil, nil, func(string) {})
|
||||
if err == nil || !strings.Contains(err.Error(), "unsupported --draft-quantize") {
|
||||
t.Fatalf("error = %v, want unsupported --draft-quantize", err)
|
||||
}
|
||||
}
|
||||
|
||||
func readSingleTensorNameAndShape(t *testing.T, data []byte) (string, []int32) {
|
||||
t.Helper()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &headerSize); err != nil {
|
||||
t.Fatalf("failed to read header size: %v", err)
|
||||
}
|
||||
|
||||
var header map[string]struct {
|
||||
Shape []int32 `json:"shape"`
|
||||
}
|
||||
if err := json.Unmarshal(data[8:8+headerSize], &header); err != nil {
|
||||
t.Fatalf("failed to parse header: %v", err)
|
||||
}
|
||||
for name, info := range header {
|
||||
if name != "__metadata__" {
|
||||
return name, info.Shape
|
||||
}
|
||||
}
|
||||
t.Fatal("no tensor entry found in header")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
|
|
|
|||
293
x/mlxrunner/cache/cache.go
vendored
293
x/mlxrunner/cache/cache.go
vendored
|
|
@ -54,6 +54,81 @@ type Attention interface {
|
|||
Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory
|
||||
}
|
||||
|
||||
// Viewer exposes a read-only attention history for a cache.
|
||||
type Viewer interface {
|
||||
View(b *batch.Batch) *nn.KVHistory
|
||||
}
|
||||
|
||||
type speculativeCommitter interface {
|
||||
Cache
|
||||
commit(n int)
|
||||
}
|
||||
|
||||
// Speculation is an isolated cache transaction for speculative target
|
||||
// validation. Updates record generated K/V without mutating the live caches;
|
||||
// Commit appends only the accepted prefix to the live caches.
|
||||
type Speculation struct {
|
||||
layers []speculativeCommitter
|
||||
}
|
||||
|
||||
// BeginSpeculation returns cache wrappers suitable for a speculative target
|
||||
// forward. The returned caches must only be used for that forward.
|
||||
func BeginSpeculation(caches []Cache) ([]Cache, *Speculation, bool) {
|
||||
specCaches := make([]Cache, len(caches))
|
||||
layers := make([]speculativeCommitter, len(caches))
|
||||
|
||||
for i, c := range caches {
|
||||
switch c := c.(type) {
|
||||
case nil:
|
||||
case *RotatingKVCache:
|
||||
sc := newSpeculativeRotatingKVCache(c)
|
||||
specCaches[i] = sc
|
||||
layers[i] = sc
|
||||
case *KVCache:
|
||||
sc := newSpeculativeKVCache(c)
|
||||
specCaches[i] = sc
|
||||
layers[i] = sc
|
||||
default:
|
||||
return nil, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
return specCaches, &Speculation{layers: layers}, true
|
||||
}
|
||||
|
||||
// BeginIsolatedSpeculation returns cache wrappers that never mutate live cache
|
||||
// state. It is intended for correctness instrumentation, not the hot path.
|
||||
func BeginIsolatedSpeculation(caches []Cache) ([]Cache, bool) {
|
||||
specCaches := make([]Cache, len(caches))
|
||||
|
||||
for i, c := range caches {
|
||||
switch c := c.(type) {
|
||||
case nil:
|
||||
case *RotatingKVCache:
|
||||
specCaches[i] = newSpeculativeRotatingKVCache(c)
|
||||
case *KVCache:
|
||||
specCaches[i] = newIsolatedKVCache(c)
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
return specCaches, true
|
||||
}
|
||||
|
||||
// Commit appends the accepted prefix from the speculative forward to the live
|
||||
// caches. The target bonus token is intentionally not committed.
|
||||
func (s *Speculation) Commit(n int) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
for _, layer := range s.layers {
|
||||
if layer != nil {
|
||||
layer.commit(n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type KVCache struct {
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
|
|
@ -70,6 +145,15 @@ func (c *KVCache) Update(_ *batch.Batch, keys, values *mlx.Array) *nn.KVHistory
|
|||
return nn.NewKVHistory(newK, newV, nil)
|
||||
}
|
||||
|
||||
// View returns the current cache contents as attention history without writing.
|
||||
func (c *KVCache) View(_ *batch.Batch) *nn.KVHistory {
|
||||
state := c.State()
|
||||
if len(state) < 2 {
|
||||
return nil
|
||||
}
|
||||
return nn.NewKVHistory(state[0], state[1], nil)
|
||||
}
|
||||
|
||||
// appendKV is the raw write path shared by Update and Restore.
|
||||
func (c *KVCache) appendKV(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
|
||||
|
|
@ -250,6 +334,94 @@ func (c *KVCache) Free() {
|
|||
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
|
||||
type speculativeBase struct {
|
||||
offset int
|
||||
}
|
||||
|
||||
func (s *speculativeBase) Free() {}
|
||||
func (s *speculativeBase) Offset() int { return s.offset }
|
||||
func (s *speculativeBase) Snapshot(int) Snapshot { return nil }
|
||||
func (s *speculativeBase) Restore(Snapshot, int) bool { return false }
|
||||
func (s *speculativeBase) Merge(parent, child Snapshot) Snapshot { return nil }
|
||||
func (s *speculativeBase) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
|
||||
return nil, snapshot
|
||||
}
|
||||
|
||||
type speculativeKVCache struct {
|
||||
speculativeBase
|
||||
target *KVCache
|
||||
start int
|
||||
end int
|
||||
}
|
||||
|
||||
func newSpeculativeKVCache(target *KVCache) *speculativeKVCache {
|
||||
return &speculativeKVCache{
|
||||
speculativeBase: speculativeBase{offset: target.Offset()},
|
||||
target: target,
|
||||
start: target.Offset(),
|
||||
end: target.Offset(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *speculativeKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
|
||||
history := c.target.Update(b, keys, values)
|
||||
c.offset = c.target.Offset()
|
||||
c.end = c.target.Offset()
|
||||
return history
|
||||
}
|
||||
|
||||
func (c *speculativeKVCache) State() []*mlx.Array {
|
||||
return c.target.State()
|
||||
}
|
||||
|
||||
func (c *speculativeKVCache) commit(n int) {
|
||||
target := max(c.start, c.start+n)
|
||||
if target > c.end {
|
||||
target = c.end
|
||||
}
|
||||
c.target.offset = target
|
||||
c.offset = target
|
||||
}
|
||||
|
||||
type isolatedKVCache struct {
|
||||
speculativeBase
|
||||
target *KVCache
|
||||
keys, values *mlx.Array
|
||||
}
|
||||
|
||||
func newIsolatedKVCache(target *KVCache) *isolatedKVCache {
|
||||
return &isolatedKVCache{
|
||||
speculativeBase: speculativeBase{offset: target.Offset()},
|
||||
target: target,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *isolatedKVCache) Update(_ *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
|
||||
c.keys = concatKV(c.keys, keys)
|
||||
c.values = concatKV(c.values, values)
|
||||
c.offset += keys.Dim(2)
|
||||
|
||||
state := c.target.State()
|
||||
if len(state) < 2 {
|
||||
return nn.NewKVHistory(c.keys, c.values, nil)
|
||||
}
|
||||
return nn.NewKVHistory(state[0].Concatenate(2, c.keys), state[1].Concatenate(2, c.values), nil)
|
||||
}
|
||||
|
||||
func (c *isolatedKVCache) State() []*mlx.Array {
|
||||
if c.keys == nil || c.values == nil {
|
||||
return c.target.State()
|
||||
}
|
||||
state := c.target.State()
|
||||
if len(state) < 2 {
|
||||
return []*mlx.Array{c.keys, c.values}
|
||||
}
|
||||
return []*mlx.Array{
|
||||
state[0].Concatenate(2, c.keys),
|
||||
state[1].Concatenate(2, c.values),
|
||||
}
|
||||
}
|
||||
|
||||
// RotatingKVCache implements sliding window attention with bounded memory
|
||||
type RotatingKVCache struct {
|
||||
maxSize int
|
||||
|
|
@ -275,6 +447,127 @@ func (c *RotatingKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KV
|
|||
})
|
||||
}
|
||||
|
||||
// View returns the current rotating cache contents in logical order for
|
||||
// assistant KV sharing.
|
||||
func (c *RotatingKVCache) View(_ *batch.Batch) *nn.KVHistory {
|
||||
k, v := c.logicalTail(c.maxSize - 1)
|
||||
if k == nil || v == nil {
|
||||
return nil
|
||||
}
|
||||
return nn.NewKVHistory(k, v, nil)
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) logicalTail(keep int) (*mlx.Array, *mlx.Array) {
|
||||
state := c.State()
|
||||
if len(state) < 2 || keep <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
keys, values := state[0], state[1]
|
||||
K := keys.Dim(2)
|
||||
if K == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
keep = min(keep, K)
|
||||
if K > c.maxSize || c.offset < c.maxSize {
|
||||
start := K - keep
|
||||
return keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice()),
|
||||
values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice())
|
||||
}
|
||||
|
||||
oldest := c.idx % K
|
||||
var logicalK, logicalV *mlx.Array
|
||||
if oldest == 0 {
|
||||
logicalK, logicalV = keys, values
|
||||
} else {
|
||||
tailK := keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(oldest, K), mlx.Slice())
|
||||
tailV := values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(oldest, K), mlx.Slice())
|
||||
headK := keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, oldest), mlx.Slice())
|
||||
headV := values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, oldest), mlx.Slice())
|
||||
logicalK = tailK.Concatenate(2, headK)
|
||||
logicalV = tailV.Concatenate(2, headV)
|
||||
}
|
||||
|
||||
start := K - keep
|
||||
return logicalK.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice()),
|
||||
logicalV.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice())
|
||||
}
|
||||
|
||||
type speculativeRotatingKVCache struct {
|
||||
speculativeBase
|
||||
target *RotatingKVCache
|
||||
keys, values *mlx.Array
|
||||
}
|
||||
|
||||
func newSpeculativeRotatingKVCache(target *RotatingKVCache) *speculativeRotatingKVCache {
|
||||
return &speculativeRotatingKVCache{
|
||||
speculativeBase: speculativeBase{offset: target.Offset()},
|
||||
target: target,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *speculativeRotatingKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
|
||||
c.keys = concatKV(c.keys, keys)
|
||||
c.values = concatKV(c.values, values)
|
||||
c.offset += keys.Dim(2)
|
||||
|
||||
oldK, oldV := c.target.logicalTail(c.target.maxSize - 1)
|
||||
histK, histV := c.keys, c.values
|
||||
if oldK != nil && oldV != nil {
|
||||
histK = oldK.Concatenate(2, c.keys)
|
||||
histV = oldV.Concatenate(2, c.values)
|
||||
}
|
||||
|
||||
return nn.NewKVHistory(histK, histV, logicalSlidingApplier{
|
||||
b: b,
|
||||
K: histK.Dim(2),
|
||||
window: c.target.maxSize,
|
||||
dtype: keys.DType(),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *speculativeRotatingKVCache) State() []*mlx.Array {
|
||||
if c.keys == nil || c.values == nil {
|
||||
return c.target.State()
|
||||
}
|
||||
oldK, oldV := c.target.logicalTail(c.target.maxSize - 1)
|
||||
if oldK == nil || oldV == nil {
|
||||
return []*mlx.Array{c.keys, c.values}
|
||||
}
|
||||
return []*mlx.Array{oldK.Concatenate(2, c.keys), oldV.Concatenate(2, c.values)}
|
||||
}
|
||||
|
||||
func (c *speculativeRotatingKVCache) commit(n int) {
|
||||
if c.keys == nil || c.values == nil || n <= 0 {
|
||||
return
|
||||
}
|
||||
n = min(n, c.keys.Dim(2))
|
||||
c.target.appendKV(prefixKV(c.keys, n), prefixKV(c.values, n))
|
||||
}
|
||||
|
||||
type logicalSlidingApplier struct {
|
||||
b *batch.Batch
|
||||
K int
|
||||
window int
|
||||
dtype mlx.DType
|
||||
}
|
||||
|
||||
func (a logicalSlidingApplier) ApplyMask(logical nn.AttentionMask) nn.AttentionMask {
|
||||
return logical.Intersect(nn.SlidingWindowMask(a.b, a.K, a.window, a.dtype))
|
||||
}
|
||||
|
||||
func concatKV(prev, next *mlx.Array) *mlx.Array {
|
||||
if prev == nil {
|
||||
return next
|
||||
}
|
||||
return prev.Concatenate(2, next)
|
||||
}
|
||||
|
||||
func prefixKV(a *mlx.Array, n int) *mlx.Array {
|
||||
return a.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, n), mlx.Slice())
|
||||
}
|
||||
|
||||
// appendKV is the raw write path shared by Update and Restore —
|
||||
// routes to concat for prefill (L > 1) and update for decode.
|
||||
func (c *RotatingKVCache) appendKV(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
|
|
|
|||
55
x/mlxrunner/cache/rotating_attention_test.go
vendored
55
x/mlxrunner/cache/rotating_attention_test.go
vendored
|
|
@ -114,6 +114,61 @@ func TestRotatingKVCacheDecodeParity(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAssistantSharedHistoryL1MasksMatchNoMask(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
if !mlx.MetalIsAvailable() {
|
||||
t.Skip("MLX Metal not available")
|
||||
}
|
||||
const H, D = 1, 4
|
||||
const window = 4
|
||||
const total = 7
|
||||
const scale = 1.0
|
||||
|
||||
q := mlx.FromValues([]float32{0.7, -0.4, 0.2, 0.9}, 1, H, 1, D)
|
||||
mlx.Eval(q)
|
||||
|
||||
full := NewKVCache()
|
||||
sliding := NewRotatingKVCache(window)
|
||||
for pos := range total {
|
||||
kVals := make([]float32, H*D)
|
||||
vVals := make([]float32, H*D)
|
||||
for i := range kVals {
|
||||
kVals[i] = 0.1*float32(pos+1) + 0.01*float32(i)
|
||||
vVals[i] = -0.1*float32(pos+1) + 0.01*float32(i)
|
||||
}
|
||||
k := mlx.FromValues(kVals, 1, H, 1, D)
|
||||
v := mlx.FromValues(vVals, 1, H, 1, D)
|
||||
full.Update(newKVBatch(full.Offset(), 1), k, v)
|
||||
sliding.Update(newKVBatch(sliding.Offset(), 1), k, v)
|
||||
}
|
||||
|
||||
b := newKVBatch(total-1, 1)
|
||||
slidingHistory := sliding.View(b)
|
||||
cases := []struct {
|
||||
name string
|
||||
h *nn.KVHistory
|
||||
mask nn.AttentionMask
|
||||
}{
|
||||
{name: "full", h: full.View(b), mask: nn.CausalMask()},
|
||||
{name: "sliding", h: slidingHistory, mask: nn.CausalMask().Intersect(nn.SlidingWindowMask(b, slidingHistory.K().Dim(2), window, q.DType()))},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := nn.ScaledDotProductAttention(b, q, scale, nn.WithKVHistory(tc.h), nn.WithMask(tc.mask))
|
||||
want := mlx.FastScaledDotProductAttention(q, tc.h.K(), tc.h.V(), scale, "", nil)
|
||||
|
||||
mlx.Eval(got, want)
|
||||
gs, ws := got.Floats(), want.Floats()
|
||||
for i := range ws {
|
||||
if math.Abs(float64(gs[i]-ws[i])) > 1e-5 {
|
||||
t.Fatalf("index %d: got %v, want %v", i, gs[i], ws[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCachePrefillParity drives an L>1 prefill into a
|
||||
// rotating cache and verifies SDPA output through WithKVHistory
|
||||
// matches a reference computed from the same K/V with the model mask
|
||||
|
|
|
|||
|
|
@ -41,3 +41,36 @@ func TestFromValues(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestComparisonOpsAndBernoulli(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
a := FromValues([]float32{1, 2, 3}, 3)
|
||||
b := FromValues([]float32{1, 1, 4}, 3)
|
||||
eq := a.Equal(b).AsType(DTypeInt32)
|
||||
gt := a.Greater(b).AsType(DTypeInt32)
|
||||
le := a.LessEqual(b).AsType(DTypeInt32)
|
||||
bern := Bernoulli(FromValues([]float32{1, 0}, 2)).AsType(DTypeInt32)
|
||||
Eval(eq, gt, le, bern)
|
||||
|
||||
for name, tc := range map[string]struct {
|
||||
got []int
|
||||
want []int
|
||||
}{
|
||||
"equal": {eq.Ints(), []int{1, 0, 0}},
|
||||
"greater": {gt.Ints(), []int{0, 1, 0}},
|
||||
"lessEqual": {le.Ints(), []int{1, 0, 1}},
|
||||
"bernoulli": {bern.Ints(), []int{1, 0}},
|
||||
} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if len(tc.got) != len(tc.want) {
|
||||
t.Fatalf("got %v, want %v", tc.got, tc.want)
|
||||
}
|
||||
for i := range tc.want {
|
||||
if tc.got[i] != tc.want[i] {
|
||||
t.Fatalf("got %v, want %v", tc.got, tc.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -137,12 +137,30 @@ func (t *Array) LogsumexpAxis(axis int, keepDims bool) *Array {
|
|||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Equal(other *Array) *Array {
|
||||
out := New("EQUAL")
|
||||
C.mlx_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Greater(other *Array) *Array {
|
||||
out := New("GREATER")
|
||||
C.mlx_greater(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Less(other *Array) *Array {
|
||||
out := New("LESS")
|
||||
C.mlx_less(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) LessEqual(other *Array) *Array {
|
||||
out := New("LESS_EQUAL")
|
||||
C.mlx_less_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) MaxAxis(axis int, keepDims bool) *Array {
|
||||
out := New("MAX_AXIS")
|
||||
C.mlx_max_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
|
|
|
|||
|
|
@ -3,9 +3,24 @@ package mlx
|
|||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import "unsafe"
|
||||
|
||||
func (t *Array) Categorical(axis int) *Array {
|
||||
key := New("")
|
||||
out := New("")
|
||||
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Bernoulli(p *Array) *Array {
|
||||
dims := p.Dims()
|
||||
shape := make([]C.int, len(dims))
|
||||
for i, d := range dims {
|
||||
shape[i] = C.int(d)
|
||||
}
|
||||
|
||||
key := New("")
|
||||
out := New("BERNOULLI")
|
||||
C.mlx_random_bernoulli(&out.ctx, p.ctx, unsafe.SliceData(shape), C.size_t(len(shape)), key.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,9 +27,40 @@ type Model interface {
|
|||
LoadWeights(tensors map[string]*mlx.Array) error
|
||||
}
|
||||
|
||||
// DraftModel is an auxiliary model stored alongside a target model.
|
||||
type DraftModel interface {
|
||||
LoadWeights(tensors map[string]*mlx.Array) error
|
||||
}
|
||||
|
||||
// MTPDefaults holds model-provided draft-token defaults for speculative
|
||||
// decoding. Environment settings in the runner may override these values.
|
||||
type MTPDefaults struct {
|
||||
InitialDraftTokens int
|
||||
MaxDraftTokens int
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// MTPDefaultsProvider lets a model provide MTP policy defaults from its own
|
||||
// config without teaching the runner model-specific shape heuristics.
|
||||
type MTPDefaultsProvider interface {
|
||||
MTPDraftDefaults(sample bool) MTPDefaults
|
||||
}
|
||||
|
||||
// MTPDraftModel is a draft model capable of Gemma-style multi-token
|
||||
// prediction from target token embeddings, target hidden states, and target KV.
|
||||
type MTPDraftModel interface {
|
||||
Draft(inputEmbeds *mlx.Array, position int32, caches []cache.Cache) (logits, hidden *mlx.Array)
|
||||
}
|
||||
|
||||
// MTPEmbeddingModel exposes the target token embedding path used by MTP drafts.
|
||||
type MTPEmbeddingModel interface {
|
||||
TokenEmbeddings(inputIDs *mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
registry = make(map[string]func(root *model.Root) (Model, error))
|
||||
mu sync.Mutex
|
||||
registry = make(map[string]func(root *model.Root) (Model, error))
|
||||
draftRegistry = make(map[string]func(root *model.Root, target Model) (DraftModel, error))
|
||||
)
|
||||
|
||||
// Register registers a model constructor by architecture name.
|
||||
|
|
@ -44,6 +75,17 @@ func Register(arch string, fn func(root *model.Root) (Model, error)) {
|
|||
registry[arch] = fn
|
||||
}
|
||||
|
||||
// RegisterDraft registers a draft model constructor by architecture name.
|
||||
func RegisterDraft(arch string, fn func(root *model.Root, target Model) (DraftModel, error)) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if _, exists := draftRegistry[arch]; exists {
|
||||
panic(fmt.Sprintf("draft model architecture %q already registered", arch))
|
||||
}
|
||||
draftRegistry[arch] = fn
|
||||
}
|
||||
|
||||
// New reads config.json from the manifest, detects the architecture, looks up
|
||||
// the registered constructor, and calls it to create the model (with config
|
||||
// parsed and struct created, but weights not yet loaded).
|
||||
|
|
@ -78,6 +120,51 @@ func New(root *model.Root) (Model, error) {
|
|||
return fn(root)
|
||||
}
|
||||
|
||||
// NewDraft constructs the draft model described by the manifest config, if any.
|
||||
func NewDraft(root *model.Root, target Model) (DraftModel, error) {
|
||||
if root == nil || root.Draft == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
configPath := root.Draft.Config
|
||||
if configPath == "" {
|
||||
configPath = "draft/config.json"
|
||||
}
|
||||
configData, err := root.Manifest.ReadConfig(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read %s: %w", configPath, err)
|
||||
}
|
||||
|
||||
var archConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(configData, &archConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse %s: %w", configPath, err)
|
||||
}
|
||||
|
||||
arch := root.Draft.Architecture
|
||||
if arch == "" && len(archConfig.Architectures) > 0 {
|
||||
arch = archConfig.Architectures[0]
|
||||
}
|
||||
if arch == "" {
|
||||
arch = archConfig.ModelType
|
||||
}
|
||||
if arch == "" {
|
||||
return nil, fmt.Errorf("no draft architecture found in %s", configPath)
|
||||
}
|
||||
slog.Info("Draft model architecture", "arch", arch)
|
||||
|
||||
mu.Lock()
|
||||
fn, ok := draftRegistry[arch]
|
||||
mu.Unlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported draft architecture: %s", arch)
|
||||
}
|
||||
|
||||
return fn(root, target)
|
||||
}
|
||||
|
||||
// Weights returns a function that loads model weights, then pins all
|
||||
// arrays reachable from the model struct and sweeps everything else.
|
||||
func Weights(m Model) func(map[string]*mlx.Array) error {
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
modeltypes "github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
|
|
@ -22,6 +23,7 @@ type TensorQuantInfo struct {
|
|||
// Root wraps a ModelManifest with pre-scanned quantization metadata.
|
||||
type Root struct {
|
||||
Manifest *manifest.ModelManifest
|
||||
Draft *modeltypes.Draft
|
||||
|
||||
// Backwards-compatible model-level quant metadata (first tensor blob).
|
||||
quantType string
|
||||
|
|
@ -43,6 +45,7 @@ func Open(modelName string) (*Root, error) {
|
|||
Manifest: m,
|
||||
tensorQuant: make(map[string]*TensorQuantInfo),
|
||||
}
|
||||
root.Draft = readDraftConfig(m)
|
||||
|
||||
for _, layer := range m.GetTensorLayers("") {
|
||||
blobPath := m.BlobPath(layer.Digest)
|
||||
|
|
@ -68,6 +71,34 @@ func Open(modelName string) (*Root, error) {
|
|||
return root, nil
|
||||
}
|
||||
|
||||
func readDraftConfig(m *manifest.ModelManifest) *modeltypes.Draft {
|
||||
if m == nil || m.Manifest == nil || m.Manifest.Config.Digest == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(m.BlobPath(m.Manifest.Config.Digest))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var cfg modeltypes.ConfigV2
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil
|
||||
}
|
||||
if cfg.Draft != nil {
|
||||
return cfg.Draft
|
||||
}
|
||||
|
||||
if m.GetConfigLayer("draft/config.json") != nil {
|
||||
return &modeltypes.Draft{
|
||||
ModelFormat: "safetensors",
|
||||
TensorPrefix: "draft.",
|
||||
Config: "draft/config.json",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close is a no-op for now (future: release resources).
|
||||
func (r *Root) Close() {}
|
||||
|
||||
|
|
|
|||
959
x/mlxrunner/mtp.go
Normal file
959
x/mlxrunner/mtp.go
Normal file
|
|
@ -0,0 +1,959 @@
|
|||
package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
)
|
||||
|
||||
const (
|
||||
mtpDefaultInitialDraftTokens = 4
|
||||
mtpDefaultMaxDraftTokens = 16
|
||||
)
|
||||
|
||||
type mtpDraftSchedule string
|
||||
|
||||
const (
|
||||
mtpDraftScheduleHeuristic mtpDraftSchedule = "heuristic"
|
||||
mtpDraftScheduleConstant mtpDraftSchedule = "constant"
|
||||
)
|
||||
|
||||
type mtpStats struct {
|
||||
iterations int
|
||||
drafted int
|
||||
accepted int
|
||||
mismatches int
|
||||
allAccepted int
|
||||
batched int
|
||||
serial int
|
||||
compared int
|
||||
batchSerialMismatches int
|
||||
maxDraft int
|
||||
targetDuration time.Duration
|
||||
draftDuration time.Duration
|
||||
validateDuration time.Duration
|
||||
}
|
||||
|
||||
type mtpOptions struct {
|
||||
initialDraftTokens int
|
||||
maxDraftTokens int
|
||||
draftSchedule mtpDraftSchedule
|
||||
serialValidate bool
|
||||
compareSerialValidate bool
|
||||
}
|
||||
|
||||
func (r *Runner) mtpDefaults(sample bool) base.MTPDefaults {
|
||||
defaults := base.MTPDefaults{
|
||||
InitialDraftTokens: mtpDefaultInitialDraftTokens,
|
||||
MaxDraftTokens: mtpDefaultMaxDraftTokens,
|
||||
Enabled: true,
|
||||
}
|
||||
if p, ok := r.Model.(base.MTPDefaultsProvider); ok {
|
||||
defaults = p.MTPDraftDefaults(sample)
|
||||
}
|
||||
if defaults.InitialDraftTokens <= 0 {
|
||||
defaults.InitialDraftTokens = mtpDefaultInitialDraftTokens
|
||||
}
|
||||
if defaults.MaxDraftTokens <= 0 {
|
||||
defaults.MaxDraftTokens = mtpDefaultMaxDraftTokens
|
||||
}
|
||||
return defaults
|
||||
}
|
||||
|
||||
func (r *Runner) loadMTPOptions(sample bool) mtpOptions {
|
||||
defaults := r.mtpDefaults(sample)
|
||||
|
||||
opts := mtpOptions{
|
||||
initialDraftTokens: defaults.InitialDraftTokens,
|
||||
maxDraftTokens: defaults.MaxDraftTokens,
|
||||
draftSchedule: mtpDraftScheduleConstant,
|
||||
}
|
||||
if v := positiveEnvInt("OLLAMA_MLX_MTP_MAX_DRAFT_TOKENS"); v > 0 {
|
||||
opts.maxDraftTokens = v
|
||||
}
|
||||
if v := positiveEnvInt("OLLAMA_MLX_MTP_INITIAL_DRAFT_TOKENS"); v > 0 {
|
||||
opts.initialDraftTokens = v
|
||||
}
|
||||
if opts.initialDraftTokens > opts.maxDraftTokens {
|
||||
opts.initialDraftTokens = opts.maxDraftTokens
|
||||
}
|
||||
if b, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_SERIAL_VALIDATE")); err == nil {
|
||||
opts.serialValidate = b
|
||||
}
|
||||
if b, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_COMPARE_SERIAL_VALIDATE")); err == nil {
|
||||
opts.compareSerialValidate = b
|
||||
}
|
||||
switch schedule := strings.ToLower(strings.TrimSpace(os.Getenv("OLLAMA_MLX_MTP_DRAFT_SCHEDULE"))); schedule {
|
||||
case "", string(mtpDraftScheduleConstant):
|
||||
opts.draftSchedule = mtpDraftScheduleConstant
|
||||
case string(mtpDraftScheduleHeuristic):
|
||||
opts.draftSchedule = mtpDraftScheduleHeuristic
|
||||
default:
|
||||
slog.Warn("invalid MTP env setting", "key", "OLLAMA_MLX_MTP_DRAFT_SCHEDULE", "value", schedule)
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func positiveEnvInt(key string) int {
|
||||
raw := os.Getenv(key)
|
||||
if raw == "" {
|
||||
return 0
|
||||
}
|
||||
v, err := strconv.Atoi(raw)
|
||||
if err != nil || v <= 0 {
|
||||
slog.Warn("invalid MTP env setting", "key", key, "value", raw)
|
||||
return 0
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (r *Runner) useGreedyMTP(opts sampler.Options) bool {
|
||||
if r.Draft == nil {
|
||||
return false
|
||||
}
|
||||
if _, ok := r.Draft.(base.MTPDraftModel); !ok {
|
||||
return false
|
||||
}
|
||||
if _, ok := r.Model.(base.MTPEmbeddingModel); !ok {
|
||||
return false
|
||||
}
|
||||
if !r.mtpDefaults(false).Enabled {
|
||||
return false
|
||||
}
|
||||
if opts.Logprobs || opts.TopLogprobs > 0 {
|
||||
return false
|
||||
}
|
||||
if opts.Temperature != 0 {
|
||||
return false
|
||||
}
|
||||
repeatPenaltyNeutral := opts.RepeatPenalty <= 0 || opts.RepeatPenalty == 1
|
||||
topPNeutral := opts.TopP <= 0 || opts.TopP >= 1
|
||||
topKNeutral := opts.TopK <= 0
|
||||
return repeatPenaltyNeutral && opts.PresencePenalty == 0 && opts.FrequencyPenalty == 0 && topPNeutral && topKNeutral && opts.MinP == 0
|
||||
}
|
||||
|
||||
func (r *Runner) useSampleMTP(opts sampler.Options) bool {
|
||||
if serial, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_SERIAL_VALIDATE")); err == nil && serial {
|
||||
return false
|
||||
}
|
||||
if compare, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_COMPARE_SERIAL_VALIDATE")); err == nil && compare {
|
||||
return false
|
||||
}
|
||||
if r.Draft == nil {
|
||||
return false
|
||||
}
|
||||
if _, ok := r.Draft.(base.MTPDraftModel); !ok {
|
||||
return false
|
||||
}
|
||||
if _, ok := r.Model.(base.MTPEmbeddingModel); !ok {
|
||||
return false
|
||||
}
|
||||
if !r.mtpDefaults(true).Enabled {
|
||||
return false
|
||||
}
|
||||
if opts.Logprobs || opts.TopLogprobs > 0 {
|
||||
return false
|
||||
}
|
||||
return opts.Temperature != 0
|
||||
}
|
||||
|
||||
func (r *Runner) runGreedyMTPDecode(ctx context.Context, request Request, session *cacheSession, caches []cache.Cache, seed []int32, position *int, started time.Time) error {
|
||||
targetEmbeddings := r.Model.(base.MTPEmbeddingModel)
|
||||
draft := r.Draft.(base.MTPDraftModel)
|
||||
mtpOpts := r.loadMTPOptions(false)
|
||||
stats := mtpStats{maxDraft: mtpOpts.initialDraftTokens}
|
||||
draftLimit := mtpOpts.initialDraftTokens
|
||||
slog.Info("MTP greedy decode enabled", "initial_draft_tokens", mtpOpts.initialDraftTokens, "max_draft_tokens", mtpOpts.maxDraftTokens, "draft_schedule", mtpOpts.draftSchedule, "serial_validate", mtpOpts.serialValidate, "compare_serial_validate", mtpOpts.compareSerialValidate)
|
||||
|
||||
targetForward := func(token *mlx.Array) *mlx.Array {
|
||||
fwd := r.Model.Forward(&batch.Batch{
|
||||
InputIDs: token,
|
||||
SeqOffsets: []int32{int32(*position)},
|
||||
SeqQueryLens: []int32{int32(token.Dim(1))},
|
||||
}, caches)
|
||||
*position += token.Dim(1)
|
||||
return fwd
|
||||
}
|
||||
|
||||
hidden := targetForward(mlx.FromValues(seed, 1, len(seed)))
|
||||
current := sampler.Result{Token: greedyTokenFromLogits(r.lastLogits(hidden))}
|
||||
mlx.Pin(current.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(current.Arrays()...)
|
||||
defer func() {
|
||||
mlx.Unpin(current.Arrays()...)
|
||||
}()
|
||||
|
||||
dec := decoder{tokenizer: r.Tokenizer}
|
||||
final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1}
|
||||
now := started
|
||||
|
||||
generated := 0
|
||||
for generated < request.Options.NumPredict {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
hidden = targetForward(current.Token.ExpandDims(-1))
|
||||
baseLogits := r.lastLogits(hidden)
|
||||
stats.targetDuration += time.Since(t0)
|
||||
|
||||
if generated == 0 {
|
||||
mlx.Eval(current.Arrays()...)
|
||||
final.PromptEvalDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !done {
|
||||
generated++
|
||||
}
|
||||
if done || generated >= request.Options.NumPredict {
|
||||
break
|
||||
}
|
||||
|
||||
stats.iterations++
|
||||
maxDraft := min(draftLimit, request.Options.NumPredict-generated)
|
||||
t0 = time.Now()
|
||||
draftTokens := r.generateMTPDrafts(draft, targetEmbeddings, current.Token, hidden, caches, int32(*position-1), maxDraft)
|
||||
draftCount := 0
|
||||
if draftTokens != nil {
|
||||
draftCount = draftTokens.Dim(1)
|
||||
mlx.Pin(baseLogits, draftTokens)
|
||||
mlx.Eval(draftTokens)
|
||||
mlx.Sweep()
|
||||
}
|
||||
stats.draftDuration += time.Since(t0)
|
||||
stats.drafted += draftCount
|
||||
var next sampler.Result
|
||||
if draftCount == 0 {
|
||||
next = sampler.Result{Token: greedyTokenFromLogits(baseLogits)}
|
||||
} else {
|
||||
var accepted int
|
||||
t0 = time.Now()
|
||||
next, accepted, done, err = r.acceptMTPDrafts(ctx, request, session, &dec, caches, position, baseLogits, draftTokens, &final, &generated, &stats, mtpOpts)
|
||||
stats.validateDuration += time.Since(t0)
|
||||
mlx.Unpin(baseLogits, draftTokens)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stats.accepted += accepted
|
||||
switch {
|
||||
case mtpOpts.draftSchedule == mtpDraftScheduleConstant:
|
||||
case accepted == draftCount:
|
||||
stats.allAccepted++
|
||||
draftLimit = min(mtpOpts.maxDraftTokens, draftLimit+2)
|
||||
default:
|
||||
stats.mismatches++
|
||||
draftLimit = max(1, draftLimit-1)
|
||||
}
|
||||
if mtpOpts.draftSchedule == mtpDraftScheduleConstant {
|
||||
if accepted == draftCount {
|
||||
stats.allAccepted++
|
||||
} else {
|
||||
stats.mismatches++
|
||||
}
|
||||
}
|
||||
stats.maxDraft = max(stats.maxDraft, draftLimit)
|
||||
if next.Token == nil {
|
||||
mlx.Sweep()
|
||||
}
|
||||
if done || generated >= request.Options.NumPredict {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mlx.Pin(next.Arrays()...)
|
||||
old := current
|
||||
current = next
|
||||
mlx.Unpin(old.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(current.Arrays()...)
|
||||
|
||||
if generated%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
final.EvalCount = generated
|
||||
final.EvalDuration = time.Since(now)
|
||||
acceptance := 0.0
|
||||
if stats.drafted > 0 {
|
||||
acceptance = float64(stats.accepted) / float64(stats.drafted)
|
||||
}
|
||||
avgDraft := 0.0
|
||||
avgAccepted := 0.0
|
||||
if stats.iterations > 0 {
|
||||
avgDraft = float64(stats.drafted) / float64(stats.iterations)
|
||||
avgAccepted = float64(stats.accepted) / float64(stats.iterations)
|
||||
}
|
||||
slog.Info("MTP decode stats", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "compared", stats.compared, "batch_serial_mismatches", stats.batchSerialMismatches, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", stats.maxDraft, "draft_schedule", mtpOpts.draftSchedule, "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case request.Responses <- final:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) runSampleMTPDecode(ctx context.Context, request Request, session *cacheSession, caches []cache.Cache, seed []int32, position *int, started time.Time) error {
|
||||
targetEmbeddings := r.Model.(base.MTPEmbeddingModel)
|
||||
draft := r.Draft.(base.MTPDraftModel)
|
||||
mtpOpts := r.loadMTPOptions(true)
|
||||
stats := mtpStats{maxDraft: mtpOpts.initialDraftTokens}
|
||||
draftLimit := mtpOpts.initialDraftTokens
|
||||
slog.Info("MTP sample decode enabled", "initial_draft_tokens", mtpOpts.initialDraftTokens, "max_draft_tokens", mtpOpts.maxDraftTokens, "draft_schedule", mtpOpts.draftSchedule, "serial_validate", mtpOpts.serialValidate)
|
||||
|
||||
targetForward := func(token *mlx.Array) *mlx.Array {
|
||||
fwd := r.Model.Forward(&batch.Batch{
|
||||
InputIDs: token,
|
||||
SeqOffsets: []int32{int32(*position)},
|
||||
SeqQueryLens: []int32{int32(token.Dim(1))},
|
||||
}, caches)
|
||||
*position += token.Dim(1)
|
||||
return fwd
|
||||
}
|
||||
|
||||
hidden := targetForward(mlx.FromValues(seed, 1, len(seed)))
|
||||
current := r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden))
|
||||
mlx.Pin(current.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(current.Arrays()...)
|
||||
defer func() {
|
||||
mlx.Unpin(current.Arrays()...)
|
||||
}()
|
||||
|
||||
dec := decoder{tokenizer: r.Tokenizer}
|
||||
final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1}
|
||||
now := started
|
||||
|
||||
generated := 0
|
||||
for generated < request.Options.NumPredict {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
hidden = targetForward(mtpTokenInput(current.Token))
|
||||
baseLogits := r.lastLogits(hidden)
|
||||
stats.targetDuration += time.Since(t0)
|
||||
|
||||
if generated == 0 {
|
||||
mlx.Eval(current.Arrays()...)
|
||||
final.PromptEvalDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !done {
|
||||
generated++
|
||||
}
|
||||
if done || generated >= request.Options.NumPredict {
|
||||
break
|
||||
}
|
||||
|
||||
stats.iterations++
|
||||
maxDraft := min(draftLimit, request.Options.NumPredict-generated)
|
||||
t0 = time.Now()
|
||||
candidates := r.generateMTPDraftCandidates(draft, targetEmbeddings, current.Token, hidden, caches, int32(*position-1), maxDraft)
|
||||
draftCount := 0
|
||||
if candidates != nil {
|
||||
draftCount = candidates.tokens.Dim(1)
|
||||
mlx.Pin(baseLogits, candidates.tokens, candidates.logits)
|
||||
mlx.Sweep()
|
||||
}
|
||||
stats.draftDuration += time.Since(t0)
|
||||
stats.drafted += draftCount
|
||||
|
||||
var next sampler.Result
|
||||
if draftCount == 0 {
|
||||
next = r.Sampler.Sample([]int{pipelineSlot}, baseLogits)
|
||||
} else {
|
||||
var accepted int
|
||||
t0 = time.Now()
|
||||
next, accepted, done, err = r.acceptSampleMTPDrafts(ctx, request, session, &dec, caches, position, baseLogits, candidates, &final, &generated, &stats)
|
||||
stats.validateDuration += time.Since(t0)
|
||||
mlx.Unpin(baseLogits, candidates.tokens, candidates.logits)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stats.accepted += accepted
|
||||
switch {
|
||||
case mtpOpts.draftSchedule == mtpDraftScheduleConstant:
|
||||
case accepted == draftCount:
|
||||
stats.allAccepted++
|
||||
draftLimit = min(mtpOpts.maxDraftTokens, draftLimit+2)
|
||||
default:
|
||||
stats.mismatches++
|
||||
draftLimit = max(1, draftLimit-1)
|
||||
}
|
||||
if mtpOpts.draftSchedule == mtpDraftScheduleConstant {
|
||||
if accepted == draftCount {
|
||||
stats.allAccepted++
|
||||
} else {
|
||||
stats.mismatches++
|
||||
}
|
||||
}
|
||||
stats.maxDraft = max(stats.maxDraft, draftLimit)
|
||||
if next.Token == nil {
|
||||
mlx.Sweep()
|
||||
}
|
||||
if done || generated >= request.Options.NumPredict {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mlx.Pin(next.Arrays()...)
|
||||
old := current
|
||||
current = next
|
||||
mlx.Unpin(old.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(current.Arrays()...)
|
||||
|
||||
if generated%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
final.EvalCount = generated
|
||||
final.EvalDuration = time.Since(now)
|
||||
acceptance := 0.0
|
||||
if stats.drafted > 0 {
|
||||
acceptance = float64(stats.accepted) / float64(stats.drafted)
|
||||
}
|
||||
avgDraft := 0.0
|
||||
avgAccepted := 0.0
|
||||
if stats.iterations > 0 {
|
||||
avgDraft = float64(stats.drafted) / float64(stats.iterations)
|
||||
avgAccepted = float64(stats.accepted) / float64(stats.iterations)
|
||||
}
|
||||
slog.Info("MTP decode stats", "mode", "sample", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", stats.maxDraft, "draft_schedule", mtpOpts.draftSchedule, "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case request.Responses <- final:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type mtpDraftCandidates struct {
|
||||
tokens *mlx.Array
|
||||
// logits are the processed proposal scores used to sample tokens.
|
||||
logits *mlx.Array
|
||||
}
|
||||
|
||||
func (r *Runner) generateMTPDrafts(draft base.MTPDraftModel, target base.MTPEmbeddingModel, token *mlx.Array, hidden *mlx.Array, caches []cache.Cache, position int32, maxDraft int) *mlx.Array {
|
||||
if maxDraft <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
lastToken := token.ExpandDims(-1)
|
||||
lastHidden := hidden
|
||||
draftTokens := make([]*mlx.Array, 0, maxDraft)
|
||||
|
||||
// Gemma4 assistant MTP is trained as "single-position" drafting:
|
||||
// keep the RoPE/cache position anchored at the last target-seen token
|
||||
// while the proposed token and projected hidden state advance.
|
||||
for range maxDraft {
|
||||
tokenEmbedding := target.TokenEmbeddings(lastToken)
|
||||
inputs := tokenEmbedding.Concatenate(-1, lastHidden)
|
||||
logits, projected := draft.Draft(inputs, position, caches)
|
||||
stepLogits := r.lastLogitsFromLogits(logits)
|
||||
nextToken := greedyTokenFromLogits(stepLogits)
|
||||
|
||||
lastToken = nextToken.ExpandDims(-1)
|
||||
lastHidden = projected
|
||||
draftTokens = append(draftTokens, lastToken)
|
||||
}
|
||||
if len(draftTokens) == 0 {
|
||||
return nil
|
||||
}
|
||||
return mlx.Concatenate(draftTokens, 1)
|
||||
}
|
||||
|
||||
func (r *Runner) generateMTPDraftCandidates(draft base.MTPDraftModel, target base.MTPEmbeddingModel, token *mlx.Array, hidden *mlx.Array, caches []cache.Cache, position int32, maxDraft int) *mtpDraftCandidates {
|
||||
if maxDraft <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
lastToken := mtpTokenInput(token)
|
||||
lastHidden := hidden
|
||||
draftTokens := make([]*mlx.Array, 0, maxDraft)
|
||||
draftLogits := make([]*mlx.Array, 0, maxDraft)
|
||||
var prefix *mlx.Array
|
||||
|
||||
// Gemma4 assistant MTP is trained as "single-position" drafting:
|
||||
// keep the RoPE/cache position anchored at the last target-seen token
|
||||
// while the proposed token and projected hidden state advance.
|
||||
for range maxDraft {
|
||||
tokenEmbedding := target.TokenEmbeddings(lastToken)
|
||||
inputs := tokenEmbedding.Concatenate(-1, lastHidden)
|
||||
logits, projected := draft.Draft(inputs, position, caches)
|
||||
stepLogits := r.lastLogitsFromLogits(logits)
|
||||
stepScores := r.Sampler.SpeculativeScores(pipelineSlot, stepLogits, prefix)
|
||||
nextToken := stepScores.Categorical(-1).AsType(mlx.DTypeInt32)
|
||||
|
||||
lastToken = mtpTokenInput(nextToken)
|
||||
lastHidden = projected
|
||||
draftTokens = append(draftTokens, lastToken)
|
||||
draftLogits = append(draftLogits, stepScores.ExpandDims(1))
|
||||
if prefix == nil {
|
||||
prefix = lastToken
|
||||
} else {
|
||||
prefix = prefix.Concatenate(1, lastToken)
|
||||
}
|
||||
}
|
||||
if len(draftTokens) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &mtpDraftCandidates{
|
||||
tokens: mlx.Concatenate(draftTokens, 1),
|
||||
logits: mlx.Concatenate(draftLogits, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) acceptMTPDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int, stats *mtpStats, opts mtpOptions) (sampler.Result, int, bool, error) {
|
||||
if opts.serialValidate {
|
||||
stats.serial++
|
||||
return r.acceptMTPDraftsSerial(ctx, request, session, dec, caches, position, baseLogits, draftTokens, final, generated)
|
||||
}
|
||||
|
||||
specCaches, spec, ok := cache.BeginSpeculation(caches)
|
||||
if ok {
|
||||
stats.batched++
|
||||
return r.acceptMTPDraftsBatched(ctx, request, session, dec, caches, specCaches, spec, position, baseLogits, draftTokens, final, generated, stats, opts)
|
||||
}
|
||||
|
||||
stats.serial++
|
||||
return r.acceptMTPDraftsSerial(ctx, request, session, dec, caches, position, baseLogits, draftTokens, final, generated)
|
||||
}
|
||||
|
||||
func (r *Runner) acceptMTPDraftsBatched(ctx context.Context, request Request, session *cacheSession, dec *decoder, liveCaches []cache.Cache, caches []cache.Cache, spec *cache.Speculation, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int, stats *mtpStats, opts mtpOptions) (sampler.Result, int, bool, error) {
|
||||
before := *position
|
||||
draftCount := draftTokens.Dim(1)
|
||||
hiddenSeq := r.Model.Forward(&batch.Batch{
|
||||
InputIDs: draftTokens,
|
||||
SeqOffsets: []int32{int32(before)},
|
||||
SeqQueryLens: []int32{int32(draftCount)},
|
||||
}, caches)
|
||||
|
||||
accepted := 0
|
||||
var next sampler.Result
|
||||
done := false
|
||||
|
||||
selectedTokens := r.mtpValidationTokens(baseLogits, hiddenSeq)
|
||||
mlx.Eval(draftTokens, selectedTokens)
|
||||
draftIDs := draftTokens.Ints()
|
||||
selectedIDs := selectedTokens.Ints()
|
||||
if len(selectedIDs) < draftCount+1 {
|
||||
return sampler.Result{}, accepted, false, fmt.Errorf("mtp validation produced %d tokens for %d draft tokens", len(selectedIDs), draftCount)
|
||||
}
|
||||
|
||||
for i, id := range draftIDs {
|
||||
if selectedIDs[i] != id {
|
||||
next = sampler.Result{Token: mtpTokenAt(selectedTokens, i)}
|
||||
break
|
||||
}
|
||||
accepted++
|
||||
if r.Tokenizer.IsEOS(int32(id)) {
|
||||
done = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if opts.compareSerialValidate {
|
||||
spec.Commit(0)
|
||||
r.compareMTPBatchedWithSerial(ctx, liveCaches, before, baseLogits, hiddenSeq, draftIDs, selectedIDs, accepted, draftCount, stats)
|
||||
}
|
||||
spec.Commit(accepted)
|
||||
*position = before + accepted
|
||||
|
||||
for _, id := range draftIDs[:accepted] {
|
||||
if *generated >= request.Options.NumPredict {
|
||||
done = true
|
||||
break
|
||||
}
|
||||
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
|
||||
var err error
|
||||
done, err = r.emitMTPToken(ctx, request, session, dec, res, final)
|
||||
if err != nil {
|
||||
return sampler.Result{}, accepted, done, err
|
||||
}
|
||||
if !done {
|
||||
(*generated)++
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if done || *generated >= request.Options.NumPredict {
|
||||
return sampler.Result{}, accepted, true, nil
|
||||
}
|
||||
if next.Token == nil {
|
||||
next = sampler.Result{Token: mtpTokenAt(selectedTokens, draftCount)}
|
||||
}
|
||||
return next, accepted, false, nil
|
||||
}
|
||||
|
||||
func (r *Runner) acceptSampleMTPDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, candidates *mtpDraftCandidates, final *CompletionResponse, generated *int, stats *mtpStats) (sampler.Result, int, bool, error) {
|
||||
specCaches, spec, ok := cache.BeginSpeculation(caches)
|
||||
if !ok {
|
||||
stats.serial++
|
||||
return r.Sampler.Sample([]int{pipelineSlot}, baseLogits), 0, false, nil
|
||||
}
|
||||
stats.batched++
|
||||
|
||||
before := *position
|
||||
draftCount := candidates.tokens.Dim(1)
|
||||
hiddenSeq := r.Model.Forward(&batch.Batch{
|
||||
InputIDs: candidates.tokens,
|
||||
SeqOffsets: []int32{int32(before)},
|
||||
SeqQueryLens: []int32{int32(draftCount)},
|
||||
}, specCaches)
|
||||
|
||||
targetScores := r.Sampler.SpeculativeScores(pipelineSlot, r.mtpValidationLogits(baseLogits, hiddenSeq), candidates.tokens)
|
||||
draftScores := candidates.logits
|
||||
if draftScores.NumDims() == 3 {
|
||||
draftScores = draftScores.Squeeze(0)
|
||||
}
|
||||
acceptedMask := mtpSampleAcceptedMask(targetScores, draftScores, candidates.tokens, draftCount)
|
||||
mlx.Eval(candidates.tokens, acceptedMask)
|
||||
|
||||
draftIDs := candidates.tokens.Ints()
|
||||
acceptedFlags := acceptedMask.Ints()
|
||||
accepted := 0
|
||||
for _, ok := range acceptedFlags {
|
||||
if ok == 0 {
|
||||
break
|
||||
}
|
||||
accepted++
|
||||
}
|
||||
if accepted > draftCount {
|
||||
return sampler.Result{}, 0, false, fmt.Errorf("mtp sample validation accepted %d tokens for %d draft tokens", accepted, draftCount)
|
||||
}
|
||||
|
||||
commitIDs := make([]int32, 0, accepted+1)
|
||||
done := false
|
||||
for i, id := range draftIDs[:accepted] {
|
||||
commitIDs = append(commitIDs, int32(id))
|
||||
if r.Tokenizer.IsEOS(int32(id)) {
|
||||
done = true
|
||||
accepted = i + 1
|
||||
commitIDs = commitIDs[:accepted]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
spec.Commit(accepted)
|
||||
*position = before + accepted
|
||||
|
||||
for _, id := range draftIDs[:accepted] {
|
||||
if *generated >= request.Options.NumPredict {
|
||||
done = true
|
||||
break
|
||||
}
|
||||
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
|
||||
var err error
|
||||
done, err = r.emitMTPToken(ctx, request, session, dec, res, final)
|
||||
if err != nil {
|
||||
return sampler.Result{}, accepted, done, err
|
||||
}
|
||||
if !done {
|
||||
(*generated)++
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if done || *generated >= request.Options.NumPredict {
|
||||
r.Sampler.Commit(pipelineSlot, commitIDs)
|
||||
return sampler.Result{}, accepted, true, nil
|
||||
}
|
||||
|
||||
var nextToken *mlx.Array
|
||||
if accepted == draftCount {
|
||||
nextToken = mtpSampleTokenAt(targetScores, draftCount)
|
||||
} else {
|
||||
nextToken = mtpSampleResidualToken(targetScores, draftScores, accepted)
|
||||
}
|
||||
mlx.Eval(nextToken)
|
||||
nextID := int32(tokenID(nextToken))
|
||||
commitIDs = append(commitIDs, nextID)
|
||||
r.Sampler.Commit(pipelineSlot, commitIDs)
|
||||
|
||||
return sampler.Result{Token: nextToken}, accepted, false, nil
|
||||
}
|
||||
|
||||
func mtpSampleAcceptedMask(targetScores, draftScores, draftTokens *mlx.Array, draftCount int) *mlx.Array {
|
||||
targetProbs := mlx.SoftmaxAxis(targetScores.Slice(mlx.Slice(0, draftCount), mlx.Slice()), -1, true)
|
||||
draftProbs := mlx.SoftmaxAxis(draftScores, -1, true)
|
||||
if draftTokens.NumDims() == 2 {
|
||||
draftTokens = draftTokens.Squeeze(0)
|
||||
}
|
||||
indices := draftTokens.ExpandDims(-1)
|
||||
p := targetProbs.TakeAlongAxis(indices, -1).Squeeze(-1)
|
||||
q := draftProbs.TakeAlongAxis(indices, -1).Squeeze(-1)
|
||||
acceptP := mlx.Minimum(p.Divide(q), mlx.FromValue(float32(1)))
|
||||
return mlx.Bernoulli(acceptP).AsType(mlx.DTypeInt32)
|
||||
}
|
||||
|
||||
func mtpSampleTokenAt(scores *mlx.Array, index int) *mlx.Array {
|
||||
row := scores.Slice(mlx.Slice(index, index+1), mlx.Slice())
|
||||
return mtpTokenVector(row.Categorical(-1).AsType(mlx.DTypeInt32))
|
||||
}
|
||||
|
||||
func mtpSampleResidualToken(targetScores, draftScores *mlx.Array, index int) *mlx.Array {
|
||||
p := mlx.SoftmaxAxis(targetScores.Slice(mlx.Slice(index, index+1), mlx.Slice()), -1, true)
|
||||
q := mlx.SoftmaxAxis(draftScores.Slice(mlx.Slice(index, index+1), mlx.Slice()), -1, true)
|
||||
diff := p.Subtract(q)
|
||||
positive := mlx.Maximum(diff, mlx.FromValue(float32(1e-20)))
|
||||
logits := mlx.Log(positive)
|
||||
logits = mlx.Where(diff.LessEqual(mlx.FromValue(float32(0))), mlx.FromValue(float32(math.Inf(-1))), logits)
|
||||
return mtpTokenVector(logits.Categorical(-1).AsType(mlx.DTypeInt32))
|
||||
}
|
||||
|
||||
func mtpTokenInput(token *mlx.Array) *mlx.Array {
|
||||
switch token.NumDims() {
|
||||
case 0:
|
||||
return token.Reshape(1, 1)
|
||||
case 1:
|
||||
return token.ExpandDims(-1)
|
||||
case 2:
|
||||
return token
|
||||
default:
|
||||
panic(fmt.Sprintf("mtp token must be rank 0, 1, or 2, got rank %d", token.NumDims()))
|
||||
}
|
||||
}
|
||||
|
||||
func mtpTokenVector(token *mlx.Array) *mlx.Array {
|
||||
switch token.NumDims() {
|
||||
case 0:
|
||||
return token.Reshape(1)
|
||||
case 1:
|
||||
return token
|
||||
default:
|
||||
panic(fmt.Sprintf("mtp sampled token must be rank 0 or 1, got rank %d", token.NumDims()))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) compareMTPBatchedWithSerial(ctx context.Context, caches []cache.Cache, before int, baseLogits, hiddenSeq *mlx.Array, draftIDs, selectedIDs []int, accepted, draftCount int, stats *mtpStats) {
|
||||
serialCaches, ok := cache.BeginIsolatedSpeculation(caches)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
compareCount := accepted + 1
|
||||
if accepted == draftCount {
|
||||
// Include the target bonus token when every draft was accepted.
|
||||
compareCount = draftCount + 1
|
||||
}
|
||||
|
||||
serialLogits := baseLogits
|
||||
for i := range compareCount {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
if i >= len(selectedIDs) {
|
||||
return
|
||||
}
|
||||
|
||||
batchedLogits := baseLogits
|
||||
if i > 0 {
|
||||
batchedLogits = r.targetLogitsAt(hiddenSeq, i-1)
|
||||
}
|
||||
|
||||
batchedToken := greedyTokenFromLogits(batchedLogits)
|
||||
serialToken := greedyTokenFromLogits(serialLogits)
|
||||
mlx.Eval(batchedToken, serialToken)
|
||||
|
||||
batchedID := tokenID(batchedToken)
|
||||
vectorizedID := selectedIDs[i]
|
||||
serialID := tokenID(serialToken)
|
||||
stats.compared++
|
||||
if vectorizedID != serialID {
|
||||
firstMismatch := stats.batchSerialMismatches == 0
|
||||
stats.batchSerialMismatches++
|
||||
if !firstMismatch {
|
||||
return
|
||||
}
|
||||
|
||||
draftID := -1
|
||||
if i < draftCount {
|
||||
draftID = draftIDs[i]
|
||||
}
|
||||
batchedTop := top2FromLogits(batchedLogits)
|
||||
serialTop := top2FromLogits(serialLogits)
|
||||
slog.Warn("MTP batched validation differs from serial validation",
|
||||
"position", before+i,
|
||||
"draft", draftID,
|
||||
"batched", vectorizedID,
|
||||
"batched_slice", batchedID,
|
||||
"serial", serialID,
|
||||
"batched_slice_top1", batchedTop.firstToken,
|
||||
"batched_slice_top2", batchedTop.secondToken,
|
||||
"batched_slice_margin", batchedTop.margin,
|
||||
"serial_top1", serialTop.firstToken,
|
||||
"serial_top2", serialTop.secondToken,
|
||||
"serial_margin", serialTop.margin,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if i >= draftCount || i >= accepted {
|
||||
return
|
||||
}
|
||||
|
||||
hidden := r.Model.Forward(&batch.Batch{
|
||||
InputIDs: mlx.FromValues([]int32{int32(draftIDs[i])}, 1, 1),
|
||||
SeqOffsets: []int32{int32(before + i)},
|
||||
SeqQueryLens: []int32{1},
|
||||
}, serialCaches)
|
||||
serialLogits = r.lastLogits(hidden)
|
||||
}
|
||||
}
|
||||
|
||||
type mtpTop2 struct {
|
||||
firstToken int
|
||||
secondToken int
|
||||
margin float64
|
||||
}
|
||||
|
||||
func top2FromLogits(logits *mlx.Array) mtpTop2 {
|
||||
indices := logits.Negative().ArgsortAxis(-1).Slice(mlx.Slice(), mlx.Slice(0, 2))
|
||||
indices32 := indices.AsType(mlx.DTypeInt32)
|
||||
values := logits.TakeAlongAxis(indices, -1).AsType(mlx.DTypeFloat32)
|
||||
mlx.Eval(indices32, values)
|
||||
|
||||
tokenIDs := indices32.Ints()
|
||||
logitValues := values.Floats()
|
||||
if len(tokenIDs) < 2 || len(logitValues) < 2 {
|
||||
return mtpTop2{}
|
||||
}
|
||||
return mtpTop2{
|
||||
firstToken: tokenIDs[0],
|
||||
secondToken: tokenIDs[1],
|
||||
margin: float64(logitValues[0] - logitValues[1]),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) acceptMTPDraftsSerial(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) {
|
||||
logits := baseLogits
|
||||
accepted := 0
|
||||
draftIDs := draftTokens.Ints()
|
||||
|
||||
for _, id := range draftIDs {
|
||||
selected := greedyTokenFromLogits(logits)
|
||||
mlx.Eval(selected)
|
||||
selectedID := tokenID(selected)
|
||||
if selectedID != id {
|
||||
return sampler.Result{Token: mlx.FromValues([]int32{int32(selectedID)}, 1)}, accepted, false, nil
|
||||
}
|
||||
|
||||
hidden := r.Model.Forward(&batch.Batch{
|
||||
InputIDs: mlx.FromValues([]int32{int32(id)}, 1, 1),
|
||||
SeqOffsets: []int32{int32(*position)},
|
||||
SeqQueryLens: []int32{1},
|
||||
}, caches)
|
||||
(*position)++
|
||||
accepted++
|
||||
|
||||
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
|
||||
done, err := r.emitMTPToken(ctx, request, session, dec, res, final)
|
||||
if err != nil {
|
||||
return sampler.Result{}, accepted, done, err
|
||||
}
|
||||
if !done {
|
||||
(*generated)++
|
||||
}
|
||||
if done || *generated >= request.Options.NumPredict {
|
||||
return sampler.Result{}, accepted, true, nil
|
||||
}
|
||||
|
||||
logits = r.lastLogits(hidden)
|
||||
}
|
||||
|
||||
return sampler.Result{Token: greedyTokenFromLogits(logits)}, accepted, false, nil
|
||||
}
|
||||
|
||||
func (r *Runner) emitMTPToken(ctx context.Context, request Request, session *cacheSession, dec *decoder, res sampler.Result, final *CompletionResponse) (bool, error) {
|
||||
output := int32(tokenID(res.Token))
|
||||
session.outputs = append(session.outputs, output)
|
||||
|
||||
if r.Tokenizer.IsEOS(output) {
|
||||
final.DoneReason = 0
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if resp, ok := dec.decode(res); ok {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, ctx.Err()
|
||||
case request.Responses <- resp:
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *Runner) lastLogits(hidden *mlx.Array) *mlx.Array {
|
||||
logits := r.Model.Unembed(hidden)
|
||||
return r.lastLogitsFromLogits(logits)
|
||||
}
|
||||
|
||||
func (r *Runner) targetLogitsAt(hiddenSeq *mlx.Array, index int) *mlx.Array {
|
||||
hidden := hiddenSeq.Slice(mlx.Slice(), mlx.Slice(index), mlx.Slice())
|
||||
return r.lastLogits(hidden)
|
||||
}
|
||||
|
||||
func (r *Runner) mtpValidationTokens(baseLogits, hiddenSeq *mlx.Array) *mlx.Array {
|
||||
return greedyTokenFromLogits(r.mtpValidationLogits(baseLogits, hiddenSeq))
|
||||
}
|
||||
|
||||
func (r *Runner) mtpValidationLogits(baseLogits, hiddenSeq *mlx.Array) *mlx.Array {
|
||||
seqLogits := r.Model.Unembed(hiddenSeq)
|
||||
return baseLogits.ExpandDims(1).Concatenate(1, seqLogits)
|
||||
}
|
||||
|
||||
func mtpTokenAt(tokens *mlx.Array, index int) *mlx.Array {
|
||||
return tokens.Slice(mlx.Slice(), mlx.Slice(index)).Squeeze(0)
|
||||
}
|
||||
|
||||
func (r *Runner) lastLogitsFromLogits(logits *mlx.Array) *mlx.Array {
|
||||
return logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
}
|
||||
|
||||
func greedyTokenFromLogits(logits *mlx.Array) *mlx.Array {
|
||||
return logits.Argmax(-1, false).AsType(mlx.DTypeInt32)
|
||||
}
|
||||
|
||||
func tokenID(token *mlx.Array) int {
|
||||
if token == nil {
|
||||
return -1
|
||||
}
|
||||
if token.DType() == mlx.DTypeInt32 {
|
||||
ids := token.Ints()
|
||||
if len(ids) > 0 {
|
||||
return ids[0]
|
||||
}
|
||||
}
|
||||
return token.Int()
|
||||
}
|
||||
|
|
@ -146,6 +146,12 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er
|
|||
|
||||
// Register the sampler after prefill completes.
|
||||
r.Sampler.Add(pipelineSlot, request.SamplerOpts, inputs)
|
||||
if r.useGreedyMTP(request.SamplerOpts) {
|
||||
return r.runGreedyMTPDecode(ctx, request, session, caches, tokens[processed:], &position, now)
|
||||
}
|
||||
if r.useSampleMTP(request.SamplerOpts) {
|
||||
return r.runSampleMTPDecode(ctx, request, session, caches, tokens[processed:], &position, now)
|
||||
}
|
||||
|
||||
step := func(token *mlx.Array) sampler.Result {
|
||||
fwd := r.Model.Forward(&batch.Batch{
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ type Request struct {
|
|||
|
||||
type Runner struct {
|
||||
Model base.Model
|
||||
Draft base.DraftModel
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
Sampler *sample.Sampler
|
||||
|
|
@ -61,12 +62,41 @@ func (r *Runner) Load(modelName string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// Assign weights to model (model-specific logic)
|
||||
loadWeights := base.Weights(m)
|
||||
if err := loadWeights(tensors); err != nil {
|
||||
// Assign weights to model (model-specific logic). Target and draft weights
|
||||
// must be loaded before sweeping so tensors from a combined manifest are
|
||||
// not discarded before the draft model can retain them.
|
||||
if err := m.LoadWeights(tensors); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Draft = nil
|
||||
draft, err := base.NewDraft(root, m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if draft != nil {
|
||||
if err := draft.LoadWeights(tensors); err != nil {
|
||||
return err
|
||||
}
|
||||
r.Draft = draft
|
||||
}
|
||||
|
||||
collected := mlx.Collect(m)
|
||||
if draft != nil {
|
||||
draftArrays := mlx.Collect(draft)
|
||||
collected = append(collected, draftArrays...)
|
||||
if root.Draft != nil {
|
||||
slog.Info("Loaded draft model", "tensor_prefix", root.Draft.TensorPrefix, "config", root.Draft.Config, "arrays", len(draftArrays))
|
||||
} else {
|
||||
slog.Info("Loaded draft model", "arrays", len(draftArrays))
|
||||
}
|
||||
}
|
||||
for _, arr := range collected {
|
||||
mlx.Pin(arr)
|
||||
}
|
||||
mlx.Sweep()
|
||||
mlx.Eval(collected...)
|
||||
|
||||
r.Model = m
|
||||
r.Tokenizer = m.Tokenizer()
|
||||
r.contextLength = m.MaxContextLength()
|
||||
|
|
|
|||
|
|
@ -349,6 +349,159 @@ func (s *Sampler) Sample(seqIDs []int, logits *mlx.Array) Result {
|
|||
return res
|
||||
}
|
||||
|
||||
// SpeculativeScores applies this slot's non-sampling transforms to logits
|
||||
// without mutating sampler state. Row i is scored as if draftTokens[:i] had
|
||||
// already been appended to the slot history. logits must be [R,V] or [1,R,V].
|
||||
func (s *Sampler) SpeculativeScores(seqID int, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
|
||||
slot, ok := s.byID[seqID]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: seqID %d not registered", seqID))
|
||||
}
|
||||
|
||||
if logits.NumDims() == 3 {
|
||||
if logits.Dim(0) != 1 {
|
||||
panic("sample.Sampler.SpeculativeScores: only batch size 1 is supported")
|
||||
}
|
||||
logits = logits.Squeeze(0)
|
||||
}
|
||||
if logits.NumDims() != 2 {
|
||||
panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: logits must be rank 2 or 3, got rank %d", logits.NumDims()))
|
||||
}
|
||||
|
||||
if draftTokens != nil && draftTokens.NumDims() == 1 {
|
||||
draftTokens = draftTokens.ExpandDims(0)
|
||||
}
|
||||
rows := logits.Dim(0)
|
||||
var hist *mlx.Array
|
||||
if slot.opts.usesHistory() {
|
||||
if s.history == nil {
|
||||
panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: seqID %d has no history", seqID))
|
||||
}
|
||||
if slot.historyLen < slot.opts.RepeatLastN {
|
||||
return s.speculativeScoresSerial(slot, logits, draftTokens)
|
||||
}
|
||||
hist = s.speculativeHistory(slot, draftTokens, rows)
|
||||
}
|
||||
|
||||
return slot.speculativeScores(&slotCtx{opts: slot.opts, history: hist}, logits)
|
||||
}
|
||||
|
||||
// Commit appends already-selected tokens to seqID's repeat-penalty history.
|
||||
// It is used after speculative sampling once the accepted continuation is
|
||||
// known. Normal Sample calls continue to mutate history themselves.
|
||||
func (s *Sampler) Commit(seqID int, tokens []int32) {
|
||||
if len(tokens) == 0 {
|
||||
return
|
||||
}
|
||||
slot, ok := s.byID[seqID]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("sample.Sampler.Commit: seqID %d not registered", seqID))
|
||||
}
|
||||
if !slot.opts.usesHistory() {
|
||||
return
|
||||
}
|
||||
if s.history == nil {
|
||||
panic(fmt.Sprintf("sample.Sampler.Commit: seqID %d has no history", seqID))
|
||||
}
|
||||
|
||||
row := slices.Index(s.slots, slot)
|
||||
width := s.historyWidth()
|
||||
take := min(len(tokens), slot.opts.RepeatLastN)
|
||||
startLen := slot.historyLen + len(tokens) - take
|
||||
writeTokens := tokens[len(tokens)-take:]
|
||||
flatOffsets := make([]int32, take)
|
||||
for i := range take {
|
||||
ringPos := (startLen + i) % slot.opts.RepeatLastN
|
||||
flatOffsets[i] = int32(row*width + ringPos)
|
||||
}
|
||||
|
||||
flatIdx := mlx.NewArrayInt32(flatOffsets, []int32{int32(take), 1})
|
||||
values := mlx.NewArrayInt32(writeTokens, []int32{int32(take), 1})
|
||||
flatHist := s.history.Reshape(s.history.Dim(0)*width, 1)
|
||||
s.history.Set(flatHist.PutAlongAxis(flatIdx, values, 0).Reshape(s.history.Dim(0), width))
|
||||
slot.historyLen += len(tokens)
|
||||
}
|
||||
|
||||
func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
|
||||
rows := logits.Dim(0)
|
||||
draftCount := 0
|
||||
if draftTokens != nil {
|
||||
draftCount = draftTokens.Dim(1)
|
||||
}
|
||||
row := slices.Index(s.slots, slot)
|
||||
baseFill := min(slot.historyLen, slot.opts.RepeatLastN)
|
||||
var base *mlx.Array
|
||||
if baseFill > 0 {
|
||||
base = s.history.Slice(mlx.Slice(row, row+1), mlx.Slice(0, baseFill))
|
||||
}
|
||||
|
||||
scored := make([]*mlx.Array, 0, rows)
|
||||
for i := range rows {
|
||||
rowLogits := logits.Slice(mlx.Slice(i, i+1), mlx.Slice())
|
||||
hist := base
|
||||
prefixLen := min(i, draftCount)
|
||||
if prefixLen > 0 {
|
||||
prefix := draftTokens.Slice(mlx.Slice(), mlx.Slice(0, prefixLen))
|
||||
if hist == nil {
|
||||
hist = prefix
|
||||
} else {
|
||||
hist = hist.Concatenate(1, prefix)
|
||||
}
|
||||
if hist.Dim(1) > slot.opts.RepeatLastN {
|
||||
hist = hist.Slice(mlx.Slice(), mlx.Slice(hist.Dim(1)-slot.opts.RepeatLastN, mlx.End))
|
||||
}
|
||||
}
|
||||
scored = append(scored, slot.speculativeScores(&slotCtx{opts: slot.opts, history: hist}, rowLogits))
|
||||
}
|
||||
return mlx.Concatenate(scored, 0)
|
||||
}
|
||||
|
||||
func (s *Sampler) speculativeHistory(slot *slotState, draftTokens *mlx.Array, rows int) *mlx.Array {
|
||||
row := slices.Index(s.slots, slot)
|
||||
width := slot.opts.RepeatLastN
|
||||
base := s.history.Slice(mlx.Slice(row, row+1), mlx.Slice(0, width))
|
||||
base = mlx.Tile(base, []int32{int32(rows), 1})
|
||||
next := slot.historyLen % width
|
||||
draftCount := 0
|
||||
if draftTokens != nil {
|
||||
draftCount = draftTokens.Dim(1)
|
||||
}
|
||||
if draftCount == 0 {
|
||||
return base
|
||||
}
|
||||
|
||||
sourceIdx := make([]int32, rows*width)
|
||||
writeMask := make([]bool, rows*width)
|
||||
for i := range rows {
|
||||
prefixLen := min(i, draftCount)
|
||||
for j := range prefixLen {
|
||||
pos := (next + j) % width
|
||||
sourceIdx[i*width+pos] = int32(j)
|
||||
writeMask[i*width+pos] = true
|
||||
}
|
||||
}
|
||||
|
||||
draftRows := mlx.Tile(draftTokens, []int32{int32(rows), 1})
|
||||
idx := mlx.NewArrayInt32(sourceIdx, []int32{int32(rows), int32(width)})
|
||||
mask := mlx.FromValues(writeMask, rows, width)
|
||||
values := draftRows.TakeAlongAxis(idx, 1)
|
||||
return mlx.Where(mask, values, base)
|
||||
}
|
||||
|
||||
func (slot *slotState) speculativeScores(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
|
||||
scores := logits
|
||||
// buildTransforms always appends the final selector transform
|
||||
// (greedy or temperature sampling). Speculative validation needs the
|
||||
// processed logits before that selector mutates the distribution.
|
||||
for _, t := range slot.transforms[:len(slot.transforms)-1] {
|
||||
scores = t(ctx, scores)
|
||||
}
|
||||
if slot.opts.Temperature > 0 {
|
||||
scores = mlx.DivScalar(scores, slot.opts.Temperature)
|
||||
}
|
||||
return scores
|
||||
}
|
||||
|
||||
// canBatch reports whether the call can take the uniform batched path.
|
||||
// All slots must share Options; when penalties are active the call must
|
||||
// additionally cover every registered slot in registration order with a
|
||||
|
|
|
|||
|
|
@ -176,6 +176,65 @@ func TestSampleHistoryWindow(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSpeculativeScoresUsesDraftHistoryWithoutCommit(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
|
||||
s.Add(0, Options{RepeatLastN: 2, RepeatPenalty: 10}, []int32{1, 2})
|
||||
draftTokens := mlx.NewArrayInt32([]int32{3, 4}, []int32{1, 2})
|
||||
scores := s.SpeculativeScores(0, batchLogits(
|
||||
[]float32{0, 9, 9, 8, 0}, // history {1,2}; token 3 wins
|
||||
[]float32{0, 0, 9, 9, 8}, // history {2,3}; token 4 wins
|
||||
[]float32{0, 0, 9, 9, 8}, // history {3,4}; token 2 wins
|
||||
), draftTokens)
|
||||
tokens := scores.Argmax(-1, false).AsType(mlx.DTypeInt32)
|
||||
mlx.Eval(tokens)
|
||||
|
||||
if got, want := tokens.Ints(), []int{3, 4, 2}; len(got) != len(want) {
|
||||
t.Fatalf("tokens = %v, want %v", got, want)
|
||||
} else {
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("tokens = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
if s.byID[0].historyLen != 2 {
|
||||
t.Fatalf("historyLen = %d, want 2", s.byID[0].historyLen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitBatchesRingWrites(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
|
||||
s.Add(0, Options{RepeatLastN: 4, RepeatPenalty: 1.1}, []int32{10, 11, 12})
|
||||
s.Commit(0, []int32{20, 21, 22})
|
||||
s.Commit(0, []int32{30, 31, 32, 33, 34})
|
||||
mlx.Eval(s.history)
|
||||
|
||||
got := s.history.Ints()
|
||||
want := []int{32, 33, 34, 31}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("history = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
if s.byID[0].historyLen != 11 {
|
||||
t.Fatalf("historyLen = %d, want 11", s.byID[0].historyLen)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBatchSamplingPreservesPerSlotBehavior is the core equivalence test:
|
||||
// for every representative dispatch branch (uniform, serial on mixed opts,
|
||||
// serial on partial ring, subset/out-of-order), a batched Sample call must
|
||||
|
|
|
|||
390
x/models/gemma4/assistant.go
Normal file
390
x/models/gemma4/assistant.go
Normal file
|
|
@ -0,0 +1,390 @@
|
|||
package gemma4
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
var (
|
||||
_ base.DraftModel = (*AssistantModel)(nil)
|
||||
_ base.MTPDraftModel = (*AssistantModel)(nil)
|
||||
)
|
||||
|
||||
type AssistantConfig struct {
|
||||
TextConfig TextConfig `json:"text_config"`
|
||||
BackboneHiddenSize int32 `json:"backbone_hidden_size"`
|
||||
UseOrderedEmbeddings bool `json:"use_ordered_embeddings"`
|
||||
NumCentroids int32 `json:"num_centroids"`
|
||||
CentroidIntermediateTopK int32 `json:"centroid_intermediate_top_k"`
|
||||
}
|
||||
|
||||
type AssistantModel struct {
|
||||
PreProjection nn.LinearLayer
|
||||
PostProjection nn.LinearLayer
|
||||
EmbedTokens nn.EmbeddingLayer
|
||||
LMHead nn.LinearLayer
|
||||
Centroids nn.LinearLayer
|
||||
TokenOrdering *mlx.Array
|
||||
Layers []*AssistantLayer
|
||||
Norm *nn.RMSNorm
|
||||
|
||||
NormScaled *mlx.Array
|
||||
|
||||
*AssistantConfig
|
||||
tensorPrefix string
|
||||
|
||||
QuantGroupSize int
|
||||
QuantBits int
|
||||
QuantMode string
|
||||
TensorQuant map[string]*model.TensorQuantInfo
|
||||
}
|
||||
|
||||
type AssistantLayer struct {
|
||||
InputNorm *nn.RMSNorm
|
||||
PostAttnNorm *nn.RMSNorm
|
||||
PreFFNorm *nn.RMSNorm
|
||||
PostFFNorm *nn.RMSNorm
|
||||
|
||||
InputNormScaled *mlx.Array
|
||||
PostAttnNormScaled *mlx.Array
|
||||
PreFFNormScaled *mlx.Array
|
||||
PostFFNormScaled *mlx.Array
|
||||
|
||||
Attention *AssistantAttention
|
||||
MLP *MLP
|
||||
LayerScalar *mlx.Array
|
||||
IsSliding bool
|
||||
}
|
||||
|
||||
type AssistantAttention struct {
|
||||
QProj nn.LinearLayer
|
||||
OProj nn.LinearLayer
|
||||
QNorm *nn.RMSNorm
|
||||
|
||||
QNormScaled *mlx.Array
|
||||
}
|
||||
|
||||
func parseAssistantConfig(configData []byte) (AssistantConfig, error) {
|
||||
var raw struct {
|
||||
TextConfig json.RawMessage `json:"text_config"`
|
||||
|
||||
BackboneHiddenSize int32 `json:"backbone_hidden_size"`
|
||||
UseOrderedEmbeddings bool `json:"use_ordered_embeddings"`
|
||||
NumCentroids int32 `json:"num_centroids"`
|
||||
CentroidIntermediateTopK int32 `json:"centroid_intermediate_top_k"`
|
||||
}
|
||||
if err := json.Unmarshal(configData, &raw); err != nil {
|
||||
return AssistantConfig{}, fmt.Errorf("parse assistant config: %w", err)
|
||||
}
|
||||
if len(raw.TextConfig) == 0 {
|
||||
return AssistantConfig{}, fmt.Errorf("assistant config missing text_config")
|
||||
}
|
||||
|
||||
text, err := parseTextConfig(raw.TextConfig)
|
||||
if err != nil {
|
||||
return AssistantConfig{}, err
|
||||
}
|
||||
if raw.NumCentroids == 0 {
|
||||
raw.NumCentroids = 2048
|
||||
}
|
||||
if raw.CentroidIntermediateTopK == 0 {
|
||||
raw.CentroidIntermediateTopK = 32
|
||||
}
|
||||
|
||||
return AssistantConfig{
|
||||
TextConfig: text,
|
||||
BackboneHiddenSize: raw.BackboneHiddenSize,
|
||||
UseOrderedEmbeddings: raw.UseOrderedEmbeddings,
|
||||
NumCentroids: raw.NumCentroids,
|
||||
CentroidIntermediateTopK: raw.CentroidIntermediateTopK,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newAssistantModel(root *model.Root, target base.Model) (base.DraftModel, error) {
|
||||
if root == nil || root.Draft == nil {
|
||||
return nil, fmt.Errorf("draft metadata missing")
|
||||
}
|
||||
|
||||
configPath := root.Draft.Config
|
||||
if configPath == "" {
|
||||
configPath = "draft/config.json"
|
||||
}
|
||||
configData, err := root.Manifest.ReadConfig(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load draft config: %w", err)
|
||||
}
|
||||
|
||||
cfg, err := parseAssistantConfig(configData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetGemma, ok := target.(*Model)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("gemma4 assistant requires gemma4 target, got %T", target)
|
||||
}
|
||||
if cfg.BackboneHiddenSize != 0 && cfg.BackboneHiddenSize != targetGemma.HiddenSize {
|
||||
return nil, fmt.Errorf("assistant backbone hidden size %d does not match target hidden size %d", cfg.BackboneHiddenSize, targetGemma.HiddenSize)
|
||||
}
|
||||
if cfg.TextConfig.VocabSize != targetGemma.VocabSize {
|
||||
return nil, fmt.Errorf("assistant vocab size %d does not match target vocab size %d", cfg.TextConfig.VocabSize, targetGemma.VocabSize)
|
||||
}
|
||||
|
||||
tensorPrefix := root.Draft.TensorPrefix
|
||||
if tensorPrefix == "" {
|
||||
tensorPrefix = "draft."
|
||||
}
|
||||
|
||||
m := &AssistantModel{
|
||||
AssistantConfig: &cfg,
|
||||
tensorPrefix: tensorPrefix,
|
||||
Layers: make([]*AssistantLayer, cfg.TextConfig.NumHiddenLayers),
|
||||
TensorQuant: root.AllTensorQuant(),
|
||||
}
|
||||
if qt := root.QuantType(); qt != "" {
|
||||
m.QuantGroupSize, m.QuantBits, m.QuantMode = model.QuantizationParams(qt)
|
||||
if gs := root.GroupSize(); gs > 0 {
|
||||
m.QuantGroupSize = gs
|
||||
}
|
||||
}
|
||||
for i := range m.Layers {
|
||||
m.Layers[i] = &AssistantLayer{
|
||||
IsSliding: isLayerSliding(int32(i), &m.TextConfig),
|
||||
Attention: &AssistantAttention{},
|
||||
MLP: &MLP{},
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *AssistantModel) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
prefix := m.tensorPrefix
|
||||
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||
|
||||
m.PreProjection = linears.Make(prefix + "pre_projection")
|
||||
m.PostProjection = linears.Make(prefix + "post_projection")
|
||||
if m.PreProjection == nil || m.PostProjection == nil {
|
||||
return fmt.Errorf("missing assistant projection weights")
|
||||
}
|
||||
|
||||
m.EmbedTokens = model.MakeEmbeddingLayer(tensors, prefix+"model.embed_tokens", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||
if m.EmbedTokens == nil {
|
||||
return fmt.Errorf("missing assistant embedding weight")
|
||||
}
|
||||
m.LMHead = m.EmbedTokens.AsLinear()
|
||||
|
||||
if m.UseOrderedEmbeddings {
|
||||
m.Centroids = linears.Make(prefix + "masked_embedding.centroids")
|
||||
m.TokenOrdering = tensors[prefix+"masked_embedding.token_ordering"]
|
||||
if m.Centroids == nil || m.TokenOrdering == nil {
|
||||
return fmt.Errorf("missing ordered embedding tensors: %smasked_embedding.centroids.weight and %smasked_embedding.token_ordering", prefix, prefix)
|
||||
}
|
||||
m.TokenOrdering = m.TokenOrdering.AsType(mlx.DTypeInt32)
|
||||
}
|
||||
|
||||
normWeight := tensors[prefix+"model.norm.weight"]
|
||||
if normWeight == nil {
|
||||
return fmt.Errorf("missing assistant final norm")
|
||||
}
|
||||
m.Norm = nn.NewRMSNorm(normWeight, m.TextConfig.RMSNormEps)
|
||||
|
||||
for i := range m.TextConfig.NumHiddenLayers {
|
||||
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
|
||||
layer := &AssistantLayer{
|
||||
IsSliding: isLayerSliding(i, &m.TextConfig),
|
||||
Attention: &AssistantAttention{
|
||||
QProj: linears.Make(layerPrefix + ".self_attn.q_proj"),
|
||||
OProj: linears.Make(layerPrefix + ".self_attn.o_proj"),
|
||||
},
|
||||
MLP: &MLP{
|
||||
GateProj: linears.Make(layerPrefix + ".mlp.gate_proj"),
|
||||
UpProj: linears.Make(layerPrefix + ".mlp.up_proj"),
|
||||
DownProj: linears.Make(layerPrefix + ".mlp.down_proj"),
|
||||
},
|
||||
LayerScalar: tensors[layerPrefix+".layer_scalar"],
|
||||
}
|
||||
|
||||
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
|
||||
layer.InputNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
|
||||
layer.PostAttnNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".pre_feedforward_layernorm.weight"]; w != nil {
|
||||
layer.PreFFNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".post_feedforward_layernorm.weight"]; w != nil {
|
||||
layer.PostFFNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps)
|
||||
}
|
||||
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
|
||||
layer.Attention.QNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps)
|
||||
}
|
||||
|
||||
if layer.InputNorm == nil || layer.PostAttnNorm == nil || layer.PreFFNorm == nil || layer.PostFFNorm == nil {
|
||||
return fmt.Errorf("assistant layer %d: missing norm weights", i)
|
||||
}
|
||||
if layer.Attention.QProj == nil || layer.Attention.OProj == nil || layer.Attention.QNorm == nil {
|
||||
return fmt.Errorf("assistant layer %d: missing attention weights", i)
|
||||
}
|
||||
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
|
||||
return fmt.Errorf("assistant layer %d: missing mlp weights", i)
|
||||
}
|
||||
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
m.precomputeScaledWeights()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AssistantModel) precomputeScaledWeights() {
|
||||
if m.Norm != nil {
|
||||
m.NormScaled = m.Norm.Weight
|
||||
}
|
||||
for _, layer := range m.Layers {
|
||||
if layer.InputNorm != nil {
|
||||
layer.InputNormScaled = layer.InputNorm.Weight
|
||||
}
|
||||
if layer.PostAttnNorm != nil {
|
||||
layer.PostAttnNormScaled = layer.PostAttnNorm.Weight
|
||||
}
|
||||
if layer.PreFFNorm != nil {
|
||||
layer.PreFFNormScaled = layer.PreFFNorm.Weight
|
||||
}
|
||||
if layer.PostFFNorm != nil {
|
||||
layer.PostFFNormScaled = layer.PostFFNorm.Weight
|
||||
}
|
||||
if layer.Attention != nil && layer.Attention.QNorm != nil {
|
||||
layer.Attention.QNormScaled = layer.Attention.QNorm.Weight
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AssistantModel) Draft(inputsEmbeds *mlx.Array, position int32, caches []cache.Cache) (logits, hidden *mlx.Array) {
|
||||
dims := inputsEmbeds.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
b := &batch.Batch{
|
||||
InputIDs: mlx.Zeros(mlx.DTypeInt32, int(B), int(L)),
|
||||
SeqOffsets: []int32{position},
|
||||
SeqQueryLens: []int32{L},
|
||||
}
|
||||
|
||||
sliding, full := m.sharedHistories(b, caches)
|
||||
h := m.PreProjection.Forward(inputsEmbeds)
|
||||
|
||||
positions := mlx.FromValues([]int32{position}, 1)
|
||||
for _, layer := range m.Layers {
|
||||
h = layer.Forward(h, b, positions, B, L, &m.TextConfig, sliding, full)
|
||||
}
|
||||
|
||||
hidden = mlx.RMSNormFn(h, m.NormScaled, m.TextConfig.RMSNormEps)
|
||||
projected := m.PostProjection.Forward(hidden)
|
||||
return m.unembed(hidden), projected
|
||||
}
|
||||
|
||||
func (m *AssistantModel) sharedHistories(b *batch.Batch, caches []cache.Cache) (sliding, full *nn.KVHistory) {
|
||||
if len(caches) < 2 {
|
||||
return nil, nil
|
||||
}
|
||||
if v, ok := caches[len(caches)-2].(cache.Viewer); ok {
|
||||
sliding = v.View(b)
|
||||
}
|
||||
if v, ok := caches[len(caches)-1].(cache.Viewer); ok {
|
||||
full = v.View(b)
|
||||
}
|
||||
return sliding, full
|
||||
}
|
||||
|
||||
func (m *AssistantModel) unembed(hidden *mlx.Array) *mlx.Array {
|
||||
if m.UseOrderedEmbeddings {
|
||||
return m.applyCentroidMasking(hidden)
|
||||
}
|
||||
return m.LMHead.Forward(hidden)
|
||||
}
|
||||
|
||||
func (m *AssistantModel) applyCentroidMasking(hidden *mlx.Array) *mlx.Array {
|
||||
B, L := hidden.Dim(0), hidden.Dim(1)
|
||||
vocab := int(m.TextConfig.VocabSize)
|
||||
numCentroids := int(m.NumCentroids)
|
||||
vocabPerCentroid := vocab / numCentroids
|
||||
topK := int(m.CentroidIntermediateTopK)
|
||||
|
||||
centroidLogits := m.Centroids.Forward(hidden)
|
||||
topKIndices := centroidLogits.Negative().ArgpartitionAxis(topK-1, -1).Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, topK))
|
||||
ordering := m.TokenOrdering.Reshape(numCentroids, vocabPerCentroid)
|
||||
selectedCanonical := ordering.TakeAxis(topKIndices, 0)
|
||||
selectedFlat := selectedCanonical.Reshape(B * L * topK * vocabPerCentroid)
|
||||
|
||||
embeddings := m.EmbedTokens.Forward(selectedFlat)
|
||||
embeddings = embeddings.Reshape(B, L, topK*vocabPerCentroid, int(m.TextConfig.HiddenSize))
|
||||
selectedLogits := hidden.ExpandDims(2).Matmul(embeddings.Transpose(0, 1, 3, 2)).Squeeze(2)
|
||||
|
||||
out := mlx.Zeros(selectedLogits.DType(), B, L, vocab)
|
||||
out = mlx.AddScalar(out, -1.0e30)
|
||||
return out.PutAlongAxis(selectedCanonical.Reshape(B, L, topK*vocabPerCentroid), selectedLogits, -1)
|
||||
}
|
||||
|
||||
func (l *AssistantLayer) Forward(x *mlx.Array, b *batch.Batch, positions *mlx.Array, B, L int32, cfg *TextConfig, sliding, full *nn.KVHistory) *mlx.Array {
|
||||
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
attnOut := l.Attention.Forward(normed, b, positions, B, L, l.IsSliding, cfg, sliding, full)
|
||||
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := mlx.Add(x, attnOut)
|
||||
|
||||
normed = mlx.RMSNormFn(h, l.PreFFNormScaled, cfg.RMSNormEps)
|
||||
mlpOut := l.MLP.Forward(normed)
|
||||
mlpOut = mlx.RMSNormFn(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
|
||||
h = mlx.Add(h, mlpOut)
|
||||
|
||||
if l.LayerScalar != nil {
|
||||
h = mlx.Mul(h, l.LayerScalar)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (a *AssistantAttention) Forward(x *mlx.Array, b *batch.Batch, positions *mlx.Array, B, L int32, isSliding bool, cfg *TextConfig, sliding, full *nn.KVHistory) *mlx.Array {
|
||||
headDim := cfg.HeadDim
|
||||
scale := cfg.SlidingScale
|
||||
ropeDims := cfg.SlidingRopeDims
|
||||
ropeBase := cfg.SlidingRopeBase
|
||||
history := sliding
|
||||
if !isSliding {
|
||||
headDim = cfg.GlobalHeadDim
|
||||
scale = cfg.FullScale
|
||||
ropeDims = cfg.FullRopeDims
|
||||
ropeBase = cfg.FullRopeBase
|
||||
history = full
|
||||
}
|
||||
if history == nil {
|
||||
panic("gemma4 assistant missing shared target KV history")
|
||||
}
|
||||
|
||||
q := a.QProj.Forward(x)
|
||||
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim)
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
q = mlx.RMSNormFn(q, a.QNormScaled, cfg.RMSNormEps)
|
||||
|
||||
var ropeFreqs *mlx.Array
|
||||
if !isSliding {
|
||||
ropeFreqs = cfg.FullRopeFreqs
|
||||
}
|
||||
q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, positions, ropeFreqs)
|
||||
|
||||
mask := nn.CausalMask()
|
||||
if isSliding && cfg.SlidingWindow > 0 {
|
||||
mask = mask.Intersect(nn.SlidingWindowMask(b, history.K().Dim(2), int(cfg.SlidingWindow), q.DType()))
|
||||
}
|
||||
|
||||
out := nn.ScaledDotProductAttention(b, q, scale, nn.WithKVHistory(history), nn.WithMask(mask))
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*headDim)
|
||||
if !mlx.MetalIsAvailable() {
|
||||
out = mlx.Contiguous(out, false)
|
||||
}
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
|
@ -18,10 +18,15 @@ import (
|
|||
func init() {
|
||||
base.Register("Gemma4ForCausalLM", newModel)
|
||||
base.Register("Gemma4ForConditionalGeneration", newModel)
|
||||
base.RegisterDraft("Gemma4AssistantForCausalLM", newAssistantModel)
|
||||
base.RegisterDraft("gemma4_assistant", newAssistantModel)
|
||||
}
|
||||
|
||||
// Compile-time interface checks.
|
||||
var _ base.Model = (*Model)(nil)
|
||||
var (
|
||||
_ base.Model = (*Model)(nil)
|
||||
_ base.MTPDefaultsProvider = (*Model)(nil)
|
||||
)
|
||||
|
||||
// RopeParams holds per-layer-type RoPE settings.
|
||||
type RopeParams struct {
|
||||
|
|
@ -466,6 +471,24 @@ func (m *Model) EnableCompile() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (m *Model) MTPDraftDefaults(_ bool) base.MTPDefaults {
|
||||
defaults := base.MTPDefaults{
|
||||
InitialDraftTokens: 4,
|
||||
MaxDraftTokens: 16,
|
||||
Enabled: true,
|
||||
}
|
||||
if m == nil || m.TextConfig == nil {
|
||||
return defaults
|
||||
}
|
||||
switch {
|
||||
case !m.EnableMoeBlock && m.HiddenSize == 5376 && m.NumHiddenLayers == 60:
|
||||
defaults.InitialDraftTokens = 14
|
||||
case m.EnableMoeBlock && m.HiddenSize == 2816 && m.NumHiddenLayers == 30:
|
||||
defaults.InitialDraftTokens = 8
|
||||
}
|
||||
return defaults
|
||||
}
|
||||
|
||||
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
|
||||
for _, prefix := range []string{"", "language_model.", "model.language_model."} {
|
||||
if tensors[prefix+"embed_tokens.weight"] != nil {
|
||||
|
|
@ -1068,6 +1091,11 @@ func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
|||
return m.tok
|
||||
}
|
||||
|
||||
// TokenEmbeddings returns the target model's scaled token embeddings for MTP.
|
||||
func (m *Model) TokenEmbeddings(inputIDs *mlx.Array) *mlx.Array {
|
||||
return mlx.MulScalar(m.EmbedTokens.Forward(inputIDs), m.EmbedScale)
|
||||
}
|
||||
|
||||
// NewCaches creates cache objects for layers that own KV state.
|
||||
func (m *Model) NewCaches() []cache.Cache {
|
||||
cacheLayers := len(m.Layers)
|
||||
|
|
|
|||
|
|
@ -434,6 +434,54 @@ func TestLayerTypeDetection(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestMTPDraftDefaults(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *TextConfig
|
||||
wantInitial int
|
||||
wantMax int
|
||||
}{
|
||||
{
|
||||
name: "nil config",
|
||||
wantInitial: 4,
|
||||
wantMax: 16,
|
||||
},
|
||||
{
|
||||
name: "31b bf16",
|
||||
cfg: &TextConfig{HiddenSize: 5376, NumHiddenLayers: 60},
|
||||
wantInitial: 14,
|
||||
wantMax: 16,
|
||||
},
|
||||
{
|
||||
name: "31b quantized",
|
||||
cfg: &TextConfig{HiddenSize: 5376, NumHiddenLayers: 60, QuantBits: 4},
|
||||
wantInitial: 14,
|
||||
wantMax: 16,
|
||||
},
|
||||
{
|
||||
name: "26b-a4b moe",
|
||||
cfg: &TextConfig{HiddenSize: 2816, NumHiddenLayers: 30, EnableMoeBlock: true},
|
||||
wantInitial: 8,
|
||||
wantMax: 16,
|
||||
},
|
||||
{
|
||||
name: "generic default",
|
||||
cfg: &TextConfig{HiddenSize: 2560, NumHiddenLayers: 42, HiddenSizePerLayer: 256},
|
||||
wantInitial: 4,
|
||||
wantMax: 16,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := (&Model{TextConfig: tt.cfg}).MTPDraftDefaults(false)
|
||||
if got.InitialDraftTokens != tt.wantInitial || got.MaxDraftTokens != tt.wantMax || !got.Enabled {
|
||||
t.Fatalf("MTPDraftDefaults() = %+v, want initial=%d max=%d enabled=true", got, tt.wantInitial, tt.wantMax)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCachesOmitsSharedKVLayers(t *testing.T) {
|
||||
m := &Model{
|
||||
Layers: []*DecoderLayer{
|
||||
|
|
@ -467,6 +515,55 @@ func TestNewCachesIncludesAllNonSharedLayers(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNewCachesAssistantSharedHistoryOrdering(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
totalLayers int
|
||||
slidingBeforeFull int
|
||||
cacheLayers int
|
||||
}{
|
||||
{name: "31B", totalLayers: 60, slidingBeforeFull: 5, cacheLayers: 60},
|
||||
{name: "26B-A4B", totalLayers: 30, slidingBeforeFull: 5, cacheLayers: 30},
|
||||
{name: "E4B", totalLayers: 42, slidingBeforeFull: 5, cacheLayers: 24},
|
||||
{name: "E2B", totalLayers: 35, slidingBeforeFull: 4, cacheLayers: 15},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
groupSize := tc.slidingBeforeFull + 1
|
||||
layers := make([]*DecoderLayer, tc.totalLayers)
|
||||
for i := range layers {
|
||||
donor := int32(-1)
|
||||
if i >= tc.cacheLayers {
|
||||
donor = 0
|
||||
}
|
||||
layers[i] = &DecoderLayer{
|
||||
IsSliding: i%groupSize < tc.slidingBeforeFull,
|
||||
KVShareDonor: donor,
|
||||
}
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: layers,
|
||||
TextConfig: &TextConfig{SlidingWindow: 512},
|
||||
}
|
||||
caches := m.NewCaches()
|
||||
if got := len(caches); got != tc.cacheLayers {
|
||||
t.Fatalf("len(NewCaches()) = %d, want %d", got, tc.cacheLayers)
|
||||
}
|
||||
|
||||
gotSliding := len(caches) - 2
|
||||
gotFull := len(caches) - 1
|
||||
if !m.Layers[gotSliding].IsSliding {
|
||||
t.Fatalf("cache %d should be sliding attention", gotSliding)
|
||||
}
|
||||
if m.Layers[gotFull].IsSliding {
|
||||
t.Fatalf("cache %d should be full attention", gotFull)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveWeightPrefix(t *testing.T) {
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue