metal: harden for ggml initialization failures (#15755)

* metal: harden for ggml initialization failures

ggml_metal_device_init performs a probe to verify the tensor API compiles.  On
some systems this passes, even though kernel coverage isn't complete, which
results in a later crash when compiling the real kernels.  This change adds a
single retry if any of the error strings match this failure mode to disable the
tensor API.  It also hardens an error case in the Go initDevices to detect
device initialization failures and panic instead of crashing later on a nil
array entry.

Fixes #15734

* review comments

* review comments
This commit is contained in:
Daniel Hiltgen 2026-04-30 16:28:03 -07:00 committed by GitHub
parent 917324bb4d
commit 4fe5609563
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 453 additions and 112 deletions

View file

@ -112,10 +112,9 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
}
ctx1stPass, cancel := context.WithTimeout(ctx, bootstrapTimeout)
defer cancel()
// For this pass, we retain duplicates in case any are incompatible with some libraries
devices = append(devices, bootstrapDevices(ctx1stPass, dirs, nil)...)
devices = append(devices, bootstrapDevicesWithMetalRetry(ctx1stPass, ctx, bootstrapTimeout, dirs, nil)...)
cancel()
}
// In the second pass, we more deeply initialize the GPUs to weed out devices that
@ -147,9 +146,9 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
wg.Add(1)
go func(i int) {
defer wg.Done()
extraEnvs := ml.GetVisibleDevicesEnv(devices[i:i+1], true)
extraEnvs := ml.GetDevicesEnv(devices[i:i+1], true)
devices[i].AddInitValidation(extraEnvs)
if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 {
if len(bootstrapDevicesWithMetalRetry(ctx2ndPass, ctx, 30*time.Second, devices[i].LibraryPath, extraEnvs)) == 0 {
slog.Debug("filtering device which didn't fully initialize",
"id", devices[i].ID,
"libdir", devices[i].LibraryPath[len(devices[i].LibraryPath)-1],
@ -334,7 +333,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
// Apply any dev filters to avoid re-discovering unsupported devices, and get IDs correct
// We avoid CUDA filters here to keep ROCm from failing to discover GPUs in a mixed environment
devFilter := ml.GetVisibleDevicesEnv(devices, false)
devFilter := ml.GetDevicesEnv(devices, false)
for dir := range libDirs {
updatedDevices := bootstrapDevices(ctx, []string{ml.LibOllamaPath, dir}, devFilter)
@ -427,27 +426,84 @@ func (r *bootstrapRunner) HasExited() bool {
return false
}
func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs map[string]string) []ml.DeviceInfo {
var out io.Writer
if envconfig.LogLevel() == logutil.LevelTrace {
out = os.Stderr
func bootstrapDevicesWithMetalRetry(firstAttemptCtx, retryParentCtx context.Context, timeout time.Duration, ollamaLibDirs []string, extraEnvs map[string]string) []ml.DeviceInfo {
runDiscovery := func(ctx context.Context, extraEnvs map[string]string) ([]ml.DeviceInfo, *llm.StatusWriter, int, error) {
start := time.Now()
defer func() {
slog.Debug("bootstrap discovery took", "duration", time.Since(start), "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)
}()
return bootstrapDevicesWithStatus(ctx, ollamaLibDirs, extraEnvs)
}
start := time.Now()
defer func() {
slog.Debug("bootstrap discovery took", "duration", time.Since(start), "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)
}()
logutil.Trace("starting runner for device discovery", "libDirs", ollamaLibDirs, "extraEnvs", extraEnvs)
devices, status, exitCode, err := runDiscovery(firstAttemptCtx, extraEnvs)
if err == nil {
recordPersistentRunnerEnv(devices, extraEnvs)
}
if err != nil && llm.ShouldRetryWithMetalTensorDisabled(err, status) && (extraEnvs == nil || extraEnvs["GGML_METAL_TENSOR_DISABLE"] != "1") {
retryEnvs := map[string]string{}
for k, v := range extraEnvs {
retryEnvs[k] = v
}
retryEnvs["GGML_METAL_TENSOR_DISABLE"] = "1"
slog.Warn("retrying GPU discovery with Metal tensor API disabled", "error", err)
retryCtx, cancel := context.WithTimeout(retryParentCtx, timeout)
defer cancel()
devices, status, exitCode, err = runDiscovery(retryCtx, retryEnvs)
if err == nil {
recordPersistentRunnerEnv(devices, retryEnvs)
}
}
if err != nil {
if exitCode >= 0 {
// Expected during bootstrapping while we filter out unsupported GPUs.
logutil.Trace("runner exited", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "code", exitCode, "detail", status.LastError())
} else {
slog.Info("failure during GPU discovery", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "error", err, "detail", status.LastError())
}
}
return devices
}
func recordPersistentRunnerEnv(devices []ml.DeviceInfo, extraEnvs map[string]string) {
if extraEnvs["GGML_METAL_TENSOR_DISABLE"] != "1" {
return
}
for i := range devices {
if devices[i].Library != "Metal" {
continue
}
if devices[i].RunnerEnvOverrides == nil {
devices[i].RunnerEnvOverrides = map[string]string{}
}
devices[i].RunnerEnvOverrides["GGML_METAL_TENSOR_DISABLE"] = "1"
}
}
func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs map[string]string) []ml.DeviceInfo {
devices, _, _, _ := bootstrapDevicesWithStatus(ctx, ollamaLibDirs, extraEnvs)
return devices
}
func bootstrapDevicesWithStatus(ctx context.Context, ollamaLibDirs []string, extraEnvs map[string]string) ([]ml.DeviceInfo, *llm.StatusWriter, int, error) {
var baseOut io.Writer = io.Discard
if envconfig.LogLevel() == logutil.LevelTrace {
baseOut = os.Stderr
}
status := llm.NewStatusWriter(baseOut)
cmd, port, err := llm.StartRunner(
true, // ollama engine
"", // no model
ollamaLibDirs,
out,
status,
extraEnvs,
)
if err != nil {
slog.Debug("failed to start runner to discovery GPUs", "error", err)
return nil
return nil, status, -1, err
}
go func() {
@ -455,18 +511,13 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs map
}()
defer cmd.Process.Kill()
devices, err := ml.GetDevicesFromRunner(ctx, &bootstrapRunner{port: port, cmd: cmd})
if err != nil {
if cmd.ProcessState != nil && cmd.ProcessState.ExitCode() >= 0 {
// Expected during bootstrapping while we filter out unsupported AMD GPUs
logutil.Trace("runner exited", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "code", cmd.ProcessState.ExitCode())
} else {
slog.Info("failure during GPU discovery", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "error", err)
}
}
logutil.Trace("runner enumerated devices", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "devices", devices)
return devices
devices, err := ml.GetDevicesFromRunner(ctx, &bootstrapRunner{port: port, cmd: cmd})
exitCode := -1
if cmd.ProcessState != nil {
exitCode = cmd.ProcessState.ExitCode()
}
return devices, status, exitCode, err
}
func overrideWarnings() {

View file

@ -4,6 +4,8 @@ import (
"log/slog"
"os"
"testing"
"github.com/ollama/ollama/ml"
)
func init() {
@ -107,3 +109,27 @@ func TestFilterOverlapByLibrary(t *testing.T) {
})
}
}
func TestRecordPersistentRunnerEnv(t *testing.T) {
devices := []ml.DeviceInfo{
{DeviceID: ml.DeviceID{Library: "Metal", ID: "0"}},
{DeviceID: ml.DeviceID{Library: "CUDA", ID: "1"}},
}
recordPersistentRunnerEnv(devices, map[string]string{
"GGML_METAL_TENSOR_DISABLE": "1",
"CUDA_VISIBLE_DEVICES": "1",
})
if got := devices[0].RunnerEnvOverrides["GGML_METAL_TENSOR_DISABLE"]; got != "1" {
t.Fatalf("Metal RunnerEnvOverrides = %q, want %q", got, "1")
}
if _, ok := devices[0].RunnerEnvOverrides["CUDA_VISIBLE_DEVICES"]; ok {
t.Fatal("unexpected CUDA_VISIBLE_DEVICES in Metal RunnerEnvOverrides")
}
if devices[1].RunnerEnvOverrides != nil {
t.Fatalf("unexpected RunnerEnvOverrides recorded for non-Metal device: %#v", devices[1].RunnerEnvOverrides)
}
}

View file

@ -278,7 +278,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
modelPath,
gpuLibs,
status,
ml.GetVisibleDevicesEnv(gpus, false),
ml.GetDevicesEnv(gpus, false),
)
s := llmServer{
@ -298,8 +298,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
if err != nil {
var msg string
if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg
if s.status != nil && s.status.LastError() != "" {
msg = s.status.LastError()
}
err := fmt.Errorf("error starting runner: %v %s", err, msg)
if llamaModel != nil {
@ -312,12 +312,12 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
go func() {
err := s.cmd.Wait()
// Favor a more detailed message over the process exit status
if err != nil && s.status != nil && s.status.LastErrMsg != "" {
if err != nil && s.status != nil && s.status.LastError() != "" {
slog.Error("llama runner terminated", "error", err)
if strings.Contains(s.status.LastErrMsg, "unknown model") {
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
if strings.Contains(s.status.LastError(), "unknown model") {
s.status.SetLastError("this model is not supported by your version of Ollama. You may need to upgrade")
}
s.doneErr = errors.New(s.status.LastErrMsg)
s.doneErr = errors.New(s.status.LastError())
} else {
s.doneErr = err
}
@ -385,20 +385,9 @@ func StartRunner(ollamaEngine bool, modelPath string, gpuLibs []string, out io.W
cmd.Env = os.Environ()
if out != nil {
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, 0, fmt.Errorf("failed to spawn server stdout pipe: %w", err)
}
stderr, err := cmd.StderrPipe()
if err != nil {
return nil, 0, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
}
go func() {
io.Copy(out, stdout) //nolint:errcheck
}()
go func() {
io.Copy(out, stderr) //nolint:errcheck
}()
// os/exec serializes Write calls when shared
cmd.Stdout = out
cmd.Stderr = out
}
cmd.SysProcAttr = LlamaServerSysProcAttr
@ -451,6 +440,38 @@ func StartRunner(ollamaEngine bool, modelPath string, gpuLibs []string, out io.W
return
}
// Workaround possible runtime crash where the probe incorrectly
// enables metal tensor, but fails at runtime
func ShouldRetryWithMetalTensorDisabled(err error, status *StatusWriter) bool {
if runtime.GOOS != "darwin" {
return false
}
var msg strings.Builder
msg.WriteString(strings.ToLower(err.Error()))
if status != nil && status.LastError() != "" {
msg.WriteByte(' ')
msg.WriteString(strings.ToLower(status.LastError()))
}
text := msg.String()
for _, needle := range []string{
"failed to initialize ggml backend device: metal",
"failed to initialize metal backend",
"failed to initialize the metal library",
"failed to allocate context",
"unable to create llama context",
"signal arrived during cgo execution",
"input types must match cooperative tensor types",
} {
if strings.Contains(text, needle) {
return true
}
}
return false
}
func (s *llmServer) ModelPath() string {
return s.modelPath
}
@ -723,7 +744,7 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
// The llama engine does its memory allocations together with model loading, so we
// need to wait until it is done to ensure that we have accurate memory data before
// loading the next model
// loading the next model.
return uniqueDeviceIDs(s.loadRequest.GPULayers), s.WaitUntilRunning(ctx)
}
@ -1276,8 +1297,8 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
// Fail fast if its exited
if s.cmd.ProcessState != nil {
msg := ""
if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg
if s.status != nil && s.status.LastError() != "" {
msg = s.status.LastError()
}
if s.cmd.ProcessState.ExitCode() == -1 {
// Most likely a signal killed it, log some more details to try to help troubleshoot
@ -1371,21 +1392,30 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
slog.Warn("client connection closed before server finished loading, aborting load")
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
case <-s.done:
return fmt.Errorf("llama runner process has terminated: %w", s.doneErr)
if s.status != nil && s.status.LastError() != "" {
return fmt.Errorf("llama runner process has terminated: %s", s.status.LastError())
}
if s.doneErr != nil {
return fmt.Errorf("llama runner process has terminated: %w", s.doneErr)
}
if s.cmd != nil && s.cmd.ProcessState != nil {
return fmt.Errorf("llama runner process has terminated with exit code %d", s.cmd.ProcessState.ExitCode())
}
return errors.New("llama runner process has terminated")
default:
}
if time.Now().After(stallTimer) {
// timeout
msg := ""
if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg
if s.status != nil && s.status.LastError() != "" {
msg = s.status.LastError()
}
return fmt.Errorf("timed out waiting for llama runner to start - progress %0.2f - %s", s.loadProgress, msg)
}
if s.cmd.ProcessState != nil {
msg := ""
if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg
if s.status != nil && s.status.LastError() != "" {
msg = s.status.LastError()
}
return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
}
@ -1695,8 +1725,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if strings.Contains(err.Error(), "unexpected EOF") || strings.Contains(err.Error(), "forcibly closed") {
s.Close()
var msg string
if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg
if s.status != nil && s.status.LastError() != "" {
msg = s.status.LastError()
} else {
msg = err.Error()
}

31
llm/server_wait_test.go Normal file
View file

@ -0,0 +1,31 @@
package llm
import (
"context"
"strings"
"testing"
)
func TestWaitUntilRunningUsesStatusMessageWhenDoneErrIsNil(t *testing.T) {
done := make(chan struct{})
close(done)
status := &StatusWriter{}
status.SetLastError("llama_init_from_model: failed to initialize the context: failed to initialize Metal backend")
s := &llmServer{
done: done,
status: status,
}
err := s.WaitUntilRunning(context.Background())
if err == nil {
t.Fatal("expected error")
}
if strings.Contains(err.Error(), "%!w(<nil>)") {
t.Fatalf("unexpected wrapped nil error: %q", err)
}
if !strings.Contains(err.Error(), s.status.LastError()) {
t.Fatalf("error %q does not include status message %q", err, s.status.LastError())
}
}

View file

@ -2,23 +2,67 @@ package llm
import (
"bytes"
"os"
"io"
"strings"
"sync/atomic"
)
// StatusWriter is a writer that captures error messages from the llama runner process
type StatusWriter struct {
LastErrMsg string
out *os.File
out io.Writer
// StartRunner wires both Stdout and Stderr to the same StatusWriter, and
// os/exec serializes Write calls in that case.
lastErrMsg atomic.Value
}
func NewStatusWriter(out *os.File) *StatusWriter {
const maxCapturedErrorBytes = 8 * 1024
func NewStatusWriter(out io.Writer) *StatusWriter {
return &StatusWriter{
out: out,
}
}
func (w *StatusWriter) LastError() string {
if w == nil {
return ""
}
if v := w.lastErrMsg.Load(); v != nil {
return v.(string)
}
return ""
}
func (w *StatusWriter) SetLastError(msg string) {
if w == nil {
return
}
w.lastErrMsg.Store(msg)
}
func (w *StatusWriter) AppendError(msg string) {
if w == nil || msg == "" {
return
}
if current := w.LastError(); current != "" {
msg = current + "\n" + msg
}
if len(msg) > maxCapturedErrorBytes {
msg = msg[len(msg)-maxCapturedErrorBytes:]
if i := strings.IndexByte(msg, '\n'); i >= 0 {
msg = msg[i+1:]
}
}
w.SetLastError(msg)
}
// TODO - regex matching to detect errors like
// libcublasLt.so.11: cannot open shared object file: No such file or directory
// TODO - if we later see error lines split across multiple Write calls in real
// logs, add a small rolling buffer here to capture those fragments.
var errorPrefixes = []string{
"error:",
@ -29,17 +73,23 @@ var errorPrefixes = []string{
"error loading model",
"GGML_ASSERT",
"Deepseek2 does not support K-shift",
"signal arrived during cgo execution",
"llama_init_from_model:",
}
func (w *StatusWriter) Write(b []byte) (int, error) {
var errMsg string
for _, prefix := range errorPrefixes {
if _, after, ok := bytes.Cut(b, []byte(prefix)); ok {
errMsg = prefix + string(bytes.TrimSpace(after))
line := after
if j := bytes.IndexByte(line, '\n'); j >= 0 {
line = line[:j]
}
errMsg = prefix + string(bytes.TrimRight(line, " \t\r"))
}
}
if errMsg != "" {
w.LastErrMsg = errMsg
w.AppendError(errMsg)
}
return w.out.Write(b)

44
llm/status_test.go Normal file
View file

@ -0,0 +1,44 @@
package llm
import (
"os"
"testing"
)
func TestStatusWriterCapturesErrorLine(t *testing.T) {
f, err := os.CreateTemp(t.TempDir(), "status-writer")
if err != nil {
t.Fatal(err)
}
defer f.Close()
w := NewStatusWriter(f)
if _, err := w.Write([]byte("llama_init_from_model: failed to initialize the context: failed to initialize Metal backend\n")); err != nil {
t.Fatal(err)
}
if got, want := w.LastError(), "llama_init_from_model: failed to initialize the context: failed to initialize Metal backend"; got != want {
t.Fatalf("LastError = %q, want %q", got, want)
}
}
func TestStatusWriterAccumulatesErrorLines(t *testing.T) {
f, err := os.CreateTemp(t.TempDir(), "status-writer")
if err != nil {
t.Fatal(err)
}
defer f.Close()
w := NewStatusWriter(f)
if _, err := w.Write([]byte("error: failed to initialize the Metal library\n")); err != nil {
t.Fatal(err)
}
if _, err := w.Write([]byte("GGML_ASSERT([rsets->data count] == 0) failed\n")); err != nil {
t.Fatal(err)
}
want := "error: failed to initialize the Metal library\nGGML_ASSERT([rsets->data count] == 0) failed"
if got := w.LastError(); got != want {
t.Fatalf("LastError = %q, want %q", got, want)
}
}

View file

@ -50,8 +50,18 @@ var initDevices = sync.OnceFunc(func() {
backends = make(map[C.ggml_backend_dev_t]C.ggml_backend_t)
for i := range C.ggml_backend_dev_count() {
d := C.ggml_backend_dev_get(i)
t := C.ggml_backend_dev_type(d)
name := C.GoString(C.ggml_backend_dev_name(d))
switch C.ggml_backend_dev_type(d) {
b := C.ggml_backend_dev_init(d, nil)
if b == nil {
slog.Error("failed to initialize ggml backend device", "device", name, "type", t)
panic(fmt.Sprintf("failed to initialize ggml backend device: %s", name))
}
backends[d] = b
switch t {
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
if len(cpus) == 0 {
// only the first cpu device should be used
@ -63,8 +73,6 @@ var initDevices = sync.OnceFunc(func() {
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
gpus = append(gpus, d)
}
backends[d] = C.ggml_backend_dev_init(d, nil)
}
})

View file

@ -314,6 +314,11 @@ type DeviceInfo struct {
// Where backends were loaded from
LibraryPath []string
// RunnerEnvOverrides stores exceptional per-device runner environment
// overrides discovered during bootstrap. This is internal server state and
// is not serialized.
RunnerEnvOverrides map[string]string `json:"-"`
}
type SystemInfo struct {
@ -519,16 +524,24 @@ func (f FlashAttentionType) String() string {
}
// Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variables
// Set mustFilter true to enable filtering of CUDA devices
func GetVisibleDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string {
// figure out the device environment variables and any recorded
// per-device runner environment overrides. Set mustFilter true to enable
// filtering of CUDA devices.
func GetDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string {
if len(l) == 0 {
return nil
}
env := map[string]string{}
for _, d := range l {
d.updateVisibleDevicesEnv(env, mustFilter)
for k, v := range d.RunnerEnvOverrides {
if existing, ok := env[k]; ok && existing != v {
slog.Warn("conflicting device environment override", "key", k, "existing", existing, "new", v, "library", d.Library, "id", d.ID)
}
env[k] = v
}
}
return env
}

60
ml/device_test.go Normal file
View file

@ -0,0 +1,60 @@
package ml
import (
"bytes"
"log/slog"
"strings"
"testing"
)
func TestMergeEnvWithRunnerEnvOverrides(t *testing.T) {
devices := []DeviceInfo{
{
DeviceID: DeviceID{Library: "Metal", ID: "0"},
RunnerEnvOverrides: map[string]string{"GGML_METAL_TENSOR_DISABLE": "1"},
},
{
DeviceID: DeviceID{Library: "CUDA", ID: "3"},
},
}
env := GetDevicesEnv(devices, true)
if got, want := env["GGML_METAL_TENSOR_DISABLE"], "1"; got != want {
t.Fatalf("GGML_METAL_TENSOR_DISABLE = %q, want %q", got, want)
}
if got, want := env["CUDA_VISIBLE_DEVICES"], "3"; got != want {
t.Fatalf("CUDA_VISIBLE_DEVICES = %q, want %q", got, want)
}
}
func TestGetDevicesEnvWarnsOnConflictingOverrides(t *testing.T) {
var logs bytes.Buffer
oldLogger := slog.Default()
slog.SetDefault(slog.New(slog.NewTextHandler(&logs, &slog.HandlerOptions{Level: slog.LevelDebug})))
t.Cleanup(func() {
slog.SetDefault(oldLogger)
})
devices := []DeviceInfo{
{
DeviceID: DeviceID{Library: "Metal", ID: "0"},
RunnerEnvOverrides: map[string]string{"TEST_OVERRIDE": "one"},
},
{
DeviceID: DeviceID{Library: "Metal", ID: "1"},
RunnerEnvOverrides: map[string]string{"TEST_OVERRIDE": "two"},
},
}
env := GetDevicesEnv(devices, false)
if got, want := env["TEST_OVERRIDE"], "two"; got != want {
t.Fatalf("TEST_OVERRIDE = %q, want %q", got, want)
}
if !strings.Contains(logs.String(), "conflicting device environment override") {
t.Fatalf("expected warning log, got %q", logs.String())
}
}

View file

@ -335,6 +335,13 @@ type Server struct {
// loadMu prevents more than one load attempt from occurring at a time
loadMu sync.Mutex
// infoInitialized caches the result of the dummy /info backend
// initialization. loadMu already serializes callers, so a simple cached
// result avoids repeated dummy loads without needing sync.Once.
infoInitialized bool
infoModel model.Model
infoErr error
// lastLoad is the load request from the previous load attempt. Used to
// detect if we can reuse an existing memory allocation.
lastLoad llm.LoadRequest
@ -1362,44 +1369,70 @@ func (s *Server) info(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
m := s.model
if m == nil {
startLoad := time.Now()
// Dummy load to get the backend wired up
f, err := os.CreateTemp("", "*.bin")
if err != nil {
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
return
}
defer f.Close()
defer os.Remove(f.Name())
if err := ggml.WriteGGUF(f, ggml.KV{
"general.architecture": "llama",
"tokenizer.ggml.model": "gpt2",
}, nil); err != nil {
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
return
}
m, err = model.New(f.Name(), ml.BackendParams{NumThreads: runtime.NumCPU(), AllocMemory: false, GPULayers: ml.GPULayersList{{}}})
if err != nil {
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
return
}
slog.Debug("dummy model load took", "duration", time.Since(startLoad))
m, err := s.infoModelLocked()
if err != nil {
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
return
}
startDevices := time.Now()
infos := m.Backend().BackendDevices()
slog.Debug("gathering device infos took", "duration", time.Since(startDevices))
if err := json.NewEncoder(w).Encode(&infos); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
return
}
}
func (s *Server) infoModelLocked() (model.Model, error) {
if s.model != nil {
return s.model, nil
}
if !s.infoInitialized {
s.infoInitialized = true
func() {
startLoad := time.Now()
defer func() {
if rec := recover(); rec != nil {
s.infoErr = fmt.Errorf("panic during dummy backend initialization: %v", rec)
}
}()
// Dummy load to get the backend wired up.
f, err := os.CreateTemp("", "*.bin")
if err != nil {
s.infoErr = err
return
}
defer f.Close()
defer os.Remove(f.Name())
if err := ggml.WriteGGUF(f, ggml.KV{
"general.architecture": "llama",
"tokenizer.ggml.model": "gpt2",
}, nil); err != nil {
s.infoErr = err
return
}
s.infoModel, s.infoErr = model.New(f.Name(), ml.BackendParams{NumThreads: runtime.NumCPU(), AllocMemory: false, GPULayers: ml.GPULayersList{{}}})
if s.infoErr == nil {
slog.Debug("dummy model load took", "duration", time.Since(startLoad))
}
}()
}
if s.infoErr != nil {
return nil, s.infoErr
}
if s.infoModel == nil {
return nil, fmt.Errorf("dummy backend initialization did not produce a model")
}
return s.infoModel, nil
}
func Execute(args []string) error {
fs := flag.NewFlagSet("runner", flag.ExitOnError)
mpath := fs.String("model", "", "Path to model binary file")

View file

@ -45,9 +45,9 @@ type Client struct {
cmd *exec.Cmd
}
// statusWriter captures the last stderr line from the subprocess while
// forwarding all output to os.Stderr. Lines longer than maxStatusLen are
// truncated to the first maxStatusLen bytes.
// statusWriter captures the last subprocess line while forwarding all output
// to os.Stderr. Lines longer than maxStatusLen are truncated to the first
// maxStatusLen bytes.
type statusWriter struct {
lastErrMsg string
buf []byte
@ -405,17 +405,12 @@ func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
c.cmd = cmd
// Forward subprocess stdout/stderr to server logs
stdout, _ := cmd.StdoutPipe()
stderr, _ := cmd.StderrPipe()
status := &statusWriter{out: os.Stderr}
c.status = status
go func() {
io.Copy(os.Stderr, stdout) //nolint:errcheck
}()
go func() {
io.Copy(status, stderr) //nolint:errcheck
}()
// os/exec serializes Write calls when shared, which keeps the status writer
// from seeing concurrent stdout/stderr fragments.
cmd.Stdout = status
cmd.Stderr = status
slog.Info("starting mlx runner subprocess", "model", c.modelName, "port", c.port)
if err := cmd.Start(); err != nil {