mirror of
https://github.com/ollama/ollama.git
synced 2026-05-13 06:21:28 +00:00
app: harden update flows (#16100)
* app: harden update flows This hardens the windows update flows and adds a new opt-in and CI triggered unit test to verify Mac/Windows updates with verification. * test: harden unit tests for OLLAMA_MODELS being set * app: harden updater
This commit is contained in:
parent
c2f2d90a67
commit
3d5a011a2e
10 changed files with 889 additions and 69 deletions
7
.github/workflows/test.yaml
vendored
7
.github/workflows/test.yaml
vendored
|
|
@ -22,6 +22,7 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
changed: ${{ steps.changes.outputs.changed }}
|
||||
app_changed: ${{ steps.changes.outputs.app_changed }}
|
||||
vendorsha: ${{ steps.changes.outputs.vendorsha }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
|
@ -38,6 +39,7 @@ jobs:
|
|||
}
|
||||
|
||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*' '.github/**/*') | tee -a $GITHUB_OUTPUT
|
||||
echo app_changed=$(changed 'app/**' 'app/**/*') | tee -a $GITHUB_OUTPUT
|
||||
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||
|
||||
linux:
|
||||
|
|
@ -250,6 +252,7 @@ jobs:
|
|||
run: go mod tidy --diff || (echo "Please run 'go mod tidy'." && exit 1)
|
||||
|
||||
test:
|
||||
needs: [changes]
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
|
|
@ -284,6 +287,10 @@ jobs:
|
|||
if: always()
|
||||
run: go test -count=1 -benchtime=1x ./...
|
||||
|
||||
- name: go test app with live updater tag
|
||||
if: ${{ needs.changes.outputs.app_changed == 'True' && contains(fromJSON('["macos-latest","windows-latest"]'), matrix.os) }}
|
||||
run: go test -count=1 -tags updater_live ./app/...
|
||||
|
||||
- uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
only-new-issues: true
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ package updater
|
|||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -169,22 +171,20 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
|
|||
if err != nil {
|
||||
return fmt.Errorf("error checking update: %w", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode)
|
||||
}
|
||||
resp.Body.Close()
|
||||
etag := strings.Trim(resp.Header.Get("etag"), "\"")
|
||||
if etag == "" {
|
||||
slog.Debug("no etag detected, falling back to filename based dedup")
|
||||
etag = "_"
|
||||
}
|
||||
filename := Installer
|
||||
_, params, err := mime.ParseMediaType(resp.Header.Get("content-disposition"))
|
||||
if err == nil {
|
||||
if err == nil && params["filename"] != "" {
|
||||
filename = params["filename"]
|
||||
}
|
||||
|
||||
stageFilename := filepath.Join(UpdateStageDir, etag, filename)
|
||||
stageFilename, err := updateStagePath(UpdateStageDir, resp.Header.Get("etag"), filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check to see if we already have it downloaded
|
||||
_, err = os.Stat(stageFilename)
|
||||
|
|
@ -202,13 +202,14 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
|
|||
return fmt.Errorf("error checking update: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
etag = strings.Trim(resp.Header.Get("etag"), "\"")
|
||||
if etag == "" {
|
||||
slog.Debug("no etag detected, falling back to filename based dedup") // TODO probably can get rid of this redundant log
|
||||
etag = "_"
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
stageFilename = filepath.Join(UpdateStageDir, etag, filename)
|
||||
stageFilename, err = updateStagePath(UpdateStageDir, resp.Header.Get("etag"), filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = os.Stat(filepath.Dir(stageFilename))
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
|
|
@ -225,10 +226,13 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
|
|||
if err != nil {
|
||||
return fmt.Errorf("write payload %s: %w", stageFilename, err)
|
||||
}
|
||||
defer fp.Close()
|
||||
if n, err := fp.Write(payload); err != nil || n != len(payload) {
|
||||
_ = fp.Close()
|
||||
return fmt.Errorf("write payload %s: %d vs %d -- %w", stageFilename, n, len(payload), err)
|
||||
}
|
||||
if err := fp.Close(); err != nil {
|
||||
return fmt.Errorf("close payload %s: %w", stageFilename, err)
|
||||
}
|
||||
slog.Info("new update downloaded " + stageFilename)
|
||||
|
||||
if err := VerifyDownload(); err != nil {
|
||||
|
|
@ -239,6 +243,61 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
|
|||
return nil
|
||||
}
|
||||
|
||||
func updateStagePath(stageDir, etag, filename string) (string, error) {
|
||||
filename, err := safeUpdateFilename(filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
stageDir, err = filepath.Abs(stageDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resolve update stage dir: %w", err)
|
||||
}
|
||||
|
||||
stageFilename := filepath.Join(stageDir, updateStageETagDir(etag), filename)
|
||||
if err := ensurePathInDir(stageDir, stageFilename); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return stageFilename, nil
|
||||
}
|
||||
|
||||
func safeUpdateFilename(filename string) (string, error) {
|
||||
filename = strings.TrimSpace(filename)
|
||||
if filename == "" {
|
||||
return "", errors.New("missing update filename")
|
||||
}
|
||||
if filename == "." || filename == ".." ||
|
||||
filepath.IsAbs(filename) || path.IsAbs(filename) ||
|
||||
strings.ContainsAny(filename, `/\:`) ||
|
||||
filepath.Base(filename) != filename || path.Base(filename) != filename {
|
||||
return "", fmt.Errorf("unsafe update filename %q", filename)
|
||||
}
|
||||
return filename, nil
|
||||
}
|
||||
|
||||
func updateStageETagDir(etag string) string {
|
||||
etag = strings.Trim(strings.TrimSpace(etag), "\"")
|
||||
if etag == "" {
|
||||
slog.Debug("no etag detected, falling back to filename based dedup")
|
||||
return "_"
|
||||
}
|
||||
|
||||
sum := sha256.Sum256([]byte(etag))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func ensurePathInDir(dir, name string) error {
|
||||
rel, err := filepath.Rel(dir, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve update staging path: %w", err)
|
||||
}
|
||||
if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || filepath.IsAbs(rel) {
|
||||
return fmt.Errorf("update staging path escapes stage dir: %s", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanupOldDownloads(stageDir string) {
|
||||
files, err := os.ReadDir(stageDir)
|
||||
if err != nil && errors.Is(err, os.ErrNotExist) {
|
||||
|
|
|
|||
|
|
@ -22,6 +22,15 @@ import (
|
|||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const updateArchiveRoot = "Ollama.app"
|
||||
|
||||
type bundleEntryScope int
|
||||
|
||||
const (
|
||||
bundleEntryRelative bundleEntryScope = iota
|
||||
bundleEntryWithArchiveRoot
|
||||
)
|
||||
|
||||
var (
|
||||
appBackupDir string
|
||||
SystemWidePath = "/Applications/Ollama.app"
|
||||
|
|
@ -167,8 +176,12 @@ func DoUpgrade(interactive bool) error {
|
|||
}
|
||||
name := s[1]
|
||||
if strings.HasSuffix(name, "/") {
|
||||
d := filepath.Join(BundlePath, name)
|
||||
err := os.MkdirAll(d, 0o755)
|
||||
d, err := bundleEntryPath(BundlePath, name, bundleEntryRelative)
|
||||
if err != nil {
|
||||
anyFailures = true
|
||||
return err
|
||||
}
|
||||
err = os.MkdirAll(d, 0o755)
|
||||
if err != nil {
|
||||
anyFailures = true
|
||||
return fmt.Errorf("failed to mkdir %s: %w", d, err)
|
||||
|
|
@ -181,30 +194,14 @@ func DoUpgrade(interactive bool) error {
|
|||
continue
|
||||
}
|
||||
|
||||
src, err := f.Open()
|
||||
destName, err := bundleEntryPath(BundlePath, name, bundleEntryRelative)
|
||||
if err != nil {
|
||||
anyFailures = true
|
||||
return fmt.Errorf("failed to open bundle file %s: %w", name, err)
|
||||
return err
|
||||
}
|
||||
destName := filepath.Join(BundlePath, name)
|
||||
// Verify directory first
|
||||
d := filepath.Dir(destName)
|
||||
if _, err := os.Stat(d); err != nil {
|
||||
err := os.MkdirAll(d, 0o755)
|
||||
if err != nil {
|
||||
anyFailures = true
|
||||
return fmt.Errorf("failed to mkdir %s: %w", d, err)
|
||||
}
|
||||
}
|
||||
destFile, err := os.OpenFile(destName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
if err := extractBundleFile(f, destName, name); err != nil {
|
||||
anyFailures = true
|
||||
return fmt.Errorf("failed to open output file %s: %w", destName, err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
if _, err := io.Copy(destFile, src); err != nil {
|
||||
anyFailures = true
|
||||
return fmt.Errorf("failed to open extract file %s: %w", destName, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, f := range links {
|
||||
|
|
@ -225,16 +222,24 @@ func DoUpgrade(interactive bool) error {
|
|||
return err
|
||||
}
|
||||
link := string(buf)
|
||||
if link[0] == '/' {
|
||||
if link == "" {
|
||||
anyFailures = true
|
||||
return fmt.Errorf("bundle contains empty symlink %s", f.Name)
|
||||
}
|
||||
if filepath.IsAbs(link) {
|
||||
anyFailures = true
|
||||
return fmt.Errorf("bundle contains absolute symlink %s -> %s", f.Name, link)
|
||||
}
|
||||
// Don't allow links outside of Ollama.app
|
||||
if strings.HasPrefix(filepath.Join(filepath.Dir(name), link), "..") {
|
||||
if !validBundleLinkTarget(name, link, bundleEntryRelative) {
|
||||
anyFailures = true
|
||||
return fmt.Errorf("bundle contains link outside of contents %s -> %s", f.Name, link)
|
||||
return fmt.Errorf("bundle contains invalid symlink %s -> %s", f.Name, link)
|
||||
}
|
||||
if err = os.Symlink(link, filepath.Join(BundlePath, name)); err != nil {
|
||||
destName, err := bundleEntryPath(BundlePath, name, bundleEntryRelative)
|
||||
if err != nil {
|
||||
anyFailures = true
|
||||
return err
|
||||
}
|
||||
if err = os.Symlink(link, destName); err != nil {
|
||||
anyFailures = true
|
||||
return err
|
||||
}
|
||||
|
|
@ -282,8 +287,11 @@ func verifyDownload() error {
|
|||
links := []*zip.File{}
|
||||
for _, f := range r.File {
|
||||
if strings.HasSuffix(f.Name, "/") {
|
||||
d := filepath.Join(dir, f.Name)
|
||||
err := os.MkdirAll(d, 0o755)
|
||||
d, err := bundleEntryPath(dir, f.Name, bundleEntryWithArchiveRoot)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = os.MkdirAll(d, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mkdir %s: %w", d, err)
|
||||
}
|
||||
|
|
@ -294,26 +302,12 @@ func verifyDownload() error {
|
|||
links = append(links, f)
|
||||
continue
|
||||
}
|
||||
src, err := f.Open()
|
||||
destName, err := bundleEntryPath(dir, f.Name, bundleEntryWithArchiveRoot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open bundle file %s: %w", f.Name, err)
|
||||
return err
|
||||
}
|
||||
destName := filepath.Join(dir, f.Name)
|
||||
// Verify directory first
|
||||
d := filepath.Dir(destName)
|
||||
if _, err := os.Stat(d); err != nil {
|
||||
err := os.MkdirAll(d, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mkdir %s: %w", d, err)
|
||||
}
|
||||
}
|
||||
destFile, err := os.OpenFile(destName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open output file %s: %w", destName, err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
if _, err := io.Copy(destFile, src); err != nil {
|
||||
return fmt.Errorf("failed to open extract file %s: %w", destName, err)
|
||||
if err := extractBundleFile(f, destName, f.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, f := range links {
|
||||
|
|
@ -326,13 +320,20 @@ func verifyDownload() error {
|
|||
return err
|
||||
}
|
||||
link := string(buf)
|
||||
if link[0] == '/' {
|
||||
if link == "" {
|
||||
return fmt.Errorf("bundle contains empty symlink %s", f.Name)
|
||||
}
|
||||
if filepath.IsAbs(link) {
|
||||
return fmt.Errorf("bundle contains absolute symlink %s -> %s", f.Name, link)
|
||||
}
|
||||
if strings.HasPrefix(filepath.Join(filepath.Dir(f.Name), link), "..") {
|
||||
return fmt.Errorf("bundle contains link outside of contents %s -> %s", f.Name, link)
|
||||
if !validBundleLinkTarget(f.Name, link, bundleEntryWithArchiveRoot) {
|
||||
return fmt.Errorf("bundle contains invalid symlink %s -> %s", f.Name, link)
|
||||
}
|
||||
if err = os.Symlink(link, filepath.Join(dir, f.Name)); err != nil {
|
||||
destName, err := bundleEntryPath(dir, f.Name, bundleEntryWithArchiveRoot)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = os.Symlink(link, destName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
@ -343,6 +344,53 @@ func verifyDownload() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func bundleEntryPath(root, name string, scope bundleEntryScope) (string, error) {
|
||||
cleanName := filepath.Clean(filepath.FromSlash(name))
|
||||
if !filepath.IsLocal(cleanName) {
|
||||
return "", fmt.Errorf("bundle contains invalid path: %s", name)
|
||||
}
|
||||
if scope == bundleEntryWithArchiveRoot && cleanName != updateArchiveRoot &&
|
||||
!strings.HasPrefix(cleanName, updateArchiveRoot+string(os.PathSeparator)) {
|
||||
return "", fmt.Errorf("bundle contains invalid path: %s", name)
|
||||
}
|
||||
return filepath.Join(root, cleanName), nil
|
||||
}
|
||||
|
||||
func extractBundleFile(f *zip.File, destName, name string) error {
|
||||
src, err := f.Open()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open bundle file %s: %w", name, err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
d := filepath.Dir(destName)
|
||||
if _, err := os.Stat(d); err != nil {
|
||||
if err := os.MkdirAll(d, 0o755); err != nil {
|
||||
return fmt.Errorf("failed to mkdir %s: %w", d, err)
|
||||
}
|
||||
}
|
||||
|
||||
destFile, err := os.OpenFile(destName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open output file %s: %w", destName, err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
if _, err := io.Copy(destFile, src); err != nil {
|
||||
return fmt.Errorf("failed to open extract file %s: %w", destName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validBundleLinkTarget(name, link string, scope bundleEntryScope) bool {
|
||||
cleanTarget := filepath.Clean(filepath.Join(filepath.Dir(filepath.FromSlash(name)), filepath.FromSlash(link)))
|
||||
if !filepath.IsLocal(cleanTarget) {
|
||||
return false
|
||||
}
|
||||
return scope == bundleEntryRelative || cleanTarget == updateArchiveRoot ||
|
||||
strings.HasPrefix(cleanTarget, updateArchiveRoot+string(os.PathSeparator))
|
||||
}
|
||||
|
||||
// If we detect an upgrade bundle, attempt to upgrade at startup
|
||||
func DoUpgradeAtStartup() error {
|
||||
bundle := getStagedUpdate()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package updater
|
|||
|
||||
import (
|
||||
"archive/zip"
|
||||
"errors"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
|
@ -146,6 +147,46 @@ func TestDoUpgrade(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDoUpgradeRejectsInvalidBundlePath(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
BundlePath = filepath.Join(tmpDir, "Ollama.app")
|
||||
appBackupDir = filepath.Join(tmpDir, "backup")
|
||||
UpdateStageDir = filepath.Join(tmpDir, "updates")
|
||||
UpgradeMarkerFile = filepath.Join(tmpDir, "upgraded")
|
||||
bundle := filepath.Join(UpdateStageDir, "foo", "ollama-darwin.zip")
|
||||
invalidTarget := filepath.Join(tmpDir, "invalid-entry")
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(BundlePath, "Contents", "MacOS"), 0o755); err != nil {
|
||||
t.Fatal("failed to create empty dirs")
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(BundlePath, "Contents", "MacOS", "Ollama"), []byte("old app"), 0o755); err != nil {
|
||||
t.Fatal("failed to create old app")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(bundle), 0o755); err != nil {
|
||||
t.Fatal("failed to create empty dirs")
|
||||
}
|
||||
if err := zipCreationHelper(bundle, []testPayload{{
|
||||
Name: "Ollama.app/../invalid-entry",
|
||||
Body: []byte("payload"),
|
||||
}}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := DoUpgrade(false); err == nil {
|
||||
t.Fatal("expected failure with invalid bundle path")
|
||||
} else if !strings.Contains(err.Error(), "bundle contains invalid path") {
|
||||
t.Fatalf("unexpected error with invalid bundle path: %s", err)
|
||||
}
|
||||
if _, err := os.Stat(invalidTarget); err == nil {
|
||||
t.Fatalf("invalid bundle path wrote %s", invalidTarget)
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
t.Fatalf("unexpected stat error for %s: %s", invalidTarget, err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(BundlePath, "Contents", "MacOS", "Ollama")); err != nil {
|
||||
t.Fatalf("old app was not restored: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoUpgradeAtStartup(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
BundlePath = filepath.Join(tmpDir, "Ollama.app")
|
||||
|
|
@ -203,7 +244,7 @@ func TestVerifyDownloadFailures(t *testing.T) {
|
|||
in []testPayload
|
||||
expected string
|
||||
}{
|
||||
{"breakout", []testPayload{
|
||||
{"invalid symlink target", []testPayload{
|
||||
{
|
||||
Name: "Ollama.app/",
|
||||
Body: []byte{},
|
||||
|
|
@ -212,15 +253,34 @@ func TestVerifyDownloadFailures(t *testing.T) {
|
|||
Body: []byte("cli payload here"),
|
||||
}, {
|
||||
Name: "Ollama.app/Contents/MacOS/Ollama",
|
||||
Body: []byte("../../../../breakout"),
|
||||
Body: []byte("../../../../invalid-target"),
|
||||
Mode: os.ModeSymlink,
|
||||
},
|
||||
}, "bundle contains link outside"},
|
||||
}, "bundle contains invalid symlink"},
|
||||
{"invalid archive symlink target", []testPayload{
|
||||
{
|
||||
Name: "Ollama.app/Contents/MacOS/Ollama",
|
||||
Body: []byte("../../../invalid-target"),
|
||||
Mode: os.ModeSymlink,
|
||||
},
|
||||
}, "bundle contains invalid symlink"},
|
||||
{"absolute", []testPayload{{
|
||||
Name: "Ollama.app/Contents/MacOS/Ollama",
|
||||
Body: []byte("/etc/foo"),
|
||||
Mode: os.ModeSymlink,
|
||||
}}, "bundle contains absolute"},
|
||||
{"invalid relative file", []testPayload{{
|
||||
Name: "Ollama.app/../invalid-entry",
|
||||
Body: []byte("payload"),
|
||||
}}, "bundle contains invalid path"},
|
||||
{"invalid relative directory", []testPayload{{
|
||||
Name: "Ollama.app/../invalid-entry/",
|
||||
Body: []byte{},
|
||||
}}, "bundle contains invalid path"},
|
||||
{"absolute file", []testPayload{{
|
||||
Name: filepath.Join(tmpDir, "invalid-entry"),
|
||||
Body: []byte("payload"),
|
||||
}}, "bundle contains invalid path"},
|
||||
{"missing", []testPayload{{
|
||||
Name: "Ollama.app/Contents/MacOS/Ollama",
|
||||
Body: []byte("../nothere"),
|
||||
|
|
@ -242,6 +302,11 @@ func TestVerifyDownloadFailures(t *testing.T) {
|
|||
if err == nil || !strings.Contains(err.Error(), tt.expected) {
|
||||
t.Fatalf("expected \"%s\" got %s", tt.expected, err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(tmpDir, "invalid-entry")); err == nil {
|
||||
t.Fatal("invalid bundle path wrote unexpected file")
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
t.Fatalf("unexpected stat error for invalid file: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
127
app/updater/updater_live_test.go
Normal file
127
app/updater/updater_live_test.go
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
//go:build (windows || darwin) && updater_live
|
||||
|
||||
package updater
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/app/store"
|
||||
"github.com/ollama/ollama/app/version"
|
||||
)
|
||||
|
||||
// TestLiveAppUpdate exercises the production update endpoint and downloads the
|
||||
// current OS update artifact. It is intentionally excluded from normal test
|
||||
// runs because it depends on ollama.com and downloads a release artifact.
|
||||
//
|
||||
// Run with:
|
||||
//
|
||||
// go test -tags updater_live -run TestLiveAppUpdate ./app/updater
|
||||
func TestLiveAppUpdate(t *testing.T) {
|
||||
const spoofedVersion = "0.20.0"
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
oldUpdateStageDir := UpdateStageDir
|
||||
oldUpdateDownloaded := UpdateDownloaded
|
||||
oldVerifyDownload := VerifyDownload
|
||||
oldVersion := version.Version
|
||||
defer func() {
|
||||
UpdateStageDir = oldUpdateStageDir
|
||||
UpdateDownloaded = oldUpdateDownloaded
|
||||
VerifyDownload = oldVerifyDownload
|
||||
version.Version = oldVersion
|
||||
}()
|
||||
|
||||
version.Version = spoofedVersion
|
||||
|
||||
expectedFilename := ""
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
t.Setenv("LOCALAPPDATA", t.TempDir())
|
||||
expectedFilename = "OllamaSetup.exe"
|
||||
case "darwin":
|
||||
expectedFilename = "Ollama-darwin.zip"
|
||||
default:
|
||||
t.Fatalf("unsupported updater live test OS %q", runtime.GOOS)
|
||||
}
|
||||
|
||||
UpdateStageDir = filepath.Join(t.TempDir(), "updates")
|
||||
UpdateDownloaded = false
|
||||
verifyCalled := false
|
||||
VerifyDownload = func() error {
|
||||
verifyCalled = true
|
||||
return verifyDownload()
|
||||
}
|
||||
|
||||
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "db.sqlite")}}
|
||||
defer updater.Store.Close()
|
||||
|
||||
available, updateResp := updater.checkForUpdate(ctx)
|
||||
if !available {
|
||||
t.Fatalf("expected production update check to offer an update for spoofed version %s", spoofedVersion)
|
||||
}
|
||||
if updateResp.UpdateURL == "" {
|
||||
t.Fatal("production update response did not include a download URL")
|
||||
}
|
||||
t.Logf("production update version=%q url=%q", updateResp.UpdateVersion, updateResp.UpdateURL)
|
||||
|
||||
if err := updater.DownloadNewRelease(ctx, updateResp); err != nil {
|
||||
t.Fatalf("download production update: %v", err)
|
||||
}
|
||||
|
||||
staged := getStagedUpdate()
|
||||
if staged == "" {
|
||||
t.Fatal("production update was not staged")
|
||||
}
|
||||
t.Logf("staged production update at %s", staged)
|
||||
|
||||
assertPathInsideDir(t, UpdateStageDir, staged)
|
||||
if filepath.Base(staged) != expectedFilename {
|
||||
t.Fatalf("expected staged %s update filename to be %q, got %q", runtime.GOOS, expectedFilename, filepath.Base(staged))
|
||||
}
|
||||
expectedExt := filepath.Ext(expectedFilename)
|
||||
if filepath.Ext(staged) != expectedExt {
|
||||
t.Fatalf("expected staged %s update to be a %s artifact, got %s", runtime.GOOS, expectedExt, staged)
|
||||
}
|
||||
|
||||
info, err := os.Stat(staged)
|
||||
if err != nil {
|
||||
t.Fatalf("stat staged update: %v", err)
|
||||
}
|
||||
if info.Size() == 0 {
|
||||
t.Fatal("staged production update is empty")
|
||||
}
|
||||
|
||||
if !verifyCalled {
|
||||
t.Fatal("DownloadNewRelease did not call VerifyDownload")
|
||||
}
|
||||
t.Logf("production updater download path verified staged %s update", runtime.GOOS)
|
||||
}
|
||||
|
||||
func assertPathInsideDir(t *testing.T, dir, name string) {
|
||||
t.Helper()
|
||||
|
||||
dir, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
name, err = filepath.Abs(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(dir, name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || filepath.IsAbs(rel) {
|
||||
t.Fatalf("staged update escaped update stage dir: %s", name)
|
||||
}
|
||||
}
|
||||
|
|
@ -11,7 +11,9 @@ import (
|
|||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -19,6 +21,52 @@ import (
|
|||
"github.com/ollama/ollama/app/store"
|
||||
)
|
||||
|
||||
func TestUpdateStagePathRejectsUnsafeFilename(t *testing.T) {
|
||||
stageDir := t.TempDir()
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
filename string
|
||||
}{
|
||||
{"empty", ""},
|
||||
{"dot", "."},
|
||||
{"dotdot", ".."},
|
||||
{"posix_parent", "../OllamaSetup.exe"},
|
||||
{"windows_parent", `..\OllamaSetup.exe`},
|
||||
{"posix_absolute_tmp", "/tmp/OllamaSetup.exe"},
|
||||
{"darwin_absolute_app", "/Applications/Ollama.app"},
|
||||
{"darwin_bundle_path", "Ollama.app/Contents/MacOS/Ollama"},
|
||||
{"darwin_user_download", "~/Downloads/Ollama-darwin.zip"},
|
||||
{"windows_absolute", `C:\Users\Public\OllamaSetup.exe`},
|
||||
{"colon", "Ollama:Setup.exe"},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if _, err := updateStagePath(stageDir, "etag", tt.filename); err == nil {
|
||||
t.Fatal("expected unsafe filename to be rejected")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateStagePathHashesETag(t *testing.T) {
|
||||
stageDir := t.TempDir()
|
||||
stageFilename, err := updateStagePath(stageDir, `../escaped`, "OllamaSetup.exe")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(stageDir, stageFilename)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || filepath.IsAbs(rel) {
|
||||
t.Fatalf("stage filename escaped stage dir: %s", stageFilename)
|
||||
}
|
||||
etagDir := filepath.Base(filepath.Dir(stageFilename))
|
||||
if etagDir == ".." || etagDir == "escaped" || strings.ContainsAny(etagDir, `/\`) {
|
||||
t.Fatalf("stage filename used raw etag path component: %s", stageFilename)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNewReleaseAvailable(t *testing.T) {
|
||||
slog.SetLogLoggerLevel(slog.LevelDebug)
|
||||
var server *httptest.Server
|
||||
|
|
@ -47,6 +95,223 @@ func TestIsNewReleaseAvailable(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDownloadNewReleaseRejectsUnsafeHeaderFilename(t *testing.T) {
|
||||
UpdateStageDir = t.TempDir()
|
||||
oldInstaller := Installer
|
||||
oldVerifyDownload := VerifyDownload
|
||||
oldUpdateDownloaded := UpdateDownloaded
|
||||
defer func() {
|
||||
Installer = oldInstaller
|
||||
VerifyDownload = oldVerifyDownload
|
||||
UpdateDownloaded = oldUpdateDownloaded
|
||||
}()
|
||||
Installer = "OllamaSetup.exe"
|
||||
UpdateDownloaded = false
|
||||
VerifyDownload = func() error {
|
||||
t.Fatal("verification should not run for rejected downloads")
|
||||
return nil
|
||||
}
|
||||
|
||||
var getAttempted atomic.Bool
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodHead {
|
||||
w.Header().Set("ETag", `"safe"`)
|
||||
w.Header().Set("Content-Disposition", `attachment; filename="../OllamaSetup.exe"`)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
getAttempted.Store(true)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
updater := &Updater{}
|
||||
err := updater.DownloadNewRelease(t.Context(), UpdateResponse{UpdateURL: server.URL + "/download"})
|
||||
if err == nil || !strings.Contains(err.Error(), "unsafe update filename") {
|
||||
t.Fatalf("expected unsafe filename error, got %v", err)
|
||||
}
|
||||
if getAttempted.Load() {
|
||||
t.Fatal("download should not continue after unsafe filename")
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(filepath.Dir(UpdateStageDir), "OllamaSetup.exe")); err == nil {
|
||||
t.Fatal("download escaped update stage dir")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadNewReleaseDoesNotUseRawETagAsPathComponent(t *testing.T) {
|
||||
UpdateStageDir = t.TempDir()
|
||||
oldInstaller := Installer
|
||||
oldVerifyDownload := VerifyDownload
|
||||
oldUpdateDownloaded := UpdateDownloaded
|
||||
defer func() {
|
||||
Installer = oldInstaller
|
||||
VerifyDownload = oldVerifyDownload
|
||||
UpdateDownloaded = oldUpdateDownloaded
|
||||
}()
|
||||
Installer = "OllamaSetup.exe"
|
||||
UpdateDownloaded = false
|
||||
VerifyDownload = func() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
payload := []byte("payload")
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("ETag", `"../escaped"`)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if r.Method == http.MethodGet {
|
||||
_, _ = w.Write(payload)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
updater := &Updater{}
|
||||
if err := updater.DownloadNewRelease(t.Context(), UpdateResponse{UpdateURL: server.URL + "/download"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(filepath.Dir(UpdateStageDir), "escaped", Installer)); err == nil {
|
||||
t.Fatal("download escaped update stage dir via etag")
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(UpdateStageDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("expected one staged update dir, got %d", len(entries))
|
||||
}
|
||||
stageFilename := filepath.Join(UpdateStageDir, entries[0].Name(), Installer)
|
||||
got, err := os.ReadFile(stageFilename)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(got) != string(payload) {
|
||||
t.Fatalf("unexpected staged payload %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackgroundCheckerSkipsAlreadyStagedETagDownload(t *testing.T) {
|
||||
UpdateStageDir = t.TempDir()
|
||||
oldInstaller := Installer
|
||||
oldVerifyDownload := VerifyDownload
|
||||
oldUpdateDownloaded := UpdateDownloaded
|
||||
oldUpdateCheckInitialDelay := UpdateCheckInitialDelay
|
||||
oldUpdateCheckInterval := UpdateCheckInterval
|
||||
oldUpdateCheckURLBase := UpdateCheckURLBase
|
||||
defer func() {
|
||||
Installer = oldInstaller
|
||||
VerifyDownload = oldVerifyDownload
|
||||
UpdateDownloaded = oldUpdateDownloaded
|
||||
UpdateCheckInitialDelay = oldUpdateCheckInitialDelay
|
||||
UpdateCheckInterval = oldUpdateCheckInterval
|
||||
UpdateCheckURLBase = oldUpdateCheckURLBase
|
||||
}()
|
||||
Installer = "OllamaSetup.exe"
|
||||
UpdateDownloaded = false
|
||||
UpdateCheckInitialDelay = time.Millisecond
|
||||
UpdateCheckInterval = 5 * time.Millisecond
|
||||
|
||||
var verifyCount atomic.Int32
|
||||
VerifyDownload = func() error {
|
||||
verifyCount.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
headETag := `"old-update"`
|
||||
getETag := `"download-response-etag"`
|
||||
payload := []byte("payload")
|
||||
var headCount atomic.Int32
|
||||
var getCount atomic.Int32
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/update.json":
|
||||
w.Write([]byte(
|
||||
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
||||
server.URL+"/9.9.9/"+Installer)))
|
||||
case "/9.9.9/" + Installer:
|
||||
w.Header().Set("Content-Disposition", `attachment; filename="OllamaSetup.exe"`)
|
||||
switch r.Method {
|
||||
case http.MethodHead:
|
||||
etag := headETag
|
||||
if getCount.Load() > 0 {
|
||||
etag = getETag
|
||||
}
|
||||
w.Header().Set("ETag", etag)
|
||||
headCount.Add(1)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
case http.MethodGet:
|
||||
w.Header().Set("ETag", getETag)
|
||||
getCount.Add(1)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(payload)
|
||||
default:
|
||||
t.Errorf("unexpected request method %s", r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected request path %s", r.URL.Path)
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
|
||||
updater := &Updater{Store: &store.Store{DBPath: filepath.Join(t.TempDir(), "test.db")}}
|
||||
defer updater.Store.Close()
|
||||
settings, err := updater.Store.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
settings.AutoUpdateEnabled = true
|
||||
if err := updater.Store.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
callbacks := make(chan string, 4)
|
||||
updater.StartBackgroundUpdaterChecker(ctx, func(ver string) error {
|
||||
callbacks <- ver
|
||||
return nil
|
||||
})
|
||||
|
||||
for range 2 {
|
||||
select {
|
||||
case <-callbacks:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for repeated update checks")
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
|
||||
stageFilename, err := updateStagePath(UpdateStageDir, getETag, Installer)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got, err := os.ReadFile(stageFilename)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(got) != string(payload) {
|
||||
t.Fatalf("unexpected staged payload %q", got)
|
||||
}
|
||||
|
||||
if headCount.Load() < 2 {
|
||||
t.Fatalf("HEAD count = %d, want at least 2", headCount.Load())
|
||||
}
|
||||
if getCount.Load() != 1 {
|
||||
t.Fatalf("GET count = %d, want 1", getCount.Load())
|
||||
}
|
||||
if verifyCount.Load() != 1 {
|
||||
t.Fatalf("verification count = %d, want 1", verifyCount.Load())
|
||||
}
|
||||
if !UpdateDownloaded {
|
||||
t.Fatal("UpdateDownloaded should stay true for already staged update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackgoundChecker(t *testing.T) {
|
||||
UpdateStageDir = t.TempDir()
|
||||
haveUpdate := false
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package updater
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
|
@ -18,6 +19,30 @@ import (
|
|||
|
||||
var runningInstaller string
|
||||
|
||||
var (
|
||||
crypt32 = windows.NewLazySystemDLL("crypt32.dll")
|
||||
procCryptMsgGetParam = crypt32.NewProc("CryptMsgGetParam")
|
||||
procCryptMsgClose = crypt32.NewProc("CryptMsgClose")
|
||||
)
|
||||
|
||||
const cmsgSignerInfoParam = 6
|
||||
|
||||
type cmsgSignerInfo struct {
|
||||
Version uint32
|
||||
Issuer windows.CertNameBlob
|
||||
SerialNumber windows.CryptIntegerBlob
|
||||
HashAlgorithm windows.CryptAlgorithmIdentifier
|
||||
HashEncryptionAlgorithm windows.CryptAlgorithmIdentifier
|
||||
EncryptedHash windows.CryptDataBlob
|
||||
AuthAttrs cryptAttributes
|
||||
UnauthAttrs cryptAttributes
|
||||
}
|
||||
|
||||
type cryptAttributes struct {
|
||||
Count uint32
|
||||
Attributes unsafe.Pointer
|
||||
}
|
||||
|
||||
type OSVERSIONINFOEXW struct {
|
||||
dwOSVersionInfoSize uint32
|
||||
dwMajorVersion uint32
|
||||
|
|
@ -99,6 +124,12 @@ func DoUpgrade(interactive bool) error {
|
|||
return fmt.Errorf("failed to lookup downloads")
|
||||
}
|
||||
|
||||
if err := VerifyDownload(); err != nil {
|
||||
_ = os.Remove(bundle)
|
||||
slog.Warn("verification failure", "bundle", bundle, "error", err)
|
||||
return fmt.Errorf("staged update verification failed: %w", err)
|
||||
}
|
||||
|
||||
// We move the installer to ensure we don't race with multiple apps starting in quick succession
|
||||
if err := os.Rename(bundle, runningInstaller); err != nil {
|
||||
return fmt.Errorf("unable to rename %s -> %s : %w", bundle, runningInstaller, err)
|
||||
|
|
@ -184,6 +215,150 @@ func DoPostUpgradeCleanup() error {
|
|||
}
|
||||
|
||||
func verifyDownload() error {
|
||||
bundle := getStagedUpdate()
|
||||
if bundle == "" {
|
||||
return fmt.Errorf("failed to lookup downloads")
|
||||
}
|
||||
slog.Debug("verifying update", "bundle", bundle)
|
||||
|
||||
if err := verifyWindowsInstallerSignature(bundle); err != nil {
|
||||
return fmt.Errorf("signature verification failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyWindowsInstallerSignature(filename string) error {
|
||||
filename16, err := windows.UTF16PtrFromString(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data := &windows.WinTrustData{
|
||||
Size: uint32(unsafe.Sizeof(windows.WinTrustData{})),
|
||||
UIChoice: windows.WTD_UI_NONE,
|
||||
RevocationChecks: windows.WTD_REVOKE_WHOLECHAIN,
|
||||
UnionChoice: windows.WTD_CHOICE_FILE,
|
||||
StateAction: windows.WTD_STATEACTION_VERIFY,
|
||||
UIContext: windows.WTD_UICONTEXT_INSTALL,
|
||||
FileOrCatalogOrBlobOrSgnrOrCert: unsafe.Pointer(&windows.WinTrustFileInfo{
|
||||
Size: uint32(unsafe.Sizeof(windows.WinTrustFileInfo{})),
|
||||
FilePath: filename16,
|
||||
}),
|
||||
}
|
||||
|
||||
verifyErr := windows.WinVerifyTrustEx(windows.InvalidHWND, &windows.WINTRUST_ACTION_GENERIC_VERIFY_V2, data)
|
||||
data.StateAction = windows.WTD_STATEACTION_CLOSE
|
||||
closeErr := windows.WinVerifyTrustEx(windows.InvalidHWND, &windows.WINTRUST_ACTION_GENERIC_VERIFY_V2, data)
|
||||
if verifyErr != nil {
|
||||
return verifyErr
|
||||
}
|
||||
if closeErr != nil {
|
||||
return fmt.Errorf("close WinVerifyTrust state: %w", closeErr)
|
||||
}
|
||||
|
||||
subject, err := windowsInstallerSignerSubject(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
slog.Debug("verified update signature", "subject", subject)
|
||||
return nil
|
||||
}
|
||||
|
||||
func windowsInstallerSignerSubject(filename string) (string, error) {
|
||||
filename16, err := windows.UTF16PtrFromString(filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var certStore windows.Handle
|
||||
var msg windows.Handle
|
||||
if err := windows.CryptQueryObject(
|
||||
windows.CERT_QUERY_OBJECT_FILE,
|
||||
unsafe.Pointer(filename16),
|
||||
windows.CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED_EMBED,
|
||||
windows.CERT_QUERY_FORMAT_FLAG_BINARY,
|
||||
0,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
&certStore,
|
||||
&msg,
|
||||
nil,
|
||||
); err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer windows.CertCloseStore(certStore, 0) //nolint:errcheck
|
||||
defer cryptMsgClose(msg) //nolint:errcheck
|
||||
|
||||
var signerInfoSize uint32
|
||||
if err := cryptMsgGetParam(msg, cmsgSignerInfoParam, 0, nil, &signerInfoSize); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if signerInfoSize == 0 {
|
||||
return "", fmt.Errorf("missing signer info")
|
||||
}
|
||||
|
||||
signerInfoBuf := make([]byte, signerInfoSize)
|
||||
if err := cryptMsgGetParam(msg, cmsgSignerInfoParam, 0, unsafe.Pointer(&signerInfoBuf[0]), &signerInfoSize); err != nil {
|
||||
return "", err
|
||||
}
|
||||
signerInfo := (*cmsgSignerInfo)(unsafe.Pointer(&signerInfoBuf[0]))
|
||||
certInfo := windows.CertInfo{
|
||||
Issuer: signerInfo.Issuer,
|
||||
SerialNumber: signerInfo.SerialNumber,
|
||||
}
|
||||
|
||||
cert, err := windows.CertFindCertificateInStore(
|
||||
certStore,
|
||||
windows.X509_ASN_ENCODING|windows.PKCS_7_ASN_ENCODING,
|
||||
0,
|
||||
windows.CERT_FIND_SUBJECT_CERT,
|
||||
unsafe.Pointer(&certInfo),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer windows.CertFreeCertificateContext(cert) //nolint:errcheck
|
||||
|
||||
parsed, err := x509.ParseCertificate(unsafe.Slice(cert.EncodedCert, cert.Length))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, org := range parsed.Subject.Organization {
|
||||
if org == "Ollama Inc." {
|
||||
return parsed.Subject.String(), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("unexpected signer: %s", parsed.Subject.String())
|
||||
}
|
||||
|
||||
func cryptMsgGetParam(msg windows.Handle, paramType, index uint32, data unsafe.Pointer, size *uint32) error {
|
||||
r1, _, e1 := procCryptMsgGetParam.Call(
|
||||
uintptr(msg),
|
||||
uintptr(paramType),
|
||||
uintptr(index),
|
||||
uintptr(data),
|
||||
uintptr(unsafe.Pointer(size)),
|
||||
)
|
||||
if r1 == 0 {
|
||||
if e1 != syscall.Errno(0) {
|
||||
return e1
|
||||
}
|
||||
return syscall.EINVAL
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cryptMsgClose(msg windows.Handle) error {
|
||||
r1, _, e1 := procCryptMsgClose.Call(uintptr(msg))
|
||||
if r1 == 0 {
|
||||
if e1 != syscall.Errno(0) {
|
||||
return e1
|
||||
}
|
||||
return syscall.EINVAL
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,85 @@
|
|||
//go:build windows || darwin
|
||||
//go:build windows
|
||||
|
||||
package updater
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestVerifyDownloadRejectsUnsignedWindowsInstaller(t *testing.T) {
|
||||
oldUpdateStageDir := UpdateStageDir
|
||||
defer func() {
|
||||
UpdateStageDir = oldUpdateStageDir
|
||||
}()
|
||||
|
||||
t.Setenv("LOCALAPPDATA", t.TempDir())
|
||||
UpdateStageDir = t.TempDir()
|
||||
bundle := filepath.Join(UpdateStageDir, "etag", "OllamaSetup.exe")
|
||||
if err := os.MkdirAll(filepath.Dir(bundle), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(bundle, []byte("not a signed installer"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err := verifyDownload()
|
||||
if err == nil || !strings.Contains(err.Error(), "signature verification failed") {
|
||||
t.Fatalf("expected signature verification failure, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoUpgradeAtStartupRejectsUnsignedWindowsInstaller(t *testing.T) {
|
||||
oldUpdateStageDir := UpdateStageDir
|
||||
oldRunningInstaller := runningInstaller
|
||||
oldUpgradeLogFile := UpgradeLogFile
|
||||
oldUpgradeMarkerFile := UpgradeMarkerFile
|
||||
oldVerifyDownload := VerifyDownload
|
||||
defer func() {
|
||||
UpdateStageDir = oldUpdateStageDir
|
||||
runningInstaller = oldRunningInstaller
|
||||
UpgradeLogFile = oldUpgradeLogFile
|
||||
UpgradeMarkerFile = oldUpgradeMarkerFile
|
||||
VerifyDownload = oldVerifyDownload
|
||||
}()
|
||||
|
||||
t.Setenv("LOCALAPPDATA", t.TempDir())
|
||||
UpdateStageDir = t.TempDir()
|
||||
runDir := t.TempDir()
|
||||
runningInstaller = filepath.Join(runDir, "OllamaSetup.exe")
|
||||
UpgradeLogFile = filepath.Join(runDir, "upgrade.log")
|
||||
UpgradeMarkerFile = filepath.Join(runDir, "upgraded")
|
||||
VerifyDownload = verifyDownload
|
||||
|
||||
bundle := filepath.Join(UpdateStageDir, "etag", "OllamaSetup.exe")
|
||||
if err := os.MkdirAll(filepath.Dir(bundle), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(bundle, []byte("not a signed installer"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err := DoUpgradeAtStartup()
|
||||
if err == nil || !strings.Contains(err.Error(), "signature verification failed") {
|
||||
t.Fatalf("expected signature verification failure, got %v", err)
|
||||
}
|
||||
if _, err := os.Stat(runningInstaller); !os.IsNotExist(err) {
|
||||
t.Fatalf("unsigned installer was moved before verification failed: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(bundle); !os.IsNotExist(err) {
|
||||
t.Fatalf("unsigned staged installer was not removed after verification failure: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsInstallerRunning(t *testing.T) {
|
||||
oldInstaller := Installer
|
||||
defer func() {
|
||||
Installer = oldInstaller
|
||||
}()
|
||||
|
||||
slog.SetLogLoggerLevel(slog.LevelDebug)
|
||||
Installer = "go.exe"
|
||||
if !isInstallerRunning() {
|
||||
|
|
|
|||
|
|
@ -6,5 +6,6 @@ func setTestHome(t *testing.T, home string) {
|
|||
t.Helper()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
t.Setenv("OLLAMA_MODELS", "")
|
||||
ReloadServerConfig()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,5 +10,6 @@ func setTestHome(t *testing.T, home string) {
|
|||
t.Helper()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
t.Setenv("OLLAMA_MODELS", "")
|
||||
envconfig.ReloadServerConfig()
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue