diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index f609fedc9..2f62e0182 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -22,6 +22,7 @@ jobs: runs-on: ubuntu-latest outputs: changed: ${{ steps.changes.outputs.changed }} + app_changed: ${{ steps.changes.outputs.app_changed }} vendorsha: ${{ steps.changes.outputs.vendorsha }} steps: - uses: actions/checkout@v4 @@ -37,7 +38,8 @@ jobs: | xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))" } - echo changed=$(changed 'llama/server/**/*' 'LLAMA_CPP_VERSION' '.github/**/*') | tee -a $GITHUB_OUTPUT + echo changed=$(changed 'llama/server/**/*' 'llama/compat/**/*' 'LLAMA_CPP_VERSION' 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT + echo app_changed=$(changed 'app/**' 'app/**/*') | tee -a $GITHUB_OUTPUT echo vendorsha=$(cat LLAMA_CPP_VERSION) | tee -a $GITHUB_OUTPUT linux: @@ -285,6 +287,7 @@ jobs: run: go mod tidy --diff || (echo "Please run 'go mod tidy'." && exit 1) test: + needs: [changes] strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] @@ -319,6 +322,10 @@ jobs: if: always() run: go test -count=1 -benchtime=1x ./... + - name: go test app with live updater tag + if: ${{ needs.changes.outputs.app_changed == 'True' && contains(fromJSON('["macos-latest","windows-latest"]'), matrix.os) }} + run: go test -count=1 -tags updater_live ./app/... + - uses: golangci/golangci-lint-action@v9 with: only-new-issues: true diff --git a/app/updater/updater.go b/app/updater/updater.go index d2929d8a1..06b3fd8db 100644 --- a/app/updater/updater.go +++ b/app/updater/updater.go @@ -5,6 +5,8 @@ package updater import ( "context" "crypto/rand" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -169,22 +171,20 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo if err != nil { return fmt.Errorf("error checking update: %w", err) } + resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode) } - resp.Body.Close() - etag := strings.Trim(resp.Header.Get("etag"), "\"") - if etag == "" { - slog.Debug("no etag detected, falling back to filename based dedup") - etag = "_" - } filename := Installer _, params, err := mime.ParseMediaType(resp.Header.Get("content-disposition")) - if err == nil { + if err == nil && params["filename"] != "" { filename = params["filename"] } - stageFilename := filepath.Join(UpdateStageDir, etag, filename) + stageFilename, err := updateStagePath(UpdateStageDir, resp.Header.Get("etag"), filename) + if err != nil { + return err + } // Check to see if we already have it downloaded _, err = os.Stat(stageFilename) @@ -202,13 +202,14 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo return fmt.Errorf("error checking update: %w", err) } defer resp.Body.Close() - etag = strings.Trim(resp.Header.Get("etag"), "\"") - if etag == "" { - slog.Debug("no etag detected, falling back to filename based dedup") // TODO probably can get rid of this redundant log - etag = "_" + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode) } - stageFilename = filepath.Join(UpdateStageDir, etag, filename) + stageFilename, err = updateStagePath(UpdateStageDir, resp.Header.Get("etag"), filename) + if err != nil { + return err + } _, err = os.Stat(filepath.Dir(stageFilename)) if errors.Is(err, os.ErrNotExist) { @@ -225,10 +226,13 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo if err != nil { return fmt.Errorf("write payload %s: %w", stageFilename, err) } - defer fp.Close() if n, err := fp.Write(payload); err != nil || n != len(payload) { + _ = fp.Close() return fmt.Errorf("write payload %s: %d vs %d -- %w", stageFilename, n, len(payload), err) } + if err := fp.Close(); err != nil { + return fmt.Errorf("close payload %s: %w", stageFilename, err) + } slog.Info("new update downloaded " + stageFilename) if err := VerifyDownload(); err != nil { @@ -239,6 +243,61 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo return nil } +func updateStagePath(stageDir, etag, filename string) (string, error) { + filename, err := safeUpdateFilename(filename) + if err != nil { + return "", err + } + + stageDir, err = filepath.Abs(stageDir) + if err != nil { + return "", fmt.Errorf("resolve update stage dir: %w", err) + } + + stageFilename := filepath.Join(stageDir, updateStageETagDir(etag), filename) + if err := ensurePathInDir(stageDir, stageFilename); err != nil { + return "", err + } + + return stageFilename, nil +} + +func safeUpdateFilename(filename string) (string, error) { + filename = strings.TrimSpace(filename) + if filename == "" { + return "", errors.New("missing update filename") + } + if filename == "." || filename == ".." || + filepath.IsAbs(filename) || path.IsAbs(filename) || + strings.ContainsAny(filename, `/\:`) || + filepath.Base(filename) != filename || path.Base(filename) != filename { + return "", fmt.Errorf("unsafe update filename %q", filename) + } + return filename, nil +} + +func updateStageETagDir(etag string) string { + etag = strings.Trim(strings.TrimSpace(etag), "\"") + if etag == "" { + slog.Debug("no etag detected, falling back to filename based dedup") + return "_" + } + + sum := sha256.Sum256([]byte(etag)) + return hex.EncodeToString(sum[:]) +} + +func ensurePathInDir(dir, name string) error { + rel, err := filepath.Rel(dir, name) + if err != nil { + return fmt.Errorf("resolve update staging path: %w", err) + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || filepath.IsAbs(rel) { + return fmt.Errorf("update staging path escapes stage dir: %s", name) + } + return nil +} + func cleanupOldDownloads(stageDir string) { files, err := os.ReadDir(stageDir) if err != nil && errors.Is(err, os.ErrNotExist) { diff --git a/app/updater/updater_darwin.go b/app/updater/updater_darwin.go index 2a9dfb01b..d2159d350 100644 --- a/app/updater/updater_darwin.go +++ b/app/updater/updater_darwin.go @@ -22,6 +22,15 @@ import ( "golang.org/x/sys/unix" ) +const updateArchiveRoot = "Ollama.app" + +type bundleEntryScope int + +const ( + bundleEntryRelative bundleEntryScope = iota + bundleEntryWithArchiveRoot +) + var ( appBackupDir string SystemWidePath = "/Applications/Ollama.app" @@ -167,8 +176,12 @@ func DoUpgrade(interactive bool) error { } name := s[1] if strings.HasSuffix(name, "/") { - d := filepath.Join(BundlePath, name) - err := os.MkdirAll(d, 0o755) + d, err := bundleEntryPath(BundlePath, name, bundleEntryRelative) + if err != nil { + anyFailures = true + return err + } + err = os.MkdirAll(d, 0o755) if err != nil { anyFailures = true return fmt.Errorf("failed to mkdir %s: %w", d, err) @@ -181,30 +194,14 @@ func DoUpgrade(interactive bool) error { continue } - src, err := f.Open() + destName, err := bundleEntryPath(BundlePath, name, bundleEntryRelative) if err != nil { anyFailures = true - return fmt.Errorf("failed to open bundle file %s: %w", name, err) + return err } - destName := filepath.Join(BundlePath, name) - // Verify directory first - d := filepath.Dir(destName) - if _, err := os.Stat(d); err != nil { - err := os.MkdirAll(d, 0o755) - if err != nil { - anyFailures = true - return fmt.Errorf("failed to mkdir %s: %w", d, err) - } - } - destFile, err := os.OpenFile(destName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755) - if err != nil { + if err := extractBundleFile(f, destName, name); err != nil { anyFailures = true - return fmt.Errorf("failed to open output file %s: %w", destName, err) - } - defer destFile.Close() - if _, err := io.Copy(destFile, src); err != nil { - anyFailures = true - return fmt.Errorf("failed to open extract file %s: %w", destName, err) + return err } } for _, f := range links { @@ -225,16 +222,24 @@ func DoUpgrade(interactive bool) error { return err } link := string(buf) - if link[0] == '/' { + if link == "" { + anyFailures = true + return fmt.Errorf("bundle contains empty symlink %s", f.Name) + } + if filepath.IsAbs(link) { anyFailures = true return fmt.Errorf("bundle contains absolute symlink %s -> %s", f.Name, link) } - // Don't allow links outside of Ollama.app - if strings.HasPrefix(filepath.Join(filepath.Dir(name), link), "..") { + if !validBundleLinkTarget(name, link, bundleEntryRelative) { anyFailures = true - return fmt.Errorf("bundle contains link outside of contents %s -> %s", f.Name, link) + return fmt.Errorf("bundle contains invalid symlink %s -> %s", f.Name, link) } - if err = os.Symlink(link, filepath.Join(BundlePath, name)); err != nil { + destName, err := bundleEntryPath(BundlePath, name, bundleEntryRelative) + if err != nil { + anyFailures = true + return err + } + if err = os.Symlink(link, destName); err != nil { anyFailures = true return err } @@ -282,8 +287,11 @@ func verifyDownload() error { links := []*zip.File{} for _, f := range r.File { if strings.HasSuffix(f.Name, "/") { - d := filepath.Join(dir, f.Name) - err := os.MkdirAll(d, 0o755) + d, err := bundleEntryPath(dir, f.Name, bundleEntryWithArchiveRoot) + if err != nil { + return err + } + err = os.MkdirAll(d, 0o755) if err != nil { return fmt.Errorf("failed to mkdir %s: %w", d, err) } @@ -294,26 +302,12 @@ func verifyDownload() error { links = append(links, f) continue } - src, err := f.Open() + destName, err := bundleEntryPath(dir, f.Name, bundleEntryWithArchiveRoot) if err != nil { - return fmt.Errorf("failed to open bundle file %s: %w", f.Name, err) + return err } - destName := filepath.Join(dir, f.Name) - // Verify directory first - d := filepath.Dir(destName) - if _, err := os.Stat(d); err != nil { - err := os.MkdirAll(d, 0o755) - if err != nil { - return fmt.Errorf("failed to mkdir %s: %w", d, err) - } - } - destFile, err := os.OpenFile(destName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755) - if err != nil { - return fmt.Errorf("failed to open output file %s: %w", destName, err) - } - defer destFile.Close() - if _, err := io.Copy(destFile, src); err != nil { - return fmt.Errorf("failed to open extract file %s: %w", destName, err) + if err := extractBundleFile(f, destName, f.Name); err != nil { + return err } } for _, f := range links { @@ -326,13 +320,20 @@ func verifyDownload() error { return err } link := string(buf) - if link[0] == '/' { + if link == "" { + return fmt.Errorf("bundle contains empty symlink %s", f.Name) + } + if filepath.IsAbs(link) { return fmt.Errorf("bundle contains absolute symlink %s -> %s", f.Name, link) } - if strings.HasPrefix(filepath.Join(filepath.Dir(f.Name), link), "..") { - return fmt.Errorf("bundle contains link outside of contents %s -> %s", f.Name, link) + if !validBundleLinkTarget(f.Name, link, bundleEntryWithArchiveRoot) { + return fmt.Errorf("bundle contains invalid symlink %s -> %s", f.Name, link) } - if err = os.Symlink(link, filepath.Join(dir, f.Name)); err != nil { + destName, err := bundleEntryPath(dir, f.Name, bundleEntryWithArchiveRoot) + if err != nil { + return err + } + if err = os.Symlink(link, destName); err != nil { return err } } @@ -343,6 +344,53 @@ func verifyDownload() error { return nil } +func bundleEntryPath(root, name string, scope bundleEntryScope) (string, error) { + cleanName := filepath.Clean(filepath.FromSlash(name)) + if !filepath.IsLocal(cleanName) { + return "", fmt.Errorf("bundle contains invalid path: %s", name) + } + if scope == bundleEntryWithArchiveRoot && cleanName != updateArchiveRoot && + !strings.HasPrefix(cleanName, updateArchiveRoot+string(os.PathSeparator)) { + return "", fmt.Errorf("bundle contains invalid path: %s", name) + } + return filepath.Join(root, cleanName), nil +} + +func extractBundleFile(f *zip.File, destName, name string) error { + src, err := f.Open() + if err != nil { + return fmt.Errorf("failed to open bundle file %s: %w", name, err) + } + defer src.Close() + + d := filepath.Dir(destName) + if _, err := os.Stat(d); err != nil { + if err := os.MkdirAll(d, 0o755); err != nil { + return fmt.Errorf("failed to mkdir %s: %w", d, err) + } + } + + destFile, err := os.OpenFile(destName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755) + if err != nil { + return fmt.Errorf("failed to open output file %s: %w", destName, err) + } + defer destFile.Close() + + if _, err := io.Copy(destFile, src); err != nil { + return fmt.Errorf("failed to open extract file %s: %w", destName, err) + } + return nil +} + +func validBundleLinkTarget(name, link string, scope bundleEntryScope) bool { + cleanTarget := filepath.Clean(filepath.Join(filepath.Dir(filepath.FromSlash(name)), filepath.FromSlash(link))) + if !filepath.IsLocal(cleanTarget) { + return false + } + return scope == bundleEntryRelative || cleanTarget == updateArchiveRoot || + strings.HasPrefix(cleanTarget, updateArchiveRoot+string(os.PathSeparator)) +} + // If we detect an upgrade bundle, attempt to upgrade at startup func DoUpgradeAtStartup() error { bundle := getStagedUpdate() diff --git a/app/updater/updater_darwin_test.go b/app/updater/updater_darwin_test.go index 74a323ba3..7cc80d617 100644 --- a/app/updater/updater_darwin_test.go +++ b/app/updater/updater_darwin_test.go @@ -2,6 +2,7 @@ package updater import ( "archive/zip" + "errors" "io/fs" "os" "path/filepath" @@ -146,6 +147,46 @@ func TestDoUpgrade(t *testing.T) { } } +func TestDoUpgradeRejectsInvalidBundlePath(t *testing.T) { + tmpDir := t.TempDir() + BundlePath = filepath.Join(tmpDir, "Ollama.app") + appBackupDir = filepath.Join(tmpDir, "backup") + UpdateStageDir = filepath.Join(tmpDir, "updates") + UpgradeMarkerFile = filepath.Join(tmpDir, "upgraded") + bundle := filepath.Join(UpdateStageDir, "foo", "ollama-darwin.zip") + invalidTarget := filepath.Join(tmpDir, "invalid-entry") + + if err := os.MkdirAll(filepath.Join(BundlePath, "Contents", "MacOS"), 0o755); err != nil { + t.Fatal("failed to create empty dirs") + } + if err := os.WriteFile(filepath.Join(BundlePath, "Contents", "MacOS", "Ollama"), []byte("old app"), 0o755); err != nil { + t.Fatal("failed to create old app") + } + if err := os.MkdirAll(filepath.Dir(bundle), 0o755); err != nil { + t.Fatal("failed to create empty dirs") + } + if err := zipCreationHelper(bundle, []testPayload{{ + Name: "Ollama.app/../invalid-entry", + Body: []byte("payload"), + }}); err != nil { + t.Fatal(err) + } + + if err := DoUpgrade(false); err == nil { + t.Fatal("expected failure with invalid bundle path") + } else if !strings.Contains(err.Error(), "bundle contains invalid path") { + t.Fatalf("unexpected error with invalid bundle path: %s", err) + } + if _, err := os.Stat(invalidTarget); err == nil { + t.Fatalf("invalid bundle path wrote %s", invalidTarget) + } else if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("unexpected stat error for %s: %s", invalidTarget, err) + } + if _, err := os.Stat(filepath.Join(BundlePath, "Contents", "MacOS", "Ollama")); err != nil { + t.Fatalf("old app was not restored: %s", err) + } +} + func TestDoUpgradeAtStartup(t *testing.T) { tmpDir := t.TempDir() BundlePath = filepath.Join(tmpDir, "Ollama.app") @@ -203,7 +244,7 @@ func TestVerifyDownloadFailures(t *testing.T) { in []testPayload expected string }{ - {"breakout", []testPayload{ + {"invalid symlink target", []testPayload{ { Name: "Ollama.app/", Body: []byte{}, @@ -212,15 +253,34 @@ func TestVerifyDownloadFailures(t *testing.T) { Body: []byte("cli payload here"), }, { Name: "Ollama.app/Contents/MacOS/Ollama", - Body: []byte("../../../../breakout"), + Body: []byte("../../../../invalid-target"), Mode: os.ModeSymlink, }, - }, "bundle contains link outside"}, + }, "bundle contains invalid symlink"}, + {"invalid archive symlink target", []testPayload{ + { + Name: "Ollama.app/Contents/MacOS/Ollama", + Body: []byte("../../../invalid-target"), + Mode: os.ModeSymlink, + }, + }, "bundle contains invalid symlink"}, {"absolute", []testPayload{{ Name: "Ollama.app/Contents/MacOS/Ollama", Body: []byte("/etc/foo"), Mode: os.ModeSymlink, }}, "bundle contains absolute"}, + {"invalid relative file", []testPayload{{ + Name: "Ollama.app/../invalid-entry", + Body: []byte("payload"), + }}, "bundle contains invalid path"}, + {"invalid relative directory", []testPayload{{ + Name: "Ollama.app/../invalid-entry/", + Body: []byte{}, + }}, "bundle contains invalid path"}, + {"absolute file", []testPayload{{ + Name: filepath.Join(tmpDir, "invalid-entry"), + Body: []byte("payload"), + }}, "bundle contains invalid path"}, {"missing", []testPayload{{ Name: "Ollama.app/Contents/MacOS/Ollama", Body: []byte("../nothere"), @@ -242,6 +302,11 @@ func TestVerifyDownloadFailures(t *testing.T) { if err == nil || !strings.Contains(err.Error(), tt.expected) { t.Fatalf("expected \"%s\" got %s", tt.expected, err) } + if _, err := os.Stat(filepath.Join(tmpDir, "invalid-entry")); err == nil { + t.Fatal("invalid bundle path wrote unexpected file") + } else if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("unexpected stat error for invalid file: %s", err) + } }) } } diff --git a/app/updater/updater_live_test.go b/app/updater/updater_live_test.go new file mode 100644 index 000000000..eb18f7a79 --- /dev/null +++ b/app/updater/updater_live_test.go @@ -0,0 +1,127 @@ +//go:build (windows || darwin) && updater_live + +package updater + +import ( + "context" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "github.com/ollama/ollama/app/store" + "github.com/ollama/ollama/app/version" +) + +// TestLiveAppUpdate exercises the production update endpoint and downloads the +// current OS update artifact. It is intentionally excluded from normal test +// runs because it depends on ollama.com and downloads a release artifact. +// +// Run with: +// +// go test -tags updater_live -run TestLiveAppUpdate ./app/updater +func TestLiveAppUpdate(t *testing.T) { + const spoofedVersion = "0.20.0" + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + oldUpdateStageDir := UpdateStageDir + oldUpdateDownloaded := UpdateDownloaded + oldVerifyDownload := VerifyDownload + oldVersion := version.Version + defer func() { + UpdateStageDir = oldUpdateStageDir + UpdateDownloaded = oldUpdateDownloaded + VerifyDownload = oldVerifyDownload + version.Version = oldVersion + }() + + version.Version = spoofedVersion + + expectedFilename := "" + switch runtime.GOOS { + case "windows": + t.Setenv("LOCALAPPDATA", t.TempDir()) + expectedFilename = "OllamaSetup.exe" + case "darwin": + expectedFilename = "Ollama-darwin.zip" + default: + t.Fatalf("unsupported updater live test OS %q", runtime.GOOS) + } + + UpdateStageDir = filepath.Join(t.TempDir(), "updates") + UpdateDownloaded = false + verifyCalled := false + VerifyDownload = func() error { + verifyCalled = true + return verifyDownload() + } + + updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "db.sqlite")}} + defer updater.Store.Close() + + available, updateResp := updater.checkForUpdate(ctx) + if !available { + t.Fatalf("expected production update check to offer an update for spoofed version %s", spoofedVersion) + } + if updateResp.UpdateURL == "" { + t.Fatal("production update response did not include a download URL") + } + t.Logf("production update version=%q url=%q", updateResp.UpdateVersion, updateResp.UpdateURL) + + if err := updater.DownloadNewRelease(ctx, updateResp); err != nil { + t.Fatalf("download production update: %v", err) + } + + staged := getStagedUpdate() + if staged == "" { + t.Fatal("production update was not staged") + } + t.Logf("staged production update at %s", staged) + + assertPathInsideDir(t, UpdateStageDir, staged) + if filepath.Base(staged) != expectedFilename { + t.Fatalf("expected staged %s update filename to be %q, got %q", runtime.GOOS, expectedFilename, filepath.Base(staged)) + } + expectedExt := filepath.Ext(expectedFilename) + if filepath.Ext(staged) != expectedExt { + t.Fatalf("expected staged %s update to be a %s artifact, got %s", runtime.GOOS, expectedExt, staged) + } + + info, err := os.Stat(staged) + if err != nil { + t.Fatalf("stat staged update: %v", err) + } + if info.Size() == 0 { + t.Fatal("staged production update is empty") + } + + if !verifyCalled { + t.Fatal("DownloadNewRelease did not call VerifyDownload") + } + t.Logf("production updater download path verified staged %s update", runtime.GOOS) +} + +func assertPathInsideDir(t *testing.T, dir, name string) { + t.Helper() + + dir, err := filepath.Abs(dir) + if err != nil { + t.Fatal(err) + } + name, err = filepath.Abs(name) + if err != nil { + t.Fatal(err) + } + + rel, err := filepath.Rel(dir, name) + if err != nil { + t.Fatal(err) + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || filepath.IsAbs(rel) { + t.Fatalf("staged update escaped update stage dir: %s", name) + } +} diff --git a/app/updater/updater_test.go b/app/updater/updater_test.go index ae8292ebf..7673e9942 100644 --- a/app/updater/updater_test.go +++ b/app/updater/updater_test.go @@ -11,7 +11,9 @@ import ( "log/slog" "net/http" "net/http/httptest" + "os" "path/filepath" + "strings" "sync/atomic" "testing" "time" @@ -19,6 +21,52 @@ import ( "github.com/ollama/ollama/app/store" ) +func TestUpdateStagePathRejectsUnsafeFilename(t *testing.T) { + stageDir := t.TempDir() + for _, tt := range []struct { + name string + filename string + }{ + {"empty", ""}, + {"dot", "."}, + {"dotdot", ".."}, + {"posix_parent", "../OllamaSetup.exe"}, + {"windows_parent", `..\OllamaSetup.exe`}, + {"posix_absolute_tmp", "/tmp/OllamaSetup.exe"}, + {"darwin_absolute_app", "/Applications/Ollama.app"}, + {"darwin_bundle_path", "Ollama.app/Contents/MacOS/Ollama"}, + {"darwin_user_download", "~/Downloads/Ollama-darwin.zip"}, + {"windows_absolute", `C:\Users\Public\OllamaSetup.exe`}, + {"colon", "Ollama:Setup.exe"}, + } { + t.Run(tt.name, func(t *testing.T) { + if _, err := updateStagePath(stageDir, "etag", tt.filename); err == nil { + t.Fatal("expected unsafe filename to be rejected") + } + }) + } +} + +func TestUpdateStagePathHashesETag(t *testing.T) { + stageDir := t.TempDir() + stageFilename, err := updateStagePath(stageDir, `../escaped`, "OllamaSetup.exe") + if err != nil { + t.Fatal(err) + } + + rel, err := filepath.Rel(stageDir, stageFilename) + if err != nil { + t.Fatal(err) + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || filepath.IsAbs(rel) { + t.Fatalf("stage filename escaped stage dir: %s", stageFilename) + } + etagDir := filepath.Base(filepath.Dir(stageFilename)) + if etagDir == ".." || etagDir == "escaped" || strings.ContainsAny(etagDir, `/\`) { + t.Fatalf("stage filename used raw etag path component: %s", stageFilename) + } +} + func TestIsNewReleaseAvailable(t *testing.T) { slog.SetLogLoggerLevel(slog.LevelDebug) var server *httptest.Server @@ -47,6 +95,223 @@ func TestIsNewReleaseAvailable(t *testing.T) { } } +func TestDownloadNewReleaseRejectsUnsafeHeaderFilename(t *testing.T) { + UpdateStageDir = t.TempDir() + oldInstaller := Installer + oldVerifyDownload := VerifyDownload + oldUpdateDownloaded := UpdateDownloaded + defer func() { + Installer = oldInstaller + VerifyDownload = oldVerifyDownload + UpdateDownloaded = oldUpdateDownloaded + }() + Installer = "OllamaSetup.exe" + UpdateDownloaded = false + VerifyDownload = func() error { + t.Fatal("verification should not run for rejected downloads") + return nil + } + + var getAttempted atomic.Bool + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodHead { + w.Header().Set("ETag", `"safe"`) + w.Header().Set("Content-Disposition", `attachment; filename="../OllamaSetup.exe"`) + w.WriteHeader(http.StatusOK) + return + } + getAttempted.Store(true) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + updater := &Updater{} + err := updater.DownloadNewRelease(t.Context(), UpdateResponse{UpdateURL: server.URL + "/download"}) + if err == nil || !strings.Contains(err.Error(), "unsafe update filename") { + t.Fatalf("expected unsafe filename error, got %v", err) + } + if getAttempted.Load() { + t.Fatal("download should not continue after unsafe filename") + } + if _, err := os.Stat(filepath.Join(filepath.Dir(UpdateStageDir), "OllamaSetup.exe")); err == nil { + t.Fatal("download escaped update stage dir") + } +} + +func TestDownloadNewReleaseDoesNotUseRawETagAsPathComponent(t *testing.T) { + UpdateStageDir = t.TempDir() + oldInstaller := Installer + oldVerifyDownload := VerifyDownload + oldUpdateDownloaded := UpdateDownloaded + defer func() { + Installer = oldInstaller + VerifyDownload = oldVerifyDownload + UpdateDownloaded = oldUpdateDownloaded + }() + Installer = "OllamaSetup.exe" + UpdateDownloaded = false + VerifyDownload = func() error { + return nil + } + + payload := []byte("payload") + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("ETag", `"../escaped"`) + w.WriteHeader(http.StatusOK) + if r.Method == http.MethodGet { + _, _ = w.Write(payload) + } + })) + defer server.Close() + + updater := &Updater{} + if err := updater.DownloadNewRelease(t.Context(), UpdateResponse{UpdateURL: server.URL + "/download"}); err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(filepath.Join(filepath.Dir(UpdateStageDir), "escaped", Installer)); err == nil { + t.Fatal("download escaped update stage dir via etag") + } + + entries, err := os.ReadDir(UpdateStageDir) + if err != nil { + t.Fatal(err) + } + if len(entries) != 1 { + t.Fatalf("expected one staged update dir, got %d", len(entries)) + } + stageFilename := filepath.Join(UpdateStageDir, entries[0].Name(), Installer) + got, err := os.ReadFile(stageFilename) + if err != nil { + t.Fatal(err) + } + if string(got) != string(payload) { + t.Fatalf("unexpected staged payload %q", got) + } +} + +func TestBackgroundCheckerSkipsAlreadyStagedETagDownload(t *testing.T) { + UpdateStageDir = t.TempDir() + oldInstaller := Installer + oldVerifyDownload := VerifyDownload + oldUpdateDownloaded := UpdateDownloaded + oldUpdateCheckInitialDelay := UpdateCheckInitialDelay + oldUpdateCheckInterval := UpdateCheckInterval + oldUpdateCheckURLBase := UpdateCheckURLBase + defer func() { + Installer = oldInstaller + VerifyDownload = oldVerifyDownload + UpdateDownloaded = oldUpdateDownloaded + UpdateCheckInitialDelay = oldUpdateCheckInitialDelay + UpdateCheckInterval = oldUpdateCheckInterval + UpdateCheckURLBase = oldUpdateCheckURLBase + }() + Installer = "OllamaSetup.exe" + UpdateDownloaded = false + UpdateCheckInitialDelay = time.Millisecond + UpdateCheckInterval = 5 * time.Millisecond + + var verifyCount atomic.Int32 + VerifyDownload = func() error { + verifyCount.Add(1) + return nil + } + + headETag := `"old-update"` + getETag := `"download-response-etag"` + payload := []byte("payload") + var headCount atomic.Int32 + var getCount atomic.Int32 + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/update.json": + w.Write([]byte( + fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`, + server.URL+"/9.9.9/"+Installer))) + case "/9.9.9/" + Installer: + w.Header().Set("Content-Disposition", `attachment; filename="OllamaSetup.exe"`) + switch r.Method { + case http.MethodHead: + etag := headETag + if getCount.Load() > 0 { + etag = getETag + } + w.Header().Set("ETag", etag) + headCount.Add(1) + w.WriteHeader(http.StatusOK) + case http.MethodGet: + w.Header().Set("ETag", getETag) + getCount.Add(1) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(payload) + default: + t.Errorf("unexpected request method %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + } + default: + t.Errorf("unexpected request path %s", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + UpdateCheckURLBase = server.URL + "/update.json" + + updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}} + defer updater.Store.Close() + settings, err := updater.Store.Settings() + if err != nil { + t.Fatal(err) + } + settings.AutoUpdateEnabled = true + if err := updater.Store.SetSettings(settings); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + callbacks := make(chan string, 4) + updater.StartBackgroundUpdaterChecker(ctx, func(ver string) error { + callbacks <- ver + return nil + }) + + for range 2 { + select { + case <-callbacks: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for repeated update checks") + } + } + cancel() + + stageFilename, err := updateStagePath(UpdateStageDir, getETag, Installer) + if err != nil { + t.Fatal(err) + } + got, err := os.ReadFile(stageFilename) + if err != nil { + t.Fatal(err) + } + if string(got) != string(payload) { + t.Fatalf("unexpected staged payload %q", got) + } + + if headCount.Load() < 2 { + t.Fatalf("HEAD count = %d, want at least 2", headCount.Load()) + } + if getCount.Load() != 1 { + t.Fatalf("GET count = %d, want 1", getCount.Load()) + } + if verifyCount.Load() != 1 { + t.Fatalf("verification count = %d, want 1", verifyCount.Load()) + } + if !UpdateDownloaded { + t.Fatal("UpdateDownloaded should stay true for already staged update") + } +} + func TestBackgoundChecker(t *testing.T) { UpdateStageDir = t.TempDir() haveUpdate := false diff --git a/app/updater/updater_windows.go b/app/updater/updater_windows.go index e54aaa3b1..6fc28e6c9 100644 --- a/app/updater/updater_windows.go +++ b/app/updater/updater_windows.go @@ -1,6 +1,7 @@ package updater import ( + "crypto/x509" "errors" "fmt" "log/slog" @@ -18,6 +19,30 @@ import ( var runningInstaller string +var ( + crypt32 = windows.NewLazySystemDLL("crypt32.dll") + procCryptMsgGetParam = crypt32.NewProc("CryptMsgGetParam") + procCryptMsgClose = crypt32.NewProc("CryptMsgClose") +) + +const cmsgSignerInfoParam = 6 + +type cmsgSignerInfo struct { + Version uint32 + Issuer windows.CertNameBlob + SerialNumber windows.CryptIntegerBlob + HashAlgorithm windows.CryptAlgorithmIdentifier + HashEncryptionAlgorithm windows.CryptAlgorithmIdentifier + EncryptedHash windows.CryptDataBlob + AuthAttrs cryptAttributes + UnauthAttrs cryptAttributes +} + +type cryptAttributes struct { + Count uint32 + Attributes unsafe.Pointer +} + type OSVERSIONINFOEXW struct { dwOSVersionInfoSize uint32 dwMajorVersion uint32 @@ -99,6 +124,12 @@ func DoUpgrade(interactive bool) error { return fmt.Errorf("failed to lookup downloads") } + if err := VerifyDownload(); err != nil { + _ = os.Remove(bundle) + slog.Warn("verification failure", "bundle", bundle, "error", err) + return fmt.Errorf("staged update verification failed: %w", err) + } + // We move the installer to ensure we don't race with multiple apps starting in quick succession if err := os.Rename(bundle, runningInstaller); err != nil { return fmt.Errorf("unable to rename %s -> %s : %w", bundle, runningInstaller, err) @@ -184,6 +215,150 @@ func DoPostUpgradeCleanup() error { } func verifyDownload() error { + bundle := getStagedUpdate() + if bundle == "" { + return fmt.Errorf("failed to lookup downloads") + } + slog.Debug("verifying update", "bundle", bundle) + + if err := verifyWindowsInstallerSignature(bundle); err != nil { + return fmt.Errorf("signature verification failed: %w", err) + } + return nil +} + +func verifyWindowsInstallerSignature(filename string) error { + filename16, err := windows.UTF16PtrFromString(filename) + if err != nil { + return err + } + + data := &windows.WinTrustData{ + Size: uint32(unsafe.Sizeof(windows.WinTrustData{})), + UIChoice: windows.WTD_UI_NONE, + RevocationChecks: windows.WTD_REVOKE_WHOLECHAIN, + UnionChoice: windows.WTD_CHOICE_FILE, + StateAction: windows.WTD_STATEACTION_VERIFY, + UIContext: windows.WTD_UICONTEXT_INSTALL, + FileOrCatalogOrBlobOrSgnrOrCert: unsafe.Pointer(&windows.WinTrustFileInfo{ + Size: uint32(unsafe.Sizeof(windows.WinTrustFileInfo{})), + FilePath: filename16, + }), + } + + verifyErr := windows.WinVerifyTrustEx(windows.InvalidHWND, &windows.WINTRUST_ACTION_GENERIC_VERIFY_V2, data) + data.StateAction = windows.WTD_STATEACTION_CLOSE + closeErr := windows.WinVerifyTrustEx(windows.InvalidHWND, &windows.WINTRUST_ACTION_GENERIC_VERIFY_V2, data) + if verifyErr != nil { + return verifyErr + } + if closeErr != nil { + return fmt.Errorf("close WinVerifyTrust state: %w", closeErr) + } + + subject, err := windowsInstallerSignerSubject(filename) + if err != nil { + return err + } + slog.Debug("verified update signature", "subject", subject) + return nil +} + +func windowsInstallerSignerSubject(filename string) (string, error) { + filename16, err := windows.UTF16PtrFromString(filename) + if err != nil { + return "", err + } + + var certStore windows.Handle + var msg windows.Handle + if err := windows.CryptQueryObject( + windows.CERT_QUERY_OBJECT_FILE, + unsafe.Pointer(filename16), + windows.CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED_EMBED, + windows.CERT_QUERY_FORMAT_FLAG_BINARY, + 0, + nil, + nil, + nil, + &certStore, + &msg, + nil, + ); err != nil { + return "", err + } + defer windows.CertCloseStore(certStore, 0) //nolint:errcheck + defer cryptMsgClose(msg) //nolint:errcheck + + var signerInfoSize uint32 + if err := cryptMsgGetParam(msg, cmsgSignerInfoParam, 0, nil, &signerInfoSize); err != nil { + return "", err + } + if signerInfoSize == 0 { + return "", fmt.Errorf("missing signer info") + } + + signerInfoBuf := make([]byte, signerInfoSize) + if err := cryptMsgGetParam(msg, cmsgSignerInfoParam, 0, unsafe.Pointer(&signerInfoBuf[0]), &signerInfoSize); err != nil { + return "", err + } + signerInfo := (*cmsgSignerInfo)(unsafe.Pointer(&signerInfoBuf[0])) + certInfo := windows.CertInfo{ + Issuer: signerInfo.Issuer, + SerialNumber: signerInfo.SerialNumber, + } + + cert, err := windows.CertFindCertificateInStore( + certStore, + windows.X509_ASN_ENCODING|windows.PKCS_7_ASN_ENCODING, + 0, + windows.CERT_FIND_SUBJECT_CERT, + unsafe.Pointer(&certInfo), + nil, + ) + if err != nil { + return "", err + } + defer windows.CertFreeCertificateContext(cert) //nolint:errcheck + + parsed, err := x509.ParseCertificate(unsafe.Slice(cert.EncodedCert, cert.Length)) + if err != nil { + return "", err + } + + for _, org := range parsed.Subject.Organization { + if org == "Ollama Inc." { + return parsed.Subject.String(), nil + } + } + return "", fmt.Errorf("unexpected signer: %s", parsed.Subject.String()) +} + +func cryptMsgGetParam(msg windows.Handle, paramType, index uint32, data unsafe.Pointer, size *uint32) error { + r1, _, e1 := procCryptMsgGetParam.Call( + uintptr(msg), + uintptr(paramType), + uintptr(index), + uintptr(data), + uintptr(unsafe.Pointer(size)), + ) + if r1 == 0 { + if e1 != syscall.Errno(0) { + return e1 + } + return syscall.EINVAL + } + return nil +} + +func cryptMsgClose(msg windows.Handle) error { + r1, _, e1 := procCryptMsgClose.Call(uintptr(msg)) + if r1 == 0 { + if e1 != syscall.Errno(0) { + return e1 + } + return syscall.EINVAL + } return nil } diff --git a/app/updater/updater_windows_test.go b/app/updater/updater_windows_test.go index 9c665be35..de86b279f 100644 --- a/app/updater/updater_windows_test.go +++ b/app/updater/updater_windows_test.go @@ -1,13 +1,85 @@ -//go:build windows || darwin +//go:build windows package updater import ( "log/slog" + "os" + "path/filepath" + "strings" "testing" ) +func TestVerifyDownloadRejectsUnsignedWindowsInstaller(t *testing.T) { + oldUpdateStageDir := UpdateStageDir + defer func() { + UpdateStageDir = oldUpdateStageDir + }() + + t.Setenv("LOCALAPPDATA", t.TempDir()) + UpdateStageDir = t.TempDir() + bundle := filepath.Join(UpdateStageDir, "etag", "OllamaSetup.exe") + if err := os.MkdirAll(filepath.Dir(bundle), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(bundle, []byte("not a signed installer"), 0o755); err != nil { + t.Fatal(err) + } + + err := verifyDownload() + if err == nil || !strings.Contains(err.Error(), "signature verification failed") { + t.Fatalf("expected signature verification failure, got %v", err) + } +} + +func TestDoUpgradeAtStartupRejectsUnsignedWindowsInstaller(t *testing.T) { + oldUpdateStageDir := UpdateStageDir + oldRunningInstaller := runningInstaller + oldUpgradeLogFile := UpgradeLogFile + oldUpgradeMarkerFile := UpgradeMarkerFile + oldVerifyDownload := VerifyDownload + defer func() { + UpdateStageDir = oldUpdateStageDir + runningInstaller = oldRunningInstaller + UpgradeLogFile = oldUpgradeLogFile + UpgradeMarkerFile = oldUpgradeMarkerFile + VerifyDownload = oldVerifyDownload + }() + + t.Setenv("LOCALAPPDATA", t.TempDir()) + UpdateStageDir = t.TempDir() + runDir := t.TempDir() + runningInstaller = filepath.Join(runDir, "OllamaSetup.exe") + UpgradeLogFile = filepath.Join(runDir, "upgrade.log") + UpgradeMarkerFile = filepath.Join(runDir, "upgraded") + VerifyDownload = verifyDownload + + bundle := filepath.Join(UpdateStageDir, "etag", "OllamaSetup.exe") + if err := os.MkdirAll(filepath.Dir(bundle), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(bundle, []byte("not a signed installer"), 0o755); err != nil { + t.Fatal(err) + } + + err := DoUpgradeAtStartup() + if err == nil || !strings.Contains(err.Error(), "signature verification failed") { + t.Fatalf("expected signature verification failure, got %v", err) + } + if _, err := os.Stat(runningInstaller); !os.IsNotExist(err) { + t.Fatalf("unsigned installer was moved before verification failed: %v", err) + } + if _, err := os.Stat(bundle); !os.IsNotExist(err) { + t.Fatalf("unsigned staged installer was not removed after verification failure: %v", err) + } +} + func TestIsInstallerRunning(t *testing.T) { + oldInstaller := Installer + defer func() { + Installer = oldInstaller + }() + slog.SetLogLoggerLevel(slog.LevelDebug) Installer = "go.exe" if !isInstallerRunning() { diff --git a/envconfig/test_home_test.go b/envconfig/test_home_test.go index 993f1c0aa..9cae6bf9d 100644 --- a/envconfig/test_home_test.go +++ b/envconfig/test_home_test.go @@ -6,5 +6,6 @@ func setTestHome(t *testing.T, home string) { t.Helper() t.Setenv("HOME", home) t.Setenv("USERPROFILE", home) + t.Setenv("OLLAMA_MODELS", "") ReloadServerConfig() } diff --git a/server/routes.go b/server/routes.go index f64d5e7c7..fbfadf9a4 100644 --- a/server/routes.go +++ b/server/routes.go @@ -3127,8 +3127,12 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo s.sched.expireRunnersForRuntimeOOM(m, err) // Only send JSON error if streaming hasn't started yet // (once streaming starts, headers are committed and we can't change status code) - if !streamStarted { + if !isStreaming || !streamStarted { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } else { + data, _ := json.Marshal(gin.H{"error": err.Error()}) + c.Writer.Write(append(data, '\n')) + c.Writer.Flush() } return } diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 489e3e78b..56837b58e 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "io" "net/http" "net/http/httptest" @@ -2989,3 +2990,120 @@ func TestImageGenerateStreamFalse(t *testing.T) { t.Errorf("expected done=true") } } + +func newImageGenerateTestServer(t *testing.T, mock *mockRunner) Server { + t.Helper() + + t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096") + gin.SetMode(gin.TestMode) + + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + + n := model.ParseName("test-image") + cfg := model.ConfigV2{Capabilities: []string{"image"}} + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(&cfg); err != nil { + t.Fatal(err) + } + configLayer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json") + if err != nil { + t.Fatal(err) + } + if err := manifest.WriteManifest(n, configLayer, nil); err != nil { + t.Fatal(err) + } + + loadedModel, err := GetModel("test-image") + if err != nil { + t.Fatal(err) + } + + opts := api.DefaultOptions() + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: map[string]*runnerRef{ + schedulerModelKey(loadedModel): { + llama: mock, + Options: &opts, + model: loadedModel, + isImagegen: true, + numParallel: 1, + }, + }, + newServerFn: newMockServer(mock), + getGpuFn: getGpuFn, + getSystemInfoFn: getSystemInfoFn, + }, + } + + go s.sched.Run(t.Context()) + return s +} + +func TestImageGenerateStreamFalseErrorAfterProgress(t *testing.T) { + mock := mockRunner{} + mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { + fn(llm.CompletionResponse{Step: 1, TotalSteps: 3, Done: false}) + return errors.New("runner died") + } + s := newImageGenerateTestServer(t, &mock) + + streamFalse := false + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test-image", + Prompt: "test prompt", + Stream: &streamFalse, + }) + + if w.Code != http.StatusInternalServerError { + t.Fatalf("expected status 500, got %d: %s", w.Code, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "runner died") { + t.Fatalf("expected runner error in body, got %q", w.Body.String()) + } +} + +func TestImageGenerateStreamingErrorAfterProgress(t *testing.T) { + mock := mockRunner{} + mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { + fn(llm.CompletionResponse{Step: 1, TotalSteps: 3, Done: false}) + return errors.New("runner died") + } + s := newImageGenerateTestServer(t, &mock) + + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test-image", + Prompt: "test prompt", + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200 after streaming started, got %d: %s", w.Code, w.Body.String()) + } + lines := strings.Split(strings.TrimSpace(w.Body.String()), "\n") + if len(lines) != 2 { + t.Fatalf("expected progress and error lines, got %d:\n%s", len(lines), w.Body.String()) + } + + var progress api.GenerateResponse + if err := json.Unmarshal([]byte(lines[0]), &progress); err != nil { + t.Fatalf("failed to parse progress response: %v", err) + } + if progress.Completed != 1 || progress.Total != 3 || progress.Done { + t.Fatalf("progress response = %+v", progress) + } + + var errorResponse struct { + Error string `json:"error"` + } + if err := json.Unmarshal([]byte(lines[1]), &errorResponse); err != nil { + t.Fatalf("failed to parse error response: %v", err) + } + if errorResponse.Error != "runner died" { + t.Fatalf("error = %q, want runner died", errorResponse.Error) + } +} diff --git a/server/test_home_test.go b/server/test_home_test.go index 7a0393684..3f724c910 100644 --- a/server/test_home_test.go +++ b/server/test_home_test.go @@ -10,5 +10,6 @@ func setTestHome(t *testing.T, home string) { t.Helper() t.Setenv("HOME", home) t.Setenv("USERPROFILE", home) + t.Setenv("OLLAMA_MODELS", "") envconfig.ReloadServerConfig() } diff --git a/x/imagegen/runner.go b/x/imagegen/runner.go index d92b59059..59a84e915 100644 --- a/x/imagegen/runner.go +++ b/x/imagegen/runner.go @@ -15,6 +15,7 @@ import ( "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/x/imagegen/mlx" + "github.com/ollama/ollama/x/internal/mlxthread" ) // Execute is the entry point for the unified MLX runner subprocess. @@ -45,17 +46,30 @@ func Execute(args []string) error { return fmt.Errorf("imagegen runner only supports image generation models") } - // Initialize MLX only for image generation mode. - if err := mlx.InitMLX(); err != nil { - slog.Error("unable to initialize MLX", "error", err) + worker, err := mlxthread.Start("imagegen", func() error { + if err := mlx.InitMLX(); err != nil { + slog.Error("unable to initialize MLX", "error", err) + return err + } + slog.Info("MLX library initialized") + return nil + }) + if err != nil { return err } - slog.Info("MLX library initialized") // Create and start server - server, err := newServer(*modelName, *port) - if err != nil { - return fmt.Errorf("failed to create server: %w", err) + var server *server + if err := worker.Do(context.Background(), func() error { + var err error + server, err = newServer(*modelName, *port) + if err != nil { + return fmt.Errorf("failed to create server: %w", err) + } + server.mlxThread = worker + return nil + }); err != nil { + return err } // Set up HTTP handlers @@ -77,7 +91,17 @@ func Execute(args []string) error { slog.Info("shutting down mlx runner") ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - httpServer.Shutdown(ctx) + if err := httpServer.Shutdown(ctx); err != nil { + slog.Warn("graceful shutdown timed out", "error", err) + if err := httpServer.Close(); err != nil { + slog.Warn("failed to close http server", "error", err) + } + } + if err := worker.Stop(ctx, func() { + mlx.ClearCache() + }); err != nil { + slog.Warn("failed to stop mlx worker", "error", err) + } close(done) }() @@ -110,6 +134,7 @@ func detectModelMode(modelName string) ModelMode { type server struct { modelName string port int + mlxThread *mlxthread.Thread // Image generation model. imageModel ImageModel @@ -147,5 +172,10 @@ func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) { return } - s.handleImageCompletion(w, r, req) + if err := s.mlxThread.Do(r.Context(), func() error { + s.handleImageCompletion(w, r, req) + return nil + }); err != nil && r.Context().Err() == nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } } diff --git a/x/imagegen/server.go b/x/imagegen/server.go index 3c74ae519..6017f0602 100644 --- a/x/imagegen/server.go +++ b/x/imagegen/server.go @@ -367,9 +367,16 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f // Check if subprocess is still alive if s.HasExited() { slog.Error("mlx subprocess has exited unexpectedly") + if errMsg := s.getLastErr(); errMsg != "" { + return fmt.Errorf("mlx runner closed response before completion: %s", errMsg) + } } - return scanErr + if scanErr != nil { + return scanErr + } + + return errors.New("mlx runner closed response before completion") } func (s *Server) Chat(ctx context.Context, req llm.ChatRequest, fn func(llm.ChatResponse)) error { diff --git a/x/imagegen/server_test.go b/x/imagegen/server_test.go new file mode 100644 index 000000000..8f2938297 --- /dev/null +++ b/x/imagegen/server_test.go @@ -0,0 +1,102 @@ +package imagegen + +import ( + "context" + "encoding/json" + "io" + "net/http" + "strings" + "testing" + + "github.com/ollama/ollama/llm" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +func newCompletionTestServer(handler func(*http.Request) string) *Server { + return &Server{ + port: 11434, + done: make(chan error, 1), + client: &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + body := handler(req) + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + Request: req, + }, nil + }), + }, + } +} + +func TestCompletionReturnsImageData(t *testing.T) { + s := newCompletionTestServer(func(r *http.Request) string { + if r.URL.Path != "/completion" { + t.Fatalf("path = %q, want /completion", r.URL.Path) + } + + var req Request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatal(err) + } + if req.Prompt != "test prompt" || req.Width != 512 || req.Height != 256 || req.Steps != 7 || req.Seed != 42 { + t.Fatalf("unexpected request: %+v", req) + } + if len(req.Images) != 1 || string(req.Images[0]) != "input-image" { + t.Fatalf("images = %q, want input-image", req.Images) + } + + return `{"step":1,"total":2}` + "\n" + + `{"done":true,"image":"base64png"}` + "\n" + }) + + var responses []llm.CompletionResponse + err := s.Completion(context.Background(), llm.CompletionRequest{ + Prompt: "test prompt", + Width: 512, + Height: 256, + Steps: 7, + Seed: 42, + Images: []llm.ImageData{{Data: []byte("input-image")}}, + }, func(resp llm.CompletionResponse) { + responses = append(responses, resp) + }) + if err != nil { + t.Fatal(err) + } + if len(responses) != 2 { + t.Fatalf("responses = %d, want 2", len(responses)) + } + if responses[0].Step != 1 || responses[0].TotalSteps != 2 || responses[0].Done { + t.Fatalf("progress response = %+v", responses[0]) + } + if !responses[1].Done || responses[1].Image != "base64png" { + t.Fatalf("final response = %+v", responses[1]) + } +} + +func TestCompletionEOFBeforeDoneReturnsError(t *testing.T) { + s := newCompletionTestServer(func(r *http.Request) string { + return `{"step":1,"total":2}` + "\n" + }) + + var responses []llm.CompletionResponse + err := s.Completion(context.Background(), llm.CompletionRequest{Prompt: "test prompt"}, func(resp llm.CompletionResponse) { + responses = append(responses, resp) + }) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "closed response before completion") { + t.Fatalf("error = %v", err) + } + if len(responses) != 1 || responses[0].Done { + t.Fatalf("responses = %+v, want one non-done progress response", responses) + } +}