mirror of
https://github.com/ollama/ollama.git
synced 2026-05-13 06:21:28 +00:00
mlxrunner: apply RoPE at per-row positions
Switch RoPE from the scalar-offset kernel (mlx_fast_rope) to the array-offset one (mlx_fast_rope_dynamic) so each batch row can start at its own position. The pipeline tracks the current position locally and passes it to the model through Batch.SeqOffsets; each model materializes that slice into an int32 array for the RoPE call. Single-sequence behavior is unchanged; this is the wiring needed before the runner can batch independent sequences.
This commit is contained in:
parent
088dfd89a8
commit
bd21678b16
10 changed files with 75 additions and 108 deletions
|
|
@ -6,4 +6,9 @@ import "github.com/ollama/ollama/x/mlxrunner/mlx"
|
|||
type Batch struct {
|
||||
// InputIDs is the input token IDs for this forward pass, shape (B, L).
|
||||
InputIDs *mlx.Array
|
||||
|
||||
// SeqOffsets gives each row's current position within its sequence —
|
||||
// where the chunk in InputIDs starts. Length equals the batch dimension
|
||||
// of InputIDs.
|
||||
SeqOffsets []int32
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,29 +44,3 @@ func (r *RMSNorm) Forward(x *Array, eps float32) *Array {
|
|||
return out
|
||||
}
|
||||
|
||||
type RoPE struct {
|
||||
Dims int
|
||||
Traditional bool
|
||||
Base float32 `json:"rope_theta"`
|
||||
Scale float32
|
||||
}
|
||||
|
||||
func (r RoPE) Forward(t *Array, offset int) *Array {
|
||||
freqs := New("")
|
||||
out := New("FAST_ROPE")
|
||||
C.mlx_fast_rope(
|
||||
&out.ctx,
|
||||
t.ctx,
|
||||
C.int(r.Dims),
|
||||
C._Bool(r.Traditional),
|
||||
C.mlx_optional_float{
|
||||
value: C.float(r.Base),
|
||||
has_value: C._Bool(func() bool { return r.Base != 0 }()),
|
||||
},
|
||||
C.float(r.Scale),
|
||||
C.int(offset),
|
||||
freqs.ctx,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
|
|
|||
|
|
@ -407,15 +407,18 @@ func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *A
|
|||
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
|
||||
}
|
||||
|
||||
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
||||
return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil)
|
||||
// RoPEWithBase applies rotary position embeddings to x. offsets is an
|
||||
// int32 array of shape [B] giving each batch row's starting position;
|
||||
// the kernel applies positions offsets[b] + 0..T-1 per row.
|
||||
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offsets *Array) *Array {
|
||||
return RoPEWithFreqs(x, dims, traditional, base, scale, offsets, nil)
|
||||
}
|
||||
|
||||
// RoPEWithFreqs applies RoPE with optional custom frequencies.
|
||||
// When freqs is non-nil, it is used instead of computing from base.
|
||||
// Note: MLX takes reciprocal(freqs) internally to get inv_freq, so pass
|
||||
// the actual frequencies (base^(2i/dim)), not the inverse frequencies.
|
||||
func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, offset int, freqs *Array) *Array {
|
||||
func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, offsets *Array, freqs *Array) *Array {
|
||||
var freqsCtx C.mlx_array
|
||||
var optBase C.mlx_optional_float
|
||||
if freqs != nil {
|
||||
|
|
@ -430,14 +433,14 @@ func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, of
|
|||
}
|
||||
}
|
||||
out := New("FAST_ROPE")
|
||||
C.mlx_fast_rope(
|
||||
C.mlx_fast_rope_dynamic(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
C.int(dims),
|
||||
C.bool(traditional),
|
||||
optBase,
|
||||
C.float(scale),
|
||||
C.int(offset),
|
||||
offsets.ctx,
|
||||
freqsCtx,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -106,6 +106,7 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er
|
|||
|
||||
now := time.Now()
|
||||
total, processed := len(tokens), 0
|
||||
position := len(inputs) - len(tokens)
|
||||
for total-processed > 1 {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
|
|
@ -116,23 +117,25 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er
|
|||
// If there's a pending snapshot, split the batch so we can
|
||||
// capture it at the exact offset.
|
||||
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
|
||||
baseOffset := len(session.inputs) - len(tokens)
|
||||
tokensUntilSnapshot := snapOffset - (baseOffset + processed)
|
||||
tokensUntilSnapshot := snapOffset - position
|
||||
if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n {
|
||||
n = tokensUntilSnapshot
|
||||
}
|
||||
}
|
||||
|
||||
r.Model.Forward(&batch.Batch{InputIDs: mlx.FromValues(tokens[processed:processed+n], 1, n)}, caches)
|
||||
r.Model.Forward(&batch.Batch{
|
||||
InputIDs: mlx.FromValues(tokens[processed:processed+n], 1, n),
|
||||
SeqOffsets: []int32{int32(position)},
|
||||
}, caches)
|
||||
mlx.Sweep()
|
||||
materializeCaches()
|
||||
processed += n
|
||||
position += n
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
|
||||
// Create snapshot if we've reached a pending offset.
|
||||
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
|
||||
baseOffset := len(session.inputs) - len(tokens)
|
||||
if baseOffset+processed >= snapOffset {
|
||||
if position >= snapOffset {
|
||||
session.snapshot()
|
||||
}
|
||||
}
|
||||
|
|
@ -144,7 +147,11 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er
|
|||
r.Sampler.Add(pipelineSlot, request.SamplerOpts, inputs)
|
||||
|
||||
step := func(token *mlx.Array) sampler.Result {
|
||||
fwd := r.Model.Forward(&batch.Batch{InputIDs: token}, caches)
|
||||
fwd := r.Model.Forward(&batch.Batch{
|
||||
InputIDs: token,
|
||||
SeqOffsets: []int32{int32(position)},
|
||||
}, caches)
|
||||
position += token.Dim(1)
|
||||
logits := r.Model.Unembed(fwd)
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
|
||||
|
|
|
|||
|
|
@ -406,6 +406,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
|||
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
||||
dims := b.InputIDs.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
|
||||
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
|
||||
|
|
@ -415,7 +416,7 @@ func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
|||
if caches != nil && i < len(caches) {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.TextConfig)
|
||||
h = layer.Forward(h, c, positions, B, L, m.TextConfig)
|
||||
}
|
||||
|
||||
return mlx.RMSNormFn(h, m.NormScaled, m.RMSNormEps)
|
||||
|
|
@ -455,10 +456,10 @@ func (m *Model) FormatPrompt(prompt string) string {
|
|||
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
|
||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
|
||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *TextConfig) *mlx.Array {
|
||||
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
|
||||
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
|
||||
attnOut := l.Attention.Forward(normed, c, positions, B, L, l.IsSliding, cfg)
|
||||
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := mlx.Add(x, attnOut)
|
||||
|
||||
|
|
@ -470,7 +471,7 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
|
|||
return mlx.Add(h, mlpOut)
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
|
@ -492,12 +493,8 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
|||
ropeTheta = cfg.RopeLocalBaseFreq
|
||||
}
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, positions)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, positions)
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
|
|
|
|||
|
|
@ -90,8 +90,7 @@ type TextConfig struct {
|
|||
|
||||
// sharedKVEntry stores cached KV state from a donor layer for KV sharing.
|
||||
type sharedKVEntry struct {
|
||||
K, V *mlx.Array
|
||||
Offset int // RoPE offset from donor's cache
|
||||
K, V *mlx.Array
|
||||
}
|
||||
|
||||
// Attention implements Gemma 4 attention with Q/K normalization and v-norm.
|
||||
|
|
@ -1017,6 +1016,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
|||
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
||||
dims := b.InputIDs.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
h = mlx.MulScalar(h, m.EmbedScale)
|
||||
|
||||
|
|
@ -1063,7 +1063,7 @@ func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
|||
}
|
||||
|
||||
var donorKV *sharedKVEntry
|
||||
h, donorKV = layer.Forward(h, c, B, L, m.TextConfig, pleInput, donorEntry, smc)
|
||||
h, donorKV = layer.Forward(h, c, positions, B, L, m.TextConfig, pleInput, donorEntry, smc)
|
||||
|
||||
// If this layer is a donor, store its cached KV for later shared layers.
|
||||
if layer.IsDonor && donorKV != nil {
|
||||
|
|
@ -1189,9 +1189,9 @@ func sliceLayerDim(combined *mlx.Array, layerIdx, B, L, pleDim int32) *mlx.Array
|
|||
return mlx.Squeeze(sliced, 2)
|
||||
}
|
||||
|
||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
|
||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
|
||||
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
attnOut, donorKV := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache)
|
||||
attnOut, donorKV := l.Attention.Forward(normed, c, positions, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache)
|
||||
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := mlx.Add(x, attnOut)
|
||||
|
||||
|
|
@ -1239,7 +1239,7 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
|
|||
return h, donorKV
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
|
||||
// Determine head dim and scale based on layer type.
|
||||
headDim := cfg.HeadDim
|
||||
scale := cfg.SlidingScale
|
||||
|
|
@ -1259,18 +1259,11 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
|||
// Apply Q norm.
|
||||
q = mlx.RMSNormFn(q, a.QNormScaled, cfg.RMSNormEps)
|
||||
|
||||
// RoPE offset: use cache offset for non-shared layers, donor offset for shared.
|
||||
offset := 0
|
||||
if donorEntry != nil {
|
||||
offset = donorEntry.Offset - int(L)
|
||||
} else if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
var ropeFreqs *mlx.Array
|
||||
if !isSliding {
|
||||
ropeFreqs = cfg.FullRopeFreqs
|
||||
}
|
||||
q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, offset, ropeFreqs)
|
||||
q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, positions, ropeFreqs)
|
||||
|
||||
var k, v *mlx.Array
|
||||
var donorKV *sharedKVEntry
|
||||
|
|
@ -1304,7 +1297,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
|||
k = mlx.RMSNormFn(k, a.KNormScaled, cfg.RMSNormEps)
|
||||
|
||||
// Apply RoPE to K.
|
||||
k = mlx.RoPEWithFreqs(k, ropeDims, false, ropeBase, 1.0, offset, ropeFreqs)
|
||||
k = mlx.RoPEWithFreqs(k, ropeDims, false, ropeBase, 1.0, positions, ropeFreqs)
|
||||
|
||||
// Apply V norm (no learnable weight, pure RMS normalization).
|
||||
v = mlx.RMSNormFn(v, nil, cfg.RMSNormEps)
|
||||
|
|
@ -1312,7 +1305,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
|||
// Update cache.
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
donorKV = &sharedKVEntry{K: k, V: v, Offset: c.Offset()}
|
||||
donorKV = &sharedKVEntry{K: k, V: v}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ type MLAAttention struct {
|
|||
}
|
||||
|
||||
// Forward computes absorbed MLA attention output.
|
||||
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
|
||||
q := a.QAProj.Forward(x)
|
||||
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
|
||||
q = a.QBProj.Forward(q)
|
||||
|
|
@ -110,12 +110,8 @@ func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Con
|
|||
kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
|
||||
kvLatent = mlx.ExpandDims(kvLatent, 1)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, positions)
|
||||
kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, positions)
|
||||
|
||||
qLatent := a.EmbedQ.Forward(qNope)
|
||||
|
||||
|
|
@ -310,11 +306,11 @@ type DenseBlock struct {
|
|||
}
|
||||
|
||||
// Forward applies the dense block
|
||||
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
||||
func (blk *DenseBlock) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
|
||||
r := blk.Attention.Forward(blk.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, positions, B, L, cfg)
|
||||
h := mlx.Add(x, r)
|
||||
|
||||
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
|
||||
r = blk.MLP.Forward(blk.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
|
||||
return mlx.Add(h, r)
|
||||
}
|
||||
|
||||
|
|
@ -327,17 +323,17 @@ type MoEBlock struct {
|
|||
}
|
||||
|
||||
// Forward applies the MoE block
|
||||
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
||||
func (blk *MoEBlock) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
|
||||
r := blk.Attention.Forward(blk.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, positions, B, L, cfg)
|
||||
h := mlx.Add(x, r)
|
||||
|
||||
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
|
||||
r = blk.MoE.Forward(blk.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
|
||||
return mlx.Add(h, r)
|
||||
}
|
||||
|
||||
// Block interface for both dense and MoE blocks
|
||||
type Block interface {
|
||||
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
|
||||
Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array
|
||||
}
|
||||
|
||||
// Model represents the complete GLM4-MoE-Lite model
|
||||
|
|
@ -702,6 +698,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
|||
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
||||
dims := b.InputIDs.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
|
||||
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
|
||||
|
|
@ -710,7 +707,7 @@ func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
|||
if caches != nil {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
h = layer.Forward(h, c, positions, B, L, m.Config)
|
||||
}
|
||||
|
||||
h = m.Norm.Forward(h, m.RMSNormEps)
|
||||
|
|
|
|||
|
|
@ -240,6 +240,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
|||
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
||||
dims := b.InputIDs.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
|
||||
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
for i, layer := range m.Layers {
|
||||
|
|
@ -247,7 +248,7 @@ func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
|||
if caches != nil && i < len(caches) {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
h = layer.Forward(h, c, positions, B, L, m.Config)
|
||||
}
|
||||
|
||||
return m.Norm.Forward(h, m.RMSNormEps)
|
||||
|
|
@ -277,12 +278,12 @@ func (m *Model) NewCaches() []cache.Cache {
|
|||
return caches
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, positions, B, L, cfg))
|
||||
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
|
@ -296,12 +297,8 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
|
|||
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
|
|
|
|||
|
|
@ -257,6 +257,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
|||
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
||||
dims := b.InputIDs.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
|
||||
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
for i, layer := range m.Layers {
|
||||
|
|
@ -264,7 +265,7 @@ func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
|||
if caches != nil && i < len(caches) {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
h = layer.Forward(h, c, positions, B, L, m.Config)
|
||||
}
|
||||
|
||||
return m.Norm.Forward(h, m.RMSNormEps)
|
||||
|
|
@ -294,12 +295,12 @@ func (m *Model) NewCaches() []cache.Cache {
|
|||
return caches
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, positions, B, L, cfg))
|
||||
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
|
@ -315,12 +316,8 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
|
|||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
|
|
|
|||
|
|
@ -1127,7 +1127,7 @@ func splitQKVZBA(mixedQKVZ, mixedBA *mlx.Array, cfg *Config, B, L int32) (q, k,
|
|||
return q, k, v, z, b, a
|
||||
}
|
||||
|
||||
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
|
||||
qg := a.QProj.Forward(x)
|
||||
qg = mlx.Reshape(qg, B, L, cfg.NumAttentionHeads, cfg.HeadDim*2)
|
||||
q := mlx.SliceStartStop(qg, []int32{0, 0, 0, 0}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim})
|
||||
|
|
@ -1146,12 +1146,8 @@ func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
|
|||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, positions)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, positions)
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
|
|
@ -1333,13 +1329,13 @@ func (m *SparseMoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
|
|||
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
|
||||
var r *mlx.Array
|
||||
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
|
||||
if l.IsLinear {
|
||||
r = l.Linear.Forward(normed, c, B, L, cfg)
|
||||
} else {
|
||||
r = l.FullAttn.Forward(normed, c, B, L, cfg)
|
||||
r = l.FullAttn.Forward(normed, c, positions, B, L, cfg)
|
||||
}
|
||||
h := mlx.Add(x, r)
|
||||
r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg)
|
||||
|
|
@ -1349,6 +1345,7 @@ func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *m
|
|||
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
||||
dims := b.InputIDs.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
|
||||
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
for i, layer := range m.Layers {
|
||||
|
|
@ -1356,7 +1353,7 @@ func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
|
|||
if caches != nil && i < len(caches) {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
h = layer.Forward(h, c, positions, B, L, m.Config)
|
||||
}
|
||||
out := m.Norm.Forward(h, m.RMSNormEps)
|
||||
return out
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue