mirror of
https://github.com/ollama/ollama.git
synced 2026-05-13 14:27:00 +00:00
Replace the MLX sampler transform chain with an explicit distribution pipeline that applies: 1. penalties 2. top-k 3. temperature/softmax 4. top-p 5. min-p 6. normalize 7. categorical The common top_k path now keeps sparse [B,K] token ids/probabilities on GPU instead of carrying full-vocab scores, and sampled MTP reuses those draft/target distributions for acceptance, bonus, and residual sampling. This change also fixes the seed parameter so that temperature sampling and sampled MTP are reproducible.
897 lines
29 KiB
Go
897 lines
29 KiB
Go
package sample
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
"slices"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
)
|
|
|
|
type Options struct {
|
|
Temperature float32
|
|
TopP float32
|
|
MinP float32
|
|
TopK int
|
|
RepeatLastN int
|
|
RepeatPenalty float32
|
|
PresencePenalty float32
|
|
FrequencyPenalty float32
|
|
Seed int
|
|
UseSeed bool
|
|
|
|
// Logprobs causes Sample to populate Result.Logprob with the selected
|
|
// token's log-probability. TopLogprobs (when > 0) adds top-K pairs.
|
|
Logprobs bool
|
|
TopLogprobs int
|
|
}
|
|
|
|
// Result bundles the outputs of one decode step. Logprob/TopTokens/
|
|
// TopLogprobs are populated whenever any registered slot has Logprobs
|
|
// (respectively TopLogprobs>0). Consumers need to filter by their
|
|
// per-slot Options.
|
|
type Result struct {
|
|
Token *mlx.Array // sampled token ids, shape [B]
|
|
Logprob *mlx.Array // sampled-token logprobs, shape [B,1]; nil unless any registered slot has Logprobs
|
|
TopTokens *mlx.Array // top-K token ids, shape [B,maxK]; nil unless any registered slot has TopLogprobs>0
|
|
TopLogprobs *mlx.Array // top-K logprobs, shape [B,maxK]; same
|
|
}
|
|
|
|
// Arrays returns the tensor fields as a slice so callers can drive the mlx
|
|
// lifecycle verbs (Pin, Unpin, Eval, AsyncEval) over the whole group. Unset
|
|
// fields stay nil; the mlx helpers skip them.
|
|
func (r Result) Arrays() []*mlx.Array {
|
|
return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs}
|
|
}
|
|
|
|
// Distribution is the filtered probability distribution used by the sampler.
|
|
// When IDs is nil, Probs is dense over the vocabulary. When IDs is set, Probs
|
|
// is sparse over the token ids in IDs, preserving GPU residency for the
|
|
// top-k-first path used by normal and speculative sampling.
|
|
type Distribution struct {
|
|
IDs *mlx.Array // sparse token ids, shape [B,K]; nil for dense distributions
|
|
Probs *mlx.Array // probabilities, shape [B,K] or [B,V]
|
|
}
|
|
|
|
// Arrays returns the tensor fields for mlx lifecycle management.
|
|
func (d Distribution) Arrays() []*mlx.Array {
|
|
return []*mlx.Array{d.IDs, d.Probs}
|
|
}
|
|
|
|
// Rows returns the number of rows in the distribution.
|
|
func (d Distribution) Rows() int {
|
|
if d.Probs == nil {
|
|
return 0
|
|
}
|
|
return d.Probs.Dim(0)
|
|
}
|
|
|
|
// SliceRows returns a row slice while preserving sparse/dense layout.
|
|
func (d Distribution) SliceRows(start, stop int) Distribution {
|
|
out := Distribution{Probs: d.Probs.Slice(mlx.Slice(start, stop), mlx.Slice())}
|
|
if d.IDs != nil {
|
|
out.IDs = d.IDs.Slice(mlx.Slice(start, stop), mlx.Slice())
|
|
}
|
|
return out
|
|
}
|
|
|
|
// SampleWithKey draws one token per row using key when supplied.
|
|
func (d Distribution) SampleWithKey(key *mlx.Array) *mlx.Array {
|
|
choice := logitsFromProbs(d.Probs).CategoricalWithKey(-1, key).AsType(mlx.DTypeInt32)
|
|
if d.IDs == nil {
|
|
return choice
|
|
}
|
|
return d.IDs.TakeAlongAxis(choice.ExpandDims(-1), -1).Squeeze(-1).AsType(mlx.DTypeInt32)
|
|
}
|
|
|
|
// Prob returns the probability assigned to one token per row.
|
|
func (d Distribution) Prob(tokens *mlx.Array) *mlx.Array {
|
|
switch tokens.NumDims() {
|
|
case 2:
|
|
if tokens.Dim(0) == 1 {
|
|
tokens = tokens.Squeeze(0)
|
|
} else if tokens.Dim(1) == 1 {
|
|
tokens = tokens.Squeeze(1)
|
|
}
|
|
case 0:
|
|
tokens = tokens.Reshape(1)
|
|
}
|
|
return d.ProbsForIDs(tokens.ExpandDims(-1)).Squeeze(-1)
|
|
}
|
|
|
|
// ProbsForIDs returns probabilities for each requested token id. ids must be
|
|
// rank-2 [B,N], matching the distribution rows.
|
|
func (d Distribution) ProbsForIDs(ids *mlx.Array) *mlx.Array {
|
|
if d.IDs == nil {
|
|
return d.Probs.TakeAlongAxis(ids, -1)
|
|
}
|
|
eq := d.IDs.ExpandDims(-1).Equal(ids.ExpandDims(1))
|
|
values := mlx.Where(eq, d.Probs.ExpandDims(-1), mlx.FromValue(float32(0)))
|
|
return values.SumAxis(1, false)
|
|
}
|
|
|
|
// ResidualAgainst returns the Leviathan/Chen rejection distribution
|
|
// proportional to max(target - draft, 0). Sparse target distributions stay
|
|
// sparse over the target support; tokens outside target support have zero mass.
|
|
func (d Distribution) ResidualAgainst(draft Distribution) Distribution {
|
|
if d.IDs != nil {
|
|
diff := d.Probs.Subtract(draft.ProbsForIDs(d.IDs))
|
|
return Distribution{IDs: d.IDs, Probs: normalizeProbs(mlx.Maximum(diff, mlx.FromValue(float32(0))))}
|
|
}
|
|
if draft.IDs != nil {
|
|
panic("sample.Distribution.ResidualAgainst: dense target with sparse draft is unsupported")
|
|
}
|
|
diff := d.Probs.Subtract(draft.Probs)
|
|
return Distribution{Probs: normalizeProbs(mlx.Maximum(diff, mlx.FromValue(float32(0))))}
|
|
}
|
|
|
|
// LogProbs returns dense log-probabilities, scattering sparse distributions
|
|
// into a full-vocabulary tensor when needed.
|
|
func (d Distribution) LogProbs(vocab int) *mlx.Array {
|
|
logProbs := logitsFromProbs(d.Probs)
|
|
if d.IDs == nil {
|
|
return logProbs
|
|
}
|
|
out := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, d.Probs.Dim(0), vocab), float32(math.Inf(-1)))
|
|
return out.PutAlongAxis(d.IDs, logProbs, -1)
|
|
}
|
|
|
|
// ConcatenateDistributions concatenates distribution rows. All inputs must use
|
|
// the same sparse/dense layout.
|
|
func ConcatenateDistributions(dists []Distribution) Distribution {
|
|
if len(dists) == 0 {
|
|
return Distribution{}
|
|
}
|
|
probs := make([]*mlx.Array, 0, len(dists))
|
|
ids := make([]*mlx.Array, 0, len(dists))
|
|
sparse := dists[0].IDs != nil
|
|
for _, d := range dists {
|
|
if (d.IDs != nil) != sparse {
|
|
panic("sample.ConcatenateDistributions: mixed sparse and dense distributions")
|
|
}
|
|
probs = append(probs, d.Probs)
|
|
if sparse {
|
|
ids = append(ids, d.IDs)
|
|
}
|
|
}
|
|
out := Distribution{Probs: mlx.Concatenate(probs, 0)}
|
|
if sparse {
|
|
out.IDs = mlx.Concatenate(ids, 0)
|
|
}
|
|
return out
|
|
}
|
|
|
|
// Sampler is a batched, slot-based sampler. Sequences are registered with
|
|
// Add and released with Remove. Each Sample call takes a subset of
|
|
// registered slots (in any order) with their [B,V] logits, samples one
|
|
// token per row, and appends it to that slot's ring-buffer history. Slots
|
|
// not named in a given call are untouched.
|
|
type Sampler struct {
|
|
slots []*slotState
|
|
byID map[int]*slotState
|
|
|
|
// history is the pooled ring-buffer storage, [B, W] int32. Row i
|
|
// belongs to slots[i]; W is max(RepeatLastN) across penalty slots.
|
|
// Allocated on the first penalty slot, rebuilt only in Add/Remove.
|
|
history *mlx.Array
|
|
|
|
// allSameOpts: every registered slot shares Options. When true the
|
|
// canonical shared value is s.slots[0].opts.
|
|
allSameOpts bool
|
|
|
|
// anyLogprobs / maxTopLogprobs: compute-for-all output config.
|
|
// Sample populates Logprob (and Top* when maxTopLogprobs>0) whenever
|
|
// any registered slot requests them, even if that slot isn't in the
|
|
// current call.
|
|
anyLogprobs bool
|
|
maxTopLogprobs int
|
|
|
|
// numCtx is the runner's context window; normalize uses it to
|
|
// resolve the repeat_last_n == -1 sentinel.
|
|
numCtx int
|
|
}
|
|
|
|
type slotState struct {
|
|
opts Options
|
|
historyLen int
|
|
randomCounter uint64
|
|
}
|
|
|
|
type slotCtx struct {
|
|
opts Options
|
|
history *mlx.Array // 2D [B, W] when penalties are configured; nil otherwise
|
|
}
|
|
|
|
// New constructs an empty sampler with no registered slots. numCtx is
|
|
// the runner's context window and must be positive.
|
|
func New(numCtx int) *Sampler {
|
|
return &Sampler{
|
|
byID: make(map[int]*slotState),
|
|
allSameOpts: true,
|
|
numCtx: numCtx,
|
|
}
|
|
}
|
|
|
|
// historyWidth returns the column count of the pooled history tensor,
|
|
// or 0 when no penalty slot has forced it to be allocated.
|
|
func (s *Sampler) historyWidth() int {
|
|
if s.history == nil {
|
|
return 0
|
|
}
|
|
return s.history.Dim(1)
|
|
}
|
|
|
|
func (o Options) usesHistory() bool {
|
|
// RepeatLastN == 0 disables the penalty ring per the repeat_last_n API
|
|
// contract (0 = disabled), overriding any penalty coefficients.
|
|
if o.RepeatLastN == 0 {
|
|
return false
|
|
}
|
|
return o.RepeatPenalty != 1 || o.PresencePenalty != 0 || o.FrequencyPenalty != 0
|
|
}
|
|
|
|
func (o Options) normalize(numCtx int) Options {
|
|
if o.RepeatPenalty <= 0 {
|
|
o.RepeatPenalty = 1
|
|
}
|
|
// Resolve the repeat_last_n == -1 sentinel ("-1 = num_ctx") against
|
|
// the caller's context window.
|
|
if o.RepeatLastN < 0 {
|
|
o.RepeatLastN = numCtx
|
|
}
|
|
if !o.usesHistory() {
|
|
// Zero the ring capacity so slots that differ only in a spurious
|
|
// RepeatLastN still batch together and don't inflate pool width.
|
|
o.RepeatLastN = 0
|
|
}
|
|
if o.Seed < 0 {
|
|
o.UseSeed = false
|
|
}
|
|
if !o.UseSeed {
|
|
// Keep unseeded callers on the same batching path even when a
|
|
// meaningless Seed value is present in an Options literal.
|
|
o.Seed = 0
|
|
}
|
|
return o
|
|
}
|
|
|
|
// Add registers a sequence under seqID. The last RepeatLastN entries of
|
|
// priorTokens seed the ring buffer.
|
|
func (s *Sampler) Add(seqID int, opts Options, priorTokens []int32) {
|
|
if _, dup := s.byID[seqID]; dup {
|
|
panic(fmt.Sprintf("sample.Sampler.Add: seqID %d already registered", seqID))
|
|
}
|
|
|
|
opts = opts.normalize(s.numCtx)
|
|
slot := &slotState{
|
|
opts: opts,
|
|
}
|
|
|
|
// Grow the pool to hold this slot's row. The pool is lazy — the first
|
|
// penalty slot allocates it — and thereafter every registered slot
|
|
// gets a row (rows for non-penalty slots are zero and never read).
|
|
// Invariant: s.history is pinned whenever non-nil.
|
|
if s.history != nil || opts.usesHistory() {
|
|
targetWidth := max(opts.RepeatLastN, s.historyWidth())
|
|
newRow := makeHistoryRow(priorTokens, opts.RepeatLastN, targetWidth)
|
|
|
|
var pool *mlx.Array
|
|
switch {
|
|
case s.history == nil && len(s.slots) == 0:
|
|
pool = newRow
|
|
case s.history == nil:
|
|
// First penalty slot with non-penalty slots already registered;
|
|
// seed zero rows so s.slots and pool row indices stay aligned.
|
|
zeros := mlx.Zeros(mlx.DTypeInt32, len(s.slots), targetWidth)
|
|
pool = zeros.Concatenate(0, newRow)
|
|
case targetWidth > s.historyWidth():
|
|
pad := mlx.Zeros(mlx.DTypeInt32, s.history.Dim(0), targetWidth-s.historyWidth())
|
|
pool = s.history.Concatenate(1, pad).Concatenate(0, newRow)
|
|
default:
|
|
pool = s.history.Concatenate(0, newRow)
|
|
}
|
|
|
|
mlx.Pin(pool)
|
|
mlx.Unpin(s.history)
|
|
s.history = pool
|
|
|
|
if opts.usesHistory() {
|
|
// Cap on seed so the next write's ring position
|
|
// (historyLen % RepeatLastN) lands at 0, overwriting the
|
|
// oldest entry when the ring was filled from priors.
|
|
slot.historyLen = min(len(priorTokens), opts.RepeatLastN)
|
|
}
|
|
}
|
|
|
|
s.slots = append(s.slots, slot)
|
|
s.byID[seqID] = slot
|
|
s.recomputeInvariants()
|
|
}
|
|
|
|
// makeHistoryRow builds a [1, width] int32 row with the last repeatLastN
|
|
// entries of priorTokens packed into [0, min(len, repeatLastN)), zeros
|
|
// elsewhere.
|
|
func makeHistoryRow(priorTokens []int32, repeatLastN, width int) *mlx.Array {
|
|
take := min(len(priorTokens), repeatLastN)
|
|
if take <= 0 {
|
|
return mlx.Zeros(mlx.DTypeInt32, 1, width)
|
|
}
|
|
row := make([]int32, width)
|
|
copy(row, priorTokens[len(priorTokens)-take:])
|
|
return mlx.NewArrayInt32(row, []int32{1, int32(width)})
|
|
}
|
|
|
|
// recomputeInvariants refreshes allSameOpts and anyLogprobs/maxTopLogprobs
|
|
// from s.slots. Called at the end of Add and Remove.
|
|
func (s *Sampler) recomputeInvariants() {
|
|
if len(s.slots) == 0 {
|
|
s.allSameOpts = true
|
|
s.anyLogprobs = false
|
|
s.maxTopLogprobs = 0
|
|
return
|
|
}
|
|
first := s.slots[0].opts
|
|
s.allSameOpts = true
|
|
s.anyLogprobs = false
|
|
s.maxTopLogprobs = 0
|
|
for _, slot := range s.slots {
|
|
if slot.opts != first {
|
|
s.allSameOpts = false
|
|
}
|
|
if slot.opts.Logprobs {
|
|
s.anyLogprobs = true
|
|
if slot.opts.TopLogprobs > s.maxTopLogprobs {
|
|
s.maxTopLogprobs = slot.opts.TopLogprobs
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Remove releases the slot. The pool tensor is rebuilt to drop the row.
|
|
func (s *Sampler) Remove(seqID int) {
|
|
slot, ok := s.byID[seqID]
|
|
if !ok {
|
|
return
|
|
}
|
|
delete(s.byID, seqID)
|
|
|
|
row := slices.Index(s.slots, slot)
|
|
s.slots = slices.Delete(s.slots, row, row+1)
|
|
s.recomputeInvariants()
|
|
|
|
if s.history == nil {
|
|
return
|
|
}
|
|
|
|
n := s.history.Dim(0)
|
|
var newHistory *mlx.Array
|
|
switch {
|
|
case n == 1:
|
|
newHistory = nil
|
|
case row == 0:
|
|
newHistory = s.history.Slice(mlx.Slice(1, n), mlx.Slice())
|
|
case row == n-1:
|
|
newHistory = s.history.Slice(mlx.Slice(0, row), mlx.Slice())
|
|
default:
|
|
before := s.history.Slice(mlx.Slice(0, row), mlx.Slice())
|
|
after := s.history.Slice(mlx.Slice(row+1, n), mlx.Slice())
|
|
newHistory = before.Concatenate(0, after)
|
|
}
|
|
|
|
mlx.Pin(newHistory)
|
|
mlx.Unpin(s.history)
|
|
s.history = newHistory
|
|
}
|
|
|
|
// Free releases the pooled history tensor and resets the sampler to the
|
|
// New-equivalent state so it may be reused.
|
|
func (s *Sampler) Free() {
|
|
mlx.Unpin(s.history)
|
|
*s = Sampler{
|
|
byID: make(map[int]*slotState),
|
|
allSameOpts: true,
|
|
numCtx: s.numCtx,
|
|
}
|
|
}
|
|
|
|
// Sample draws one token per row of logits ([B,V]); seqIDs[i] names the
|
|
// slot whose logits live at row i. Each sampled token is appended to its
|
|
// slot's ring. Slots not named in seqIDs are untouched.
|
|
func (s *Sampler) Sample(seqIDs []int, logits *mlx.Array) Result {
|
|
if len(seqIDs) == 0 {
|
|
return Result{}
|
|
}
|
|
|
|
slots := make([]*slotState, len(seqIDs))
|
|
for i, id := range seqIDs {
|
|
slot, ok := s.byID[id]
|
|
if !ok {
|
|
panic(fmt.Sprintf("sample.Sampler.Sample: seqID %d not registered", id))
|
|
}
|
|
slots[i] = slot
|
|
}
|
|
|
|
var token *mlx.Array
|
|
if opts0, ok := s.canBatch(slots); ok {
|
|
token = s.sampleTokensUniform(slots, opts0, logits)
|
|
} else {
|
|
token = s.sampleTokensSerial(slots, logits)
|
|
}
|
|
|
|
res := Result{Token: token}
|
|
if s.anyLogprobs {
|
|
// Log-softmax over original logits so every row holds a truthful
|
|
// value (compute-for-all; consumers filter per-slot). Subtract
|
|
// max first for numerical stability in the logsumexp.
|
|
lp := logits.AsType(mlx.DTypeFloat32)
|
|
lp = lp.Subtract(lp.MaxAxis(-1, true))
|
|
lp = lp.Subtract(lp.LogsumexpAxis(-1, true))
|
|
res.Logprob = lp.TakeAlongAxis(token.ExpandDims(-1), -1)
|
|
if s.maxTopLogprobs > 0 {
|
|
k := s.maxTopLogprobs
|
|
if vocab := lp.Dim(lp.NumDims() - 1); k > vocab {
|
|
k = vocab
|
|
}
|
|
// Argpartition on the negated values places the K largest
|
|
// (unsorted) in positions [0:K].
|
|
idx := lp.Negative().ArgpartitionAxis(k-1, -1).Slice(mlx.Slice(), mlx.Slice(0, k))
|
|
res.TopTokens = idx.AsType(mlx.DTypeInt32)
|
|
res.TopLogprobs = lp.TakeAlongAxis(idx, -1)
|
|
}
|
|
}
|
|
return res
|
|
}
|
|
|
|
// Distribution applies this slot's sampling transforms to logits without
|
|
// mutating sampler state. Row i is built as if draftTokens[:i] had already
|
|
// been appended to the slot history. logits must be [R,V] or [1,R,V].
|
|
func (s *Sampler) Distribution(seqID int, logits *mlx.Array, draftTokens *mlx.Array) Distribution {
|
|
slot, logits, draftTokens := s.speculativeInputs("Distribution", seqID, logits, draftTokens)
|
|
rows := logits.Dim(0)
|
|
|
|
var hist *mlx.Array
|
|
if slot.opts.usesHistory() {
|
|
if s.history == nil {
|
|
panic(fmt.Sprintf("sample.Sampler.Distribution: seqID %d has no history", seqID))
|
|
}
|
|
if slot.historyLen < slot.opts.RepeatLastN {
|
|
return s.speculativeDistributionSerial(slot, logits, draftTokens)
|
|
}
|
|
hist = s.speculativeHistory(slot, draftTokens, rows)
|
|
}
|
|
|
|
return slot.distribution(&slotCtx{opts: slot.opts, history: hist}, logits)
|
|
}
|
|
|
|
// SpeculativeScores applies this slot's sampling transforms to logits without
|
|
// mutating sampler state and returns dense log-probability scores for sampled
|
|
// decoding. Greedy decoding returns the penalty-adjusted logits.
|
|
func (s *Sampler) SpeculativeScores(seqID int, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
|
|
slot, logits, draftTokens := s.speculativeInputs("SpeculativeScores", seqID, logits, draftTokens)
|
|
rows := logits.Dim(0)
|
|
|
|
var hist *mlx.Array
|
|
if slot.opts.usesHistory() {
|
|
if s.history == nil {
|
|
panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: seqID %d has no history", seqID))
|
|
}
|
|
if slot.historyLen < slot.opts.RepeatLastN {
|
|
return s.speculativeScoresSerial(slot, logits, draftTokens)
|
|
}
|
|
hist = s.speculativeHistory(slot, draftTokens, rows)
|
|
}
|
|
|
|
return slot.speculativeScores(&slotCtx{opts: slot.opts, history: hist}, logits)
|
|
}
|
|
|
|
// SampleDistribution draws from a precomputed distribution while advancing
|
|
// seqID's deterministic RNG stream when a seed is configured.
|
|
func (s *Sampler) SampleDistribution(seqID int, dist Distribution) *mlx.Array {
|
|
slot := s.mustSlot("SampleDistribution", seqID)
|
|
return dist.SampleWithKey(slot.nextRandomKey())
|
|
}
|
|
|
|
// Bernoulli samples boolean outcomes while advancing seqID's deterministic RNG
|
|
// stream when a seed is configured.
|
|
func (s *Sampler) Bernoulli(seqID int, p *mlx.Array) *mlx.Array {
|
|
slot := s.mustSlot("Bernoulli", seqID)
|
|
return mlx.BernoulliWithKey(p, slot.nextRandomKey())
|
|
}
|
|
|
|
func (s *Sampler) mustSlot(caller string, seqID int) *slotState {
|
|
slot, ok := s.byID[seqID]
|
|
if !ok {
|
|
panic(fmt.Sprintf("sample.Sampler.%s: seqID %d not registered", caller, seqID))
|
|
}
|
|
return slot
|
|
}
|
|
|
|
func (s *Sampler) speculativeInputs(caller string, seqID int, logits *mlx.Array, draftTokens *mlx.Array) (*slotState, *mlx.Array, *mlx.Array) {
|
|
slot := s.mustSlot(caller, seqID)
|
|
|
|
if logits.NumDims() == 3 {
|
|
if logits.Dim(0) != 1 {
|
|
panic(fmt.Sprintf("sample.Sampler.%s: only batch size 1 is supported", caller))
|
|
}
|
|
logits = logits.Squeeze(0)
|
|
}
|
|
if logits.NumDims() != 2 {
|
|
panic(fmt.Sprintf("sample.Sampler.%s: logits must be rank 2 or 3, got rank %d", caller, logits.NumDims()))
|
|
}
|
|
|
|
if draftTokens != nil && draftTokens.NumDims() == 1 {
|
|
draftTokens = draftTokens.ExpandDims(0)
|
|
}
|
|
return slot, logits, draftTokens
|
|
}
|
|
|
|
// Commit appends already-selected tokens to seqID's repeat-penalty history.
|
|
// It is used after speculative sampling once the accepted continuation is
|
|
// known. Normal Sample calls continue to mutate history themselves.
|
|
func (s *Sampler) Commit(seqID int, tokens []int32) {
|
|
if len(tokens) == 0 {
|
|
return
|
|
}
|
|
slot, ok := s.byID[seqID]
|
|
if !ok {
|
|
panic(fmt.Sprintf("sample.Sampler.Commit: seqID %d not registered", seqID))
|
|
}
|
|
if !slot.opts.usesHistory() {
|
|
return
|
|
}
|
|
if s.history == nil {
|
|
panic(fmt.Sprintf("sample.Sampler.Commit: seqID %d has no history", seqID))
|
|
}
|
|
|
|
row := slices.Index(s.slots, slot)
|
|
width := s.historyWidth()
|
|
take := min(len(tokens), slot.opts.RepeatLastN)
|
|
startLen := slot.historyLen + len(tokens) - take
|
|
writeTokens := tokens[len(tokens)-take:]
|
|
flatOffsets := make([]int32, take)
|
|
for i := range take {
|
|
ringPos := (startLen + i) % slot.opts.RepeatLastN
|
|
flatOffsets[i] = int32(row*width + ringPos)
|
|
}
|
|
|
|
flatIdx := mlx.NewArrayInt32(flatOffsets, []int32{int32(take), 1})
|
|
values := mlx.NewArrayInt32(writeTokens, []int32{int32(take), 1})
|
|
flatHist := s.history.Reshape(s.history.Dim(0)*width, 1)
|
|
s.history.Set(flatHist.PutAlongAxis(flatIdx, values, 0).Reshape(s.history.Dim(0), width))
|
|
slot.historyLen += len(tokens)
|
|
}
|
|
|
|
func (s *Sampler) speculativeDistributionSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) Distribution {
|
|
rows := logits.Dim(0)
|
|
draftCount := 0
|
|
if draftTokens != nil {
|
|
draftCount = draftTokens.Dim(1)
|
|
}
|
|
row := slices.Index(s.slots, slot)
|
|
baseFill := min(slot.historyLen, slot.opts.RepeatLastN)
|
|
var base *mlx.Array
|
|
if baseFill > 0 {
|
|
base = s.history.Slice(mlx.Slice(row, row+1), mlx.Slice(0, baseFill))
|
|
}
|
|
|
|
dists := make([]Distribution, 0, rows)
|
|
for i := range rows {
|
|
rowLogits := logits.Slice(mlx.Slice(i, i+1), mlx.Slice())
|
|
hist := base
|
|
prefixLen := min(i, draftCount)
|
|
if prefixLen > 0 {
|
|
prefix := draftTokens.Slice(mlx.Slice(), mlx.Slice(0, prefixLen))
|
|
if hist == nil {
|
|
hist = prefix
|
|
} else {
|
|
hist = hist.Concatenate(1, prefix)
|
|
}
|
|
if hist.Dim(1) > slot.opts.RepeatLastN {
|
|
hist = hist.Slice(mlx.Slice(), mlx.Slice(hist.Dim(1)-slot.opts.RepeatLastN, mlx.End))
|
|
}
|
|
}
|
|
dists = append(dists, slot.distribution(&slotCtx{opts: slot.opts, history: hist}, rowLogits))
|
|
}
|
|
return ConcatenateDistributions(dists)
|
|
}
|
|
|
|
func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
|
|
return s.speculativeDistributionSerial(slot, logits, draftTokens).LogProbs(logits.Dim(logits.NumDims() - 1))
|
|
}
|
|
|
|
func (s *Sampler) speculativeHistory(slot *slotState, draftTokens *mlx.Array, rows int) *mlx.Array {
|
|
row := slices.Index(s.slots, slot)
|
|
width := slot.opts.RepeatLastN
|
|
base := s.history.Slice(mlx.Slice(row, row+1), mlx.Slice(0, width))
|
|
base = mlx.Tile(base, []int32{int32(rows), 1})
|
|
next := slot.historyLen % width
|
|
draftCount := 0
|
|
if draftTokens != nil {
|
|
draftCount = draftTokens.Dim(1)
|
|
}
|
|
if draftCount == 0 {
|
|
return base
|
|
}
|
|
|
|
sourceIdx := make([]int32, rows*width)
|
|
writeMask := make([]bool, rows*width)
|
|
for i := range rows {
|
|
prefixLen := min(i, draftCount)
|
|
for j := range prefixLen {
|
|
pos := (next + j) % width
|
|
sourceIdx[i*width+pos] = int32(j)
|
|
writeMask[i*width+pos] = true
|
|
}
|
|
}
|
|
|
|
draftRows := mlx.Tile(draftTokens, []int32{int32(rows), 1})
|
|
idx := mlx.NewArrayInt32(sourceIdx, []int32{int32(rows), int32(width)})
|
|
mask := mlx.FromValues(writeMask, rows, width)
|
|
values := draftRows.TakeAlongAxis(idx, 1)
|
|
return mlx.Where(mask, values, base)
|
|
}
|
|
|
|
func (slot *slotState) speculativeScores(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
|
|
if slot.opts.Temperature == 0 {
|
|
return slot.baseScores(ctx, logits)
|
|
}
|
|
return slot.distribution(ctx, logits).LogProbs(logits.Dim(logits.NumDims() - 1))
|
|
}
|
|
|
|
// canBatch reports whether the call can take the uniform batched path.
|
|
// All slots must share Options; when penalties are active the call must
|
|
// additionally cover every registered slot in registration order with a
|
|
// full ring, because the uniform path indexes the pool positionally.
|
|
func (s *Sampler) canBatch(slots []*slotState) (Options, bool) {
|
|
if !s.allSameOpts {
|
|
return Options{}, false
|
|
}
|
|
// slots is non-empty (Sample guards) and every slot is registered,
|
|
// so s.slots[0].opts is the canonical shared value.
|
|
shared := s.slots[0].opts
|
|
// TODO: Before using multi-slot batching with seeded stochastic sampling,
|
|
// make sure each row gets its own per-slot random key instead of sharing
|
|
// slots[0]'s key through one batched categorical op.
|
|
if !shared.usesHistory() {
|
|
return shared, true
|
|
}
|
|
if len(slots) != len(s.slots) {
|
|
return Options{}, false
|
|
}
|
|
for i, slot := range slots {
|
|
if s.slots[i] != slot || slot.historyLen < shared.RepeatLastN {
|
|
return Options{}, false
|
|
}
|
|
}
|
|
return shared, true
|
|
}
|
|
|
|
// sampleTokensUniform runs one fused sampling pass over the whole batch.
|
|
// Reached only when canBatch is true, which lets the pool be used in place
|
|
// with a single PutAlongAxis write-back and no gather.
|
|
func (s *Sampler) sampleTokensUniform(slots []*slotState, opts Options, logits *mlx.Array) *mlx.Array {
|
|
B := len(slots)
|
|
|
|
var hist *mlx.Array
|
|
if opts.usesHistory() {
|
|
hist = s.history
|
|
if s.historyWidth() > opts.RepeatLastN {
|
|
hist = hist.Slice(mlx.Slice(), mlx.Slice(0, opts.RepeatLastN))
|
|
}
|
|
}
|
|
|
|
ctx := &slotCtx{opts: opts, history: hist}
|
|
token := slots[0].sample(ctx, logits)
|
|
if opts.UseSeed && opts.Temperature != 0 {
|
|
// TODO: This only keeps counters aligned; it does not give each slot
|
|
// an independent key for the batched draw.
|
|
for _, slot := range slots[1:] {
|
|
slot.randomCounter++
|
|
}
|
|
}
|
|
|
|
if !opts.usesHistory() {
|
|
return token
|
|
}
|
|
|
|
writeIdxData := make([]int32, B)
|
|
for i, slot := range slots {
|
|
writeIdxData[i] = int32(slot.historyLen % opts.RepeatLastN)
|
|
slot.historyLen++
|
|
}
|
|
writeIdx := mlx.NewArrayInt32(writeIdxData, []int32{int32(B), 1})
|
|
|
|
s.history.Set(s.history.PutAlongAxis(writeIdx, token.ExpandDims(-1), 1))
|
|
return token
|
|
}
|
|
|
|
// sampleTokensSerial samples each slot against its own row of logits.
|
|
func (s *Sampler) sampleTokensSerial(slots []*slotState, logits *mlx.Array) *mlx.Array {
|
|
perSlotTokens := make([]*mlx.Array, len(slots))
|
|
|
|
rowOf := make(map[*slotState]int, len(s.slots))
|
|
for i, slot := range s.slots {
|
|
rowOf[slot] = i
|
|
}
|
|
|
|
for i, slot := range slots {
|
|
row := logits.Slice(mlx.Slice(i, i+1), mlx.Slice())
|
|
|
|
var hist *mlx.Array
|
|
if slot.opts.usesHistory() && slot.historyLen > 0 && s.history != nil {
|
|
poolRow := rowOf[slot]
|
|
fill := min(slot.historyLen, slot.opts.RepeatLastN)
|
|
hist = s.history.Slice(
|
|
mlx.Slice(poolRow, poolRow+1),
|
|
mlx.Slice(0, fill),
|
|
)
|
|
}
|
|
|
|
ctx := &slotCtx{opts: slot.opts, history: hist}
|
|
perSlotTokens[i] = slot.sample(ctx, row)
|
|
}
|
|
|
|
token := mlx.Concatenate(perSlotTokens, 0)
|
|
|
|
if s.history != nil {
|
|
// For each writing slot collect its flat (row-major) pool offset
|
|
// and the call-order position of its token. One PutAlongAxis on a
|
|
// flat view of the pool scatters all writes in a single op.
|
|
flatOffsets := make([]int32, 0, len(slots))
|
|
tokenPos := make([]int32, 0, len(slots))
|
|
for i, slot := range slots {
|
|
if !slot.opts.usesHistory() {
|
|
continue
|
|
}
|
|
ringPos := slot.historyLen % slot.opts.RepeatLastN
|
|
flatOffsets = append(flatOffsets, int32(rowOf[slot]*s.historyWidth()+ringPos))
|
|
tokenPos = append(tokenPos, int32(i))
|
|
slot.historyLen++
|
|
}
|
|
|
|
if len(flatOffsets) > 0 {
|
|
m := len(flatOffsets)
|
|
flatIdx := mlx.NewArrayInt32(flatOffsets, []int32{int32(m), 1})
|
|
writingTokens := token
|
|
if m != len(slots) {
|
|
tokenPosIdx := mlx.NewArrayInt32(tokenPos, []int32{int32(m)})
|
|
writingTokens = token.TakeAxis(tokenPosIdx, 0)
|
|
}
|
|
flatHist := s.history.Reshape(s.history.Dim(0)*s.historyWidth(), 1)
|
|
s.history.Set(flatHist.PutAlongAxis(flatIdx, writingTokens.ExpandDims(-1), 0).Reshape(s.history.Dim(0), s.historyWidth()))
|
|
}
|
|
}
|
|
return token
|
|
}
|
|
|
|
func (slot *slotState) sample(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
|
|
if slot.opts.Temperature == 0 {
|
|
return slot.baseScores(ctx, logits).Argmax(-1, false).AsType(mlx.DTypeInt32)
|
|
}
|
|
return slot.distribution(ctx, logits).SampleWithKey(slot.nextRandomKey())
|
|
}
|
|
|
|
func (slot *slotState) nextRandomKey() *mlx.Array {
|
|
if !slot.opts.UseSeed {
|
|
return nil
|
|
}
|
|
seed := mixSeed(uint64(slot.opts.Seed), slot.randomCounter)
|
|
slot.randomCounter++
|
|
return mlx.RandomKey(seed)
|
|
}
|
|
|
|
const (
|
|
// SplitMix64 constants used to decorrelate nearby (seed, counter) pairs.
|
|
splitMix64Weyl = 0x9e3779b97f4a7c15
|
|
splitMix64Mul1 = 0xbf58476d1ce4e5b9
|
|
splitMix64Mul2 = 0x94d049bb133111eb
|
|
splitMix64Shift1 = 30
|
|
splitMix64Shift2 = 27
|
|
splitMix64FinalShift = 31
|
|
)
|
|
|
|
func mixSeed(seed, counter uint64) uint64 {
|
|
z := seed + splitMix64Weyl*(counter+1)
|
|
z = (z ^ (z >> splitMix64Shift1)) * splitMix64Mul1
|
|
z = (z ^ (z >> splitMix64Shift2)) * splitMix64Mul2
|
|
return z ^ (z >> splitMix64FinalShift)
|
|
}
|
|
|
|
func (slot *slotState) baseScores(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
|
|
scores := logits
|
|
if slot.opts.usesHistory() {
|
|
scores = penalty(ctx, scores)
|
|
}
|
|
return scores
|
|
}
|
|
|
|
func (slot *slotState) distribution(ctx *slotCtx, logits *mlx.Array) Distribution {
|
|
scores := slot.baseScores(ctx, logits)
|
|
if slot.opts.Temperature <= 0 {
|
|
ids := scores.Argmax(-1, false).AsType(mlx.DTypeInt32).ExpandDims(-1)
|
|
probs := mlx.AddScalar(ids.AsType(mlx.DTypeFloat32).Multiply(mlx.FromValue(float32(0))), 1)
|
|
return Distribution{IDs: ids, Probs: probs}
|
|
}
|
|
|
|
vocab := scores.Dim(scores.NumDims() - 1)
|
|
if slot.opts.TopK > 0 && slot.opts.TopK < vocab {
|
|
return sparseDistribution(ctx.opts, scores)
|
|
}
|
|
return denseDistribution(ctx.opts, scores)
|
|
}
|
|
|
|
func sparseDistribution(opts Options, scores *mlx.Array) Distribution {
|
|
ids := scores.Negative().ArgpartitionAxis(opts.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(0, opts.TopK)).AsType(mlx.DTypeInt32)
|
|
topScores := scores.TakeAlongAxis(ids, -1).AsType(mlx.DTypeFloat32)
|
|
probs := mlx.SoftmaxAxis(mlx.DivScalar(topScores, opts.Temperature), -1, true)
|
|
probs = applyTopPProbs(probs, opts.TopP)
|
|
probs = applyMinPProbs(probs, opts.MinP)
|
|
return Distribution{IDs: ids, Probs: normalizeProbs(probs)}
|
|
}
|
|
|
|
func denseDistribution(opts Options, scores *mlx.Array) Distribution {
|
|
probs := mlx.SoftmaxAxis(mlx.DivScalar(scores.AsType(mlx.DTypeFloat32), opts.Temperature), -1, true)
|
|
probs = applyTopPProbs(probs, opts.TopP)
|
|
probs = applyMinPProbs(probs, opts.MinP)
|
|
return Distribution{Probs: normalizeProbs(probs)}
|
|
}
|
|
|
|
func applyTopPProbs(probs *mlx.Array, topP float32) *mlx.Array {
|
|
if topP <= 0 || topP >= 1 {
|
|
return probs
|
|
}
|
|
order := probs.Negative().ArgsortAxis(-1)
|
|
sorted := probs.TakeAlongAxis(order, -1)
|
|
prevCumProbs := sorted.Cumsum(-1, false, true).Subtract(sorted)
|
|
keep := prevCumProbs.Less(mlx.FromValue(topP))
|
|
filtered := mlx.Where(keep, sorted, mlx.FromValue(float32(0)))
|
|
return mlx.Zeros(probs.DType(), probs.Dims()...).PutAlongAxis(order, filtered, -1)
|
|
}
|
|
|
|
func applyMinPProbs(probs *mlx.Array, minP float32) *mlx.Array {
|
|
if minP <= 0 || minP > 1 {
|
|
return probs
|
|
}
|
|
threshold := mlx.MulScalar(probs.MaxAxis(-1, true), minP)
|
|
return mlx.Where(probs.Less(threshold), mlx.FromValue(float32(0)), probs)
|
|
}
|
|
|
|
func normalizeProbs(probs *mlx.Array) *mlx.Array {
|
|
sum := mlx.Maximum(probs.SumAxis(-1, true), mlx.FromValue(float32(1e-20)))
|
|
return probs.Divide(sum)
|
|
}
|
|
|
|
func logitsFromProbs(probs *mlx.Array) *mlx.Array {
|
|
positive := mlx.Maximum(probs, mlx.FromValue(float32(1e-20)))
|
|
logits := mlx.Log(positive)
|
|
return mlx.Where(probs.LessEqual(mlx.FromValue(float32(0))), mlx.FromValue(float32(math.Inf(-1))), logits)
|
|
}
|
|
|
|
func penalty(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
|
|
tokenIndices := ctx.history
|
|
if tokenIndices == nil {
|
|
return scores
|
|
}
|
|
|
|
if ctx.opts.RepeatPenalty != 1 || ctx.opts.PresencePenalty != 0 {
|
|
adjusted := scores.TakeAlongAxis(tokenIndices, -1)
|
|
if ctx.opts.RepeatPenalty != 1 {
|
|
factor := mlx.Where(
|
|
adjusted.Less(mlx.FromValue(float32(0))),
|
|
mlx.FromValue(ctx.opts.RepeatPenalty),
|
|
mlx.FromValue(1/ctx.opts.RepeatPenalty),
|
|
)
|
|
adjusted = adjusted.Multiply(factor)
|
|
}
|
|
if ctx.opts.PresencePenalty != 0 {
|
|
adjusted = mlx.AddScalar(adjusted, -ctx.opts.PresencePenalty)
|
|
}
|
|
scores = scores.PutAlongAxis(tokenIndices, adjusted, -1)
|
|
}
|
|
|
|
if ctx.opts.FrequencyPenalty != 0 {
|
|
scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-ctx.opts.FrequencyPenalty), -1)
|
|
}
|
|
|
|
return scores
|
|
}
|