mlxrunner: use MaxAxis in the min-P sampler

One reduction op instead of Argmax + TakeAlongAxis.
This commit is contained in:
Jesse Gross 2026-04-16 13:41:59 -07:00
parent 24e038d56a
commit ca01373b28
2 changed files with 7 additions and 1 deletions

View file

@ -139,6 +139,12 @@ func (t *Array) Less(other *Array) *Array {
return out
}
func (t *Array) MaxAxis(axis int, keepDims bool) *Array {
out := New("MAX_AXIS")
C.mlx_max_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL")
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)

View file

@ -198,7 +198,7 @@ func minP(s *Sampler, scores *mlx.Array) *mlx.Array {
return scores
}
maxScore := scores.TakeAlongAxis(scores.Argmax(-1, true), -1)
maxScore := scores.MaxAxis(-1, true)
threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(s.MinP))))
return mlx.Where(