mlxrunner: apply RoPE at per-row positions

Switch RoPE from the scalar-offset kernel (mlx_fast_rope) to the
array-offset one (mlx_fast_rope_dynamic) so each batch row can start
at its own position. The pipeline tracks the current position locally
and passes it to the model through Batch.SeqOffsets; each model
materializes that slice into an int32 array for the RoPE call.

Single-sequence behavior is unchanged; this is the wiring needed
before the runner can batch independent sequences.
This commit is contained in:
Jesse Gross 2026-04-21 16:15:59 -07:00
parent 088dfd89a8
commit bd21678b16
10 changed files with 75 additions and 108 deletions

View file

@ -6,4 +6,9 @@ import "github.com/ollama/ollama/x/mlxrunner/mlx"
type Batch struct {
// InputIDs is the input token IDs for this forward pass, shape (B, L).
InputIDs *mlx.Array
// SeqOffsets gives each row's current position within its sequence —
// where the chunk in InputIDs starts. Length equals the batch dimension
// of InputIDs.
SeqOffsets []int32
}

View file

@ -44,29 +44,3 @@ func (r *RMSNorm) Forward(x *Array, eps float32) *Array {
return out
}
type RoPE struct {
Dims int
Traditional bool
Base float32 `json:"rope_theta"`
Scale float32
}
func (r RoPE) Forward(t *Array, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE")
C.mlx_fast_rope(
&out.ctx,
t.ctx,
C.int(r.Dims),
C._Bool(r.Traditional),
C.mlx_optional_float{
value: C.float(r.Base),
has_value: C._Bool(func() bool { return r.Base != 0 }()),
},
C.float(r.Scale),
C.int(offset),
freqs.ctx,
DefaultStream().ctx,
)
return out
}

View file

@ -407,15 +407,18 @@ func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *A
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
}
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil)
// RoPEWithBase applies rotary position embeddings to x. offsets is an
// int32 array of shape [B] giving each batch row's starting position;
// the kernel applies positions offsets[b] + 0..T-1 per row.
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offsets *Array) *Array {
return RoPEWithFreqs(x, dims, traditional, base, scale, offsets, nil)
}
// RoPEWithFreqs applies RoPE with optional custom frequencies.
// When freqs is non-nil, it is used instead of computing from base.
// Note: MLX takes reciprocal(freqs) internally to get inv_freq, so pass
// the actual frequencies (base^(2i/dim)), not the inverse frequencies.
func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, offset int, freqs *Array) *Array {
func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, offsets *Array, freqs *Array) *Array {
var freqsCtx C.mlx_array
var optBase C.mlx_optional_float
if freqs != nil {
@ -430,14 +433,14 @@ func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, of
}
}
out := New("FAST_ROPE")
C.mlx_fast_rope(
C.mlx_fast_rope_dynamic(
&out.ctx,
x.ctx,
C.int(dims),
C.bool(traditional),
optBase,
C.float(scale),
C.int(offset),
offsets.ctx,
freqsCtx,
DefaultStream().ctx,
)

View file

@ -106,6 +106,7 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er
now := time.Now()
total, processed := len(tokens), 0
position := len(inputs) - len(tokens)
for total-processed > 1 {
if err := ctx.Err(); err != nil {
return err
@ -116,23 +117,25 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er
// If there's a pending snapshot, split the batch so we can
// capture it at the exact offset.
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
baseOffset := len(session.inputs) - len(tokens)
tokensUntilSnapshot := snapOffset - (baseOffset + processed)
tokensUntilSnapshot := snapOffset - position
if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n {
n = tokensUntilSnapshot
}
}
r.Model.Forward(&batch.Batch{InputIDs: mlx.FromValues(tokens[processed:processed+n], 1, n)}, caches)
r.Model.Forward(&batch.Batch{
InputIDs: mlx.FromValues(tokens[processed:processed+n], 1, n),
SeqOffsets: []int32{int32(position)},
}, caches)
mlx.Sweep()
materializeCaches()
processed += n
position += n
slog.Info("Prompt processing progress", "processed", processed, "total", total)
// Create snapshot if we've reached a pending offset.
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
baseOffset := len(session.inputs) - len(tokens)
if baseOffset+processed >= snapOffset {
if position >= snapOffset {
session.snapshot()
}
}
@ -144,7 +147,11 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er
r.Sampler.Add(pipelineSlot, request.SamplerOpts, inputs)
step := func(token *mlx.Array) sampler.Result {
fwd := r.Model.Forward(&batch.Batch{InputIDs: token}, caches)
fwd := r.Model.Forward(&batch.Batch{
InputIDs: token,
SeqOffsets: []int32{int32(position)},
}, caches)
position += token.Dim(1)
logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)

View file

@ -406,6 +406,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
dims := b.InputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
h := m.EmbedTokens.Forward(b.InputIDs)
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
@ -415,7 +416,7 @@ func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.TextConfig)
h = layer.Forward(h, c, positions, B, L, m.TextConfig)
}
return mlx.RMSNormFn(h, m.NormScaled, m.RMSNormEps)
@ -455,10 +456,10 @@ func (m *Model) FormatPrompt(prompt string) string {
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *TextConfig) *mlx.Array {
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
attnOut := l.Attention.Forward(normed, c, positions, B, L, l.IsSliding, cfg)
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
@ -470,7 +471,7 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
return mlx.Add(h, mlpOut)
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
@ -492,12 +493,8 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
ropeTheta = cfg.RopeLocalBaseFreq
}
offset := 0
if c != nil {
offset = c.Offset()
}
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, positions)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, positions)
if c != nil {
k, v = c.Update(k, v)

View file

@ -90,8 +90,7 @@ type TextConfig struct {
// sharedKVEntry stores cached KV state from a donor layer for KV sharing.
type sharedKVEntry struct {
K, V *mlx.Array
Offset int // RoPE offset from donor's cache
K, V *mlx.Array
}
// Attention implements Gemma 4 attention with Q/K normalization and v-norm.
@ -1017,6 +1016,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
dims := b.InputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
h := m.EmbedTokens.Forward(b.InputIDs)
h = mlx.MulScalar(h, m.EmbedScale)
@ -1063,7 +1063,7 @@ func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
}
var donorKV *sharedKVEntry
h, donorKV = layer.Forward(h, c, B, L, m.TextConfig, pleInput, donorEntry, smc)
h, donorKV = layer.Forward(h, c, positions, B, L, m.TextConfig, pleInput, donorEntry, smc)
// If this layer is a donor, store its cached KV for later shared layers.
if layer.IsDonor && donorKV != nil {
@ -1189,9 +1189,9 @@ func sliceLayerDim(combined *mlx.Array, layerIdx, B, L, pleDim int32) *mlx.Array
return mlx.Squeeze(sliced, 2)
}
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
attnOut, donorKV := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache)
attnOut, donorKV := l.Attention.Forward(normed, c, positions, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache)
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
@ -1239,7 +1239,7 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
return h, donorKV
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
// Determine head dim and scale based on layer type.
headDim := cfg.HeadDim
scale := cfg.SlidingScale
@ -1259,18 +1259,11 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
// Apply Q norm.
q = mlx.RMSNormFn(q, a.QNormScaled, cfg.RMSNormEps)
// RoPE offset: use cache offset for non-shared layers, donor offset for shared.
offset := 0
if donorEntry != nil {
offset = donorEntry.Offset - int(L)
} else if c != nil {
offset = c.Offset()
}
var ropeFreqs *mlx.Array
if !isSliding {
ropeFreqs = cfg.FullRopeFreqs
}
q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, offset, ropeFreqs)
q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, positions, ropeFreqs)
var k, v *mlx.Array
var donorKV *sharedKVEntry
@ -1304,7 +1297,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
k = mlx.RMSNormFn(k, a.KNormScaled, cfg.RMSNormEps)
// Apply RoPE to K.
k = mlx.RoPEWithFreqs(k, ropeDims, false, ropeBase, 1.0, offset, ropeFreqs)
k = mlx.RoPEWithFreqs(k, ropeDims, false, ropeBase, 1.0, positions, ropeFreqs)
// Apply V norm (no learnable weight, pure RMS normalization).
v = mlx.RMSNormFn(v, nil, cfg.RMSNormEps)
@ -1312,7 +1305,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
// Update cache.
if c != nil {
k, v = c.Update(k, v)
donorKV = &sharedKVEntry{K: k, V: v, Offset: c.Offset()}
donorKV = &sharedKVEntry{K: k, V: v}
}
}

View file

@ -88,7 +88,7 @@ type MLAAttention struct {
}
// Forward computes absorbed MLA attention output.
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
q := a.QAProj.Forward(x)
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
q = a.QBProj.Forward(q)
@ -110,12 +110,8 @@ func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Con
kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
kvLatent = mlx.ExpandDims(kvLatent, 1)
offset := 0
if c != nil {
offset = c.Offset()
}
qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, positions)
kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, positions)
qLatent := a.EmbedQ.Forward(qNope)
@ -310,11 +306,11 @@ type DenseBlock struct {
}
// Forward applies the dense block
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
func (blk *DenseBlock) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
r := blk.Attention.Forward(blk.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, positions, B, L, cfg)
h := mlx.Add(x, r)
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
r = blk.MLP.Forward(blk.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
return mlx.Add(h, r)
}
@ -327,17 +323,17 @@ type MoEBlock struct {
}
// Forward applies the MoE block
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
func (blk *MoEBlock) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
r := blk.Attention.Forward(blk.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, positions, B, L, cfg)
h := mlx.Add(x, r)
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
r = blk.MoE.Forward(blk.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
return mlx.Add(h, r)
}
// Block interface for both dense and MoE blocks
type Block interface {
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array
}
// Model represents the complete GLM4-MoE-Lite model
@ -702,6 +698,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
dims := b.InputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
h := m.EmbedTokens.Forward(b.InputIDs)
@ -710,7 +707,7 @@ func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
if caches != nil {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
h = layer.Forward(h, c, positions, B, L, m.Config)
}
h = m.Norm.Forward(h, m.RMSNormEps)

View file

@ -240,6 +240,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
dims := b.InputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
h := m.EmbedTokens.Forward(b.InputIDs)
for i, layer := range m.Layers {
@ -247,7 +248,7 @@ func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
h = layer.Forward(h, c, positions, B, L, m.Config)
}
return m.Norm.Forward(h, m.RMSNormEps)
@ -277,12 +278,12 @@ func (m *Model) NewCaches() []cache.Cache {
return caches
}
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, positions, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
@ -296,12 +297,8 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
v = mlx.Transpose(v, 0, 2, 1, 3)
offset := 0
if c != nil {
offset = c.Offset()
}
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
if c != nil {
k, v = c.Update(k, v)

View file

@ -257,6 +257,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
dims := b.InputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
h := m.EmbedTokens.Forward(b.InputIDs)
for i, layer := range m.Layers {
@ -264,7 +265,7 @@ func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
h = layer.Forward(h, c, positions, B, L, m.Config)
}
return m.Norm.Forward(h, m.RMSNormEps)
@ -294,12 +295,12 @@ func (m *Model) NewCaches() []cache.Cache {
return caches
}
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, positions, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
@ -315,12 +316,8 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
offset := 0
if c != nil {
offset = c.Offset()
}
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
if c != nil {
k, v = c.Update(k, v)

View file

@ -1127,7 +1127,7 @@ func splitQKVZBA(mixedQKVZ, mixedBA *mlx.Array, cfg *Config, B, L int32) (q, k,
return q, k, v, z, b, a
}
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
qg := a.QProj.Forward(x)
qg = mlx.Reshape(qg, B, L, cfg.NumAttentionHeads, cfg.HeadDim*2)
q := mlx.SliceStartStop(qg, []int32{0, 0, 0, 0}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim})
@ -1146,12 +1146,8 @@ func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
offset := 0
if c != nil {
offset = c.Offset()
}
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, positions)
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, positions)
if c != nil {
k, v = c.Update(k, v)
@ -1333,13 +1329,13 @@ func (m *SparseMoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
return mlx.Reshape(y, B, L, cfg.HiddenSize)
}
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
var r *mlx.Array
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
if l.IsLinear {
r = l.Linear.Forward(normed, c, B, L, cfg)
} else {
r = l.FullAttn.Forward(normed, c, B, L, cfg)
r = l.FullAttn.Forward(normed, c, positions, B, L, cfg)
}
h := mlx.Add(x, r)
r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg)
@ -1349,6 +1345,7 @@ func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *m
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
dims := b.InputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
h := m.EmbedTokens.Forward(b.InputIDs)
for i, layer := range m.Layers {
@ -1356,7 +1353,7 @@ func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
h = layer.Forward(h, c, positions, B, L, m.Config)
}
out := m.Norm.Forward(h, m.RMSNormEps)
return out