mirror of
https://github.com/ollama/ollama.git
synced 2026-05-13 06:21:28 +00:00
server/launch: add model recommendations cache endpoint (#15868)
This commit is contained in:
parent
87288ced4f
commit
321cc8a2ba
10 changed files with 1250 additions and 38 deletions
|
|
@ -368,6 +368,16 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
|
|||
return &lr, nil
|
||||
}
|
||||
|
||||
// ModelRecommendationsExperimental lists model recommendations from the local
|
||||
// server's experimental recommendations endpoint.
|
||||
func (c *Client) ModelRecommendationsExperimental(ctx context.Context) (*ModelRecommendationsResponse, error) {
|
||||
var resp ModelRecommendationsResponse
|
||||
if err := c.do(ctx, http.MethodGet, "/api/experimental/model-recommendations", nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListRunning lists running models.
|
||||
func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
|
||||
var lr ProcessResponse
|
||||
|
|
|
|||
14
api/types.go
14
api/types.go
|
|
@ -802,6 +802,20 @@ type ListResponse struct {
|
|||
Models []ListModelResponse `json:"models"`
|
||||
}
|
||||
|
||||
// ModelRecommendationsResponse is the response from [Client.ModelRecommendationsExperimental].
|
||||
type ModelRecommendationsResponse struct {
|
||||
Recommendations []ModelRecommendation `json:"recommendations"`
|
||||
}
|
||||
|
||||
// ModelRecommendation is a single recommendation entry in [ModelRecommendationsResponse].
|
||||
type ModelRecommendation struct {
|
||||
Model string `json:"model"`
|
||||
Description string `json:"description"`
|
||||
ContextLength int `json:"context_length,omitempty"`
|
||||
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
|
||||
VRAM string `json:"vram,omitempty"`
|
||||
}
|
||||
|
||||
// ProcessResponse is the response from [Client.Process].
|
||||
type ProcessResponse struct {
|
||||
Models []ProcessModelResponse `json:"models"`
|
||||
|
|
|
|||
|
|
@ -273,6 +273,8 @@ func TestLaunchCmdModelFlagClearsDisabledCloudOverride(t *testing.T) {
|
|||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -474,6 +476,8 @@ func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T)
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
|
||||
case "/api/show":
|
||||
|
|
|
|||
|
|
@ -49,15 +49,6 @@ func withHermesUserHome(t *testing.T, dir string) {
|
|||
})
|
||||
}
|
||||
|
||||
func withHermesLookPath(t *testing.T, fn func(string) (string, error)) {
|
||||
t.Helper()
|
||||
old := hermesLookPath
|
||||
hermesLookPath = fn
|
||||
t.Cleanup(func() {
|
||||
hermesLookPath = old
|
||||
})
|
||||
}
|
||||
|
||||
func clearHermesMessagingEnvVars(t *testing.T) {
|
||||
t.Helper()
|
||||
for _, group := range hermesMessagingEnvGroups {
|
||||
|
|
@ -112,6 +103,8 @@ func TestHermesConfigurePreservesExistingConfigAndEnablesWeb(t *testing.T) {
|
|||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"gemma4"},{"name":"qwen3.5"},{"name":"llama3.3"}]}`)
|
||||
default:
|
||||
|
|
@ -224,6 +217,8 @@ func TestHermesConfigureUpdatesMatchingCustomProviderWithoutDroppingFields(t *te
|
|||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"gemma4"},{"name":"qwen3.5"},{"name":"llama3.3"}]}`)
|
||||
default:
|
||||
|
|
@ -300,6 +295,8 @@ func TestHermesConfigureUsesLaunchResolvedHostForModelDiscovery(t *testing.T) {
|
|||
switch r.URL.Path {
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"gemma4"},{"name":"qwen3.5"},{"name":"llama3.3"}]}`)
|
||||
default:
|
||||
|
|
@ -365,6 +362,8 @@ func TestHermesConfigureMigratesLegacyManagedAliases(t *testing.T) {
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"gemma4"},{"name":"qwen3.5"}]}`)
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"slices"
|
||||
|
|
@ -182,9 +183,12 @@ type ModelInfo = modelInfo
|
|||
|
||||
// ModelItem represents a model for selection UIs.
|
||||
type ModelItem struct {
|
||||
Name string
|
||||
Description string
|
||||
Recommended bool
|
||||
Name string
|
||||
Description string
|
||||
Recommended bool
|
||||
VRAM string
|
||||
ContextLength int
|
||||
MaxOutputTokens int
|
||||
}
|
||||
|
||||
// LaunchCmd returns the cobra command for launching integrations.
|
||||
|
|
@ -720,9 +724,10 @@ func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []
|
|||
if err := c.loadModelInventoryOnce(ctx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
recommendations := c.recommendations(ctx)
|
||||
|
||||
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
|
||||
items, orderedChecked, _, _ := buildModelList(c.modelInventory, preChecked, current)
|
||||
items, orderedChecked, _, _ := buildModelListWithRecommendations(c.modelInventory, recommendations, preChecked, current)
|
||||
if cloudDisabled {
|
||||
items = filterCloudItems(items)
|
||||
orderedChecked = c.filterDisabledCloudModels(ctx, orderedChecked)
|
||||
|
|
@ -733,6 +738,60 @@ func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []
|
|||
return items, orderedChecked, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) recommendations(ctx context.Context) []ModelItem {
|
||||
recommendations, err := c.requestRecommendations(ctx)
|
||||
if err != nil || len(recommendations) == 0 {
|
||||
// Fail open: recommendation issues should not block launch flows.
|
||||
// Fall back to built-in recommendations until server data is available.
|
||||
fallback := append([]ModelItem(nil), recommendedModels...)
|
||||
setDynamicCloudModelLimits(cloudModelLimitsFromRecommendations(fallback))
|
||||
return fallback
|
||||
}
|
||||
setDynamicCloudModelLimits(cloudModelLimitsFromRecommendations(recommendations))
|
||||
return recommendations
|
||||
}
|
||||
|
||||
func (c *launcherClient) requestRecommendations(ctx context.Context) ([]ModelItem, error) {
|
||||
resp, err := c.apiClient.ModelRecommendationsExperimental(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items := make([]ModelItem, 0, len(resp.Recommendations))
|
||||
seen := make(map[string]struct{}, len(resp.Recommendations))
|
||||
|
||||
for _, rec := range resp.Recommendations {
|
||||
name := strings.TrimSpace(rec.Model)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
|
||||
if isCloudModelName(name) && (rec.ContextLength <= 0 || rec.MaxOutputTokens <= 0) {
|
||||
slog.Warn("skipping cloud recommendation with missing limits", "model", name)
|
||||
continue
|
||||
}
|
||||
|
||||
description := strings.TrimSpace(rec.Description)
|
||||
if description == "" {
|
||||
description = "Recommended model"
|
||||
}
|
||||
|
||||
items = append(items, ModelItem{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Recommended: true,
|
||||
VRAM: strings.TrimSpace(rec.VRAM),
|
||||
ContextLength: rec.ContextLength,
|
||||
MaxOutputTokens: rec.MaxOutputTokens,
|
||||
})
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (c *launcherClient) ensureModelsReady(ctx context.Context, models []string) error {
|
||||
models = dedupeModelList(models)
|
||||
if len(models) == 0 {
|
||||
|
|
|
|||
|
|
@ -197,6 +197,8 @@ func TestBuildLauncherState_ManagedSingleIntegrationUsesCurrentModel(t *testing.
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -230,6 +232,8 @@ func TestBuildLauncherState_ManagedSingleIntegrationShowsSavedModelWhenLiveConfi
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -269,6 +273,8 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfiguresOnboardsAndRuns(t *
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -326,6 +332,8 @@ func TestLaunchIntegration_ManagedSingleIntegrationReOnboardsWhenSavedFlagIsStal
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -418,6 +426,8 @@ func TestLaunchIntegration_ManagedSingleIntegrationSkipsRewriteWhenSavedMatches(
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -468,6 +478,8 @@ func TestLaunchIntegration_ManagedSingleIntegrationRewritesWhenSavedDiffers(t *t
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -715,6 +727,8 @@ func TestBuildLauncherState_InstalledAndCloudDisabled(t *testing.T) {
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||
case "/api/status":
|
||||
|
|
@ -767,6 +781,8 @@ func TestBuildLauncherState_MigratesLegacyOpenclawAliasConfig(t *testing.T) {
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||
default:
|
||||
|
|
@ -812,6 +828,8 @@ func TestBuildLauncherState_ToleratesInventoryFailure(t *testing.T) {
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
fmt.Fprint(w, `{"error":"temporary failure"}`)
|
||||
|
|
@ -858,6 +876,8 @@ func TestResolveRunModel_UsesSavedModelWithoutSelector(t *testing.T) {
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -903,6 +923,8 @@ func TestResolveRunModel_HeadlessYesAutoPicksLastModel(t *testing.T) {
|
|||
modelPulled := false
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -965,6 +987,8 @@ func TestResolveRunModel_UsesRequestPolicy(t *testing.T) {
|
|||
modelPulled := false
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -1024,6 +1048,8 @@ func TestResolveRunModel_ForcePickerAlwaysUsesSelector(t *testing.T) {
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -1074,6 +1100,8 @@ func TestResolveRunModel_ForcePicker_DoesNotReorderByLastModel(t *testing.T) {
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"qwen3.5"},{"name":"gemma4"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -1124,6 +1152,8 @@ func TestResolveRunModel_UsesSignInHookForCloudModel(t *testing.T) {
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[]}`)
|
||||
case "/api/status":
|
||||
|
|
@ -1181,6 +1211,8 @@ func TestLaunchIntegration_EditorForceConfigure(t *testing.T) {
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -1250,6 +1282,8 @@ func TestLaunchIntegration_EditorForceConfigure_FloatsCheckedModelsInPicker(t *t
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"qwen3.5:cloud","remote_model":"qwen3.5"},{"name":"qwen3.5"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -1368,6 +1402,8 @@ func TestLaunchIntegration_EditorCloudDisabledFallsBackToSelector(t *testing.T)
|
|||
switch r.URL.Path {
|
||||
case "/api/status":
|
||||
fmt.Fprint(w, `{"cloud":{"disabled":true,"source":"config"}}`)
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -1415,6 +1451,8 @@ func TestLaunchIntegration_EditorConfigureMultiSkipsMissingLocalAndPersistsAccep
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"glm-5:cloud","remote_model":"glm-5"}]}`)
|
||||
case "/api/status":
|
||||
|
|
@ -1497,6 +1535,8 @@ func TestLaunchIntegration_EditorConfigureMultiSkipsUnauthedCloudAndPersistsAcce
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"glm-5:cloud","remote_model":"glm-5"}]}`)
|
||||
case "/api/status":
|
||||
|
|
@ -1582,6 +1622,8 @@ func TestLaunchIntegration_EditorConfigureMultiRemovesReselectedFailingModel(t *
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"glm-5:cloud","remote_model":"glm-5"},{"name":"llama3.2"}]}`)
|
||||
case "/api/status":
|
||||
|
|
@ -1669,6 +1711,8 @@ func TestLaunchIntegration_EditorConfigureMultiAllFailuresKeepsExistingAndSkipsL
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -1975,6 +2019,8 @@ func TestLaunchIntegration_ConfigureOnlyDoesNotRequireInstalledBinary(t *testing
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -2018,6 +2064,8 @@ func TestLaunchIntegration_ClaudeSavesPrimaryModel(t *testing.T) {
|
|||
var aliasSyncCalled bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[]}`)
|
||||
case "/api/status":
|
||||
|
|
@ -2077,6 +2125,8 @@ func TestLaunchIntegration_ClaudeForceConfigureReprompts(t *testing.T) {
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"qwen3:8b"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -2134,6 +2184,8 @@ func TestLaunchIntegration_ClaudeForceConfigureMissingSelectionDoesNotSave(t *te
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -2260,6 +2312,8 @@ func TestLaunchIntegration_ConfigureOnlyPrompt(t *testing.T) {
|
|||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||
case "/api/show":
|
||||
|
|
@ -2484,6 +2538,8 @@ func TestLaunchIntegration_HeadlessSelectorFlowFailsWithoutPrompt(t *testing.T)
|
|||
pullCalled := false
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/experimental/model-recommendations":
|
||||
fmt.Fprint(w, `{"recommendations":[]}`)
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
|
||||
case "/api/show":
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
|
|
@ -21,17 +22,12 @@ import (
|
|||
)
|
||||
|
||||
var recommendedModels = []ModelItem{
|
||||
{Name: "kimi-k2.6:cloud", Description: "State-of-the-art coding, long-horizon execution, and multimodal agent swarm capability", Recommended: true},
|
||||
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true},
|
||||
{Name: "glm-5.1:cloud", Description: "Reasoning and code generation", Recommended: true},
|
||||
{Name: "minimax-m2.7:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
|
||||
{Name: "gemma4", Description: "Reasoning and code generation locally", Recommended: true},
|
||||
{Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true},
|
||||
}
|
||||
|
||||
var recommendedVRAM = map[string]string{
|
||||
"gemma4": "~16GB",
|
||||
"qwen3.5": "~11GB",
|
||||
{Name: "kimi-k2.6:cloud", Description: "State-of-the-art coding, long-horizon execution, and multimodal agent swarm capability", Recommended: true, ContextLength: 262_144, MaxOutputTokens: 262_144},
|
||||
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true, ContextLength: 262_144, MaxOutputTokens: 32_768},
|
||||
{Name: "glm-5.1:cloud", Description: "Reasoning and code generation", Recommended: true, ContextLength: 202_752, MaxOutputTokens: 131_072},
|
||||
{Name: "minimax-m2.7:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true, ContextLength: 204_800, MaxOutputTokens: 128_000},
|
||||
{Name: "gemma4", Description: "Reasoning and code generation locally", Recommended: true, VRAM: "~16GB"},
|
||||
{Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true, VRAM: "~11GB"},
|
||||
}
|
||||
|
||||
// cloudModelLimit holds context and output token limits for a cloud model.
|
||||
|
|
@ -40,10 +36,10 @@ type cloudModelLimit struct {
|
|||
Output int
|
||||
}
|
||||
|
||||
// cloudModelLimits maps cloud model base names to their token limits.
|
||||
// extraCloudModelLimits maps cloud model base names to token limits for models
|
||||
// that are not already covered by recommendedModels fallback entries.
|
||||
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||
var cloudModelLimits = map[string]cloudModelLimit{
|
||||
"minimax-m2.7": {Context: 204_800, Output: 128_000},
|
||||
var extraCloudModelLimits = map[string]cloudModelLimit{
|
||||
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||
|
|
@ -65,11 +61,24 @@ var cloudModelLimits = map[string]cloudModelLimit{
|
|||
"qwen3.5": {Context: 262_144, Output: 32_768},
|
||||
}
|
||||
|
||||
var cloudModelLimits = mergeCloudModelLimits(cloudModelLimitsFromRecommendations(recommendedModels), extraCloudModelLimits)
|
||||
|
||||
var (
|
||||
dynamicCloudModelLimitsMu sync.RWMutex
|
||||
dynamicCloudModelLimits = map[string]cloudModelLimit{}
|
||||
)
|
||||
|
||||
// lookupCloudModelLimit returns the token limits for a cloud model.
|
||||
// It normalizes explicit cloud source suffixes before checking the shared limit map.
|
||||
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
||||
base, stripped := modelref.StripCloudSourceTag(name)
|
||||
if stripped {
|
||||
dynamicCloudModelLimitsMu.RLock()
|
||||
l, ok := dynamicCloudModelLimits[base]
|
||||
dynamicCloudModelLimitsMu.RUnlock()
|
||||
if ok {
|
||||
return l, true
|
||||
}
|
||||
if l, ok := cloudModelLimits[base]; ok {
|
||||
return l, true
|
||||
}
|
||||
|
|
@ -77,6 +86,49 @@ func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
|
|||
return cloudModelLimit{}, false
|
||||
}
|
||||
|
||||
func setDynamicCloudModelLimits(limits map[string]cloudModelLimit) {
|
||||
dynamicCloudModelLimitsMu.Lock()
|
||||
defer dynamicCloudModelLimitsMu.Unlock()
|
||||
if limits == nil {
|
||||
dynamicCloudModelLimits = map[string]cloudModelLimit{}
|
||||
return
|
||||
}
|
||||
cp := make(map[string]cloudModelLimit, len(limits))
|
||||
for k, v := range limits {
|
||||
cp[k] = v
|
||||
}
|
||||
dynamicCloudModelLimits = cp
|
||||
}
|
||||
|
||||
func cloudModelLimitsFromRecommendations(recommendations []ModelItem) map[string]cloudModelLimit {
|
||||
limits := make(map[string]cloudModelLimit, len(recommendations))
|
||||
for _, rec := range recommendations {
|
||||
if !isCloudModelName(rec.Name) || rec.ContextLength <= 0 || rec.MaxOutputTokens <= 0 {
|
||||
continue
|
||||
}
|
||||
base, stripped := modelref.StripCloudSourceTag(rec.Name)
|
||||
if !stripped || base == "" {
|
||||
continue
|
||||
}
|
||||
limits[base] = cloudModelLimit{
|
||||
Context: rec.ContextLength,
|
||||
Output: rec.MaxOutputTokens,
|
||||
}
|
||||
}
|
||||
return limits
|
||||
}
|
||||
|
||||
func mergeCloudModelLimits(base map[string]cloudModelLimit, overlay map[string]cloudModelLimit) map[string]cloudModelLimit {
|
||||
out := make(map[string]cloudModelLimit, len(base)+len(overlay))
|
||||
for name, limit := range base {
|
||||
out[name] = limit
|
||||
}
|
||||
for name, limit := range overlay {
|
||||
out[name] = limit
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// missingModelPolicy controls how model-not-found errors should be handled.
|
||||
type missingModelPolicy int
|
||||
|
||||
|
|
@ -276,13 +328,17 @@ func confirmConfigEdit(runner Runner, paths []string) (bool, error) {
|
|||
|
||||
// buildModelList merges existing models with recommendations for selection UIs.
|
||||
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||
return buildModelListWithRecommendations(existing, recommendedModels, preChecked, current)
|
||||
}
|
||||
|
||||
func buildModelListWithRecommendations(existing []modelInfo, recommendations []ModelItem, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||
existingModels = make(map[string]bool)
|
||||
cloudModels = make(map[string]bool)
|
||||
recommended := make(map[string]bool)
|
||||
var hasLocalModel, hasCloudModel bool
|
||||
|
||||
recDesc := make(map[string]string)
|
||||
for _, rec := range recommendedModels {
|
||||
for _, rec := range recommendations {
|
||||
recommended[rec.Name] = true
|
||||
recDesc[rec.Name] = rec.Description
|
||||
}
|
||||
|
|
@ -301,7 +357,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
|||
items = append(items, item)
|
||||
}
|
||||
|
||||
for _, rec := range recommendedModels {
|
||||
for _, rec := range recommendations {
|
||||
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
|
||||
continue
|
||||
}
|
||||
|
|
@ -347,8 +403,8 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
|||
if items[i].Description != "" {
|
||||
parts = append(parts, items[i].Description)
|
||||
}
|
||||
if vram := recommendedVRAM[items[i].Name]; vram != "" {
|
||||
parts = append(parts, vram)
|
||||
if items[i].VRAM != "" {
|
||||
parts = append(parts, items[i].VRAM)
|
||||
}
|
||||
parts = append(parts, "(not downloaded)")
|
||||
items[i].Description = strings.Join(parts, ", ")
|
||||
|
|
@ -356,12 +412,12 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
|||
}
|
||||
|
||||
recRank := make(map[string]int)
|
||||
for i, rec := range recommendedModels {
|
||||
for i, rec := range recommendations {
|
||||
recRank[rec.Name] = i + 1
|
||||
}
|
||||
|
||||
if hasLocalModel || hasCloudModel {
|
||||
// Keep the Recommended section pinned to recommendedModels order. Checked
|
||||
// Keep the Recommended section pinned to recommendation order. Checked
|
||||
// and default-model priority only apply within the More section.
|
||||
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
|
|
|
|||
401
server/model_recommendations.go
Normal file
401
server/model_recommendations.go
Normal file
|
|
@ -0,0 +1,401 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
const modelRecommendationsURL = "https://ollama.com/api/experimental/model-recommendations"
|
||||
|
||||
var (
|
||||
modelRecommendationsRefreshInterval = 4 * time.Hour
|
||||
modelRecommendationsFetchTimeout = 3 * time.Second
|
||||
modelRecommendationsReadRefreshCooldown = 5 * time.Second
|
||||
modelRecommendationsBackoffSteps = []time.Duration{
|
||||
5 * time.Minute,
|
||||
15 * time.Minute,
|
||||
time.Hour,
|
||||
4 * time.Hour,
|
||||
}
|
||||
|
||||
errModelRecommendationsNoCloud = errors.New("cloud disabled")
|
||||
)
|
||||
|
||||
type modelRecommendationsCache struct {
|
||||
mu sync.RWMutex
|
||||
recommendations []api.ModelRecommendation
|
||||
refreshing bool
|
||||
nextReadRefreshAfter time.Time
|
||||
once sync.Once
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func newModelRecommendationsCache() *modelRecommendationsCache {
|
||||
return &modelRecommendationsCache{
|
||||
recommendations: cloneModelRecommendations(defaultModelRecommendations),
|
||||
client: http.DefaultClient,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) Start(ctx context.Context) {
|
||||
c.once.Do(func() {
|
||||
slog.Debug("starting model recommendations cache",
|
||||
"default_recommendations", len(defaultModelRecommendations),
|
||||
"refresh_interval", modelRecommendationsRefreshInterval.String(),
|
||||
"fetch_timeout", modelRecommendationsFetchTimeout.String(),
|
||||
)
|
||||
go c.run(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) Get() []api.ModelRecommendation {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return cloneModelRecommendations(c.recommendations)
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) GetSWR(ctx context.Context) []api.ModelRecommendation {
|
||||
recs := c.Get()
|
||||
c.triggerRefreshOnRead(ctx)
|
||||
return recs
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) set(recs []api.ModelRecommendation) {
|
||||
c.mu.Lock()
|
||||
c.recommendations = cloneModelRecommendations(recs)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) beginRefresh() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.refreshing {
|
||||
return false
|
||||
}
|
||||
c.refreshing = true
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) beginReadRefresh() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
now := time.Now()
|
||||
if c.refreshing || now.Before(c.nextReadRefreshAfter) {
|
||||
return false
|
||||
}
|
||||
|
||||
c.refreshing = true
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) endRefresh() {
|
||||
c.mu.Lock()
|
||||
c.refreshing = false
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) endReadRefresh() {
|
||||
c.mu.Lock()
|
||||
c.refreshing = false
|
||||
c.nextReadRefreshAfter = time.Now().Add(modelRecommendationsReadRefreshCooldown)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) refreshIfIdle(ctx context.Context) (bool, error) {
|
||||
if !c.beginRefresh() {
|
||||
return false, nil
|
||||
}
|
||||
defer c.endRefresh()
|
||||
return true, c.refresh(ctx)
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) triggerRefreshOnRead(ctx context.Context) {
|
||||
if !c.beginReadRefresh() {
|
||||
return
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
ctx = context.WithoutCancel(ctx)
|
||||
|
||||
slog.Debug("triggering model recommendations refresh on read")
|
||||
go func() {
|
||||
defer c.endReadRefresh()
|
||||
|
||||
if err := c.refresh(ctx); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, errModelRecommendationsNoCloud):
|
||||
slog.Debug("skipping model recommendations read refresh because cloud is disabled")
|
||||
default:
|
||||
slog.Warn("model recommendations read refresh failed", "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) run(ctx context.Context) {
|
||||
c.loadSnapshot()
|
||||
|
||||
failures := 0
|
||||
for {
|
||||
started, err := c.refreshIfIdle(ctx)
|
||||
switch {
|
||||
case !started:
|
||||
failures = 0
|
||||
slog.Debug("skipping timer model recommendations refresh because refresh is already running")
|
||||
case err == nil:
|
||||
failures = 0
|
||||
case errors.Is(err, errModelRecommendationsNoCloud):
|
||||
failures = 0
|
||||
slog.Debug("skipping model recommendations refresh because cloud is disabled")
|
||||
default:
|
||||
failures++
|
||||
slog.Warn("model recommendations refresh failed", "error", err)
|
||||
}
|
||||
|
||||
var wait time.Duration
|
||||
if failures == 0 {
|
||||
wait = withJitter(modelRecommendationsRefreshInterval)
|
||||
} else {
|
||||
wait = withJitter(modelRecommendationsBackoffSteps[min(failures-1, len(modelRecommendationsBackoffSteps)-1)])
|
||||
}
|
||||
slog.Info("model recommendations cache sleep scheduled", "wait", wait.String(), "consecutive_failures", failures)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Debug("stopping model recommendations cache")
|
||||
return
|
||||
case <-time.After(wait):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) refresh(ctx context.Context) error {
|
||||
if envconfig.NoCloud() {
|
||||
return errModelRecommendationsNoCloud
|
||||
}
|
||||
slog.Debug("refreshing model recommendations from remote", "url", modelRecommendationsURL)
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, modelRecommendationsFetchTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelRecommendationsURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var payload api.ModelRecommendationsResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recs, err := validateModelRecommendations(payload.Recommendations)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.set(recs)
|
||||
slog.Debug("model recommendations refreshed", "count", len(recs))
|
||||
if err := c.persistSnapshot(recs); err != nil {
|
||||
slog.Warn("failed to persist model recommendations snapshot", "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) loadSnapshot() {
|
||||
path, err := modelRecommendationsSnapshotPath()
|
||||
if err != nil {
|
||||
slog.Warn("failed to resolve model recommendations snapshot path", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
slog.Warn("failed to read model recommendations snapshot", "path", path, "error", err)
|
||||
} else {
|
||||
slog.Debug("model recommendations snapshot not found", "path", path)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var snap api.ModelRecommendationsResponse
|
||||
if err := json.Unmarshal(data, &snap); err != nil {
|
||||
slog.Warn("failed to parse model recommendations snapshot", "path", path, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
recs, err := validateModelRecommendations(snap.Recommendations)
|
||||
if err != nil {
|
||||
slog.Warn("ignoring invalid model recommendations snapshot", "path", path, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.set(recs)
|
||||
slog.Debug("loaded model recommendations snapshot", "path", path, "count", len(recs))
|
||||
}
|
||||
|
||||
func (c *modelRecommendationsCache) persistSnapshot(recs []api.ModelRecommendation) error {
|
||||
path, err := modelRecommendationsSnapshotPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
payload := api.ModelRecommendationsResponse{Recommendations: recs}
|
||||
data, err := json.MarshalIndent(payload, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmp, err := os.CreateTemp(filepath.Dir(path), ".model-recommendations-*.tmp")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tmpPath := tmp.Name()
|
||||
defer os.Remove(tmpPath)
|
||||
|
||||
if _, err := tmp.Write(data); err != nil {
|
||||
_ = tmp.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmp.Sync(); err != nil {
|
||||
_ = tmp.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, path); err != nil {
|
||||
return err
|
||||
}
|
||||
slog.Debug("persisted model recommendations snapshot", "path", path, "count", len(recs))
|
||||
return nil
|
||||
}
|
||||
|
||||
func modelRecommendationsSnapshotPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "cache", "model-recommendations.json"), nil
|
||||
}
|
||||
|
||||
func validateModelRecommendations(recs []api.ModelRecommendation) ([]api.ModelRecommendation, error) {
|
||||
if len(recs) == 0 {
|
||||
return nil, errors.New("empty recommendations")
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(recs))
|
||||
valid := make([]api.ModelRecommendation, 0, len(recs))
|
||||
for _, rec := range recs {
|
||||
rec.Model = strings.TrimSpace(rec.Model)
|
||||
rec.Description = strings.TrimSpace(rec.Description)
|
||||
rec.VRAM = strings.TrimSpace(rec.VRAM)
|
||||
|
||||
if rec.Model == "" {
|
||||
return nil, errors.New("recommendation missing model")
|
||||
}
|
||||
if _, ok := seen[rec.Model]; ok {
|
||||
return nil, fmt.Errorf("duplicate recommendation %q", rec.Model)
|
||||
}
|
||||
seen[rec.Model] = struct{}{}
|
||||
|
||||
if isCloudRecommendation(rec.Model) && (rec.ContextLength <= 0 || rec.MaxOutputTokens <= 0) {
|
||||
slog.Warn("dropping cloud recommendation missing limits", "model", rec.Model)
|
||||
continue
|
||||
}
|
||||
valid = append(valid, rec)
|
||||
}
|
||||
|
||||
if len(valid) == 0 {
|
||||
return nil, errors.New("no valid recommendations")
|
||||
}
|
||||
|
||||
return valid, nil
|
||||
}
|
||||
|
||||
func isCloudRecommendation(modelName string) bool {
|
||||
return strings.HasSuffix(modelName, ":cloud") || strings.HasSuffix(modelName, "-cloud")
|
||||
}
|
||||
|
||||
func withJitter(d time.Duration) time.Duration {
|
||||
if d <= 0 {
|
||||
return d
|
||||
}
|
||||
// jitter in range [0.8x, 1.2x]
|
||||
factor := 0.8 + rand.Float64()*0.4
|
||||
return time.Duration(float64(d) * factor)
|
||||
}
|
||||
|
||||
func cloneModelRecommendations(in []api.ModelRecommendation) []api.ModelRecommendation {
|
||||
out := make([]api.ModelRecommendation, len(in))
|
||||
copy(out, in)
|
||||
return out
|
||||
}
|
||||
|
||||
var defaultModelRecommendations = []api.ModelRecommendation{
|
||||
{
|
||||
Model: "kimi-k2.6:cloud",
|
||||
Description: "State-of-the-art coding, long-horizon execution, and multimodal agent swarm capability",
|
||||
ContextLength: 262_144,
|
||||
MaxOutputTokens: 262_144,
|
||||
},
|
||||
{
|
||||
Model: "glm-5.1:cloud",
|
||||
Description: "Reasoning and code generation",
|
||||
ContextLength: 202_752,
|
||||
MaxOutputTokens: 131_072,
|
||||
},
|
||||
{
|
||||
Model: "qwen3.5:cloud",
|
||||
Description: "Reasoning, coding, and agentic tool use with vision",
|
||||
ContextLength: 262_144,
|
||||
MaxOutputTokens: 32_768,
|
||||
},
|
||||
{
|
||||
Model: "minimax-m2.7:cloud",
|
||||
Description: "Fast, efficient coding and real-world productivity",
|
||||
ContextLength: 204_800,
|
||||
MaxOutputTokens: 128_000,
|
||||
},
|
||||
{
|
||||
Model: "gemma4",
|
||||
Description: "Reasoning and code generation locally",
|
||||
VRAM: "~16GB",
|
||||
},
|
||||
{
|
||||
Model: "qwen3.5",
|
||||
Description: "Reasoning, coding, and visual understanding locally",
|
||||
VRAM: "~11GB",
|
||||
},
|
||||
}
|
||||
586
server/model_recommendations_test.go
Normal file
586
server/model_recommendations_test.go
Normal file
|
|
@ -0,0 +1,586 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func TestModelRecommendationsDefaultOrder(t *testing.T) {
|
||||
want := []string{
|
||||
"kimi-k2.6:cloud",
|
||||
"glm-5.1:cloud",
|
||||
"qwen3.5:cloud",
|
||||
"minimax-m2.7:cloud",
|
||||
"gemma4",
|
||||
"qwen3.5",
|
||||
}
|
||||
|
||||
if got := modelRecommendationNames(defaultModelRecommendations); !slices.Equal(got, want) {
|
||||
t.Fatalf("recommendations = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRecommendationsCacheRefreshAppliesServerSideChanges(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "")
|
||||
|
||||
first := []api.ModelRecommendation{
|
||||
{Model: " first-cloud:cloud ", Description: " first ", ContextLength: 2048, MaxOutputTokens: 512},
|
||||
{Model: " first-local ", Description: " first local ", VRAM: " ~3GB "},
|
||||
}
|
||||
second := []api.ModelRecommendation{
|
||||
{Model: "second-cloud:cloud", Description: "second", ContextLength: 4096, MaxOutputTokens: 1024},
|
||||
{Model: "second-local", Description: "second local", VRAM: "~6GB"},
|
||||
}
|
||||
|
||||
calls := 0
|
||||
cache := newModelRecommendationsCache()
|
||||
cache.client = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if req.Method != http.MethodGet {
|
||||
t.Fatalf("method = %q, want GET", req.Method)
|
||||
}
|
||||
if req.URL.String() != modelRecommendationsURL {
|
||||
t.Fatalf("url = %q, want %q", req.URL.String(), modelRecommendationsURL)
|
||||
}
|
||||
|
||||
calls++
|
||||
payload := api.ModelRecommendationsResponse{Recommendations: first}
|
||||
if calls > 1 {
|
||||
payload.Recommendations = second
|
||||
}
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal payload failed: %v", err)
|
||||
}
|
||||
return jsonHTTPResponse(http.StatusOK, string(data)), nil
|
||||
})}
|
||||
|
||||
if err := cache.refresh(context.Background()); err != nil {
|
||||
t.Fatalf("first refresh failed: %v", err)
|
||||
}
|
||||
if got, want := cache.Get(), []api.ModelRecommendation{
|
||||
{Model: "first-cloud:cloud", Description: "first", ContextLength: 2048, MaxOutputTokens: 512},
|
||||
{Model: "first-local", Description: "first local", VRAM: "~3GB"},
|
||||
}; !slices.Equal(got, want) {
|
||||
t.Fatalf("after first refresh recommendations = %#v, want %#v", got, want)
|
||||
}
|
||||
|
||||
if err := cache.refresh(context.Background()); err != nil {
|
||||
t.Fatalf("second refresh failed: %v", err)
|
||||
}
|
||||
if got, want := cache.Get(), second; !slices.Equal(got, want) {
|
||||
t.Fatalf("after second refresh recommendations = %#v, want %#v", got, want)
|
||||
}
|
||||
|
||||
path, err := modelRecommendationsSnapshotPath()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot path failed: %v", err)
|
||||
}
|
||||
snapshotData, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read snapshot failed: %v", err)
|
||||
}
|
||||
var snapshot api.ModelRecommendationsResponse
|
||||
if err := json.Unmarshal(snapshotData, &snapshot); err != nil {
|
||||
t.Fatalf("unmarshal snapshot failed: %v", err)
|
||||
}
|
||||
if !slices.Equal(snapshot.Recommendations, second) {
|
||||
t.Fatalf("snapshot recommendations = %#v, want %#v", snapshot.Recommendations, second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRecommendationsCacheRefreshErrorCasesPreserveCurrentData(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
transport roundTripFunc
|
||||
errSubstr string
|
||||
}{
|
||||
{
|
||||
name: "transport error",
|
||||
transport: func(*http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("network down")
|
||||
},
|
||||
errSubstr: "network down",
|
||||
},
|
||||
{
|
||||
name: "remote status error",
|
||||
transport: func(*http.Request) (*http.Response, error) {
|
||||
return jsonHTTPResponse(http.StatusInternalServerError, "upstream broken"), nil
|
||||
},
|
||||
errSubstr: "status 500: upstream broken",
|
||||
},
|
||||
{
|
||||
name: "invalid json payload",
|
||||
transport: func(*http.Request) (*http.Response, error) {
|
||||
return jsonHTTPResponse(http.StatusOK, "{"), nil
|
||||
},
|
||||
errSubstr: "unexpected EOF",
|
||||
},
|
||||
{
|
||||
name: "duplicate recommendations",
|
||||
transport: func(*http.Request) (*http.Response, error) {
|
||||
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"dup","description":"a"},{"model":"dup","description":"b"}]}`), nil
|
||||
},
|
||||
errSubstr: `duplicate recommendation "dup"`,
|
||||
},
|
||||
{
|
||||
name: "empty recommendations",
|
||||
transport: func(*http.Request) (*http.Response, error) {
|
||||
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[]}`), nil
|
||||
},
|
||||
errSubstr: "empty recommendations",
|
||||
},
|
||||
{
|
||||
name: "only invalid cloud recommendations",
|
||||
transport: func(*http.Request) (*http.Response, error) {
|
||||
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"bad:cloud","description":"missing limits"}]}`), nil
|
||||
},
|
||||
errSubstr: "no valid recommendations",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "")
|
||||
|
||||
cache := newModelRecommendationsCache()
|
||||
stable := []api.ModelRecommendation{{Model: "stable-local", Description: "stable desc", VRAM: "~2GB"}}
|
||||
cache.set(stable)
|
||||
cache.client = &http.Client{Transport: tc.transport}
|
||||
|
||||
err := cache.refresh(context.Background())
|
||||
if err == nil {
|
||||
t.Fatalf("refresh returned nil error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.errSubstr) {
|
||||
t.Fatalf("error = %q, want substring %q", err.Error(), tc.errSubstr)
|
||||
}
|
||||
|
||||
if got := cache.Get(); !slices.Equal(got, stable) {
|
||||
t.Fatalf("recommendations changed on error: got %#v, want %#v", got, stable)
|
||||
}
|
||||
|
||||
path, pathErr := modelRecommendationsSnapshotPath()
|
||||
if pathErr != nil {
|
||||
t.Fatalf("snapshot path failed: %v", pathErr)
|
||||
}
|
||||
if _, statErr := os.Stat(path); !errors.Is(statErr, os.ErrNotExist) {
|
||||
t.Fatalf("snapshot file should not be written on error, stat err = %v", statErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRecommendationsCacheRefreshNoCloudShortCircuits(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "1")
|
||||
|
||||
called := false
|
||||
cache := newModelRecommendationsCache()
|
||||
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
called = true
|
||||
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"should-not-be-used","description":"n/a"}]}`), nil
|
||||
})}
|
||||
|
||||
err := cache.refresh(context.Background())
|
||||
if !errors.Is(err, errModelRecommendationsNoCloud) {
|
||||
t.Fatalf("refresh error = %v, want %v", err, errModelRecommendationsNoCloud)
|
||||
}
|
||||
if called {
|
||||
t.Fatalf("remote endpoint should not be called when cloud is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRecommendationsSnapshotPersistAndLoad(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "")
|
||||
|
||||
want := []api.ModelRecommendation{
|
||||
{Model: "persist-cloud:cloud", Description: "persisted", ContextLength: 8192, MaxOutputTokens: 2048},
|
||||
{Model: "persist-local", Description: "persisted local", VRAM: "~5GB"},
|
||||
}
|
||||
|
||||
writer := newModelRecommendationsCache()
|
||||
if err := writer.persistSnapshot(want); err != nil {
|
||||
t.Fatalf("persistSnapshot failed: %v", err)
|
||||
}
|
||||
|
||||
loader := newModelRecommendationsCache()
|
||||
loader.set([]api.ModelRecommendation{{Model: "old", Description: "old"}})
|
||||
loader.loadSnapshot()
|
||||
|
||||
if got := loader.Get(); !slices.Equal(got, want) {
|
||||
t.Fatalf("loaded recommendations = %#v, want %#v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRecommendationsLoadSnapshotInvalidDoesNotOverwrite(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "")
|
||||
|
||||
path, err := modelRecommendationsSnapshotPath()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot path failed: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
t.Fatalf("mkdir failed: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(path, []byte("{invalid"), 0o644); err != nil {
|
||||
t.Fatalf("write invalid snapshot failed: %v", err)
|
||||
}
|
||||
|
||||
cache := newModelRecommendationsCache()
|
||||
existing := []api.ModelRecommendation{{Model: "existing", Description: "existing description"}}
|
||||
cache.set(existing)
|
||||
cache.loadSnapshot()
|
||||
|
||||
if got := cache.Get(); !slices.Equal(got, existing) {
|
||||
t.Fatalf("recommendations overwritten by invalid snapshot: got %#v, want %#v", got, existing)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateModelRecommendationsTrimsAndDropsInvalidCloudEntries(t *testing.T) {
|
||||
input := []api.ModelRecommendation{
|
||||
{Model: " good-cloud:cloud ", Description: " good cloud ", ContextLength: 1024, MaxOutputTokens: 256},
|
||||
{Model: "bad-cloud:cloud", Description: "missing limits"},
|
||||
{Model: " good-local ", Description: " good local ", VRAM: " ~2GB "},
|
||||
}
|
||||
|
||||
got, err := validateModelRecommendations(input)
|
||||
if err != nil {
|
||||
t.Fatalf("validateModelRecommendations failed: %v", err)
|
||||
}
|
||||
|
||||
want := []api.ModelRecommendation{
|
||||
{Model: "good-cloud:cloud", Description: "good cloud", ContextLength: 1024, MaxOutputTokens: 256},
|
||||
{Model: "good-local", Description: "good local", VRAM: "~2GB"},
|
||||
}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("validated recommendations = %#v, want %#v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRecommendationsHandlerReturnsDefaults(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/api/experimental/model-recommendations", nil)
|
||||
|
||||
s := &Server{}
|
||||
s.ModelRecommendationsExperimentalHandler(ctx)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
got := decodeRecommendationNames(t, w)
|
||||
want := modelRecommendationNames(defaultModelRecommendations)
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("models = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRecommendationsHandlerUsesCache(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupModelRecommendationsTestEnv(t, "1")
|
||||
|
||||
cache := newModelRecommendationsCache()
|
||||
cache.set([]api.ModelRecommendation{{Model: "test-model", Description: "test description"}})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(w)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/api/experimental/model-recommendations", nil)
|
||||
|
||||
s := &Server{modelRecommendations: cache}
|
||||
s.ModelRecommendationsExperimentalHandler(ctx)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
got := decodeRecommendationNames(t, w)
|
||||
if !slices.Equal(got, []string{"test-model"}) {
|
||||
t.Fatalf("models = %v, want %v", got, []string{"test-model"})
|
||||
}
|
||||
waitForCacheIdle(t, cache)
|
||||
}
|
||||
|
||||
func TestModelRecommendationsRouteRegistration(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupModelRecommendationsTestEnv(t, "1")
|
||||
|
||||
cache := newModelRecommendationsCache()
|
||||
cache.set([]api.ModelRecommendation{{Model: "route-model", Description: "route description"}})
|
||||
s := &Server{modelRecommendations: cache}
|
||||
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRoutes failed: %v", err)
|
||||
}
|
||||
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/api/experimental/model-recommendations", nil)
|
||||
getResp := httptest.NewRecorder()
|
||||
router.ServeHTTP(getResp, getReq)
|
||||
if getResp.Code != http.StatusOK {
|
||||
t.Fatalf("GET status = %d, want %d", getResp.Code, http.StatusOK)
|
||||
}
|
||||
if got := decodeRecommendationNames(t, getResp); !slices.Equal(got, []string{"route-model"}) {
|
||||
t.Fatalf("GET models = %v, want %v", got, []string{"route-model"})
|
||||
}
|
||||
|
||||
postReq := httptest.NewRequest(http.MethodPost, "/api/experimental/model-recommendations", nil)
|
||||
postResp := httptest.NewRecorder()
|
||||
router.ServeHTTP(postResp, postReq)
|
||||
if postResp.Code != http.StatusMethodNotAllowed {
|
||||
t.Fatalf("POST status = %d, want %d", postResp.Code, http.StatusMethodNotAllowed)
|
||||
}
|
||||
waitForCacheIdle(t, cache)
|
||||
}
|
||||
|
||||
func TestModelRecommendationsGetSWRTriggersRefreshOnRead(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "")
|
||||
|
||||
cache := newModelRecommendationsCache()
|
||||
old := []api.ModelRecommendation{{Model: "old", Description: "old"}}
|
||||
newRecs := []api.ModelRecommendation{{Model: "new-cloud:cloud", Description: "new", ContextLength: 1024, MaxOutputTokens: 256}}
|
||||
cache.set(old)
|
||||
|
||||
refreshDone := make(chan struct{})
|
||||
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
defer close(refreshDone)
|
||||
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"new-cloud:cloud","description":"new","context_length":1024,"max_output_tokens":256}]}`), nil
|
||||
})}
|
||||
|
||||
gotImmediate := cache.GetSWR(context.Background())
|
||||
if !slices.Equal(gotImmediate, old) {
|
||||
t.Fatalf("GetSWR should return current cache immediately: got %#v, want %#v", gotImmediate, old)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-refreshDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for async refresh")
|
||||
}
|
||||
|
||||
waitForCondition(t, 2*time.Second, func() bool {
|
||||
return slices.Equal(cache.Get(), newRecs)
|
||||
})
|
||||
waitForCacheIdle(t, cache)
|
||||
}
|
||||
|
||||
func TestModelRecommendationsGetSWRSkipsWhenRefreshAlreadyInFlight(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "")
|
||||
|
||||
cache := newModelRecommendationsCache()
|
||||
cache.set([]api.ModelRecommendation{{Model: "old", Description: "old"}})
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
var calls atomic.Int32
|
||||
|
||||
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
n := calls.Add(1)
|
||||
if n == 1 {
|
||||
close(started)
|
||||
}
|
||||
<-release
|
||||
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"updated","description":"ok"}]}`), nil
|
||||
})}
|
||||
|
||||
cache.GetSWR(context.Background())
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for first refresh call")
|
||||
}
|
||||
|
||||
for range 5 {
|
||||
cache.GetSWR(context.Background())
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("calls during in-flight refresh = %d, want 1", got)
|
||||
}
|
||||
|
||||
close(release)
|
||||
waitForCacheIdle(t, cache)
|
||||
}
|
||||
|
||||
func TestModelRecommendationsGetSWRThrottlesRefreshAfterCompletion(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "")
|
||||
withModelRecommendationsReadRefreshCooldown(t, 100*time.Millisecond)
|
||||
|
||||
cache := newModelRecommendationsCache()
|
||||
cache.set([]api.ModelRecommendation{{Model: "old", Description: "old"}})
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
var calls atomic.Int32
|
||||
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
if calls.Add(1) == 1 {
|
||||
close(started)
|
||||
<-release
|
||||
}
|
||||
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"updated","description":"ok"}]}`), nil
|
||||
})}
|
||||
|
||||
cache.GetSWR(context.Background())
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for first refresh call")
|
||||
}
|
||||
|
||||
time.Sleep(2 * modelRecommendationsReadRefreshCooldown)
|
||||
close(release)
|
||||
waitForCacheIdle(t, cache)
|
||||
|
||||
cache.GetSWR(context.Background())
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("calls during read refresh cooldown = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRecommendationsGetSWRRetriesAfterReadRefreshCooldown(t *testing.T) {
|
||||
setupModelRecommendationsTestEnv(t, "")
|
||||
withModelRecommendationsReadRefreshCooldown(t, 100*time.Millisecond)
|
||||
|
||||
cache := newModelRecommendationsCache()
|
||||
old := []api.ModelRecommendation{{Model: "old", Description: "old"}}
|
||||
cache.set(old)
|
||||
|
||||
var calls atomic.Int32
|
||||
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
if calls.Add(1) == 1 {
|
||||
return nil, errors.New("temporary upstream failure")
|
||||
}
|
||||
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"recovered","description":"ok"}]}`), nil
|
||||
})}
|
||||
|
||||
cache.GetSWR(context.Background())
|
||||
waitForCondition(t, 2*time.Second, func() bool { return calls.Load() >= 1 })
|
||||
waitForCacheIdle(t, cache)
|
||||
|
||||
if !slices.Equal(cache.Get(), old) {
|
||||
t.Fatalf("cache should remain unchanged after failed refresh, got %#v", cache.Get())
|
||||
}
|
||||
|
||||
cache.GetSWR(context.Background())
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("calls during read refresh cooldown after failure = %d, want 1", got)
|
||||
}
|
||||
|
||||
waitForCondition(t, 2*time.Second, func() bool {
|
||||
cache.GetSWR(context.Background())
|
||||
return calls.Load() >= 2
|
||||
})
|
||||
waitForCondition(t, 2*time.Second, func() bool {
|
||||
return slices.Equal(cache.Get(), []api.ModelRecommendation{{Model: "recovered", Description: "ok"}})
|
||||
})
|
||||
waitForCacheIdle(t, cache)
|
||||
}
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func jsonHTTPResponse(statusCode int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func setupModelRecommendationsTestEnv(t *testing.T, noCloudEnv string) {
|
||||
t.Helper()
|
||||
home := t.TempDir()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", home)
|
||||
t.Setenv("HOMEDRIVE", filepath.VolumeName(home))
|
||||
t.Setenv("HOMEPATH", strings.TrimPrefix(home, filepath.VolumeName(home)))
|
||||
|
||||
// Use explicit false rather than empty to avoid platform/env ambiguity.
|
||||
if noCloudEnv == "" {
|
||||
noCloudEnv = "false"
|
||||
}
|
||||
t.Setenv("OLLAMA_NO_CLOUD", noCloudEnv)
|
||||
envconfig.ReloadServerConfig()
|
||||
t.Cleanup(envconfig.ReloadServerConfig)
|
||||
}
|
||||
|
||||
func withModelRecommendationsReadRefreshCooldown(t *testing.T, d time.Duration) {
|
||||
t.Helper()
|
||||
old := modelRecommendationsReadRefreshCooldown
|
||||
modelRecommendationsReadRefreshCooldown = d
|
||||
t.Cleanup(func() {
|
||||
modelRecommendationsReadRefreshCooldown = old
|
||||
})
|
||||
}
|
||||
|
||||
func waitForCondition(t *testing.T, timeout time.Duration, cond func() bool) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if cond() {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
t.Fatal("timed out waiting for condition")
|
||||
}
|
||||
|
||||
func waitForCacheIdle(t *testing.T, cache *modelRecommendationsCache) {
|
||||
t.Helper()
|
||||
waitForCondition(t, 2*time.Second, func() bool {
|
||||
cache.mu.RLock()
|
||||
refreshing := cache.refreshing
|
||||
cache.mu.RUnlock()
|
||||
return !refreshing
|
||||
})
|
||||
}
|
||||
|
||||
func decodeRecommendationNames(t *testing.T, w *httptest.ResponseRecorder) []string {
|
||||
t.Helper()
|
||||
|
||||
var resp struct {
|
||||
Recommendations []struct {
|
||||
Model string `json:"model"`
|
||||
} `json:"recommendations"`
|
||||
}
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(resp.Recommendations))
|
||||
for _, rec := range resp.Recommendations {
|
||||
names = append(names, rec.Model)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func modelRecommendationNames(recs []api.ModelRecommendation) []string {
|
||||
names := make([]string, len(recs))
|
||||
for i, rec := range recs {
|
||||
names[i] = rec.Model
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
|
@ -98,10 +98,11 @@ var useClient2 = experimentEnabled("client2")
|
|||
var mode string = gin.DebugMode
|
||||
|
||||
type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
defaultNumCtx int
|
||||
requestLogger *inferenceRequestLogger
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
defaultNumCtx int
|
||||
requestLogger *inferenceRequestLogger
|
||||
modelRecommendations *modelRecommendationsCache
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
@ -1716,6 +1717,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||
r.POST("/api/copy", s.CopyHandler)
|
||||
r.POST("/api/experimental/web_search", s.WebSearchExperimentalHandler)
|
||||
r.POST("/api/experimental/web_fetch", s.WebFetchExperimentalHandler)
|
||||
r.GET("/api/experimental/model-recommendations", s.ModelRecommendationsExperimentalHandler)
|
||||
|
||||
// Inference
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
|
|
@ -1757,6 +1759,24 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||
return r, nil
|
||||
}
|
||||
|
||||
func (s *Server) ModelRecommendationsExperimentalHandler(c *gin.Context) {
|
||||
recs := defaultModelRecommendations
|
||||
source := "default"
|
||||
if s.modelRecommendations != nil {
|
||||
ctx := context.Background()
|
||||
if c.Request != nil {
|
||||
ctx = c.Request.Context()
|
||||
}
|
||||
recs = s.modelRecommendations.GetSWR(ctx)
|
||||
source = "cache"
|
||||
}
|
||||
|
||||
slog.Debug("serving model recommendations", "recommendation_source", source, "count", len(recs))
|
||||
c.JSON(http.StatusOK, api.ModelRecommendationsResponse{
|
||||
Recommendations: recs,
|
||||
})
|
||||
}
|
||||
|
||||
func Serve(ln net.Listener) error {
|
||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
slog.Info("server config", "env", envconfig.Values())
|
||||
|
|
@ -1791,7 +1811,13 @@ func Serve(ln net.Listener) error {
|
|||
}
|
||||
}
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
// TODO(parthsareen): If we add more runtime caches, prefer introducing a
|
||||
// small cache manager owned by Server (for shared start/stop/health wiring)
|
||||
// instead of adding one top-level field per cache here in Serve.
|
||||
s := &Server{
|
||||
addr: ln.Addr(),
|
||||
modelRecommendations: newModelRecommendationsCache(),
|
||||
}
|
||||
if err := s.initRequestLogging(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -1816,6 +1842,7 @@ func Serve(ln net.Listener) error {
|
|||
schedCtx, schedDone := context.WithCancel(ctx)
|
||||
sched := InitScheduler(schedCtx)
|
||||
s.sched = sched
|
||||
s.modelRecommendations.Start(ctx)
|
||||
|
||||
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
|
||||
srvr := &http.Server{
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue