ollama/x/mlxrunner/mtp.go
Patrick Devine 15e6076d79
mlx: Gemma4 MTP speculative decoding (#15980)
This change adds support for MTP (multi-token prediction) speculative decoding for the
gemma4 model family.

It includes:
  * support for importing safetensors based gemma4 draft models with `ollama create`
  * a new DRAFT command in the Modelfile for specifying draft models
  * a --quantize-draft flag for the ollama create command to quantize the draft model
  * cache support for speculation
  * changes to the rotating cache to be able to handle MTP correctly
  * sampling support for draft model token prediction

---------

Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
2026-05-05 08:55:04 -07:00

959 lines
29 KiB
Go

package mlxrunner
import (
"context"
"fmt"
"log/slog"
"math"
"os"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
)
const (
mtpDefaultInitialDraftTokens = 4
mtpDefaultMaxDraftTokens = 16
)
type mtpDraftSchedule string
const (
mtpDraftScheduleHeuristic mtpDraftSchedule = "heuristic"
mtpDraftScheduleConstant mtpDraftSchedule = "constant"
)
type mtpStats struct {
iterations int
drafted int
accepted int
mismatches int
allAccepted int
batched int
serial int
compared int
batchSerialMismatches int
maxDraft int
targetDuration time.Duration
draftDuration time.Duration
validateDuration time.Duration
}
type mtpOptions struct {
initialDraftTokens int
maxDraftTokens int
draftSchedule mtpDraftSchedule
serialValidate bool
compareSerialValidate bool
}
func (r *Runner) mtpDefaults(sample bool) base.MTPDefaults {
defaults := base.MTPDefaults{
InitialDraftTokens: mtpDefaultInitialDraftTokens,
MaxDraftTokens: mtpDefaultMaxDraftTokens,
Enabled: true,
}
if p, ok := r.Model.(base.MTPDefaultsProvider); ok {
defaults = p.MTPDraftDefaults(sample)
}
if defaults.InitialDraftTokens <= 0 {
defaults.InitialDraftTokens = mtpDefaultInitialDraftTokens
}
if defaults.MaxDraftTokens <= 0 {
defaults.MaxDraftTokens = mtpDefaultMaxDraftTokens
}
return defaults
}
func (r *Runner) loadMTPOptions(sample bool) mtpOptions {
defaults := r.mtpDefaults(sample)
opts := mtpOptions{
initialDraftTokens: defaults.InitialDraftTokens,
maxDraftTokens: defaults.MaxDraftTokens,
draftSchedule: mtpDraftScheduleConstant,
}
if v := positiveEnvInt("OLLAMA_MLX_MTP_MAX_DRAFT_TOKENS"); v > 0 {
opts.maxDraftTokens = v
}
if v := positiveEnvInt("OLLAMA_MLX_MTP_INITIAL_DRAFT_TOKENS"); v > 0 {
opts.initialDraftTokens = v
}
if opts.initialDraftTokens > opts.maxDraftTokens {
opts.initialDraftTokens = opts.maxDraftTokens
}
if b, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_SERIAL_VALIDATE")); err == nil {
opts.serialValidate = b
}
if b, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_COMPARE_SERIAL_VALIDATE")); err == nil {
opts.compareSerialValidate = b
}
switch schedule := strings.ToLower(strings.TrimSpace(os.Getenv("OLLAMA_MLX_MTP_DRAFT_SCHEDULE"))); schedule {
case "", string(mtpDraftScheduleConstant):
opts.draftSchedule = mtpDraftScheduleConstant
case string(mtpDraftScheduleHeuristic):
opts.draftSchedule = mtpDraftScheduleHeuristic
default:
slog.Warn("invalid MTP env setting", "key", "OLLAMA_MLX_MTP_DRAFT_SCHEDULE", "value", schedule)
}
return opts
}
func positiveEnvInt(key string) int {
raw := os.Getenv(key)
if raw == "" {
return 0
}
v, err := strconv.Atoi(raw)
if err != nil || v <= 0 {
slog.Warn("invalid MTP env setting", "key", key, "value", raw)
return 0
}
return v
}
func (r *Runner) useGreedyMTP(opts sampler.Options) bool {
if r.Draft == nil {
return false
}
if _, ok := r.Draft.(base.MTPDraftModel); !ok {
return false
}
if _, ok := r.Model.(base.MTPEmbeddingModel); !ok {
return false
}
if !r.mtpDefaults(false).Enabled {
return false
}
if opts.Logprobs || opts.TopLogprobs > 0 {
return false
}
if opts.Temperature != 0 {
return false
}
repeatPenaltyNeutral := opts.RepeatPenalty <= 0 || opts.RepeatPenalty == 1
topPNeutral := opts.TopP <= 0 || opts.TopP >= 1
topKNeutral := opts.TopK <= 0
return repeatPenaltyNeutral && opts.PresencePenalty == 0 && opts.FrequencyPenalty == 0 && topPNeutral && topKNeutral && opts.MinP == 0
}
func (r *Runner) useSampleMTP(opts sampler.Options) bool {
if serial, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_SERIAL_VALIDATE")); err == nil && serial {
return false
}
if compare, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_COMPARE_SERIAL_VALIDATE")); err == nil && compare {
return false
}
if r.Draft == nil {
return false
}
if _, ok := r.Draft.(base.MTPDraftModel); !ok {
return false
}
if _, ok := r.Model.(base.MTPEmbeddingModel); !ok {
return false
}
if !r.mtpDefaults(true).Enabled {
return false
}
if opts.Logprobs || opts.TopLogprobs > 0 {
return false
}
return opts.Temperature != 0
}
func (r *Runner) runGreedyMTPDecode(ctx context.Context, request Request, session *cacheSession, caches []cache.Cache, seed []int32, position *int, started time.Time) error {
targetEmbeddings := r.Model.(base.MTPEmbeddingModel)
draft := r.Draft.(base.MTPDraftModel)
mtpOpts := r.loadMTPOptions(false)
stats := mtpStats{maxDraft: mtpOpts.initialDraftTokens}
draftLimit := mtpOpts.initialDraftTokens
slog.Info("MTP greedy decode enabled", "initial_draft_tokens", mtpOpts.initialDraftTokens, "max_draft_tokens", mtpOpts.maxDraftTokens, "draft_schedule", mtpOpts.draftSchedule, "serial_validate", mtpOpts.serialValidate, "compare_serial_validate", mtpOpts.compareSerialValidate)
targetForward := func(token *mlx.Array) *mlx.Array {
fwd := r.Model.Forward(&batch.Batch{
InputIDs: token,
SeqOffsets: []int32{int32(*position)},
SeqQueryLens: []int32{int32(token.Dim(1))},
}, caches)
*position += token.Dim(1)
return fwd
}
hidden := targetForward(mlx.FromValues(seed, 1, len(seed)))
current := sampler.Result{Token: greedyTokenFromLogits(r.lastLogits(hidden))}
mlx.Pin(current.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
defer func() {
mlx.Unpin(current.Arrays()...)
}()
dec := decoder{tokenizer: r.Tokenizer}
final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1}
now := started
generated := 0
for generated < request.Options.NumPredict {
if err := ctx.Err(); err != nil {
return err
}
t0 := time.Now()
hidden = targetForward(current.Token.ExpandDims(-1))
baseLogits := r.lastLogits(hidden)
stats.targetDuration += time.Since(t0)
if generated == 0 {
mlx.Eval(current.Arrays()...)
final.PromptEvalDuration = time.Since(now)
now = time.Now()
}
done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final)
if err != nil {
return err
}
if !done {
generated++
}
if done || generated >= request.Options.NumPredict {
break
}
stats.iterations++
maxDraft := min(draftLimit, request.Options.NumPredict-generated)
t0 = time.Now()
draftTokens := r.generateMTPDrafts(draft, targetEmbeddings, current.Token, hidden, caches, int32(*position-1), maxDraft)
draftCount := 0
if draftTokens != nil {
draftCount = draftTokens.Dim(1)
mlx.Pin(baseLogits, draftTokens)
mlx.Eval(draftTokens)
mlx.Sweep()
}
stats.draftDuration += time.Since(t0)
stats.drafted += draftCount
var next sampler.Result
if draftCount == 0 {
next = sampler.Result{Token: greedyTokenFromLogits(baseLogits)}
} else {
var accepted int
t0 = time.Now()
next, accepted, done, err = r.acceptMTPDrafts(ctx, request, session, &dec, caches, position, baseLogits, draftTokens, &final, &generated, &stats, mtpOpts)
stats.validateDuration += time.Since(t0)
mlx.Unpin(baseLogits, draftTokens)
if err != nil {
return err
}
stats.accepted += accepted
switch {
case mtpOpts.draftSchedule == mtpDraftScheduleConstant:
case accepted == draftCount:
stats.allAccepted++
draftLimit = min(mtpOpts.maxDraftTokens, draftLimit+2)
default:
stats.mismatches++
draftLimit = max(1, draftLimit-1)
}
if mtpOpts.draftSchedule == mtpDraftScheduleConstant {
if accepted == draftCount {
stats.allAccepted++
} else {
stats.mismatches++
}
}
stats.maxDraft = max(stats.maxDraft, draftLimit)
if next.Token == nil {
mlx.Sweep()
}
if done || generated >= request.Options.NumPredict {
break
}
}
mlx.Pin(next.Arrays()...)
old := current
current = next
mlx.Unpin(old.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
if generated%256 == 0 {
mlx.ClearCache()
}
}
final.EvalCount = generated
final.EvalDuration = time.Since(now)
acceptance := 0.0
if stats.drafted > 0 {
acceptance = float64(stats.accepted) / float64(stats.drafted)
}
avgDraft := 0.0
avgAccepted := 0.0
if stats.iterations > 0 {
avgDraft = float64(stats.drafted) / float64(stats.iterations)
avgAccepted = float64(stats.accepted) / float64(stats.iterations)
}
slog.Info("MTP decode stats", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "compared", stats.compared, "batch_serial_mismatches", stats.batchSerialMismatches, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", stats.maxDraft, "draft_schedule", mtpOpts.draftSchedule, "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration)
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- final:
return nil
}
}
func (r *Runner) runSampleMTPDecode(ctx context.Context, request Request, session *cacheSession, caches []cache.Cache, seed []int32, position *int, started time.Time) error {
targetEmbeddings := r.Model.(base.MTPEmbeddingModel)
draft := r.Draft.(base.MTPDraftModel)
mtpOpts := r.loadMTPOptions(true)
stats := mtpStats{maxDraft: mtpOpts.initialDraftTokens}
draftLimit := mtpOpts.initialDraftTokens
slog.Info("MTP sample decode enabled", "initial_draft_tokens", mtpOpts.initialDraftTokens, "max_draft_tokens", mtpOpts.maxDraftTokens, "draft_schedule", mtpOpts.draftSchedule, "serial_validate", mtpOpts.serialValidate)
targetForward := func(token *mlx.Array) *mlx.Array {
fwd := r.Model.Forward(&batch.Batch{
InputIDs: token,
SeqOffsets: []int32{int32(*position)},
SeqQueryLens: []int32{int32(token.Dim(1))},
}, caches)
*position += token.Dim(1)
return fwd
}
hidden := targetForward(mlx.FromValues(seed, 1, len(seed)))
current := r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden))
mlx.Pin(current.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
defer func() {
mlx.Unpin(current.Arrays()...)
}()
dec := decoder{tokenizer: r.Tokenizer}
final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1}
now := started
generated := 0
for generated < request.Options.NumPredict {
if err := ctx.Err(); err != nil {
return err
}
t0 := time.Now()
hidden = targetForward(mtpTokenInput(current.Token))
baseLogits := r.lastLogits(hidden)
stats.targetDuration += time.Since(t0)
if generated == 0 {
mlx.Eval(current.Arrays()...)
final.PromptEvalDuration = time.Since(now)
now = time.Now()
}
done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final)
if err != nil {
return err
}
if !done {
generated++
}
if done || generated >= request.Options.NumPredict {
break
}
stats.iterations++
maxDraft := min(draftLimit, request.Options.NumPredict-generated)
t0 = time.Now()
candidates := r.generateMTPDraftCandidates(draft, targetEmbeddings, current.Token, hidden, caches, int32(*position-1), maxDraft)
draftCount := 0
if candidates != nil {
draftCount = candidates.tokens.Dim(1)
mlx.Pin(baseLogits, candidates.tokens, candidates.logits)
mlx.Sweep()
}
stats.draftDuration += time.Since(t0)
stats.drafted += draftCount
var next sampler.Result
if draftCount == 0 {
next = r.Sampler.Sample([]int{pipelineSlot}, baseLogits)
} else {
var accepted int
t0 = time.Now()
next, accepted, done, err = r.acceptSampleMTPDrafts(ctx, request, session, &dec, caches, position, baseLogits, candidates, &final, &generated, &stats)
stats.validateDuration += time.Since(t0)
mlx.Unpin(baseLogits, candidates.tokens, candidates.logits)
if err != nil {
return err
}
stats.accepted += accepted
switch {
case mtpOpts.draftSchedule == mtpDraftScheduleConstant:
case accepted == draftCount:
stats.allAccepted++
draftLimit = min(mtpOpts.maxDraftTokens, draftLimit+2)
default:
stats.mismatches++
draftLimit = max(1, draftLimit-1)
}
if mtpOpts.draftSchedule == mtpDraftScheduleConstant {
if accepted == draftCount {
stats.allAccepted++
} else {
stats.mismatches++
}
}
stats.maxDraft = max(stats.maxDraft, draftLimit)
if next.Token == nil {
mlx.Sweep()
}
if done || generated >= request.Options.NumPredict {
break
}
}
mlx.Pin(next.Arrays()...)
old := current
current = next
mlx.Unpin(old.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
if generated%256 == 0 {
mlx.ClearCache()
}
}
final.EvalCount = generated
final.EvalDuration = time.Since(now)
acceptance := 0.0
if stats.drafted > 0 {
acceptance = float64(stats.accepted) / float64(stats.drafted)
}
avgDraft := 0.0
avgAccepted := 0.0
if stats.iterations > 0 {
avgDraft = float64(stats.drafted) / float64(stats.iterations)
avgAccepted = float64(stats.accepted) / float64(stats.iterations)
}
slog.Info("MTP decode stats", "mode", "sample", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", stats.maxDraft, "draft_schedule", mtpOpts.draftSchedule, "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration)
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- final:
return nil
}
}
type mtpDraftCandidates struct {
tokens *mlx.Array
// logits are the processed proposal scores used to sample tokens.
logits *mlx.Array
}
func (r *Runner) generateMTPDrafts(draft base.MTPDraftModel, target base.MTPEmbeddingModel, token *mlx.Array, hidden *mlx.Array, caches []cache.Cache, position int32, maxDraft int) *mlx.Array {
if maxDraft <= 0 {
return nil
}
lastToken := token.ExpandDims(-1)
lastHidden := hidden
draftTokens := make([]*mlx.Array, 0, maxDraft)
// Gemma4 assistant MTP is trained as "single-position" drafting:
// keep the RoPE/cache position anchored at the last target-seen token
// while the proposed token and projected hidden state advance.
for range maxDraft {
tokenEmbedding := target.TokenEmbeddings(lastToken)
inputs := tokenEmbedding.Concatenate(-1, lastHidden)
logits, projected := draft.Draft(inputs, position, caches)
stepLogits := r.lastLogitsFromLogits(logits)
nextToken := greedyTokenFromLogits(stepLogits)
lastToken = nextToken.ExpandDims(-1)
lastHidden = projected
draftTokens = append(draftTokens, lastToken)
}
if len(draftTokens) == 0 {
return nil
}
return mlx.Concatenate(draftTokens, 1)
}
func (r *Runner) generateMTPDraftCandidates(draft base.MTPDraftModel, target base.MTPEmbeddingModel, token *mlx.Array, hidden *mlx.Array, caches []cache.Cache, position int32, maxDraft int) *mtpDraftCandidates {
if maxDraft <= 0 {
return nil
}
lastToken := mtpTokenInput(token)
lastHidden := hidden
draftTokens := make([]*mlx.Array, 0, maxDraft)
draftLogits := make([]*mlx.Array, 0, maxDraft)
var prefix *mlx.Array
// Gemma4 assistant MTP is trained as "single-position" drafting:
// keep the RoPE/cache position anchored at the last target-seen token
// while the proposed token and projected hidden state advance.
for range maxDraft {
tokenEmbedding := target.TokenEmbeddings(lastToken)
inputs := tokenEmbedding.Concatenate(-1, lastHidden)
logits, projected := draft.Draft(inputs, position, caches)
stepLogits := r.lastLogitsFromLogits(logits)
stepScores := r.Sampler.SpeculativeScores(pipelineSlot, stepLogits, prefix)
nextToken := stepScores.Categorical(-1).AsType(mlx.DTypeInt32)
lastToken = mtpTokenInput(nextToken)
lastHidden = projected
draftTokens = append(draftTokens, lastToken)
draftLogits = append(draftLogits, stepScores.ExpandDims(1))
if prefix == nil {
prefix = lastToken
} else {
prefix = prefix.Concatenate(1, lastToken)
}
}
if len(draftTokens) == 0 {
return nil
}
return &mtpDraftCandidates{
tokens: mlx.Concatenate(draftTokens, 1),
logits: mlx.Concatenate(draftLogits, 1),
}
}
func (r *Runner) acceptMTPDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int, stats *mtpStats, opts mtpOptions) (sampler.Result, int, bool, error) {
if opts.serialValidate {
stats.serial++
return r.acceptMTPDraftsSerial(ctx, request, session, dec, caches, position, baseLogits, draftTokens, final, generated)
}
specCaches, spec, ok := cache.BeginSpeculation(caches)
if ok {
stats.batched++
return r.acceptMTPDraftsBatched(ctx, request, session, dec, caches, specCaches, spec, position, baseLogits, draftTokens, final, generated, stats, opts)
}
stats.serial++
return r.acceptMTPDraftsSerial(ctx, request, session, dec, caches, position, baseLogits, draftTokens, final, generated)
}
func (r *Runner) acceptMTPDraftsBatched(ctx context.Context, request Request, session *cacheSession, dec *decoder, liveCaches []cache.Cache, caches []cache.Cache, spec *cache.Speculation, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int, stats *mtpStats, opts mtpOptions) (sampler.Result, int, bool, error) {
before := *position
draftCount := draftTokens.Dim(1)
hiddenSeq := r.Model.Forward(&batch.Batch{
InputIDs: draftTokens,
SeqOffsets: []int32{int32(before)},
SeqQueryLens: []int32{int32(draftCount)},
}, caches)
accepted := 0
var next sampler.Result
done := false
selectedTokens := r.mtpValidationTokens(baseLogits, hiddenSeq)
mlx.Eval(draftTokens, selectedTokens)
draftIDs := draftTokens.Ints()
selectedIDs := selectedTokens.Ints()
if len(selectedIDs) < draftCount+1 {
return sampler.Result{}, accepted, false, fmt.Errorf("mtp validation produced %d tokens for %d draft tokens", len(selectedIDs), draftCount)
}
for i, id := range draftIDs {
if selectedIDs[i] != id {
next = sampler.Result{Token: mtpTokenAt(selectedTokens, i)}
break
}
accepted++
if r.Tokenizer.IsEOS(int32(id)) {
done = true
break
}
}
if opts.compareSerialValidate {
spec.Commit(0)
r.compareMTPBatchedWithSerial(ctx, liveCaches, before, baseLogits, hiddenSeq, draftIDs, selectedIDs, accepted, draftCount, stats)
}
spec.Commit(accepted)
*position = before + accepted
for _, id := range draftIDs[:accepted] {
if *generated >= request.Options.NumPredict {
done = true
break
}
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
var err error
done, err = r.emitMTPToken(ctx, request, session, dec, res, final)
if err != nil {
return sampler.Result{}, accepted, done, err
}
if !done {
(*generated)++
}
if done {
break
}
}
if done || *generated >= request.Options.NumPredict {
return sampler.Result{}, accepted, true, nil
}
if next.Token == nil {
next = sampler.Result{Token: mtpTokenAt(selectedTokens, draftCount)}
}
return next, accepted, false, nil
}
func (r *Runner) acceptSampleMTPDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, candidates *mtpDraftCandidates, final *CompletionResponse, generated *int, stats *mtpStats) (sampler.Result, int, bool, error) {
specCaches, spec, ok := cache.BeginSpeculation(caches)
if !ok {
stats.serial++
return r.Sampler.Sample([]int{pipelineSlot}, baseLogits), 0, false, nil
}
stats.batched++
before := *position
draftCount := candidates.tokens.Dim(1)
hiddenSeq := r.Model.Forward(&batch.Batch{
InputIDs: candidates.tokens,
SeqOffsets: []int32{int32(before)},
SeqQueryLens: []int32{int32(draftCount)},
}, specCaches)
targetScores := r.Sampler.SpeculativeScores(pipelineSlot, r.mtpValidationLogits(baseLogits, hiddenSeq), candidates.tokens)
draftScores := candidates.logits
if draftScores.NumDims() == 3 {
draftScores = draftScores.Squeeze(0)
}
acceptedMask := mtpSampleAcceptedMask(targetScores, draftScores, candidates.tokens, draftCount)
mlx.Eval(candidates.tokens, acceptedMask)
draftIDs := candidates.tokens.Ints()
acceptedFlags := acceptedMask.Ints()
accepted := 0
for _, ok := range acceptedFlags {
if ok == 0 {
break
}
accepted++
}
if accepted > draftCount {
return sampler.Result{}, 0, false, fmt.Errorf("mtp sample validation accepted %d tokens for %d draft tokens", accepted, draftCount)
}
commitIDs := make([]int32, 0, accepted+1)
done := false
for i, id := range draftIDs[:accepted] {
commitIDs = append(commitIDs, int32(id))
if r.Tokenizer.IsEOS(int32(id)) {
done = true
accepted = i + 1
commitIDs = commitIDs[:accepted]
break
}
}
spec.Commit(accepted)
*position = before + accepted
for _, id := range draftIDs[:accepted] {
if *generated >= request.Options.NumPredict {
done = true
break
}
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
var err error
done, err = r.emitMTPToken(ctx, request, session, dec, res, final)
if err != nil {
return sampler.Result{}, accepted, done, err
}
if !done {
(*generated)++
}
if done {
break
}
}
if done || *generated >= request.Options.NumPredict {
r.Sampler.Commit(pipelineSlot, commitIDs)
return sampler.Result{}, accepted, true, nil
}
var nextToken *mlx.Array
if accepted == draftCount {
nextToken = mtpSampleTokenAt(targetScores, draftCount)
} else {
nextToken = mtpSampleResidualToken(targetScores, draftScores, accepted)
}
mlx.Eval(nextToken)
nextID := int32(tokenID(nextToken))
commitIDs = append(commitIDs, nextID)
r.Sampler.Commit(pipelineSlot, commitIDs)
return sampler.Result{Token: nextToken}, accepted, false, nil
}
func mtpSampleAcceptedMask(targetScores, draftScores, draftTokens *mlx.Array, draftCount int) *mlx.Array {
targetProbs := mlx.SoftmaxAxis(targetScores.Slice(mlx.Slice(0, draftCount), mlx.Slice()), -1, true)
draftProbs := mlx.SoftmaxAxis(draftScores, -1, true)
if draftTokens.NumDims() == 2 {
draftTokens = draftTokens.Squeeze(0)
}
indices := draftTokens.ExpandDims(-1)
p := targetProbs.TakeAlongAxis(indices, -1).Squeeze(-1)
q := draftProbs.TakeAlongAxis(indices, -1).Squeeze(-1)
acceptP := mlx.Minimum(p.Divide(q), mlx.FromValue(float32(1)))
return mlx.Bernoulli(acceptP).AsType(mlx.DTypeInt32)
}
func mtpSampleTokenAt(scores *mlx.Array, index int) *mlx.Array {
row := scores.Slice(mlx.Slice(index, index+1), mlx.Slice())
return mtpTokenVector(row.Categorical(-1).AsType(mlx.DTypeInt32))
}
func mtpSampleResidualToken(targetScores, draftScores *mlx.Array, index int) *mlx.Array {
p := mlx.SoftmaxAxis(targetScores.Slice(mlx.Slice(index, index+1), mlx.Slice()), -1, true)
q := mlx.SoftmaxAxis(draftScores.Slice(mlx.Slice(index, index+1), mlx.Slice()), -1, true)
diff := p.Subtract(q)
positive := mlx.Maximum(diff, mlx.FromValue(float32(1e-20)))
logits := mlx.Log(positive)
logits = mlx.Where(diff.LessEqual(mlx.FromValue(float32(0))), mlx.FromValue(float32(math.Inf(-1))), logits)
return mtpTokenVector(logits.Categorical(-1).AsType(mlx.DTypeInt32))
}
func mtpTokenInput(token *mlx.Array) *mlx.Array {
switch token.NumDims() {
case 0:
return token.Reshape(1, 1)
case 1:
return token.ExpandDims(-1)
case 2:
return token
default:
panic(fmt.Sprintf("mtp token must be rank 0, 1, or 2, got rank %d", token.NumDims()))
}
}
func mtpTokenVector(token *mlx.Array) *mlx.Array {
switch token.NumDims() {
case 0:
return token.Reshape(1)
case 1:
return token
default:
panic(fmt.Sprintf("mtp sampled token must be rank 0 or 1, got rank %d", token.NumDims()))
}
}
func (r *Runner) compareMTPBatchedWithSerial(ctx context.Context, caches []cache.Cache, before int, baseLogits, hiddenSeq *mlx.Array, draftIDs, selectedIDs []int, accepted, draftCount int, stats *mtpStats) {
serialCaches, ok := cache.BeginIsolatedSpeculation(caches)
if !ok {
return
}
compareCount := accepted + 1
if accepted == draftCount {
// Include the target bonus token when every draft was accepted.
compareCount = draftCount + 1
}
serialLogits := baseLogits
for i := range compareCount {
if err := ctx.Err(); err != nil {
return
}
if i >= len(selectedIDs) {
return
}
batchedLogits := baseLogits
if i > 0 {
batchedLogits = r.targetLogitsAt(hiddenSeq, i-1)
}
batchedToken := greedyTokenFromLogits(batchedLogits)
serialToken := greedyTokenFromLogits(serialLogits)
mlx.Eval(batchedToken, serialToken)
batchedID := tokenID(batchedToken)
vectorizedID := selectedIDs[i]
serialID := tokenID(serialToken)
stats.compared++
if vectorizedID != serialID {
firstMismatch := stats.batchSerialMismatches == 0
stats.batchSerialMismatches++
if !firstMismatch {
return
}
draftID := -1
if i < draftCount {
draftID = draftIDs[i]
}
batchedTop := top2FromLogits(batchedLogits)
serialTop := top2FromLogits(serialLogits)
slog.Warn("MTP batched validation differs from serial validation",
"position", before+i,
"draft", draftID,
"batched", vectorizedID,
"batched_slice", batchedID,
"serial", serialID,
"batched_slice_top1", batchedTop.firstToken,
"batched_slice_top2", batchedTop.secondToken,
"batched_slice_margin", batchedTop.margin,
"serial_top1", serialTop.firstToken,
"serial_top2", serialTop.secondToken,
"serial_margin", serialTop.margin,
)
return
}
if i >= draftCount || i >= accepted {
return
}
hidden := r.Model.Forward(&batch.Batch{
InputIDs: mlx.FromValues([]int32{int32(draftIDs[i])}, 1, 1),
SeqOffsets: []int32{int32(before + i)},
SeqQueryLens: []int32{1},
}, serialCaches)
serialLogits = r.lastLogits(hidden)
}
}
type mtpTop2 struct {
firstToken int
secondToken int
margin float64
}
func top2FromLogits(logits *mlx.Array) mtpTop2 {
indices := logits.Negative().ArgsortAxis(-1).Slice(mlx.Slice(), mlx.Slice(0, 2))
indices32 := indices.AsType(mlx.DTypeInt32)
values := logits.TakeAlongAxis(indices, -1).AsType(mlx.DTypeFloat32)
mlx.Eval(indices32, values)
tokenIDs := indices32.Ints()
logitValues := values.Floats()
if len(tokenIDs) < 2 || len(logitValues) < 2 {
return mtpTop2{}
}
return mtpTop2{
firstToken: tokenIDs[0],
secondToken: tokenIDs[1],
margin: float64(logitValues[0] - logitValues[1]),
}
}
func (r *Runner) acceptMTPDraftsSerial(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) {
logits := baseLogits
accepted := 0
draftIDs := draftTokens.Ints()
for _, id := range draftIDs {
selected := greedyTokenFromLogits(logits)
mlx.Eval(selected)
selectedID := tokenID(selected)
if selectedID != id {
return sampler.Result{Token: mlx.FromValues([]int32{int32(selectedID)}, 1)}, accepted, false, nil
}
hidden := r.Model.Forward(&batch.Batch{
InputIDs: mlx.FromValues([]int32{int32(id)}, 1, 1),
SeqOffsets: []int32{int32(*position)},
SeqQueryLens: []int32{1},
}, caches)
(*position)++
accepted++
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
done, err := r.emitMTPToken(ctx, request, session, dec, res, final)
if err != nil {
return sampler.Result{}, accepted, done, err
}
if !done {
(*generated)++
}
if done || *generated >= request.Options.NumPredict {
return sampler.Result{}, accepted, true, nil
}
logits = r.lastLogits(hidden)
}
return sampler.Result{Token: greedyTokenFromLogits(logits)}, accepted, false, nil
}
func (r *Runner) emitMTPToken(ctx context.Context, request Request, session *cacheSession, dec *decoder, res sampler.Result, final *CompletionResponse) (bool, error) {
output := int32(tokenID(res.Token))
session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) {
final.DoneReason = 0
return true, nil
}
if resp, ok := dec.decode(res); ok {
select {
case <-ctx.Done():
return false, ctx.Err()
case request.Responses <- resp:
}
}
return false, nil
}
func (r *Runner) lastLogits(hidden *mlx.Array) *mlx.Array {
logits := r.Model.Unembed(hidden)
return r.lastLogitsFromLogits(logits)
}
func (r *Runner) targetLogitsAt(hiddenSeq *mlx.Array, index int) *mlx.Array {
hidden := hiddenSeq.Slice(mlx.Slice(), mlx.Slice(index), mlx.Slice())
return r.lastLogits(hidden)
}
func (r *Runner) mtpValidationTokens(baseLogits, hiddenSeq *mlx.Array) *mlx.Array {
return greedyTokenFromLogits(r.mtpValidationLogits(baseLogits, hiddenSeq))
}
func (r *Runner) mtpValidationLogits(baseLogits, hiddenSeq *mlx.Array) *mlx.Array {
seqLogits := r.Model.Unembed(hiddenSeq)
return baseLogits.ExpandDims(1).Concatenate(1, seqLogits)
}
func mtpTokenAt(tokens *mlx.Array, index int) *mlx.Array {
return tokens.Slice(mlx.Slice(), mlx.Slice(index)).Squeeze(0)
}
func (r *Runner) lastLogitsFromLogits(logits *mlx.Array) *mlx.Array {
return logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
}
func greedyTokenFromLogits(logits *mlx.Array) *mlx.Array {
return logits.Argmax(-1, false).AsType(mlx.DTypeInt32)
}
func tokenID(token *mlx.Array) int {
if token == nil {
return -1
}
if token.DType() == mlx.DTypeInt32 {
ids := token.Ints()
if len(ids) > 0 {
return ids[0]
}
}
return token.Int()
}