diff --git a/x/mlxrunner/mlx/ops.go b/x/mlxrunner/mlx/ops.go index 8bc0e47f9..faa44898c 100644 --- a/x/mlxrunner/mlx/ops.go +++ b/x/mlxrunner/mlx/ops.go @@ -72,6 +72,10 @@ func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array { } func (t *Array) Concatenate(axis int, others ...*Array) *Array { + if len(others) == 0 { + return t.Clone() + } + vector := C.mlx_vector_array_new() defer C.mlx_vector_array_free(vector) @@ -127,9 +131,9 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array { return out } -func (t *Array) Logsumexp(keepDims bool) *Array { - out := New("LOGSUMEXP") - C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx) +func (t *Array) LogsumexpAxis(axis int, keepDims bool) *Array { + out := New("LOGSUMEXP_AXIS") + C.mlx_logsumexp_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx) return out } diff --git a/x/mlxrunner/mlx/ops_extra.go b/x/mlxrunner/mlx/ops_extra.go index 409f71263..ce7a5b246 100644 --- a/x/mlxrunner/mlx/ops_extra.go +++ b/x/mlxrunner/mlx/ops_extra.go @@ -376,6 +376,9 @@ func Concatenate(arrays []*Array, axis int) *Array { if len(arrays) == 0 { return nil } + if len(arrays) == 1 { + return arrays[0].Clone() + } return arrays[0].Concatenate(axis, arrays[1:]...) } diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 6297c220b..34d3e3d13 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -49,14 +49,15 @@ func (r *Runner) Prepare(request *Request) error { return nil } +// The runner serializes requests today so we just use a fixed slot ID. +const pipelineSlot = 0 + func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) error { mlx.ResetPeakMemory() var sample, nextSample sampler.Result defer func() { - if request.Sampler != nil { - request.Sampler.Free() - } + r.Sampler.Remove(pipelineSlot) mlx.Unpin(sample.Arrays()...) mlx.Unpin(nextSample.Arrays()...) mlx.Sweep() @@ -70,7 +71,6 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er }() inputs := request.Tokens - request.Sampler.ResetHistory(inputs) session := r.cache.begin(r.Model, inputs) defer session.close() @@ -122,7 +122,7 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er } } - r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches) + r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], 1, n), caches) mlx.Sweep() materializeCaches() processed += n @@ -139,21 +139,28 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er mlx.ClearCache() } + // Register the sampler after prefill completes. + r.Sampler.Add(pipelineSlot, request.SamplerOpts, inputs) + step := func(token *mlx.Array) sampler.Result { - fwd := r.Model.Forward(token.ExpandDims(0), caches) + fwd := r.Model.Forward(token, caches) logits := r.Model.Unembed(fwd) logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) - sample := request.Sampler.Sample(logits) + sample := r.Sampler.Sample([]int{pipelineSlot}, logits) mlx.Pin(sample.Arrays()...) mlx.Sweep() mlx.AsyncEval(sample.Arrays()...) return sample } - sample = step(mlx.FromValues(tokens[processed:], total-processed)) + sample = step(mlx.FromValues(tokens[processed:], 1, total-processed)) - dec := decoder{tokenizer: r.Tokenizer} + dec := decoder{ + tokenizer: r.Tokenizer, + wantLogprobs: request.SamplerOpts.Logprobs, + wantTopLogprobs: request.SamplerOpts.TopLogprobs, + } final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.NumPredict, DoneReason: 1} for i := range request.Options.NumPredict { @@ -161,8 +168,7 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er return err } - request.Sampler.AppendToken(sample.Token) - nextSample = step(sample.Token) + nextSample = step(sample.Token.ExpandDims(-1)) if i == 0 { mlx.Eval(sample.Arrays()...) @@ -209,15 +215,17 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er // with those bytes so Content and Logprobs stay aligned when a chunk does // flush. type decoder struct { - tokenizer *tokenizer.Tokenizer - buf bytes.Buffer - logprobs []llm.Logprob + tokenizer *tokenizer.Tokenizer + buf bytes.Buffer + logprobs []llm.Logprob + wantLogprobs bool + wantTopLogprobs int } func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) { output := int32(res.Token.Int()) d.buf.WriteString(d.tokenizer.Decode([]int32{output})) - d.logprobs = append(d.logprobs, buildLogprob(res, d.tokenizer.Decode)...) + d.logprobs = append(d.logprobs, buildLogprob(res, d.wantLogprobs, d.wantTopLogprobs, d.tokenizer.Decode)...) content := flushValidUTF8Prefix(&d.buf) if content == "" { @@ -228,8 +236,13 @@ func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) { return resp, true } -func buildLogprob(sample sampler.Result, decode func([]int32) string) []llm.Logprob { - if sample.Logprob == nil { +// buildLogprob converts the sampler's logprob tensors into the wire-format +// llm.Logprob entries the caller wants. The sampler populates its logprob +// tensors whenever any registered slot requested them, so the caller must +// gate emission on its own request config (wantLogprobs / wantTopLogprobs) +// rather than on whether the tensors happen to be non-nil. +func buildLogprob(sample sampler.Result, wantLogprobs bool, wantTopLogprobs int, decode func([]int32) string) []llm.Logprob { + if !wantLogprobs || sample.Logprob == nil { return nil } tok := func(id int32) string { return decode([]int32{id}) } @@ -241,7 +254,7 @@ func buildLogprob(sample sampler.Result, decode func([]int32) string) []llm.Logp }, } - if sample.TopTokens != nil { + if wantTopLogprobs > 0 && sample.TopTokens != nil { ids := sample.TopTokens.Ints() vals := sample.TopLogprobs.Floats() pairs := make([]llm.TokenLogprob, len(ids)) @@ -251,9 +264,14 @@ func buildLogprob(sample sampler.Result, decode func([]int32) string) []llm.Logp Logprob: float64(vals[i]), } } + // The sampler emits the top maxK across registered slots via + // Argpartition, which leaves entries unsorted. sort.Slice(pairs, func(i, j int) bool { return pairs[i].Logprob > pairs[j].Logprob }) + if wantTopLogprobs < len(pairs) { + pairs = pairs[:wantTopLogprobs] + } out.TopLogprobs = pairs } return []llm.Logprob{out} diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index 2ab5e323a..6e23471b1 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -27,15 +27,16 @@ type Request struct { Responses chan CompletionResponse Pipeline func(context.Context, Request) error - Ctx context.Context //nolint:containedctx - Tokens []int32 - Sampler *sample.Sampler + Ctx context.Context //nolint:containedctx + Tokens []int32 + SamplerOpts sample.Options } type Runner struct { Model base.Model Tokenizer *tokenizer.Tokenizer Requests chan Request + Sampler *sample.Sampler cache kvCache contextLength int } @@ -67,6 +68,7 @@ func (r *Runner) Load(modelName string) error { r.Model = m r.Tokenizer = m.Tokenizer() r.contextLength = m.MaxContextLength() + r.Sampler = sample.New(r.contextLength) mlx.EnableCompile() return nil diff --git a/x/mlxrunner/sample/logprob_test.go b/x/mlxrunner/sample/logprob_test.go index fa46d6389..9e37dd817 100644 --- a/x/mlxrunner/sample/logprob_test.go +++ b/x/mlxrunner/sample/logprob_test.go @@ -24,14 +24,15 @@ type logprobEntry struct { func runSampleLogprobs(t *testing.T, logits []float32, topK int) (int, float64, []logprobEntry) { t.Helper() - s := New(Options{Logprobs: true, TopLogprobs: topK}) + s := New(128) defer func() { s.Free() mlx.Sweep() }() + s.Add(0, Options{Logprobs: true, TopLogprobs: topK}, nil) tensor := mlx.FromValues(logits, 1, len(logits)) - res := s.Sample(tensor) + res := s.Sample([]int{0}, tensor) mlx.Pin(res.Arrays()...) defer mlx.Unpin(res.Arrays()...) @@ -55,6 +56,8 @@ func runSampleLogprobs(t *testing.T, logits []float32, topK int) (int, float64, } func TestSampleLogprobsBasic(t *testing.T) { + skipIfNoMLX(t) + tests := []struct { name string logits []float32 @@ -92,6 +95,8 @@ func TestSampleLogprobsBasic(t *testing.T) { } func TestSampleLogprobsNumericalStability(t *testing.T) { + skipIfNoMLX(t) + logits := []float32{1000.0, 999.0, 998.0} _, selLP, top := runSampleLogprobs(t, logits, 3) @@ -111,6 +116,8 @@ func TestSampleLogprobsNumericalStability(t *testing.T) { } func TestSampleLogprobsProbabilityCorrectness(t *testing.T) { + skipIfNoMLX(t) + tests := []struct { name string logits []float32 @@ -167,6 +174,8 @@ func TestSampleLogprobsProbabilityCorrectness(t *testing.T) { } func TestSampleLogprobsSoftmaxCorrectness(t *testing.T) { + skipIfNoMLX(t) + tests := []struct { name string logits []float32 @@ -202,6 +211,8 @@ func TestSampleLogprobsSoftmaxCorrectness(t *testing.T) { } func TestSampleLogprobsSelectedTokenCorrectness(t *testing.T) { + skipIfNoMLX(t) + logits := []float32{3.0, 1.0, 2.0, 0.5} maxIdx := 0 @@ -225,7 +236,47 @@ func TestSampleLogprobsSelectedTokenCorrectness(t *testing.T) { } } +// TestBatchedLogprobsPerRow verifies that per-row logprobs in a batched +// sample call match the per-slot reference. The numerically-stable softmax +// must reduce along the last axis only, not over the whole batch. +func TestBatchedLogprobsPerRow(t *testing.T) { + skipIfNoMLX(t) + + rowA := []float32{2, 1, 0} + rowB := []float32{0, 5, 0} + + _, wantA, _ := runSampleLogprobs(t, rowA, 0) + _, wantB, _ := runSampleLogprobs(t, rowB, 0) + + s := New(128) + t.Cleanup(func() { + s.Free() + mlx.Sweep() + }) + s.Add(1, Options{Logprobs: true}, nil) + s.Add(2, Options{Logprobs: true}, nil) + + logits := mlx.FromValues(append(append([]float32{}, rowA...), rowB...), 2, 3) + res := s.Sample([]int{1, 2}, logits) + mlx.Pin(res.Arrays()...) + t.Cleanup(func() { mlx.Unpin(res.Arrays()...) }) + mlx.Eval(res.Arrays()...) + + got := res.Logprob.Floats() + if len(got) != 2 { + t.Fatalf("Logprob length = %d, want 2", len(got)) + } + if math.Abs(float64(got[0])-wantA) > 1e-5 { + t.Errorf("row 0 logprob = %f, want %f (per-slot reference)", got[0], wantA) + } + if math.Abs(float64(got[1])-wantB) > 1e-5 { + t.Errorf("row 1 logprob = %f, want %f (per-slot reference)", got[1], wantB) + } +} + func TestSampleLogprobsTopKOrdering(t *testing.T) { + skipIfNoMLX(t) + // Logits chosen so argmax order differs from index order. logits := []float32{2.0, 5.0, 1.0, 4.0, 3.0} wantOrder := []int{1, 3, 4, 0, 2} diff --git a/x/mlxrunner/sample/sample.go b/x/mlxrunner/sample/sample.go index 8f3987cdc..b61b68fcb 100644 --- a/x/mlxrunner/sample/sample.go +++ b/x/mlxrunner/sample/sample.go @@ -1,13 +1,13 @@ package sample import ( + "fmt" "math" + "slices" "github.com/ollama/ollama/x/mlxrunner/mlx" ) -type Transform func(*Sampler, *mlx.Array) *mlx.Array - type Options struct { Temperature float32 TopP float32 @@ -24,24 +24,15 @@ type Options struct { TopLogprobs int } -type Sampler struct { - Options - - // history is a ring buffer of the last RepeatLastN sampled tokens. - // historyLen counts total appends; the live slots are [0, historyLen) - // while filling and all of history once historyLen >= RepeatLastN. - history *mlx.Array - historyLen int - transforms []Transform -} - -// Result bundles the outputs of one decode step. The logprob tensors are -// populated only when the sampler is configured to report them. +// Result bundles the outputs of one decode step. Logprob/TopTokens/ +// TopLogprobs are populated whenever any registered slot has Logprobs +// (respectively TopLogprobs>0). Consumers need to filter by their +// per-slot Options. type Result struct { - Token *mlx.Array // sampled token id, shape [B] - Logprob *mlx.Array // sampled-token logprob, shape [B,1]; nil unless Logprobs - TopTokens *mlx.Array // top-K token ids, shape [B,K]; nil unless TopLogprobs > 0 - TopLogprobs *mlx.Array // top-K logprobs, shape [B,K]; nil unless TopLogprobs > 0 + Token *mlx.Array // sampled token ids, shape [B] + Logprob *mlx.Array // sampled-token logprobs, shape [B,1]; nil unless any registered slot has Logprobs + TopTokens *mlx.Array // top-K token ids, shape [B,maxK]; nil unless any registered slot has TopLogprobs>0 + TopLogprobs *mlx.Array // top-K logprobs, shape [B,maxK]; same } // Arrays returns the tensor fields as a slice so callers can drive the mlx @@ -51,119 +42,300 @@ func (r Result) Arrays() []*mlx.Array { return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs} } -func New(opts Options) *Sampler { - if opts.RepeatPenalty <= 0 { - opts.RepeatPenalty = 1 +// Sampler is a batched, slot-based sampler. Sequences are registered with +// Add and released with Remove. Each Sample call takes a subset of +// registered slots (in any order) with their [B,V] logits, samples one +// token per row, and appends it to that slot's ring-buffer history. Slots +// not named in a given call are untouched. +type Sampler struct { + slots []*slotState + byID map[int]*slotState + + // history is the pooled ring-buffer storage, [B, W] int32. Row i + // belongs to slots[i]; W is max(RepeatLastN) across penalty slots. + // Allocated on the first penalty slot, rebuilt only in Add/Remove. + history *mlx.Array + + // allSameOpts: every registered slot shares Options. When true the + // canonical shared value is s.slots[0].opts. + allSameOpts bool + + // anyLogprobs / maxTopLogprobs: compute-for-all output config. + // Sample populates Logprob (and Top* when maxTopLogprobs>0) whenever + // any registered slot requests them, even if that slot isn't in the + // current call. + anyLogprobs bool + maxTopLogprobs int + + // numCtx is the runner's context window; normalize uses it to + // resolve the repeat_last_n == -1 sentinel. + numCtx int +} + +type slotState struct { + opts Options + transforms []transform + historyLen int +} + +type slotCtx struct { + opts Options + history *mlx.Array // 2D [B, W] when penalties are configured; nil otherwise +} + +type transform func(*slotCtx, *mlx.Array) *mlx.Array + +// New constructs an empty sampler with no registered slots. numCtx is +// the runner's context window and must be positive. +func New(numCtx int) *Sampler { + return &Sampler{ + byID: make(map[int]*slotState), + allSameOpts: true, + numCtx: numCtx, + } +} + +// historyWidth returns the column count of the pooled history tensor, +// or 0 when no penalty slot has forced it to be allocated. +func (s *Sampler) historyWidth() int { + if s.history == nil { + return 0 + } + return s.history.Dim(1) +} + +func (o Options) usesHistory() bool { + // RepeatLastN == 0 disables the penalty ring per the repeat_last_n API + // contract (0 = disabled), overriding any penalty coefficients. + if o.RepeatLastN == 0 { + return false + } + return o.RepeatPenalty != 1 || o.PresencePenalty != 0 || o.FrequencyPenalty != 0 +} + +func (o Options) normalize(numCtx int) Options { + if o.RepeatPenalty <= 0 { + o.RepeatPenalty = 1 + } + // Resolve the repeat_last_n == -1 sentinel ("-1 = num_ctx") against + // the caller's context window. + if o.RepeatLastN < 0 { + o.RepeatLastN = numCtx + } + if !o.usesHistory() { + // Zero the ring capacity so slots that differ only in a spurious + // RepeatLastN still batch together and don't inflate pool width. + o.RepeatLastN = 0 + } + return o +} + +func (o Options) buildTransforms() []transform { + var ts []transform + if o.usesHistory() { + ts = append(ts, penalty) } - s := &Sampler{Options: opts} - - var transforms []Transform - if s.usesHistory() { - // Ring buffer needs a bounded capacity; fall back to a - // default when the caller hasn't configured one. - if s.RepeatLastN <= 0 { - s.RepeatLastN = 64 - } - transforms = append(transforms, penalty) - } - - hasTopP := opts.TopP > 0 && opts.TopP < 1 - hasTopK := opts.TopK > 0 + hasTopP := o.TopP > 0 && o.TopP < 1 + hasTopK := o.TopK > 0 switch { case hasTopP: // topKTopP always does a full descending sort for the top-P // cumulative mask and opportunistically masks top-K during the // same pass when it is also configured. - transforms = append(transforms, topKTopP) + ts = append(ts, topKTopP) case hasTopK: // Argpartition (partial sort) is cheaper than a full sort. - transforms = append(transforms, topK) + ts = append(ts, topK) } - if opts.MinP != 0 { - transforms = append(transforms, minP) + if o.MinP != 0 { + ts = append(ts, minP) } - if opts.Temperature == 0 { - transforms = append(transforms, greedy) + if o.Temperature == 0 { + ts = append(ts, greedy) } else { - transforms = append(transforms, temperature) + ts = append(ts, temperature) + } + return ts +} + +// Add registers a sequence under seqID. The last RepeatLastN entries of +// priorTokens seed the ring buffer. +func (s *Sampler) Add(seqID int, opts Options, priorTokens []int32) { + if _, dup := s.byID[seqID]; dup { + panic(fmt.Sprintf("sample.Sampler.Add: seqID %d already registered", seqID)) } - s.transforms = transforms - return s + opts = opts.normalize(s.numCtx) + slot := &slotState{ + opts: opts, + transforms: opts.buildTransforms(), + } + + // Grow the pool to hold this slot's row. The pool is lazy — the first + // penalty slot allocates it — and thereafter every registered slot + // gets a row (rows for non-penalty slots are zero and never read). + // Invariant: s.history is pinned whenever non-nil. + if s.history != nil || opts.usesHistory() { + targetWidth := max(opts.RepeatLastN, s.historyWidth()) + newRow := makeHistoryRow(priorTokens, opts.RepeatLastN, targetWidth) + + var pool *mlx.Array + switch { + case s.history == nil && len(s.slots) == 0: + pool = newRow + case s.history == nil: + // First penalty slot with non-penalty slots already registered; + // seed zero rows so s.slots and pool row indices stay aligned. + zeros := mlx.Zeros(mlx.DTypeInt32, len(s.slots), targetWidth) + pool = zeros.Concatenate(0, newRow) + case targetWidth > s.historyWidth(): + pad := mlx.Zeros(mlx.DTypeInt32, s.history.Dim(0), targetWidth-s.historyWidth()) + pool = s.history.Concatenate(1, pad).Concatenate(0, newRow) + default: + pool = s.history.Concatenate(0, newRow) + } + + mlx.Pin(pool) + mlx.Unpin(s.history) + s.history = pool + + if opts.usesHistory() { + // Cap on seed so the next write's ring position + // (historyLen % RepeatLastN) lands at 0, overwriting the + // oldest entry when the ring was filled from priors. + slot.historyLen = min(len(priorTokens), opts.RepeatLastN) + } + } + + s.slots = append(s.slots, slot) + s.byID[seqID] = slot + s.recomputeInvariants() } -func (s *Sampler) usesHistory() bool { - return s.RepeatPenalty != 1 || s.PresencePenalty != 0 || s.FrequencyPenalty != 0 +// makeHistoryRow builds a [1, width] int32 row with the last repeatLastN +// entries of priorTokens packed into [0, min(len, repeatLastN)), zeros +// elsewhere. +func makeHistoryRow(priorTokens []int32, repeatLastN, width int) *mlx.Array { + take := min(len(priorTokens), repeatLastN) + if take <= 0 { + return mlx.Zeros(mlx.DTypeInt32, 1, width) + } + row := make([]int32, width) + copy(row, priorTokens[len(priorTokens)-take:]) + return mlx.NewArrayInt32(row, []int32{1, int32(width)}) } -func (s *Sampler) ResetHistory(history []int32) { - if !s.usesHistory() { +// recomputeInvariants refreshes allSameOpts and anyLogprobs/maxTopLogprobs +// from s.slots. Called at the end of Add and Remove. +func (s *Sampler) recomputeInvariants() { + if len(s.slots) == 0 { + s.allSameOpts = true + s.anyLogprobs = false + s.maxTopLogprobs = 0 return } - if len(history) > s.RepeatLastN { - history = history[len(history)-s.RepeatLastN:] + first := s.slots[0].opts + s.allSameOpts = true + s.anyLogprobs = false + s.maxTopLogprobs = 0 + for _, slot := range s.slots { + if slot.opts != first { + s.allSameOpts = false + } + if slot.opts.Logprobs { + s.anyLogprobs = true + if slot.opts.TopLogprobs > s.maxTopLogprobs { + s.maxTopLogprobs = slot.opts.TopLogprobs + } + } + } +} + +// Remove releases the slot. The pool tensor is rebuilt to drop the row. +func (s *Sampler) Remove(seqID int) { + slot, ok := s.byID[seqID] + if !ok { + return + } + delete(s.byID, seqID) + + row := slices.Index(s.slots, slot) + s.slots = slices.Delete(s.slots, row, row+1) + s.recomputeInvariants() + + if s.history == nil { + return } - var ring *mlx.Array - switch gap := s.RepeatLastN - len(history); { - case len(history) == 0: - ring = mlx.Zeros(mlx.DTypeInt32, s.RepeatLastN) - case gap == 0: - ring = mlx.NewArrayInt32(history, []int32{int32(len(history))}) + n := s.history.Dim(0) + var newHistory *mlx.Array + switch { + case n == 1: + newHistory = nil + case row == 0: + newHistory = s.history.Slice(mlx.Slice(1, n), mlx.Slice()) + case row == n-1: + newHistory = s.history.Slice(mlx.Slice(0, row), mlx.Slice()) default: - init := mlx.NewArrayInt32(history, []int32{int32(len(history))}) - ring = init.Concatenate(0, mlx.Zeros(mlx.DTypeInt32, gap)) + before := s.history.Slice(mlx.Slice(0, row), mlx.Slice()) + after := s.history.Slice(mlx.Slice(row+1, n), mlx.Slice()) + newHistory = before.Concatenate(0, after) } - mlx.Pin(ring) - if s.history != nil { - mlx.Unpin(s.history) - } - s.history = ring - s.historyLen = len(history) -} - -// AppendToken records token in the ring buffer. ResetHistory must have been -// called first to allocate the buffer. -func (s *Sampler) AppendToken(token *mlx.Array) { - if !s.usesHistory() || token == nil { - return - } - - writeIdx := s.historyLen % s.RepeatLastN - s.history.Set(s.history.SliceUpdate(token, mlx.Slice(writeIdx, writeIdx+1))) - s.historyLen++ + + mlx.Pin(newHistory) + mlx.Unpin(s.history) + s.history = newHistory } +// Free releases the pooled history tensor and resets the sampler to the +// New-equivalent state so it may be reused. func (s *Sampler) Free() { - if s.history != nil { - mlx.Unpin(s.history) + mlx.Unpin(s.history) + *s = Sampler{ + byID: make(map[int]*slotState), + allSameOpts: true, + numCtx: s.numCtx, } - s.history = nil - s.historyLen = 0 } -// Sample runs the configured transform chain on the raw per-token logits -// and returns the sampled token id plus, when configured, the reported -// log-probability tensors for the selected token and the top-K tokens. -func (s *Sampler) Sample(logits *mlx.Array) Result { - scores := logits - for _, transform := range s.transforms { - scores = transform(s, scores) +// Sample draws one token per row of logits ([B,V]); seqIDs[i] names the +// slot whose logits live at row i. Each sampled token is appended to its +// slot's ring. Slots not named in seqIDs are untouched. +func (s *Sampler) Sample(seqIDs []int, logits *mlx.Array) Result { + if len(seqIDs) == 0 { + return Result{} } - res := Result{Token: scores} - if s.Logprobs { - // Compute log_softmax in fp32 and subtract the max before - // logsumexp so the final subtraction stays on small values. - // Otherwise it cancels two large numbers and loses precision. + slots := make([]*slotState, len(seqIDs)) + for i, id := range seqIDs { + slot, ok := s.byID[id] + if !ok { + panic(fmt.Sprintf("sample.Sampler.Sample: seqID %d not registered", id)) + } + slots[i] = slot + } + + var token *mlx.Array + if opts0, ok := s.canBatch(slots); ok { + token = s.sampleTokensUniform(slots, opts0, logits) + } else { + token = s.sampleTokensSerial(slots, logits) + } + + res := Result{Token: token} + if s.anyLogprobs { + // Log-softmax over original logits so every row holds a truthful + // value (compute-for-all; consumers filter per-slot). Subtract + // max first for numerical stability in the logsumexp. lp := logits.AsType(mlx.DTypeFloat32) lp = lp.Subtract(lp.MaxAxis(-1, true)) - lp = lp.Subtract(lp.Logsumexp(true)) - res.Logprob = lp.TakeAlongAxis(res.Token.ExpandDims(-1), -1) - if k := s.TopLogprobs; k > 0 { + lp = lp.Subtract(lp.LogsumexpAxis(-1, true)) + res.Logprob = lp.TakeAlongAxis(token.ExpandDims(-1), -1) + if s.maxTopLogprobs > 0 { + k := s.maxTopLogprobs if vocab := lp.Dim(lp.NumDims() - 1); k > vocab { k = vocab } @@ -177,55 +349,180 @@ func (s *Sampler) Sample(logits *mlx.Array) Result { return res } -func greedy(_ *Sampler, scores *mlx.Array) *mlx.Array { - return scores.Argmax(-1, false) +// canBatch reports whether the call can take the uniform batched path. +// All slots must share Options; when penalties are active the call must +// additionally cover every registered slot in registration order with a +// full ring, because the uniform path indexes the pool positionally. +func (s *Sampler) canBatch(slots []*slotState) (Options, bool) { + if !s.allSameOpts { + return Options{}, false + } + // slots is non-empty (Sample guards) and every slot is registered, + // so s.slots[0].opts is the canonical shared value. + shared := s.slots[0].opts + if !shared.usesHistory() { + return shared, true + } + if len(slots) != len(s.slots) { + return Options{}, false + } + for i, slot := range slots { + if s.slots[i] != slot || slot.historyLen < shared.RepeatLastN { + return Options{}, false + } + } + return shared, true } -func temperature(s *Sampler, scores *mlx.Array) *mlx.Array { - return mlx.DivScalar(scores, s.Temperature).Categorical(-1) +// sampleTokensUniform runs one fused transform pass over the whole batch. +// Reached only when canBatch is true, which lets the pool be used in place +// with a single PutAlongAxis write-back and no gather. +func (s *Sampler) sampleTokensUniform(slots []*slotState, opts Options, logits *mlx.Array) *mlx.Array { + B := len(slots) + + var hist *mlx.Array + if opts.usesHistory() { + hist = s.history + if s.historyWidth() > opts.RepeatLastN { + hist = hist.Slice(mlx.Slice(), mlx.Slice(0, opts.RepeatLastN)) + } + } + + ctx := &slotCtx{opts: opts, history: hist} + scores := logits + for _, t := range slots[0].transforms { + scores = t(ctx, scores) + } + token := scores + + if !opts.usesHistory() { + return token + } + + writeIdxData := make([]int32, B) + for i, slot := range slots { + writeIdxData[i] = int32(slot.historyLen % opts.RepeatLastN) + slot.historyLen++ + } + writeIdx := mlx.NewArrayInt32(writeIdxData, []int32{int32(B), 1}) + + s.history.Set(s.history.PutAlongAxis(writeIdx, token.ExpandDims(-1), 1)) + return token +} + +// sampleTokensSerial runs each slot's transforms against its own row of +// logits. +func (s *Sampler) sampleTokensSerial(slots []*slotState, logits *mlx.Array) *mlx.Array { + perSlotTokens := make([]*mlx.Array, len(slots)) + + rowOf := make(map[*slotState]int, len(s.slots)) + for i, slot := range s.slots { + rowOf[slot] = i + } + + for i, slot := range slots { + row := logits.Slice(mlx.Slice(i, i+1), mlx.Slice()) + + var hist *mlx.Array + if slot.opts.usesHistory() && slot.historyLen > 0 && s.history != nil { + poolRow := rowOf[slot] + fill := min(slot.historyLen, slot.opts.RepeatLastN) + hist = s.history.Slice( + mlx.Slice(poolRow, poolRow+1), + mlx.Slice(0, fill), + ) + } + + ctx := &slotCtx{opts: slot.opts, history: hist} + scores := row + for _, t := range slot.transforms { + scores = t(ctx, scores) + } + perSlotTokens[i] = scores + } + + token := mlx.Concatenate(perSlotTokens, 0) + + if s.history != nil { + // For each writing slot collect its flat (row-major) pool offset + // and the call-order position of its token. One PutAlongAxis on a + // flat view of the pool scatters all writes in a single op. + flatOffsets := make([]int32, 0, len(slots)) + tokenPos := make([]int32, 0, len(slots)) + for i, slot := range slots { + if !slot.opts.usesHistory() { + continue + } + ringPos := slot.historyLen % slot.opts.RepeatLastN + flatOffsets = append(flatOffsets, int32(rowOf[slot]*s.historyWidth()+ringPos)) + tokenPos = append(tokenPos, int32(i)) + slot.historyLen++ + } + + if len(flatOffsets) > 0 { + m := len(flatOffsets) + flatIdx := mlx.NewArrayInt32(flatOffsets, []int32{int32(m), 1}) + writingTokens := token + if m != len(slots) { + tokenPosIdx := mlx.NewArrayInt32(tokenPos, []int32{int32(m)}) + writingTokens = token.TakeAxis(tokenPosIdx, 0) + } + flatHist := s.history.Reshape(s.history.Dim(0)*s.historyWidth(), 1) + s.history.Set(flatHist.PutAlongAxis(flatIdx, writingTokens.ExpandDims(-1), 0).Reshape(s.history.Dim(0), s.historyWidth())) + } + } + return token +} + +func greedy(_ *slotCtx, scores *mlx.Array) *mlx.Array { + return scores.Argmax(-1, false).AsType(mlx.DTypeInt32) +} + +func temperature(ctx *slotCtx, scores *mlx.Array) *mlx.Array { + return mlx.DivScalar(scores, ctx.opts.Temperature).Categorical(-1).AsType(mlx.DTypeInt32) } // topKTopP applies top-P in a descending sort pass and, when top-K is also // configured, masks any surviving value below the K-th largest in the same -// pass. Callers dispatch here whenever top-P is enabled — the top-K-only -// case uses a cheaper partial sort via the topK transform. -func topKTopP(s *Sampler, scores *mlx.Array) *mlx.Array { +// pass. Callers dispatch here whenever top-P is enabled — the top-K-only case +// uses a cheaper partial sort via the topK transform. +func topKTopP(ctx *slotCtx, scores *mlx.Array) *mlx.Array { vocab := scores.Dim(scores.NumDims() - 1) - applyTopK := s.TopK > 0 && s.TopK < vocab + applyTopK := ctx.opts.TopK > 0 && ctx.opts.TopK < vocab order := scores.Negative().ArgsortAxis(-1) sorted := scores.TakeAlongAxis(order, -1) negInf := mlx.FromValue(float32(math.Inf(-1))) // Top-P: in descending order, keep tokens whose exclusive cumulative - // probability is still below s.TopP. + // probability is still below TopP. probs := mlx.SoftmaxAxis(sorted, -1, true) prevCumProbs := probs.Cumsum(-1, false, true).Subtract(probs) - keep := prevCumProbs.Less(mlx.FromValue(s.TopP)) + keep := prevCumProbs.Less(mlx.FromValue(ctx.opts.TopP)) sorted = mlx.Where(keep, sorted, negInf) out := scores.PutAlongAxis(order, sorted, -1) - // Top-K: sorted is already in descending order, so positions [K, V) - // are the ones to drop. Scatter -inf through their original-layout - // indices (order[K:]). Positional (not value-based) so exactly K - // tokens survive — ties at the K-th logit get broken by the sort - // order rather than promoted through the filter. + // Top-K: sorted is already in descending order, so positions [K, V) are + // the ones to drop. Scatter -inf through their original-layout indices + // (order[K:]). Positional (not value-based) so exactly K tokens survive — + // ties at the K-th logit get broken by the sort order rather than + // promoted through the filter. if applyTopK { - dropOrder := order.Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End)) + dropOrder := order.Slice(mlx.Slice(), mlx.Slice(ctx.opts.TopK, mlx.End)) out = out.PutAlongAxis(dropOrder, negInf, -1) } return out } -func minP(s *Sampler, scores *mlx.Array) *mlx.Array { - if s.MinP <= 0 || s.MinP > 1 { +func minP(ctx *slotCtx, scores *mlx.Array) *mlx.Array { + if ctx.opts.MinP <= 0 || ctx.opts.MinP > 1 { return scores } maxScore := scores.MaxAxis(-1, true) - threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(s.MinP)))) + threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(ctx.opts.MinP)))) return mlx.Where( scores.Less(threshold), @@ -234,52 +531,43 @@ func minP(s *Sampler, scores *mlx.Array) *mlx.Array { ) } -func topK(s *Sampler, scores *mlx.Array) *mlx.Array { - if s.TopK <= 0 { +func topK(ctx *slotCtx, scores *mlx.Array) *mlx.Array { + if ctx.opts.TopK <= 0 { return scores } - vocab := scores.Dim(scores.NumDims() - 1) - if s.TopK >= vocab { + if ctx.opts.TopK >= vocab { return scores } - mask := scores.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End)) + mask := scores.Negative().ArgpartitionAxis(ctx.opts.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(ctx.opts.TopK, mlx.End)) return scores.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1) } -func penalty(s *Sampler, scores *mlx.Array) *mlx.Array { - if s.historyLen == 0 { +func penalty(ctx *slotCtx, scores *mlx.Array) *mlx.Array { + tokenIndices := ctx.history + if tokenIndices == nil { return scores } - tokenIndices := s.history - if s.historyLen < s.RepeatLastN { - tokenIndices = tokenIndices.Slice(mlx.Slice(0, s.historyLen)) - } - - if scores.NumDims() > 1 { - tokenIndices = tokenIndices.ExpandDims(0) - } - - if s.RepeatPenalty != 1 || s.PresencePenalty != 0 { + if ctx.opts.RepeatPenalty != 1 || ctx.opts.PresencePenalty != 0 { adjusted := scores.TakeAlongAxis(tokenIndices, -1) - if s.RepeatPenalty != 1 { + if ctx.opts.RepeatPenalty != 1 { factor := mlx.Where( adjusted.Less(mlx.FromValue(float32(0))), - mlx.FromValue(s.RepeatPenalty), - mlx.FromValue(1/s.RepeatPenalty), + mlx.FromValue(ctx.opts.RepeatPenalty), + mlx.FromValue(1/ctx.opts.RepeatPenalty), ) adjusted = adjusted.Multiply(factor) } - if s.PresencePenalty != 0 { - adjusted = mlx.AddScalar(adjusted, -s.PresencePenalty) + if ctx.opts.PresencePenalty != 0 { + adjusted = mlx.AddScalar(adjusted, -ctx.opts.PresencePenalty) } scores = scores.PutAlongAxis(tokenIndices, adjusted, -1) } - if s.FrequencyPenalty != 0 { - scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-s.FrequencyPenalty), -1) + if ctx.opts.FrequencyPenalty != 0 { + scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-ctx.opts.FrequencyPenalty), -1) } return scores diff --git a/x/mlxrunner/sample/sample_test.go b/x/mlxrunner/sample/sample_test.go index 456207bcf..3871cc6bf 100644 --- a/x/mlxrunner/sample/sample_test.go +++ b/x/mlxrunner/sample/sample_test.go @@ -9,117 +9,298 @@ import ( "github.com/ollama/ollama/x/mlxrunner/mlx" ) -func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) { - s := New(Options{RepeatLastN: 1, PresencePenalty: 6}) - defer func() { - s.Free() - mlx.Sweep() - }() - - s.ResetHistory([]int32{0}) - s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1})) - - logits := mlx.FromValues([]float32{0, 5, 4}, 3) - got := s.Sample(logits).Token - mlx.Eval(got) - - // logits will be [0, -1, 4] after the penalty - // and then (index) 2 after the greedy sampler - gotInt := got.Int() - if gotInt != 2 { - t.Fatalf("got %d, want 2", gotInt) +func skipIfNoMLX(t *testing.T) { + t.Helper() + if err := mlx.CheckInit(); err != nil { + t.Skipf("MLX not available: %v", err) } } -func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) { - s := New(Options{RepeatLastN: 1, RepeatPenalty: 2}) - defer func() { +// slotLogits builds a [1, V] logits tensor for a single-slot Sample call. +func slotLogits(values []float32) *mlx.Array { + return mlx.FromValues(values, 1, len(values)) +} + +// batchLogits stacks per-row float32 slices of equal length into a [B, V] +// logits tensor. +func batchLogits(rows ...[]float32) *mlx.Array { + v := len(rows[0]) + flat := make([]float32, 0, len(rows)*v) + for _, r := range rows { + if len(r) != v { + panic("batchLogits: rows must share vocab size") + } + flat = append(flat, r...) + } + return mlx.FromValues(flat, len(rows), v) +} + +// sampleOne runs Sample on a freshly-added single slot and returns the +// sampled token id. Used both for the single-slot options table and as the +// reference oracle for the batched-equivalence test. +func sampleOne(t *testing.T, opts Options, priorTokens []int32, values []float32) int { + t.Helper() + s := New(128) + t.Cleanup(func() { s.Free() mlx.Sweep() - }() + }) + s.Add(0, opts, priorTokens) - s.ResetHistory([]int32{1}) - - logits := mlx.FromValues([]float32{0, 5, 4}, 3) - got := s.Sample(logits).Token + got := s.Sample([]int{0}, slotLogits(values)).Token mlx.Eval(got) + return got.Int() +} - // token 1 is repeated and positive, so 5 / 2 falls below token 2. - gotInt := got.Int() - if gotInt != 2 { - t.Fatalf("got %d, want 2", gotInt) +// logOf returns log(p) as a float32 so tests can build logits that softmax to +// a chosen probability distribution. +func logOf(p float64) float32 { return float32(math.Log(p)) } + +// TestSampleSingleSlotOptions pins the per-slot behavior of each Options +// knob against a concrete expected token. Expected values are worked out by +// hand from the math of each transform, not from a second call into the +// sampler — so a regression in any single transform shows up here. +func TestSampleSingleSlotOptions(t *testing.T) { + skipIfNoMLX(t) + + cases := []struct { + name string + opts Options + priors []int32 + logits []float32 + want int + }{ + { + name: "presence penalty", + opts: Options{RepeatLastN: 1, PresencePenalty: 6}, + priors: []int32{1}, + logits: []float32{0, 5, 4}, + want: 2, // token 1: 5 - 6 = -1, argmax shifts to 2 + }, + { + name: "repeat penalty on positive logits", + opts: Options{RepeatLastN: 1, RepeatPenalty: 2}, + priors: []int32{1}, + logits: []float32{0, 5, 4}, + want: 2, // token 1 positive → divided: 5/2 = 2.5, argmax shifts to 2 + }, + { + name: "repeat penalty on negative logits", + opts: Options{RepeatLastN: 1, RepeatPenalty: 4}, + priors: []int32{1}, + logits: []float32{-5, -1, -3}, + want: 2, // token 1 negative → multiplied: -1*4 = -4, argmax shifts to 2 + }, + { + name: "frequency penalty", + opts: Options{RepeatLastN: 4, FrequencyPenalty: 2}, + priors: []int32{1, 1}, + logits: []float32{0, 5, 4}, + want: 2, // 5 - 2*count(1)=2*2=4 → 1, argmax shifts to 2 + }, + { + name: "top-k", + opts: Options{Temperature: 1, TopK: 1}, + logits: []float32{1, 5, 4}, + want: 1, // only argmax survives → deterministic even with temperature + }, + { + name: "top-p", + opts: Options{Temperature: 1, TopP: 0.4}, + logits: []float32{logOf(0.5), logOf(0.3), logOf(0.2)}, + want: 0, // exclusive cumsum below 0.4 keeps only token 0 + }, + { + name: "min-p", + opts: Options{Temperature: 1, MinP: 0.7}, + logits: []float32{logOf(0.5), logOf(0.3), logOf(0.2)}, + want: 0, // threshold 0.5*0.7=0.35 drops all but the top token + }, + { + name: "RepeatLastN=0 disables penalties", + opts: Options{RepeatLastN: 0, RepeatPenalty: 2, PresencePenalty: 10}, + priors: []int32{1}, + logits: []float32{0, 5, 4}, + want: 1, // 0 = disabled per API contract, argmax unchanged + }, + { + name: "RepeatLastN=-1 resolves to num_ctx", + opts: Options{RepeatLastN: -1, PresencePenalty: 6}, + priors: []int32{1}, + logits: []float32{0, 5, 4}, + want: 2, // -1 → num_ctx (128); penalty applies, argmax shifts + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := sampleOne(t, tc.opts, tc.priors, tc.logits); got != tc.want { + t.Errorf("got %d, want %d", got, tc.want) + } + }) } } -func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) { - s := New(Options{RepeatLastN: 4, FrequencyPenalty: 2}) - defer func() { +// TestSampleHistoryWindow verifies that penalty history respects the +// RepeatLastN window: priors longer than RepeatLastN are trimmed on Add, +// and once the ring wraps, tokens that rotate out no longer contribute +// to penalties. +func TestSampleHistoryWindow(t *testing.T) { + skipIfNoMLX(t) + + s := New(128) + t.Cleanup(func() { s.Free() mlx.Sweep() - }() + }) - s.ResetHistory([]int32{1, 1}) + // RepeatLastN=2 with priors {1, 2, 3}: makeHistoryRow keeps only + // {2, 3}. Token 1 was trimmed — its penalty is NOT active. + s.Add(0, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{1, 2, 3}) - logits := mlx.FromValues([]float32{0, 5, 4}, 3) - got := s.Sample(logits).Token - mlx.Eval(got) + // Step 1: logits favor token 1 (trimmed). If the trim were broken it + // would be penalized and the argmax would move. + step1 := s.Sample([]int{0}, slotLogits([]float32{0, 5, 0, 0, 0})).Token + mlx.Eval(step1) + if got := step1.Int(); got != 1 { + t.Fatalf("step 1 = %d, want 1 (token 1 trimmed from priors)", got) + } + // After step 1 the ring holds {1, 3}; token 2 has rotated out. - // token 1 appears twice, so 5 - (2 * 2) falls below token 2. - gotInt := got.Int() - if gotInt != 2 { - t.Fatalf("got %d, want 2", gotInt) + // Step 2: logits favor token 2 (rotated out). If the ring wrap were + // wrong, token 2 would still be penalized. + step2 := s.Sample([]int{0}, slotLogits([]float32{0, 0, 5, 0, 0})).Token + mlx.Eval(step2) + if got := step2.Int(); got != 2 { + t.Fatalf("step 2 = %d, want 2 (token 2 rotated out of ring)", got) } } -func TestHistoryRingBufferWraps(t *testing.T) { - s := New(Options{RepeatLastN: 2, PresencePenalty: 10}) - defer func() { - s.Free() - mlx.Sweep() - }() +// TestBatchSamplingPreservesPerSlotBehavior is the core equivalence test: +// for every representative dispatch branch (uniform, serial on mixed opts, +// serial on partial ring, subset/out-of-order), a batched Sample call must +// produce the same token per row as running the same slot alone. +func TestBatchSamplingPreservesPerSlotBehavior(t *testing.T) { + skipIfNoMLX(t) - s.ResetHistory(nil) - s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1})) - s.AppendToken(mlx.NewArrayInt32([]int32{2}, []int32{1})) - s.AppendToken(mlx.NewArrayInt32([]int32{3}, []int32{1})) + type slot struct { + id int + opts Options + priors []int32 + } - // After three appends the ring holds {3, 2}; token 1 has been overwritten - // and must not be penalized. - logits := mlx.FromValues([]float32{0, 1, 5, 5, 0}, 5) - got := s.Sample(logits).Token - mlx.Eval(got) + cases := []struct { + name string + slots []slot + sample []int + rows [][]float32 + }{ + { + name: "uniform", + slots: []slot{ + {10, Options{RepeatLastN: 2, PresencePenalty: 5}, []int32{1, 2}}, + {20, Options{RepeatLastN: 2, PresencePenalty: 5}, []int32{0, 2}}, + }, + sample: []int{10, 20}, + rows: [][]float32{{0, 5, 4}, {3, 0, 0}}, + }, + { + name: "serial — mixed opts", + slots: []slot{ + {1, Options{RepeatLastN: 1, RepeatPenalty: 2}, []int32{1}}, + {2, Options{Temperature: 1, TopK: 1}, nil}, + }, + sample: []int{1, 2}, + rows: [][]float32{{0, 5, 4, 1}, {2, 1, 5, 3}}, + }, + { + name: "serial — partial ring", + slots: []slot{ + {1, Options{RepeatLastN: 4, PresencePenalty: 5}, []int32{1, 1, 1, 1}}, + {2, Options{RepeatLastN: 4, PresencePenalty: 5}, []int32{2}}, + }, + sample: []int{1, 2}, + rows: [][]float32{{0, 5, 4}, {0, 4, 5}}, + }, + { + name: "subset out-of-order", + slots: []slot{ + {10, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{1, 1}}, + {20, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{2, 2}}, + {30, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{3, 3}}, + }, + sample: []int{30, 10}, + rows: [][]float32{{5, 5, 5, 0, 5, 5}, {5, 0, 5, 5, 0, 5}}, + }, + } - gotInt := got.Int() - if gotInt != 1 { - t.Fatalf("got %d, want 1 (token 1 should survive after wrap)", gotInt) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // Per-slot reference for each sampled seq. + want := make([]int, len(tc.sample)) + for i, id := range tc.sample { + var spec slot + for _, s := range tc.slots { + if s.id == id { + spec = s + break + } + } + want[i] = sampleOne(t, spec.opts, spec.priors, tc.rows[i]) + } + + // Batched call. + s := New(128) + t.Cleanup(func() { + s.Free() + mlx.Sweep() + }) + for _, spec := range tc.slots { + s.Add(spec.id, spec.opts, spec.priors) + } + res := s.Sample(tc.sample, batchLogits(tc.rows...)) + mlx.Eval(res.Token) + got := res.Token.Ints() + + for i, id := range tc.sample { + if got[i] != want[i] { + t.Errorf("seq %d: batched = %d, per-slot = %d", id, got[i], want[i]) + } + } + }) } } -func TestMinPMasksTokensBelowThreshold(t *testing.T) { - s := New(Options{MinP: 0.5}) - defer func() { +// TestRemoveDoesNotLeakHistory: after Remove, a newly-added slot at the +// recycled row must start from its own priors only — no carryover from +// the removed slot's history. +func TestRemoveDoesNotLeakHistory(t *testing.T) { + skipIfNoMLX(t) + + opts := Options{RepeatLastN: 1, PresencePenalty: 10} + s := New(128) + t.Cleanup(func() { s.Free() mlx.Sweep() - }() + }) + s.Add(1, opts, []int32{1}) + s.Add(2, opts, []int32{2}) + s.Remove(1) + s.Add(3, opts, []int32{0}) - logits := mlx.FromValues([]float32{ - float32(math.Log(0.5)), - float32(math.Log(0.3)), - float32(math.Log(0.2)), - }, 3) - got := minP(s, logits) - mlx.Eval(got) - - gotFloats := got.Floats() - if len(gotFloats) != 3 { - t.Fatalf("got %d scores, want 3", len(gotFloats)) + // Slot 2 retains history {2}; slot 3 retains history {0}. With + // equal logits and PresencePenalty=10 the argmax drops to the first + // unpenalized token. + res := s.Sample([]int{2, 3}, batchLogits( + []float32{3, 3, 0}, + []float32{3, 3, 0}, + )) + mlx.Eval(res.Token) + tokens := res.Token.Ints() + if tokens[0] != 0 { + t.Errorf("slot 2 = %d, want 0 (token 2 penalized)", tokens[0]) } - - if math.IsInf(float64(gotFloats[0]), -1) || math.IsInf(float64(gotFloats[1]), -1) { - t.Fatalf("kept tokens were masked: %v", gotFloats) - } - - if !math.IsInf(float64(gotFloats[2]), -1) { - t.Fatalf("lowest-probability token should be masked, got %v", gotFloats) + if tokens[1] != 1 { + t.Errorf("slot 3 = %d, want 1 (token 0 penalized, no slot-1 carryover)", tokens[1]) } } diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go index 41d8c976f..d44de3d2b 100644 --- a/x/mlxrunner/server.go +++ b/x/mlxrunner/server.go @@ -93,7 +93,7 @@ func Execute(args []string) error { } request.Pipeline = runner.TextGenerationPipeline - request.Sampler = sample.New(sample.Options{ + request.SamplerOpts = sample.Options{ Temperature: request.Options.Temperature, TopP: request.Options.TopP, MinP: request.Options.MinP, @@ -104,7 +104,7 @@ func Execute(args []string) error { FrequencyPenalty: request.Options.FrequencyPenalty, Logprobs: request.Logprobs, TopLogprobs: request.TopLogprobs, - }) + } if err := runner.Prepare(&request); err != nil { http.Error(w, err.Error(), http.StatusBadRequest)