mirror of
https://github.com/ollama/ollama.git
synced 2026-05-13 06:21:28 +00:00
metal: harden for ggml initialization failures (#15755)
* metal: harden for ggml initialization failures ggml_metal_device_init performs a probe to verify the tensor API compiles. On some systems this passes, even though kernel coverage isn't complete, which results in a later crash when compiling the real kernels. This change adds a single retry if any of the error strings match this failure mode to disable the tensor API. It also hardens an error case in the Go initDevices to detect device initialization failures and panic instead of crashing later on a nil array entry. Fixes #15734 * review comments * review comments
This commit is contained in:
parent
917324bb4d
commit
4fe5609563
11 changed files with 453 additions and 112 deletions
|
|
@ -112,10 +112,9 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
|||
}
|
||||
|
||||
ctx1stPass, cancel := context.WithTimeout(ctx, bootstrapTimeout)
|
||||
defer cancel()
|
||||
|
||||
// For this pass, we retain duplicates in case any are incompatible with some libraries
|
||||
devices = append(devices, bootstrapDevices(ctx1stPass, dirs, nil)...)
|
||||
devices = append(devices, bootstrapDevicesWithMetalRetry(ctx1stPass, ctx, bootstrapTimeout, dirs, nil)...)
|
||||
cancel()
|
||||
}
|
||||
|
||||
// In the second pass, we more deeply initialize the GPUs to weed out devices that
|
||||
|
|
@ -147,9 +146,9 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
|||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
extraEnvs := ml.GetVisibleDevicesEnv(devices[i:i+1], true)
|
||||
extraEnvs := ml.GetDevicesEnv(devices[i:i+1], true)
|
||||
devices[i].AddInitValidation(extraEnvs)
|
||||
if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 {
|
||||
if len(bootstrapDevicesWithMetalRetry(ctx2ndPass, ctx, 30*time.Second, devices[i].LibraryPath, extraEnvs)) == 0 {
|
||||
slog.Debug("filtering device which didn't fully initialize",
|
||||
"id", devices[i].ID,
|
||||
"libdir", devices[i].LibraryPath[len(devices[i].LibraryPath)-1],
|
||||
|
|
@ -334,7 +333,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
|||
|
||||
// Apply any dev filters to avoid re-discovering unsupported devices, and get IDs correct
|
||||
// We avoid CUDA filters here to keep ROCm from failing to discover GPUs in a mixed environment
|
||||
devFilter := ml.GetVisibleDevicesEnv(devices, false)
|
||||
devFilter := ml.GetDevicesEnv(devices, false)
|
||||
|
||||
for dir := range libDirs {
|
||||
updatedDevices := bootstrapDevices(ctx, []string{ml.LibOllamaPath, dir}, devFilter)
|
||||
|
|
@ -427,27 +426,84 @@ func (r *bootstrapRunner) HasExited() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs map[string]string) []ml.DeviceInfo {
|
||||
var out io.Writer
|
||||
if envconfig.LogLevel() == logutil.LevelTrace {
|
||||
out = os.Stderr
|
||||
func bootstrapDevicesWithMetalRetry(firstAttemptCtx, retryParentCtx context.Context, timeout time.Duration, ollamaLibDirs []string, extraEnvs map[string]string) []ml.DeviceInfo {
|
||||
runDiscovery := func(ctx context.Context, extraEnvs map[string]string) ([]ml.DeviceInfo, *llm.StatusWriter, int, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
slog.Debug("bootstrap discovery took", "duration", time.Since(start), "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)
|
||||
}()
|
||||
return bootstrapDevicesWithStatus(ctx, ollamaLibDirs, extraEnvs)
|
||||
}
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
slog.Debug("bootstrap discovery took", "duration", time.Since(start), "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)
|
||||
}()
|
||||
|
||||
logutil.Trace("starting runner for device discovery", "libDirs", ollamaLibDirs, "extraEnvs", extraEnvs)
|
||||
devices, status, exitCode, err := runDiscovery(firstAttemptCtx, extraEnvs)
|
||||
if err == nil {
|
||||
recordPersistentRunnerEnv(devices, extraEnvs)
|
||||
}
|
||||
if err != nil && llm.ShouldRetryWithMetalTensorDisabled(err, status) && (extraEnvs == nil || extraEnvs["GGML_METAL_TENSOR_DISABLE"] != "1") {
|
||||
retryEnvs := map[string]string{}
|
||||
for k, v := range extraEnvs {
|
||||
retryEnvs[k] = v
|
||||
}
|
||||
retryEnvs["GGML_METAL_TENSOR_DISABLE"] = "1"
|
||||
slog.Warn("retrying GPU discovery with Metal tensor API disabled", "error", err)
|
||||
retryCtx, cancel := context.WithTimeout(retryParentCtx, timeout)
|
||||
defer cancel()
|
||||
devices, status, exitCode, err = runDiscovery(retryCtx, retryEnvs)
|
||||
if err == nil {
|
||||
recordPersistentRunnerEnv(devices, retryEnvs)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if exitCode >= 0 {
|
||||
// Expected during bootstrapping while we filter out unsupported GPUs.
|
||||
logutil.Trace("runner exited", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "code", exitCode, "detail", status.LastError())
|
||||
} else {
|
||||
slog.Info("failure during GPU discovery", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "error", err, "detail", status.LastError())
|
||||
}
|
||||
}
|
||||
|
||||
return devices
|
||||
}
|
||||
|
||||
func recordPersistentRunnerEnv(devices []ml.DeviceInfo, extraEnvs map[string]string) {
|
||||
if extraEnvs["GGML_METAL_TENSOR_DISABLE"] != "1" {
|
||||
return
|
||||
}
|
||||
|
||||
for i := range devices {
|
||||
if devices[i].Library != "Metal" {
|
||||
continue
|
||||
}
|
||||
if devices[i].RunnerEnvOverrides == nil {
|
||||
devices[i].RunnerEnvOverrides = map[string]string{}
|
||||
}
|
||||
devices[i].RunnerEnvOverrides["GGML_METAL_TENSOR_DISABLE"] = "1"
|
||||
}
|
||||
}
|
||||
|
||||
func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs map[string]string) []ml.DeviceInfo {
|
||||
devices, _, _, _ := bootstrapDevicesWithStatus(ctx, ollamaLibDirs, extraEnvs)
|
||||
return devices
|
||||
}
|
||||
|
||||
func bootstrapDevicesWithStatus(ctx context.Context, ollamaLibDirs []string, extraEnvs map[string]string) ([]ml.DeviceInfo, *llm.StatusWriter, int, error) {
|
||||
var baseOut io.Writer = io.Discard
|
||||
if envconfig.LogLevel() == logutil.LevelTrace {
|
||||
baseOut = os.Stderr
|
||||
}
|
||||
|
||||
status := llm.NewStatusWriter(baseOut)
|
||||
cmd, port, err := llm.StartRunner(
|
||||
true, // ollama engine
|
||||
"", // no model
|
||||
ollamaLibDirs,
|
||||
out,
|
||||
status,
|
||||
extraEnvs,
|
||||
)
|
||||
if err != nil {
|
||||
slog.Debug("failed to start runner to discovery GPUs", "error", err)
|
||||
return nil
|
||||
return nil, status, -1, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
|
|
@ -455,18 +511,13 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs map
|
|||
}()
|
||||
|
||||
defer cmd.Process.Kill()
|
||||
devices, err := ml.GetDevicesFromRunner(ctx, &bootstrapRunner{port: port, cmd: cmd})
|
||||
if err != nil {
|
||||
if cmd.ProcessState != nil && cmd.ProcessState.ExitCode() >= 0 {
|
||||
// Expected during bootstrapping while we filter out unsupported AMD GPUs
|
||||
logutil.Trace("runner exited", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "code", cmd.ProcessState.ExitCode())
|
||||
} else {
|
||||
slog.Info("failure during GPU discovery", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "error", err)
|
||||
}
|
||||
}
|
||||
logutil.Trace("runner enumerated devices", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "devices", devices)
|
||||
|
||||
return devices
|
||||
devices, err := ml.GetDevicesFromRunner(ctx, &bootstrapRunner{port: port, cmd: cmd})
|
||||
exitCode := -1
|
||||
if cmd.ProcessState != nil {
|
||||
exitCode = cmd.ProcessState.ExitCode()
|
||||
}
|
||||
return devices, status, exitCode, err
|
||||
}
|
||||
|
||||
func overrideWarnings() {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import (
|
|||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
|
@ -107,3 +109,27 @@ func TestFilterOverlapByLibrary(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordPersistentRunnerEnv(t *testing.T) {
|
||||
devices := []ml.DeviceInfo{
|
||||
{DeviceID: ml.DeviceID{Library: "Metal", ID: "0"}},
|
||||
{DeviceID: ml.DeviceID{Library: "CUDA", ID: "1"}},
|
||||
}
|
||||
|
||||
recordPersistentRunnerEnv(devices, map[string]string{
|
||||
"GGML_METAL_TENSOR_DISABLE": "1",
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
})
|
||||
|
||||
if got := devices[0].RunnerEnvOverrides["GGML_METAL_TENSOR_DISABLE"]; got != "1" {
|
||||
t.Fatalf("Metal RunnerEnvOverrides = %q, want %q", got, "1")
|
||||
}
|
||||
|
||||
if _, ok := devices[0].RunnerEnvOverrides["CUDA_VISIBLE_DEVICES"]; ok {
|
||||
t.Fatal("unexpected CUDA_VISIBLE_DEVICES in Metal RunnerEnvOverrides")
|
||||
}
|
||||
|
||||
if devices[1].RunnerEnvOverrides != nil {
|
||||
t.Fatalf("unexpected RunnerEnvOverrides recorded for non-Metal device: %#v", devices[1].RunnerEnvOverrides)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -278,7 +278,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
|||
modelPath,
|
||||
gpuLibs,
|
||||
status,
|
||||
ml.GetVisibleDevicesEnv(gpus, false),
|
||||
ml.GetDevicesEnv(gpus, false),
|
||||
)
|
||||
|
||||
s := llmServer{
|
||||
|
|
@ -298,8 +298,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
|||
|
||||
if err != nil {
|
||||
var msg string
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
if s.status != nil && s.status.LastError() != "" {
|
||||
msg = s.status.LastError()
|
||||
}
|
||||
err := fmt.Errorf("error starting runner: %v %s", err, msg)
|
||||
if llamaModel != nil {
|
||||
|
|
@ -312,12 +312,12 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
|||
go func() {
|
||||
err := s.cmd.Wait()
|
||||
// Favor a more detailed message over the process exit status
|
||||
if err != nil && s.status != nil && s.status.LastErrMsg != "" {
|
||||
if err != nil && s.status != nil && s.status.LastError() != "" {
|
||||
slog.Error("llama runner terminated", "error", err)
|
||||
if strings.Contains(s.status.LastErrMsg, "unknown model") {
|
||||
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
|
||||
if strings.Contains(s.status.LastError(), "unknown model") {
|
||||
s.status.SetLastError("this model is not supported by your version of Ollama. You may need to upgrade")
|
||||
}
|
||||
s.doneErr = errors.New(s.status.LastErrMsg)
|
||||
s.doneErr = errors.New(s.status.LastError())
|
||||
} else {
|
||||
s.doneErr = err
|
||||
}
|
||||
|
|
@ -385,20 +385,9 @@ func StartRunner(ollamaEngine bool, modelPath string, gpuLibs []string, out io.W
|
|||
cmd.Env = os.Environ()
|
||||
|
||||
if out != nil {
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to spawn server stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
|
||||
}
|
||||
go func() {
|
||||
io.Copy(out, stdout) //nolint:errcheck
|
||||
}()
|
||||
go func() {
|
||||
io.Copy(out, stderr) //nolint:errcheck
|
||||
}()
|
||||
// os/exec serializes Write calls when shared
|
||||
cmd.Stdout = out
|
||||
cmd.Stderr = out
|
||||
}
|
||||
cmd.SysProcAttr = LlamaServerSysProcAttr
|
||||
|
||||
|
|
@ -451,6 +440,38 @@ func StartRunner(ollamaEngine bool, modelPath string, gpuLibs []string, out io.W
|
|||
return
|
||||
}
|
||||
|
||||
// Workaround possible runtime crash where the probe incorrectly
|
||||
// enables metal tensor, but fails at runtime
|
||||
func ShouldRetryWithMetalTensorDisabled(err error, status *StatusWriter) bool {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return false
|
||||
}
|
||||
|
||||
var msg strings.Builder
|
||||
msg.WriteString(strings.ToLower(err.Error()))
|
||||
if status != nil && status.LastError() != "" {
|
||||
msg.WriteByte(' ')
|
||||
msg.WriteString(strings.ToLower(status.LastError()))
|
||||
}
|
||||
text := msg.String()
|
||||
|
||||
for _, needle := range []string{
|
||||
"failed to initialize ggml backend device: metal",
|
||||
"failed to initialize metal backend",
|
||||
"failed to initialize the metal library",
|
||||
"failed to allocate context",
|
||||
"unable to create llama context",
|
||||
"signal arrived during cgo execution",
|
||||
"input types must match cooperative tensor types",
|
||||
} {
|
||||
if strings.Contains(text, needle) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *llmServer) ModelPath() string {
|
||||
return s.modelPath
|
||||
}
|
||||
|
|
@ -723,7 +744,7 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
|||
|
||||
// The llama engine does its memory allocations together with model loading, so we
|
||||
// need to wait until it is done to ensure that we have accurate memory data before
|
||||
// loading the next model
|
||||
// loading the next model.
|
||||
return uniqueDeviceIDs(s.loadRequest.GPULayers), s.WaitUntilRunning(ctx)
|
||||
}
|
||||
|
||||
|
|
@ -1276,8 +1297,8 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
|||
// Fail fast if its exited
|
||||
if s.cmd.ProcessState != nil {
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
if s.status != nil && s.status.LastError() != "" {
|
||||
msg = s.status.LastError()
|
||||
}
|
||||
if s.cmd.ProcessState.ExitCode() == -1 {
|
||||
// Most likely a signal killed it, log some more details to try to help troubleshoot
|
||||
|
|
@ -1371,21 +1392,30 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
|||
slog.Warn("client connection closed before server finished loading, aborting load")
|
||||
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
|
||||
case <-s.done:
|
||||
return fmt.Errorf("llama runner process has terminated: %w", s.doneErr)
|
||||
if s.status != nil && s.status.LastError() != "" {
|
||||
return fmt.Errorf("llama runner process has terminated: %s", s.status.LastError())
|
||||
}
|
||||
if s.doneErr != nil {
|
||||
return fmt.Errorf("llama runner process has terminated: %w", s.doneErr)
|
||||
}
|
||||
if s.cmd != nil && s.cmd.ProcessState != nil {
|
||||
return fmt.Errorf("llama runner process has terminated with exit code %d", s.cmd.ProcessState.ExitCode())
|
||||
}
|
||||
return errors.New("llama runner process has terminated")
|
||||
default:
|
||||
}
|
||||
if time.Now().After(stallTimer) {
|
||||
// timeout
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
if s.status != nil && s.status.LastError() != "" {
|
||||
msg = s.status.LastError()
|
||||
}
|
||||
return fmt.Errorf("timed out waiting for llama runner to start - progress %0.2f - %s", s.loadProgress, msg)
|
||||
}
|
||||
if s.cmd.ProcessState != nil {
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
if s.status != nil && s.status.LastError() != "" {
|
||||
msg = s.status.LastError()
|
||||
}
|
||||
return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
|
||||
}
|
||||
|
|
@ -1695,8 +1725,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||
if strings.Contains(err.Error(), "unexpected EOF") || strings.Contains(err.Error(), "forcibly closed") {
|
||||
s.Close()
|
||||
var msg string
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
if s.status != nil && s.status.LastError() != "" {
|
||||
msg = s.status.LastError()
|
||||
} else {
|
||||
msg = err.Error()
|
||||
}
|
||||
|
|
|
|||
31
llm/server_wait_test.go
Normal file
31
llm/server_wait_test.go
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWaitUntilRunningUsesStatusMessageWhenDoneErrIsNil(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
close(done)
|
||||
|
||||
status := &StatusWriter{}
|
||||
status.SetLastError("llama_init_from_model: failed to initialize the context: failed to initialize Metal backend")
|
||||
|
||||
s := &llmServer{
|
||||
done: done,
|
||||
status: status,
|
||||
}
|
||||
|
||||
err := s.WaitUntilRunning(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if strings.Contains(err.Error(), "%!w(<nil>)") {
|
||||
t.Fatalf("unexpected wrapped nil error: %q", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), s.status.LastError()) {
|
||||
t.Fatalf("error %q does not include status message %q", err, s.status.LastError())
|
||||
}
|
||||
}
|
||||
|
|
@ -2,23 +2,67 @@ package llm
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"io"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// StatusWriter is a writer that captures error messages from the llama runner process
|
||||
type StatusWriter struct {
|
||||
LastErrMsg string
|
||||
out *os.File
|
||||
out io.Writer
|
||||
// StartRunner wires both Stdout and Stderr to the same StatusWriter, and
|
||||
// os/exec serializes Write calls in that case.
|
||||
lastErrMsg atomic.Value
|
||||
}
|
||||
|
||||
func NewStatusWriter(out *os.File) *StatusWriter {
|
||||
const maxCapturedErrorBytes = 8 * 1024
|
||||
|
||||
func NewStatusWriter(out io.Writer) *StatusWriter {
|
||||
return &StatusWriter{
|
||||
out: out,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *StatusWriter) LastError() string {
|
||||
if w == nil {
|
||||
return ""
|
||||
}
|
||||
if v := w.lastErrMsg.Load(); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (w *StatusWriter) SetLastError(msg string) {
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
w.lastErrMsg.Store(msg)
|
||||
}
|
||||
|
||||
func (w *StatusWriter) AppendError(msg string) {
|
||||
if w == nil || msg == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if current := w.LastError(); current != "" {
|
||||
msg = current + "\n" + msg
|
||||
}
|
||||
|
||||
if len(msg) > maxCapturedErrorBytes {
|
||||
msg = msg[len(msg)-maxCapturedErrorBytes:]
|
||||
if i := strings.IndexByte(msg, '\n'); i >= 0 {
|
||||
msg = msg[i+1:]
|
||||
}
|
||||
}
|
||||
|
||||
w.SetLastError(msg)
|
||||
}
|
||||
|
||||
// TODO - regex matching to detect errors like
|
||||
// libcublasLt.so.11: cannot open shared object file: No such file or directory
|
||||
// TODO - if we later see error lines split across multiple Write calls in real
|
||||
// logs, add a small rolling buffer here to capture those fragments.
|
||||
|
||||
var errorPrefixes = []string{
|
||||
"error:",
|
||||
|
|
@ -29,17 +73,23 @@ var errorPrefixes = []string{
|
|||
"error loading model",
|
||||
"GGML_ASSERT",
|
||||
"Deepseek2 does not support K-shift",
|
||||
"signal arrived during cgo execution",
|
||||
"llama_init_from_model:",
|
||||
}
|
||||
|
||||
func (w *StatusWriter) Write(b []byte) (int, error) {
|
||||
var errMsg string
|
||||
for _, prefix := range errorPrefixes {
|
||||
if _, after, ok := bytes.Cut(b, []byte(prefix)); ok {
|
||||
errMsg = prefix + string(bytes.TrimSpace(after))
|
||||
line := after
|
||||
if j := bytes.IndexByte(line, '\n'); j >= 0 {
|
||||
line = line[:j]
|
||||
}
|
||||
errMsg = prefix + string(bytes.TrimRight(line, " \t\r"))
|
||||
}
|
||||
}
|
||||
if errMsg != "" {
|
||||
w.LastErrMsg = errMsg
|
||||
w.AppendError(errMsg)
|
||||
}
|
||||
|
||||
return w.out.Write(b)
|
||||
|
|
|
|||
44
llm/status_test.go
Normal file
44
llm/status_test.go
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
package llm
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStatusWriterCapturesErrorLine(t *testing.T) {
|
||||
f, err := os.CreateTemp(t.TempDir(), "status-writer")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
w := NewStatusWriter(f)
|
||||
if _, err := w.Write([]byte("llama_init_from_model: failed to initialize the context: failed to initialize Metal backend\n")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, want := w.LastError(), "llama_init_from_model: failed to initialize the context: failed to initialize Metal backend"; got != want {
|
||||
t.Fatalf("LastError = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusWriterAccumulatesErrorLines(t *testing.T) {
|
||||
f, err := os.CreateTemp(t.TempDir(), "status-writer")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
w := NewStatusWriter(f)
|
||||
if _, err := w.Write([]byte("error: failed to initialize the Metal library\n")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write([]byte("GGML_ASSERT([rsets->data count] == 0) failed\n")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
want := "error: failed to initialize the Metal library\nGGML_ASSERT([rsets->data count] == 0) failed"
|
||||
if got := w.LastError(); got != want {
|
||||
t.Fatalf("LastError = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
|
@ -50,8 +50,18 @@ var initDevices = sync.OnceFunc(func() {
|
|||
backends = make(map[C.ggml_backend_dev_t]C.ggml_backend_t)
|
||||
for i := range C.ggml_backend_dev_count() {
|
||||
d := C.ggml_backend_dev_get(i)
|
||||
t := C.ggml_backend_dev_type(d)
|
||||
name := C.GoString(C.ggml_backend_dev_name(d))
|
||||
|
||||
switch C.ggml_backend_dev_type(d) {
|
||||
b := C.ggml_backend_dev_init(d, nil)
|
||||
if b == nil {
|
||||
slog.Error("failed to initialize ggml backend device", "device", name, "type", t)
|
||||
panic(fmt.Sprintf("failed to initialize ggml backend device: %s", name))
|
||||
}
|
||||
|
||||
backends[d] = b
|
||||
|
||||
switch t {
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
|
||||
if len(cpus) == 0 {
|
||||
// only the first cpu device should be used
|
||||
|
|
@ -63,8 +73,6 @@ var initDevices = sync.OnceFunc(func() {
|
|||
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
|
||||
gpus = append(gpus, d)
|
||||
}
|
||||
|
||||
backends[d] = C.ggml_backend_dev_init(d, nil)
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
|||
19
ml/device.go
19
ml/device.go
|
|
@ -314,6 +314,11 @@ type DeviceInfo struct {
|
|||
|
||||
// Where backends were loaded from
|
||||
LibraryPath []string
|
||||
|
||||
// RunnerEnvOverrides stores exceptional per-device runner environment
|
||||
// overrides discovered during bootstrap. This is internal server state and
|
||||
// is not serialized.
|
||||
RunnerEnvOverrides map[string]string `json:"-"`
|
||||
}
|
||||
|
||||
type SystemInfo struct {
|
||||
|
|
@ -519,16 +524,24 @@ func (f FlashAttentionType) String() string {
|
|||
}
|
||||
|
||||
// Given the list of GPUs this instantiation is targeted for,
|
||||
// figure out the visible devices environment variables
|
||||
// Set mustFilter true to enable filtering of CUDA devices
|
||||
func GetVisibleDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string {
|
||||
// figure out the device environment variables and any recorded
|
||||
// per-device runner environment overrides. Set mustFilter true to enable
|
||||
// filtering of CUDA devices.
|
||||
func GetDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string {
|
||||
if len(l) == 0 {
|
||||
return nil
|
||||
}
|
||||
env := map[string]string{}
|
||||
for _, d := range l {
|
||||
d.updateVisibleDevicesEnv(env, mustFilter)
|
||||
for k, v := range d.RunnerEnvOverrides {
|
||||
if existing, ok := env[k]; ok && existing != v {
|
||||
slog.Warn("conflicting device environment override", "key", k, "existing", existing, "new", v, "library", d.Library, "id", d.ID)
|
||||
}
|
||||
env[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return env
|
||||
}
|
||||
|
||||
|
|
|
|||
60
ml/device_test.go
Normal file
60
ml/device_test.go
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMergeEnvWithRunnerEnvOverrides(t *testing.T) {
|
||||
devices := []DeviceInfo{
|
||||
{
|
||||
DeviceID: DeviceID{Library: "Metal", ID: "0"},
|
||||
RunnerEnvOverrides: map[string]string{"GGML_METAL_TENSOR_DISABLE": "1"},
|
||||
},
|
||||
{
|
||||
DeviceID: DeviceID{Library: "CUDA", ID: "3"},
|
||||
},
|
||||
}
|
||||
|
||||
env := GetDevicesEnv(devices, true)
|
||||
|
||||
if got, want := env["GGML_METAL_TENSOR_DISABLE"], "1"; got != want {
|
||||
t.Fatalf("GGML_METAL_TENSOR_DISABLE = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
if got, want := env["CUDA_VISIBLE_DEVICES"], "3"; got != want {
|
||||
t.Fatalf("CUDA_VISIBLE_DEVICES = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDevicesEnvWarnsOnConflictingOverrides(t *testing.T) {
|
||||
var logs bytes.Buffer
|
||||
oldLogger := slog.Default()
|
||||
slog.SetDefault(slog.New(slog.NewTextHandler(&logs, &slog.HandlerOptions{Level: slog.LevelDebug})))
|
||||
t.Cleanup(func() {
|
||||
slog.SetDefault(oldLogger)
|
||||
})
|
||||
|
||||
devices := []DeviceInfo{
|
||||
{
|
||||
DeviceID: DeviceID{Library: "Metal", ID: "0"},
|
||||
RunnerEnvOverrides: map[string]string{"TEST_OVERRIDE": "one"},
|
||||
},
|
||||
{
|
||||
DeviceID: DeviceID{Library: "Metal", ID: "1"},
|
||||
RunnerEnvOverrides: map[string]string{"TEST_OVERRIDE": "two"},
|
||||
},
|
||||
}
|
||||
|
||||
env := GetDevicesEnv(devices, false)
|
||||
|
||||
if got, want := env["TEST_OVERRIDE"], "two"; got != want {
|
||||
t.Fatalf("TEST_OVERRIDE = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
if !strings.Contains(logs.String(), "conflicting device environment override") {
|
||||
t.Fatalf("expected warning log, got %q", logs.String())
|
||||
}
|
||||
}
|
||||
|
|
@ -335,6 +335,13 @@ type Server struct {
|
|||
// loadMu prevents more than one load attempt from occurring at a time
|
||||
loadMu sync.Mutex
|
||||
|
||||
// infoInitialized caches the result of the dummy /info backend
|
||||
// initialization. loadMu already serializes callers, so a simple cached
|
||||
// result avoids repeated dummy loads without needing sync.Once.
|
||||
infoInitialized bool
|
||||
infoModel model.Model
|
||||
infoErr error
|
||||
|
||||
// lastLoad is the load request from the previous load attempt. Used to
|
||||
// detect if we can reuse an existing memory allocation.
|
||||
lastLoad llm.LoadRequest
|
||||
|
|
@ -1362,44 +1369,70 @@ func (s *Server) info(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
m := s.model
|
||||
|
||||
if m == nil {
|
||||
startLoad := time.Now()
|
||||
|
||||
// Dummy load to get the backend wired up
|
||||
f, err := os.CreateTemp("", "*.bin")
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
defer os.Remove(f.Name())
|
||||
|
||||
if err := ggml.WriteGGUF(f, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"tokenizer.ggml.model": "gpt2",
|
||||
}, nil); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
m, err = model.New(f.Name(), ml.BackendParams{NumThreads: runtime.NumCPU(), AllocMemory: false, GPULayers: ml.GPULayersList{{}}})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
slog.Debug("dummy model load took", "duration", time.Since(startLoad))
|
||||
m, err := s.infoModelLocked()
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
startDevices := time.Now()
|
||||
infos := m.Backend().BackendDevices()
|
||||
slog.Debug("gathering device infos took", "duration", time.Since(startDevices))
|
||||
if err := json.NewEncoder(w).Encode(&infos); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) infoModelLocked() (model.Model, error) {
|
||||
if s.model != nil {
|
||||
return s.model, nil
|
||||
}
|
||||
|
||||
if !s.infoInitialized {
|
||||
s.infoInitialized = true
|
||||
|
||||
func() {
|
||||
startLoad := time.Now()
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
s.infoErr = fmt.Errorf("panic during dummy backend initialization: %v", rec)
|
||||
}
|
||||
}()
|
||||
|
||||
// Dummy load to get the backend wired up.
|
||||
f, err := os.CreateTemp("", "*.bin")
|
||||
if err != nil {
|
||||
s.infoErr = err
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
defer os.Remove(f.Name())
|
||||
|
||||
if err := ggml.WriteGGUF(f, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"tokenizer.ggml.model": "gpt2",
|
||||
}, nil); err != nil {
|
||||
s.infoErr = err
|
||||
return
|
||||
}
|
||||
|
||||
s.infoModel, s.infoErr = model.New(f.Name(), ml.BackendParams{NumThreads: runtime.NumCPU(), AllocMemory: false, GPULayers: ml.GPULayersList{{}}})
|
||||
if s.infoErr == nil {
|
||||
slog.Debug("dummy model load took", "duration", time.Since(startLoad))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if s.infoErr != nil {
|
||||
return nil, s.infoErr
|
||||
}
|
||||
if s.infoModel == nil {
|
||||
return nil, fmt.Errorf("dummy backend initialization did not produce a model")
|
||||
}
|
||||
|
||||
return s.infoModel, nil
|
||||
}
|
||||
|
||||
func Execute(args []string) error {
|
||||
fs := flag.NewFlagSet("runner", flag.ExitOnError)
|
||||
mpath := fs.String("model", "", "Path to model binary file")
|
||||
|
|
|
|||
|
|
@ -45,9 +45,9 @@ type Client struct {
|
|||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
// statusWriter captures the last stderr line from the subprocess while
|
||||
// forwarding all output to os.Stderr. Lines longer than maxStatusLen are
|
||||
// truncated to the first maxStatusLen bytes.
|
||||
// statusWriter captures the last subprocess line while forwarding all output
|
||||
// to os.Stderr. Lines longer than maxStatusLen are truncated to the first
|
||||
// maxStatusLen bytes.
|
||||
type statusWriter struct {
|
||||
lastErrMsg string
|
||||
buf []byte
|
||||
|
|
@ -405,17 +405,12 @@ func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
|
|||
|
||||
c.cmd = cmd
|
||||
|
||||
// Forward subprocess stdout/stderr to server logs
|
||||
stdout, _ := cmd.StdoutPipe()
|
||||
stderr, _ := cmd.StderrPipe()
|
||||
status := &statusWriter{out: os.Stderr}
|
||||
c.status = status
|
||||
go func() {
|
||||
io.Copy(os.Stderr, stdout) //nolint:errcheck
|
||||
}()
|
||||
go func() {
|
||||
io.Copy(status, stderr) //nolint:errcheck
|
||||
}()
|
||||
// os/exec serializes Write calls when shared, which keeps the status writer
|
||||
// from seeing concurrent stdout/stderr fragments.
|
||||
cmd.Stdout = status
|
||||
cmd.Stderr = status
|
||||
|
||||
slog.Info("starting mlx runner subprocess", "model", c.modelName, "port", c.port)
|
||||
if err := cmd.Start(); err != nil {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue