mlx: rework the MLX sampler

Replace the MLX sampler transform chain with an explicit distribution pipeline that applies:
  1. penalties
  2. top-k
  3. temperature/softmax
  4. top-p
  5. min-p
  6. normalize
  7. categorical

The common top_k path now keeps sparse [B,K] token ids/probabilities on GPU instead of carrying full-vocab
scores, and sampled MTP reuses those draft/target distributions for acceptance, bonus, and residual sampling.

This change also fixes the seed parameter so that temperature sampling and sampled MTP are reproducible.
This commit is contained in:
Patrick Devine 2026-05-12 13:57:52 -07:00
parent 6bdb73073b
commit 98c87255bd
5 changed files with 494 additions and 184 deletions

View file

@ -5,21 +5,39 @@ import "C"
import "unsafe"
func RandomKey(seed uint64) *Array {
out := New("RANDOM_KEY")
C.mlx_random_key(&out.ctx, C.uint64_t(seed))
return out
}
func (t *Array) Categorical(axis int) *Array {
key := New("")
return t.CategoricalWithKey(axis, nil)
}
func (t *Array) CategoricalWithKey(axis int, key *Array) *Array {
if key == nil {
key = New("")
}
out := New("")
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
return out
}
func Bernoulli(p *Array) *Array {
return BernoulliWithKey(p, nil)
}
func BernoulliWithKey(p *Array, key *Array) *Array {
dims := p.Dims()
shape := make([]C.int, len(dims))
for i, d := range dims {
shape[i] = C.int(d)
}
key := New("")
if key == nil {
key = New("")
}
out := New("BERNOULLI")
C.mlx_random_bernoulli(&out.ctx, p.ctx, unsafe.SliceData(shape), C.size_t(len(shape)), key.ctx, DefaultStream().ctx)
return out

View file

@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"math"
"os"
"strconv"
"strings"
@ -375,9 +374,11 @@ func (r *Runner) runSampleMTPDecode(ctx context.Context, request Request, sessio
t0 = time.Now()
candidates := r.generateMTPDraftCandidates(draft, targetEmbeddings, current.Token, hidden, caches, int32(*position-1), maxDraft)
draftCount := 0
var candidateArrays []*mlx.Array
if candidates != nil {
draftCount = candidates.tokens.Dim(1)
mlx.Pin(baseLogits, candidates.tokens, candidates.logits)
candidateArrays = append([]*mlx.Array{baseLogits}, candidates.Arrays()...)
mlx.Pin(candidateArrays...)
mlx.Sweep()
}
stats.draftDuration += time.Since(t0)
@ -391,7 +392,7 @@ func (r *Runner) runSampleMTPDecode(ctx context.Context, request Request, sessio
t0 = time.Now()
next, accepted, done, err = r.acceptSampleMTPDrafts(ctx, request, session, &dec, caches, position, baseLogits, candidates, &final, &generated, &stats)
stats.validateDuration += time.Since(t0)
mlx.Unpin(baseLogits, candidates.tokens, candidates.logits)
mlx.Unpin(candidateArrays...)
if err != nil {
return err
}
@ -456,8 +457,15 @@ func (r *Runner) runSampleMTPDecode(ctx context.Context, request Request, sessio
type mtpDraftCandidates struct {
tokens *mlx.Array
// logits are the processed proposal scores used to sample tokens.
logits *mlx.Array
// dist is the proposal distribution used to sample each drafted token.
dist sampler.Distribution
}
func (c *mtpDraftCandidates) Arrays() []*mlx.Array {
if c == nil {
return nil
}
return append([]*mlx.Array{c.tokens}, c.dist.Arrays()...)
}
func (r *Runner) generateMTPDrafts(draft base.MTPDraftModel, target base.MTPEmbeddingModel, token *mlx.Array, hidden *mlx.Array, caches []cache.Cache, position int32, maxDraft int) *mlx.Array {
@ -497,7 +505,7 @@ func (r *Runner) generateMTPDraftCandidates(draft base.MTPDraftModel, target bas
lastToken := mtpTokenInput(token)
lastHidden := hidden
draftTokens := make([]*mlx.Array, 0, maxDraft)
draftLogits := make([]*mlx.Array, 0, maxDraft)
draftDists := make([]sampler.Distribution, 0, maxDraft)
var prefix *mlx.Array
// Gemma4 assistant MTP is trained as "single-position" drafting:
@ -508,13 +516,13 @@ func (r *Runner) generateMTPDraftCandidates(draft base.MTPDraftModel, target bas
inputs := tokenEmbedding.Concatenate(-1, lastHidden)
logits, projected := draft.Draft(inputs, position, caches)
stepLogits := r.lastLogitsFromLogits(logits)
stepScores := r.Sampler.SpeculativeScores(pipelineSlot, stepLogits, prefix)
nextToken := stepScores.Categorical(-1).AsType(mlx.DTypeInt32)
dist := r.Sampler.Distribution(pipelineSlot, stepLogits, prefix)
nextToken := r.Sampler.SampleDistribution(pipelineSlot, dist)
lastToken = mtpTokenInput(nextToken)
lastHidden = projected
draftTokens = append(draftTokens, lastToken)
draftLogits = append(draftLogits, stepScores.ExpandDims(1))
draftDists = append(draftDists, dist)
if prefix == nil {
prefix = lastToken
} else {
@ -526,7 +534,7 @@ func (r *Runner) generateMTPDraftCandidates(draft base.MTPDraftModel, target bas
}
return &mtpDraftCandidates{
tokens: mlx.Concatenate(draftTokens, 1),
logits: mlx.Concatenate(draftLogits, 1),
dist: sampler.ConcatenateDistributions(draftDists),
}
}
@ -630,12 +638,9 @@ func (r *Runner) acceptSampleMTPDrafts(ctx context.Context, request Request, ses
SeqQueryLens: []int32{int32(draftCount)},
}, specCaches)
targetScores := r.Sampler.SpeculativeScores(pipelineSlot, r.mtpValidationLogits(baseLogits, hiddenSeq), candidates.tokens)
draftScores := candidates.logits
if draftScores.NumDims() == 3 {
draftScores = draftScores.Squeeze(0)
}
acceptedMask := mtpSampleAcceptedMask(targetScores, draftScores, candidates.tokens, draftCount)
targetDist := r.Sampler.Distribution(pipelineSlot, r.mtpValidationLogits(baseLogits, hiddenSeq), candidates.tokens)
draftDist := candidates.dist
acceptedMask := r.mtpSampleAcceptedMask(targetDist.SliceRows(0, draftCount), draftDist, candidates.tokens)
mlx.Eval(candidates.tokens, acceptedMask)
draftIDs := candidates.tokens.Ints()
@ -692,9 +697,9 @@ func (r *Runner) acceptSampleMTPDrafts(ctx context.Context, request Request, ses
var nextToken *mlx.Array
if accepted == draftCount {
nextToken = mtpSampleTokenAt(targetScores, draftCount)
nextToken = r.mtpSampleTokenAt(targetDist, draftCount)
} else {
nextToken = mtpSampleResidualToken(targetScores, draftScores, accepted)
nextToken = r.mtpSampleResidualToken(targetDist, draftDist, accepted)
}
mlx.Eval(nextToken)
nextID := int32(tokenID(nextToken))
@ -704,32 +709,20 @@ func (r *Runner) acceptSampleMTPDrafts(ctx context.Context, request Request, ses
return sampler.Result{Token: nextToken}, accepted, false, nil
}
func mtpSampleAcceptedMask(targetScores, draftScores, draftTokens *mlx.Array, draftCount int) *mlx.Array {
targetProbs := mlx.SoftmaxAxis(targetScores.Slice(mlx.Slice(0, draftCount), mlx.Slice()), -1, true)
draftProbs := mlx.SoftmaxAxis(draftScores, -1, true)
if draftTokens.NumDims() == 2 {
draftTokens = draftTokens.Squeeze(0)
}
indices := draftTokens.ExpandDims(-1)
p := targetProbs.TakeAlongAxis(indices, -1).Squeeze(-1)
q := draftProbs.TakeAlongAxis(indices, -1).Squeeze(-1)
func (r *Runner) mtpSampleAcceptedMask(targetDist, draftDist sampler.Distribution, draftTokens *mlx.Array) *mlx.Array {
p := targetDist.Prob(draftTokens)
q := draftDist.Prob(draftTokens)
acceptP := mlx.Minimum(p.Divide(q), mlx.FromValue(float32(1)))
return mlx.Bernoulli(acceptP).AsType(mlx.DTypeInt32)
return r.Sampler.Bernoulli(pipelineSlot, acceptP).AsType(mlx.DTypeInt32)
}
func mtpSampleTokenAt(scores *mlx.Array, index int) *mlx.Array {
row := scores.Slice(mlx.Slice(index, index+1), mlx.Slice())
return mtpTokenVector(row.Categorical(-1).AsType(mlx.DTypeInt32))
func (r *Runner) mtpSampleTokenAt(dist sampler.Distribution, index int) *mlx.Array {
return mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, dist.SliceRows(index, index+1)))
}
func mtpSampleResidualToken(targetScores, draftScores *mlx.Array, index int) *mlx.Array {
p := mlx.SoftmaxAxis(targetScores.Slice(mlx.Slice(index, index+1), mlx.Slice()), -1, true)
q := mlx.SoftmaxAxis(draftScores.Slice(mlx.Slice(index, index+1), mlx.Slice()), -1, true)
diff := p.Subtract(q)
positive := mlx.Maximum(diff, mlx.FromValue(float32(1e-20)))
logits := mlx.Log(positive)
logits = mlx.Where(diff.LessEqual(mlx.FromValue(float32(0))), mlx.FromValue(float32(math.Inf(-1))), logits)
return mtpTokenVector(logits.Categorical(-1).AsType(mlx.DTypeInt32))
func (r *Runner) mtpSampleResidualToken(targetDist, draftDist sampler.Distribution, index int) *mlx.Array {
residual := targetDist.SliceRows(index, index+1).ResidualAgainst(draftDist.SliceRows(index, index+1))
return mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, residual))
}
func mtpTokenInput(token *mlx.Array) *mlx.Array {

View file

@ -17,6 +17,8 @@ type Options struct {
RepeatPenalty float32
PresencePenalty float32
FrequencyPenalty float32
Seed int
UseSeed bool
// Logprobs causes Sample to populate Result.Logprob with the selected
// token's log-probability. TopLogprobs (when > 0) adds top-K pairs.
@ -42,6 +44,123 @@ func (r Result) Arrays() []*mlx.Array {
return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs}
}
// Distribution is the filtered probability distribution used by the sampler.
// When IDs is nil, Probs is dense over the vocabulary. When IDs is set, Probs
// is sparse over the token ids in IDs, preserving GPU residency for the
// top-k-first path used by normal and speculative sampling.
type Distribution struct {
IDs *mlx.Array // sparse token ids, shape [B,K]; nil for dense distributions
Probs *mlx.Array // probabilities, shape [B,K] or [B,V]
}
// Arrays returns the tensor fields for mlx lifecycle management.
func (d Distribution) Arrays() []*mlx.Array {
return []*mlx.Array{d.IDs, d.Probs}
}
// Rows returns the number of rows in the distribution.
func (d Distribution) Rows() int {
if d.Probs == nil {
return 0
}
return d.Probs.Dim(0)
}
// SliceRows returns a row slice while preserving sparse/dense layout.
func (d Distribution) SliceRows(start, stop int) Distribution {
out := Distribution{Probs: d.Probs.Slice(mlx.Slice(start, stop), mlx.Slice())}
if d.IDs != nil {
out.IDs = d.IDs.Slice(mlx.Slice(start, stop), mlx.Slice())
}
return out
}
// SampleWithKey draws one token per row using key when supplied.
func (d Distribution) SampleWithKey(key *mlx.Array) *mlx.Array {
choice := logitsFromProbs(d.Probs).CategoricalWithKey(-1, key).AsType(mlx.DTypeInt32)
if d.IDs == nil {
return choice
}
return d.IDs.TakeAlongAxis(choice.ExpandDims(-1), -1).Squeeze(-1).AsType(mlx.DTypeInt32)
}
// Prob returns the probability assigned to one token per row.
func (d Distribution) Prob(tokens *mlx.Array) *mlx.Array {
switch tokens.NumDims() {
case 2:
if tokens.Dim(0) == 1 {
tokens = tokens.Squeeze(0)
} else if tokens.Dim(1) == 1 {
tokens = tokens.Squeeze(1)
}
case 0:
tokens = tokens.Reshape(1)
}
return d.ProbsForIDs(tokens.ExpandDims(-1)).Squeeze(-1)
}
// ProbsForIDs returns probabilities for each requested token id. ids must be
// rank-2 [B,N], matching the distribution rows.
func (d Distribution) ProbsForIDs(ids *mlx.Array) *mlx.Array {
if d.IDs == nil {
return d.Probs.TakeAlongAxis(ids, -1)
}
eq := d.IDs.ExpandDims(-1).Equal(ids.ExpandDims(1))
values := mlx.Where(eq, d.Probs.ExpandDims(-1), mlx.FromValue(float32(0)))
return values.SumAxis(1, false)
}
// ResidualAgainst returns the Leviathan/Chen rejection distribution
// proportional to max(target - draft, 0). Sparse target distributions stay
// sparse over the target support; tokens outside target support have zero mass.
func (d Distribution) ResidualAgainst(draft Distribution) Distribution {
if d.IDs != nil {
diff := d.Probs.Subtract(draft.ProbsForIDs(d.IDs))
return Distribution{IDs: d.IDs, Probs: normalizeProbs(mlx.Maximum(diff, mlx.FromValue(float32(0))))}
}
if draft.IDs != nil {
panic("sample.Distribution.ResidualAgainst: dense target with sparse draft is unsupported")
}
diff := d.Probs.Subtract(draft.Probs)
return Distribution{Probs: normalizeProbs(mlx.Maximum(diff, mlx.FromValue(float32(0))))}
}
// LogProbs returns dense log-probabilities, scattering sparse distributions
// into a full-vocabulary tensor when needed.
func (d Distribution) LogProbs(vocab int) *mlx.Array {
logProbs := logitsFromProbs(d.Probs)
if d.IDs == nil {
return logProbs
}
out := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, d.Probs.Dim(0), vocab), float32(math.Inf(-1)))
return out.PutAlongAxis(d.IDs, logProbs, -1)
}
// ConcatenateDistributions concatenates distribution rows. All inputs must use
// the same sparse/dense layout.
func ConcatenateDistributions(dists []Distribution) Distribution {
if len(dists) == 0 {
return Distribution{}
}
probs := make([]*mlx.Array, 0, len(dists))
ids := make([]*mlx.Array, 0, len(dists))
sparse := dists[0].IDs != nil
for _, d := range dists {
if (d.IDs != nil) != sparse {
panic("sample.ConcatenateDistributions: mixed sparse and dense distributions")
}
probs = append(probs, d.Probs)
if sparse {
ids = append(ids, d.IDs)
}
}
out := Distribution{Probs: mlx.Concatenate(probs, 0)}
if sparse {
out.IDs = mlx.Concatenate(ids, 0)
}
return out
}
// Sampler is a batched, slot-based sampler. Sequences are registered with
// Add and released with Remove. Each Sample call takes a subset of
// registered slots (in any order) with their [B,V] logits, samples one
@ -73,9 +192,9 @@ type Sampler struct {
}
type slotState struct {
opts Options
transforms []transform
historyLen int
opts Options
historyLen int
randomCounter uint64
}
type slotCtx struct {
@ -83,8 +202,6 @@ type slotCtx struct {
history *mlx.Array // 2D [B, W] when penalties are configured; nil otherwise
}
type transform func(*slotCtx, *mlx.Array) *mlx.Array
// New constructs an empty sampler with no registered slots. numCtx is
// the runner's context window and must be positive.
func New(numCtx int) *Sampler {
@ -127,40 +244,17 @@ func (o Options) normalize(numCtx int) Options {
// RepeatLastN still batch together and don't inflate pool width.
o.RepeatLastN = 0
}
if o.Seed < 0 {
o.UseSeed = false
}
if !o.UseSeed {
// Keep unseeded callers on the same batching path even when a
// meaningless Seed value is present in an Options literal.
o.Seed = 0
}
return o
}
func (o Options) buildTransforms() []transform {
var ts []transform
if o.usesHistory() {
ts = append(ts, penalty)
}
hasTopP := o.TopP > 0 && o.TopP < 1
hasTopK := o.TopK > 0
switch {
case hasTopP:
// topKTopP always does a full descending sort for the top-P
// cumulative mask and opportunistically masks top-K during the
// same pass when it is also configured.
ts = append(ts, topKTopP)
case hasTopK:
// Argpartition (partial sort) is cheaper than a full sort.
ts = append(ts, topK)
}
if o.MinP != 0 {
ts = append(ts, minP)
}
if o.Temperature == 0 {
ts = append(ts, greedy)
} else {
ts = append(ts, temperature)
}
return ts
}
// Add registers a sequence under seqID. The last RepeatLastN entries of
// priorTokens seed the ring buffer.
func (s *Sampler) Add(seqID int, opts Options, priorTokens []int32) {
@ -170,8 +264,7 @@ func (s *Sampler) Add(seqID int, opts Options, priorTokens []int32) {
opts = opts.normalize(s.numCtx)
slot := &slotState{
opts: opts,
transforms: opts.buildTransforms(),
opts: opts,
}
// Grow the pool to hold this slot's row. The pool is lazy — the first
@ -349,29 +442,34 @@ func (s *Sampler) Sample(seqIDs []int, logits *mlx.Array) Result {
return res
}
// SpeculativeScores applies this slot's non-sampling transforms to logits
// without mutating sampler state. Row i is scored as if draftTokens[:i] had
// already been appended to the slot history. logits must be [R,V] or [1,R,V].
func (s *Sampler) SpeculativeScores(seqID int, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
slot, ok := s.byID[seqID]
if !ok {
panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: seqID %d not registered", seqID))
}
if logits.NumDims() == 3 {
if logits.Dim(0) != 1 {
panic("sample.Sampler.SpeculativeScores: only batch size 1 is supported")
}
logits = logits.Squeeze(0)
}
if logits.NumDims() != 2 {
panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: logits must be rank 2 or 3, got rank %d", logits.NumDims()))
}
if draftTokens != nil && draftTokens.NumDims() == 1 {
draftTokens = draftTokens.ExpandDims(0)
}
// Distribution applies this slot's sampling transforms to logits without
// mutating sampler state. Row i is built as if draftTokens[:i] had already
// been appended to the slot history. logits must be [R,V] or [1,R,V].
func (s *Sampler) Distribution(seqID int, logits *mlx.Array, draftTokens *mlx.Array) Distribution {
slot, logits, draftTokens := s.speculativeInputs("Distribution", seqID, logits, draftTokens)
rows := logits.Dim(0)
var hist *mlx.Array
if slot.opts.usesHistory() {
if s.history == nil {
panic(fmt.Sprintf("sample.Sampler.Distribution: seqID %d has no history", seqID))
}
if slot.historyLen < slot.opts.RepeatLastN {
return s.speculativeDistributionSerial(slot, logits, draftTokens)
}
hist = s.speculativeHistory(slot, draftTokens, rows)
}
return slot.distribution(&slotCtx{opts: slot.opts, history: hist}, logits)
}
// SpeculativeScores applies this slot's sampling transforms to logits without
// mutating sampler state and returns dense log-probability scores for sampled
// decoding. Greedy decoding returns the penalty-adjusted logits.
func (s *Sampler) SpeculativeScores(seqID int, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
slot, logits, draftTokens := s.speculativeInputs("SpeculativeScores", seqID, logits, draftTokens)
rows := logits.Dim(0)
var hist *mlx.Array
if slot.opts.usesHistory() {
if s.history == nil {
@ -386,6 +484,47 @@ func (s *Sampler) SpeculativeScores(seqID int, logits *mlx.Array, draftTokens *m
return slot.speculativeScores(&slotCtx{opts: slot.opts, history: hist}, logits)
}
// SampleDistribution draws from a precomputed distribution while advancing
// seqID's deterministic RNG stream when a seed is configured.
func (s *Sampler) SampleDistribution(seqID int, dist Distribution) *mlx.Array {
slot := s.mustSlot("SampleDistribution", seqID)
return dist.SampleWithKey(slot.nextRandomKey())
}
// Bernoulli samples boolean outcomes while advancing seqID's deterministic RNG
// stream when a seed is configured.
func (s *Sampler) Bernoulli(seqID int, p *mlx.Array) *mlx.Array {
slot := s.mustSlot("Bernoulli", seqID)
return mlx.BernoulliWithKey(p, slot.nextRandomKey())
}
func (s *Sampler) mustSlot(caller string, seqID int) *slotState {
slot, ok := s.byID[seqID]
if !ok {
panic(fmt.Sprintf("sample.Sampler.%s: seqID %d not registered", caller, seqID))
}
return slot
}
func (s *Sampler) speculativeInputs(caller string, seqID int, logits *mlx.Array, draftTokens *mlx.Array) (*slotState, *mlx.Array, *mlx.Array) {
slot := s.mustSlot(caller, seqID)
if logits.NumDims() == 3 {
if logits.Dim(0) != 1 {
panic(fmt.Sprintf("sample.Sampler.%s: only batch size 1 is supported", caller))
}
logits = logits.Squeeze(0)
}
if logits.NumDims() != 2 {
panic(fmt.Sprintf("sample.Sampler.%s: logits must be rank 2 or 3, got rank %d", caller, logits.NumDims()))
}
if draftTokens != nil && draftTokens.NumDims() == 1 {
draftTokens = draftTokens.ExpandDims(0)
}
return slot, logits, draftTokens
}
// Commit appends already-selected tokens to seqID's repeat-penalty history.
// It is used after speculative sampling once the accepted continuation is
// known. Normal Sample calls continue to mutate history themselves.
@ -422,7 +561,7 @@ func (s *Sampler) Commit(seqID int, tokens []int32) {
slot.historyLen += len(tokens)
}
func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
func (s *Sampler) speculativeDistributionSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) Distribution {
rows := logits.Dim(0)
draftCount := 0
if draftTokens != nil {
@ -435,7 +574,7 @@ func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, dr
base = s.history.Slice(mlx.Slice(row, row+1), mlx.Slice(0, baseFill))
}
scored := make([]*mlx.Array, 0, rows)
dists := make([]Distribution, 0, rows)
for i := range rows {
rowLogits := logits.Slice(mlx.Slice(i, i+1), mlx.Slice())
hist := base
@ -451,9 +590,13 @@ func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, dr
hist = hist.Slice(mlx.Slice(), mlx.Slice(hist.Dim(1)-slot.opts.RepeatLastN, mlx.End))
}
}
scored = append(scored, slot.speculativeScores(&slotCtx{opts: slot.opts, history: hist}, rowLogits))
dists = append(dists, slot.distribution(&slotCtx{opts: slot.opts, history: hist}, rowLogits))
}
return mlx.Concatenate(scored, 0)
return ConcatenateDistributions(dists)
}
func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
return s.speculativeDistributionSerial(slot, logits, draftTokens).LogProbs(logits.Dim(logits.NumDims() - 1))
}
func (s *Sampler) speculativeHistory(slot *slotState, draftTokens *mlx.Array, rows int) *mlx.Array {
@ -489,17 +632,10 @@ func (s *Sampler) speculativeHistory(slot *slotState, draftTokens *mlx.Array, ro
}
func (slot *slotState) speculativeScores(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
scores := logits
// buildTransforms always appends the final selector transform
// (greedy or temperature sampling). Speculative validation needs the
// processed logits before that selector mutates the distribution.
for _, t := range slot.transforms[:len(slot.transforms)-1] {
scores = t(ctx, scores)
if slot.opts.Temperature == 0 {
return slot.baseScores(ctx, logits)
}
if slot.opts.Temperature > 0 {
scores = mlx.DivScalar(scores, slot.opts.Temperature)
}
return scores
return slot.distribution(ctx, logits).LogProbs(logits.Dim(logits.NumDims() - 1))
}
// canBatch reports whether the call can take the uniform batched path.
@ -513,6 +649,9 @@ func (s *Sampler) canBatch(slots []*slotState) (Options, bool) {
// slots is non-empty (Sample guards) and every slot is registered,
// so s.slots[0].opts is the canonical shared value.
shared := s.slots[0].opts
// TODO: Before using multi-slot batching with seeded stochastic sampling,
// make sure each row gets its own per-slot random key instead of sharing
// slots[0]'s key through one batched categorical op.
if !shared.usesHistory() {
return shared, true
}
@ -527,7 +666,7 @@ func (s *Sampler) canBatch(slots []*slotState) (Options, bool) {
return shared, true
}
// sampleTokensUniform runs one fused transform pass over the whole batch.
// sampleTokensUniform runs one fused sampling pass over the whole batch.
// Reached only when canBatch is true, which lets the pool be used in place
// with a single PutAlongAxis write-back and no gather.
func (s *Sampler) sampleTokensUniform(slots []*slotState, opts Options, logits *mlx.Array) *mlx.Array {
@ -542,11 +681,14 @@ func (s *Sampler) sampleTokensUniform(slots []*slotState, opts Options, logits *
}
ctx := &slotCtx{opts: opts, history: hist}
scores := logits
for _, t := range slots[0].transforms {
scores = t(ctx, scores)
token := slots[0].sample(ctx, logits)
if opts.UseSeed && opts.Temperature != 0 {
// TODO: This only keeps counters aligned; it does not give each slot
// an independent key for the batched draw.
for _, slot := range slots[1:] {
slot.randomCounter++
}
}
token := scores
if !opts.usesHistory() {
return token
@ -563,8 +705,7 @@ func (s *Sampler) sampleTokensUniform(slots []*slotState, opts Options, logits *
return token
}
// sampleTokensSerial runs each slot's transforms against its own row of
// logits.
// sampleTokensSerial samples each slot against its own row of logits.
func (s *Sampler) sampleTokensSerial(slots []*slotState, logits *mlx.Array) *mlx.Array {
perSlotTokens := make([]*mlx.Array, len(slots))
@ -587,11 +728,7 @@ func (s *Sampler) sampleTokensSerial(slots []*slotState, logits *mlx.Array) *mlx
}
ctx := &slotCtx{opts: slot.opts, history: hist}
scores := row
for _, t := range slot.transforms {
scores = t(ctx, scores)
}
perSlotTokens[i] = scores
perSlotTokens[i] = slot.sample(ctx, row)
}
token := mlx.Concatenate(perSlotTokens, 0)
@ -627,74 +764,107 @@ func (s *Sampler) sampleTokensSerial(slots []*slotState, logits *mlx.Array) *mlx
return token
}
func greedy(_ *slotCtx, scores *mlx.Array) *mlx.Array {
return scores.Argmax(-1, false).AsType(mlx.DTypeInt32)
func (slot *slotState) sample(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
if slot.opts.Temperature == 0 {
return slot.baseScores(ctx, logits).Argmax(-1, false).AsType(mlx.DTypeInt32)
}
return slot.distribution(ctx, logits).SampleWithKey(slot.nextRandomKey())
}
func temperature(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
return mlx.DivScalar(scores, ctx.opts.Temperature).Categorical(-1).AsType(mlx.DTypeInt32)
func (slot *slotState) nextRandomKey() *mlx.Array {
if !slot.opts.UseSeed {
return nil
}
seed := mixSeed(uint64(slot.opts.Seed), slot.randomCounter)
slot.randomCounter++
return mlx.RandomKey(seed)
}
// topKTopP applies top-P in a descending sort pass and, when top-K is also
// configured, masks any surviving value below the K-th largest in the same
// pass. Callers dispatch here whenever top-P is enabled — the top-K-only case
// uses a cheaper partial sort via the topK transform.
func topKTopP(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
const (
// SplitMix64 constants used to decorrelate nearby (seed, counter) pairs.
splitMix64Weyl = 0x9e3779b97f4a7c15
splitMix64Mul1 = 0xbf58476d1ce4e5b9
splitMix64Mul2 = 0x94d049bb133111eb
splitMix64Shift1 = 30
splitMix64Shift2 = 27
splitMix64FinalShift = 31
)
func mixSeed(seed, counter uint64) uint64 {
z := seed + splitMix64Weyl*(counter+1)
z = (z ^ (z >> splitMix64Shift1)) * splitMix64Mul1
z = (z ^ (z >> splitMix64Shift2)) * splitMix64Mul2
return z ^ (z >> splitMix64FinalShift)
}
func (slot *slotState) baseScores(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
scores := logits
if slot.opts.usesHistory() {
scores = penalty(ctx, scores)
}
return scores
}
func (slot *slotState) distribution(ctx *slotCtx, logits *mlx.Array) Distribution {
scores := slot.baseScores(ctx, logits)
if slot.opts.Temperature <= 0 {
ids := scores.Argmax(-1, false).AsType(mlx.DTypeInt32).ExpandDims(-1)
probs := mlx.AddScalar(ids.AsType(mlx.DTypeFloat32).Multiply(mlx.FromValue(float32(0))), 1)
return Distribution{IDs: ids, Probs: probs}
}
vocab := scores.Dim(scores.NumDims() - 1)
applyTopK := ctx.opts.TopK > 0 && ctx.opts.TopK < vocab
order := scores.Negative().ArgsortAxis(-1)
sorted := scores.TakeAlongAxis(order, -1)
negInf := mlx.FromValue(float32(math.Inf(-1)))
// Top-P: in descending order, keep tokens whose exclusive cumulative
// probability is still below TopP.
probs := mlx.SoftmaxAxis(sorted, -1, true)
prevCumProbs := probs.Cumsum(-1, false, true).Subtract(probs)
keep := prevCumProbs.Less(mlx.FromValue(ctx.opts.TopP))
sorted = mlx.Where(keep, sorted, negInf)
out := scores.PutAlongAxis(order, sorted, -1)
// Top-K: sorted is already in descending order, so positions [K, V) are
// the ones to drop. Scatter -inf through their original-layout indices
// (order[K:]). Positional (not value-based) so exactly K tokens survive —
// ties at the K-th logit get broken by the sort order rather than
// promoted through the filter.
if applyTopK {
dropOrder := order.Slice(mlx.Slice(), mlx.Slice(ctx.opts.TopK, mlx.End))
out = out.PutAlongAxis(dropOrder, negInf, -1)
if slot.opts.TopK > 0 && slot.opts.TopK < vocab {
return sparseDistribution(ctx.opts, scores)
}
return out
return denseDistribution(ctx.opts, scores)
}
func minP(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
if ctx.opts.MinP <= 0 || ctx.opts.MinP > 1 {
return scores
}
maxScore := scores.MaxAxis(-1, true)
threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(ctx.opts.MinP))))
return mlx.Where(
scores.Less(threshold),
mlx.FromValue(float32(math.Inf(-1))),
scores,
)
func sparseDistribution(opts Options, scores *mlx.Array) Distribution {
ids := scores.Negative().ArgpartitionAxis(opts.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(0, opts.TopK)).AsType(mlx.DTypeInt32)
topScores := scores.TakeAlongAxis(ids, -1).AsType(mlx.DTypeFloat32)
probs := mlx.SoftmaxAxis(mlx.DivScalar(topScores, opts.Temperature), -1, true)
probs = applyTopPProbs(probs, opts.TopP)
probs = applyMinPProbs(probs, opts.MinP)
return Distribution{IDs: ids, Probs: normalizeProbs(probs)}
}
func topK(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
if ctx.opts.TopK <= 0 {
return scores
}
vocab := scores.Dim(scores.NumDims() - 1)
if ctx.opts.TopK >= vocab {
return scores
}
func denseDistribution(opts Options, scores *mlx.Array) Distribution {
probs := mlx.SoftmaxAxis(mlx.DivScalar(scores.AsType(mlx.DTypeFloat32), opts.Temperature), -1, true)
probs = applyTopPProbs(probs, opts.TopP)
probs = applyMinPProbs(probs, opts.MinP)
return Distribution{Probs: normalizeProbs(probs)}
}
mask := scores.Negative().ArgpartitionAxis(ctx.opts.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(ctx.opts.TopK, mlx.End))
return scores.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
func applyTopPProbs(probs *mlx.Array, topP float32) *mlx.Array {
if topP <= 0 || topP >= 1 {
return probs
}
order := probs.Negative().ArgsortAxis(-1)
sorted := probs.TakeAlongAxis(order, -1)
prevCumProbs := sorted.Cumsum(-1, false, true).Subtract(sorted)
keep := prevCumProbs.Less(mlx.FromValue(topP))
filtered := mlx.Where(keep, sorted, mlx.FromValue(float32(0)))
return mlx.Zeros(probs.DType(), probs.Dims()...).PutAlongAxis(order, filtered, -1)
}
func applyMinPProbs(probs *mlx.Array, minP float32) *mlx.Array {
if minP <= 0 || minP > 1 {
return probs
}
threshold := mlx.MulScalar(probs.MaxAxis(-1, true), minP)
return mlx.Where(probs.Less(threshold), mlx.FromValue(float32(0)), probs)
}
func normalizeProbs(probs *mlx.Array) *mlx.Array {
sum := mlx.Maximum(probs.SumAxis(-1, true), mlx.FromValue(float32(1e-20)))
return probs.Divide(sum)
}
func logitsFromProbs(probs *mlx.Array) *mlx.Array {
positive := mlx.Maximum(probs, mlx.FromValue(float32(1e-20)))
logits := mlx.Log(positive)
return mlx.Where(probs.LessEqual(mlx.FromValue(float32(0))), mlx.FromValue(float32(math.Inf(-1))), logits)
}
func penalty(ctx *slotCtx, scores *mlx.Array) *mlx.Array {

View file

@ -4,6 +4,7 @@ package sample
import (
"math"
"slices"
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
@ -141,6 +142,132 @@ func TestSampleSingleSlotOptions(t *testing.T) {
}
}
func TestDistributionAppliesTopKBeforeTopP(t *testing.T) {
skipIfNoMLX(t)
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
s.Add(0, Options{Temperature: 1, TopK: 2, TopP: 0.7}, nil)
dist := s.Distribution(0, slotLogits([]float32{logOf(0.6), logOf(0.2), logOf(0.2)}), nil)
mlx.Eval(dist.Arrays()...)
ids := dist.IDs.Ints()
probs := dist.Probs.Floats()
if len(ids) != 2 || len(probs) != 2 {
t.Fatalf("support = ids %v probs %v, want 2 sparse entries", ids, probs)
}
foundTop := false
for i, id := range ids {
switch id {
case 0:
foundTop = true
if math.Abs(float64(probs[i]-1)) > 1e-5 {
t.Fatalf("top token prob = %v, want 1; ids=%v probs=%v", probs[i], ids, probs)
}
default:
if math.Abs(float64(probs[i])) > 1e-5 {
t.Fatalf("non-top token %d prob = %v, want 0; ids=%v probs=%v", id, probs[i], ids, probs)
}
}
}
if !foundTop {
t.Fatalf("top-k support %v did not include token 0", ids)
}
}
func TestDistributionResidualUsesTargetSupport(t *testing.T) {
skipIfNoMLX(t)
target := Distribution{
IDs: mlx.NewArrayInt32([]int32{2, 5}, []int32{1, 2}),
Probs: mlx.FromValues([]float32{0.7, 0.3}, 1, 2),
}
draft := Distribution{
IDs: mlx.NewArrayInt32([]int32{2, 4}, []int32{1, 2}),
Probs: mlx.FromValues([]float32{0.2, 0.8}, 1, 2),
}
residual := target.ResidualAgainst(draft)
mlx.Eval(residual.Arrays()...)
ids := residual.IDs.Ints()
probs := residual.Probs.Floats()
want := map[int]float64{2: 0.625, 5: 0.375}
if len(ids) != 2 || len(probs) != 2 {
t.Fatalf("residual = ids %v probs %v, want 2 sparse entries", ids, probs)
}
for i, id := range ids {
w, ok := want[id]
if !ok {
t.Fatalf("residual includes token %d outside target support: ids=%v probs=%v", id, ids, probs)
}
if math.Abs(float64(probs[i])-w) > 1e-5 {
t.Fatalf("residual token %d prob = %v, want %v; ids=%v probs=%v", id, probs[i], w, ids, probs)
}
}
}
func TestSeededSamplingIsReproducible(t *testing.T) {
skipIfNoMLX(t)
seededSequence := func(seed int) []int {
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
s.Add(0, Options{Temperature: 1, TopK: 4, Seed: seed, UseSeed: true}, nil)
logits := slotLogits([]float32{0, 0, 0, 0})
out := make([]int, 32)
for i := range out {
token := s.Sample([]int{0}, logits).Token
mlx.Eval(token)
out[i] = token.Int()
}
return out
}
a := seededSequence(1234)
b := seededSequence(1234)
if !slices.Equal(a, b) {
t.Fatalf("same seed produced different sequences:\n%v\n%v", a, b)
}
c := seededSequence(5678)
if slices.Equal(a, c) {
t.Fatalf("different seeds produced the same sequence: %v", a)
}
}
func TestSeededBernoulliIsReproducible(t *testing.T) {
skipIfNoMLX(t)
seededMask := func() []int {
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
s.Add(0, Options{Seed: 99, UseSeed: true}, nil)
mask := s.Bernoulli(0, mlx.FromValues([]float32{0.5, 0.5, 0.5, 0.5, 0.5, 0.5}, 6)).AsType(mlx.DTypeInt32)
mlx.Eval(mask)
return mask.Ints()
}
a := seededMask()
b := seededMask()
if !slices.Equal(a, b) {
t.Fatalf("same seed produced different bernoulli masks:\n%v\n%v", a, b)
}
}
// TestSampleHistoryWindow verifies that penalty history respects the
// RepeatLastN window: priors longer than RepeatLastN are trimmed on Add,
// and once the ring wraps, tokens that rotate out no longer contribute

View file

@ -136,6 +136,8 @@ func Execute(args []string) error {
RepeatPenalty: request.Options.RepeatPenalty,
PresencePenalty: request.Options.PresencePenalty,
FrequencyPenalty: request.Options.FrequencyPenalty,
Seed: request.Options.Seed,
UseSeed: request.Options.Seed >= 0,
Logprobs: request.Logprobs,
TopLogprobs: request.TopLogprobs,
}