mirror of
https://github.com/ollama/ollama.git
synced 2026-05-13 06:21:28 +00:00
mlxrunner: use MaxAxis in the min-P sampler
One reduction op instead of Argmax + TakeAlongAxis.
This commit is contained in:
parent
24e038d56a
commit
ca01373b28
2 changed files with 7 additions and 1 deletions
|
|
@ -139,6 +139,12 @@ func (t *Array) Less(other *Array) *Array {
|
|||
return out
|
||||
}
|
||||
|
||||
func (t *Array) MaxAxis(axis int, keepDims bool) *Array {
|
||||
out := New("MAX_AXIS")
|
||||
C.mlx_max_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Matmul(other *Array) *Array {
|
||||
out := New("MATMUL")
|
||||
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
|
|
|
|||
|
|
@ -198,7 +198,7 @@ func minP(s *Sampler, scores *mlx.Array) *mlx.Array {
|
|||
return scores
|
||||
}
|
||||
|
||||
maxScore := scores.TakeAlongAxis(scores.Argmax(-1, true), -1)
|
||||
maxScore := scores.MaxAxis(-1, true)
|
||||
threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(s.MinP))))
|
||||
|
||||
return mlx.Where(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue