anthropic: support tool_result images end-to-end

This commit is contained in:
Eva Ho 2026-05-06 15:17:00 -04:00
parent f81cad3373
commit 5286dfdbde
2 changed files with 148 additions and 8 deletions

View file

@ -348,10 +348,10 @@ func requiresCloudAnthropicChatFallback(path string, body []byte) bool {
return false
}
return hasAnthropicWebSearchTool(body) || hasAnthropicToolResultImage(body)
return hasAnthropicWebSearchTool(body) || hasAnthropicToolResultBase64Image(body)
}
func hasAnthropicToolResultImage(body []byte) bool {
func hasAnthropicToolResultBase64Image(body []byte) bool {
if len(body) == 0 {
return false
}
@ -378,7 +378,7 @@ func hasAnthropicToolResultImage(body []byte) bool {
if strings.TrimSpace(block.Type) != "tool_result" {
continue
}
if anthropicToolResultContentHasImage(block.Content) {
if anthropicToolResultContentHasBase64Image(block.Content) {
return true
}
}
@ -387,26 +387,32 @@ func hasAnthropicToolResultImage(body []byte) bool {
return false
}
func anthropicToolResultContentHasImage(raw json.RawMessage) bool {
func anthropicToolResultContentHasBase64Image(raw json.RawMessage) bool {
if len(raw) == 0 || bytes.Equal(bytes.TrimSpace(raw), []byte("null")) {
return false
}
var blocks []struct {
Type string `json:"type"`
Type string `json:"type"`
Source *struct {
Type string `json:"type"`
} `json:"source"`
}
if err := json.Unmarshal(raw, &blocks); err == nil {
for _, block := range blocks {
if strings.TrimSpace(block.Type) == "image" {
if strings.TrimSpace(block.Type) == "image" && block.Source != nil && strings.TrimSpace(block.Source.Type) == "base64" {
return true
}
}
}
var block struct {
Type string `json:"type"`
Type string `json:"type"`
Source *struct {
Type string `json:"type"`
} `json:"source"`
}
if err := json.Unmarshal(raw, &block); err == nil && strings.TrimSpace(block.Type) == "image" {
if err := json.Unmarshal(raw, &block); err == nil && strings.TrimSpace(block.Type) == "image" && block.Source != nil && strings.TrimSpace(block.Source.Type) == "base64" {
return true
}

View file

@ -863,6 +863,72 @@ func TestExplicitCloudPassthroughAPIAndV1(t *testing.T) {
}
})
t.Run("v1 messages tool_result url image bypasses conversion", func(t *testing.T) {
upstream, capture := newUpstream(t, `{"id":"msg_1","type":"message"}`)
defer upstream.Close()
original := cloudProxyBaseURL
cloudProxyBaseURL = upstream.URL
t.Cleanup(func() { cloudProxyBaseURL = original })
s := &Server{}
router, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatal(err)
}
local := httptest.NewServer(router)
defer local.Close()
reqBody := `{
"model":"gpt-oss:120b-cloud",
"max_tokens":10,
"messages":[{
"role":"user",
"content":[{
"type":"tool_result",
"tool_use_id":"call_456",
"content":[
{"type":"text","text":"Here is the screenshot:"},
{"type":"image","source":{"type":"url","url":"https://example.com/image.png"}}
]
}]
}],
"stream":false
}`
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages?beta=true", bytes.NewBufferString(reqBody))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := local.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body))
}
if capture.path != "/v1/messages" {
t.Fatalf("expected upstream path /v1/messages for url image passthrough, got %q", capture.path)
}
if !strings.Contains(capture.body, `"type":"tool_result"`) {
t.Fatalf("expected original anthropic request body, got %q", capture.body)
}
if !strings.Contains(capture.body, `"type":"url"`) {
t.Fatalf("expected url image source in upstream body, got %q", capture.body)
}
if strings.Contains(capture.body, `"num_predict":10`) {
t.Fatalf("expected no converted ollama options in upstream body, got %q", capture.body)
}
})
t.Run("v1 model retrieve bypasses conversion", func(t *testing.T) {
upstream, capture := newUpstream(t, `{"id":"kimi-k2.5:cloud","object":"model","created":1,"owned_by":"ollama"}`)
defer upstream.Close()
@ -1248,6 +1314,74 @@ func TestCloudPassthroughSkipsAnthropicToolResultImages(t *testing.T) {
}
}
func TestCloudPassthroughDoesNotSkipAnthropicToolResultURLImages(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
type upstreamCapture struct {
path string
}
capture := &upstreamCapture{}
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capture.path = r.URL.Path
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message"}`))
}))
defer upstream.Close()
original := cloudProxyBaseURL
cloudProxyBaseURL = upstream.URL
t.Cleanup(func() { cloudProxyBaseURL = original })
router := gin.New()
router.POST(
"/v1/messages",
cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable),
middleware.AnthropicMessagesMiddleware(),
func(c *gin.Context) { c.Status(http.StatusTeapot) },
)
local := httptest.NewServer(router)
defer local.Close()
reqBody := `{
"model":"kimi-k2.5:cloud",
"max_tokens":10,
"messages":[{
"role":"user",
"content":[{
"type":"tool_result",
"tool_use_id":"call_456",
"content":[
{"type":"text","text":"Here is the screenshot:"},
{"type":"image","source":{"type":"url","url":"https://example.com/image.png"}}
]
}]
}]
}`
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages", bytes.NewBufferString(reqBody))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := local.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected passthrough response status 200, got %d (%s)", resp.StatusCode, string(body))
}
if capture.path != "/v1/messages" {
t.Fatalf("expected passthrough to upstream /v1/messages for url images, got %q", capture.path)
}
}
func TestCloudPassthroughSigningFailureReturnsUnauthorized(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())