server/launch: add model recommendations cache endpoint (#15868)

This commit is contained in:
Parth Sareen 2026-04-28 17:09:04 -07:00 committed by GitHub
parent 87288ced4f
commit 321cc8a2ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 1250 additions and 38 deletions

View file

@ -368,6 +368,16 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
return &lr, nil
}
// ModelRecommendationsExperimental lists model recommendations from the local
// server's experimental recommendations endpoint.
func (c *Client) ModelRecommendationsExperimental(ctx context.Context) (*ModelRecommendationsResponse, error) {
var resp ModelRecommendationsResponse
if err := c.do(ctx, http.MethodGet, "/api/experimental/model-recommendations", nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListRunning lists running models.
func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
var lr ProcessResponse

View file

@ -802,6 +802,20 @@ type ListResponse struct {
Models []ListModelResponse `json:"models"`
}
// ModelRecommendationsResponse is the response from [Client.ModelRecommendationsExperimental].
type ModelRecommendationsResponse struct {
Recommendations []ModelRecommendation `json:"recommendations"`
}
// ModelRecommendation is a single recommendation entry in [ModelRecommendationsResponse].
type ModelRecommendation struct {
Model string `json:"model"`
Description string `json:"description"`
ContextLength int `json:"context_length,omitempty"`
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
VRAM string `json:"vram,omitempty"`
}
// ProcessResponse is the response from [Client.Process].
type ProcessResponse struct {
Models []ProcessModelResponse `json:"models"`

View file

@ -273,6 +273,8 @@ func TestLaunchCmdModelFlagClearsDisabledCloudOverride(t *testing.T) {
switch r.URL.Path {
case "/api/status":
fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
case "/api/show":
@ -474,6 +476,8 @@ func TestLaunchCmdIntegrationArgPromptsForModelWithSavedSelection(t *testing.T)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
case "/api/show":

View file

@ -49,15 +49,6 @@ func withHermesUserHome(t *testing.T, dir string) {
})
}
func withHermesLookPath(t *testing.T, fn func(string) (string, error)) {
t.Helper()
old := hermesLookPath
hermesLookPath = fn
t.Cleanup(func() {
hermesLookPath = old
})
}
func clearHermesMessagingEnvVars(t *testing.T) {
t.Helper()
for _, group := range hermesMessagingEnvGroups {
@ -112,6 +103,8 @@ func TestHermesConfigurePreservesExistingConfigAndEnablesWeb(t *testing.T) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"},{"name":"qwen3.5"},{"name":"llama3.3"}]}`)
default:
@ -224,6 +217,8 @@ func TestHermesConfigureUpdatesMatchingCustomProviderWithoutDroppingFields(t *te
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"},{"name":"qwen3.5"},{"name":"llama3.3"}]}`)
default:
@ -300,6 +295,8 @@ func TestHermesConfigureUsesLaunchResolvedHostForModelDiscovery(t *testing.T) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"},{"name":"qwen3.5"},{"name":"llama3.3"}]}`)
default:
@ -365,6 +362,8 @@ func TestHermesConfigureMigratesLegacyManagedAliases(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"},{"name":"qwen3.5"}]}`)
default:

View file

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"os"
"slices"
@ -182,9 +183,12 @@ type ModelInfo = modelInfo
// ModelItem represents a model for selection UIs.
type ModelItem struct {
Name string
Description string
Recommended bool
Name string
Description string
Recommended bool
VRAM string
ContextLength int
MaxOutputTokens int
}
// LaunchCmd returns the cobra command for launching integrations.
@ -720,9 +724,10 @@ func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []
if err := c.loadModelInventoryOnce(ctx); err != nil {
return nil, nil, err
}
recommendations := c.recommendations(ctx)
cloudDisabled, _ := cloudStatusDisabled(ctx, c.apiClient)
items, orderedChecked, _, _ := buildModelList(c.modelInventory, preChecked, current)
items, orderedChecked, _, _ := buildModelListWithRecommendations(c.modelInventory, recommendations, preChecked, current)
if cloudDisabled {
items = filterCloudItems(items)
orderedChecked = c.filterDisabledCloudModels(ctx, orderedChecked)
@ -733,6 +738,60 @@ func (c *launcherClient) loadSelectableModels(ctx context.Context, preChecked []
return items, orderedChecked, nil
}
func (c *launcherClient) recommendations(ctx context.Context) []ModelItem {
recommendations, err := c.requestRecommendations(ctx)
if err != nil || len(recommendations) == 0 {
// Fail open: recommendation issues should not block launch flows.
// Fall back to built-in recommendations until server data is available.
fallback := append([]ModelItem(nil), recommendedModels...)
setDynamicCloudModelLimits(cloudModelLimitsFromRecommendations(fallback))
return fallback
}
setDynamicCloudModelLimits(cloudModelLimitsFromRecommendations(recommendations))
return recommendations
}
func (c *launcherClient) requestRecommendations(ctx context.Context) ([]ModelItem, error) {
resp, err := c.apiClient.ModelRecommendationsExperimental(ctx)
if err != nil {
return nil, err
}
items := make([]ModelItem, 0, len(resp.Recommendations))
seen := make(map[string]struct{}, len(resp.Recommendations))
for _, rec := range resp.Recommendations {
name := strings.TrimSpace(rec.Model)
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
if isCloudModelName(name) && (rec.ContextLength <= 0 || rec.MaxOutputTokens <= 0) {
slog.Warn("skipping cloud recommendation with missing limits", "model", name)
continue
}
description := strings.TrimSpace(rec.Description)
if description == "" {
description = "Recommended model"
}
items = append(items, ModelItem{
Name: name,
Description: description,
Recommended: true,
VRAM: strings.TrimSpace(rec.VRAM),
ContextLength: rec.ContextLength,
MaxOutputTokens: rec.MaxOutputTokens,
})
}
return items, nil
}
func (c *launcherClient) ensureModelsReady(ctx context.Context, models []string) error {
models = dedupeModelList(models)
if len(models) == 0 {

View file

@ -197,6 +197,8 @@ func TestBuildLauncherState_ManagedSingleIntegrationUsesCurrentModel(t *testing.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
@ -230,6 +232,8 @@ func TestBuildLauncherState_ManagedSingleIntegrationShowsSavedModelWhenLiveConfi
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
@ -269,6 +273,8 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfiguresOnboardsAndRuns(t *
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
@ -326,6 +332,8 @@ func TestLaunchIntegration_ManagedSingleIntegrationReOnboardsWhenSavedFlagIsStal
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
@ -418,6 +426,8 @@ func TestLaunchIntegration_ManagedSingleIntegrationSkipsRewriteWhenSavedMatches(
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
@ -468,6 +478,8 @@ func TestLaunchIntegration_ManagedSingleIntegrationRewritesWhenSavedDiffers(t *t
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
@ -715,6 +727,8 @@ func TestBuildLauncherState_InstalledAndCloudDisabled(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
case "/api/status":
@ -767,6 +781,8 @@ func TestBuildLauncherState_MigratesLegacyOpenclawAliasConfig(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
default:
@ -812,6 +828,8 @@ func TestBuildLauncherState_ToleratesInventoryFailure(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(w, `{"error":"temporary failure"}`)
@ -858,6 +876,8 @@ func TestResolveRunModel_UsesSavedModelWithoutSelector(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
case "/api/show":
@ -903,6 +923,8 @@ func TestResolveRunModel_HeadlessYesAutoPicksLastModel(t *testing.T) {
modelPulled := false
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
case "/api/show":
@ -965,6 +987,8 @@ func TestResolveRunModel_UsesRequestPolicy(t *testing.T) {
modelPulled := false
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
case "/api/show":
@ -1024,6 +1048,8 @@ func TestResolveRunModel_ForcePickerAlwaysUsesSelector(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
case "/api/show":
@ -1074,6 +1100,8 @@ func TestResolveRunModel_ForcePicker_DoesNotReorderByLastModel(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"qwen3.5"},{"name":"gemma4"}]}`)
case "/api/show":
@ -1124,6 +1152,8 @@ func TestResolveRunModel_UsesSignInHookForCloudModel(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[]}`)
case "/api/status":
@ -1181,6 +1211,8 @@ func TestLaunchIntegration_EditorForceConfigure(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"qwen3:8b"}]}`)
case "/api/show":
@ -1250,6 +1282,8 @@ func TestLaunchIntegration_EditorForceConfigure_FloatsCheckedModelsInPicker(t *t
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"qwen3.5:cloud","remote_model":"qwen3.5"},{"name":"qwen3.5"}]}`)
case "/api/show":
@ -1368,6 +1402,8 @@ func TestLaunchIntegration_EditorCloudDisabledFallsBackToSelector(t *testing.T)
switch r.URL.Path {
case "/api/status":
fmt.Fprint(w, `{"cloud":{"disabled":true,"source":"config"}}`)
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
case "/api/show":
@ -1415,6 +1451,8 @@ func TestLaunchIntegration_EditorConfigureMultiSkipsMissingLocalAndPersistsAccep
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"glm-5:cloud","remote_model":"glm-5"}]}`)
case "/api/status":
@ -1497,6 +1535,8 @@ func TestLaunchIntegration_EditorConfigureMultiSkipsUnauthedCloudAndPersistsAcce
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"},{"name":"glm-5:cloud","remote_model":"glm-5"}]}`)
case "/api/status":
@ -1582,6 +1622,8 @@ func TestLaunchIntegration_EditorConfigureMultiRemovesReselectedFailingModel(t *
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"glm-5:cloud","remote_model":"glm-5"},{"name":"llama3.2"}]}`)
case "/api/status":
@ -1669,6 +1711,8 @@ func TestLaunchIntegration_EditorConfigureMultiAllFailuresKeepsExistingAndSkipsL
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[]}`)
case "/api/show":
@ -1975,6 +2019,8 @@ func TestLaunchIntegration_ConfigureOnlyDoesNotRequireInstalledBinary(t *testing
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
case "/api/show":
@ -2018,6 +2064,8 @@ func TestLaunchIntegration_ClaudeSavesPrimaryModel(t *testing.T) {
var aliasSyncCalled bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[]}`)
case "/api/status":
@ -2077,6 +2125,8 @@ func TestLaunchIntegration_ClaudeForceConfigureReprompts(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"qwen3:8b"}]}`)
case "/api/show":
@ -2134,6 +2184,8 @@ func TestLaunchIntegration_ClaudeForceConfigureMissingSelectionDoesNotSave(t *te
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
case "/api/show":
@ -2260,6 +2312,8 @@ func TestLaunchIntegration_ConfigureOnlyPrompt(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
case "/api/show":
@ -2484,6 +2538,8 @@ func TestLaunchIntegration_HeadlessSelectorFlowFailsWithoutPrompt(t *testing.T)
pullCalled := false
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/model-recommendations":
fmt.Fprint(w, `{"recommendations":[]}`)
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"llama3.2"}]}`)
case "/api/show":

View file

@ -10,6 +10,7 @@ import (
"runtime"
"slices"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
@ -21,17 +22,12 @@ import (
)
var recommendedModels = []ModelItem{
{Name: "kimi-k2.6:cloud", Description: "State-of-the-art coding, long-horizon execution, and multimodal agent swarm capability", Recommended: true},
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true},
{Name: "glm-5.1:cloud", Description: "Reasoning and code generation", Recommended: true},
{Name: "minimax-m2.7:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
{Name: "gemma4", Description: "Reasoning and code generation locally", Recommended: true},
{Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true},
}
var recommendedVRAM = map[string]string{
"gemma4": "~16GB",
"qwen3.5": "~11GB",
{Name: "kimi-k2.6:cloud", Description: "State-of-the-art coding, long-horizon execution, and multimodal agent swarm capability", Recommended: true, ContextLength: 262_144, MaxOutputTokens: 262_144},
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true, ContextLength: 262_144, MaxOutputTokens: 32_768},
{Name: "glm-5.1:cloud", Description: "Reasoning and code generation", Recommended: true, ContextLength: 202_752, MaxOutputTokens: 131_072},
{Name: "minimax-m2.7:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true, ContextLength: 204_800, MaxOutputTokens: 128_000},
{Name: "gemma4", Description: "Reasoning and code generation locally", Recommended: true, VRAM: "~16GB"},
{Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true, VRAM: "~11GB"},
}
// cloudModelLimit holds context and output token limits for a cloud model.
@ -40,10 +36,10 @@ type cloudModelLimit struct {
Output int
}
// cloudModelLimits maps cloud model base names to their token limits.
// extraCloudModelLimits maps cloud model base names to token limits for models
// that are not already covered by recommendedModels fallback entries.
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
var cloudModelLimits = map[string]cloudModelLimit{
"minimax-m2.7": {Context: 204_800, Output: 128_000},
var extraCloudModelLimits = map[string]cloudModelLimit{
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
@ -65,11 +61,24 @@ var cloudModelLimits = map[string]cloudModelLimit{
"qwen3.5": {Context: 262_144, Output: 32_768},
}
var cloudModelLimits = mergeCloudModelLimits(cloudModelLimitsFromRecommendations(recommendedModels), extraCloudModelLimits)
var (
dynamicCloudModelLimitsMu sync.RWMutex
dynamicCloudModelLimits = map[string]cloudModelLimit{}
)
// lookupCloudModelLimit returns the token limits for a cloud model.
// It normalizes explicit cloud source suffixes before checking the shared limit map.
func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
base, stripped := modelref.StripCloudSourceTag(name)
if stripped {
dynamicCloudModelLimitsMu.RLock()
l, ok := dynamicCloudModelLimits[base]
dynamicCloudModelLimitsMu.RUnlock()
if ok {
return l, true
}
if l, ok := cloudModelLimits[base]; ok {
return l, true
}
@ -77,6 +86,49 @@ func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
return cloudModelLimit{}, false
}
func setDynamicCloudModelLimits(limits map[string]cloudModelLimit) {
dynamicCloudModelLimitsMu.Lock()
defer dynamicCloudModelLimitsMu.Unlock()
if limits == nil {
dynamicCloudModelLimits = map[string]cloudModelLimit{}
return
}
cp := make(map[string]cloudModelLimit, len(limits))
for k, v := range limits {
cp[k] = v
}
dynamicCloudModelLimits = cp
}
func cloudModelLimitsFromRecommendations(recommendations []ModelItem) map[string]cloudModelLimit {
limits := make(map[string]cloudModelLimit, len(recommendations))
for _, rec := range recommendations {
if !isCloudModelName(rec.Name) || rec.ContextLength <= 0 || rec.MaxOutputTokens <= 0 {
continue
}
base, stripped := modelref.StripCloudSourceTag(rec.Name)
if !stripped || base == "" {
continue
}
limits[base] = cloudModelLimit{
Context: rec.ContextLength,
Output: rec.MaxOutputTokens,
}
}
return limits
}
func mergeCloudModelLimits(base map[string]cloudModelLimit, overlay map[string]cloudModelLimit) map[string]cloudModelLimit {
out := make(map[string]cloudModelLimit, len(base)+len(overlay))
for name, limit := range base {
out[name] = limit
}
for name, limit := range overlay {
out[name] = limit
}
return out
}
// missingModelPolicy controls how model-not-found errors should be handled.
type missingModelPolicy int
@ -276,13 +328,17 @@ func confirmConfigEdit(runner Runner, paths []string) (bool, error) {
// buildModelList merges existing models with recommendations for selection UIs.
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
return buildModelListWithRecommendations(existing, recommendedModels, preChecked, current)
}
func buildModelListWithRecommendations(existing []modelInfo, recommendations []ModelItem, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
existingModels = make(map[string]bool)
cloudModels = make(map[string]bool)
recommended := make(map[string]bool)
var hasLocalModel, hasCloudModel bool
recDesc := make(map[string]string)
for _, rec := range recommendedModels {
for _, rec := range recommendations {
recommended[rec.Name] = true
recDesc[rec.Name] = rec.Description
}
@ -301,7 +357,7 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
items = append(items, item)
}
for _, rec := range recommendedModels {
for _, rec := range recommendations {
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
continue
}
@ -347,8 +403,8 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
if items[i].Description != "" {
parts = append(parts, items[i].Description)
}
if vram := recommendedVRAM[items[i].Name]; vram != "" {
parts = append(parts, vram)
if items[i].VRAM != "" {
parts = append(parts, items[i].VRAM)
}
parts = append(parts, "(not downloaded)")
items[i].Description = strings.Join(parts, ", ")
@ -356,12 +412,12 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
}
recRank := make(map[string]int)
for i, rec := range recommendedModels {
for i, rec := range recommendations {
recRank[rec.Name] = i + 1
}
if hasLocalModel || hasCloudModel {
// Keep the Recommended section pinned to recommendedModels order. Checked
// Keep the Recommended section pinned to recommendation order. Checked
// and default-model priority only apply within the More section.
slices.SortStableFunc(items, func(a, b ModelItem) int {
ac, bc := checked[a.Name], checked[b.Name]

View file

@ -0,0 +1,401 @@
package server
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math/rand/v2"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
const modelRecommendationsURL = "https://ollama.com/api/experimental/model-recommendations"
var (
modelRecommendationsRefreshInterval = 4 * time.Hour
modelRecommendationsFetchTimeout = 3 * time.Second
modelRecommendationsReadRefreshCooldown = 5 * time.Second
modelRecommendationsBackoffSteps = []time.Duration{
5 * time.Minute,
15 * time.Minute,
time.Hour,
4 * time.Hour,
}
errModelRecommendationsNoCloud = errors.New("cloud disabled")
)
type modelRecommendationsCache struct {
mu sync.RWMutex
recommendations []api.ModelRecommendation
refreshing bool
nextReadRefreshAfter time.Time
once sync.Once
client *http.Client
}
func newModelRecommendationsCache() *modelRecommendationsCache {
return &modelRecommendationsCache{
recommendations: cloneModelRecommendations(defaultModelRecommendations),
client: http.DefaultClient,
}
}
func (c *modelRecommendationsCache) Start(ctx context.Context) {
c.once.Do(func() {
slog.Debug("starting model recommendations cache",
"default_recommendations", len(defaultModelRecommendations),
"refresh_interval", modelRecommendationsRefreshInterval.String(),
"fetch_timeout", modelRecommendationsFetchTimeout.String(),
)
go c.run(ctx)
})
}
func (c *modelRecommendationsCache) Get() []api.ModelRecommendation {
c.mu.RLock()
defer c.mu.RUnlock()
return cloneModelRecommendations(c.recommendations)
}
func (c *modelRecommendationsCache) GetSWR(ctx context.Context) []api.ModelRecommendation {
recs := c.Get()
c.triggerRefreshOnRead(ctx)
return recs
}
func (c *modelRecommendationsCache) set(recs []api.ModelRecommendation) {
c.mu.Lock()
c.recommendations = cloneModelRecommendations(recs)
c.mu.Unlock()
}
func (c *modelRecommendationsCache) beginRefresh() bool {
c.mu.Lock()
defer c.mu.Unlock()
if c.refreshing {
return false
}
c.refreshing = true
return true
}
func (c *modelRecommendationsCache) beginReadRefresh() bool {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
if c.refreshing || now.Before(c.nextReadRefreshAfter) {
return false
}
c.refreshing = true
return true
}
func (c *modelRecommendationsCache) endRefresh() {
c.mu.Lock()
c.refreshing = false
c.mu.Unlock()
}
func (c *modelRecommendationsCache) endReadRefresh() {
c.mu.Lock()
c.refreshing = false
c.nextReadRefreshAfter = time.Now().Add(modelRecommendationsReadRefreshCooldown)
c.mu.Unlock()
}
func (c *modelRecommendationsCache) refreshIfIdle(ctx context.Context) (bool, error) {
if !c.beginRefresh() {
return false, nil
}
defer c.endRefresh()
return true, c.refresh(ctx)
}
func (c *modelRecommendationsCache) triggerRefreshOnRead(ctx context.Context) {
if !c.beginReadRefresh() {
return
}
if ctx == nil {
ctx = context.Background()
}
ctx = context.WithoutCancel(ctx)
slog.Debug("triggering model recommendations refresh on read")
go func() {
defer c.endReadRefresh()
if err := c.refresh(ctx); err != nil {
switch {
case errors.Is(err, errModelRecommendationsNoCloud):
slog.Debug("skipping model recommendations read refresh because cloud is disabled")
default:
slog.Warn("model recommendations read refresh failed", "error", err)
}
}
}()
}
func (c *modelRecommendationsCache) run(ctx context.Context) {
c.loadSnapshot()
failures := 0
for {
started, err := c.refreshIfIdle(ctx)
switch {
case !started:
failures = 0
slog.Debug("skipping timer model recommendations refresh because refresh is already running")
case err == nil:
failures = 0
case errors.Is(err, errModelRecommendationsNoCloud):
failures = 0
slog.Debug("skipping model recommendations refresh because cloud is disabled")
default:
failures++
slog.Warn("model recommendations refresh failed", "error", err)
}
var wait time.Duration
if failures == 0 {
wait = withJitter(modelRecommendationsRefreshInterval)
} else {
wait = withJitter(modelRecommendationsBackoffSteps[min(failures-1, len(modelRecommendationsBackoffSteps)-1)])
}
slog.Info("model recommendations cache sleep scheduled", "wait", wait.String(), "consecutive_failures", failures)
select {
case <-ctx.Done():
slog.Debug("stopping model recommendations cache")
return
case <-time.After(wait):
}
}
}
func (c *modelRecommendationsCache) refresh(ctx context.Context) error {
if envconfig.NoCloud() {
return errModelRecommendationsNoCloud
}
slog.Debug("refreshing model recommendations from remote", "url", modelRecommendationsURL)
reqCtx, cancel := context.WithTimeout(ctx, modelRecommendationsFetchTimeout)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelRecommendationsURL, nil)
if err != nil {
return err
}
req.Header.Set("Accept", "application/json")
resp, err := c.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var payload api.ModelRecommendationsResponse
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return err
}
recs, err := validateModelRecommendations(payload.Recommendations)
if err != nil {
return err
}
c.set(recs)
slog.Debug("model recommendations refreshed", "count", len(recs))
if err := c.persistSnapshot(recs); err != nil {
slog.Warn("failed to persist model recommendations snapshot", "error", err)
}
return nil
}
func (c *modelRecommendationsCache) loadSnapshot() {
path, err := modelRecommendationsSnapshotPath()
if err != nil {
slog.Warn("failed to resolve model recommendations snapshot path", "error", err)
return
}
data, err := os.ReadFile(path)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
slog.Warn("failed to read model recommendations snapshot", "path", path, "error", err)
} else {
slog.Debug("model recommendations snapshot not found", "path", path)
}
return
}
var snap api.ModelRecommendationsResponse
if err := json.Unmarshal(data, &snap); err != nil {
slog.Warn("failed to parse model recommendations snapshot", "path", path, "error", err)
return
}
recs, err := validateModelRecommendations(snap.Recommendations)
if err != nil {
slog.Warn("ignoring invalid model recommendations snapshot", "path", path, "error", err)
return
}
c.set(recs)
slog.Debug("loaded model recommendations snapshot", "path", path, "count", len(recs))
}
func (c *modelRecommendationsCache) persistSnapshot(recs []api.ModelRecommendation) error {
path, err := modelRecommendationsSnapshotPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
payload := api.ModelRecommendationsResponse{Recommendations: recs}
data, err := json.MarshalIndent(payload, "", " ")
if err != nil {
return err
}
tmp, err := os.CreateTemp(filepath.Dir(path), ".model-recommendations-*.tmp")
if err != nil {
return err
}
tmpPath := tmp.Name()
defer os.Remove(tmpPath)
if _, err := tmp.Write(data); err != nil {
_ = tmp.Close()
return err
}
if err := tmp.Sync(); err != nil {
_ = tmp.Close()
return err
}
if err := tmp.Close(); err != nil {
return err
}
if err := os.Rename(tmpPath, path); err != nil {
return err
}
slog.Debug("persisted model recommendations snapshot", "path", path, "count", len(recs))
return nil
}
func modelRecommendationsSnapshotPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "cache", "model-recommendations.json"), nil
}
func validateModelRecommendations(recs []api.ModelRecommendation) ([]api.ModelRecommendation, error) {
if len(recs) == 0 {
return nil, errors.New("empty recommendations")
}
seen := make(map[string]struct{}, len(recs))
valid := make([]api.ModelRecommendation, 0, len(recs))
for _, rec := range recs {
rec.Model = strings.TrimSpace(rec.Model)
rec.Description = strings.TrimSpace(rec.Description)
rec.VRAM = strings.TrimSpace(rec.VRAM)
if rec.Model == "" {
return nil, errors.New("recommendation missing model")
}
if _, ok := seen[rec.Model]; ok {
return nil, fmt.Errorf("duplicate recommendation %q", rec.Model)
}
seen[rec.Model] = struct{}{}
if isCloudRecommendation(rec.Model) && (rec.ContextLength <= 0 || rec.MaxOutputTokens <= 0) {
slog.Warn("dropping cloud recommendation missing limits", "model", rec.Model)
continue
}
valid = append(valid, rec)
}
if len(valid) == 0 {
return nil, errors.New("no valid recommendations")
}
return valid, nil
}
func isCloudRecommendation(modelName string) bool {
return strings.HasSuffix(modelName, ":cloud") || strings.HasSuffix(modelName, "-cloud")
}
func withJitter(d time.Duration) time.Duration {
if d <= 0 {
return d
}
// jitter in range [0.8x, 1.2x]
factor := 0.8 + rand.Float64()*0.4
return time.Duration(float64(d) * factor)
}
func cloneModelRecommendations(in []api.ModelRecommendation) []api.ModelRecommendation {
out := make([]api.ModelRecommendation, len(in))
copy(out, in)
return out
}
var defaultModelRecommendations = []api.ModelRecommendation{
{
Model: "kimi-k2.6:cloud",
Description: "State-of-the-art coding, long-horizon execution, and multimodal agent swarm capability",
ContextLength: 262_144,
MaxOutputTokens: 262_144,
},
{
Model: "glm-5.1:cloud",
Description: "Reasoning and code generation",
ContextLength: 202_752,
MaxOutputTokens: 131_072,
},
{
Model: "qwen3.5:cloud",
Description: "Reasoning, coding, and agentic tool use with vision",
ContextLength: 262_144,
MaxOutputTokens: 32_768,
},
{
Model: "minimax-m2.7:cloud",
Description: "Fast, efficient coding and real-world productivity",
ContextLength: 204_800,
MaxOutputTokens: 128_000,
},
{
Model: "gemma4",
Description: "Reasoning and code generation locally",
VRAM: "~16GB",
},
{
Model: "qwen3.5",
Description: "Reasoning, coding, and visual understanding locally",
VRAM: "~11GB",
},
}

View file

@ -0,0 +1,586 @@
package server
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"slices"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
func TestModelRecommendationsDefaultOrder(t *testing.T) {
want := []string{
"kimi-k2.6:cloud",
"glm-5.1:cloud",
"qwen3.5:cloud",
"minimax-m2.7:cloud",
"gemma4",
"qwen3.5",
}
if got := modelRecommendationNames(defaultModelRecommendations); !slices.Equal(got, want) {
t.Fatalf("recommendations = %v, want %v", got, want)
}
}
func TestModelRecommendationsCacheRefreshAppliesServerSideChanges(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
first := []api.ModelRecommendation{
{Model: " first-cloud:cloud ", Description: " first ", ContextLength: 2048, MaxOutputTokens: 512},
{Model: " first-local ", Description: " first local ", VRAM: " ~3GB "},
}
second := []api.ModelRecommendation{
{Model: "second-cloud:cloud", Description: "second", ContextLength: 4096, MaxOutputTokens: 1024},
{Model: "second-local", Description: "second local", VRAM: "~6GB"},
}
calls := 0
cache := newModelRecommendationsCache()
cache.client = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.Method != http.MethodGet {
t.Fatalf("method = %q, want GET", req.Method)
}
if req.URL.String() != modelRecommendationsURL {
t.Fatalf("url = %q, want %q", req.URL.String(), modelRecommendationsURL)
}
calls++
payload := api.ModelRecommendationsResponse{Recommendations: first}
if calls > 1 {
payload.Recommendations = second
}
data, err := json.Marshal(payload)
if err != nil {
t.Fatalf("marshal payload failed: %v", err)
}
return jsonHTTPResponse(http.StatusOK, string(data)), nil
})}
if err := cache.refresh(context.Background()); err != nil {
t.Fatalf("first refresh failed: %v", err)
}
if got, want := cache.Get(), []api.ModelRecommendation{
{Model: "first-cloud:cloud", Description: "first", ContextLength: 2048, MaxOutputTokens: 512},
{Model: "first-local", Description: "first local", VRAM: "~3GB"},
}; !slices.Equal(got, want) {
t.Fatalf("after first refresh recommendations = %#v, want %#v", got, want)
}
if err := cache.refresh(context.Background()); err != nil {
t.Fatalf("second refresh failed: %v", err)
}
if got, want := cache.Get(), second; !slices.Equal(got, want) {
t.Fatalf("after second refresh recommendations = %#v, want %#v", got, want)
}
path, err := modelRecommendationsSnapshotPath()
if err != nil {
t.Fatalf("snapshot path failed: %v", err)
}
snapshotData, err := os.ReadFile(path)
if err != nil {
t.Fatalf("read snapshot failed: %v", err)
}
var snapshot api.ModelRecommendationsResponse
if err := json.Unmarshal(snapshotData, &snapshot); err != nil {
t.Fatalf("unmarshal snapshot failed: %v", err)
}
if !slices.Equal(snapshot.Recommendations, second) {
t.Fatalf("snapshot recommendations = %#v, want %#v", snapshot.Recommendations, second)
}
}
func TestModelRecommendationsCacheRefreshErrorCasesPreserveCurrentData(t *testing.T) {
cases := []struct {
name string
transport roundTripFunc
errSubstr string
}{
{
name: "transport error",
transport: func(*http.Request) (*http.Response, error) {
return nil, errors.New("network down")
},
errSubstr: "network down",
},
{
name: "remote status error",
transport: func(*http.Request) (*http.Response, error) {
return jsonHTTPResponse(http.StatusInternalServerError, "upstream broken"), nil
},
errSubstr: "status 500: upstream broken",
},
{
name: "invalid json payload",
transport: func(*http.Request) (*http.Response, error) {
return jsonHTTPResponse(http.StatusOK, "{"), nil
},
errSubstr: "unexpected EOF",
},
{
name: "duplicate recommendations",
transport: func(*http.Request) (*http.Response, error) {
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"dup","description":"a"},{"model":"dup","description":"b"}]}`), nil
},
errSubstr: `duplicate recommendation "dup"`,
},
{
name: "empty recommendations",
transport: func(*http.Request) (*http.Response, error) {
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[]}`), nil
},
errSubstr: "empty recommendations",
},
{
name: "only invalid cloud recommendations",
transport: func(*http.Request) (*http.Response, error) {
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"bad:cloud","description":"missing limits"}]}`), nil
},
errSubstr: "no valid recommendations",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
cache := newModelRecommendationsCache()
stable := []api.ModelRecommendation{{Model: "stable-local", Description: "stable desc", VRAM: "~2GB"}}
cache.set(stable)
cache.client = &http.Client{Transport: tc.transport}
err := cache.refresh(context.Background())
if err == nil {
t.Fatalf("refresh returned nil error")
}
if !strings.Contains(err.Error(), tc.errSubstr) {
t.Fatalf("error = %q, want substring %q", err.Error(), tc.errSubstr)
}
if got := cache.Get(); !slices.Equal(got, stable) {
t.Fatalf("recommendations changed on error: got %#v, want %#v", got, stable)
}
path, pathErr := modelRecommendationsSnapshotPath()
if pathErr != nil {
t.Fatalf("snapshot path failed: %v", pathErr)
}
if _, statErr := os.Stat(path); !errors.Is(statErr, os.ErrNotExist) {
t.Fatalf("snapshot file should not be written on error, stat err = %v", statErr)
}
})
}
}
func TestModelRecommendationsCacheRefreshNoCloudShortCircuits(t *testing.T) {
setupModelRecommendationsTestEnv(t, "1")
called := false
cache := newModelRecommendationsCache()
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
called = true
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"should-not-be-used","description":"n/a"}]}`), nil
})}
err := cache.refresh(context.Background())
if !errors.Is(err, errModelRecommendationsNoCloud) {
t.Fatalf("refresh error = %v, want %v", err, errModelRecommendationsNoCloud)
}
if called {
t.Fatalf("remote endpoint should not be called when cloud is disabled")
}
}
func TestModelRecommendationsSnapshotPersistAndLoad(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
want := []api.ModelRecommendation{
{Model: "persist-cloud:cloud", Description: "persisted", ContextLength: 8192, MaxOutputTokens: 2048},
{Model: "persist-local", Description: "persisted local", VRAM: "~5GB"},
}
writer := newModelRecommendationsCache()
if err := writer.persistSnapshot(want); err != nil {
t.Fatalf("persistSnapshot failed: %v", err)
}
loader := newModelRecommendationsCache()
loader.set([]api.ModelRecommendation{{Model: "old", Description: "old"}})
loader.loadSnapshot()
if got := loader.Get(); !slices.Equal(got, want) {
t.Fatalf("loaded recommendations = %#v, want %#v", got, want)
}
}
func TestModelRecommendationsLoadSnapshotInvalidDoesNotOverwrite(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
path, err := modelRecommendationsSnapshotPath()
if err != nil {
t.Fatalf("snapshot path failed: %v", err)
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
t.Fatalf("mkdir failed: %v", err)
}
if err := os.WriteFile(path, []byte("{invalid"), 0o644); err != nil {
t.Fatalf("write invalid snapshot failed: %v", err)
}
cache := newModelRecommendationsCache()
existing := []api.ModelRecommendation{{Model: "existing", Description: "existing description"}}
cache.set(existing)
cache.loadSnapshot()
if got := cache.Get(); !slices.Equal(got, existing) {
t.Fatalf("recommendations overwritten by invalid snapshot: got %#v, want %#v", got, existing)
}
}
func TestValidateModelRecommendationsTrimsAndDropsInvalidCloudEntries(t *testing.T) {
input := []api.ModelRecommendation{
{Model: " good-cloud:cloud ", Description: " good cloud ", ContextLength: 1024, MaxOutputTokens: 256},
{Model: "bad-cloud:cloud", Description: "missing limits"},
{Model: " good-local ", Description: " good local ", VRAM: " ~2GB "},
}
got, err := validateModelRecommendations(input)
if err != nil {
t.Fatalf("validateModelRecommendations failed: %v", err)
}
want := []api.ModelRecommendation{
{Model: "good-cloud:cloud", Description: "good cloud", ContextLength: 1024, MaxOutputTokens: 256},
{Model: "good-local", Description: "good local", VRAM: "~2GB"},
}
if !slices.Equal(got, want) {
t.Fatalf("validated recommendations = %#v, want %#v", got, want)
}
}
func TestModelRecommendationsHandlerReturnsDefaults(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(w)
ctx.Request = httptest.NewRequest(http.MethodGet, "/api/experimental/model-recommendations", nil)
s := &Server{}
s.ModelRecommendationsExperimentalHandler(ctx)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
}
got := decodeRecommendationNames(t, w)
want := modelRecommendationNames(defaultModelRecommendations)
if !slices.Equal(got, want) {
t.Fatalf("models = %v, want %v", got, want)
}
}
func TestModelRecommendationsHandlerUsesCache(t *testing.T) {
gin.SetMode(gin.TestMode)
setupModelRecommendationsTestEnv(t, "1")
cache := newModelRecommendationsCache()
cache.set([]api.ModelRecommendation{{Model: "test-model", Description: "test description"}})
w := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(w)
ctx.Request = httptest.NewRequest(http.MethodGet, "/api/experimental/model-recommendations", nil)
s := &Server{modelRecommendations: cache}
s.ModelRecommendationsExperimentalHandler(ctx)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
}
got := decodeRecommendationNames(t, w)
if !slices.Equal(got, []string{"test-model"}) {
t.Fatalf("models = %v, want %v", got, []string{"test-model"})
}
waitForCacheIdle(t, cache)
}
func TestModelRecommendationsRouteRegistration(t *testing.T) {
gin.SetMode(gin.TestMode)
setupModelRecommendationsTestEnv(t, "1")
cache := newModelRecommendationsCache()
cache.set([]api.ModelRecommendation{{Model: "route-model", Description: "route description"}})
s := &Server{modelRecommendations: cache}
router, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatalf("GenerateRoutes failed: %v", err)
}
getReq := httptest.NewRequest(http.MethodGet, "/api/experimental/model-recommendations", nil)
getResp := httptest.NewRecorder()
router.ServeHTTP(getResp, getReq)
if getResp.Code != http.StatusOK {
t.Fatalf("GET status = %d, want %d", getResp.Code, http.StatusOK)
}
if got := decodeRecommendationNames(t, getResp); !slices.Equal(got, []string{"route-model"}) {
t.Fatalf("GET models = %v, want %v", got, []string{"route-model"})
}
postReq := httptest.NewRequest(http.MethodPost, "/api/experimental/model-recommendations", nil)
postResp := httptest.NewRecorder()
router.ServeHTTP(postResp, postReq)
if postResp.Code != http.StatusMethodNotAllowed {
t.Fatalf("POST status = %d, want %d", postResp.Code, http.StatusMethodNotAllowed)
}
waitForCacheIdle(t, cache)
}
func TestModelRecommendationsGetSWRTriggersRefreshOnRead(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
cache := newModelRecommendationsCache()
old := []api.ModelRecommendation{{Model: "old", Description: "old"}}
newRecs := []api.ModelRecommendation{{Model: "new-cloud:cloud", Description: "new", ContextLength: 1024, MaxOutputTokens: 256}}
cache.set(old)
refreshDone := make(chan struct{})
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
defer close(refreshDone)
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"new-cloud:cloud","description":"new","context_length":1024,"max_output_tokens":256}]}`), nil
})}
gotImmediate := cache.GetSWR(context.Background())
if !slices.Equal(gotImmediate, old) {
t.Fatalf("GetSWR should return current cache immediately: got %#v, want %#v", gotImmediate, old)
}
select {
case <-refreshDone:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for async refresh")
}
waitForCondition(t, 2*time.Second, func() bool {
return slices.Equal(cache.Get(), newRecs)
})
waitForCacheIdle(t, cache)
}
func TestModelRecommendationsGetSWRSkipsWhenRefreshAlreadyInFlight(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
cache := newModelRecommendationsCache()
cache.set([]api.ModelRecommendation{{Model: "old", Description: "old"}})
started := make(chan struct{})
release := make(chan struct{})
var calls atomic.Int32
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
n := calls.Add(1)
if n == 1 {
close(started)
}
<-release
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"updated","description":"ok"}]}`), nil
})}
cache.GetSWR(context.Background())
select {
case <-started:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for first refresh call")
}
for range 5 {
cache.GetSWR(context.Background())
}
time.Sleep(50 * time.Millisecond)
if got := calls.Load(); got != 1 {
t.Fatalf("calls during in-flight refresh = %d, want 1", got)
}
close(release)
waitForCacheIdle(t, cache)
}
func TestModelRecommendationsGetSWRThrottlesRefreshAfterCompletion(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
withModelRecommendationsReadRefreshCooldown(t, 100*time.Millisecond)
cache := newModelRecommendationsCache()
cache.set([]api.ModelRecommendation{{Model: "old", Description: "old"}})
started := make(chan struct{})
release := make(chan struct{})
var calls atomic.Int32
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
if calls.Add(1) == 1 {
close(started)
<-release
}
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"updated","description":"ok"}]}`), nil
})}
cache.GetSWR(context.Background())
select {
case <-started:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for first refresh call")
}
time.Sleep(2 * modelRecommendationsReadRefreshCooldown)
close(release)
waitForCacheIdle(t, cache)
cache.GetSWR(context.Background())
time.Sleep(25 * time.Millisecond)
if got := calls.Load(); got != 1 {
t.Fatalf("calls during read refresh cooldown = %d, want 1", got)
}
}
func TestModelRecommendationsGetSWRRetriesAfterReadRefreshCooldown(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
withModelRecommendationsReadRefreshCooldown(t, 100*time.Millisecond)
cache := newModelRecommendationsCache()
old := []api.ModelRecommendation{{Model: "old", Description: "old"}}
cache.set(old)
var calls atomic.Int32
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
if calls.Add(1) == 1 {
return nil, errors.New("temporary upstream failure")
}
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"recovered","description":"ok"}]}`), nil
})}
cache.GetSWR(context.Background())
waitForCondition(t, 2*time.Second, func() bool { return calls.Load() >= 1 })
waitForCacheIdle(t, cache)
if !slices.Equal(cache.Get(), old) {
t.Fatalf("cache should remain unchanged after failed refresh, got %#v", cache.Get())
}
cache.GetSWR(context.Background())
time.Sleep(25 * time.Millisecond)
if got := calls.Load(); got != 1 {
t.Fatalf("calls during read refresh cooldown after failure = %d, want 1", got)
}
waitForCondition(t, 2*time.Second, func() bool {
cache.GetSWR(context.Background())
return calls.Load() >= 2
})
waitForCondition(t, 2*time.Second, func() bool {
return slices.Equal(cache.Get(), []api.ModelRecommendation{{Model: "recovered", Description: "ok"}})
})
waitForCacheIdle(t, cache)
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func jsonHTTPResponse(statusCode int, body string) *http.Response {
return &http.Response{
StatusCode: statusCode,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func setupModelRecommendationsTestEnv(t *testing.T, noCloudEnv string) {
t.Helper()
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
t.Setenv("HOMEDRIVE", filepath.VolumeName(home))
t.Setenv("HOMEPATH", strings.TrimPrefix(home, filepath.VolumeName(home)))
// Use explicit false rather than empty to avoid platform/env ambiguity.
if noCloudEnv == "" {
noCloudEnv = "false"
}
t.Setenv("OLLAMA_NO_CLOUD", noCloudEnv)
envconfig.ReloadServerConfig()
t.Cleanup(envconfig.ReloadServerConfig)
}
func withModelRecommendationsReadRefreshCooldown(t *testing.T, d time.Duration) {
t.Helper()
old := modelRecommendationsReadRefreshCooldown
modelRecommendationsReadRefreshCooldown = d
t.Cleanup(func() {
modelRecommendationsReadRefreshCooldown = old
})
}
func waitForCondition(t *testing.T, timeout time.Duration, cond func() bool) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if cond() {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatal("timed out waiting for condition")
}
func waitForCacheIdle(t *testing.T, cache *modelRecommendationsCache) {
t.Helper()
waitForCondition(t, 2*time.Second, func() bool {
cache.mu.RLock()
refreshing := cache.refreshing
cache.mu.RUnlock()
return !refreshing
})
}
func decodeRecommendationNames(t *testing.T, w *httptest.ResponseRecorder) []string {
t.Helper()
var resp struct {
Recommendations []struct {
Model string `json:"model"`
} `json:"recommendations"`
}
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decode failed: %v", err)
}
names := make([]string, 0, len(resp.Recommendations))
for _, rec := range resp.Recommendations {
names = append(names, rec.Model)
}
return names
}
func modelRecommendationNames(recs []api.ModelRecommendation) []string {
names := make([]string, len(recs))
for i, rec := range recs {
names[i] = rec.Model
}
return names
}

View file

@ -98,10 +98,11 @@ var useClient2 = experimentEnabled("client2")
var mode string = gin.DebugMode
type Server struct {
addr net.Addr
sched *Scheduler
defaultNumCtx int
requestLogger *inferenceRequestLogger
addr net.Addr
sched *Scheduler
defaultNumCtx int
requestLogger *inferenceRequestLogger
modelRecommendations *modelRecommendationsCache
}
func init() {
@ -1716,6 +1717,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/copy", s.CopyHandler)
r.POST("/api/experimental/web_search", s.WebSearchExperimentalHandler)
r.POST("/api/experimental/web_fetch", s.WebFetchExperimentalHandler)
r.GET("/api/experimental/model-recommendations", s.ModelRecommendationsExperimentalHandler)
// Inference
r.GET("/api/ps", s.PsHandler)
@ -1757,6 +1759,24 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
return r, nil
}
func (s *Server) ModelRecommendationsExperimentalHandler(c *gin.Context) {
recs := defaultModelRecommendations
source := "default"
if s.modelRecommendations != nil {
ctx := context.Background()
if c.Request != nil {
ctx = c.Request.Context()
}
recs = s.modelRecommendations.GetSWR(ctx)
source = "cache"
}
slog.Debug("serving model recommendations", "recommendation_source", source, "count", len(recs))
c.JSON(http.StatusOK, api.ModelRecommendationsResponse{
Recommendations: recs,
})
}
func Serve(ln net.Listener) error {
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
slog.Info("server config", "env", envconfig.Values())
@ -1791,7 +1811,13 @@ func Serve(ln net.Listener) error {
}
}
s := &Server{addr: ln.Addr()}
// TODO(parthsareen): If we add more runtime caches, prefer introducing a
// small cache manager owned by Server (for shared start/stop/health wiring)
// instead of adding one top-level field per cache here in Serve.
s := &Server{
addr: ln.Addr(),
modelRecommendations: newModelRecommendationsCache(),
}
if err := s.initRequestLogging(); err != nil {
return err
}
@ -1816,6 +1842,7 @@ func Serve(ln net.Listener) error {
schedCtx, schedDone := context.WithCancel(ctx)
sched := InitScheduler(schedCtx)
s.sched = sched
s.modelRecommendations.Start(ctx)
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
srvr := &http.Server{