ollama/x/mlxrunner/cache/cache.go
Patrick Devine 15e6076d79
mlx: Gemma4 MTP speculative decoding (#15980)
This change adds support for MTP (multi-token prediction) speculative decoding for the
gemma4 model family.

It includes:
  * support for importing safetensors based gemma4 draft models with `ollama create`
  * a new DRAFT command in the Modelfile for specifying draft models
  * a --quantize-draft flag for the ollama create command to quantize the draft model
  * cache support for speculation
  * changes to the rotating cache to be able to handle MTP correctly
  * sampling support for draft model token prediction

---------

Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
2026-05-05 08:55:04 -07:00

840 lines
24 KiB
Go

package cache
import (
"fmt"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
// Cache is common state management shared by every cache kind. Writers
// live on the specific caches
type Cache interface {
// State returns the cache-owned state roots that should be kept/evaluated.
State() []*mlx.Array
Free()
Offset() int
// Snapshot copies cache state from fromOffset to current offset into
// pinned VRAM arrays. The active cache is unchanged.
Snapshot(fromOffset int) Snapshot
// Restore brings the cache to target. If snapshot is nil, rewinds
// using the cache's own live state. Returns false if the target is
// unreachable (e.g. target > current offset, or negative).
Restore(snapshot Snapshot, target int) bool
// Merge combines two sequential snapshots [a,b) and [b,c) into [a,c).
// Takes ownership of both inputs.
Merge(parent, child Snapshot) Snapshot
// Split divides a snapshot [a,c) at offset b into [a,b) and [b,c).
// Takes ownership of the input. Cache types that cannot split
// (e.g. recurrent) return (nil, snapshot).
Split(snapshot Snapshot, at int) (parent, child Snapshot)
}
// Snapshot is paged-out cache state that can be restored later.
type Snapshot interface {
// Size returns the byte size of the paged-out data (in VRAM).
Size() int
// Close unpins the snapshot's arrays so they can be freed by Sweep.
Close()
}
// Attention is the contract for caches that back attention layers
// (KVCache, RotatingKVCache).
type Attention interface {
Cache
// Update appends (k, v) and returns an opaque nn.KVHistory for
// this layer's SDPA.
Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory
}
// Viewer exposes a read-only attention history for a cache.
type Viewer interface {
View(b *batch.Batch) *nn.KVHistory
}
type speculativeCommitter interface {
Cache
commit(n int)
}
// Speculation is an isolated cache transaction for speculative target
// validation. Updates record generated K/V without mutating the live caches;
// Commit appends only the accepted prefix to the live caches.
type Speculation struct {
layers []speculativeCommitter
}
// BeginSpeculation returns cache wrappers suitable for a speculative target
// forward. The returned caches must only be used for that forward.
func BeginSpeculation(caches []Cache) ([]Cache, *Speculation, bool) {
specCaches := make([]Cache, len(caches))
layers := make([]speculativeCommitter, len(caches))
for i, c := range caches {
switch c := c.(type) {
case nil:
case *RotatingKVCache:
sc := newSpeculativeRotatingKVCache(c)
specCaches[i] = sc
layers[i] = sc
case *KVCache:
sc := newSpeculativeKVCache(c)
specCaches[i] = sc
layers[i] = sc
default:
return nil, nil, false
}
}
return specCaches, &Speculation{layers: layers}, true
}
// BeginIsolatedSpeculation returns cache wrappers that never mutate live cache
// state. It is intended for correctness instrumentation, not the hot path.
func BeginIsolatedSpeculation(caches []Cache) ([]Cache, bool) {
specCaches := make([]Cache, len(caches))
for i, c := range caches {
switch c := c.(type) {
case nil:
case *RotatingKVCache:
specCaches[i] = newSpeculativeRotatingKVCache(c)
case *KVCache:
specCaches[i] = newIsolatedKVCache(c)
default:
return nil, false
}
}
return specCaches, true
}
// Commit appends the accepted prefix from the speculative forward to the live
// caches. The target bonus token is intentionally not committed.
func (s *Speculation) Commit(n int) {
if s == nil {
return
}
for _, layer := range s.layers {
if layer != nil {
layer.commit(n)
}
}
}
type KVCache struct {
keys, values *mlx.Array
offset int
step int
}
func NewKVCache() *KVCache {
return &KVCache{step: 256}
}
// Assumes B = 1; heterogeneous batches are not supported.
func (c *KVCache) Update(_ *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
newK, newV := c.appendKV(keys, values)
return nn.NewKVHistory(newK, newV, nil)
}
// View returns the current cache contents as attention history without writing.
func (c *KVCache) View(_ *batch.Batch) *nn.KVHistory {
state := c.State()
if len(state) < 2 {
return nil
}
return nn.NewKVHistory(state[0], state[1], nil)
}
// appendKV is the raw write path shared by Update and Restore.
func (c *KVCache) appendKV(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
prev := c.offset
// Grow buffer if needed
if c.keys == nil || (prev+L) > c.keys.Dim(2) {
steps := (c.step + L - 1) / c.step
newKeys := mlx.Zeros(keys.DType(), B, H, steps*c.step, Dk)
newValues := mlx.Zeros(values.DType(), B, H, steps*c.step, Dv)
if c.keys != nil {
if prev%c.step != 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
}
c.keys.Set(c.keys.Concatenate(2, newKeys))
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
mlx.Pin(c.keys, c.values)
}
}
c.offset += L
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
func (c *KVCache) State() []*mlx.Array {
if c.keys == nil || c.values == nil {
return nil
}
return []*mlx.Array{
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
}
}
// kvSnapshot holds paged-out KV data for a range [fromOffset, toOffset).
type kvSnapshot struct {
keys, values *mlx.Array
fromOffset, toOffset int
}
func (s *kvSnapshot) Size() int { return s.keys.NumBytes() + s.values.NumBytes() }
func (s *kvSnapshot) Close() { mlx.Unpin(s.keys, s.values) }
func (c *KVCache) Snapshot(fromOffset int) Snapshot {
if c.keys == nil || c.offset <= fromOffset {
return nil
}
from := max(0, fromOffset)
to := c.offset
kSlice := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
vSlice := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
kCopy := mlx.Contiguous(kSlice, false)
vCopy := mlx.Contiguous(vSlice, false)
mlx.Pin(kCopy, vCopy)
mlx.AsyncEval(kCopy, vCopy)
return &kvSnapshot{
keys: kCopy,
values: vCopy,
fromOffset: from,
toOffset: to,
}
}
func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
if target < 0 {
return false
}
if snapshot == nil {
if target > c.offset {
return false
}
c.offset = target
return true
}
snap := snapshot.(*kvSnapshot)
if target > snap.toOffset || c.offset < snap.fromOffset {
return false
}
// Rewind to snapshot start, then feed snapshot.
c.offset = snap.fromOffset
c.appendKV(snap.keys, snap.values)
// Clamp to target if needed (target may be less than full snapshot).
if target < c.offset {
c.offset = target
}
return true
}
func (c *KVCache) Merge(parent, child Snapshot) Snapshot {
if parent == nil || child == nil {
if parent != nil {
parent.Close()
}
if child != nil {
child.Close()
}
return nil
}
p := parent.(*kvSnapshot)
ch := child.(*kvSnapshot)
mk := p.keys.Concatenate(2, ch.keys)
mv := p.values.Concatenate(2, ch.values)
mlx.Pin(mk, mv)
mlx.AsyncEval(mk, mv)
p.Close()
ch.Close()
return &kvSnapshot{
keys: mk,
values: mv,
fromOffset: p.fromOffset,
toOffset: ch.toOffset,
}
}
func (c *KVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
if snapshot == nil {
return nil, nil
}
snap := snapshot.(*kvSnapshot)
splitIdx := at - snap.fromOffset
seqLen := snap.toOffset - snap.fromOffset
if splitIdx <= 0 {
return nil, snapshot
}
if splitIdx >= seqLen {
return snapshot, nil
}
pk := mlx.Contiguous(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()), false)
pv := mlx.Contiguous(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()), false)
ck := mlx.Contiguous(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()), false)
cv := mlx.Contiguous(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()), false)
mlx.Pin(pk, pv, ck, cv)
mlx.AsyncEval(pk, pv, ck, cv)
snap.Close()
p := &kvSnapshot{
keys: pk,
values: pv,
fromOffset: snap.fromOffset,
toOffset: at,
}
ch := &kvSnapshot{
keys: ck,
values: cv,
fromOffset: at,
toOffset: snap.toOffset,
}
return p, ch
}
func (c *KVCache) Free() {
mlx.Unpin(c.keys, c.values)
c.keys, c.values = nil, nil
c.offset = 0
}
func (c *KVCache) Offset() int { return c.offset }
type speculativeBase struct {
offset int
}
func (s *speculativeBase) Free() {}
func (s *speculativeBase) Offset() int { return s.offset }
func (s *speculativeBase) Snapshot(int) Snapshot { return nil }
func (s *speculativeBase) Restore(Snapshot, int) bool { return false }
func (s *speculativeBase) Merge(parent, child Snapshot) Snapshot { return nil }
func (s *speculativeBase) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
return nil, snapshot
}
type speculativeKVCache struct {
speculativeBase
target *KVCache
start int
end int
}
func newSpeculativeKVCache(target *KVCache) *speculativeKVCache {
return &speculativeKVCache{
speculativeBase: speculativeBase{offset: target.Offset()},
target: target,
start: target.Offset(),
end: target.Offset(),
}
}
func (c *speculativeKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
history := c.target.Update(b, keys, values)
c.offset = c.target.Offset()
c.end = c.target.Offset()
return history
}
func (c *speculativeKVCache) State() []*mlx.Array {
return c.target.State()
}
func (c *speculativeKVCache) commit(n int) {
target := max(c.start, c.start+n)
if target > c.end {
target = c.end
}
c.target.offset = target
c.offset = target
}
type isolatedKVCache struct {
speculativeBase
target *KVCache
keys, values *mlx.Array
}
func newIsolatedKVCache(target *KVCache) *isolatedKVCache {
return &isolatedKVCache{
speculativeBase: speculativeBase{offset: target.Offset()},
target: target,
}
}
func (c *isolatedKVCache) Update(_ *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
c.keys = concatKV(c.keys, keys)
c.values = concatKV(c.values, values)
c.offset += keys.Dim(2)
state := c.target.State()
if len(state) < 2 {
return nn.NewKVHistory(c.keys, c.values, nil)
}
return nn.NewKVHistory(state[0].Concatenate(2, c.keys), state[1].Concatenate(2, c.values), nil)
}
func (c *isolatedKVCache) State() []*mlx.Array {
if c.keys == nil || c.values == nil {
return c.target.State()
}
state := c.target.State()
if len(state) < 2 {
return []*mlx.Array{c.keys, c.values}
}
return []*mlx.Array{
state[0].Concatenate(2, c.keys),
state[1].Concatenate(2, c.values),
}
}
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
maxSize int
idx int
*KVCache
}
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()}
}
// Assumes B = 1; heterogeneous batches are not supported.
func (c *RotatingKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
newK, newV := c.appendKV(keys, values)
return nn.NewKVHistory(newK, newV, rotatingApplier{
b: b,
K: newK.Dim(2),
L: keys.Dim(2),
window: c.maxSize,
ringIdx: c.idx,
dtype: keys.DType(),
})
}
// View returns the current rotating cache contents in logical order for
// assistant KV sharing.
func (c *RotatingKVCache) View(_ *batch.Batch) *nn.KVHistory {
k, v := c.logicalTail(c.maxSize - 1)
if k == nil || v == nil {
return nil
}
return nn.NewKVHistory(k, v, nil)
}
func (c *RotatingKVCache) logicalTail(keep int) (*mlx.Array, *mlx.Array) {
state := c.State()
if len(state) < 2 || keep <= 0 {
return nil, nil
}
keys, values := state[0], state[1]
K := keys.Dim(2)
if K == 0 {
return nil, nil
}
keep = min(keep, K)
if K > c.maxSize || c.offset < c.maxSize {
start := K - keep
return keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice()),
values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice())
}
oldest := c.idx % K
var logicalK, logicalV *mlx.Array
if oldest == 0 {
logicalK, logicalV = keys, values
} else {
tailK := keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(oldest, K), mlx.Slice())
tailV := values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(oldest, K), mlx.Slice())
headK := keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, oldest), mlx.Slice())
headV := values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, oldest), mlx.Slice())
logicalK = tailK.Concatenate(2, headK)
logicalV = tailV.Concatenate(2, headV)
}
start := K - keep
return logicalK.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice()),
logicalV.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice())
}
type speculativeRotatingKVCache struct {
speculativeBase
target *RotatingKVCache
keys, values *mlx.Array
}
func newSpeculativeRotatingKVCache(target *RotatingKVCache) *speculativeRotatingKVCache {
return &speculativeRotatingKVCache{
speculativeBase: speculativeBase{offset: target.Offset()},
target: target,
}
}
func (c *speculativeRotatingKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
c.keys = concatKV(c.keys, keys)
c.values = concatKV(c.values, values)
c.offset += keys.Dim(2)
oldK, oldV := c.target.logicalTail(c.target.maxSize - 1)
histK, histV := c.keys, c.values
if oldK != nil && oldV != nil {
histK = oldK.Concatenate(2, c.keys)
histV = oldV.Concatenate(2, c.values)
}
return nn.NewKVHistory(histK, histV, logicalSlidingApplier{
b: b,
K: histK.Dim(2),
window: c.target.maxSize,
dtype: keys.DType(),
})
}
func (c *speculativeRotatingKVCache) State() []*mlx.Array {
if c.keys == nil || c.values == nil {
return c.target.State()
}
oldK, oldV := c.target.logicalTail(c.target.maxSize - 1)
if oldK == nil || oldV == nil {
return []*mlx.Array{c.keys, c.values}
}
return []*mlx.Array{oldK.Concatenate(2, c.keys), oldV.Concatenate(2, c.values)}
}
func (c *speculativeRotatingKVCache) commit(n int) {
if c.keys == nil || c.values == nil || n <= 0 {
return
}
n = min(n, c.keys.Dim(2))
c.target.appendKV(prefixKV(c.keys, n), prefixKV(c.values, n))
}
type logicalSlidingApplier struct {
b *batch.Batch
K int
window int
dtype mlx.DType
}
func (a logicalSlidingApplier) ApplyMask(logical nn.AttentionMask) nn.AttentionMask {
return logical.Intersect(nn.SlidingWindowMask(a.b, a.K, a.window, a.dtype))
}
func concatKV(prev, next *mlx.Array) *mlx.Array {
if prev == nil {
return next
}
return prev.Concatenate(2, next)
}
func prefixKV(a *mlx.Array, n int) *mlx.Array {
return a.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, n), mlx.Slice())
}
// appendKV is the raw write path shared by Update and Restore —
// routes to concat for prefill (L > 1) and update for decode.
func (c *RotatingKVCache) appendKV(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
if keys.Dim(2) > 1 {
return c.concat(keys, values)
}
return c.update(keys, values)
}
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
logutil.Trace("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
if c.keys == nil {
c.keys, c.values = keys.Clone(), values.Clone()
mlx.Pin(c.keys, c.values)
} else {
if c.idx < c.keys.Dim(2) {
if c.offset <= c.maxSize {
// Not yet wrapped: slots [c.idx, Dim) are grow padding
// or stale post-rewind data, not live window content.
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
} else {
// Wrapped: logical order is slots[idx..Dim) then slots[0..idx).
// Linearize so the trim + concat below operate on contiguous
// positions and preserve the last (maxSize - 1) old tokens.
tailK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.keys.Dim(2)), mlx.Slice())
tailV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.values.Dim(2)), mlx.Slice())
headK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
headV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
c.keys.Set(tailK.Concatenate(2, headK))
c.values.Set(tailV.Concatenate(2, headV))
c.idx = c.keys.Dim(2)
}
}
// Trim to max_size to maintain sliding window
if trim := c.idx - c.maxSize + 1; trim > 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
}
c.keys.Set(c.keys.Concatenate(2, keys))
c.values.Set(c.values.Concatenate(2, values))
c.idx = c.keys.Dim(2)
}
c.offset += keys.Dim(2)
c.idx = c.keys.Dim(2)
return c.keys, c.values
}
func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
logutil.Trace("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
prev := c.offset
// Grow buffer if not yet at max
if c.keys == nil || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) {
newSize := min(c.step, c.maxSize-prev)
newKeys := mlx.Zeros(keys.DType(), B, H, newSize, Dk)
newValues := mlx.Zeros(values.DType(), B, H, newSize, Dv)
if c.keys != nil {
c.keys.Set(c.keys.Concatenate(2, newKeys))
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
mlx.Pin(c.keys, c.values)
}
c.idx = prev
}
// Trim to max_size to maintain sliding window
if trim := c.keys.Dim(2) - c.maxSize; trim > 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
c.idx = c.maxSize
}
// Rotate when hitting max
if c.idx >= c.maxSize {
c.idx = 0
}
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
c.offset += L
c.idx += L
validLen := min(c.offset, c.maxSize)
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
}
func (c *RotatingKVCache) State() []*mlx.Array {
if c.keys == nil || c.values == nil {
return nil
}
liveLen := min(c.offset, c.keys.Dim(2))
return []*mlx.Array{
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
}
}
// rotatingSnapshot holds paged-out data for a RotatingKVCache.
type rotatingSnapshot struct {
kvSnapshot // embedded KV data
idx int // buffer write position at snapshot time
}
func (s *rotatingSnapshot) Size() int { return s.kvSnapshot.Size() }
func (s *rotatingSnapshot) Close() { s.kvSnapshot.Close() }
func (c *RotatingKVCache) Snapshot(fromOffset int) Snapshot {
if c.keys == nil || c.offset <= fromOffset {
return nil
}
state := c.State()
k := state[0].Clone()
v := state[1].Clone()
mlx.Pin(k, v)
return &rotatingSnapshot{
kvSnapshot: kvSnapshot{
keys: k,
values: v,
fromOffset: fromOffset,
toOffset: c.offset,
},
idx: c.idx,
}
}
func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
if target < 0 {
return false
}
if snapshot == nil {
if target >= c.offset {
return target == c.offset
}
// Live rewind is only safe when the buffer hasn't filled yet
// (offset <= maxSize). Once the window has shifted, rewinding
// leaves fewer than maxSize trailing tokens to attend to —
// a snapshot is required to restore the full window.
if c.offset > c.maxSize {
return false
}
c.offset = target
c.idx = target
return true
}
snap := snapshot.(*rotatingSnapshot)
if target > snap.toOffset {
return false
}
// Reject if clamping would leave an incomplete window.
if target < snap.toOffset && snap.toOffset > c.maxSize {
return false
}
// Restore from snapshot: rebuild buffer state.
// Free existing state first.
if c.keys != nil {
mlx.Unpin(c.keys, c.values)
}
c.keys = snap.keys.Clone()
c.values = snap.values.Clone()
mlx.Pin(c.keys, c.values)
c.offset = snap.toOffset
c.idx = snap.idx
// Clamp to target if needed.
if target < c.offset {
c.offset = target
c.idx = target
}
return true
}
func (c *RotatingKVCache) Merge(parent, child Snapshot) Snapshot {
// For rotating caches, the child snapshot supersedes the parent
// since it contains the full window state.
if parent != nil {
parent.Close()
}
return child
}
func (c *RotatingKVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
// Rotating cache snapshots contain the full window state.
// Cannot cleanly split a ring buffer at an arbitrary point.
return nil, snapshot
}
func (c *RotatingKVCache) Free() {
c.KVCache.Free()
c.idx = 0
}
// rotatingApplier composes the sliding-window storage restriction
// onto the caller's logical mask.
//
// ringIdx is the cache's write cursor at Update time. At L=1 decode
// the ring buffer is not position-ordered — logical col j lives at
// storage slot (ringIdx+j) mod K — so tensor masks built in
// logical space must be gathered into this layout before the kernel
// sees them. At L>1 prefill the concat path has already linearised
// storage, so the gather is identity and ringIdx is unused.
type rotatingApplier struct {
b *batch.Batch
K int
L int
window int
ringIdx int
dtype mlx.DType
}
func (r rotatingApplier) ApplyMask(logical nn.AttentionMask) nn.AttentionMask {
if r.L == 1 {
// Single-query decode: storage already enforces the window
// (Update keeps the last maxSize tokens, all within
// [absQ-window+1, absQ]), and every stored key's absolute
// position <= absQ. For a zero or plain-causal logical mask
// both constraints reduce to "no mask", so return the zero
// mask and let SDPA dispatch to mode="".
if logical.IsZero() || logical.IsCausal() {
return nn.AttentionMask{}
}
// Tensor-backed mask (user ArrayMask, causal+Relax, causal
// with accumulated array): materialize in logical-position
// order then gather K cols into ring-slot order so they
// align with the cache output the kernel will index.
arr := logical.AsArray(r.b, r.K, r.dtype)
arr = gatherRingCols(arr, r.ringIdx, r.K)
return nn.ArrayMask(arr)
}
return logical.Intersect(nn.SlidingWindowMask(r.b, r.K, r.window, r.dtype))
}
// gatherRingCols reorders a [B, 1, L, K] mask's K axis from
// logical-position order (col 0 = oldest stored position) into the
// cache's ring-slot order (col 0 = buffer slot 0). Logical col j
// lives at slot (ringIdx+j) mod K, so storage slot s reads from
// logical col (s-ringIdx+K) mod K. Returns arr unchanged when the
// permutation is a no-op: ringIdx % K == 0 (layouts coincide), or
// the K axis broadcasts (dim 3 == 1, i.e. Q-padding-shaped masks
// where every key shares the same value).
func gatherRingCols(arr *mlx.Array, ringIdx, K int) *mlx.Array {
if w := arr.Dim(3); w != 1 && w != K {
panic(fmt.Sprintf("gatherRingCols: K-axis width %d must be 1 or %d", w, K))
}
ringIdx %= K
if ringIdx == 0 || arr.Dim(3) == 1 {
return arr
}
shift := K - ringIdx
tail := arr.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(shift, K))
head := arr.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(0, shift))
return tail.Concatenate(3, head)
}