mirror of
https://github.com/ollama/ollama.git
synced 2026-05-13 14:27:00 +00:00
anthropic: Preserve Claude local image-path tool results in renderer-owned prompt formatting (#16047)
This commit is contained in:
parent
421faa0263
commit
6bdb73073b
16 changed files with 670 additions and 102 deletions
|
|
@ -78,6 +78,11 @@ type MessagesRequest struct {
|
|||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Metadata *Metadata `json:"metadata,omitempty"`
|
||||
OutputConfig *OutputConfig `json:"output_config,omitempty"`
|
||||
}
|
||||
|
||||
type OutputConfig struct {
|
||||
Effort string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
// MessageParam represents a message in the request
|
||||
|
|
@ -161,7 +166,7 @@ type WebSearchToolResultError struct {
|
|||
|
||||
// ImageSource represents the source of an image
|
||||
type ImageSource struct {
|
||||
Type string `json:"type"` // "base64" or "url"
|
||||
Type string `json:"type"` // "base64"
|
||||
MediaType string `json:"media_type,omitempty"`
|
||||
Data string `json:"data,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
|
|
@ -373,9 +378,26 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
|||
}
|
||||
|
||||
var think *api.ThinkValue
|
||||
normalizedEffort := ""
|
||||
if r.OutputConfig != nil {
|
||||
normalizedEffort = strings.ToLower(strings.TrimSpace(r.OutputConfig.Effort))
|
||||
if normalizedEffort == "xhigh" {
|
||||
normalizedEffort = "high"
|
||||
}
|
||||
}
|
||||
|
||||
if r.Thinking != nil && r.Thinking.Type == "enabled" {
|
||||
think = &api.ThinkValue{Value: true}
|
||||
}
|
||||
if r.Thinking != nil && r.Thinking.Type == "disabled" {
|
||||
think = &api.ThinkValue{Value: false}
|
||||
}
|
||||
if think == nil && r.OutputConfig != nil {
|
||||
switch normalizedEffort {
|
||||
case "high", "medium", "low", "max":
|
||||
think = &api.ThinkValue{Value: normalizedEffort}
|
||||
}
|
||||
}
|
||||
|
||||
stream := r.Stream
|
||||
convertedRequest := &api.ChatRequest{
|
||||
|
|
@ -425,17 +447,12 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
|||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
|
||||
if block.Source.Type == "base64" {
|
||||
decoded, err := base64.StdEncoding.DecodeString(block.Source.Data)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
images = append(images, decoded)
|
||||
} else {
|
||||
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", block.Source.Type)
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", block.Source.Type)
|
||||
decoded, err := resolveImageSource(block.Source)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: unsupported image source", "role", role, "source_type", block.Source.Type, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
images = append(images, decoded)
|
||||
|
||||
case "tool_use":
|
||||
toolUseBlocks++
|
||||
|
|
@ -457,26 +474,16 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
|||
|
||||
case "tool_result":
|
||||
toolResultBlocks++
|
||||
var resultContent string
|
||||
|
||||
switch c := block.Content.(type) {
|
||||
case string:
|
||||
resultContent = c
|
||||
case []any:
|
||||
for _, cb := range c {
|
||||
if cbMap, ok := cb.(map[string]any); ok {
|
||||
if cbMap["type"] == "text" {
|
||||
if text, ok := cbMap["text"].(string); ok {
|
||||
resultContent += text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
resultContent, resultImages, err := convertToolResultContent(block.Content)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: invalid tool_result content", "role", role, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: resultContent,
|
||||
Images: resultImages,
|
||||
ToolCallID: block.ToolUseID,
|
||||
})
|
||||
|
||||
|
|
@ -508,6 +515,10 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
|||
}
|
||||
}
|
||||
|
||||
if role == "user" && len(toolResults) > 0 {
|
||||
messages = append(messages, toolResults...)
|
||||
}
|
||||
|
||||
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
|
||||
m := api.Message{
|
||||
Role: role,
|
||||
|
|
@ -519,8 +530,10 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
|||
messages = append(messages, m)
|
||||
}
|
||||
|
||||
// Add tool results as separate messages
|
||||
messages = append(messages, toolResults...)
|
||||
// Add tool results as separate messages.
|
||||
if role != "user" || len(toolResults) == 0 {
|
||||
messages = append(messages, toolResults...)
|
||||
}
|
||||
logutil.Trace("anthropic: converted block message",
|
||||
"role", role,
|
||||
"blocks", len(msg.Content),
|
||||
|
|
@ -969,6 +982,71 @@ func GenerateMessageID() string {
|
|||
return generateID("msg")
|
||||
}
|
||||
|
||||
func resolveImageSource(source *ImageSource) (api.ImageData, error) {
|
||||
if source.Type != "base64" {
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", source.Type)
|
||||
}
|
||||
|
||||
decoded, err := base64.StdEncoding.DecodeString(source.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
func convertToolResultContent(content any) (string, []api.ImageData, error) {
|
||||
switch c := content.(type) {
|
||||
case nil:
|
||||
return "", nil, nil
|
||||
case string:
|
||||
return c, nil, nil
|
||||
case []any:
|
||||
var text strings.Builder
|
||||
var images []api.ImageData
|
||||
|
||||
for _, cb := range c {
|
||||
cbMap, ok := cb.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch cbMap["type"] {
|
||||
case "text":
|
||||
if t, ok := cbMap["text"].(string); ok {
|
||||
text.WriteString(t)
|
||||
}
|
||||
case "image":
|
||||
rawSource, ok := cbMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
return "", nil, errors.New("invalid tool_result image source")
|
||||
}
|
||||
|
||||
var source ImageSource
|
||||
if rawType, ok := rawSource["type"].(string); ok {
|
||||
source.Type = rawType
|
||||
}
|
||||
if rawMediaType, ok := rawSource["media_type"].(string); ok {
|
||||
source.MediaType = rawMediaType
|
||||
}
|
||||
if rawData, ok := rawSource["data"].(string); ok {
|
||||
source.Data = rawData
|
||||
}
|
||||
|
||||
img, err := resolveImageSource(&source)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
images = append(images, img)
|
||||
}
|
||||
}
|
||||
|
||||
return text.String(), images, nil
|
||||
default:
|
||||
return "", nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ptr returns a pointer to the given string value
|
||||
func ptr(s string) *string {
|
||||
return &s
|
||||
|
|
|
|||
|
|
@ -271,6 +271,241 @@ func TestFromMessagesRequest_WithToolResult(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolResultImage(t *testing.T) {
|
||||
imgData, _ := base64.StdEncoding.DecodeString(testImage)
|
||||
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "tool_result",
|
||||
ToolUseID: "call_img",
|
||||
Content: []any{
|
||||
map[string]any{"type": "text", "text": "Attached image"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": testImage,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
msg := result.Messages[0]
|
||||
if msg.Role != "tool" {
|
||||
t.Errorf("expected role 'tool', got %q", msg.Role)
|
||||
}
|
||||
if msg.ToolCallID != "call_img" {
|
||||
t.Errorf("expected tool_call_id 'call_img', got %q", msg.ToolCallID)
|
||||
}
|
||||
if msg.Content != "Attached image" {
|
||||
t.Errorf("unexpected content: %q", msg.Content)
|
||||
}
|
||||
if len(msg.Images) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(msg.Images))
|
||||
}
|
||||
if string(msg.Images[0]) != string(imgData) {
|
||||
t.Error("image data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolResultFollowedByUserText(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "tool_use",
|
||||
ID: "call_read",
|
||||
Name: "Read",
|
||||
Input: makeArgs("file_path", "/Users/hoyyeva/Desktop/aaa.png"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "tool_result",
|
||||
ToolUseID: "call_read",
|
||||
Content: "Read image (311.5KB)",
|
||||
},
|
||||
{
|
||||
Type: "text",
|
||||
Text: ptr("Please describe it."),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 3 {
|
||||
t.Fatalf("expected 3 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
if result.Messages[1].Role != "tool" {
|
||||
t.Fatalf("expected second message to be tool, got %q", result.Messages[1].Role)
|
||||
}
|
||||
if result.Messages[1].ToolCallID != "call_read" {
|
||||
t.Fatalf("expected tool_call_id 'call_read', got %q", result.Messages[1].ToolCallID)
|
||||
}
|
||||
if result.Messages[2].Role != "user" {
|
||||
t.Fatalf("expected third message to be user, got %q", result.Messages[2].Role)
|
||||
}
|
||||
if result.Messages[2].Content != "Please describe it." {
|
||||
t.Fatalf("unexpected user content: %q", result.Messages[2].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithOutputConfigEffort(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "gemma4",
|
||||
MaxTokens: 32000,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: textContent("Describe the image."),
|
||||
},
|
||||
},
|
||||
OutputConfig: &OutputConfig{
|
||||
Effort: "high",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Think == nil {
|
||||
t.Fatal("expected think to be set from output_config.effort")
|
||||
}
|
||||
|
||||
if got := result.Think.String(); got != "high" {
|
||||
t.Fatalf("expected think level 'high', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithOutputConfigEffortXHighMapsToHigh(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "gemma4",
|
||||
MaxTokens: 32000,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: textContent("Describe the image."),
|
||||
},
|
||||
},
|
||||
OutputConfig: &OutputConfig{
|
||||
Effort: "xhigh",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Think == nil {
|
||||
t.Fatal("expected think to be set from output_config.effort")
|
||||
}
|
||||
|
||||
if got := result.Think.String(); got != "high" {
|
||||
t.Fatalf("expected think level 'high' for xhigh effort, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_ThinkingDisabledOverridesOutputConfigEffort(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "gemma4",
|
||||
MaxTokens: 32000,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: textContent("Describe the image."),
|
||||
},
|
||||
},
|
||||
Thinking: &ThinkingConfig{
|
||||
Type: "disabled",
|
||||
},
|
||||
OutputConfig: &OutputConfig{
|
||||
Effort: "high",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Think == nil {
|
||||
t.Fatal("expected think to be set")
|
||||
}
|
||||
|
||||
if got := result.Think.Value; got != false {
|
||||
t.Fatalf("expected think=false when thinking is disabled, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_ThinkingAdaptiveUsesOutputConfigEffort(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "gemma4",
|
||||
MaxTokens: 32000,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: textContent("Describe the image."),
|
||||
},
|
||||
},
|
||||
Thinking: &ThinkingConfig{
|
||||
Type: "adaptive",
|
||||
},
|
||||
OutputConfig: &OutputConfig{
|
||||
Effort: "high",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Think == nil {
|
||||
t.Fatal("expected think to be set from output_config.effort")
|
||||
}
|
||||
|
||||
if got := result.Think.String(); got != "high" {
|
||||
t.Fatalf("expected think level 'high' for adaptive thinking, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
|
|
|
|||
|
|
@ -98,7 +98,8 @@ func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkV
|
|||
toolResponsesEmitted := false
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for k := i + 1; k < len(loopMessages) && loopMessages[k].Role == "tool"; k++ {
|
||||
sb.WriteString(r.formatToolResponseBlock(r.toolResponseName(loopMessages[k], message.ToolCalls), loopMessages[k].Content))
|
||||
response := r.renderToolResponseContent(loopMessages[k], &imageOffset)
|
||||
sb.WriteString(r.formatToolResponseBlock(r.toolResponseName(loopMessages[k], message.ToolCalls), response))
|
||||
toolResponsesEmitted = true
|
||||
prevMessageType = "tool_response"
|
||||
}
|
||||
|
|
@ -160,19 +161,22 @@ func stripThinking(text string) string {
|
|||
// When trim is true, leading/trailing whitespace is stripped (matching the Jinja2
|
||||
// template's | trim filter applied to non-model content).
|
||||
func (r *Gemma4Renderer) renderContent(sb *strings.Builder, msg api.Message, imageOffset *int, trim bool) {
|
||||
if len(msg.Images) > 0 && r.useImgTags {
|
||||
for range msg.Images {
|
||||
sb.WriteString(fmt.Sprintf("[img-%d]", *imageOffset))
|
||||
*imageOffset++
|
||||
}
|
||||
}
|
||||
content := msg.Content
|
||||
if trim {
|
||||
content = strings.TrimSpace(content)
|
||||
}
|
||||
if len(msg.Images) > 0 && r.useImgTags {
|
||||
content, *imageOffset = renderContentWithImageTags(content, len(msg.Images), *imageOffset)
|
||||
}
|
||||
sb.WriteString(content)
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) renderToolResponseContent(msg api.Message, imageOffset *int) string {
|
||||
var sb strings.Builder
|
||||
r.renderContent(&sb, msg, imageOffset, false)
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) previousNonToolRole(messages []api.Message, idx int) string {
|
||||
for i := idx - 1; i >= 0; i-- {
|
||||
if messages[i].Role != "tool" {
|
||||
|
|
|
|||
|
|
@ -13,15 +13,11 @@ type GlmOcrRenderer struct {
|
|||
}
|
||||
|
||||
func (r *GlmOcrRenderer) renderContent(message api.Message, imageOffset int) (string, int) {
|
||||
var sb strings.Builder
|
||||
for range message.Images {
|
||||
if r.useImgTags {
|
||||
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
|
||||
imageOffset++
|
||||
}
|
||||
if r.useImgTags {
|
||||
return renderContentWithImageTags(message.Content, len(message.Images), imageOffset)
|
||||
}
|
||||
sb.WriteString(message.Content)
|
||||
return sb.String(), imageOffset
|
||||
|
||||
return message.Content, imageOffset
|
||||
}
|
||||
|
||||
func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
|
|
@ -85,8 +81,10 @@ func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkV
|
|||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
content, nextOffset := r.renderContent(message, imageOffset)
|
||||
imageOffset = nextOffset
|
||||
sb.WriteString("\n<tool_response>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString(content)
|
||||
sb.WriteString("\n</tool_response>\n")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>\n")
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ func TestGlmOcrRenderer_Images(t *testing.T) {
|
|||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\n[img-0]Describe this image.<|assistant|>\n",
|
||||
expected: "[gMASK]<sop><|user|>\n[img-0] Describe this image.<|assistant|>\n",
|
||||
},
|
||||
{
|
||||
name: "use_img_tags_multiple_images",
|
||||
|
|
@ -37,7 +37,7 @@ func TestGlmOcrRenderer_Images(t *testing.T) {
|
|||
Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\n[img-0][img-1]Describe these images.<|assistant|>\n",
|
||||
expected: "[gMASK]<sop><|user|>\n[img-0][img-1] Describe these images.<|assistant|>\n",
|
||||
},
|
||||
{
|
||||
name: "multi_turn_increments_image_offset",
|
||||
|
|
@ -58,7 +58,7 @@ func TestGlmOcrRenderer_Images(t *testing.T) {
|
|||
Images: []api.ImageData{api.ImageData("img2")},
|
||||
},
|
||||
},
|
||||
expected: "[gMASK]<sop><|user|>\n[img-0]First image<|assistant|>\n<think></think>\nProcessed.\n<|user|>\n[img-1]Second image<|assistant|>\n",
|
||||
expected: "[gMASK]<sop><|user|>\n[img-0] First image<|assistant|>\n<think></think>\nProcessed.\n<|user|>\n[img-1] Second image<|assistant|>\n",
|
||||
},
|
||||
{
|
||||
name: "default_no_img_tags",
|
||||
|
|
|
|||
39
model/renderers/image_tags.go
Normal file
39
model/renderers/image_tags.go
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
package renderers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// renderContentWithImageTags preserves the legacy server-side placeholder
|
||||
// semantics for explicit [img] tokens: replace placeholders in order, and
|
||||
// only prepend tags for any remaining images without placeholders.
|
||||
func renderContentWithImageTags(content string, imageCount int, imageOffset int) (string, int) {
|
||||
if imageCount == 0 {
|
||||
return content, imageOffset
|
||||
}
|
||||
|
||||
if strings.Contains(content, "[img-") {
|
||||
return content, imageOffset + imageCount
|
||||
}
|
||||
|
||||
var prefix strings.Builder
|
||||
for i := range imageCount {
|
||||
imgTag := fmt.Sprintf("[img-%d]", imageOffset+i)
|
||||
if strings.Contains(content, "[img]") {
|
||||
content = strings.Replace(content, "[img]", imgTag, 1)
|
||||
} else {
|
||||
prefix.WriteString(imgTag)
|
||||
}
|
||||
}
|
||||
|
||||
if prefix.Len() > 0 && content != "" {
|
||||
if r, _ := utf8.DecodeRuneInString(content); r != utf8.RuneError && !unicode.IsSpace(r) {
|
||||
prefix.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
return prefix.String() + content, imageOffset + imageCount
|
||||
}
|
||||
67
model/renderers/image_tags_test.go
Normal file
67
model/renderers/image_tags_test.go
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
package renderers
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestRenderContentWithImageTags(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
imageCount int
|
||||
imageOffset int
|
||||
want string
|
||||
wantOffset int
|
||||
}{
|
||||
{
|
||||
name: "prefixes when there are no placeholders",
|
||||
content: "describe this image",
|
||||
imageCount: 2,
|
||||
imageOffset: 0,
|
||||
want: "[img-0][img-1] describe this image",
|
||||
wantOffset: 2,
|
||||
},
|
||||
{
|
||||
name: "replaces explicit placeholders in order",
|
||||
content: "compare [img] and [img]",
|
||||
imageCount: 2,
|
||||
imageOffset: 3,
|
||||
want: "compare [img-3] and [img-4]",
|
||||
wantOffset: 5,
|
||||
},
|
||||
{
|
||||
name: "prefixes extra images after placeholders are exhausted",
|
||||
content: "compare [img]",
|
||||
imageCount: 2,
|
||||
imageOffset: 0,
|
||||
want: "[img-1] compare [img-0]",
|
||||
wantOffset: 2,
|
||||
},
|
||||
{
|
||||
name: "leaves leftover placeholders when there are fewer images",
|
||||
content: "compare [img] and [img]",
|
||||
imageCount: 1,
|
||||
imageOffset: 0,
|
||||
want: "compare [img-0] and [img]",
|
||||
wantOffset: 1,
|
||||
},
|
||||
{
|
||||
name: "preserves already-numbered placeholders",
|
||||
content: "compare [img-0] and [img-1]",
|
||||
imageCount: 2,
|
||||
imageOffset: 0,
|
||||
want: "compare [img-0] and [img-1]",
|
||||
wantOffset: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, gotOffset := renderContentWithImageTags(tt.content, tt.imageCount, tt.imageOffset)
|
||||
if got != tt.want {
|
||||
t.Fatalf("content = %q, want %q", got, tt.want)
|
||||
}
|
||||
if gotOffset != tt.wantOffset {
|
||||
t.Fatalf("offset = %d, want %d", gotOffset, tt.wantOffset)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -3,7 +3,6 @@ package renderers
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
|
|
@ -199,19 +198,18 @@ func (r *LFM2Renderer) renderMessageContent(message api.Message, imageOffset int
|
|||
return content
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
if r.useImgTags {
|
||||
for i := range message.Images {
|
||||
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset+i))
|
||||
}
|
||||
} else {
|
||||
placeholder := lfm2ImagePlaceholder(false)
|
||||
if strings.Contains(content, placeholder) {
|
||||
return content
|
||||
}
|
||||
for range message.Images {
|
||||
sb.WriteString(placeholder)
|
||||
}
|
||||
content, _ = renderContentWithImageTags(content, len(message.Images), imageOffset)
|
||||
return content
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
placeholder := lfm2ImagePlaceholder(false)
|
||||
if strings.Contains(content, placeholder) {
|
||||
return content
|
||||
}
|
||||
for range message.Images {
|
||||
sb.WriteString(placeholder)
|
||||
}
|
||||
sb.WriteString(content)
|
||||
return sb.String()
|
||||
|
|
|
|||
|
|
@ -236,7 +236,7 @@ func TestLFM2Renderer_Images(t *testing.T) {
|
|||
Content: "Describe this image.",
|
||||
Images: []api.ImageData{api.ImageData("img1")},
|
||||
},
|
||||
expected: "<|startoftext|><|im_start|>user\n[img-0]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
|
||||
expected: "<|startoftext|><|im_start|>user\n[img-0] Describe this image.<|im_end|>\n<|im_start|>assistant\n",
|
||||
},
|
||||
{
|
||||
name: "existing_template_image_placeholder_not_duplicated",
|
||||
|
|
|
|||
|
|
@ -79,12 +79,14 @@ func (r *Nemotron3NanoRenderer) Render(messages []api.Message, tools []api.Tool,
|
|||
// Check if previous message was also a tool message
|
||||
prevWasTool := i > 0 && loopMessages[i-1].Role == "tool"
|
||||
nextIsTool := i+1 < len(loopMessages) && loopMessages[i+1].Role == "tool"
|
||||
content := r.renderMessageContent(message, imageOffset)
|
||||
imageOffset += len(message.Images)
|
||||
|
||||
if !prevWasTool {
|
||||
sb.WriteString("<|im_start|>user\n")
|
||||
}
|
||||
sb.WriteString("<tool_response>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString(content)
|
||||
sb.WriteString("\n</tool_response>\n")
|
||||
|
||||
if !nextIsTool {
|
||||
|
|
@ -237,23 +239,8 @@ func (r *Nemotron3NanoRenderer) renderMessageContent(message api.Message, imageO
|
|||
return content
|
||||
}
|
||||
|
||||
if strings.Contains(content, "[img-") {
|
||||
return content
|
||||
}
|
||||
|
||||
if strings.Contains(content, "[img]") {
|
||||
for i := range message.Images {
|
||||
content = strings.Replace(content, "[img]", fmt.Sprintf("[img-%d]", imageOffset+i), 1)
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for i := range message.Images {
|
||||
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset+i))
|
||||
}
|
||||
sb.WriteString(content)
|
||||
return sb.String()
|
||||
content, _ = renderContentWithImageTags(content, len(message.Images), imageOffset)
|
||||
return content
|
||||
}
|
||||
|
||||
func nemotron3NanoRenderContent(content any) string {
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ func TestNemotron3NanoRenderer_Images(t *testing.T) {
|
|||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Describe this image.", Images: []api.ImageData{api.ImageData("img1")}},
|
||||
},
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\n[img-0]Describe this image.<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\n[img-0] Describe this image.<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
{
|
||||
name: "generic image placeholder is rewritten",
|
||||
|
|
@ -35,7 +35,7 @@ func TestNemotron3NanoRenderer_Images(t *testing.T) {
|
|||
{Role: "assistant", Content: "It shows something."},
|
||||
{Role: "user", Content: "Compare these.", Images: []api.ImageData{api.ImageData("img2"), api.ImageData("img3")}},
|
||||
},
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\n[img-0]Describe the first image.<|im_end|>\n<|im_start|>assistant\n<think></think>It shows something.<|im_end|>\n<|im_start|>user\n[img-1][img-2]Compare these.<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\n[img-0] Describe the first image.<|im_end|>\n<|im_start|>assistant\n<think></think>It shows something.<|im_end|>\n<|im_start|>user\n[img-1][img-2] Compare these.<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package renderers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
|
|
@ -45,15 +44,14 @@ type Qwen35Renderer struct {
|
|||
}
|
||||
|
||||
func (r *Qwen35Renderer) renderContent(content api.Message, imageOffset int) (string, int) {
|
||||
if r.useImgTags {
|
||||
return renderContentWithImageTags(content.Content, len(content.Images), imageOffset)
|
||||
}
|
||||
|
||||
// This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go
|
||||
var subSb strings.Builder
|
||||
for range content.Images {
|
||||
if r.useImgTags {
|
||||
subSb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
|
||||
imageOffset++
|
||||
} else {
|
||||
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
|
||||
}
|
||||
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
|
||||
}
|
||||
// TODO: support videos
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package renderers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
|
|
@ -15,18 +14,17 @@ type Qwen3VLRenderer struct {
|
|||
}
|
||||
|
||||
func (r *Qwen3VLRenderer) renderContent(content api.Message, imageOffset int) (string, int) {
|
||||
if r.useImgTags {
|
||||
return renderContentWithImageTags(content.Content, len(content.Images), imageOffset)
|
||||
}
|
||||
|
||||
// This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go
|
||||
var subSb strings.Builder
|
||||
for range content.Images {
|
||||
// TODO: (jmorganca): how to render this is different for different
|
||||
// model backends, and so we should eventually parameterize this or
|
||||
// only output a placeholder such as [img]
|
||||
if r.useImgTags {
|
||||
subSb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
|
||||
imageOffset++
|
||||
} else {
|
||||
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
|
||||
}
|
||||
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
|
||||
}
|
||||
// TODO: support videos
|
||||
|
||||
|
|
@ -126,7 +124,7 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, think
|
|||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|im_start|>user")
|
||||
}
|
||||
sb.WriteString("\n<tool_response>\n" + message.Content + "\n</tool_response>")
|
||||
sb.WriteString("\n<tool_response>\n" + content + "\n</tool_response>")
|
||||
if i == len(messages)-1 || messages[i+1].Role != "tool" {
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ Let me analyze this image.`,
|
|||
},
|
||||
useImgTags: true,
|
||||
expected: `<|im_start|>user
|
||||
[img-0]Describe this image.<|im_end|>
|
||||
[img-0] Describe this image.<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Let me analyze this image.`,
|
||||
},
|
||||
|
|
@ -123,7 +123,7 @@ Let me analyze this image.`,
|
|||
},
|
||||
useImgTags: true,
|
||||
expected: `<|im_start|>user
|
||||
[img-0][img-1]Describe these images.<|im_end|>
|
||||
[img-0][img-1] Describe these images.<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Let me analyze this image.`,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -75,7 +75,9 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
|
||||
}
|
||||
|
||||
for cnt, msg := range msgs[currMsgIdx:] {
|
||||
renderMsgs := slices.Clone(msgs)
|
||||
|
||||
for cnt, msg := range renderMsgs[currMsgIdx:] {
|
||||
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
|
||||
return "", nil, errors.New("this model only supports one image while more than one image requested")
|
||||
}
|
||||
|
|
@ -101,11 +103,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
|
||||
}
|
||||
}
|
||||
msgs[currMsgIdx+cnt].Content = prefix + prompt
|
||||
|
||||
if m.Config.Renderer != "" {
|
||||
continue
|
||||
}
|
||||
|
||||
renderMsgs[currMsgIdx+cnt].Content = prefix + prompt
|
||||
}
|
||||
|
||||
// truncate any messages that do not fit into the context window
|
||||
p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think)
|
||||
p, err := renderPrompt(m, append(system, renderMsgs[currMsgIdx:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -401,11 +401,170 @@ func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
|
|||
t.Fatalf("len(images) = %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if !strings.Contains(prompt, "<|user|>\n[img-0][img-1]extract text") {
|
||||
if !strings.Contains(prompt, "<|user|>\n[img-0][img-1] extract text") {
|
||||
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatPromptRendererAddsToolImageTags(t *testing.T) {
|
||||
msgs := []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "look at this file",
|
||||
Images: []api.ImageData{[]byte("img-1")},
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_read",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "Read",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "attached image",
|
||||
Images: []api.ImageData{[]byte("img-2")},
|
||||
ToolCallID: "call_read",
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
renderer string
|
||||
wantUserTag string
|
||||
wantToolContent string
|
||||
}{
|
||||
{
|
||||
name: "gemma4",
|
||||
renderer: "gemma4",
|
||||
wantUserTag: "<|turn>user\n[img-0] look at this file<turn|>\n",
|
||||
wantToolContent: "[img-1] attached image",
|
||||
},
|
||||
{
|
||||
name: "qwen3-vl",
|
||||
renderer: "qwen3-vl-instruct",
|
||||
wantUserTag: "<|im_start|>user\n[img-0] look at this file<|im_end|>\n",
|
||||
wantToolContent: "<tool_response>\n[img-1] attached image\n</tool_response>",
|
||||
},
|
||||
{
|
||||
name: "qwen3.5",
|
||||
renderer: "qwen3.5",
|
||||
wantUserTag: "<|im_start|>user\n[img-0] look at this file<|im_end|>\n",
|
||||
wantToolContent: "<tool_response>\n[img-1] attached image\n</tool_response>",
|
||||
},
|
||||
{
|
||||
name: "glm-ocr",
|
||||
renderer: "glm-ocr",
|
||||
wantUserTag: "<|user|>\n[img-0] look at this file",
|
||||
wantToolContent: "<tool_response>\n[img-1] attached image\n</tool_response>",
|
||||
},
|
||||
{
|
||||
name: "nemotron-3-nano",
|
||||
renderer: "nemotron-3-nano",
|
||||
wantUserTag: "<|im_start|>user\n[img-0] look at this file<|im_end|>\n",
|
||||
wantToolContent: "<tool_response>\n[img-1] attached image\n</tool_response>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := Model{
|
||||
Config: model.ConfigV2{Renderer: tt.renderer},
|
||||
ProjectorPaths: []string{"vision"},
|
||||
}
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
|
||||
think := false
|
||||
|
||||
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, want := len(images), 2; got != want {
|
||||
t.Fatalf("len(images) = %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if !strings.Contains(prompt, tt.wantUserTag) {
|
||||
t.Fatalf("prompt missing user image tag, got: %q", prompt)
|
||||
}
|
||||
|
||||
if !strings.Contains(prompt, tt.wantToolContent) {
|
||||
t.Fatalf("prompt missing tool image tag, got: %q", prompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatPromptRendererPreservesExplicitImagePlaceholders(t *testing.T) {
|
||||
msgs := []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "compare [img] and [img]",
|
||||
Images: []api.ImageData{[]byte("img-1"), []byte("img-2")},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
renderer string
|
||||
wantSnippet string
|
||||
}{
|
||||
{
|
||||
name: "gemma4",
|
||||
renderer: "gemma4",
|
||||
wantSnippet: "<|turn>user\ncompare [img-0] and [img-1]<turn|>\n",
|
||||
},
|
||||
{
|
||||
name: "qwen3-vl",
|
||||
renderer: "qwen3-vl-instruct",
|
||||
wantSnippet: "<|im_start|>user\ncompare [img-0] and [img-1]<|im_end|>\n",
|
||||
},
|
||||
{
|
||||
name: "qwen3.5",
|
||||
renderer: "qwen3.5",
|
||||
wantSnippet: "<|im_start|>user\ncompare [img-0] and [img-1]<|im_end|>\n",
|
||||
},
|
||||
{
|
||||
name: "glm-ocr",
|
||||
renderer: "glm-ocr",
|
||||
wantSnippet: "<|user|>\ncompare [img-0] and [img-1]",
|
||||
},
|
||||
{
|
||||
name: "nemotron-3-nano",
|
||||
renderer: "nemotron-3-nano",
|
||||
wantSnippet: "<|im_start|>user\ncompare [img-0] and [img-1]<|im_end|>\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := Model{
|
||||
Config: model.ConfigV2{Renderer: tt.renderer},
|
||||
ProjectorPaths: []string{"vision"},
|
||||
}
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
|
||||
think := false
|
||||
|
||||
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, want := len(images), 2; got != want {
|
||||
t.Fatalf("len(images) = %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if !strings.Contains(prompt, tt.wantSnippet) {
|
||||
t.Fatalf("prompt missing replaced placeholders, got: %q", prompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderPromptResolvesDynamicGemma4Renderer(t *testing.T) {
|
||||
msgs := []api.Message{{Role: "user", Content: "Hello"}}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue