diff --git a/x/mlxrunner/mlx/random.go b/x/mlxrunner/mlx/random.go index 009bddc39..291e8852f 100644 --- a/x/mlxrunner/mlx/random.go +++ b/x/mlxrunner/mlx/random.go @@ -5,21 +5,39 @@ import "C" import "unsafe" +func RandomKey(seed uint64) *Array { + out := New("RANDOM_KEY") + C.mlx_random_key(&out.ctx, C.uint64_t(seed)) + return out +} + func (t *Array) Categorical(axis int) *Array { - key := New("") + return t.CategoricalWithKey(axis, nil) +} + +func (t *Array) CategoricalWithKey(axis int, key *Array) *Array { + if key == nil { + key = New("") + } out := New("") C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx) return out } func Bernoulli(p *Array) *Array { + return BernoulliWithKey(p, nil) +} + +func BernoulliWithKey(p *Array, key *Array) *Array { dims := p.Dims() shape := make([]C.int, len(dims)) for i, d := range dims { shape[i] = C.int(d) } - key := New("") + if key == nil { + key = New("") + } out := New("BERNOULLI") C.mlx_random_bernoulli(&out.ctx, p.ctx, unsafe.SliceData(shape), C.size_t(len(shape)), key.ctx, DefaultStream().ctx) return out diff --git a/x/mlxrunner/mtp.go b/x/mlxrunner/mtp.go index 491e2706d..e87489c06 100644 --- a/x/mlxrunner/mtp.go +++ b/x/mlxrunner/mtp.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "log/slog" - "math" "os" "strconv" "strings" @@ -375,9 +374,11 @@ func (r *Runner) runSampleMTPDecode(ctx context.Context, request Request, sessio t0 = time.Now() candidates := r.generateMTPDraftCandidates(draft, targetEmbeddings, current.Token, hidden, caches, int32(*position-1), maxDraft) draftCount := 0 + var candidateArrays []*mlx.Array if candidates != nil { draftCount = candidates.tokens.Dim(1) - mlx.Pin(baseLogits, candidates.tokens, candidates.logits) + candidateArrays = append([]*mlx.Array{baseLogits}, candidates.Arrays()...) + mlx.Pin(candidateArrays...) mlx.Sweep() } stats.draftDuration += time.Since(t0) @@ -391,7 +392,7 @@ func (r *Runner) runSampleMTPDecode(ctx context.Context, request Request, sessio t0 = time.Now() next, accepted, done, err = r.acceptSampleMTPDrafts(ctx, request, session, &dec, caches, position, baseLogits, candidates, &final, &generated, &stats) stats.validateDuration += time.Since(t0) - mlx.Unpin(baseLogits, candidates.tokens, candidates.logits) + mlx.Unpin(candidateArrays...) if err != nil { return err } @@ -456,8 +457,15 @@ func (r *Runner) runSampleMTPDecode(ctx context.Context, request Request, sessio type mtpDraftCandidates struct { tokens *mlx.Array - // logits are the processed proposal scores used to sample tokens. - logits *mlx.Array + // dist is the proposal distribution used to sample each drafted token. + dist sampler.Distribution +} + +func (c *mtpDraftCandidates) Arrays() []*mlx.Array { + if c == nil { + return nil + } + return append([]*mlx.Array{c.tokens}, c.dist.Arrays()...) } func (r *Runner) generateMTPDrafts(draft base.MTPDraftModel, target base.MTPEmbeddingModel, token *mlx.Array, hidden *mlx.Array, caches []cache.Cache, position int32, maxDraft int) *mlx.Array { @@ -497,7 +505,7 @@ func (r *Runner) generateMTPDraftCandidates(draft base.MTPDraftModel, target bas lastToken := mtpTokenInput(token) lastHidden := hidden draftTokens := make([]*mlx.Array, 0, maxDraft) - draftLogits := make([]*mlx.Array, 0, maxDraft) + draftDists := make([]sampler.Distribution, 0, maxDraft) var prefix *mlx.Array // Gemma4 assistant MTP is trained as "single-position" drafting: @@ -508,13 +516,13 @@ func (r *Runner) generateMTPDraftCandidates(draft base.MTPDraftModel, target bas inputs := tokenEmbedding.Concatenate(-1, lastHidden) logits, projected := draft.Draft(inputs, position, caches) stepLogits := r.lastLogitsFromLogits(logits) - stepScores := r.Sampler.SpeculativeScores(pipelineSlot, stepLogits, prefix) - nextToken := stepScores.Categorical(-1).AsType(mlx.DTypeInt32) + dist := r.Sampler.Distribution(pipelineSlot, stepLogits, prefix) + nextToken := r.Sampler.SampleDistribution(pipelineSlot, dist) lastToken = mtpTokenInput(nextToken) lastHidden = projected draftTokens = append(draftTokens, lastToken) - draftLogits = append(draftLogits, stepScores.ExpandDims(1)) + draftDists = append(draftDists, dist) if prefix == nil { prefix = lastToken } else { @@ -526,7 +534,7 @@ func (r *Runner) generateMTPDraftCandidates(draft base.MTPDraftModel, target bas } return &mtpDraftCandidates{ tokens: mlx.Concatenate(draftTokens, 1), - logits: mlx.Concatenate(draftLogits, 1), + dist: sampler.ConcatenateDistributions(draftDists), } } @@ -630,12 +638,9 @@ func (r *Runner) acceptSampleMTPDrafts(ctx context.Context, request Request, ses SeqQueryLens: []int32{int32(draftCount)}, }, specCaches) - targetScores := r.Sampler.SpeculativeScores(pipelineSlot, r.mtpValidationLogits(baseLogits, hiddenSeq), candidates.tokens) - draftScores := candidates.logits - if draftScores.NumDims() == 3 { - draftScores = draftScores.Squeeze(0) - } - acceptedMask := mtpSampleAcceptedMask(targetScores, draftScores, candidates.tokens, draftCount) + targetDist := r.Sampler.Distribution(pipelineSlot, r.mtpValidationLogits(baseLogits, hiddenSeq), candidates.tokens) + draftDist := candidates.dist + acceptedMask := r.mtpSampleAcceptedMask(targetDist.SliceRows(0, draftCount), draftDist, candidates.tokens) mlx.Eval(candidates.tokens, acceptedMask) draftIDs := candidates.tokens.Ints() @@ -692,9 +697,9 @@ func (r *Runner) acceptSampleMTPDrafts(ctx context.Context, request Request, ses var nextToken *mlx.Array if accepted == draftCount { - nextToken = mtpSampleTokenAt(targetScores, draftCount) + nextToken = r.mtpSampleTokenAt(targetDist, draftCount) } else { - nextToken = mtpSampleResidualToken(targetScores, draftScores, accepted) + nextToken = r.mtpSampleResidualToken(targetDist, draftDist, accepted) } mlx.Eval(nextToken) nextID := int32(tokenID(nextToken)) @@ -704,32 +709,20 @@ func (r *Runner) acceptSampleMTPDrafts(ctx context.Context, request Request, ses return sampler.Result{Token: nextToken}, accepted, false, nil } -func mtpSampleAcceptedMask(targetScores, draftScores, draftTokens *mlx.Array, draftCount int) *mlx.Array { - targetProbs := mlx.SoftmaxAxis(targetScores.Slice(mlx.Slice(0, draftCount), mlx.Slice()), -1, true) - draftProbs := mlx.SoftmaxAxis(draftScores, -1, true) - if draftTokens.NumDims() == 2 { - draftTokens = draftTokens.Squeeze(0) - } - indices := draftTokens.ExpandDims(-1) - p := targetProbs.TakeAlongAxis(indices, -1).Squeeze(-1) - q := draftProbs.TakeAlongAxis(indices, -1).Squeeze(-1) +func (r *Runner) mtpSampleAcceptedMask(targetDist, draftDist sampler.Distribution, draftTokens *mlx.Array) *mlx.Array { + p := targetDist.Prob(draftTokens) + q := draftDist.Prob(draftTokens) acceptP := mlx.Minimum(p.Divide(q), mlx.FromValue(float32(1))) - return mlx.Bernoulli(acceptP).AsType(mlx.DTypeInt32) + return r.Sampler.Bernoulli(pipelineSlot, acceptP).AsType(mlx.DTypeInt32) } -func mtpSampleTokenAt(scores *mlx.Array, index int) *mlx.Array { - row := scores.Slice(mlx.Slice(index, index+1), mlx.Slice()) - return mtpTokenVector(row.Categorical(-1).AsType(mlx.DTypeInt32)) +func (r *Runner) mtpSampleTokenAt(dist sampler.Distribution, index int) *mlx.Array { + return mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, dist.SliceRows(index, index+1))) } -func mtpSampleResidualToken(targetScores, draftScores *mlx.Array, index int) *mlx.Array { - p := mlx.SoftmaxAxis(targetScores.Slice(mlx.Slice(index, index+1), mlx.Slice()), -1, true) - q := mlx.SoftmaxAxis(draftScores.Slice(mlx.Slice(index, index+1), mlx.Slice()), -1, true) - diff := p.Subtract(q) - positive := mlx.Maximum(diff, mlx.FromValue(float32(1e-20))) - logits := mlx.Log(positive) - logits = mlx.Where(diff.LessEqual(mlx.FromValue(float32(0))), mlx.FromValue(float32(math.Inf(-1))), logits) - return mtpTokenVector(logits.Categorical(-1).AsType(mlx.DTypeInt32)) +func (r *Runner) mtpSampleResidualToken(targetDist, draftDist sampler.Distribution, index int) *mlx.Array { + residual := targetDist.SliceRows(index, index+1).ResidualAgainst(draftDist.SliceRows(index, index+1)) + return mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, residual)) } func mtpTokenInput(token *mlx.Array) *mlx.Array { diff --git a/x/mlxrunner/sample/sample.go b/x/mlxrunner/sample/sample.go index bc6baaa96..aee68998e 100644 --- a/x/mlxrunner/sample/sample.go +++ b/x/mlxrunner/sample/sample.go @@ -17,6 +17,8 @@ type Options struct { RepeatPenalty float32 PresencePenalty float32 FrequencyPenalty float32 + Seed int + UseSeed bool // Logprobs causes Sample to populate Result.Logprob with the selected // token's log-probability. TopLogprobs (when > 0) adds top-K pairs. @@ -42,6 +44,123 @@ func (r Result) Arrays() []*mlx.Array { return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs} } +// Distribution is the filtered probability distribution used by the sampler. +// When IDs is nil, Probs is dense over the vocabulary. When IDs is set, Probs +// is sparse over the token ids in IDs, preserving GPU residency for the +// top-k-first path used by normal and speculative sampling. +type Distribution struct { + IDs *mlx.Array // sparse token ids, shape [B,K]; nil for dense distributions + Probs *mlx.Array // probabilities, shape [B,K] or [B,V] +} + +// Arrays returns the tensor fields for mlx lifecycle management. +func (d Distribution) Arrays() []*mlx.Array { + return []*mlx.Array{d.IDs, d.Probs} +} + +// Rows returns the number of rows in the distribution. +func (d Distribution) Rows() int { + if d.Probs == nil { + return 0 + } + return d.Probs.Dim(0) +} + +// SliceRows returns a row slice while preserving sparse/dense layout. +func (d Distribution) SliceRows(start, stop int) Distribution { + out := Distribution{Probs: d.Probs.Slice(mlx.Slice(start, stop), mlx.Slice())} + if d.IDs != nil { + out.IDs = d.IDs.Slice(mlx.Slice(start, stop), mlx.Slice()) + } + return out +} + +// SampleWithKey draws one token per row using key when supplied. +func (d Distribution) SampleWithKey(key *mlx.Array) *mlx.Array { + choice := logitsFromProbs(d.Probs).CategoricalWithKey(-1, key).AsType(mlx.DTypeInt32) + if d.IDs == nil { + return choice + } + return d.IDs.TakeAlongAxis(choice.ExpandDims(-1), -1).Squeeze(-1).AsType(mlx.DTypeInt32) +} + +// Prob returns the probability assigned to one token per row. +func (d Distribution) Prob(tokens *mlx.Array) *mlx.Array { + switch tokens.NumDims() { + case 2: + if tokens.Dim(0) == 1 { + tokens = tokens.Squeeze(0) + } else if tokens.Dim(1) == 1 { + tokens = tokens.Squeeze(1) + } + case 0: + tokens = tokens.Reshape(1) + } + return d.ProbsForIDs(tokens.ExpandDims(-1)).Squeeze(-1) +} + +// ProbsForIDs returns probabilities for each requested token id. ids must be +// rank-2 [B,N], matching the distribution rows. +func (d Distribution) ProbsForIDs(ids *mlx.Array) *mlx.Array { + if d.IDs == nil { + return d.Probs.TakeAlongAxis(ids, -1) + } + eq := d.IDs.ExpandDims(-1).Equal(ids.ExpandDims(1)) + values := mlx.Where(eq, d.Probs.ExpandDims(-1), mlx.FromValue(float32(0))) + return values.SumAxis(1, false) +} + +// ResidualAgainst returns the Leviathan/Chen rejection distribution +// proportional to max(target - draft, 0). Sparse target distributions stay +// sparse over the target support; tokens outside target support have zero mass. +func (d Distribution) ResidualAgainst(draft Distribution) Distribution { + if d.IDs != nil { + diff := d.Probs.Subtract(draft.ProbsForIDs(d.IDs)) + return Distribution{IDs: d.IDs, Probs: normalizeProbs(mlx.Maximum(diff, mlx.FromValue(float32(0))))} + } + if draft.IDs != nil { + panic("sample.Distribution.ResidualAgainst: dense target with sparse draft is unsupported") + } + diff := d.Probs.Subtract(draft.Probs) + return Distribution{Probs: normalizeProbs(mlx.Maximum(diff, mlx.FromValue(float32(0))))} +} + +// LogProbs returns dense log-probabilities, scattering sparse distributions +// into a full-vocabulary tensor when needed. +func (d Distribution) LogProbs(vocab int) *mlx.Array { + logProbs := logitsFromProbs(d.Probs) + if d.IDs == nil { + return logProbs + } + out := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, d.Probs.Dim(0), vocab), float32(math.Inf(-1))) + return out.PutAlongAxis(d.IDs, logProbs, -1) +} + +// ConcatenateDistributions concatenates distribution rows. All inputs must use +// the same sparse/dense layout. +func ConcatenateDistributions(dists []Distribution) Distribution { + if len(dists) == 0 { + return Distribution{} + } + probs := make([]*mlx.Array, 0, len(dists)) + ids := make([]*mlx.Array, 0, len(dists)) + sparse := dists[0].IDs != nil + for _, d := range dists { + if (d.IDs != nil) != sparse { + panic("sample.ConcatenateDistributions: mixed sparse and dense distributions") + } + probs = append(probs, d.Probs) + if sparse { + ids = append(ids, d.IDs) + } + } + out := Distribution{Probs: mlx.Concatenate(probs, 0)} + if sparse { + out.IDs = mlx.Concatenate(ids, 0) + } + return out +} + // Sampler is a batched, slot-based sampler. Sequences are registered with // Add and released with Remove. Each Sample call takes a subset of // registered slots (in any order) with their [B,V] logits, samples one @@ -73,9 +192,9 @@ type Sampler struct { } type slotState struct { - opts Options - transforms []transform - historyLen int + opts Options + historyLen int + randomCounter uint64 } type slotCtx struct { @@ -83,8 +202,6 @@ type slotCtx struct { history *mlx.Array // 2D [B, W] when penalties are configured; nil otherwise } -type transform func(*slotCtx, *mlx.Array) *mlx.Array - // New constructs an empty sampler with no registered slots. numCtx is // the runner's context window and must be positive. func New(numCtx int) *Sampler { @@ -127,40 +244,17 @@ func (o Options) normalize(numCtx int) Options { // RepeatLastN still batch together and don't inflate pool width. o.RepeatLastN = 0 } + if o.Seed < 0 { + o.UseSeed = false + } + if !o.UseSeed { + // Keep unseeded callers on the same batching path even when a + // meaningless Seed value is present in an Options literal. + o.Seed = 0 + } return o } -func (o Options) buildTransforms() []transform { - var ts []transform - if o.usesHistory() { - ts = append(ts, penalty) - } - - hasTopP := o.TopP > 0 && o.TopP < 1 - hasTopK := o.TopK > 0 - switch { - case hasTopP: - // topKTopP always does a full descending sort for the top-P - // cumulative mask and opportunistically masks top-K during the - // same pass when it is also configured. - ts = append(ts, topKTopP) - case hasTopK: - // Argpartition (partial sort) is cheaper than a full sort. - ts = append(ts, topK) - } - - if o.MinP != 0 { - ts = append(ts, minP) - } - - if o.Temperature == 0 { - ts = append(ts, greedy) - } else { - ts = append(ts, temperature) - } - return ts -} - // Add registers a sequence under seqID. The last RepeatLastN entries of // priorTokens seed the ring buffer. func (s *Sampler) Add(seqID int, opts Options, priorTokens []int32) { @@ -170,8 +264,7 @@ func (s *Sampler) Add(seqID int, opts Options, priorTokens []int32) { opts = opts.normalize(s.numCtx) slot := &slotState{ - opts: opts, - transforms: opts.buildTransforms(), + opts: opts, } // Grow the pool to hold this slot's row. The pool is lazy — the first @@ -349,29 +442,34 @@ func (s *Sampler) Sample(seqIDs []int, logits *mlx.Array) Result { return res } -// SpeculativeScores applies this slot's non-sampling transforms to logits -// without mutating sampler state. Row i is scored as if draftTokens[:i] had -// already been appended to the slot history. logits must be [R,V] or [1,R,V]. -func (s *Sampler) SpeculativeScores(seqID int, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array { - slot, ok := s.byID[seqID] - if !ok { - panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: seqID %d not registered", seqID)) - } - - if logits.NumDims() == 3 { - if logits.Dim(0) != 1 { - panic("sample.Sampler.SpeculativeScores: only batch size 1 is supported") - } - logits = logits.Squeeze(0) - } - if logits.NumDims() != 2 { - panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: logits must be rank 2 or 3, got rank %d", logits.NumDims())) - } - - if draftTokens != nil && draftTokens.NumDims() == 1 { - draftTokens = draftTokens.ExpandDims(0) - } +// Distribution applies this slot's sampling transforms to logits without +// mutating sampler state. Row i is built as if draftTokens[:i] had already +// been appended to the slot history. logits must be [R,V] or [1,R,V]. +func (s *Sampler) Distribution(seqID int, logits *mlx.Array, draftTokens *mlx.Array) Distribution { + slot, logits, draftTokens := s.speculativeInputs("Distribution", seqID, logits, draftTokens) rows := logits.Dim(0) + + var hist *mlx.Array + if slot.opts.usesHistory() { + if s.history == nil { + panic(fmt.Sprintf("sample.Sampler.Distribution: seqID %d has no history", seqID)) + } + if slot.historyLen < slot.opts.RepeatLastN { + return s.speculativeDistributionSerial(slot, logits, draftTokens) + } + hist = s.speculativeHistory(slot, draftTokens, rows) + } + + return slot.distribution(&slotCtx{opts: slot.opts, history: hist}, logits) +} + +// SpeculativeScores applies this slot's sampling transforms to logits without +// mutating sampler state and returns dense log-probability scores for sampled +// decoding. Greedy decoding returns the penalty-adjusted logits. +func (s *Sampler) SpeculativeScores(seqID int, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array { + slot, logits, draftTokens := s.speculativeInputs("SpeculativeScores", seqID, logits, draftTokens) + rows := logits.Dim(0) + var hist *mlx.Array if slot.opts.usesHistory() { if s.history == nil { @@ -386,6 +484,47 @@ func (s *Sampler) SpeculativeScores(seqID int, logits *mlx.Array, draftTokens *m return slot.speculativeScores(&slotCtx{opts: slot.opts, history: hist}, logits) } +// SampleDistribution draws from a precomputed distribution while advancing +// seqID's deterministic RNG stream when a seed is configured. +func (s *Sampler) SampleDistribution(seqID int, dist Distribution) *mlx.Array { + slot := s.mustSlot("SampleDistribution", seqID) + return dist.SampleWithKey(slot.nextRandomKey()) +} + +// Bernoulli samples boolean outcomes while advancing seqID's deterministic RNG +// stream when a seed is configured. +func (s *Sampler) Bernoulli(seqID int, p *mlx.Array) *mlx.Array { + slot := s.mustSlot("Bernoulli", seqID) + return mlx.BernoulliWithKey(p, slot.nextRandomKey()) +} + +func (s *Sampler) mustSlot(caller string, seqID int) *slotState { + slot, ok := s.byID[seqID] + if !ok { + panic(fmt.Sprintf("sample.Sampler.%s: seqID %d not registered", caller, seqID)) + } + return slot +} + +func (s *Sampler) speculativeInputs(caller string, seqID int, logits *mlx.Array, draftTokens *mlx.Array) (*slotState, *mlx.Array, *mlx.Array) { + slot := s.mustSlot(caller, seqID) + + if logits.NumDims() == 3 { + if logits.Dim(0) != 1 { + panic(fmt.Sprintf("sample.Sampler.%s: only batch size 1 is supported", caller)) + } + logits = logits.Squeeze(0) + } + if logits.NumDims() != 2 { + panic(fmt.Sprintf("sample.Sampler.%s: logits must be rank 2 or 3, got rank %d", caller, logits.NumDims())) + } + + if draftTokens != nil && draftTokens.NumDims() == 1 { + draftTokens = draftTokens.ExpandDims(0) + } + return slot, logits, draftTokens +} + // Commit appends already-selected tokens to seqID's repeat-penalty history. // It is used after speculative sampling once the accepted continuation is // known. Normal Sample calls continue to mutate history themselves. @@ -422,7 +561,7 @@ func (s *Sampler) Commit(seqID int, tokens []int32) { slot.historyLen += len(tokens) } -func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array { +func (s *Sampler) speculativeDistributionSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) Distribution { rows := logits.Dim(0) draftCount := 0 if draftTokens != nil { @@ -435,7 +574,7 @@ func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, dr base = s.history.Slice(mlx.Slice(row, row+1), mlx.Slice(0, baseFill)) } - scored := make([]*mlx.Array, 0, rows) + dists := make([]Distribution, 0, rows) for i := range rows { rowLogits := logits.Slice(mlx.Slice(i, i+1), mlx.Slice()) hist := base @@ -451,9 +590,13 @@ func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, dr hist = hist.Slice(mlx.Slice(), mlx.Slice(hist.Dim(1)-slot.opts.RepeatLastN, mlx.End)) } } - scored = append(scored, slot.speculativeScores(&slotCtx{opts: slot.opts, history: hist}, rowLogits)) + dists = append(dists, slot.distribution(&slotCtx{opts: slot.opts, history: hist}, rowLogits)) } - return mlx.Concatenate(scored, 0) + return ConcatenateDistributions(dists) +} + +func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array { + return s.speculativeDistributionSerial(slot, logits, draftTokens).LogProbs(logits.Dim(logits.NumDims() - 1)) } func (s *Sampler) speculativeHistory(slot *slotState, draftTokens *mlx.Array, rows int) *mlx.Array { @@ -489,17 +632,10 @@ func (s *Sampler) speculativeHistory(slot *slotState, draftTokens *mlx.Array, ro } func (slot *slotState) speculativeScores(ctx *slotCtx, logits *mlx.Array) *mlx.Array { - scores := logits - // buildTransforms always appends the final selector transform - // (greedy or temperature sampling). Speculative validation needs the - // processed logits before that selector mutates the distribution. - for _, t := range slot.transforms[:len(slot.transforms)-1] { - scores = t(ctx, scores) + if slot.opts.Temperature == 0 { + return slot.baseScores(ctx, logits) } - if slot.opts.Temperature > 0 { - scores = mlx.DivScalar(scores, slot.opts.Temperature) - } - return scores + return slot.distribution(ctx, logits).LogProbs(logits.Dim(logits.NumDims() - 1)) } // canBatch reports whether the call can take the uniform batched path. @@ -513,6 +649,9 @@ func (s *Sampler) canBatch(slots []*slotState) (Options, bool) { // slots is non-empty (Sample guards) and every slot is registered, // so s.slots[0].opts is the canonical shared value. shared := s.slots[0].opts + // TODO: Before using multi-slot batching with seeded stochastic sampling, + // make sure each row gets its own per-slot random key instead of sharing + // slots[0]'s key through one batched categorical op. if !shared.usesHistory() { return shared, true } @@ -527,7 +666,7 @@ func (s *Sampler) canBatch(slots []*slotState) (Options, bool) { return shared, true } -// sampleTokensUniform runs one fused transform pass over the whole batch. +// sampleTokensUniform runs one fused sampling pass over the whole batch. // Reached only when canBatch is true, which lets the pool be used in place // with a single PutAlongAxis write-back and no gather. func (s *Sampler) sampleTokensUniform(slots []*slotState, opts Options, logits *mlx.Array) *mlx.Array { @@ -542,11 +681,14 @@ func (s *Sampler) sampleTokensUniform(slots []*slotState, opts Options, logits * } ctx := &slotCtx{opts: opts, history: hist} - scores := logits - for _, t := range slots[0].transforms { - scores = t(ctx, scores) + token := slots[0].sample(ctx, logits) + if opts.UseSeed && opts.Temperature != 0 { + // TODO: This only keeps counters aligned; it does not give each slot + // an independent key for the batched draw. + for _, slot := range slots[1:] { + slot.randomCounter++ + } } - token := scores if !opts.usesHistory() { return token @@ -563,8 +705,7 @@ func (s *Sampler) sampleTokensUniform(slots []*slotState, opts Options, logits * return token } -// sampleTokensSerial runs each slot's transforms against its own row of -// logits. +// sampleTokensSerial samples each slot against its own row of logits. func (s *Sampler) sampleTokensSerial(slots []*slotState, logits *mlx.Array) *mlx.Array { perSlotTokens := make([]*mlx.Array, len(slots)) @@ -587,11 +728,7 @@ func (s *Sampler) sampleTokensSerial(slots []*slotState, logits *mlx.Array) *mlx } ctx := &slotCtx{opts: slot.opts, history: hist} - scores := row - for _, t := range slot.transforms { - scores = t(ctx, scores) - } - perSlotTokens[i] = scores + perSlotTokens[i] = slot.sample(ctx, row) } token := mlx.Concatenate(perSlotTokens, 0) @@ -627,74 +764,107 @@ func (s *Sampler) sampleTokensSerial(slots []*slotState, logits *mlx.Array) *mlx return token } -func greedy(_ *slotCtx, scores *mlx.Array) *mlx.Array { - return scores.Argmax(-1, false).AsType(mlx.DTypeInt32) +func (slot *slotState) sample(ctx *slotCtx, logits *mlx.Array) *mlx.Array { + if slot.opts.Temperature == 0 { + return slot.baseScores(ctx, logits).Argmax(-1, false).AsType(mlx.DTypeInt32) + } + return slot.distribution(ctx, logits).SampleWithKey(slot.nextRandomKey()) } -func temperature(ctx *slotCtx, scores *mlx.Array) *mlx.Array { - return mlx.DivScalar(scores, ctx.opts.Temperature).Categorical(-1).AsType(mlx.DTypeInt32) +func (slot *slotState) nextRandomKey() *mlx.Array { + if !slot.opts.UseSeed { + return nil + } + seed := mixSeed(uint64(slot.opts.Seed), slot.randomCounter) + slot.randomCounter++ + return mlx.RandomKey(seed) } -// topKTopP applies top-P in a descending sort pass and, when top-K is also -// configured, masks any surviving value below the K-th largest in the same -// pass. Callers dispatch here whenever top-P is enabled — the top-K-only case -// uses a cheaper partial sort via the topK transform. -func topKTopP(ctx *slotCtx, scores *mlx.Array) *mlx.Array { +const ( + // SplitMix64 constants used to decorrelate nearby (seed, counter) pairs. + splitMix64Weyl = 0x9e3779b97f4a7c15 + splitMix64Mul1 = 0xbf58476d1ce4e5b9 + splitMix64Mul2 = 0x94d049bb133111eb + splitMix64Shift1 = 30 + splitMix64Shift2 = 27 + splitMix64FinalShift = 31 +) + +func mixSeed(seed, counter uint64) uint64 { + z := seed + splitMix64Weyl*(counter+1) + z = (z ^ (z >> splitMix64Shift1)) * splitMix64Mul1 + z = (z ^ (z >> splitMix64Shift2)) * splitMix64Mul2 + return z ^ (z >> splitMix64FinalShift) +} + +func (slot *slotState) baseScores(ctx *slotCtx, logits *mlx.Array) *mlx.Array { + scores := logits + if slot.opts.usesHistory() { + scores = penalty(ctx, scores) + } + return scores +} + +func (slot *slotState) distribution(ctx *slotCtx, logits *mlx.Array) Distribution { + scores := slot.baseScores(ctx, logits) + if slot.opts.Temperature <= 0 { + ids := scores.Argmax(-1, false).AsType(mlx.DTypeInt32).ExpandDims(-1) + probs := mlx.AddScalar(ids.AsType(mlx.DTypeFloat32).Multiply(mlx.FromValue(float32(0))), 1) + return Distribution{IDs: ids, Probs: probs} + } + vocab := scores.Dim(scores.NumDims() - 1) - applyTopK := ctx.opts.TopK > 0 && ctx.opts.TopK < vocab - - order := scores.Negative().ArgsortAxis(-1) - sorted := scores.TakeAlongAxis(order, -1) - negInf := mlx.FromValue(float32(math.Inf(-1))) - - // Top-P: in descending order, keep tokens whose exclusive cumulative - // probability is still below TopP. - probs := mlx.SoftmaxAxis(sorted, -1, true) - prevCumProbs := probs.Cumsum(-1, false, true).Subtract(probs) - keep := prevCumProbs.Less(mlx.FromValue(ctx.opts.TopP)) - sorted = mlx.Where(keep, sorted, negInf) - - out := scores.PutAlongAxis(order, sorted, -1) - - // Top-K: sorted is already in descending order, so positions [K, V) are - // the ones to drop. Scatter -inf through their original-layout indices - // (order[K:]). Positional (not value-based) so exactly K tokens survive — - // ties at the K-th logit get broken by the sort order rather than - // promoted through the filter. - if applyTopK { - dropOrder := order.Slice(mlx.Slice(), mlx.Slice(ctx.opts.TopK, mlx.End)) - out = out.PutAlongAxis(dropOrder, negInf, -1) + if slot.opts.TopK > 0 && slot.opts.TopK < vocab { + return sparseDistribution(ctx.opts, scores) } - - return out + return denseDistribution(ctx.opts, scores) } -func minP(ctx *slotCtx, scores *mlx.Array) *mlx.Array { - if ctx.opts.MinP <= 0 || ctx.opts.MinP > 1 { - return scores - } - - maxScore := scores.MaxAxis(-1, true) - threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(ctx.opts.MinP)))) - - return mlx.Where( - scores.Less(threshold), - mlx.FromValue(float32(math.Inf(-1))), - scores, - ) +func sparseDistribution(opts Options, scores *mlx.Array) Distribution { + ids := scores.Negative().ArgpartitionAxis(opts.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(0, opts.TopK)).AsType(mlx.DTypeInt32) + topScores := scores.TakeAlongAxis(ids, -1).AsType(mlx.DTypeFloat32) + probs := mlx.SoftmaxAxis(mlx.DivScalar(topScores, opts.Temperature), -1, true) + probs = applyTopPProbs(probs, opts.TopP) + probs = applyMinPProbs(probs, opts.MinP) + return Distribution{IDs: ids, Probs: normalizeProbs(probs)} } -func topK(ctx *slotCtx, scores *mlx.Array) *mlx.Array { - if ctx.opts.TopK <= 0 { - return scores - } - vocab := scores.Dim(scores.NumDims() - 1) - if ctx.opts.TopK >= vocab { - return scores - } +func denseDistribution(opts Options, scores *mlx.Array) Distribution { + probs := mlx.SoftmaxAxis(mlx.DivScalar(scores.AsType(mlx.DTypeFloat32), opts.Temperature), -1, true) + probs = applyTopPProbs(probs, opts.TopP) + probs = applyMinPProbs(probs, opts.MinP) + return Distribution{Probs: normalizeProbs(probs)} +} - mask := scores.Negative().ArgpartitionAxis(ctx.opts.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(ctx.opts.TopK, mlx.End)) - return scores.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1) +func applyTopPProbs(probs *mlx.Array, topP float32) *mlx.Array { + if topP <= 0 || topP >= 1 { + return probs + } + order := probs.Negative().ArgsortAxis(-1) + sorted := probs.TakeAlongAxis(order, -1) + prevCumProbs := sorted.Cumsum(-1, false, true).Subtract(sorted) + keep := prevCumProbs.Less(mlx.FromValue(topP)) + filtered := mlx.Where(keep, sorted, mlx.FromValue(float32(0))) + return mlx.Zeros(probs.DType(), probs.Dims()...).PutAlongAxis(order, filtered, -1) +} + +func applyMinPProbs(probs *mlx.Array, minP float32) *mlx.Array { + if minP <= 0 || minP > 1 { + return probs + } + threshold := mlx.MulScalar(probs.MaxAxis(-1, true), minP) + return mlx.Where(probs.Less(threshold), mlx.FromValue(float32(0)), probs) +} + +func normalizeProbs(probs *mlx.Array) *mlx.Array { + sum := mlx.Maximum(probs.SumAxis(-1, true), mlx.FromValue(float32(1e-20))) + return probs.Divide(sum) +} + +func logitsFromProbs(probs *mlx.Array) *mlx.Array { + positive := mlx.Maximum(probs, mlx.FromValue(float32(1e-20))) + logits := mlx.Log(positive) + return mlx.Where(probs.LessEqual(mlx.FromValue(float32(0))), mlx.FromValue(float32(math.Inf(-1))), logits) } func penalty(ctx *slotCtx, scores *mlx.Array) *mlx.Array { diff --git a/x/mlxrunner/sample/sample_test.go b/x/mlxrunner/sample/sample_test.go index b11c6b84f..af6920fab 100644 --- a/x/mlxrunner/sample/sample_test.go +++ b/x/mlxrunner/sample/sample_test.go @@ -4,6 +4,7 @@ package sample import ( "math" + "slices" "testing" "github.com/ollama/ollama/x/mlxrunner/mlx" @@ -141,6 +142,132 @@ func TestSampleSingleSlotOptions(t *testing.T) { } } +func TestDistributionAppliesTopKBeforeTopP(t *testing.T) { + skipIfNoMLX(t) + + s := New(128) + t.Cleanup(func() { + s.Free() + mlx.Sweep() + }) + s.Add(0, Options{Temperature: 1, TopK: 2, TopP: 0.7}, nil) + + dist := s.Distribution(0, slotLogits([]float32{logOf(0.6), logOf(0.2), logOf(0.2)}), nil) + mlx.Eval(dist.Arrays()...) + + ids := dist.IDs.Ints() + probs := dist.Probs.Floats() + if len(ids) != 2 || len(probs) != 2 { + t.Fatalf("support = ids %v probs %v, want 2 sparse entries", ids, probs) + } + + foundTop := false + for i, id := range ids { + switch id { + case 0: + foundTop = true + if math.Abs(float64(probs[i]-1)) > 1e-5 { + t.Fatalf("top token prob = %v, want 1; ids=%v probs=%v", probs[i], ids, probs) + } + default: + if math.Abs(float64(probs[i])) > 1e-5 { + t.Fatalf("non-top token %d prob = %v, want 0; ids=%v probs=%v", id, probs[i], ids, probs) + } + } + } + if !foundTop { + t.Fatalf("top-k support %v did not include token 0", ids) + } +} + +func TestDistributionResidualUsesTargetSupport(t *testing.T) { + skipIfNoMLX(t) + + target := Distribution{ + IDs: mlx.NewArrayInt32([]int32{2, 5}, []int32{1, 2}), + Probs: mlx.FromValues([]float32{0.7, 0.3}, 1, 2), + } + draft := Distribution{ + IDs: mlx.NewArrayInt32([]int32{2, 4}, []int32{1, 2}), + Probs: mlx.FromValues([]float32{0.2, 0.8}, 1, 2), + } + + residual := target.ResidualAgainst(draft) + mlx.Eval(residual.Arrays()...) + + ids := residual.IDs.Ints() + probs := residual.Probs.Floats() + want := map[int]float64{2: 0.625, 5: 0.375} + if len(ids) != 2 || len(probs) != 2 { + t.Fatalf("residual = ids %v probs %v, want 2 sparse entries", ids, probs) + } + for i, id := range ids { + w, ok := want[id] + if !ok { + t.Fatalf("residual includes token %d outside target support: ids=%v probs=%v", id, ids, probs) + } + if math.Abs(float64(probs[i])-w) > 1e-5 { + t.Fatalf("residual token %d prob = %v, want %v; ids=%v probs=%v", id, probs[i], w, ids, probs) + } + } +} + +func TestSeededSamplingIsReproducible(t *testing.T) { + skipIfNoMLX(t) + + seededSequence := func(seed int) []int { + s := New(128) + t.Cleanup(func() { + s.Free() + mlx.Sweep() + }) + s.Add(0, Options{Temperature: 1, TopK: 4, Seed: seed, UseSeed: true}, nil) + + logits := slotLogits([]float32{0, 0, 0, 0}) + out := make([]int, 32) + for i := range out { + token := s.Sample([]int{0}, logits).Token + mlx.Eval(token) + out[i] = token.Int() + } + return out + } + + a := seededSequence(1234) + b := seededSequence(1234) + if !slices.Equal(a, b) { + t.Fatalf("same seed produced different sequences:\n%v\n%v", a, b) + } + + c := seededSequence(5678) + if slices.Equal(a, c) { + t.Fatalf("different seeds produced the same sequence: %v", a) + } +} + +func TestSeededBernoulliIsReproducible(t *testing.T) { + skipIfNoMLX(t) + + seededMask := func() []int { + s := New(128) + t.Cleanup(func() { + s.Free() + mlx.Sweep() + }) + s.Add(0, Options{Seed: 99, UseSeed: true}, nil) + + mask := s.Bernoulli(0, mlx.FromValues([]float32{0.5, 0.5, 0.5, 0.5, 0.5, 0.5}, 6)).AsType(mlx.DTypeInt32) + mlx.Eval(mask) + return mask.Ints() + } + + a := seededMask() + b := seededMask() + if !slices.Equal(a, b) { + t.Fatalf("same seed produced different bernoulli masks:\n%v\n%v", a, b) + } +} + // TestSampleHistoryWindow verifies that penalty history respects the // RepeatLastN window: priors longer than RepeatLastN are trimmed on Add, // and once the ring wraps, tokens that rotate out no longer contribute diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go index aa37c82e0..22b608d5e 100644 --- a/x/mlxrunner/server.go +++ b/x/mlxrunner/server.go @@ -136,6 +136,8 @@ func Execute(args []string) error { RepeatPenalty: request.Options.RepeatPenalty, PresencePenalty: request.Options.PresencePenalty, FrequencyPenalty: request.Options.FrequencyPenalty, + Seed: request.Options.Seed, + UseSeed: request.Options.Seed >= 0, Logprobs: request.Logprobs, TopLogprobs: request.TopLogprobs, }