Merge remote-tracking branch 'upstream/main' into llama-runner-phase-0

# Conflicts:
#	.github/workflows/test.yaml
This commit is contained in:
Daniel Hiltgen 2026-05-11 13:27:29 -07:00
commit 6dbb3483ee
15 changed files with 1162 additions and 81 deletions

View file

@ -22,6 +22,7 @@ jobs:
runs-on: ubuntu-latest
outputs:
changed: ${{ steps.changes.outputs.changed }}
app_changed: ${{ steps.changes.outputs.app_changed }}
vendorsha: ${{ steps.changes.outputs.vendorsha }}
steps:
- uses: actions/checkout@v4
@ -37,7 +38,8 @@ jobs:
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
}
echo changed=$(changed 'llama/server/**/*' 'LLAMA_CPP_VERSION' '.github/**/*') | tee -a $GITHUB_OUTPUT
echo changed=$(changed 'llama/server/**/*' 'llama/compat/**/*' 'LLAMA_CPP_VERSION' 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT
echo app_changed=$(changed 'app/**' 'app/**/*') | tee -a $GITHUB_OUTPUT
echo vendorsha=$(cat LLAMA_CPP_VERSION) | tee -a $GITHUB_OUTPUT
linux:
@ -285,6 +287,7 @@ jobs:
run: go mod tidy --diff || (echo "Please run 'go mod tidy'." && exit 1)
test:
needs: [changes]
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
@ -319,6 +322,10 @@ jobs:
if: always()
run: go test -count=1 -benchtime=1x ./...
- name: go test app with live updater tag
if: ${{ needs.changes.outputs.app_changed == 'True' && contains(fromJSON('["macos-latest","windows-latest"]'), matrix.os) }}
run: go test -count=1 -tags updater_live ./app/...
- uses: golangci/golangci-lint-action@v9
with:
only-new-issues: true

View file

@ -5,6 +5,8 @@ package updater
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@ -169,22 +171,20 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
if err != nil {
return fmt.Errorf("error checking update: %w", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode)
}
resp.Body.Close()
etag := strings.Trim(resp.Header.Get("etag"), "\"")
if etag == "" {
slog.Debug("no etag detected, falling back to filename based dedup")
etag = "_"
}
filename := Installer
_, params, err := mime.ParseMediaType(resp.Header.Get("content-disposition"))
if err == nil {
if err == nil && params["filename"] != "" {
filename = params["filename"]
}
stageFilename := filepath.Join(UpdateStageDir, etag, filename)
stageFilename, err := updateStagePath(UpdateStageDir, resp.Header.Get("etag"), filename)
if err != nil {
return err
}
// Check to see if we already have it downloaded
_, err = os.Stat(stageFilename)
@ -202,13 +202,14 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
return fmt.Errorf("error checking update: %w", err)
}
defer resp.Body.Close()
etag = strings.Trim(resp.Header.Get("etag"), "\"")
if etag == "" {
slog.Debug("no etag detected, falling back to filename based dedup") // TODO probably can get rid of this redundant log
etag = "_"
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode)
}
stageFilename = filepath.Join(UpdateStageDir, etag, filename)
stageFilename, err = updateStagePath(UpdateStageDir, resp.Header.Get("etag"), filename)
if err != nil {
return err
}
_, err = os.Stat(filepath.Dir(stageFilename))
if errors.Is(err, os.ErrNotExist) {
@ -225,10 +226,13 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
if err != nil {
return fmt.Errorf("write payload %s: %w", stageFilename, err)
}
defer fp.Close()
if n, err := fp.Write(payload); err != nil || n != len(payload) {
_ = fp.Close()
return fmt.Errorf("write payload %s: %d vs %d -- %w", stageFilename, n, len(payload), err)
}
if err := fp.Close(); err != nil {
return fmt.Errorf("close payload %s: %w", stageFilename, err)
}
slog.Info("new update downloaded " + stageFilename)
if err := VerifyDownload(); err != nil {
@ -239,6 +243,61 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
return nil
}
func updateStagePath(stageDir, etag, filename string) (string, error) {
filename, err := safeUpdateFilename(filename)
if err != nil {
return "", err
}
stageDir, err = filepath.Abs(stageDir)
if err != nil {
return "", fmt.Errorf("resolve update stage dir: %w", err)
}
stageFilename := filepath.Join(stageDir, updateStageETagDir(etag), filename)
if err := ensurePathInDir(stageDir, stageFilename); err != nil {
return "", err
}
return stageFilename, nil
}
func safeUpdateFilename(filename string) (string, error) {
filename = strings.TrimSpace(filename)
if filename == "" {
return "", errors.New("missing update filename")
}
if filename == "." || filename == ".." ||
filepath.IsAbs(filename) || path.IsAbs(filename) ||
strings.ContainsAny(filename, `/\:`) ||
filepath.Base(filename) != filename || path.Base(filename) != filename {
return "", fmt.Errorf("unsafe update filename %q", filename)
}
return filename, nil
}
func updateStageETagDir(etag string) string {
etag = strings.Trim(strings.TrimSpace(etag), "\"")
if etag == "" {
slog.Debug("no etag detected, falling back to filename based dedup")
return "_"
}
sum := sha256.Sum256([]byte(etag))
return hex.EncodeToString(sum[:])
}
func ensurePathInDir(dir, name string) error {
rel, err := filepath.Rel(dir, name)
if err != nil {
return fmt.Errorf("resolve update staging path: %w", err)
}
if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || filepath.IsAbs(rel) {
return fmt.Errorf("update staging path escapes stage dir: %s", name)
}
return nil
}
func cleanupOldDownloads(stageDir string) {
files, err := os.ReadDir(stageDir)
if err != nil && errors.Is(err, os.ErrNotExist) {

View file

@ -22,6 +22,15 @@ import (
"golang.org/x/sys/unix"
)
const updateArchiveRoot = "Ollama.app"
type bundleEntryScope int
const (
bundleEntryRelative bundleEntryScope = iota
bundleEntryWithArchiveRoot
)
var (
appBackupDir string
SystemWidePath = "/Applications/Ollama.app"
@ -167,8 +176,12 @@ func DoUpgrade(interactive bool) error {
}
name := s[1]
if strings.HasSuffix(name, "/") {
d := filepath.Join(BundlePath, name)
err := os.MkdirAll(d, 0o755)
d, err := bundleEntryPath(BundlePath, name, bundleEntryRelative)
if err != nil {
anyFailures = true
return err
}
err = os.MkdirAll(d, 0o755)
if err != nil {
anyFailures = true
return fmt.Errorf("failed to mkdir %s: %w", d, err)
@ -181,30 +194,14 @@ func DoUpgrade(interactive bool) error {
continue
}
src, err := f.Open()
destName, err := bundleEntryPath(BundlePath, name, bundleEntryRelative)
if err != nil {
anyFailures = true
return fmt.Errorf("failed to open bundle file %s: %w", name, err)
return err
}
destName := filepath.Join(BundlePath, name)
// Verify directory first
d := filepath.Dir(destName)
if _, err := os.Stat(d); err != nil {
err := os.MkdirAll(d, 0o755)
if err != nil {
anyFailures = true
return fmt.Errorf("failed to mkdir %s: %w", d, err)
}
}
destFile, err := os.OpenFile(destName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
if err := extractBundleFile(f, destName, name); err != nil {
anyFailures = true
return fmt.Errorf("failed to open output file %s: %w", destName, err)
}
defer destFile.Close()
if _, err := io.Copy(destFile, src); err != nil {
anyFailures = true
return fmt.Errorf("failed to open extract file %s: %w", destName, err)
return err
}
}
for _, f := range links {
@ -225,16 +222,24 @@ func DoUpgrade(interactive bool) error {
return err
}
link := string(buf)
if link[0] == '/' {
if link == "" {
anyFailures = true
return fmt.Errorf("bundle contains empty symlink %s", f.Name)
}
if filepath.IsAbs(link) {
anyFailures = true
return fmt.Errorf("bundle contains absolute symlink %s -> %s", f.Name, link)
}
// Don't allow links outside of Ollama.app
if strings.HasPrefix(filepath.Join(filepath.Dir(name), link), "..") {
if !validBundleLinkTarget(name, link, bundleEntryRelative) {
anyFailures = true
return fmt.Errorf("bundle contains link outside of contents %s -> %s", f.Name, link)
return fmt.Errorf("bundle contains invalid symlink %s -> %s", f.Name, link)
}
if err = os.Symlink(link, filepath.Join(BundlePath, name)); err != nil {
destName, err := bundleEntryPath(BundlePath, name, bundleEntryRelative)
if err != nil {
anyFailures = true
return err
}
if err = os.Symlink(link, destName); err != nil {
anyFailures = true
return err
}
@ -282,8 +287,11 @@ func verifyDownload() error {
links := []*zip.File{}
for _, f := range r.File {
if strings.HasSuffix(f.Name, "/") {
d := filepath.Join(dir, f.Name)
err := os.MkdirAll(d, 0o755)
d, err := bundleEntryPath(dir, f.Name, bundleEntryWithArchiveRoot)
if err != nil {
return err
}
err = os.MkdirAll(d, 0o755)
if err != nil {
return fmt.Errorf("failed to mkdir %s: %w", d, err)
}
@ -294,26 +302,12 @@ func verifyDownload() error {
links = append(links, f)
continue
}
src, err := f.Open()
destName, err := bundleEntryPath(dir, f.Name, bundleEntryWithArchiveRoot)
if err != nil {
return fmt.Errorf("failed to open bundle file %s: %w", f.Name, err)
return err
}
destName := filepath.Join(dir, f.Name)
// Verify directory first
d := filepath.Dir(destName)
if _, err := os.Stat(d); err != nil {
err := os.MkdirAll(d, 0o755)
if err != nil {
return fmt.Errorf("failed to mkdir %s: %w", d, err)
}
}
destFile, err := os.OpenFile(destName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
return fmt.Errorf("failed to open output file %s: %w", destName, err)
}
defer destFile.Close()
if _, err := io.Copy(destFile, src); err != nil {
return fmt.Errorf("failed to open extract file %s: %w", destName, err)
if err := extractBundleFile(f, destName, f.Name); err != nil {
return err
}
}
for _, f := range links {
@ -326,13 +320,20 @@ func verifyDownload() error {
return err
}
link := string(buf)
if link[0] == '/' {
if link == "" {
return fmt.Errorf("bundle contains empty symlink %s", f.Name)
}
if filepath.IsAbs(link) {
return fmt.Errorf("bundle contains absolute symlink %s -> %s", f.Name, link)
}
if strings.HasPrefix(filepath.Join(filepath.Dir(f.Name), link), "..") {
return fmt.Errorf("bundle contains link outside of contents %s -> %s", f.Name, link)
if !validBundleLinkTarget(f.Name, link, bundleEntryWithArchiveRoot) {
return fmt.Errorf("bundle contains invalid symlink %s -> %s", f.Name, link)
}
if err = os.Symlink(link, filepath.Join(dir, f.Name)); err != nil {
destName, err := bundleEntryPath(dir, f.Name, bundleEntryWithArchiveRoot)
if err != nil {
return err
}
if err = os.Symlink(link, destName); err != nil {
return err
}
}
@ -343,6 +344,53 @@ func verifyDownload() error {
return nil
}
func bundleEntryPath(root, name string, scope bundleEntryScope) (string, error) {
cleanName := filepath.Clean(filepath.FromSlash(name))
if !filepath.IsLocal(cleanName) {
return "", fmt.Errorf("bundle contains invalid path: %s", name)
}
if scope == bundleEntryWithArchiveRoot && cleanName != updateArchiveRoot &&
!strings.HasPrefix(cleanName, updateArchiveRoot+string(os.PathSeparator)) {
return "", fmt.Errorf("bundle contains invalid path: %s", name)
}
return filepath.Join(root, cleanName), nil
}
func extractBundleFile(f *zip.File, destName, name string) error {
src, err := f.Open()
if err != nil {
return fmt.Errorf("failed to open bundle file %s: %w", name, err)
}
defer src.Close()
d := filepath.Dir(destName)
if _, err := os.Stat(d); err != nil {
if err := os.MkdirAll(d, 0o755); err != nil {
return fmt.Errorf("failed to mkdir %s: %w", d, err)
}
}
destFile, err := os.OpenFile(destName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
return fmt.Errorf("failed to open output file %s: %w", destName, err)
}
defer destFile.Close()
if _, err := io.Copy(destFile, src); err != nil {
return fmt.Errorf("failed to open extract file %s: %w", destName, err)
}
return nil
}
func validBundleLinkTarget(name, link string, scope bundleEntryScope) bool {
cleanTarget := filepath.Clean(filepath.Join(filepath.Dir(filepath.FromSlash(name)), filepath.FromSlash(link)))
if !filepath.IsLocal(cleanTarget) {
return false
}
return scope == bundleEntryRelative || cleanTarget == updateArchiveRoot ||
strings.HasPrefix(cleanTarget, updateArchiveRoot+string(os.PathSeparator))
}
// If we detect an upgrade bundle, attempt to upgrade at startup
func DoUpgradeAtStartup() error {
bundle := getStagedUpdate()

View file

@ -2,6 +2,7 @@ package updater
import (
"archive/zip"
"errors"
"io/fs"
"os"
"path/filepath"
@ -146,6 +147,46 @@ func TestDoUpgrade(t *testing.T) {
}
}
func TestDoUpgradeRejectsInvalidBundlePath(t *testing.T) {
tmpDir := t.TempDir()
BundlePath = filepath.Join(tmpDir, "Ollama.app")
appBackupDir = filepath.Join(tmpDir, "backup")
UpdateStageDir = filepath.Join(tmpDir, "updates")
UpgradeMarkerFile = filepath.Join(tmpDir, "upgraded")
bundle := filepath.Join(UpdateStageDir, "foo", "ollama-darwin.zip")
invalidTarget := filepath.Join(tmpDir, "invalid-entry")
if err := os.MkdirAll(filepath.Join(BundlePath, "Contents", "MacOS"), 0o755); err != nil {
t.Fatal("failed to create empty dirs")
}
if err := os.WriteFile(filepath.Join(BundlePath, "Contents", "MacOS", "Ollama"), []byte("old app"), 0o755); err != nil {
t.Fatal("failed to create old app")
}
if err := os.MkdirAll(filepath.Dir(bundle), 0o755); err != nil {
t.Fatal("failed to create empty dirs")
}
if err := zipCreationHelper(bundle, []testPayload{{
Name: "Ollama.app/../invalid-entry",
Body: []byte("payload"),
}}); err != nil {
t.Fatal(err)
}
if err := DoUpgrade(false); err == nil {
t.Fatal("expected failure with invalid bundle path")
} else if !strings.Contains(err.Error(), "bundle contains invalid path") {
t.Fatalf("unexpected error with invalid bundle path: %s", err)
}
if _, err := os.Stat(invalidTarget); err == nil {
t.Fatalf("invalid bundle path wrote %s", invalidTarget)
} else if !errors.Is(err, os.ErrNotExist) {
t.Fatalf("unexpected stat error for %s: %s", invalidTarget, err)
}
if _, err := os.Stat(filepath.Join(BundlePath, "Contents", "MacOS", "Ollama")); err != nil {
t.Fatalf("old app was not restored: %s", err)
}
}
func TestDoUpgradeAtStartup(t *testing.T) {
tmpDir := t.TempDir()
BundlePath = filepath.Join(tmpDir, "Ollama.app")
@ -203,7 +244,7 @@ func TestVerifyDownloadFailures(t *testing.T) {
in []testPayload
expected string
}{
{"breakout", []testPayload{
{"invalid symlink target", []testPayload{
{
Name: "Ollama.app/",
Body: []byte{},
@ -212,15 +253,34 @@ func TestVerifyDownloadFailures(t *testing.T) {
Body: []byte("cli payload here"),
}, {
Name: "Ollama.app/Contents/MacOS/Ollama",
Body: []byte("../../../../breakout"),
Body: []byte("../../../../invalid-target"),
Mode: os.ModeSymlink,
},
}, "bundle contains link outside"},
}, "bundle contains invalid symlink"},
{"invalid archive symlink target", []testPayload{
{
Name: "Ollama.app/Contents/MacOS/Ollama",
Body: []byte("../../../invalid-target"),
Mode: os.ModeSymlink,
},
}, "bundle contains invalid symlink"},
{"absolute", []testPayload{{
Name: "Ollama.app/Contents/MacOS/Ollama",
Body: []byte("/etc/foo"),
Mode: os.ModeSymlink,
}}, "bundle contains absolute"},
{"invalid relative file", []testPayload{{
Name: "Ollama.app/../invalid-entry",
Body: []byte("payload"),
}}, "bundle contains invalid path"},
{"invalid relative directory", []testPayload{{
Name: "Ollama.app/../invalid-entry/",
Body: []byte{},
}}, "bundle contains invalid path"},
{"absolute file", []testPayload{{
Name: filepath.Join(tmpDir, "invalid-entry"),
Body: []byte("payload"),
}}, "bundle contains invalid path"},
{"missing", []testPayload{{
Name: "Ollama.app/Contents/MacOS/Ollama",
Body: []byte("../nothere"),
@ -242,6 +302,11 @@ func TestVerifyDownloadFailures(t *testing.T) {
if err == nil || !strings.Contains(err.Error(), tt.expected) {
t.Fatalf("expected \"%s\" got %s", tt.expected, err)
}
if _, err := os.Stat(filepath.Join(tmpDir, "invalid-entry")); err == nil {
t.Fatal("invalid bundle path wrote unexpected file")
} else if !errors.Is(err, os.ErrNotExist) {
t.Fatalf("unexpected stat error for invalid file: %s", err)
}
})
}
}

View file

@ -0,0 +1,127 @@
//go:build (windows || darwin) && updater_live
package updater
import (
"context"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/app/version"
)
// TestLiveAppUpdate exercises the production update endpoint and downloads the
// current OS update artifact. It is intentionally excluded from normal test
// runs because it depends on ollama.com and downloads a release artifact.
//
// Run with:
//
// go test -tags updater_live -run TestLiveAppUpdate ./app/updater
func TestLiveAppUpdate(t *testing.T) {
const spoofedVersion = "0.20.0"
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
oldUpdateStageDir := UpdateStageDir
oldUpdateDownloaded := UpdateDownloaded
oldVerifyDownload := VerifyDownload
oldVersion := version.Version
defer func() {
UpdateStageDir = oldUpdateStageDir
UpdateDownloaded = oldUpdateDownloaded
VerifyDownload = oldVerifyDownload
version.Version = oldVersion
}()
version.Version = spoofedVersion
expectedFilename := ""
switch runtime.GOOS {
case "windows":
t.Setenv("LOCALAPPDATA", t.TempDir())
expectedFilename = "OllamaSetup.exe"
case "darwin":
expectedFilename = "Ollama-darwin.zip"
default:
t.Fatalf("unsupported updater live test OS %q", runtime.GOOS)
}
UpdateStageDir = filepath.Join(t.TempDir(), "updates")
UpdateDownloaded = false
verifyCalled := false
VerifyDownload = func() error {
verifyCalled = true
return verifyDownload()
}
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "db.sqlite")}}
defer updater.Store.Close()
available, updateResp := updater.checkForUpdate(ctx)
if !available {
t.Fatalf("expected production update check to offer an update for spoofed version %s", spoofedVersion)
}
if updateResp.UpdateURL == "" {
t.Fatal("production update response did not include a download URL")
}
t.Logf("production update version=%q url=%q", updateResp.UpdateVersion, updateResp.UpdateURL)
if err := updater.DownloadNewRelease(ctx, updateResp); err != nil {
t.Fatalf("download production update: %v", err)
}
staged := getStagedUpdate()
if staged == "" {
t.Fatal("production update was not staged")
}
t.Logf("staged production update at %s", staged)
assertPathInsideDir(t, UpdateStageDir, staged)
if filepath.Base(staged) != expectedFilename {
t.Fatalf("expected staged %s update filename to be %q, got %q", runtime.GOOS, expectedFilename, filepath.Base(staged))
}
expectedExt := filepath.Ext(expectedFilename)
if filepath.Ext(staged) != expectedExt {
t.Fatalf("expected staged %s update to be a %s artifact, got %s", runtime.GOOS, expectedExt, staged)
}
info, err := os.Stat(staged)
if err != nil {
t.Fatalf("stat staged update: %v", err)
}
if info.Size() == 0 {
t.Fatal("staged production update is empty")
}
if !verifyCalled {
t.Fatal("DownloadNewRelease did not call VerifyDownload")
}
t.Logf("production updater download path verified staged %s update", runtime.GOOS)
}
func assertPathInsideDir(t *testing.T, dir, name string) {
t.Helper()
dir, err := filepath.Abs(dir)
if err != nil {
t.Fatal(err)
}
name, err = filepath.Abs(name)
if err != nil {
t.Fatal(err)
}
rel, err := filepath.Rel(dir, name)
if err != nil {
t.Fatal(err)
}
if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || filepath.IsAbs(rel) {
t.Fatalf("staged update escaped update stage dir: %s", name)
}
}

View file

@ -11,7 +11,9 @@ import (
"log/slog"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync/atomic"
"testing"
"time"
@ -19,6 +21,52 @@ import (
"github.com/ollama/ollama/app/store"
)
func TestUpdateStagePathRejectsUnsafeFilename(t *testing.T) {
stageDir := t.TempDir()
for _, tt := range []struct {
name string
filename string
}{
{"empty", ""},
{"dot", "."},
{"dotdot", ".."},
{"posix_parent", "../OllamaSetup.exe"},
{"windows_parent", `..\OllamaSetup.exe`},
{"posix_absolute_tmp", "/tmp/OllamaSetup.exe"},
{"darwin_absolute_app", "/Applications/Ollama.app"},
{"darwin_bundle_path", "Ollama.app/Contents/MacOS/Ollama"},
{"darwin_user_download", "~/Downloads/Ollama-darwin.zip"},
{"windows_absolute", `C:\Users\Public\OllamaSetup.exe`},
{"colon", "Ollama:Setup.exe"},
} {
t.Run(tt.name, func(t *testing.T) {
if _, err := updateStagePath(stageDir, "etag", tt.filename); err == nil {
t.Fatal("expected unsafe filename to be rejected")
}
})
}
}
func TestUpdateStagePathHashesETag(t *testing.T) {
stageDir := t.TempDir()
stageFilename, err := updateStagePath(stageDir, `../escaped`, "OllamaSetup.exe")
if err != nil {
t.Fatal(err)
}
rel, err := filepath.Rel(stageDir, stageFilename)
if err != nil {
t.Fatal(err)
}
if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || filepath.IsAbs(rel) {
t.Fatalf("stage filename escaped stage dir: %s", stageFilename)
}
etagDir := filepath.Base(filepath.Dir(stageFilename))
if etagDir == ".." || etagDir == "escaped" || strings.ContainsAny(etagDir, `/\`) {
t.Fatalf("stage filename used raw etag path component: %s", stageFilename)
}
}
func TestIsNewReleaseAvailable(t *testing.T) {
slog.SetLogLoggerLevel(slog.LevelDebug)
var server *httptest.Server
@ -47,6 +95,223 @@ func TestIsNewReleaseAvailable(t *testing.T) {
}
}
func TestDownloadNewReleaseRejectsUnsafeHeaderFilename(t *testing.T) {
UpdateStageDir = t.TempDir()
oldInstaller := Installer
oldVerifyDownload := VerifyDownload
oldUpdateDownloaded := UpdateDownloaded
defer func() {
Installer = oldInstaller
VerifyDownload = oldVerifyDownload
UpdateDownloaded = oldUpdateDownloaded
}()
Installer = "OllamaSetup.exe"
UpdateDownloaded = false
VerifyDownload = func() error {
t.Fatal("verification should not run for rejected downloads")
return nil
}
var getAttempted atomic.Bool
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead {
w.Header().Set("ETag", `"safe"`)
w.Header().Set("Content-Disposition", `attachment; filename="../OllamaSetup.exe"`)
w.WriteHeader(http.StatusOK)
return
}
getAttempted.Store(true)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
updater := &Updater{}
err := updater.DownloadNewRelease(t.Context(), UpdateResponse{UpdateURL: server.URL + "/download"})
if err == nil || !strings.Contains(err.Error(), "unsafe update filename") {
t.Fatalf("expected unsafe filename error, got %v", err)
}
if getAttempted.Load() {
t.Fatal("download should not continue after unsafe filename")
}
if _, err := os.Stat(filepath.Join(filepath.Dir(UpdateStageDir), "OllamaSetup.exe")); err == nil {
t.Fatal("download escaped update stage dir")
}
}
func TestDownloadNewReleaseDoesNotUseRawETagAsPathComponent(t *testing.T) {
UpdateStageDir = t.TempDir()
oldInstaller := Installer
oldVerifyDownload := VerifyDownload
oldUpdateDownloaded := UpdateDownloaded
defer func() {
Installer = oldInstaller
VerifyDownload = oldVerifyDownload
UpdateDownloaded = oldUpdateDownloaded
}()
Installer = "OllamaSetup.exe"
UpdateDownloaded = false
VerifyDownload = func() error {
return nil
}
payload := []byte("payload")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("ETag", `"../escaped"`)
w.WriteHeader(http.StatusOK)
if r.Method == http.MethodGet {
_, _ = w.Write(payload)
}
}))
defer server.Close()
updater := &Updater{}
if err := updater.DownloadNewRelease(t.Context(), UpdateResponse{UpdateURL: server.URL + "/download"}); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(filepath.Join(filepath.Dir(UpdateStageDir), "escaped", Installer)); err == nil {
t.Fatal("download escaped update stage dir via etag")
}
entries, err := os.ReadDir(UpdateStageDir)
if err != nil {
t.Fatal(err)
}
if len(entries) != 1 {
t.Fatalf("expected one staged update dir, got %d", len(entries))
}
stageFilename := filepath.Join(UpdateStageDir, entries[0].Name(), Installer)
got, err := os.ReadFile(stageFilename)
if err != nil {
t.Fatal(err)
}
if string(got) != string(payload) {
t.Fatalf("unexpected staged payload %q", got)
}
}
func TestBackgroundCheckerSkipsAlreadyStagedETagDownload(t *testing.T) {
UpdateStageDir = t.TempDir()
oldInstaller := Installer
oldVerifyDownload := VerifyDownload
oldUpdateDownloaded := UpdateDownloaded
oldUpdateCheckInitialDelay := UpdateCheckInitialDelay
oldUpdateCheckInterval := UpdateCheckInterval
oldUpdateCheckURLBase := UpdateCheckURLBase
defer func() {
Installer = oldInstaller
VerifyDownload = oldVerifyDownload
UpdateDownloaded = oldUpdateDownloaded
UpdateCheckInitialDelay = oldUpdateCheckInitialDelay
UpdateCheckInterval = oldUpdateCheckInterval
UpdateCheckURLBase = oldUpdateCheckURLBase
}()
Installer = "OllamaSetup.exe"
UpdateDownloaded = false
UpdateCheckInitialDelay = time.Millisecond
UpdateCheckInterval = 5 * time.Millisecond
var verifyCount atomic.Int32
VerifyDownload = func() error {
verifyCount.Add(1)
return nil
}
headETag := `"old-update"`
getETag := `"download-response-etag"`
payload := []byte("payload")
var headCount atomic.Int32
var getCount atomic.Int32
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/update.json":
w.Write([]byte(
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
case "/9.9.9/" + Installer:
w.Header().Set("Content-Disposition", `attachment; filename="OllamaSetup.exe"`)
switch r.Method {
case http.MethodHead:
etag := headETag
if getCount.Load() > 0 {
etag = getETag
}
w.Header().Set("ETag", etag)
headCount.Add(1)
w.WriteHeader(http.StatusOK)
case http.MethodGet:
w.Header().Set("ETag", getETag)
getCount.Add(1)
w.WriteHeader(http.StatusOK)
_, _ = w.Write(payload)
default:
t.Errorf("unexpected request method %s", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
}
default:
t.Errorf("unexpected request path %s", r.URL.Path)
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
defer updater.Store.Close()
settings, err := updater.Store.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = true
if err := updater.Store.SetSettings(settings); err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
callbacks := make(chan string, 4)
updater.StartBackgroundUpdaterChecker(ctx, func(ver string) error {
callbacks <- ver
return nil
})
for range 2 {
select {
case <-callbacks:
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for repeated update checks")
}
}
cancel()
stageFilename, err := updateStagePath(UpdateStageDir, getETag, Installer)
if err != nil {
t.Fatal(err)
}
got, err := os.ReadFile(stageFilename)
if err != nil {
t.Fatal(err)
}
if string(got) != string(payload) {
t.Fatalf("unexpected staged payload %q", got)
}
if headCount.Load() < 2 {
t.Fatalf("HEAD count = %d, want at least 2", headCount.Load())
}
if getCount.Load() != 1 {
t.Fatalf("GET count = %d, want 1", getCount.Load())
}
if verifyCount.Load() != 1 {
t.Fatalf("verification count = %d, want 1", verifyCount.Load())
}
if !UpdateDownloaded {
t.Fatal("UpdateDownloaded should stay true for already staged update")
}
}
func TestBackgoundChecker(t *testing.T) {
UpdateStageDir = t.TempDir()
haveUpdate := false

View file

@ -1,6 +1,7 @@
package updater
import (
"crypto/x509"
"errors"
"fmt"
"log/slog"
@ -18,6 +19,30 @@ import (
var runningInstaller string
var (
crypt32 = windows.NewLazySystemDLL("crypt32.dll")
procCryptMsgGetParam = crypt32.NewProc("CryptMsgGetParam")
procCryptMsgClose = crypt32.NewProc("CryptMsgClose")
)
const cmsgSignerInfoParam = 6
type cmsgSignerInfo struct {
Version uint32
Issuer windows.CertNameBlob
SerialNumber windows.CryptIntegerBlob
HashAlgorithm windows.CryptAlgorithmIdentifier
HashEncryptionAlgorithm windows.CryptAlgorithmIdentifier
EncryptedHash windows.CryptDataBlob
AuthAttrs cryptAttributes
UnauthAttrs cryptAttributes
}
type cryptAttributes struct {
Count uint32
Attributes unsafe.Pointer
}
type OSVERSIONINFOEXW struct {
dwOSVersionInfoSize uint32
dwMajorVersion uint32
@ -99,6 +124,12 @@ func DoUpgrade(interactive bool) error {
return fmt.Errorf("failed to lookup downloads")
}
if err := VerifyDownload(); err != nil {
_ = os.Remove(bundle)
slog.Warn("verification failure", "bundle", bundle, "error", err)
return fmt.Errorf("staged update verification failed: %w", err)
}
// We move the installer to ensure we don't race with multiple apps starting in quick succession
if err := os.Rename(bundle, runningInstaller); err != nil {
return fmt.Errorf("unable to rename %s -> %s : %w", bundle, runningInstaller, err)
@ -184,6 +215,150 @@ func DoPostUpgradeCleanup() error {
}
func verifyDownload() error {
bundle := getStagedUpdate()
if bundle == "" {
return fmt.Errorf("failed to lookup downloads")
}
slog.Debug("verifying update", "bundle", bundle)
if err := verifyWindowsInstallerSignature(bundle); err != nil {
return fmt.Errorf("signature verification failed: %w", err)
}
return nil
}
func verifyWindowsInstallerSignature(filename string) error {
filename16, err := windows.UTF16PtrFromString(filename)
if err != nil {
return err
}
data := &windows.WinTrustData{
Size: uint32(unsafe.Sizeof(windows.WinTrustData{})),
UIChoice: windows.WTD_UI_NONE,
RevocationChecks: windows.WTD_REVOKE_WHOLECHAIN,
UnionChoice: windows.WTD_CHOICE_FILE,
StateAction: windows.WTD_STATEACTION_VERIFY,
UIContext: windows.WTD_UICONTEXT_INSTALL,
FileOrCatalogOrBlobOrSgnrOrCert: unsafe.Pointer(&windows.WinTrustFileInfo{
Size: uint32(unsafe.Sizeof(windows.WinTrustFileInfo{})),
FilePath: filename16,
}),
}
verifyErr := windows.WinVerifyTrustEx(windows.InvalidHWND, &windows.WINTRUST_ACTION_GENERIC_VERIFY_V2, data)
data.StateAction = windows.WTD_STATEACTION_CLOSE
closeErr := windows.WinVerifyTrustEx(windows.InvalidHWND, &windows.WINTRUST_ACTION_GENERIC_VERIFY_V2, data)
if verifyErr != nil {
return verifyErr
}
if closeErr != nil {
return fmt.Errorf("close WinVerifyTrust state: %w", closeErr)
}
subject, err := windowsInstallerSignerSubject(filename)
if err != nil {
return err
}
slog.Debug("verified update signature", "subject", subject)
return nil
}
func windowsInstallerSignerSubject(filename string) (string, error) {
filename16, err := windows.UTF16PtrFromString(filename)
if err != nil {
return "", err
}
var certStore windows.Handle
var msg windows.Handle
if err := windows.CryptQueryObject(
windows.CERT_QUERY_OBJECT_FILE,
unsafe.Pointer(filename16),
windows.CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED_EMBED,
windows.CERT_QUERY_FORMAT_FLAG_BINARY,
0,
nil,
nil,
nil,
&certStore,
&msg,
nil,
); err != nil {
return "", err
}
defer windows.CertCloseStore(certStore, 0) //nolint:errcheck
defer cryptMsgClose(msg) //nolint:errcheck
var signerInfoSize uint32
if err := cryptMsgGetParam(msg, cmsgSignerInfoParam, 0, nil, &signerInfoSize); err != nil {
return "", err
}
if signerInfoSize == 0 {
return "", fmt.Errorf("missing signer info")
}
signerInfoBuf := make([]byte, signerInfoSize)
if err := cryptMsgGetParam(msg, cmsgSignerInfoParam, 0, unsafe.Pointer(&signerInfoBuf[0]), &signerInfoSize); err != nil {
return "", err
}
signerInfo := (*cmsgSignerInfo)(unsafe.Pointer(&signerInfoBuf[0]))
certInfo := windows.CertInfo{
Issuer: signerInfo.Issuer,
SerialNumber: signerInfo.SerialNumber,
}
cert, err := windows.CertFindCertificateInStore(
certStore,
windows.X509_ASN_ENCODING|windows.PKCS_7_ASN_ENCODING,
0,
windows.CERT_FIND_SUBJECT_CERT,
unsafe.Pointer(&certInfo),
nil,
)
if err != nil {
return "", err
}
defer windows.CertFreeCertificateContext(cert) //nolint:errcheck
parsed, err := x509.ParseCertificate(unsafe.Slice(cert.EncodedCert, cert.Length))
if err != nil {
return "", err
}
for _, org := range parsed.Subject.Organization {
if org == "Ollama Inc." {
return parsed.Subject.String(), nil
}
}
return "", fmt.Errorf("unexpected signer: %s", parsed.Subject.String())
}
func cryptMsgGetParam(msg windows.Handle, paramType, index uint32, data unsafe.Pointer, size *uint32) error {
r1, _, e1 := procCryptMsgGetParam.Call(
uintptr(msg),
uintptr(paramType),
uintptr(index),
uintptr(data),
uintptr(unsafe.Pointer(size)),
)
if r1 == 0 {
if e1 != syscall.Errno(0) {
return e1
}
return syscall.EINVAL
}
return nil
}
func cryptMsgClose(msg windows.Handle) error {
r1, _, e1 := procCryptMsgClose.Call(uintptr(msg))
if r1 == 0 {
if e1 != syscall.Errno(0) {
return e1
}
return syscall.EINVAL
}
return nil
}

View file

@ -1,13 +1,85 @@
//go:build windows || darwin
//go:build windows
package updater
import (
"log/slog"
"os"
"path/filepath"
"strings"
"testing"
)
func TestVerifyDownloadRejectsUnsignedWindowsInstaller(t *testing.T) {
oldUpdateStageDir := UpdateStageDir
defer func() {
UpdateStageDir = oldUpdateStageDir
}()
t.Setenv("LOCALAPPDATA", t.TempDir())
UpdateStageDir = t.TempDir()
bundle := filepath.Join(UpdateStageDir, "etag", "OllamaSetup.exe")
if err := os.MkdirAll(filepath.Dir(bundle), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(bundle, []byte("not a signed installer"), 0o755); err != nil {
t.Fatal(err)
}
err := verifyDownload()
if err == nil || !strings.Contains(err.Error(), "signature verification failed") {
t.Fatalf("expected signature verification failure, got %v", err)
}
}
func TestDoUpgradeAtStartupRejectsUnsignedWindowsInstaller(t *testing.T) {
oldUpdateStageDir := UpdateStageDir
oldRunningInstaller := runningInstaller
oldUpgradeLogFile := UpgradeLogFile
oldUpgradeMarkerFile := UpgradeMarkerFile
oldVerifyDownload := VerifyDownload
defer func() {
UpdateStageDir = oldUpdateStageDir
runningInstaller = oldRunningInstaller
UpgradeLogFile = oldUpgradeLogFile
UpgradeMarkerFile = oldUpgradeMarkerFile
VerifyDownload = oldVerifyDownload
}()
t.Setenv("LOCALAPPDATA", t.TempDir())
UpdateStageDir = t.TempDir()
runDir := t.TempDir()
runningInstaller = filepath.Join(runDir, "OllamaSetup.exe")
UpgradeLogFile = filepath.Join(runDir, "upgrade.log")
UpgradeMarkerFile = filepath.Join(runDir, "upgraded")
VerifyDownload = verifyDownload
bundle := filepath.Join(UpdateStageDir, "etag", "OllamaSetup.exe")
if err := os.MkdirAll(filepath.Dir(bundle), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(bundle, []byte("not a signed installer"), 0o755); err != nil {
t.Fatal(err)
}
err := DoUpgradeAtStartup()
if err == nil || !strings.Contains(err.Error(), "signature verification failed") {
t.Fatalf("expected signature verification failure, got %v", err)
}
if _, err := os.Stat(runningInstaller); !os.IsNotExist(err) {
t.Fatalf("unsigned installer was moved before verification failed: %v", err)
}
if _, err := os.Stat(bundle); !os.IsNotExist(err) {
t.Fatalf("unsigned staged installer was not removed after verification failure: %v", err)
}
}
func TestIsInstallerRunning(t *testing.T) {
oldInstaller := Installer
defer func() {
Installer = oldInstaller
}()
slog.SetLogLoggerLevel(slog.LevelDebug)
Installer = "go.exe"
if !isInstallerRunning() {

View file

@ -6,5 +6,6 @@ func setTestHome(t *testing.T, home string) {
t.Helper()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
t.Setenv("OLLAMA_MODELS", "")
ReloadServerConfig()
}

View file

@ -3127,8 +3127,12 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
s.sched.expireRunnersForRuntimeOOM(m, err)
// Only send JSON error if streaming hasn't started yet
// (once streaming starts, headers are committed and we can't change status code)
if !streamStarted {
if !isStreaming || !streamStarted {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} else {
data, _ := json.Marshal(gin.H{"error": err.Error()})
c.Writer.Write(append(data, '\n'))
c.Writer.Flush()
}
return
}

View file

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
@ -2989,3 +2990,120 @@ func TestImageGenerateStreamFalse(t *testing.T) {
t.Errorf("expected done=true")
}
}
func newImageGenerateTestServer(t *testing.T, mock *mockRunner) Server {
t.Helper()
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
n := model.ParseName("test-image")
cfg := model.ConfigV2{Capabilities: []string{"image"}}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(&cfg); err != nil {
t.Fatal(err)
}
configLayer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatal(err)
}
if err := manifest.WriteManifest(n, configLayer, nil); err != nil {
t.Fatal(err)
}
loadedModel, err := GetModel("test-image")
if err != nil {
t.Fatal(err)
}
opts := api.DefaultOptions()
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: map[string]*runnerRef{
schedulerModelKey(loadedModel): {
llama: mock,
Options: &opts,
model: loadedModel,
isImagegen: true,
numParallel: 1,
},
},
newServerFn: newMockServer(mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
},
}
go s.sched.Run(t.Context())
return s
}
func TestImageGenerateStreamFalseErrorAfterProgress(t *testing.T) {
mock := mockRunner{}
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
fn(llm.CompletionResponse{Step: 1, TotalSteps: 3, Done: false})
return errors.New("runner died")
}
s := newImageGenerateTestServer(t, &mock)
streamFalse := false
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-image",
Prompt: "test prompt",
Stream: &streamFalse,
})
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected status 500, got %d: %s", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), "runner died") {
t.Fatalf("expected runner error in body, got %q", w.Body.String())
}
}
func TestImageGenerateStreamingErrorAfterProgress(t *testing.T) {
mock := mockRunner{}
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
fn(llm.CompletionResponse{Step: 1, TotalSteps: 3, Done: false})
return errors.New("runner died")
}
s := newImageGenerateTestServer(t, &mock)
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-image",
Prompt: "test prompt",
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200 after streaming started, got %d: %s", w.Code, w.Body.String())
}
lines := strings.Split(strings.TrimSpace(w.Body.String()), "\n")
if len(lines) != 2 {
t.Fatalf("expected progress and error lines, got %d:\n%s", len(lines), w.Body.String())
}
var progress api.GenerateResponse
if err := json.Unmarshal([]byte(lines[0]), &progress); err != nil {
t.Fatalf("failed to parse progress response: %v", err)
}
if progress.Completed != 1 || progress.Total != 3 || progress.Done {
t.Fatalf("progress response = %+v", progress)
}
var errorResponse struct {
Error string `json:"error"`
}
if err := json.Unmarshal([]byte(lines[1]), &errorResponse); err != nil {
t.Fatalf("failed to parse error response: %v", err)
}
if errorResponse.Error != "runner died" {
t.Fatalf("error = %q, want runner died", errorResponse.Error)
}
}

View file

@ -10,5 +10,6 @@ func setTestHome(t *testing.T, home string) {
t.Helper()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
t.Setenv("OLLAMA_MODELS", "")
envconfig.ReloadServerConfig()
}

View file

@ -15,6 +15,7 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/internal/mlxthread"
)
// Execute is the entry point for the unified MLX runner subprocess.
@ -45,17 +46,30 @@ func Execute(args []string) error {
return fmt.Errorf("imagegen runner only supports image generation models")
}
// Initialize MLX only for image generation mode.
if err := mlx.InitMLX(); err != nil {
slog.Error("unable to initialize MLX", "error", err)
worker, err := mlxthread.Start("imagegen", func() error {
if err := mlx.InitMLX(); err != nil {
slog.Error("unable to initialize MLX", "error", err)
return err
}
slog.Info("MLX library initialized")
return nil
})
if err != nil {
return err
}
slog.Info("MLX library initialized")
// Create and start server
server, err := newServer(*modelName, *port)
if err != nil {
return fmt.Errorf("failed to create server: %w", err)
var server *server
if err := worker.Do(context.Background(), func() error {
var err error
server, err = newServer(*modelName, *port)
if err != nil {
return fmt.Errorf("failed to create server: %w", err)
}
server.mlxThread = worker
return nil
}); err != nil {
return err
}
// Set up HTTP handlers
@ -77,7 +91,17 @@ func Execute(args []string) error {
slog.Info("shutting down mlx runner")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpServer.Shutdown(ctx)
if err := httpServer.Shutdown(ctx); err != nil {
slog.Warn("graceful shutdown timed out", "error", err)
if err := httpServer.Close(); err != nil {
slog.Warn("failed to close http server", "error", err)
}
}
if err := worker.Stop(ctx, func() {
mlx.ClearCache()
}); err != nil {
slog.Warn("failed to stop mlx worker", "error", err)
}
close(done)
}()
@ -110,6 +134,7 @@ func detectModelMode(modelName string) ModelMode {
type server struct {
modelName string
port int
mlxThread *mlxthread.Thread
// Image generation model.
imageModel ImageModel
@ -147,5 +172,10 @@ func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) {
return
}
s.handleImageCompletion(w, r, req)
if err := s.mlxThread.Do(r.Context(), func() error {
s.handleImageCompletion(w, r, req)
return nil
}); err != nil && r.Context().Err() == nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}

View file

@ -367,9 +367,16 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
// Check if subprocess is still alive
if s.HasExited() {
slog.Error("mlx subprocess has exited unexpectedly")
if errMsg := s.getLastErr(); errMsg != "" {
return fmt.Errorf("mlx runner closed response before completion: %s", errMsg)
}
}
return scanErr
if scanErr != nil {
return scanErr
}
return errors.New("mlx runner closed response before completion")
}
func (s *Server) Chat(ctx context.Context, req llm.ChatRequest, fn func(llm.ChatResponse)) error {

102
x/imagegen/server_test.go Normal file
View file

@ -0,0 +1,102 @@
package imagegen
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"github.com/ollama/ollama/llm"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func newCompletionTestServer(handler func(*http.Request) string) *Server {
return &Server{
port: 11434,
done: make(chan error, 1),
client: &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
body := handler(req)
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
Request: req,
}, nil
}),
},
}
}
func TestCompletionReturnsImageData(t *testing.T) {
s := newCompletionTestServer(func(r *http.Request) string {
if r.URL.Path != "/completion" {
t.Fatalf("path = %q, want /completion", r.URL.Path)
}
var req Request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatal(err)
}
if req.Prompt != "test prompt" || req.Width != 512 || req.Height != 256 || req.Steps != 7 || req.Seed != 42 {
t.Fatalf("unexpected request: %+v", req)
}
if len(req.Images) != 1 || string(req.Images[0]) != "input-image" {
t.Fatalf("images = %q, want input-image", req.Images)
}
return `{"step":1,"total":2}` + "\n" +
`{"done":true,"image":"base64png"}` + "\n"
})
var responses []llm.CompletionResponse
err := s.Completion(context.Background(), llm.CompletionRequest{
Prompt: "test prompt",
Width: 512,
Height: 256,
Steps: 7,
Seed: 42,
Images: []llm.ImageData{{Data: []byte("input-image")}},
}, func(resp llm.CompletionResponse) {
responses = append(responses, resp)
})
if err != nil {
t.Fatal(err)
}
if len(responses) != 2 {
t.Fatalf("responses = %d, want 2", len(responses))
}
if responses[0].Step != 1 || responses[0].TotalSteps != 2 || responses[0].Done {
t.Fatalf("progress response = %+v", responses[0])
}
if !responses[1].Done || responses[1].Image != "base64png" {
t.Fatalf("final response = %+v", responses[1])
}
}
func TestCompletionEOFBeforeDoneReturnsError(t *testing.T) {
s := newCompletionTestServer(func(r *http.Request) string {
return `{"step":1,"total":2}` + "\n"
})
var responses []llm.CompletionResponse
err := s.Completion(context.Background(), llm.CompletionRequest{Prompt: "test prompt"}, func(resp llm.CompletionResponse) {
responses = append(responses, resp)
})
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "closed response before completion") {
t.Fatalf("error = %v", err)
}
if len(responses) != 1 || responses[0].Done {
t.Fatalf("responses = %+v, want one non-done progress response", responses)
}
}