mirror of
https://github.com/ollama/ollama.git
synced 2026-05-13 14:27:00 +00:00
Merge 98c87255bd into 3af1a008e2
This commit is contained in:
commit
c20ec6506b
5 changed files with 494 additions and 184 deletions
|
|
@ -5,21 +5,39 @@ import "C"
|
|||
|
||||
import "unsafe"
|
||||
|
||||
func RandomKey(seed uint64) *Array {
|
||||
out := New("RANDOM_KEY")
|
||||
C.mlx_random_key(&out.ctx, C.uint64_t(seed))
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Categorical(axis int) *Array {
|
||||
key := New("")
|
||||
return t.CategoricalWithKey(axis, nil)
|
||||
}
|
||||
|
||||
func (t *Array) CategoricalWithKey(axis int, key *Array) *Array {
|
||||
if key == nil {
|
||||
key = New("")
|
||||
}
|
||||
out := New("")
|
||||
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Bernoulli(p *Array) *Array {
|
||||
return BernoulliWithKey(p, nil)
|
||||
}
|
||||
|
||||
func BernoulliWithKey(p *Array, key *Array) *Array {
|
||||
dims := p.Dims()
|
||||
shape := make([]C.int, len(dims))
|
||||
for i, d := range dims {
|
||||
shape[i] = C.int(d)
|
||||
}
|
||||
|
||||
key := New("")
|
||||
if key == nil {
|
||||
key = New("")
|
||||
}
|
||||
out := New("BERNOULLI")
|
||||
C.mlx_random_bernoulli(&out.ctx, p.ctx, unsafe.SliceData(shape), C.size_t(len(shape)), key.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
|
@ -375,9 +374,11 @@ func (r *Runner) runSampleMTPDecode(ctx context.Context, request Request, sessio
|
|||
t0 = time.Now()
|
||||
candidates := r.generateMTPDraftCandidates(draft, targetEmbeddings, current.Token, hidden, caches, int32(*position-1), maxDraft)
|
||||
draftCount := 0
|
||||
var candidateArrays []*mlx.Array
|
||||
if candidates != nil {
|
||||
draftCount = candidates.tokens.Dim(1)
|
||||
mlx.Pin(baseLogits, candidates.tokens, candidates.logits)
|
||||
candidateArrays = append([]*mlx.Array{baseLogits}, candidates.Arrays()...)
|
||||
mlx.Pin(candidateArrays...)
|
||||
mlx.Sweep()
|
||||
}
|
||||
stats.draftDuration += time.Since(t0)
|
||||
|
|
@ -391,7 +392,7 @@ func (r *Runner) runSampleMTPDecode(ctx context.Context, request Request, sessio
|
|||
t0 = time.Now()
|
||||
next, accepted, done, err = r.acceptSampleMTPDrafts(ctx, request, session, &dec, caches, position, baseLogits, candidates, &final, &generated, &stats)
|
||||
stats.validateDuration += time.Since(t0)
|
||||
mlx.Unpin(baseLogits, candidates.tokens, candidates.logits)
|
||||
mlx.Unpin(candidateArrays...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -456,8 +457,15 @@ func (r *Runner) runSampleMTPDecode(ctx context.Context, request Request, sessio
|
|||
|
||||
type mtpDraftCandidates struct {
|
||||
tokens *mlx.Array
|
||||
// logits are the processed proposal scores used to sample tokens.
|
||||
logits *mlx.Array
|
||||
// dist is the proposal distribution used to sample each drafted token.
|
||||
dist sampler.Distribution
|
||||
}
|
||||
|
||||
func (c *mtpDraftCandidates) Arrays() []*mlx.Array {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return append([]*mlx.Array{c.tokens}, c.dist.Arrays()...)
|
||||
}
|
||||
|
||||
func (r *Runner) generateMTPDrafts(draft base.MTPDraftModel, target base.MTPEmbeddingModel, token *mlx.Array, hidden *mlx.Array, caches []cache.Cache, position int32, maxDraft int) *mlx.Array {
|
||||
|
|
@ -497,7 +505,7 @@ func (r *Runner) generateMTPDraftCandidates(draft base.MTPDraftModel, target bas
|
|||
lastToken := mtpTokenInput(token)
|
||||
lastHidden := hidden
|
||||
draftTokens := make([]*mlx.Array, 0, maxDraft)
|
||||
draftLogits := make([]*mlx.Array, 0, maxDraft)
|
||||
draftDists := make([]sampler.Distribution, 0, maxDraft)
|
||||
var prefix *mlx.Array
|
||||
|
||||
// Gemma4 assistant MTP is trained as "single-position" drafting:
|
||||
|
|
@ -508,13 +516,13 @@ func (r *Runner) generateMTPDraftCandidates(draft base.MTPDraftModel, target bas
|
|||
inputs := tokenEmbedding.Concatenate(-1, lastHidden)
|
||||
logits, projected := draft.Draft(inputs, position, caches)
|
||||
stepLogits := r.lastLogitsFromLogits(logits)
|
||||
stepScores := r.Sampler.SpeculativeScores(pipelineSlot, stepLogits, prefix)
|
||||
nextToken := stepScores.Categorical(-1).AsType(mlx.DTypeInt32)
|
||||
dist := r.Sampler.Distribution(pipelineSlot, stepLogits, prefix)
|
||||
nextToken := r.Sampler.SampleDistribution(pipelineSlot, dist)
|
||||
|
||||
lastToken = mtpTokenInput(nextToken)
|
||||
lastHidden = projected
|
||||
draftTokens = append(draftTokens, lastToken)
|
||||
draftLogits = append(draftLogits, stepScores.ExpandDims(1))
|
||||
draftDists = append(draftDists, dist)
|
||||
if prefix == nil {
|
||||
prefix = lastToken
|
||||
} else {
|
||||
|
|
@ -526,7 +534,7 @@ func (r *Runner) generateMTPDraftCandidates(draft base.MTPDraftModel, target bas
|
|||
}
|
||||
return &mtpDraftCandidates{
|
||||
tokens: mlx.Concatenate(draftTokens, 1),
|
||||
logits: mlx.Concatenate(draftLogits, 1),
|
||||
dist: sampler.ConcatenateDistributions(draftDists),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -630,12 +638,9 @@ func (r *Runner) acceptSampleMTPDrafts(ctx context.Context, request Request, ses
|
|||
SeqQueryLens: []int32{int32(draftCount)},
|
||||
}, specCaches)
|
||||
|
||||
targetScores := r.Sampler.SpeculativeScores(pipelineSlot, r.mtpValidationLogits(baseLogits, hiddenSeq), candidates.tokens)
|
||||
draftScores := candidates.logits
|
||||
if draftScores.NumDims() == 3 {
|
||||
draftScores = draftScores.Squeeze(0)
|
||||
}
|
||||
acceptedMask := mtpSampleAcceptedMask(targetScores, draftScores, candidates.tokens, draftCount)
|
||||
targetDist := r.Sampler.Distribution(pipelineSlot, r.mtpValidationLogits(baseLogits, hiddenSeq), candidates.tokens)
|
||||
draftDist := candidates.dist
|
||||
acceptedMask := r.mtpSampleAcceptedMask(targetDist.SliceRows(0, draftCount), draftDist, candidates.tokens)
|
||||
mlx.Eval(candidates.tokens, acceptedMask)
|
||||
|
||||
draftIDs := candidates.tokens.Ints()
|
||||
|
|
@ -692,9 +697,9 @@ func (r *Runner) acceptSampleMTPDrafts(ctx context.Context, request Request, ses
|
|||
|
||||
var nextToken *mlx.Array
|
||||
if accepted == draftCount {
|
||||
nextToken = mtpSampleTokenAt(targetScores, draftCount)
|
||||
nextToken = r.mtpSampleTokenAt(targetDist, draftCount)
|
||||
} else {
|
||||
nextToken = mtpSampleResidualToken(targetScores, draftScores, accepted)
|
||||
nextToken = r.mtpSampleResidualToken(targetDist, draftDist, accepted)
|
||||
}
|
||||
mlx.Eval(nextToken)
|
||||
nextID := int32(tokenID(nextToken))
|
||||
|
|
@ -704,32 +709,20 @@ func (r *Runner) acceptSampleMTPDrafts(ctx context.Context, request Request, ses
|
|||
return sampler.Result{Token: nextToken}, accepted, false, nil
|
||||
}
|
||||
|
||||
func mtpSampleAcceptedMask(targetScores, draftScores, draftTokens *mlx.Array, draftCount int) *mlx.Array {
|
||||
targetProbs := mlx.SoftmaxAxis(targetScores.Slice(mlx.Slice(0, draftCount), mlx.Slice()), -1, true)
|
||||
draftProbs := mlx.SoftmaxAxis(draftScores, -1, true)
|
||||
if draftTokens.NumDims() == 2 {
|
||||
draftTokens = draftTokens.Squeeze(0)
|
||||
}
|
||||
indices := draftTokens.ExpandDims(-1)
|
||||
p := targetProbs.TakeAlongAxis(indices, -1).Squeeze(-1)
|
||||
q := draftProbs.TakeAlongAxis(indices, -1).Squeeze(-1)
|
||||
func (r *Runner) mtpSampleAcceptedMask(targetDist, draftDist sampler.Distribution, draftTokens *mlx.Array) *mlx.Array {
|
||||
p := targetDist.Prob(draftTokens)
|
||||
q := draftDist.Prob(draftTokens)
|
||||
acceptP := mlx.Minimum(p.Divide(q), mlx.FromValue(float32(1)))
|
||||
return mlx.Bernoulli(acceptP).AsType(mlx.DTypeInt32)
|
||||
return r.Sampler.Bernoulli(pipelineSlot, acceptP).AsType(mlx.DTypeInt32)
|
||||
}
|
||||
|
||||
func mtpSampleTokenAt(scores *mlx.Array, index int) *mlx.Array {
|
||||
row := scores.Slice(mlx.Slice(index, index+1), mlx.Slice())
|
||||
return mtpTokenVector(row.Categorical(-1).AsType(mlx.DTypeInt32))
|
||||
func (r *Runner) mtpSampleTokenAt(dist sampler.Distribution, index int) *mlx.Array {
|
||||
return mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, dist.SliceRows(index, index+1)))
|
||||
}
|
||||
|
||||
func mtpSampleResidualToken(targetScores, draftScores *mlx.Array, index int) *mlx.Array {
|
||||
p := mlx.SoftmaxAxis(targetScores.Slice(mlx.Slice(index, index+1), mlx.Slice()), -1, true)
|
||||
q := mlx.SoftmaxAxis(draftScores.Slice(mlx.Slice(index, index+1), mlx.Slice()), -1, true)
|
||||
diff := p.Subtract(q)
|
||||
positive := mlx.Maximum(diff, mlx.FromValue(float32(1e-20)))
|
||||
logits := mlx.Log(positive)
|
||||
logits = mlx.Where(diff.LessEqual(mlx.FromValue(float32(0))), mlx.FromValue(float32(math.Inf(-1))), logits)
|
||||
return mtpTokenVector(logits.Categorical(-1).AsType(mlx.DTypeInt32))
|
||||
func (r *Runner) mtpSampleResidualToken(targetDist, draftDist sampler.Distribution, index int) *mlx.Array {
|
||||
residual := targetDist.SliceRows(index, index+1).ResidualAgainst(draftDist.SliceRows(index, index+1))
|
||||
return mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, residual))
|
||||
}
|
||||
|
||||
func mtpTokenInput(token *mlx.Array) *mlx.Array {
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ type Options struct {
|
|||
RepeatPenalty float32
|
||||
PresencePenalty float32
|
||||
FrequencyPenalty float32
|
||||
Seed int
|
||||
UseSeed bool
|
||||
|
||||
// Logprobs causes Sample to populate Result.Logprob with the selected
|
||||
// token's log-probability. TopLogprobs (when > 0) adds top-K pairs.
|
||||
|
|
@ -42,6 +44,123 @@ func (r Result) Arrays() []*mlx.Array {
|
|||
return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs}
|
||||
}
|
||||
|
||||
// Distribution is the filtered probability distribution used by the sampler.
|
||||
// When IDs is nil, Probs is dense over the vocabulary. When IDs is set, Probs
|
||||
// is sparse over the token ids in IDs, preserving GPU residency for the
|
||||
// top-k-first path used by normal and speculative sampling.
|
||||
type Distribution struct {
|
||||
IDs *mlx.Array // sparse token ids, shape [B,K]; nil for dense distributions
|
||||
Probs *mlx.Array // probabilities, shape [B,K] or [B,V]
|
||||
}
|
||||
|
||||
// Arrays returns the tensor fields for mlx lifecycle management.
|
||||
func (d Distribution) Arrays() []*mlx.Array {
|
||||
return []*mlx.Array{d.IDs, d.Probs}
|
||||
}
|
||||
|
||||
// Rows returns the number of rows in the distribution.
|
||||
func (d Distribution) Rows() int {
|
||||
if d.Probs == nil {
|
||||
return 0
|
||||
}
|
||||
return d.Probs.Dim(0)
|
||||
}
|
||||
|
||||
// SliceRows returns a row slice while preserving sparse/dense layout.
|
||||
func (d Distribution) SliceRows(start, stop int) Distribution {
|
||||
out := Distribution{Probs: d.Probs.Slice(mlx.Slice(start, stop), mlx.Slice())}
|
||||
if d.IDs != nil {
|
||||
out.IDs = d.IDs.Slice(mlx.Slice(start, stop), mlx.Slice())
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// SampleWithKey draws one token per row using key when supplied.
|
||||
func (d Distribution) SampleWithKey(key *mlx.Array) *mlx.Array {
|
||||
choice := logitsFromProbs(d.Probs).CategoricalWithKey(-1, key).AsType(mlx.DTypeInt32)
|
||||
if d.IDs == nil {
|
||||
return choice
|
||||
}
|
||||
return d.IDs.TakeAlongAxis(choice.ExpandDims(-1), -1).Squeeze(-1).AsType(mlx.DTypeInt32)
|
||||
}
|
||||
|
||||
// Prob returns the probability assigned to one token per row.
|
||||
func (d Distribution) Prob(tokens *mlx.Array) *mlx.Array {
|
||||
switch tokens.NumDims() {
|
||||
case 2:
|
||||
if tokens.Dim(0) == 1 {
|
||||
tokens = tokens.Squeeze(0)
|
||||
} else if tokens.Dim(1) == 1 {
|
||||
tokens = tokens.Squeeze(1)
|
||||
}
|
||||
case 0:
|
||||
tokens = tokens.Reshape(1)
|
||||
}
|
||||
return d.ProbsForIDs(tokens.ExpandDims(-1)).Squeeze(-1)
|
||||
}
|
||||
|
||||
// ProbsForIDs returns probabilities for each requested token id. ids must be
|
||||
// rank-2 [B,N], matching the distribution rows.
|
||||
func (d Distribution) ProbsForIDs(ids *mlx.Array) *mlx.Array {
|
||||
if d.IDs == nil {
|
||||
return d.Probs.TakeAlongAxis(ids, -1)
|
||||
}
|
||||
eq := d.IDs.ExpandDims(-1).Equal(ids.ExpandDims(1))
|
||||
values := mlx.Where(eq, d.Probs.ExpandDims(-1), mlx.FromValue(float32(0)))
|
||||
return values.SumAxis(1, false)
|
||||
}
|
||||
|
||||
// ResidualAgainst returns the Leviathan/Chen rejection distribution
|
||||
// proportional to max(target - draft, 0). Sparse target distributions stay
|
||||
// sparse over the target support; tokens outside target support have zero mass.
|
||||
func (d Distribution) ResidualAgainst(draft Distribution) Distribution {
|
||||
if d.IDs != nil {
|
||||
diff := d.Probs.Subtract(draft.ProbsForIDs(d.IDs))
|
||||
return Distribution{IDs: d.IDs, Probs: normalizeProbs(mlx.Maximum(diff, mlx.FromValue(float32(0))))}
|
||||
}
|
||||
if draft.IDs != nil {
|
||||
panic("sample.Distribution.ResidualAgainst: dense target with sparse draft is unsupported")
|
||||
}
|
||||
diff := d.Probs.Subtract(draft.Probs)
|
||||
return Distribution{Probs: normalizeProbs(mlx.Maximum(diff, mlx.FromValue(float32(0))))}
|
||||
}
|
||||
|
||||
// LogProbs returns dense log-probabilities, scattering sparse distributions
|
||||
// into a full-vocabulary tensor when needed.
|
||||
func (d Distribution) LogProbs(vocab int) *mlx.Array {
|
||||
logProbs := logitsFromProbs(d.Probs)
|
||||
if d.IDs == nil {
|
||||
return logProbs
|
||||
}
|
||||
out := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, d.Probs.Dim(0), vocab), float32(math.Inf(-1)))
|
||||
return out.PutAlongAxis(d.IDs, logProbs, -1)
|
||||
}
|
||||
|
||||
// ConcatenateDistributions concatenates distribution rows. All inputs must use
|
||||
// the same sparse/dense layout.
|
||||
func ConcatenateDistributions(dists []Distribution) Distribution {
|
||||
if len(dists) == 0 {
|
||||
return Distribution{}
|
||||
}
|
||||
probs := make([]*mlx.Array, 0, len(dists))
|
||||
ids := make([]*mlx.Array, 0, len(dists))
|
||||
sparse := dists[0].IDs != nil
|
||||
for _, d := range dists {
|
||||
if (d.IDs != nil) != sparse {
|
||||
panic("sample.ConcatenateDistributions: mixed sparse and dense distributions")
|
||||
}
|
||||
probs = append(probs, d.Probs)
|
||||
if sparse {
|
||||
ids = append(ids, d.IDs)
|
||||
}
|
||||
}
|
||||
out := Distribution{Probs: mlx.Concatenate(probs, 0)}
|
||||
if sparse {
|
||||
out.IDs = mlx.Concatenate(ids, 0)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Sampler is a batched, slot-based sampler. Sequences are registered with
|
||||
// Add and released with Remove. Each Sample call takes a subset of
|
||||
// registered slots (in any order) with their [B,V] logits, samples one
|
||||
|
|
@ -73,9 +192,9 @@ type Sampler struct {
|
|||
}
|
||||
|
||||
type slotState struct {
|
||||
opts Options
|
||||
transforms []transform
|
||||
historyLen int
|
||||
opts Options
|
||||
historyLen int
|
||||
randomCounter uint64
|
||||
}
|
||||
|
||||
type slotCtx struct {
|
||||
|
|
@ -83,8 +202,6 @@ type slotCtx struct {
|
|||
history *mlx.Array // 2D [B, W] when penalties are configured; nil otherwise
|
||||
}
|
||||
|
||||
type transform func(*slotCtx, *mlx.Array) *mlx.Array
|
||||
|
||||
// New constructs an empty sampler with no registered slots. numCtx is
|
||||
// the runner's context window and must be positive.
|
||||
func New(numCtx int) *Sampler {
|
||||
|
|
@ -127,40 +244,17 @@ func (o Options) normalize(numCtx int) Options {
|
|||
// RepeatLastN still batch together and don't inflate pool width.
|
||||
o.RepeatLastN = 0
|
||||
}
|
||||
if o.Seed < 0 {
|
||||
o.UseSeed = false
|
||||
}
|
||||
if !o.UseSeed {
|
||||
// Keep unseeded callers on the same batching path even when a
|
||||
// meaningless Seed value is present in an Options literal.
|
||||
o.Seed = 0
|
||||
}
|
||||
return o
|
||||
}
|
||||
|
||||
func (o Options) buildTransforms() []transform {
|
||||
var ts []transform
|
||||
if o.usesHistory() {
|
||||
ts = append(ts, penalty)
|
||||
}
|
||||
|
||||
hasTopP := o.TopP > 0 && o.TopP < 1
|
||||
hasTopK := o.TopK > 0
|
||||
switch {
|
||||
case hasTopP:
|
||||
// topKTopP always does a full descending sort for the top-P
|
||||
// cumulative mask and opportunistically masks top-K during the
|
||||
// same pass when it is also configured.
|
||||
ts = append(ts, topKTopP)
|
||||
case hasTopK:
|
||||
// Argpartition (partial sort) is cheaper than a full sort.
|
||||
ts = append(ts, topK)
|
||||
}
|
||||
|
||||
if o.MinP != 0 {
|
||||
ts = append(ts, minP)
|
||||
}
|
||||
|
||||
if o.Temperature == 0 {
|
||||
ts = append(ts, greedy)
|
||||
} else {
|
||||
ts = append(ts, temperature)
|
||||
}
|
||||
return ts
|
||||
}
|
||||
|
||||
// Add registers a sequence under seqID. The last RepeatLastN entries of
|
||||
// priorTokens seed the ring buffer.
|
||||
func (s *Sampler) Add(seqID int, opts Options, priorTokens []int32) {
|
||||
|
|
@ -170,8 +264,7 @@ func (s *Sampler) Add(seqID int, opts Options, priorTokens []int32) {
|
|||
|
||||
opts = opts.normalize(s.numCtx)
|
||||
slot := &slotState{
|
||||
opts: opts,
|
||||
transforms: opts.buildTransforms(),
|
||||
opts: opts,
|
||||
}
|
||||
|
||||
// Grow the pool to hold this slot's row. The pool is lazy — the first
|
||||
|
|
@ -349,29 +442,34 @@ func (s *Sampler) Sample(seqIDs []int, logits *mlx.Array) Result {
|
|||
return res
|
||||
}
|
||||
|
||||
// SpeculativeScores applies this slot's non-sampling transforms to logits
|
||||
// without mutating sampler state. Row i is scored as if draftTokens[:i] had
|
||||
// already been appended to the slot history. logits must be [R,V] or [1,R,V].
|
||||
func (s *Sampler) SpeculativeScores(seqID int, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
|
||||
slot, ok := s.byID[seqID]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: seqID %d not registered", seqID))
|
||||
}
|
||||
|
||||
if logits.NumDims() == 3 {
|
||||
if logits.Dim(0) != 1 {
|
||||
panic("sample.Sampler.SpeculativeScores: only batch size 1 is supported")
|
||||
}
|
||||
logits = logits.Squeeze(0)
|
||||
}
|
||||
if logits.NumDims() != 2 {
|
||||
panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: logits must be rank 2 or 3, got rank %d", logits.NumDims()))
|
||||
}
|
||||
|
||||
if draftTokens != nil && draftTokens.NumDims() == 1 {
|
||||
draftTokens = draftTokens.ExpandDims(0)
|
||||
}
|
||||
// Distribution applies this slot's sampling transforms to logits without
|
||||
// mutating sampler state. Row i is built as if draftTokens[:i] had already
|
||||
// been appended to the slot history. logits must be [R,V] or [1,R,V].
|
||||
func (s *Sampler) Distribution(seqID int, logits *mlx.Array, draftTokens *mlx.Array) Distribution {
|
||||
slot, logits, draftTokens := s.speculativeInputs("Distribution", seqID, logits, draftTokens)
|
||||
rows := logits.Dim(0)
|
||||
|
||||
var hist *mlx.Array
|
||||
if slot.opts.usesHistory() {
|
||||
if s.history == nil {
|
||||
panic(fmt.Sprintf("sample.Sampler.Distribution: seqID %d has no history", seqID))
|
||||
}
|
||||
if slot.historyLen < slot.opts.RepeatLastN {
|
||||
return s.speculativeDistributionSerial(slot, logits, draftTokens)
|
||||
}
|
||||
hist = s.speculativeHistory(slot, draftTokens, rows)
|
||||
}
|
||||
|
||||
return slot.distribution(&slotCtx{opts: slot.opts, history: hist}, logits)
|
||||
}
|
||||
|
||||
// SpeculativeScores applies this slot's sampling transforms to logits without
|
||||
// mutating sampler state and returns dense log-probability scores for sampled
|
||||
// decoding. Greedy decoding returns the penalty-adjusted logits.
|
||||
func (s *Sampler) SpeculativeScores(seqID int, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
|
||||
slot, logits, draftTokens := s.speculativeInputs("SpeculativeScores", seqID, logits, draftTokens)
|
||||
rows := logits.Dim(0)
|
||||
|
||||
var hist *mlx.Array
|
||||
if slot.opts.usesHistory() {
|
||||
if s.history == nil {
|
||||
|
|
@ -386,6 +484,47 @@ func (s *Sampler) SpeculativeScores(seqID int, logits *mlx.Array, draftTokens *m
|
|||
return slot.speculativeScores(&slotCtx{opts: slot.opts, history: hist}, logits)
|
||||
}
|
||||
|
||||
// SampleDistribution draws from a precomputed distribution while advancing
|
||||
// seqID's deterministic RNG stream when a seed is configured.
|
||||
func (s *Sampler) SampleDistribution(seqID int, dist Distribution) *mlx.Array {
|
||||
slot := s.mustSlot("SampleDistribution", seqID)
|
||||
return dist.SampleWithKey(slot.nextRandomKey())
|
||||
}
|
||||
|
||||
// Bernoulli samples boolean outcomes while advancing seqID's deterministic RNG
|
||||
// stream when a seed is configured.
|
||||
func (s *Sampler) Bernoulli(seqID int, p *mlx.Array) *mlx.Array {
|
||||
slot := s.mustSlot("Bernoulli", seqID)
|
||||
return mlx.BernoulliWithKey(p, slot.nextRandomKey())
|
||||
}
|
||||
|
||||
func (s *Sampler) mustSlot(caller string, seqID int) *slotState {
|
||||
slot, ok := s.byID[seqID]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("sample.Sampler.%s: seqID %d not registered", caller, seqID))
|
||||
}
|
||||
return slot
|
||||
}
|
||||
|
||||
func (s *Sampler) speculativeInputs(caller string, seqID int, logits *mlx.Array, draftTokens *mlx.Array) (*slotState, *mlx.Array, *mlx.Array) {
|
||||
slot := s.mustSlot(caller, seqID)
|
||||
|
||||
if logits.NumDims() == 3 {
|
||||
if logits.Dim(0) != 1 {
|
||||
panic(fmt.Sprintf("sample.Sampler.%s: only batch size 1 is supported", caller))
|
||||
}
|
||||
logits = logits.Squeeze(0)
|
||||
}
|
||||
if logits.NumDims() != 2 {
|
||||
panic(fmt.Sprintf("sample.Sampler.%s: logits must be rank 2 or 3, got rank %d", caller, logits.NumDims()))
|
||||
}
|
||||
|
||||
if draftTokens != nil && draftTokens.NumDims() == 1 {
|
||||
draftTokens = draftTokens.ExpandDims(0)
|
||||
}
|
||||
return slot, logits, draftTokens
|
||||
}
|
||||
|
||||
// Commit appends already-selected tokens to seqID's repeat-penalty history.
|
||||
// It is used after speculative sampling once the accepted continuation is
|
||||
// known. Normal Sample calls continue to mutate history themselves.
|
||||
|
|
@ -422,7 +561,7 @@ func (s *Sampler) Commit(seqID int, tokens []int32) {
|
|||
slot.historyLen += len(tokens)
|
||||
}
|
||||
|
||||
func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
|
||||
func (s *Sampler) speculativeDistributionSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) Distribution {
|
||||
rows := logits.Dim(0)
|
||||
draftCount := 0
|
||||
if draftTokens != nil {
|
||||
|
|
@ -435,7 +574,7 @@ func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, dr
|
|||
base = s.history.Slice(mlx.Slice(row, row+1), mlx.Slice(0, baseFill))
|
||||
}
|
||||
|
||||
scored := make([]*mlx.Array, 0, rows)
|
||||
dists := make([]Distribution, 0, rows)
|
||||
for i := range rows {
|
||||
rowLogits := logits.Slice(mlx.Slice(i, i+1), mlx.Slice())
|
||||
hist := base
|
||||
|
|
@ -451,9 +590,13 @@ func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, dr
|
|||
hist = hist.Slice(mlx.Slice(), mlx.Slice(hist.Dim(1)-slot.opts.RepeatLastN, mlx.End))
|
||||
}
|
||||
}
|
||||
scored = append(scored, slot.speculativeScores(&slotCtx{opts: slot.opts, history: hist}, rowLogits))
|
||||
dists = append(dists, slot.distribution(&slotCtx{opts: slot.opts, history: hist}, rowLogits))
|
||||
}
|
||||
return mlx.Concatenate(scored, 0)
|
||||
return ConcatenateDistributions(dists)
|
||||
}
|
||||
|
||||
func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
|
||||
return s.speculativeDistributionSerial(slot, logits, draftTokens).LogProbs(logits.Dim(logits.NumDims() - 1))
|
||||
}
|
||||
|
||||
func (s *Sampler) speculativeHistory(slot *slotState, draftTokens *mlx.Array, rows int) *mlx.Array {
|
||||
|
|
@ -489,17 +632,10 @@ func (s *Sampler) speculativeHistory(slot *slotState, draftTokens *mlx.Array, ro
|
|||
}
|
||||
|
||||
func (slot *slotState) speculativeScores(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
|
||||
scores := logits
|
||||
// buildTransforms always appends the final selector transform
|
||||
// (greedy or temperature sampling). Speculative validation needs the
|
||||
// processed logits before that selector mutates the distribution.
|
||||
for _, t := range slot.transforms[:len(slot.transforms)-1] {
|
||||
scores = t(ctx, scores)
|
||||
if slot.opts.Temperature == 0 {
|
||||
return slot.baseScores(ctx, logits)
|
||||
}
|
||||
if slot.opts.Temperature > 0 {
|
||||
scores = mlx.DivScalar(scores, slot.opts.Temperature)
|
||||
}
|
||||
return scores
|
||||
return slot.distribution(ctx, logits).LogProbs(logits.Dim(logits.NumDims() - 1))
|
||||
}
|
||||
|
||||
// canBatch reports whether the call can take the uniform batched path.
|
||||
|
|
@ -513,6 +649,9 @@ func (s *Sampler) canBatch(slots []*slotState) (Options, bool) {
|
|||
// slots is non-empty (Sample guards) and every slot is registered,
|
||||
// so s.slots[0].opts is the canonical shared value.
|
||||
shared := s.slots[0].opts
|
||||
// TODO: Before using multi-slot batching with seeded stochastic sampling,
|
||||
// make sure each row gets its own per-slot random key instead of sharing
|
||||
// slots[0]'s key through one batched categorical op.
|
||||
if !shared.usesHistory() {
|
||||
return shared, true
|
||||
}
|
||||
|
|
@ -527,7 +666,7 @@ func (s *Sampler) canBatch(slots []*slotState) (Options, bool) {
|
|||
return shared, true
|
||||
}
|
||||
|
||||
// sampleTokensUniform runs one fused transform pass over the whole batch.
|
||||
// sampleTokensUniform runs one fused sampling pass over the whole batch.
|
||||
// Reached only when canBatch is true, which lets the pool be used in place
|
||||
// with a single PutAlongAxis write-back and no gather.
|
||||
func (s *Sampler) sampleTokensUniform(slots []*slotState, opts Options, logits *mlx.Array) *mlx.Array {
|
||||
|
|
@ -542,11 +681,14 @@ func (s *Sampler) sampleTokensUniform(slots []*slotState, opts Options, logits *
|
|||
}
|
||||
|
||||
ctx := &slotCtx{opts: opts, history: hist}
|
||||
scores := logits
|
||||
for _, t := range slots[0].transforms {
|
||||
scores = t(ctx, scores)
|
||||
token := slots[0].sample(ctx, logits)
|
||||
if opts.UseSeed && opts.Temperature != 0 {
|
||||
// TODO: This only keeps counters aligned; it does not give each slot
|
||||
// an independent key for the batched draw.
|
||||
for _, slot := range slots[1:] {
|
||||
slot.randomCounter++
|
||||
}
|
||||
}
|
||||
token := scores
|
||||
|
||||
if !opts.usesHistory() {
|
||||
return token
|
||||
|
|
@ -563,8 +705,7 @@ func (s *Sampler) sampleTokensUniform(slots []*slotState, opts Options, logits *
|
|||
return token
|
||||
}
|
||||
|
||||
// sampleTokensSerial runs each slot's transforms against its own row of
|
||||
// logits.
|
||||
// sampleTokensSerial samples each slot against its own row of logits.
|
||||
func (s *Sampler) sampleTokensSerial(slots []*slotState, logits *mlx.Array) *mlx.Array {
|
||||
perSlotTokens := make([]*mlx.Array, len(slots))
|
||||
|
||||
|
|
@ -587,11 +728,7 @@ func (s *Sampler) sampleTokensSerial(slots []*slotState, logits *mlx.Array) *mlx
|
|||
}
|
||||
|
||||
ctx := &slotCtx{opts: slot.opts, history: hist}
|
||||
scores := row
|
||||
for _, t := range slot.transforms {
|
||||
scores = t(ctx, scores)
|
||||
}
|
||||
perSlotTokens[i] = scores
|
||||
perSlotTokens[i] = slot.sample(ctx, row)
|
||||
}
|
||||
|
||||
token := mlx.Concatenate(perSlotTokens, 0)
|
||||
|
|
@ -627,74 +764,107 @@ func (s *Sampler) sampleTokensSerial(slots []*slotState, logits *mlx.Array) *mlx
|
|||
return token
|
||||
}
|
||||
|
||||
func greedy(_ *slotCtx, scores *mlx.Array) *mlx.Array {
|
||||
return scores.Argmax(-1, false).AsType(mlx.DTypeInt32)
|
||||
func (slot *slotState) sample(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
|
||||
if slot.opts.Temperature == 0 {
|
||||
return slot.baseScores(ctx, logits).Argmax(-1, false).AsType(mlx.DTypeInt32)
|
||||
}
|
||||
return slot.distribution(ctx, logits).SampleWithKey(slot.nextRandomKey())
|
||||
}
|
||||
|
||||
func temperature(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
|
||||
return mlx.DivScalar(scores, ctx.opts.Temperature).Categorical(-1).AsType(mlx.DTypeInt32)
|
||||
func (slot *slotState) nextRandomKey() *mlx.Array {
|
||||
if !slot.opts.UseSeed {
|
||||
return nil
|
||||
}
|
||||
seed := mixSeed(uint64(slot.opts.Seed), slot.randomCounter)
|
||||
slot.randomCounter++
|
||||
return mlx.RandomKey(seed)
|
||||
}
|
||||
|
||||
// topKTopP applies top-P in a descending sort pass and, when top-K is also
|
||||
// configured, masks any surviving value below the K-th largest in the same
|
||||
// pass. Callers dispatch here whenever top-P is enabled — the top-K-only case
|
||||
// uses a cheaper partial sort via the topK transform.
|
||||
func topKTopP(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
|
||||
const (
|
||||
// SplitMix64 constants used to decorrelate nearby (seed, counter) pairs.
|
||||
splitMix64Weyl = 0x9e3779b97f4a7c15
|
||||
splitMix64Mul1 = 0xbf58476d1ce4e5b9
|
||||
splitMix64Mul2 = 0x94d049bb133111eb
|
||||
splitMix64Shift1 = 30
|
||||
splitMix64Shift2 = 27
|
||||
splitMix64FinalShift = 31
|
||||
)
|
||||
|
||||
func mixSeed(seed, counter uint64) uint64 {
|
||||
z := seed + splitMix64Weyl*(counter+1)
|
||||
z = (z ^ (z >> splitMix64Shift1)) * splitMix64Mul1
|
||||
z = (z ^ (z >> splitMix64Shift2)) * splitMix64Mul2
|
||||
return z ^ (z >> splitMix64FinalShift)
|
||||
}
|
||||
|
||||
func (slot *slotState) baseScores(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
|
||||
scores := logits
|
||||
if slot.opts.usesHistory() {
|
||||
scores = penalty(ctx, scores)
|
||||
}
|
||||
return scores
|
||||
}
|
||||
|
||||
func (slot *slotState) distribution(ctx *slotCtx, logits *mlx.Array) Distribution {
|
||||
scores := slot.baseScores(ctx, logits)
|
||||
if slot.opts.Temperature <= 0 {
|
||||
ids := scores.Argmax(-1, false).AsType(mlx.DTypeInt32).ExpandDims(-1)
|
||||
probs := mlx.AddScalar(ids.AsType(mlx.DTypeFloat32).Multiply(mlx.FromValue(float32(0))), 1)
|
||||
return Distribution{IDs: ids, Probs: probs}
|
||||
}
|
||||
|
||||
vocab := scores.Dim(scores.NumDims() - 1)
|
||||
applyTopK := ctx.opts.TopK > 0 && ctx.opts.TopK < vocab
|
||||
|
||||
order := scores.Negative().ArgsortAxis(-1)
|
||||
sorted := scores.TakeAlongAxis(order, -1)
|
||||
negInf := mlx.FromValue(float32(math.Inf(-1)))
|
||||
|
||||
// Top-P: in descending order, keep tokens whose exclusive cumulative
|
||||
// probability is still below TopP.
|
||||
probs := mlx.SoftmaxAxis(sorted, -1, true)
|
||||
prevCumProbs := probs.Cumsum(-1, false, true).Subtract(probs)
|
||||
keep := prevCumProbs.Less(mlx.FromValue(ctx.opts.TopP))
|
||||
sorted = mlx.Where(keep, sorted, negInf)
|
||||
|
||||
out := scores.PutAlongAxis(order, sorted, -1)
|
||||
|
||||
// Top-K: sorted is already in descending order, so positions [K, V) are
|
||||
// the ones to drop. Scatter -inf through their original-layout indices
|
||||
// (order[K:]). Positional (not value-based) so exactly K tokens survive —
|
||||
// ties at the K-th logit get broken by the sort order rather than
|
||||
// promoted through the filter.
|
||||
if applyTopK {
|
||||
dropOrder := order.Slice(mlx.Slice(), mlx.Slice(ctx.opts.TopK, mlx.End))
|
||||
out = out.PutAlongAxis(dropOrder, negInf, -1)
|
||||
if slot.opts.TopK > 0 && slot.opts.TopK < vocab {
|
||||
return sparseDistribution(ctx.opts, scores)
|
||||
}
|
||||
|
||||
return out
|
||||
return denseDistribution(ctx.opts, scores)
|
||||
}
|
||||
|
||||
func minP(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
|
||||
if ctx.opts.MinP <= 0 || ctx.opts.MinP > 1 {
|
||||
return scores
|
||||
}
|
||||
|
||||
maxScore := scores.MaxAxis(-1, true)
|
||||
threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(ctx.opts.MinP))))
|
||||
|
||||
return mlx.Where(
|
||||
scores.Less(threshold),
|
||||
mlx.FromValue(float32(math.Inf(-1))),
|
||||
scores,
|
||||
)
|
||||
func sparseDistribution(opts Options, scores *mlx.Array) Distribution {
|
||||
ids := scores.Negative().ArgpartitionAxis(opts.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(0, opts.TopK)).AsType(mlx.DTypeInt32)
|
||||
topScores := scores.TakeAlongAxis(ids, -1).AsType(mlx.DTypeFloat32)
|
||||
probs := mlx.SoftmaxAxis(mlx.DivScalar(topScores, opts.Temperature), -1, true)
|
||||
probs = applyTopPProbs(probs, opts.TopP)
|
||||
probs = applyMinPProbs(probs, opts.MinP)
|
||||
return Distribution{IDs: ids, Probs: normalizeProbs(probs)}
|
||||
}
|
||||
|
||||
func topK(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
|
||||
if ctx.opts.TopK <= 0 {
|
||||
return scores
|
||||
}
|
||||
vocab := scores.Dim(scores.NumDims() - 1)
|
||||
if ctx.opts.TopK >= vocab {
|
||||
return scores
|
||||
}
|
||||
func denseDistribution(opts Options, scores *mlx.Array) Distribution {
|
||||
probs := mlx.SoftmaxAxis(mlx.DivScalar(scores.AsType(mlx.DTypeFloat32), opts.Temperature), -1, true)
|
||||
probs = applyTopPProbs(probs, opts.TopP)
|
||||
probs = applyMinPProbs(probs, opts.MinP)
|
||||
return Distribution{Probs: normalizeProbs(probs)}
|
||||
}
|
||||
|
||||
mask := scores.Negative().ArgpartitionAxis(ctx.opts.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(ctx.opts.TopK, mlx.End))
|
||||
return scores.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||
func applyTopPProbs(probs *mlx.Array, topP float32) *mlx.Array {
|
||||
if topP <= 0 || topP >= 1 {
|
||||
return probs
|
||||
}
|
||||
order := probs.Negative().ArgsortAxis(-1)
|
||||
sorted := probs.TakeAlongAxis(order, -1)
|
||||
prevCumProbs := sorted.Cumsum(-1, false, true).Subtract(sorted)
|
||||
keep := prevCumProbs.Less(mlx.FromValue(topP))
|
||||
filtered := mlx.Where(keep, sorted, mlx.FromValue(float32(0)))
|
||||
return mlx.Zeros(probs.DType(), probs.Dims()...).PutAlongAxis(order, filtered, -1)
|
||||
}
|
||||
|
||||
func applyMinPProbs(probs *mlx.Array, minP float32) *mlx.Array {
|
||||
if minP <= 0 || minP > 1 {
|
||||
return probs
|
||||
}
|
||||
threshold := mlx.MulScalar(probs.MaxAxis(-1, true), minP)
|
||||
return mlx.Where(probs.Less(threshold), mlx.FromValue(float32(0)), probs)
|
||||
}
|
||||
|
||||
func normalizeProbs(probs *mlx.Array) *mlx.Array {
|
||||
sum := mlx.Maximum(probs.SumAxis(-1, true), mlx.FromValue(float32(1e-20)))
|
||||
return probs.Divide(sum)
|
||||
}
|
||||
|
||||
func logitsFromProbs(probs *mlx.Array) *mlx.Array {
|
||||
positive := mlx.Maximum(probs, mlx.FromValue(float32(1e-20)))
|
||||
logits := mlx.Log(positive)
|
||||
return mlx.Where(probs.LessEqual(mlx.FromValue(float32(0))), mlx.FromValue(float32(math.Inf(-1))), logits)
|
||||
}
|
||||
|
||||
func penalty(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ package sample
|
|||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
|
|
@ -141,6 +142,132 @@ func TestSampleSingleSlotOptions(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDistributionAppliesTopKBeforeTopP(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
s.Add(0, Options{Temperature: 1, TopK: 2, TopP: 0.7}, nil)
|
||||
|
||||
dist := s.Distribution(0, slotLogits([]float32{logOf(0.6), logOf(0.2), logOf(0.2)}), nil)
|
||||
mlx.Eval(dist.Arrays()...)
|
||||
|
||||
ids := dist.IDs.Ints()
|
||||
probs := dist.Probs.Floats()
|
||||
if len(ids) != 2 || len(probs) != 2 {
|
||||
t.Fatalf("support = ids %v probs %v, want 2 sparse entries", ids, probs)
|
||||
}
|
||||
|
||||
foundTop := false
|
||||
for i, id := range ids {
|
||||
switch id {
|
||||
case 0:
|
||||
foundTop = true
|
||||
if math.Abs(float64(probs[i]-1)) > 1e-5 {
|
||||
t.Fatalf("top token prob = %v, want 1; ids=%v probs=%v", probs[i], ids, probs)
|
||||
}
|
||||
default:
|
||||
if math.Abs(float64(probs[i])) > 1e-5 {
|
||||
t.Fatalf("non-top token %d prob = %v, want 0; ids=%v probs=%v", id, probs[i], ids, probs)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundTop {
|
||||
t.Fatalf("top-k support %v did not include token 0", ids)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistributionResidualUsesTargetSupport(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
target := Distribution{
|
||||
IDs: mlx.NewArrayInt32([]int32{2, 5}, []int32{1, 2}),
|
||||
Probs: mlx.FromValues([]float32{0.7, 0.3}, 1, 2),
|
||||
}
|
||||
draft := Distribution{
|
||||
IDs: mlx.NewArrayInt32([]int32{2, 4}, []int32{1, 2}),
|
||||
Probs: mlx.FromValues([]float32{0.2, 0.8}, 1, 2),
|
||||
}
|
||||
|
||||
residual := target.ResidualAgainst(draft)
|
||||
mlx.Eval(residual.Arrays()...)
|
||||
|
||||
ids := residual.IDs.Ints()
|
||||
probs := residual.Probs.Floats()
|
||||
want := map[int]float64{2: 0.625, 5: 0.375}
|
||||
if len(ids) != 2 || len(probs) != 2 {
|
||||
t.Fatalf("residual = ids %v probs %v, want 2 sparse entries", ids, probs)
|
||||
}
|
||||
for i, id := range ids {
|
||||
w, ok := want[id]
|
||||
if !ok {
|
||||
t.Fatalf("residual includes token %d outside target support: ids=%v probs=%v", id, ids, probs)
|
||||
}
|
||||
if math.Abs(float64(probs[i])-w) > 1e-5 {
|
||||
t.Fatalf("residual token %d prob = %v, want %v; ids=%v probs=%v", id, probs[i], w, ids, probs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeededSamplingIsReproducible(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
seededSequence := func(seed int) []int {
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
s.Add(0, Options{Temperature: 1, TopK: 4, Seed: seed, UseSeed: true}, nil)
|
||||
|
||||
logits := slotLogits([]float32{0, 0, 0, 0})
|
||||
out := make([]int, 32)
|
||||
for i := range out {
|
||||
token := s.Sample([]int{0}, logits).Token
|
||||
mlx.Eval(token)
|
||||
out[i] = token.Int()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
a := seededSequence(1234)
|
||||
b := seededSequence(1234)
|
||||
if !slices.Equal(a, b) {
|
||||
t.Fatalf("same seed produced different sequences:\n%v\n%v", a, b)
|
||||
}
|
||||
|
||||
c := seededSequence(5678)
|
||||
if slices.Equal(a, c) {
|
||||
t.Fatalf("different seeds produced the same sequence: %v", a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeededBernoulliIsReproducible(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
seededMask := func() []int {
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
s.Add(0, Options{Seed: 99, UseSeed: true}, nil)
|
||||
|
||||
mask := s.Bernoulli(0, mlx.FromValues([]float32{0.5, 0.5, 0.5, 0.5, 0.5, 0.5}, 6)).AsType(mlx.DTypeInt32)
|
||||
mlx.Eval(mask)
|
||||
return mask.Ints()
|
||||
}
|
||||
|
||||
a := seededMask()
|
||||
b := seededMask()
|
||||
if !slices.Equal(a, b) {
|
||||
t.Fatalf("same seed produced different bernoulli masks:\n%v\n%v", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSampleHistoryWindow verifies that penalty history respects the
|
||||
// RepeatLastN window: priors longer than RepeatLastN are trimmed on Add,
|
||||
// and once the ring wraps, tokens that rotate out no longer contribute
|
||||
|
|
|
|||
|
|
@ -136,6 +136,8 @@ func Execute(args []string) error {
|
|||
RepeatPenalty: request.Options.RepeatPenalty,
|
||||
PresencePenalty: request.Options.PresencePenalty,
|
||||
FrequencyPenalty: request.Options.FrequencyPenalty,
|
||||
Seed: request.Options.Seed,
|
||||
UseSeed: request.Options.Seed >= 0,
|
||||
Logprobs: request.Logprobs,
|
||||
TopLogprobs: request.TopLogprobs,
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue