mirror of
https://github.com/ollama/ollama.git
synced 2026-05-13 06:21:28 +00:00
The MLX runner now routes model work through a locked worker thread. Status also used that worker only to sample memory, so a scheduler health ping could sit behind long prefill or generation until its 10s context expired, causing /v1/status to return 500 and the server to treat the runner as unhealthy. While Metal doesn't change VRAM reporting, CUDA does. Cache the last memory sample and make status perform only a short best-effort refresh. If the worker is busy, status returns the cached value while a single background refresh continues and updates the cache when the worker becomes available. The in-flight guard and lifecycle context keep this from spawning unbounded refreshes while preserving live VRAM refresh behavior for CUDA. Fixes #16081
245 lines
6.3 KiB
Go
245 lines
6.3 KiB
Go
package mlxrunner
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/envconfig"
|
|
"github.com/ollama/ollama/logutil"
|
|
"github.com/ollama/ollama/x/internal/mlxthread"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
"github.com/ollama/ollama/x/mlxrunner/sample"
|
|
)
|
|
|
|
func Execute(args []string) error {
|
|
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
|
|
|
var (
|
|
modelName string
|
|
port int
|
|
)
|
|
|
|
flagSet := flag.NewFlagSet("mlxrunner", flag.ExitOnError)
|
|
flagSet.StringVar(&modelName, "model", "", "Model name")
|
|
flagSet.IntVar(&port, "port", 0, "Port to listen on")
|
|
_ = flagSet.Bool("verbose", false, "Enable debug logging")
|
|
flagSet.Parse(args)
|
|
|
|
worker, err := mlxthread.Start("mlxrunner", func() error {
|
|
if err := mlx.CheckInit(); err != nil {
|
|
return fmt.Errorf("MLX not available: %w", err)
|
|
}
|
|
|
|
if mlx.GPUIsAvailable() {
|
|
mlx.SetDefaultDeviceGPU()
|
|
slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "gpu")
|
|
} else {
|
|
slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "cpu")
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer worker.Stop(context.Background(), func() {
|
|
mlx.Sweep()
|
|
mlx.ClearCache()
|
|
})
|
|
runnerCtx, cancelRunner := context.WithCancel(context.Background())
|
|
defer cancelRunner()
|
|
|
|
runner := Runner{
|
|
Requests: make(chan Request),
|
|
mlxThread: worker,
|
|
}
|
|
|
|
if err := worker.Do(context.Background(), func() error {
|
|
return runner.Load(modelName)
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
readMemory := func() (uint64, error) {
|
|
return uint64(mlx.ActiveMemory() + mlx.CacheMemory()), nil
|
|
}
|
|
initialMemory, err := mlxthread.Call(context.Background(), worker, readMemory)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
memoryCache := newStatusMemoryCache(
|
|
runnerCtx,
|
|
initialMemory,
|
|
time.Now(),
|
|
statusMemoryRefreshWait,
|
|
func() (uint64, error) {
|
|
return mlxthread.Call(runnerCtx, worker, readMemory)
|
|
},
|
|
)
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
|
if err := json.NewEncoder(w).Encode(statusResponse{
|
|
Status: 0,
|
|
Progress: 100,
|
|
ContextLength: runner.contextLength,
|
|
Memory: memoryCache.Memory(),
|
|
}); err != nil {
|
|
slog.Error("Failed to encode response", "error", err)
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
})
|
|
|
|
mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case "POST":
|
|
fallthrough
|
|
case "GET":
|
|
if err := json.NewEncoder(w).Encode(map[string]any{
|
|
"Success": true,
|
|
}); err != nil {
|
|
slog.Error("Failed to encode response", "error", err)
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
case "DELETE":
|
|
// TODO: cleanup model and cache
|
|
}
|
|
})
|
|
|
|
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
|
request := Request{Responses: make(chan CompletionResponse)}
|
|
|
|
if err := json.NewDecoder(r.Body).Decode(&request.CompletionRequest); err != nil {
|
|
slog.Error("Failed to decode request", "error", err)
|
|
http.Error(w, "Bad Request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
request.Pipeline = runner.TextGenerationPipeline
|
|
request.SamplerOpts = sample.Options{
|
|
Temperature: request.Options.Temperature,
|
|
TopP: request.Options.TopP,
|
|
MinP: request.Options.MinP,
|
|
TopK: request.Options.TopK,
|
|
RepeatLastN: request.Options.RepeatLastN,
|
|
RepeatPenalty: request.Options.RepeatPenalty,
|
|
PresencePenalty: request.Options.PresencePenalty,
|
|
FrequencyPenalty: request.Options.FrequencyPenalty,
|
|
Logprobs: request.Logprobs,
|
|
TopLogprobs: request.TopLogprobs,
|
|
}
|
|
|
|
if err := runner.Prepare(&request); err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
var cancel context.CancelFunc
|
|
request.Ctx, cancel = context.WithCancel(r.Context())
|
|
defer cancel()
|
|
|
|
select {
|
|
case <-r.Context().Done():
|
|
return
|
|
case runner.Requests <- request:
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/jsonl")
|
|
w.WriteHeader(http.StatusOK)
|
|
enc := json.NewEncoder(w)
|
|
for {
|
|
select {
|
|
case <-r.Context().Done():
|
|
return
|
|
case response, ok := <-request.Responses:
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if err := enc.Encode(response); err != nil {
|
|
slog.Error("Failed to encode response", "error", err)
|
|
return
|
|
}
|
|
|
|
if f, ok := w.(http.Flusher); ok {
|
|
f.Flush()
|
|
}
|
|
}
|
|
}
|
|
})
|
|
|
|
mux.HandleFunc("POST /v1/tokenize", func(w http.ResponseWriter, r *http.Request) {
|
|
var b bytes.Buffer
|
|
if _, err := io.Copy(&b, r.Body); err != nil {
|
|
slog.Error("Failed to read request body", "error", err)
|
|
http.Error(w, "Bad Request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
tokens := runner.Tokenizer.Encode(b.String(), runner.Tokenizer.AddBOS())
|
|
|
|
if err := json.NewEncoder(w).Encode(tokens); err != nil {
|
|
slog.Error("Failed to encode response", "error", err)
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
})
|
|
|
|
for source, target := range map[string]string{
|
|
"GET /health": "/v1/status",
|
|
"POST /load": "/v1/models",
|
|
"POST /completion": "/v1/completions",
|
|
} {
|
|
mux.Handle(source, http.RedirectHandler(target, http.StatusPermanentRedirect))
|
|
}
|
|
|
|
return runner.Run("127.0.0.1", strconv.Itoa(port), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
recorder := &statusRecorder{ResponseWriter: w, code: http.StatusOK}
|
|
t := time.Now()
|
|
mux.ServeHTTP(recorder, r)
|
|
|
|
var level slog.Level
|
|
switch {
|
|
case recorder.code >= 500:
|
|
level = slog.LevelError
|
|
case recorder.code >= 400:
|
|
level = slog.LevelWarn
|
|
case recorder.code >= 300:
|
|
return
|
|
}
|
|
|
|
slog.Log(r.Context(), level, "ServeHTTP", "method", r.Method, "path", r.URL.Path, "took", time.Since(t), "status", recorder.Status())
|
|
}))
|
|
}
|
|
|
|
type statusRecorder struct {
|
|
http.ResponseWriter
|
|
code int
|
|
}
|
|
|
|
func (w *statusRecorder) WriteHeader(code int) {
|
|
w.code = code
|
|
w.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func (w *statusRecorder) Status() string {
|
|
return strconv.Itoa(w.code) + " " + http.StatusText(w.code)
|
|
}
|
|
|
|
func (w *statusRecorder) Flush() {
|
|
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
|
f.Flush()
|
|
}
|
|
}
|