From 4fe56095639ece0112270a0529d401bf031306d5 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Thu, 30 Apr 2026 16:28:03 -0700 Subject: [PATCH] metal: harden for ggml initialization failures (#15755) * metal: harden for ggml initialization failures ggml_metal_device_init performs a probe to verify the tensor API compiles. On some systems this passes, even though kernel coverage isn't complete, which results in a later crash when compiling the real kernels. This change adds a single retry if any of the error strings match this failure mode to disable the tensor API. It also hardens an error case in the Go initDevices to detect device initialization failures and panic instead of crashing later on a nil array entry. Fixes #15734 * review comments * review comments --- discover/runner.go | 107 +++++++++++++++++++++++++--------- discover/runner_test.go | 26 +++++++++ llm/server.go | 92 +++++++++++++++++++---------- llm/server_wait_test.go | 31 ++++++++++ llm/status.go | 62 ++++++++++++++++++-- llm/status_test.go | 44 ++++++++++++++ ml/backend/ggml/ggml.go | 14 ++++- ml/device.go | 19 +++++- ml/device_test.go | 60 +++++++++++++++++++ runner/ollamarunner/runner.go | 91 ++++++++++++++++++++--------- x/mlxrunner/client.go | 19 +++--- 11 files changed, 453 insertions(+), 112 deletions(-) create mode 100644 llm/server_wait_test.go create mode 100644 llm/status_test.go create mode 100644 ml/device_test.go diff --git a/discover/runner.go b/discover/runner.go index 433a531a1..7d26a9008 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -112,10 +112,9 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml. } ctx1stPass, cancel := context.WithTimeout(ctx, bootstrapTimeout) - defer cancel() - // For this pass, we retain duplicates in case any are incompatible with some libraries - devices = append(devices, bootstrapDevices(ctx1stPass, dirs, nil)...) + devices = append(devices, bootstrapDevicesWithMetalRetry(ctx1stPass, ctx, bootstrapTimeout, dirs, nil)...) + cancel() } // In the second pass, we more deeply initialize the GPUs to weed out devices that @@ -147,9 +146,9 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml. wg.Add(1) go func(i int) { defer wg.Done() - extraEnvs := ml.GetVisibleDevicesEnv(devices[i:i+1], true) + extraEnvs := ml.GetDevicesEnv(devices[i:i+1], true) devices[i].AddInitValidation(extraEnvs) - if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 { + if len(bootstrapDevicesWithMetalRetry(ctx2ndPass, ctx, 30*time.Second, devices[i].LibraryPath, extraEnvs)) == 0 { slog.Debug("filtering device which didn't fully initialize", "id", devices[i].ID, "libdir", devices[i].LibraryPath[len(devices[i].LibraryPath)-1], @@ -334,7 +333,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml. // Apply any dev filters to avoid re-discovering unsupported devices, and get IDs correct // We avoid CUDA filters here to keep ROCm from failing to discover GPUs in a mixed environment - devFilter := ml.GetVisibleDevicesEnv(devices, false) + devFilter := ml.GetDevicesEnv(devices, false) for dir := range libDirs { updatedDevices := bootstrapDevices(ctx, []string{ml.LibOllamaPath, dir}, devFilter) @@ -427,27 +426,84 @@ func (r *bootstrapRunner) HasExited() bool { return false } -func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs map[string]string) []ml.DeviceInfo { - var out io.Writer - if envconfig.LogLevel() == logutil.LevelTrace { - out = os.Stderr +func bootstrapDevicesWithMetalRetry(firstAttemptCtx, retryParentCtx context.Context, timeout time.Duration, ollamaLibDirs []string, extraEnvs map[string]string) []ml.DeviceInfo { + runDiscovery := func(ctx context.Context, extraEnvs map[string]string) ([]ml.DeviceInfo, *llm.StatusWriter, int, error) { + start := time.Now() + defer func() { + slog.Debug("bootstrap discovery took", "duration", time.Since(start), "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs) + }() + return bootstrapDevicesWithStatus(ctx, ollamaLibDirs, extraEnvs) } - start := time.Now() - defer func() { - slog.Debug("bootstrap discovery took", "duration", time.Since(start), "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs) - }() - logutil.Trace("starting runner for device discovery", "libDirs", ollamaLibDirs, "extraEnvs", extraEnvs) + devices, status, exitCode, err := runDiscovery(firstAttemptCtx, extraEnvs) + if err == nil { + recordPersistentRunnerEnv(devices, extraEnvs) + } + if err != nil && llm.ShouldRetryWithMetalTensorDisabled(err, status) && (extraEnvs == nil || extraEnvs["GGML_METAL_TENSOR_DISABLE"] != "1") { + retryEnvs := map[string]string{} + for k, v := range extraEnvs { + retryEnvs[k] = v + } + retryEnvs["GGML_METAL_TENSOR_DISABLE"] = "1" + slog.Warn("retrying GPU discovery with Metal tensor API disabled", "error", err) + retryCtx, cancel := context.WithTimeout(retryParentCtx, timeout) + defer cancel() + devices, status, exitCode, err = runDiscovery(retryCtx, retryEnvs) + if err == nil { + recordPersistentRunnerEnv(devices, retryEnvs) + } + } + + if err != nil { + if exitCode >= 0 { + // Expected during bootstrapping while we filter out unsupported GPUs. + logutil.Trace("runner exited", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "code", exitCode, "detail", status.LastError()) + } else { + slog.Info("failure during GPU discovery", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "error", err, "detail", status.LastError()) + } + } + + return devices +} + +func recordPersistentRunnerEnv(devices []ml.DeviceInfo, extraEnvs map[string]string) { + if extraEnvs["GGML_METAL_TENSOR_DISABLE"] != "1" { + return + } + + for i := range devices { + if devices[i].Library != "Metal" { + continue + } + if devices[i].RunnerEnvOverrides == nil { + devices[i].RunnerEnvOverrides = map[string]string{} + } + devices[i].RunnerEnvOverrides["GGML_METAL_TENSOR_DISABLE"] = "1" + } +} + +func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs map[string]string) []ml.DeviceInfo { + devices, _, _, _ := bootstrapDevicesWithStatus(ctx, ollamaLibDirs, extraEnvs) + return devices +} + +func bootstrapDevicesWithStatus(ctx context.Context, ollamaLibDirs []string, extraEnvs map[string]string) ([]ml.DeviceInfo, *llm.StatusWriter, int, error) { + var baseOut io.Writer = io.Discard + if envconfig.LogLevel() == logutil.LevelTrace { + baseOut = os.Stderr + } + + status := llm.NewStatusWriter(baseOut) cmd, port, err := llm.StartRunner( true, // ollama engine "", // no model ollamaLibDirs, - out, + status, extraEnvs, ) if err != nil { slog.Debug("failed to start runner to discovery GPUs", "error", err) - return nil + return nil, status, -1, err } go func() { @@ -455,18 +511,13 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs map }() defer cmd.Process.Kill() - devices, err := ml.GetDevicesFromRunner(ctx, &bootstrapRunner{port: port, cmd: cmd}) - if err != nil { - if cmd.ProcessState != nil && cmd.ProcessState.ExitCode() >= 0 { - // Expected during bootstrapping while we filter out unsupported AMD GPUs - logutil.Trace("runner exited", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "code", cmd.ProcessState.ExitCode()) - } else { - slog.Info("failure during GPU discovery", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "error", err) - } - } - logutil.Trace("runner enumerated devices", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "devices", devices) - return devices + devices, err := ml.GetDevicesFromRunner(ctx, &bootstrapRunner{port: port, cmd: cmd}) + exitCode := -1 + if cmd.ProcessState != nil { + exitCode = cmd.ProcessState.ExitCode() + } + return devices, status, exitCode, err } func overrideWarnings() { diff --git a/discover/runner_test.go b/discover/runner_test.go index 74ce629e4..51edf0b36 100644 --- a/discover/runner_test.go +++ b/discover/runner_test.go @@ -4,6 +4,8 @@ import ( "log/slog" "os" "testing" + + "github.com/ollama/ollama/ml" ) func init() { @@ -107,3 +109,27 @@ func TestFilterOverlapByLibrary(t *testing.T) { }) } } + +func TestRecordPersistentRunnerEnv(t *testing.T) { + devices := []ml.DeviceInfo{ + {DeviceID: ml.DeviceID{Library: "Metal", ID: "0"}}, + {DeviceID: ml.DeviceID{Library: "CUDA", ID: "1"}}, + } + + recordPersistentRunnerEnv(devices, map[string]string{ + "GGML_METAL_TENSOR_DISABLE": "1", + "CUDA_VISIBLE_DEVICES": "1", + }) + + if got := devices[0].RunnerEnvOverrides["GGML_METAL_TENSOR_DISABLE"]; got != "1" { + t.Fatalf("Metal RunnerEnvOverrides = %q, want %q", got, "1") + } + + if _, ok := devices[0].RunnerEnvOverrides["CUDA_VISIBLE_DEVICES"]; ok { + t.Fatal("unexpected CUDA_VISIBLE_DEVICES in Metal RunnerEnvOverrides") + } + + if devices[1].RunnerEnvOverrides != nil { + t.Fatalf("unexpected RunnerEnvOverrides recorded for non-Metal device: %#v", devices[1].RunnerEnvOverrides) + } +} diff --git a/llm/server.go b/llm/server.go index a8104f79f..0b907e0d4 100644 --- a/llm/server.go +++ b/llm/server.go @@ -278,7 +278,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st modelPath, gpuLibs, status, - ml.GetVisibleDevicesEnv(gpus, false), + ml.GetDevicesEnv(gpus, false), ) s := llmServer{ @@ -298,8 +298,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st if err != nil { var msg string - if s.status != nil && s.status.LastErrMsg != "" { - msg = s.status.LastErrMsg + if s.status != nil && s.status.LastError() != "" { + msg = s.status.LastError() } err := fmt.Errorf("error starting runner: %v %s", err, msg) if llamaModel != nil { @@ -312,12 +312,12 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st go func() { err := s.cmd.Wait() // Favor a more detailed message over the process exit status - if err != nil && s.status != nil && s.status.LastErrMsg != "" { + if err != nil && s.status != nil && s.status.LastError() != "" { slog.Error("llama runner terminated", "error", err) - if strings.Contains(s.status.LastErrMsg, "unknown model") { - s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade" + if strings.Contains(s.status.LastError(), "unknown model") { + s.status.SetLastError("this model is not supported by your version of Ollama. You may need to upgrade") } - s.doneErr = errors.New(s.status.LastErrMsg) + s.doneErr = errors.New(s.status.LastError()) } else { s.doneErr = err } @@ -385,20 +385,9 @@ func StartRunner(ollamaEngine bool, modelPath string, gpuLibs []string, out io.W cmd.Env = os.Environ() if out != nil { - stdout, err := cmd.StdoutPipe() - if err != nil { - return nil, 0, fmt.Errorf("failed to spawn server stdout pipe: %w", err) - } - stderr, err := cmd.StderrPipe() - if err != nil { - return nil, 0, fmt.Errorf("failed to spawn server stderr pipe: %w", err) - } - go func() { - io.Copy(out, stdout) //nolint:errcheck - }() - go func() { - io.Copy(out, stderr) //nolint:errcheck - }() + // os/exec serializes Write calls when shared + cmd.Stdout = out + cmd.Stderr = out } cmd.SysProcAttr = LlamaServerSysProcAttr @@ -451,6 +440,38 @@ func StartRunner(ollamaEngine bool, modelPath string, gpuLibs []string, out io.W return } +// Workaround possible runtime crash where the probe incorrectly +// enables metal tensor, but fails at runtime +func ShouldRetryWithMetalTensorDisabled(err error, status *StatusWriter) bool { + if runtime.GOOS != "darwin" { + return false + } + + var msg strings.Builder + msg.WriteString(strings.ToLower(err.Error())) + if status != nil && status.LastError() != "" { + msg.WriteByte(' ') + msg.WriteString(strings.ToLower(status.LastError())) + } + text := msg.String() + + for _, needle := range []string{ + "failed to initialize ggml backend device: metal", + "failed to initialize metal backend", + "failed to initialize the metal library", + "failed to allocate context", + "unable to create llama context", + "signal arrived during cgo execution", + "input types must match cooperative tensor types", + } { + if strings.Contains(text, needle) { + return true + } + } + + return false +} + func (s *llmServer) ModelPath() string { return s.modelPath } @@ -723,7 +744,7 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system // The llama engine does its memory allocations together with model loading, so we // need to wait until it is done to ensure that we have accurate memory data before - // loading the next model + // loading the next model. return uniqueDeviceIDs(s.loadRequest.GPULayers), s.WaitUntilRunning(ctx) } @@ -1276,8 +1297,8 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { // Fail fast if its exited if s.cmd.ProcessState != nil { msg := "" - if s.status != nil && s.status.LastErrMsg != "" { - msg = s.status.LastErrMsg + if s.status != nil && s.status.LastError() != "" { + msg = s.status.LastError() } if s.cmd.ProcessState.ExitCode() == -1 { // Most likely a signal killed it, log some more details to try to help troubleshoot @@ -1371,21 +1392,30 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { slog.Warn("client connection closed before server finished loading, aborting load") return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err()) case <-s.done: - return fmt.Errorf("llama runner process has terminated: %w", s.doneErr) + if s.status != nil && s.status.LastError() != "" { + return fmt.Errorf("llama runner process has terminated: %s", s.status.LastError()) + } + if s.doneErr != nil { + return fmt.Errorf("llama runner process has terminated: %w", s.doneErr) + } + if s.cmd != nil && s.cmd.ProcessState != nil { + return fmt.Errorf("llama runner process has terminated with exit code %d", s.cmd.ProcessState.ExitCode()) + } + return errors.New("llama runner process has terminated") default: } if time.Now().After(stallTimer) { // timeout msg := "" - if s.status != nil && s.status.LastErrMsg != "" { - msg = s.status.LastErrMsg + if s.status != nil && s.status.LastError() != "" { + msg = s.status.LastError() } return fmt.Errorf("timed out waiting for llama runner to start - progress %0.2f - %s", s.loadProgress, msg) } if s.cmd.ProcessState != nil { msg := "" - if s.status != nil && s.status.LastErrMsg != "" { - msg = s.status.LastErrMsg + if s.status != nil && s.status.LastError() != "" { + msg = s.status.LastError() } return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg) } @@ -1695,8 +1725,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu if strings.Contains(err.Error(), "unexpected EOF") || strings.Contains(err.Error(), "forcibly closed") { s.Close() var msg string - if s.status != nil && s.status.LastErrMsg != "" { - msg = s.status.LastErrMsg + if s.status != nil && s.status.LastError() != "" { + msg = s.status.LastError() } else { msg = err.Error() } diff --git a/llm/server_wait_test.go b/llm/server_wait_test.go new file mode 100644 index 000000000..7548338af --- /dev/null +++ b/llm/server_wait_test.go @@ -0,0 +1,31 @@ +package llm + +import ( + "context" + "strings" + "testing" +) + +func TestWaitUntilRunningUsesStatusMessageWhenDoneErrIsNil(t *testing.T) { + done := make(chan struct{}) + close(done) + + status := &StatusWriter{} + status.SetLastError("llama_init_from_model: failed to initialize the context: failed to initialize Metal backend") + + s := &llmServer{ + done: done, + status: status, + } + + err := s.WaitUntilRunning(context.Background()) + if err == nil { + t.Fatal("expected error") + } + if strings.Contains(err.Error(), "%!w()") { + t.Fatalf("unexpected wrapped nil error: %q", err) + } + if !strings.Contains(err.Error(), s.status.LastError()) { + t.Fatalf("error %q does not include status message %q", err, s.status.LastError()) + } +} diff --git a/llm/status.go b/llm/status.go index fdb94954c..c2dd6ac26 100644 --- a/llm/status.go +++ b/llm/status.go @@ -2,23 +2,67 @@ package llm import ( "bytes" - "os" + "io" + "strings" + "sync/atomic" ) // StatusWriter is a writer that captures error messages from the llama runner process type StatusWriter struct { - LastErrMsg string - out *os.File + out io.Writer + // StartRunner wires both Stdout and Stderr to the same StatusWriter, and + // os/exec serializes Write calls in that case. + lastErrMsg atomic.Value } -func NewStatusWriter(out *os.File) *StatusWriter { +const maxCapturedErrorBytes = 8 * 1024 + +func NewStatusWriter(out io.Writer) *StatusWriter { return &StatusWriter{ out: out, } } +func (w *StatusWriter) LastError() string { + if w == nil { + return "" + } + if v := w.lastErrMsg.Load(); v != nil { + return v.(string) + } + return "" +} + +func (w *StatusWriter) SetLastError(msg string) { + if w == nil { + return + } + w.lastErrMsg.Store(msg) +} + +func (w *StatusWriter) AppendError(msg string) { + if w == nil || msg == "" { + return + } + + if current := w.LastError(); current != "" { + msg = current + "\n" + msg + } + + if len(msg) > maxCapturedErrorBytes { + msg = msg[len(msg)-maxCapturedErrorBytes:] + if i := strings.IndexByte(msg, '\n'); i >= 0 { + msg = msg[i+1:] + } + } + + w.SetLastError(msg) +} + // TODO - regex matching to detect errors like // libcublasLt.so.11: cannot open shared object file: No such file or directory +// TODO - if we later see error lines split across multiple Write calls in real +// logs, add a small rolling buffer here to capture those fragments. var errorPrefixes = []string{ "error:", @@ -29,17 +73,23 @@ var errorPrefixes = []string{ "error loading model", "GGML_ASSERT", "Deepseek2 does not support K-shift", + "signal arrived during cgo execution", + "llama_init_from_model:", } func (w *StatusWriter) Write(b []byte) (int, error) { var errMsg string for _, prefix := range errorPrefixes { if _, after, ok := bytes.Cut(b, []byte(prefix)); ok { - errMsg = prefix + string(bytes.TrimSpace(after)) + line := after + if j := bytes.IndexByte(line, '\n'); j >= 0 { + line = line[:j] + } + errMsg = prefix + string(bytes.TrimRight(line, " \t\r")) } } if errMsg != "" { - w.LastErrMsg = errMsg + w.AppendError(errMsg) } return w.out.Write(b) diff --git a/llm/status_test.go b/llm/status_test.go new file mode 100644 index 000000000..2297ddd25 --- /dev/null +++ b/llm/status_test.go @@ -0,0 +1,44 @@ +package llm + +import ( + "os" + "testing" +) + +func TestStatusWriterCapturesErrorLine(t *testing.T) { + f, err := os.CreateTemp(t.TempDir(), "status-writer") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + w := NewStatusWriter(f) + if _, err := w.Write([]byte("llama_init_from_model: failed to initialize the context: failed to initialize Metal backend\n")); err != nil { + t.Fatal(err) + } + + if got, want := w.LastError(), "llama_init_from_model: failed to initialize the context: failed to initialize Metal backend"; got != want { + t.Fatalf("LastError = %q, want %q", got, want) + } +} + +func TestStatusWriterAccumulatesErrorLines(t *testing.T) { + f, err := os.CreateTemp(t.TempDir(), "status-writer") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + w := NewStatusWriter(f) + if _, err := w.Write([]byte("error: failed to initialize the Metal library\n")); err != nil { + t.Fatal(err) + } + if _, err := w.Write([]byte("GGML_ASSERT([rsets->data count] == 0) failed\n")); err != nil { + t.Fatal(err) + } + + want := "error: failed to initialize the Metal library\nGGML_ASSERT([rsets->data count] == 0) failed" + if got := w.LastError(); got != want { + t.Fatalf("LastError = %q, want %q", got, want) + } +} diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 8ec8c94f9..c82c8ef54 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -50,8 +50,18 @@ var initDevices = sync.OnceFunc(func() { backends = make(map[C.ggml_backend_dev_t]C.ggml_backend_t) for i := range C.ggml_backend_dev_count() { d := C.ggml_backend_dev_get(i) + t := C.ggml_backend_dev_type(d) + name := C.GoString(C.ggml_backend_dev_name(d)) - switch C.ggml_backend_dev_type(d) { + b := C.ggml_backend_dev_init(d, nil) + if b == nil { + slog.Error("failed to initialize ggml backend device", "device", name, "type", t) + panic(fmt.Sprintf("failed to initialize ggml backend device: %s", name)) + } + + backends[d] = b + + switch t { case C.GGML_BACKEND_DEVICE_TYPE_CPU: if len(cpus) == 0 { // only the first cpu device should be used @@ -63,8 +73,6 @@ var initDevices = sync.OnceFunc(func() { C.GGML_BACKEND_DEVICE_TYPE_IGPU: gpus = append(gpus, d) } - - backends[d] = C.ggml_backend_dev_init(d, nil) } }) diff --git a/ml/device.go b/ml/device.go index 47e180d30..708a6bf30 100644 --- a/ml/device.go +++ b/ml/device.go @@ -314,6 +314,11 @@ type DeviceInfo struct { // Where backends were loaded from LibraryPath []string + + // RunnerEnvOverrides stores exceptional per-device runner environment + // overrides discovered during bootstrap. This is internal server state and + // is not serialized. + RunnerEnvOverrides map[string]string `json:"-"` } type SystemInfo struct { @@ -519,16 +524,24 @@ func (f FlashAttentionType) String() string { } // Given the list of GPUs this instantiation is targeted for, -// figure out the visible devices environment variables -// Set mustFilter true to enable filtering of CUDA devices -func GetVisibleDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string { +// figure out the device environment variables and any recorded +// per-device runner environment overrides. Set mustFilter true to enable +// filtering of CUDA devices. +func GetDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string { if len(l) == 0 { return nil } env := map[string]string{} for _, d := range l { d.updateVisibleDevicesEnv(env, mustFilter) + for k, v := range d.RunnerEnvOverrides { + if existing, ok := env[k]; ok && existing != v { + slog.Warn("conflicting device environment override", "key", k, "existing", existing, "new", v, "library", d.Library, "id", d.ID) + } + env[k] = v + } } + return env } diff --git a/ml/device_test.go b/ml/device_test.go new file mode 100644 index 000000000..9e98a9f2b --- /dev/null +++ b/ml/device_test.go @@ -0,0 +1,60 @@ +package ml + +import ( + "bytes" + "log/slog" + "strings" + "testing" +) + +func TestMergeEnvWithRunnerEnvOverrides(t *testing.T) { + devices := []DeviceInfo{ + { + DeviceID: DeviceID{Library: "Metal", ID: "0"}, + RunnerEnvOverrides: map[string]string{"GGML_METAL_TENSOR_DISABLE": "1"}, + }, + { + DeviceID: DeviceID{Library: "CUDA", ID: "3"}, + }, + } + + env := GetDevicesEnv(devices, true) + + if got, want := env["GGML_METAL_TENSOR_DISABLE"], "1"; got != want { + t.Fatalf("GGML_METAL_TENSOR_DISABLE = %q, want %q", got, want) + } + + if got, want := env["CUDA_VISIBLE_DEVICES"], "3"; got != want { + t.Fatalf("CUDA_VISIBLE_DEVICES = %q, want %q", got, want) + } +} + +func TestGetDevicesEnvWarnsOnConflictingOverrides(t *testing.T) { + var logs bytes.Buffer + oldLogger := slog.Default() + slog.SetDefault(slog.New(slog.NewTextHandler(&logs, &slog.HandlerOptions{Level: slog.LevelDebug}))) + t.Cleanup(func() { + slog.SetDefault(oldLogger) + }) + + devices := []DeviceInfo{ + { + DeviceID: DeviceID{Library: "Metal", ID: "0"}, + RunnerEnvOverrides: map[string]string{"TEST_OVERRIDE": "one"}, + }, + { + DeviceID: DeviceID{Library: "Metal", ID: "1"}, + RunnerEnvOverrides: map[string]string{"TEST_OVERRIDE": "two"}, + }, + } + + env := GetDevicesEnv(devices, false) + + if got, want := env["TEST_OVERRIDE"], "two"; got != want { + t.Fatalf("TEST_OVERRIDE = %q, want %q", got, want) + } + + if !strings.Contains(logs.String(), "conflicting device environment override") { + t.Fatalf("expected warning log, got %q", logs.String()) + } +} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index ccf646539..7920796cc 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -335,6 +335,13 @@ type Server struct { // loadMu prevents more than one load attempt from occurring at a time loadMu sync.Mutex + // infoInitialized caches the result of the dummy /info backend + // initialization. loadMu already serializes callers, so a simple cached + // result avoids repeated dummy loads without needing sync.Once. + infoInitialized bool + infoModel model.Model + infoErr error + // lastLoad is the load request from the previous load attempt. Used to // detect if we can reuse an existing memory allocation. lastLoad llm.LoadRequest @@ -1362,44 +1369,70 @@ func (s *Server) info(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - m := s.model - - if m == nil { - startLoad := time.Now() - - // Dummy load to get the backend wired up - f, err := os.CreateTemp("", "*.bin") - if err != nil { - http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError) - return - } - defer f.Close() - defer os.Remove(f.Name()) - - if err := ggml.WriteGGUF(f, ggml.KV{ - "general.architecture": "llama", - "tokenizer.ggml.model": "gpt2", - }, nil); err != nil { - http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError) - return - } - - m, err = model.New(f.Name(), ml.BackendParams{NumThreads: runtime.NumCPU(), AllocMemory: false, GPULayers: ml.GPULayersList{{}}}) - if err != nil { - http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError) - return - } - slog.Debug("dummy model load took", "duration", time.Since(startLoad)) + m, err := s.infoModelLocked() + if err != nil { + http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError) + return } - startDevices := time.Now() infos := m.Backend().BackendDevices() slog.Debug("gathering device infos took", "duration", time.Since(startDevices)) if err := json.NewEncoder(w).Encode(&infos); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + return } } +func (s *Server) infoModelLocked() (model.Model, error) { + if s.model != nil { + return s.model, nil + } + + if !s.infoInitialized { + s.infoInitialized = true + + func() { + startLoad := time.Now() + defer func() { + if rec := recover(); rec != nil { + s.infoErr = fmt.Errorf("panic during dummy backend initialization: %v", rec) + } + }() + + // Dummy load to get the backend wired up. + f, err := os.CreateTemp("", "*.bin") + if err != nil { + s.infoErr = err + return + } + defer f.Close() + defer os.Remove(f.Name()) + + if err := ggml.WriteGGUF(f, ggml.KV{ + "general.architecture": "llama", + "tokenizer.ggml.model": "gpt2", + }, nil); err != nil { + s.infoErr = err + return + } + + s.infoModel, s.infoErr = model.New(f.Name(), ml.BackendParams{NumThreads: runtime.NumCPU(), AllocMemory: false, GPULayers: ml.GPULayersList{{}}}) + if s.infoErr == nil { + slog.Debug("dummy model load took", "duration", time.Since(startLoad)) + } + }() + } + + if s.infoErr != nil { + return nil, s.infoErr + } + if s.infoModel == nil { + return nil, fmt.Errorf("dummy backend initialization did not produce a model") + } + + return s.infoModel, nil +} + func Execute(args []string) error { fs := flag.NewFlagSet("runner", flag.ExitOnError) mpath := fs.String("model", "", "Path to model binary file") diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go index 4730a10d0..e2774dced 100644 --- a/x/mlxrunner/client.go +++ b/x/mlxrunner/client.go @@ -45,9 +45,9 @@ type Client struct { cmd *exec.Cmd } -// statusWriter captures the last stderr line from the subprocess while -// forwarding all output to os.Stderr. Lines longer than maxStatusLen are -// truncated to the first maxStatusLen bytes. +// statusWriter captures the last subprocess line while forwarding all output +// to os.Stderr. Lines longer than maxStatusLen are truncated to the first +// maxStatusLen bytes. type statusWriter struct { lastErrMsg string buf []byte @@ -405,17 +405,12 @@ func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo c.cmd = cmd - // Forward subprocess stdout/stderr to server logs - stdout, _ := cmd.StdoutPipe() - stderr, _ := cmd.StderrPipe() status := &statusWriter{out: os.Stderr} c.status = status - go func() { - io.Copy(os.Stderr, stdout) //nolint:errcheck - }() - go func() { - io.Copy(status, stderr) //nolint:errcheck - }() + // os/exec serializes Write calls when shared, which keeps the status writer + // from seeing concurrent stdout/stderr fragments. + cmd.Stdout = status + cmd.Stderr = status slog.Info("starting mlx runner subprocess", "model", c.modelName, "port", c.port) if err := cmd.Start(); err != nil {