From 36280de4276fcf3f4bfdbf75628696ff329dab17 Mon Sep 17 00:00:00 2001 From: Tiago Terenas Almeida <45994793+tiagomta@users.noreply.github.com> Date: Fri, 12 Dec 2025 19:07:47 +0000 Subject: [PATCH] . --- server/routes.go | 5 +-- server/webproxy.go | 11 +++-- server/webproxy_test.go | 97 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 107 insertions(+), 6 deletions(-) create mode 100644 server/webproxy_test.go diff --git a/server/routes.go b/server/routes.go index 2dbfcb6e1..05c1e081a 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1521,13 +1521,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { // Inference r.GET("/api/ps", s.PsHandler) - // Web proxy endpoints: forward web search/fetch to main server (ollama.com) - r.POST("/api/web_search", s.WebSearchHandler) - r.POST("/api/web_fetch", s.WebFetchHandler) r.POST("/api/generate", s.GenerateHandler) r.POST("/api/chat", s.ChatHandler) r.POST("/api/embed", s.EmbedHandler) r.POST("/api/embeddings", s.EmbeddingsHandler) + r.POST("/api/web_search", s.WebSearchHandler) + r.POST("/api/web_fetch", s.WebFetchHandler) // Inference (OpenAI compatibility) r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler) diff --git a/server/webproxy.go b/server/webproxy.go index 70a46ad20..3fe2e8f44 100644 --- a/server/webproxy.go +++ b/server/webproxy.go @@ -13,6 +13,12 @@ import ( "github.com/ollama/ollama/auth" ) +// signFunc is a variable to allow tests to override signing behavior. +var signFunc = auth.Sign + +// httpClient is injectable for tests to capture outbound requests. +var httpClient = &http.Client{Timeout: 30 * time.Second} + // proxyToMain forwards the incoming request body to the main ollama server // and sets an Authorization token if signing is available locally. func (s *Server) proxyToMain(c *gin.Context, path string) { @@ -27,7 +33,7 @@ func (s *Server) proxyToMain(c *gin.Context, path string) { now := strconv.FormatInt(time.Now().Unix(), 10) chal := fmt.Sprintf("%s,%s?ts=%s", http.MethodPost, path, now) - token, err := auth.Sign(ctx, []byte(chal)) + token, err := signFunc(ctx, []byte(chal)) if err != nil { // If signing fails, return an error so callers know the proxy couldn't // obtain a token. Clients may fallback to asking the user for a key. @@ -57,8 +63,7 @@ func (s *Server) proxyToMain(c *gin.Context, path string) { req.Header.Set("Authorization", token) } - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := httpClient.Do(req) if err != nil { c.JSON(http.StatusBadGateway, gin.H{"error": "failed to contact main server"}) return diff --git a/server/webproxy_test.go b/server/webproxy_test.go new file mode 100644 index 000000000..e0e17b6d0 --- /dev/null +++ b/server/webproxy_test.go @@ -0,0 +1,97 @@ +package server + +import ( + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" +) + +type rtStub struct { + fn func(req *http.Request) *http.Response +} + +func (r *rtStub) RoundTrip(req *http.Request) (*http.Response, error) { + if r.fn == nil { + return nil, errors.New("no stub") + } + return r.fn(req), nil +} + +func TestProxyToMain_WithToken_SetsAuthorizationAndForwards(t *testing.T) { + t.Parallel() + + // backup globals + origSign := signFunc + origClient := httpClient + defer func() { signFunc = origSign; httpClient = origClient }() + + var captured *http.Request + + // stub signing + signFunc = func(_ any, _ []byte) (string, error) { return "Bearer testtoken", nil } + + // stub transport + stub := &rtStub{fn: func(req *http.Request) *http.Response { + captured = req + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("ok")), + Header: http.Header{"Content-Type": {"application/json"}}, + } + }} + httpClient = &http.Client{Transport: stub} + + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + req := httptest.NewRequest(http.MethodPost, "/api/web_search", strings.NewReader(`{"q":"hi"}`)) + req.Header.Set("Content-Type", "application/json") + c.Request = req + + s := &Server{} + s.WebSearchHandler(c) + + if w.Code != 200 { + t.Fatalf("expected 200 status, got %d", w.Code) + } + if strings.TrimSpace(w.Body.String()) != "ok" { + t.Fatalf("unexpected body: %q", w.Body.String()) + } + if captured == nil { + t.Fatal("no outbound request captured") + } + if got := captured.Header.Get("Authorization"); got != "Bearer testtoken" { + t.Fatalf("expected Authorization header set, got %q", got) + } + if captured.URL.Path != "/api/web_search" { + t.Fatalf("expected path /api/web_search, got %s", captured.URL.Path) + } +} + +func TestProxyToMain_SignFails_Returns500(t *testing.T) { + t.Parallel() + + origSign := signFunc + defer func() { signFunc = origSign }() + + signFunc = func(_ any, _ []byte) (string, error) { return "", errors.New("no key") } + + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + req := httptest.NewRequest(http.MethodPost, "/api/web_fetch", strings.NewReader(`{"url":"https://example.com"}`)) + req.Header.Set("Content-Type", "application/json") + c.Request = req + + s := &Server{} + s.WebFetchHandler(c) + + if w.Code != http.StatusInternalServerError { + t.Fatalf("expected 500 status, got %d", w.Code) + } +}