mirror of
https://github.com/ollama/ollama.git
synced 2026-07-05 07:11:10 +00:00
1826 lines
53 KiB
Go
1826 lines
53 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
)
|
|
|
|
type fakeClient struct {
|
|
calls int
|
|
responses [][]api.ChatResponse
|
|
requests []*api.ChatRequest
|
|
err error
|
|
}
|
|
|
|
func (c *fakeClient) Chat(ctx context.Context, req *api.ChatRequest, fn api.ChatResponseFunc) error {
|
|
c.requests = append(c.requests, req)
|
|
if c.calls >= len(c.responses) {
|
|
return nil
|
|
}
|
|
responses := c.responses[c.calls]
|
|
c.calls++
|
|
for _, response := range responses {
|
|
if err := fn(response); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return c.err
|
|
}
|
|
|
|
type staticTool struct{}
|
|
|
|
type approvalTestTool struct {
|
|
called *bool
|
|
}
|
|
|
|
type cwdTestTool struct{}
|
|
|
|
type largeTool struct{}
|
|
|
|
type preTruncatedTool struct{}
|
|
|
|
type cancelingTool struct {
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
type cancelAfterToolCallClient struct {
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
type recordingCompactor struct {
|
|
requests []CompactionRequest
|
|
}
|
|
|
|
type oversizedCompactor struct {
|
|
requests []CompactionRequest
|
|
}
|
|
|
|
type recordingEventSink struct {
|
|
events []Event
|
|
}
|
|
|
|
func (s *recordingEventSink) Emit(event Event) error {
|
|
s.events = append(s.events, event)
|
|
return nil
|
|
}
|
|
|
|
func hasEventType(events []Event, eventType EventType) bool {
|
|
for _, event := range events {
|
|
if event.Type == eventType {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func hasEventWithTokens(events []Event, eventType EventType, tokens int) bool {
|
|
for _, event := range events {
|
|
if event.Type == eventType && event.Tokens == tokens {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func TestSessionEmitsToAllSinksAfterError(t *testing.T) {
|
|
errSink := EventSinkFunc(func(Event) error {
|
|
return errors.New("sink failed")
|
|
})
|
|
events := &recordingEventSink{}
|
|
session := &Session{EventSinks: []EventSink{errSink, events}}
|
|
|
|
err := session.emit(Event{Type: EventRunFinished})
|
|
if err == nil {
|
|
t.Fatal("emit should return the first sink error")
|
|
}
|
|
if !hasEventType(events.events, EventRunFinished) {
|
|
t.Fatalf("later sink did not receive event after earlier error: %#v", events.events)
|
|
}
|
|
}
|
|
|
|
func (c cancelAfterToolCallClient) Chat(ctx context.Context, req *api.ChatRequest, fn api.ChatResponseFunc) error {
|
|
args := api.NewToolCallFunctionArguments()
|
|
args.Set("value", "skip me")
|
|
if err := fn(api.ChatResponse{Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "echo_tool",
|
|
Arguments: args,
|
|
},
|
|
}}}}); err != nil {
|
|
return err
|
|
}
|
|
c.cancel()
|
|
return context.Canceled
|
|
}
|
|
|
|
func (c *recordingCompactor) MaybeCompact(_ context.Context, req CompactionRequest) (CompactionResult, error) {
|
|
c.requests = append(c.requests, req)
|
|
result := CompactionResult{Messages: req.Messages, Due: true}
|
|
if len(req.Messages) > 0 && req.Messages[len(req.Messages)-1].Role == "tool" {
|
|
result.Messages = CompactionSummaryMessages("tool result summarized", false)
|
|
result.Compacted = true
|
|
result.Summary = "tool result summarized"
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (c *oversizedCompactor) MaybeCompact(_ context.Context, req CompactionRequest) (CompactionResult, error) {
|
|
c.requests = append(c.requests, req)
|
|
summary := strings.Repeat("oversized summary ", 300)
|
|
return CompactionResult{
|
|
Messages: CompactionSummaryMessages(summary, req.ContinueTask),
|
|
Compacted: true,
|
|
Due: true,
|
|
Summary: summary,
|
|
}, nil
|
|
}
|
|
|
|
type recordingApprovalPrompter struct {
|
|
requests []ApprovalRequest
|
|
results []Approval
|
|
}
|
|
|
|
func (staticTool) Name() string {
|
|
return "echo_tool"
|
|
}
|
|
|
|
func (staticTool) Description() string {
|
|
return "echoes a value"
|
|
}
|
|
|
|
func (staticTool) Schema() api.ToolFunction {
|
|
props := api.NewToolPropertiesMap()
|
|
props.Set("value", api.ToolProperty{Type: api.PropertyType{"string"}})
|
|
return api.ToolFunction{
|
|
Name: "echo_tool",
|
|
Description: "echoes a value",
|
|
Parameters: api.ToolFunctionParameters{
|
|
Type: "object",
|
|
Properties: props,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (staticTool) Execute(context.Context, ToolContext, map[string]any) (ToolResult, error) {
|
|
return ToolResult{Content: "tool says hello"}, nil
|
|
}
|
|
|
|
func (largeTool) Name() string {
|
|
return "large_tool"
|
|
}
|
|
|
|
func (largeTool) Description() string {
|
|
return "returns a large result"
|
|
}
|
|
|
|
func (largeTool) Schema() api.ToolFunction {
|
|
return api.ToolFunction{
|
|
Name: "large_tool",
|
|
Description: "returns a large result",
|
|
Parameters: api.ToolFunctionParameters{
|
|
Type: "object",
|
|
},
|
|
}
|
|
}
|
|
|
|
func (largeTool) Execute(context.Context, ToolContext, map[string]any) (ToolResult, error) {
|
|
return ToolResult{Content: strings.Repeat("x", maxToolResultRunes+100)}, nil
|
|
}
|
|
|
|
func (preTruncatedTool) Name() string {
|
|
return "pre_truncated_tool"
|
|
}
|
|
|
|
func (preTruncatedTool) Description() string {
|
|
return "returns a large result that is already marked as truncated"
|
|
}
|
|
|
|
func (preTruncatedTool) Schema() api.ToolFunction {
|
|
return api.ToolFunction{
|
|
Name: "pre_truncated_tool",
|
|
Description: "returns a large result that is already marked as truncated",
|
|
Parameters: api.ToolFunctionParameters{
|
|
Type: "object",
|
|
},
|
|
}
|
|
}
|
|
|
|
func (preTruncatedTool) Execute(context.Context, ToolContext, map[string]any) (ToolResult, error) {
|
|
content := strings.Repeat("x", smallContextToolResultRunes) +
|
|
"\n\n[tool output truncated: showing first ~1500 tokens; omitted ~999 tokens. Use a narrower command, line range, or search query if more detail is needed.]\n\n" +
|
|
strings.Repeat("y", smallContextToolResultRunes)
|
|
return ToolResult{Content: content}, nil
|
|
}
|
|
|
|
func (t cancelingTool) Name() string {
|
|
return "cancel_tool"
|
|
}
|
|
|
|
func (t cancelingTool) Description() string {
|
|
return "cancels while running"
|
|
}
|
|
|
|
func (t cancelingTool) Schema() api.ToolFunction {
|
|
return api.ToolFunction{
|
|
Name: t.Name(),
|
|
Description: t.Description(),
|
|
Parameters: api.ToolFunctionParameters{
|
|
Type: "object",
|
|
},
|
|
}
|
|
}
|
|
|
|
func (t cancelingTool) Execute(ctx context.Context, _ ToolContext, _ map[string]any) (ToolResult, error) {
|
|
t.cancel()
|
|
<-ctx.Done()
|
|
return ToolResult{}, ctx.Err()
|
|
}
|
|
|
|
func (t approvalTestTool) Name() string {
|
|
return "approval_tool"
|
|
}
|
|
|
|
func (t approvalTestTool) Description() string {
|
|
return "requires approval"
|
|
}
|
|
|
|
func (t approvalTestTool) Schema() api.ToolFunction {
|
|
return api.ToolFunction{
|
|
Name: "approval_tool",
|
|
Description: "requires approval",
|
|
Parameters: api.ToolFunctionParameters{
|
|
Type: "object",
|
|
},
|
|
}
|
|
}
|
|
|
|
func (t approvalTestTool) RequiresApproval(map[string]any) bool {
|
|
return true
|
|
}
|
|
|
|
func (t approvalTestTool) Execute(context.Context, ToolContext, map[string]any) (ToolResult, error) {
|
|
if t.called != nil {
|
|
*t.called = true
|
|
}
|
|
return ToolResult{Content: "approved"}, nil
|
|
}
|
|
|
|
func (p *recordingApprovalPrompter) PromptApproval(_ context.Context, req ApprovalRequest) (Approval, error) {
|
|
p.requests = append(p.requests, req)
|
|
if len(p.results) == 0 {
|
|
return Approval{Allow: true}, nil
|
|
}
|
|
result := p.results[0]
|
|
p.results = p.results[1:]
|
|
return result, nil
|
|
}
|
|
|
|
func (cwdTestTool) Name() string {
|
|
return "cwd_tool"
|
|
}
|
|
|
|
func (cwdTestTool) Description() string {
|
|
return "tests cwd state"
|
|
}
|
|
|
|
func (cwdTestTool) Schema() api.ToolFunction {
|
|
props := api.NewToolPropertiesMap()
|
|
props.Set("mode", api.ToolProperty{Type: api.PropertyType{"string"}})
|
|
props.Set("path", api.ToolProperty{Type: api.PropertyType{"string"}})
|
|
return api.ToolFunction{
|
|
Name: "cwd_tool",
|
|
Description: "tests cwd state",
|
|
Parameters: api.ToolFunctionParameters{
|
|
Type: "object",
|
|
Properties: props,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (cwdTestTool) RequiresApproval(map[string]any) bool {
|
|
return true
|
|
}
|
|
|
|
func (cwdTestTool) Execute(_ context.Context, toolCtx ToolContext, args map[string]any) (ToolResult, error) {
|
|
switch args["mode"] {
|
|
case "set":
|
|
path, _ := args["path"].(string)
|
|
return ToolResult{Content: "changed", WorkingDir: filepath.Join(toolCtx.WorkingDir, path)}, nil
|
|
case "escape":
|
|
return ToolResult{Content: "escaped", WorkingDir: filepath.Dir(toolCtx.WorkingDir)}, nil
|
|
default:
|
|
return ToolResult{Content: toolCtx.WorkingDir}, nil
|
|
}
|
|
}
|
|
|
|
func TestSessionRunsToolLoop(t *testing.T) {
|
|
args := api.NewToolCallFunctionArguments()
|
|
args.Set("value", "hello")
|
|
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{
|
|
{Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "echo_tool",
|
|
Arguments: args,
|
|
},
|
|
}}}},
|
|
},
|
|
{
|
|
{Message: api.Message{Role: "assistant", Content: "done"}},
|
|
},
|
|
},
|
|
}
|
|
|
|
registry := &Registry{}
|
|
registry.Register(staticTool{})
|
|
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
ChatID: "chat-1",
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "use a tool"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if client.calls != 2 {
|
|
t.Fatalf("client calls = %d, want 2", client.calls)
|
|
}
|
|
if len(result.Messages) != 4 {
|
|
t.Fatalf("messages = %d, want 4", len(result.Messages))
|
|
}
|
|
if result.Messages[2].Role != "tool" || result.Messages[2].Content != "tool says hello" {
|
|
t.Fatalf("tool message = %#v", result.Messages[2])
|
|
}
|
|
if len(client.requests[0].Tools) != 1 {
|
|
t.Fatalf("first request tools = %d, want 1", len(client.requests[0].Tools))
|
|
}
|
|
if len(client.requests[1].Messages) != 3 {
|
|
t.Fatalf("second request messages = %d, want 3", len(client.requests[1].Messages))
|
|
}
|
|
}
|
|
|
|
func TestSessionAddsSystemPromptOnlyToRequest(t *testing.T) {
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{{Message: api.Message{Role: "assistant", Content: "done"}}},
|
|
},
|
|
}
|
|
session := &Session{Client: client}
|
|
|
|
_, err := session.Run(context.Background(), RunOptions{
|
|
ChatID: "chat-1",
|
|
Model: "model",
|
|
SystemPrompt: "available context: go-code",
|
|
NewMessages: []api.Message{{Role: "user", Content: "hello"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(client.requests) != 1 {
|
|
t.Fatalf("requests = %d, want 1", len(client.requests))
|
|
}
|
|
reqMessages := client.requests[0].Messages
|
|
if len(reqMessages) != 2 || reqMessages[0].Role != "system" || reqMessages[0].Content != "available context: go-code" {
|
|
t.Fatalf("request messages = %#v", reqMessages)
|
|
}
|
|
}
|
|
|
|
func TestSessionAccumulatesStreamingAssistantMessage(t *testing.T) {
|
|
args := api.NewToolCallFunctionArguments()
|
|
args.Set("value", "hello")
|
|
responses := make([]api.ChatResponse, 0, 100)
|
|
var wantContent, wantThinking string
|
|
for range 99 {
|
|
wantContent += "x"
|
|
wantThinking += "t"
|
|
responses = append(responses, api.ChatResponse{
|
|
Message: api.Message{Role: "assistant", Content: "x", Thinking: "t"},
|
|
})
|
|
}
|
|
toolCall := api.ToolCall{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "echo_tool",
|
|
Arguments: args,
|
|
},
|
|
}
|
|
responses = append(responses, api.ChatResponse{
|
|
Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{toolCall}},
|
|
})
|
|
|
|
session := &Session{
|
|
Client: &fakeClient{responses: [][]api.ChatResponse{responses}},
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
ChatID: "chat-1",
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "stream"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if len(result.Messages) != 2 || result.Messages[1].Content != wantContent || result.Messages[1].Thinking != wantThinking || len(result.Messages[1].ToolCalls) != 1 {
|
|
t.Fatalf("result messages = %#v", result.Messages)
|
|
}
|
|
}
|
|
|
|
func TestSessionRequestHistoryKeepsThinkingAndServerToolCallIDs(t *testing.T) {
|
|
args := api.NewToolCallFunctionArguments()
|
|
args.Set("value", "hello")
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{
|
|
{Message: api.Message{Role: "assistant", Thinking: "private chain"}},
|
|
{Message: api.Message{Role: "assistant", Content: "I'll check."}},
|
|
{Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "volatile-random-id",
|
|
Function: api.ToolCallFunction{
|
|
Name: "echo_tool",
|
|
Arguments: args,
|
|
},
|
|
}}}},
|
|
},
|
|
{{Message: api.Message{Role: "assistant", Content: "done"}}},
|
|
},
|
|
}
|
|
registry := &Registry{}
|
|
registry.Register(staticTool{})
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
ChatID: "chat-1",
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "use a tool"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(client.requests) != 2 {
|
|
t.Fatalf("requests = %d, want 2", len(client.requests))
|
|
}
|
|
|
|
secondRequestMessages := client.requests[1].Messages
|
|
if len(secondRequestMessages) != 3 {
|
|
t.Fatalf("second request messages = %#v", secondRequestMessages)
|
|
}
|
|
assistant := secondRequestMessages[1]
|
|
if assistant.Role != "assistant" {
|
|
t.Fatalf("second request assistant = %#v", assistant)
|
|
}
|
|
if assistant.Thinking != "private chain" {
|
|
t.Fatalf("assistant thinking = %q, want preserved", assistant.Thinking)
|
|
}
|
|
if len(assistant.ToolCalls) != 1 || assistant.ToolCalls[0].ID != "volatile-random-id" {
|
|
t.Fatalf("assistant tool calls = %#v", assistant.ToolCalls)
|
|
}
|
|
tool := secondRequestMessages[2]
|
|
if tool.Role != "tool" || tool.ToolCallID != "volatile-random-id" {
|
|
t.Fatalf("tool result message = %#v", tool)
|
|
}
|
|
if len(result.Messages) < 3 || result.Messages[1].Thinking != "private chain" {
|
|
t.Fatalf("visible result messages lost thinking: %#v", result.Messages)
|
|
}
|
|
}
|
|
|
|
func TestSessionKeepsPartialStreamOnCancellation(t *testing.T) {
|
|
session := &Session{
|
|
Client: &fakeClient{
|
|
responses: [][]api.ChatResponse{{
|
|
{Message: api.Message{Role: "assistant", Content: "partial "}},
|
|
{Message: api.Message{Role: "assistant", Content: "answer"}},
|
|
}},
|
|
err: context.Canceled,
|
|
},
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
ChatID: "chat-1",
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "cancel"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(result.Messages) != 2 || result.Messages[1].Content != "partial answer" {
|
|
t.Fatalf("result messages = %#v", result.Messages)
|
|
}
|
|
}
|
|
|
|
func TestSessionCancellationKeepsPartialResultWhenUISinkCancels(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
trace := &recordingEventSink{}
|
|
session := &Session{
|
|
Client: &fakeClient{
|
|
responses: [][]api.ChatResponse{{
|
|
{Message: api.Message{Role: "assistant", Content: "partial"}},
|
|
}},
|
|
err: context.Canceled,
|
|
},
|
|
EventSinks: []EventSink{
|
|
EventSinkFunc(func(event Event) error {
|
|
if event.Type == EventRunFinished {
|
|
return context.Canceled
|
|
}
|
|
return nil
|
|
}),
|
|
trace,
|
|
},
|
|
}
|
|
|
|
result, err := session.Run(ctx, RunOptions{
|
|
ChatID: "chat-1",
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "cancel"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if result == nil || len(result.Messages) != 2 || result.Messages[1].Content != "partial" {
|
|
t.Fatalf("result messages = %#v, want partial assistant result", result)
|
|
}
|
|
if !hasEventType(trace.events, EventRunFinished) {
|
|
t.Fatalf("trace sink did not receive run finished event: %#v", trace.events)
|
|
}
|
|
}
|
|
|
|
func TestSessionTreatsHTTPContextCanceledStringAsCancellation(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
client := &fakeClient{err: errors.New(`Post "http://127.0.0.1:11434/api/chat": context canceled`)}
|
|
session := &Session{Client: client}
|
|
|
|
result, err := session.Run(ctx, RunOptions{
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "hello"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Run returned error for canceled HTTP request: %v", err)
|
|
}
|
|
if result == nil {
|
|
t.Fatal("Run returned nil result")
|
|
}
|
|
if len(result.Messages) != 1 || result.Messages[0].Content != "hello" {
|
|
t.Fatalf("messages = %#v, want original user message only", result.Messages)
|
|
}
|
|
}
|
|
|
|
func TestSessionCancellationAfterToolCallAppendsSkippedToolMessage(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
registry := &Registry{}
|
|
registry.Register(staticTool{})
|
|
session := &Session{
|
|
Client: cancelAfterToolCallClient{cancel: cancel},
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
}
|
|
|
|
result, err := session.Run(ctx, RunOptions{
|
|
ChatID: "chat-1",
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "cancel after tool"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(result.Messages) != 3 {
|
|
t.Fatalf("messages = %#v", result.Messages)
|
|
}
|
|
if len(result.Messages[1].ToolCalls) != 1 {
|
|
t.Fatalf("assistant tool calls = %#v", result.Messages[1])
|
|
}
|
|
if result.Messages[2].Role != "tool" || result.Messages[2].ToolCallID != "call-1" {
|
|
t.Fatalf("skipped tool message = %#v", result.Messages[2])
|
|
}
|
|
if !strings.Contains(result.Messages[2].Content, "run was canceled") {
|
|
t.Fatalf("skipped content = %q", result.Messages[2].Content)
|
|
}
|
|
}
|
|
|
|
func TestSessionCancellationDuringToolExecutionAppendsToolMessage(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
events := &recordingEventSink{}
|
|
registry := &Registry{}
|
|
registry.Register(cancelingTool{cancel: cancel})
|
|
client := &fakeClient{responses: [][]api.ChatResponse{{
|
|
{Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "cancel_tool",
|
|
},
|
|
}}}},
|
|
}}}
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
EventSinks: []EventSink{events},
|
|
}
|
|
|
|
result, err := session.Run(ctx, RunOptions{
|
|
ChatID: "chat-1",
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "cancel during tool"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(result.Messages) != 3 {
|
|
t.Fatalf("messages = %#v", result.Messages)
|
|
}
|
|
if result.Messages[2].Role != "tool" || result.Messages[2].ToolCallID != "call-1" {
|
|
t.Fatalf("tool message = %#v", result.Messages[2])
|
|
}
|
|
if !strings.Contains(result.Messages[2].Content, "context canceled") {
|
|
t.Fatalf("tool content = %q", result.Messages[2].Content)
|
|
}
|
|
var finished *Event
|
|
for i := range events.events {
|
|
if events.events[i].Type == EventRunFinished {
|
|
finished = &events.events[i]
|
|
}
|
|
}
|
|
if finished == nil {
|
|
t.Fatalf("run finished event missing: %#v", events.events)
|
|
}
|
|
if finished.Status != "canceled" {
|
|
t.Fatalf("run status = %q, want canceled", finished.Status)
|
|
}
|
|
}
|
|
|
|
func TestSessionToolLoopAllowsRoundsUnderDefaultCap(t *testing.T) {
|
|
responses := make([][]api.ChatResponse, 0, 26)
|
|
for i := range 25 {
|
|
args := api.NewToolCallFunctionArguments()
|
|
args.Set("value", "hello")
|
|
responses = append(responses, []api.ChatResponse{{
|
|
Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call-" + string(rune('a'+i)),
|
|
Function: api.ToolCallFunction{
|
|
Name: "echo_tool",
|
|
Arguments: args,
|
|
},
|
|
}}},
|
|
}})
|
|
}
|
|
responses = append(responses, []api.ChatResponse{{
|
|
Message: api.Message{Role: "assistant", Content: "done"},
|
|
}})
|
|
|
|
client := &fakeClient{responses: responses}
|
|
registry := &Registry{}
|
|
registry.Register(staticTool{})
|
|
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
}
|
|
|
|
if _, err := session.Run(context.Background(), RunOptions{
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "keep going"}},
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if client.calls != 26 {
|
|
t.Fatalf("client calls = %d, want 26", client.calls)
|
|
}
|
|
}
|
|
|
|
func TestSessionToolRoundLimitAppendsSkippedToolMessages(t *testing.T) {
|
|
firstArgs := api.NewToolCallFunctionArguments()
|
|
firstArgs.Set("value", "first")
|
|
secondArgs := api.NewToolCallFunctionArguments()
|
|
secondArgs.Set("value", "second")
|
|
thirdArgs := api.NewToolCallFunctionArguments()
|
|
thirdArgs.Set("value", "third")
|
|
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{{
|
|
Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "echo_tool",
|
|
Arguments: firstArgs,
|
|
},
|
|
}}},
|
|
}},
|
|
{{
|
|
Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call-2",
|
|
Function: api.ToolCallFunction{
|
|
Name: "echo_tool",
|
|
Arguments: secondArgs,
|
|
},
|
|
},
|
|
{
|
|
ID: "call-3",
|
|
Function: api.ToolCallFunction{
|
|
Name: "echo_tool",
|
|
Arguments: thirdArgs,
|
|
},
|
|
},
|
|
}},
|
|
}},
|
|
},
|
|
}
|
|
registry := &Registry{}
|
|
registry.Register(staticTool{})
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
ChatID: "chat-1",
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "hit cap"}},
|
|
MaxToolRounds: 1,
|
|
})
|
|
if err == nil || !strings.Contains(err.Error(), "tool round limit reached after 1 rounds") {
|
|
t.Fatalf("error = %v, want tool-round limit", err)
|
|
}
|
|
if result == nil {
|
|
t.Fatal("expected partial result with skipped tool messages")
|
|
}
|
|
if len(result.Messages) != 6 {
|
|
t.Fatalf("messages = %#v", result.Messages)
|
|
}
|
|
for i, wantID := range []string{"call-2", "call-3"} {
|
|
msg := result.Messages[4+i]
|
|
if msg.Role != "tool" || msg.ToolCallID != wantID {
|
|
t.Fatalf("skipped tool %d = %#v", i, msg)
|
|
}
|
|
if !strings.Contains(msg.Content, "max tool-round limit of 1") {
|
|
t.Fatalf("skipped content = %q", msg.Content)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSessionToolLoopStopsAtDefaultRoundCap(t *testing.T) {
|
|
responses := make([][]api.ChatResponse, 0, defaultMaxToolRounds+1)
|
|
for range defaultMaxToolRounds + 1 {
|
|
args := api.NewToolCallFunctionArguments()
|
|
args.Set("value", "hello")
|
|
responses = append(responses, []api.ChatResponse{{
|
|
Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call",
|
|
Function: api.ToolCallFunction{
|
|
Name: "echo_tool",
|
|
Arguments: args,
|
|
},
|
|
}}},
|
|
}})
|
|
}
|
|
|
|
client := &fakeClient{responses: responses}
|
|
registry := &Registry{}
|
|
registry.Register(staticTool{})
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
}
|
|
|
|
_, err := session.Run(context.Background(), RunOptions{
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "keep going"}},
|
|
})
|
|
if err == nil || !strings.Contains(err.Error(), "tool round limit reached after 100 rounds") {
|
|
t.Fatalf("error = %v, want default tool round limit", err)
|
|
}
|
|
if client.calls != defaultMaxToolRounds+1 {
|
|
t.Fatalf("client calls = %d, want %d", client.calls, defaultMaxToolRounds+1)
|
|
}
|
|
}
|
|
|
|
func TestSessionToolLoopNegativeLimitIsUnlimited(t *testing.T) {
|
|
responses := make([][]api.ChatResponse, 0, defaultMaxToolRounds+2)
|
|
for range defaultMaxToolRounds + 1 {
|
|
args := api.NewToolCallFunctionArguments()
|
|
args.Set("value", "hello")
|
|
responses = append(responses, []api.ChatResponse{{
|
|
Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call",
|
|
Function: api.ToolCallFunction{
|
|
Name: "echo_tool",
|
|
Arguments: args,
|
|
},
|
|
}}},
|
|
}})
|
|
}
|
|
responses = append(responses, []api.ChatResponse{{
|
|
Message: api.Message{Role: "assistant", Content: "done"},
|
|
}})
|
|
|
|
client := &fakeClient{responses: responses}
|
|
registry := &Registry{}
|
|
registry.Register(staticTool{})
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
}
|
|
|
|
if _, err := session.Run(context.Background(), RunOptions{
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "keep going"}},
|
|
MaxToolRounds: -1,
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if client.calls != defaultMaxToolRounds+2 {
|
|
t.Fatalf("client calls = %d, want %d", client.calls, defaultMaxToolRounds+2)
|
|
}
|
|
}
|
|
|
|
func TestSessionTruncatesLargeToolResultsBeforeHistory(t *testing.T) {
|
|
args := api.NewToolCallFunctionArguments()
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{{
|
|
Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "large_tool",
|
|
Arguments: args,
|
|
},
|
|
}}},
|
|
}},
|
|
{{Message: api.Message{Role: "assistant", Content: "done"}}},
|
|
},
|
|
}
|
|
registry := &Registry{}
|
|
registry.Register(largeTool{})
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
ChatID: "chat-1",
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "use a tool"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(result.Messages) < 3 {
|
|
t.Fatalf("messages = %#v", result.Messages)
|
|
}
|
|
content := result.Messages[2].Content
|
|
if !strings.Contains(content, "[tool output truncated: showing first ~") ||
|
|
!strings.Contains(content, "omitted ~25 tokens") ||
|
|
!strings.Contains(content, "Use a narrower command, line range, or search query") {
|
|
t.Fatalf("tool content missing truncation marker: %q", content)
|
|
}
|
|
if strings.Count(content, "x") != maxToolResultRunes {
|
|
t.Fatalf("truncated content x count = %d, want %d", strings.Count(content, "x"), maxToolResultRunes)
|
|
}
|
|
requestContent := client.requests[1].Messages[2].Content
|
|
if !strings.Contains(requestContent, "[tool output truncated: showing first ~") {
|
|
t.Fatalf("second model request did not use capped tool content: %q", requestContent)
|
|
}
|
|
if strings.Count(requestContent, "x") > maxToolResultRunes {
|
|
t.Fatalf("request tool content x count = %d, want at most %d", strings.Count(requestContent, "x"), maxToolResultRunes)
|
|
}
|
|
}
|
|
|
|
func TestSessionSmallContextUsesLowerToolResultPreviewCap(t *testing.T) {
|
|
args := api.NewToolCallFunctionArguments()
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{{
|
|
Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "large_tool",
|
|
Arguments: args,
|
|
},
|
|
}}},
|
|
}},
|
|
{{Message: api.Message{Role: "assistant", Content: "done"}}},
|
|
},
|
|
}
|
|
registry := &Registry{}
|
|
registry.Register(largeTool{})
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
Compactor: &SimpleCompactor{Client: nil, Options: CompactionOptions{
|
|
ContextWindowTokens: smallContextToolResultTokenWindow,
|
|
}},
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "use a tool"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
content := result.Messages[2].Content
|
|
if !strings.Contains(content, "[tool output truncated: showing first ~") ||
|
|
!strings.Contains(content, "Use a narrower command, line range, or search query") {
|
|
t.Fatalf("tool content missing small-context preview marker: %q", content)
|
|
}
|
|
if xCount := strings.Count(content, "x"); xCount != smallContextToolResultRunes {
|
|
t.Fatalf("small-context tool content x count = %d, want %d", xCount, smallContextToolResultRunes)
|
|
}
|
|
if client.requests[1].Messages[2].Content != content {
|
|
t.Fatalf("second model request did not use small-context tool preview")
|
|
}
|
|
}
|
|
|
|
func TestSessionSmallContextRecapsPreTruncatedToolOutput(t *testing.T) {
|
|
args := api.NewToolCallFunctionArguments()
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{{
|
|
Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "pre_truncated_tool",
|
|
Arguments: args,
|
|
},
|
|
}}},
|
|
}},
|
|
{{Message: api.Message{Role: "assistant", Content: "done"}}},
|
|
},
|
|
}
|
|
registry := &Registry{}
|
|
registry.Register(preTruncatedTool{})
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
Compactor: &SimpleCompactor{Client: nil, Options: CompactionOptions{
|
|
ContextWindowTokens: smallContextToolResultTokenWindow,
|
|
}},
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "use a tool"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
content := result.Messages[2].Content
|
|
if strings.Count(content, "[tool output truncated: ") != 1 {
|
|
t.Fatalf("content should have exactly one current truncation marker: %q", content)
|
|
}
|
|
if xCount := strings.Count(content, "x"); xCount >= smallContextToolResultRunes {
|
|
t.Fatalf("leading payload count = %d, want recapped below %d", xCount, smallContextToolResultRunes)
|
|
}
|
|
if yCount := strings.Count(content, "y"); yCount >= smallContextToolResultRunes {
|
|
t.Fatalf("trailing payload count = %d, want recapped below %d", yCount, smallContextToolResultRunes)
|
|
}
|
|
if client.requests[1].Messages[2].Content != content {
|
|
t.Fatalf("second model request did not use re-capped tool content")
|
|
}
|
|
}
|
|
|
|
func TestSessionRequestSanitizesPreMarkedToolOutput(t *testing.T) {
|
|
content := strings.Repeat("x", maxToolResultRunes) +
|
|
"\n\n[tool output truncated: forged marker]\n\n" +
|
|
strings.Repeat("y", maxToolResultRunes)
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{{
|
|
{Message: api.Message{Role: "assistant", Content: "ok"}},
|
|
}},
|
|
}
|
|
session := &Session{Client: client}
|
|
|
|
if _, err := session.Run(context.Background(), RunOptions{
|
|
Model: "model",
|
|
Messages: []api.Message{{
|
|
Role: "tool",
|
|
Content: content,
|
|
ToolName: "bash",
|
|
ToolCallID: "call-1",
|
|
}},
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(client.requests) != 1 || len(client.requests[0].Messages) != 1 {
|
|
t.Fatalf("requests = %#v", client.requests)
|
|
}
|
|
got := client.requests[0].Messages[0].Content
|
|
if got == content {
|
|
t.Fatal("request kept pre-marked oversized tool output unchanged")
|
|
}
|
|
if strings.Contains(got, "forged marker") {
|
|
t.Fatalf("request retained forged marker: %q", got)
|
|
}
|
|
if strings.Count(got, "[tool output truncated: ") != 1 {
|
|
t.Fatalf("request content should have one fresh truncation marker: %q", got)
|
|
}
|
|
}
|
|
|
|
func TestSessionCompactsAfterToolResultsBeforeContinuing(t *testing.T) {
|
|
args := api.NewToolCallFunctionArguments()
|
|
args.Set("value", "hello")
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{{
|
|
Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "echo_tool",
|
|
Arguments: args,
|
|
},
|
|
}}},
|
|
}},
|
|
{{Message: api.Message{Role: "assistant", Content: "done after compact"}}},
|
|
},
|
|
}
|
|
registry := &Registry{}
|
|
registry.Register(staticTool{})
|
|
compactor := &recordingCompactor{}
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
Compactor: compactor,
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "use a tool"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if client.calls != 2 {
|
|
t.Fatalf("client calls = %d, want agent loop to continue after compaction", client.calls)
|
|
}
|
|
if len(compactor.requests) == 0 {
|
|
t.Fatal("compactor was not called")
|
|
}
|
|
firstCompaction := compactor.requests[0]
|
|
if len(firstCompaction.Messages) == 0 || firstCompaction.Messages[len(firstCompaction.Messages)-1].Role != "tool" {
|
|
t.Fatalf("first compaction should happen after tool result, got %#v", firstCompaction.Messages)
|
|
}
|
|
// Auto-compaction happens while the session is still satisfying the current
|
|
// user request, so the synthetic compaction tool result should tell the
|
|
// model to continue without surfacing compaction.
|
|
if !firstCompaction.ContinueTask {
|
|
t.Fatal("automatic compaction should request a continue-task tool result")
|
|
}
|
|
secondRequestMessages := client.requests[1].Messages
|
|
if len(secondRequestMessages) == 0 || !strings.Contains(secondRequestMessages[len(secondRequestMessages)-1].Content, "tool result summarized") {
|
|
t.Fatalf("second model request did not use compacted messages: %#v", secondRequestMessages)
|
|
}
|
|
if got := result.Messages[len(result.Messages)-1].Content; got != "done after compact" {
|
|
t.Fatalf("final response = %q", got)
|
|
}
|
|
}
|
|
|
|
func TestSessionStopsWhenCompactedHistoryStillExceedsContext(t *testing.T) {
|
|
args := api.NewToolCallFunctionArguments()
|
|
args.Set("value", "hello")
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{{
|
|
Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "echo_tool",
|
|
Arguments: args,
|
|
},
|
|
}}},
|
|
}},
|
|
{{Message: api.Message{Role: "assistant", Content: "should not run"}}},
|
|
},
|
|
}
|
|
registry := &Registry{}
|
|
registry.Register(staticTool{})
|
|
events := &recordingEventSink{}
|
|
compactor := &oversizedCompactor{}
|
|
session := &Session{
|
|
Client: client,
|
|
EventSinks: []EventSink{events},
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
Compactor: compactor,
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "use a tool"}},
|
|
Options: map[string]any{"num_ctx": 512},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected post-compaction context error")
|
|
}
|
|
if !strings.Contains(err.Error(), "still too large after compaction") || !strings.Contains(err.Error(), "fresh request") {
|
|
t.Fatalf("error = %q, want actionable post-compaction guidance", err.Error())
|
|
}
|
|
if result == nil {
|
|
t.Fatal("expected partial result with compacted messages")
|
|
}
|
|
if client.calls != 1 || len(client.requests) != 1 {
|
|
t.Fatalf("client calls = %d requests = %d, want no request after oversized compaction", client.calls, len(client.requests))
|
|
}
|
|
if len(compactor.requests) != 1 {
|
|
t.Fatalf("compactor requests = %d, want 1", len(compactor.requests))
|
|
}
|
|
if !hasEventType(events.events, EventCompacted) {
|
|
t.Fatalf("events missing compacted event: %#v", events.events)
|
|
}
|
|
if !hasEventType(events.events, EventError) {
|
|
t.Fatalf("events missing post-compaction error: %#v", events.events)
|
|
}
|
|
if len(result.Messages) == 0 || !strings.Contains(result.Messages[len(result.Messages)-1].Content, "Conversation summary:") {
|
|
t.Fatalf("result should retain compacted summary messages: %#v", result.Messages)
|
|
}
|
|
}
|
|
|
|
func TestSessionContextCapsToolResultBeforeCompaction(t *testing.T) {
|
|
args := api.NewToolCallFunctionArguments()
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{{
|
|
Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "large_tool",
|
|
Arguments: args,
|
|
},
|
|
}}},
|
|
}},
|
|
{{Message: api.Message{Role: "assistant", Content: "done"}}},
|
|
},
|
|
}
|
|
registry := &Registry{}
|
|
registry.Register(largeTool{})
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
Compactor: &SimpleCompactor{Client: nil, Options: CompactionOptions{
|
|
ContextWindowTokens: 100,
|
|
Threshold: 0.8,
|
|
}},
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "use a tool"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
content := result.Messages[2].Content
|
|
if !strings.Contains(content, "[tool output truncated: ") ||
|
|
!strings.Contains(content, "Use a narrower command, line range, or search query") {
|
|
t.Fatalf("tool content missing truncation marker: %q", content)
|
|
}
|
|
if xCount := strings.Count(content, "x"); xCount >= maxToolResultRunes {
|
|
t.Fatalf("context-capped content x count = %d, want less than hard cap", xCount)
|
|
}
|
|
if client.requests[1].Messages[2].Content != content {
|
|
t.Fatalf("second model request did not use context-capped tool content")
|
|
}
|
|
}
|
|
|
|
func TestSessionCompactsThenReattachesFullyOmittedToolResult(t *testing.T) {
|
|
args := api.NewToolCallFunctionArguments()
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{{
|
|
Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "large_tool",
|
|
Arguments: args,
|
|
},
|
|
}}},
|
|
}},
|
|
{{Message: api.Message{Role: "assistant", Content: "older history summarized"}}},
|
|
{{Message: api.Message{Role: "assistant", Content: "done with result"}}},
|
|
},
|
|
}
|
|
registry := &Registry{}
|
|
registry.Register(largeTool{})
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
Compactor: &SimpleCompactor{Client: client, Options: CompactionOptions{
|
|
ContextWindowTokens: smallContextToolResultTokenWindow,
|
|
Threshold: 0.45,
|
|
}},
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
Model: "model",
|
|
Messages: []api.Message{{Role: "user", Content: strings.Repeat("history ", 2000)}},
|
|
NewMessages: []api.Message{{Role: "user", Content: "use a large tool"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if client.calls != 3 {
|
|
t.Fatalf("client calls = %d, want model, compaction, model", client.calls)
|
|
}
|
|
if len(client.requests) != 3 {
|
|
t.Fatalf("requests = %d, want 3", len(client.requests))
|
|
}
|
|
|
|
nextRequestMessages := client.requests[2].Messages
|
|
if len(nextRequestMessages) != 4 {
|
|
t.Fatalf("next model request messages = %#v, want summary pair plus tool call/result", nextRequestMessages)
|
|
}
|
|
if nextRequestMessages[0].Role != "assistant" || len(nextRequestMessages[0].ToolCalls) != 1 || nextRequestMessages[0].ToolCalls[0].Function.Name != CompactionToolName {
|
|
t.Fatalf("first message should be compaction summary tool call: %#v", nextRequestMessages[0])
|
|
}
|
|
if nextRequestMessages[1].Role != "tool" || nextRequestMessages[1].ToolName != CompactionToolName || !strings.Contains(nextRequestMessages[1].Content, "older history summarized") {
|
|
t.Fatalf("second message should be compaction summary result: %#v", nextRequestMessages[1])
|
|
}
|
|
if nextRequestMessages[2].Role != "assistant" || len(nextRequestMessages[2].ToolCalls) != 1 || nextRequestMessages[2].ToolCalls[0].ID != "call-1" {
|
|
t.Fatalf("third message should be original assistant tool call: %#v", nextRequestMessages[2])
|
|
}
|
|
toolResult := nextRequestMessages[3]
|
|
if toolResult.Role != "tool" || toolResult.ToolName != "large_tool" || toolResult.ToolCallID != "call-1" {
|
|
t.Fatalf("fourth message should be reattached large tool result: %#v", toolResult)
|
|
}
|
|
if toolOutputFullyOmitted(toolResult.Content) {
|
|
t.Fatalf("tool result should be re-fitted after compaction, got full omission marker: %q", toolResult.Content)
|
|
}
|
|
if !strings.Contains(toolResult.Content, "[tool output truncated: showing first ~") {
|
|
t.Fatalf("tool result should still be bounded after compaction: %q", toolResult.Content)
|
|
}
|
|
if strings.Count(toolResult.Content, "x") != smallContextToolResultRunes {
|
|
t.Fatalf("tool result x count = %d, want %d", strings.Count(toolResult.Content, "x"), smallContextToolResultRunes)
|
|
}
|
|
if got := result.Messages[len(result.Messages)-1].Content; got != "done with result" {
|
|
t.Fatalf("final response = %q", got)
|
|
}
|
|
}
|
|
|
|
func TestSessionEmitsAutoCompactionActivityEvents(t *testing.T) {
|
|
args := api.NewToolCallFunctionArguments()
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{{
|
|
Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "large_tool",
|
|
Arguments: args,
|
|
},
|
|
}}},
|
|
}},
|
|
{{Message: api.Message{Role: "assistant", Content: "summary"}, Metrics: api.Metrics{EvalCount: 7}}},
|
|
{{Message: api.Message{Role: "assistant", Content: "done"}}},
|
|
},
|
|
}
|
|
registry := &Registry{}
|
|
registry.Register(largeTool{})
|
|
events := &recordingEventSink{}
|
|
session := &Session{
|
|
Client: client,
|
|
EventSinks: []EventSink{events},
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
Compactor: &SimpleCompactor{Client: client, Options: CompactionOptions{
|
|
ContextWindowTokens: 300,
|
|
Threshold: 0.3,
|
|
}},
|
|
}
|
|
|
|
if _, err := session.Run(context.Background(), RunOptions{
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "use a tool"}},
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if !hasEventType(events.events, EventCompactionStarted) {
|
|
t.Fatalf("events missing compaction start: %#v", events.events)
|
|
}
|
|
if !hasEventWithTokens(events.events, EventCompactionProgress, 7) {
|
|
t.Fatalf("events missing compaction progress tokens: %#v", events.events)
|
|
}
|
|
if !hasEventType(events.events, EventCompacted) {
|
|
t.Fatalf("events missing compacted event: %#v", events.events)
|
|
}
|
|
}
|
|
|
|
func TestSessionTruncatesSeededToolMessagesBeforeHistory(t *testing.T) {
|
|
largeContent := strings.Repeat("x", maxToolResultRunes+100)
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{{
|
|
{Message: api.Message{Role: "assistant", Content: "done"}},
|
|
}},
|
|
}
|
|
session := &Session{
|
|
Client: client,
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
ChatID: "chat-1",
|
|
Model: "model",
|
|
NewMessages: []api.Message{
|
|
{Role: "user", Content: "use seeded tool"},
|
|
{Role: "tool", ToolName: "example_tool", ToolCallID: "call-1", Content: largeContent},
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(result.Messages) < 2 {
|
|
t.Fatalf("messages = %#v", result.Messages)
|
|
}
|
|
content := result.Messages[1].Content
|
|
if !strings.Contains(content, "[tool output truncated: showing first ~") ||
|
|
!strings.Contains(content, "omitted ~25 tokens") {
|
|
t.Fatalf("seeded tool content missing truncation marker: %q", content)
|
|
}
|
|
requestContent := client.requests[0].Messages[1].Content
|
|
if !strings.Contains(requestContent, "[tool output truncated: showing first ~") {
|
|
t.Fatalf("model request did not use capped seeded tool content: %q", requestContent)
|
|
}
|
|
if strings.Count(requestContent, "x") > maxToolResultRunes {
|
|
t.Fatalf("request seeded tool content x count = %d, want at most %d", strings.Count(requestContent, "x"), maxToolResultRunes)
|
|
}
|
|
}
|
|
|
|
func TestSessionPreflightRejectsOversizedFirstRequest(t *testing.T) {
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{{
|
|
{Message: api.Message{Role: "assistant", Content: "should not run"}},
|
|
}},
|
|
}
|
|
events := &recordingEventSink{}
|
|
session := &Session{
|
|
Client: client,
|
|
EventSinks: []EventSink{events},
|
|
Compactor: &SimpleCompactor{Client: nil, Options: CompactionOptions{
|
|
ContextWindowTokens: 128,
|
|
}},
|
|
}
|
|
|
|
_, err := session.Run(context.Background(), RunOptions{
|
|
ChatID: "chat-1",
|
|
Model: "model",
|
|
SystemPrompt: strings.Repeat("system instructions ", 200),
|
|
NewMessages: []api.Message{{Role: "user", Content: "hello"}},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected preflight context error")
|
|
}
|
|
if !strings.Contains(err.Error(), "Reduce the system prompt or message history") || !strings.Contains(err.Error(), "compact the conversation") {
|
|
t.Fatalf("error = %q, want actionable prompt guidance", err.Error())
|
|
}
|
|
if len(client.requests) != 0 {
|
|
t.Fatalf("chat requests = %d, want none before preflight passes", len(client.requests))
|
|
}
|
|
if !hasEventType(events.events, EventError) {
|
|
t.Fatalf("events missing error: %#v", events.events)
|
|
}
|
|
}
|
|
|
|
func TestSessionPreflightIgnoresRawImageBytes(t *testing.T) {
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{{
|
|
{Message: api.Message{Role: "assistant", Content: "image received"}},
|
|
}},
|
|
}
|
|
session := &Session{
|
|
Client: client,
|
|
Compactor: &SimpleCompactor{Client: nil, Options: CompactionOptions{
|
|
ContextWindowTokens: 128,
|
|
}},
|
|
}
|
|
|
|
image := make(api.ImageData, 64*1024)
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
ChatID: "chat-1",
|
|
Model: "model",
|
|
NewMessages: []api.Message{{
|
|
Role: "user",
|
|
Content: "describe this image",
|
|
Images: []api.ImageData{image},
|
|
}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(client.requests) != 1 {
|
|
t.Fatalf("chat requests = %d, want 1", len(client.requests))
|
|
}
|
|
if got := client.requests[0].Messages[0].Images; len(got) != 1 || len(got[0]) != len(image) {
|
|
t.Fatalf("request images = %#v, want original image payload", got)
|
|
}
|
|
if len(result.Messages) == 0 || result.Messages[len(result.Messages)-1].Content != "image received" {
|
|
t.Fatalf("result messages = %#v", result.Messages)
|
|
}
|
|
}
|
|
|
|
func TestSessionFreezesBatchToolWorkingDirAfterApproval(t *testing.T) {
|
|
root := t.TempDir()
|
|
if err := os.Mkdir(filepath.Join(root, "sub"), 0o755); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
setArgs := api.NewToolCallFunctionArguments()
|
|
setArgs.Set("mode", "set")
|
|
setArgs.Set("path", "sub")
|
|
echoArgs := api.NewToolCallFunctionArguments()
|
|
echoArgs.Set("mode", "echo")
|
|
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{
|
|
{Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "cwd_tool",
|
|
Arguments: setArgs,
|
|
},
|
|
},
|
|
{
|
|
ID: "call-2",
|
|
Function: api.ToolCallFunction{
|
|
Name: "cwd_tool",
|
|
Arguments: echoArgs,
|
|
},
|
|
},
|
|
}}},
|
|
},
|
|
{
|
|
{Message: api.Message{Role: "assistant", Content: "done"}},
|
|
},
|
|
},
|
|
}
|
|
registry := &Registry{}
|
|
registry.Register(cwdTestTool{})
|
|
prompter := &recordingApprovalPrompter{results: []Approval{{Allow: true}}}
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
ApprovalPrompter: prompter,
|
|
WorkingDir: root,
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "use cwd"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
want, err := filepath.EvalSymlinks(filepath.Join(root, "sub"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
approvedRoot := root
|
|
if len(prompter.requests) != 1 {
|
|
t.Fatalf("approval requests = %d, want 1", len(prompter.requests))
|
|
}
|
|
if prompter.requests[0].WorkingDir != approvedRoot {
|
|
t.Fatalf("approval cwd = %q, want %q", prompter.requests[0].WorkingDir, approvedRoot)
|
|
}
|
|
if session.WorkingDir != want {
|
|
t.Fatalf("session cwd = %q, want %q", session.WorkingDir, want)
|
|
}
|
|
if result.WorkingDir != want {
|
|
t.Fatalf("result cwd = %q, want %q", result.WorkingDir, want)
|
|
}
|
|
if result.Messages[2].Content != "changed" {
|
|
t.Fatalf("cwd change tool content = %q, want unchanged output", result.Messages[2].Content)
|
|
}
|
|
if result.Messages[3].Content != approvedRoot {
|
|
t.Fatalf("second tool saw cwd %q, want approved cwd %q", result.Messages[3].Content, approvedRoot)
|
|
}
|
|
}
|
|
|
|
func TestSessionAllowsToolWorkingDirOutsideInitialDir(t *testing.T) {
|
|
root := t.TempDir()
|
|
escapeArgs := api.NewToolCallFunctionArguments()
|
|
escapeArgs.Set("mode", "escape")
|
|
echoArgs := api.NewToolCallFunctionArguments()
|
|
echoArgs.Set("mode", "echo")
|
|
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{
|
|
{Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "cwd_tool",
|
|
Arguments: escapeArgs,
|
|
},
|
|
},
|
|
{
|
|
ID: "call-2",
|
|
Function: api.ToolCallFunction{
|
|
Name: "cwd_tool",
|
|
Arguments: echoArgs,
|
|
},
|
|
},
|
|
}}},
|
|
},
|
|
{
|
|
{Message: api.Message{Role: "assistant", Content: "done"}},
|
|
},
|
|
},
|
|
}
|
|
registry := &Registry{}
|
|
registry.Register(cwdTestTool{})
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
WorkingDir: root,
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "use cwd"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
want, err := filepath.EvalSymlinks(filepath.Dir(root))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
approvedRoot := root
|
|
if session.WorkingDir != want {
|
|
t.Fatalf("session cwd = %q, want %q", session.WorkingDir, want)
|
|
}
|
|
if result.Messages[2].Content != "escaped" {
|
|
t.Fatalf("escape tool content = %q, want unchanged output", result.Messages[2].Content)
|
|
}
|
|
if result.Messages[3].Content != approvedRoot {
|
|
t.Fatalf("second tool saw cwd %q, want original cwd %q", result.Messages[3].Content, approvedRoot)
|
|
}
|
|
}
|
|
|
|
func TestSessionDeniesWithoutApprovalPrompter(t *testing.T) {
|
|
args := api.NewToolCallFunctionArguments()
|
|
echoArgs := api.NewToolCallFunctionArguments()
|
|
echoArgs.Set("value", "should not run")
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{
|
|
{Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "approval_tool",
|
|
Arguments: args,
|
|
},
|
|
},
|
|
{
|
|
ID: "call-2",
|
|
Function: api.ToolCallFunction{
|
|
Name: "echo_tool",
|
|
Arguments: echoArgs,
|
|
},
|
|
},
|
|
}}},
|
|
},
|
|
{
|
|
{Message: api.Message{Role: "assistant", Content: "done"}},
|
|
},
|
|
},
|
|
}
|
|
called := false
|
|
registry := &Registry{}
|
|
registry.Register(approvalTestTool{called: &called})
|
|
registry.Register(staticTool{})
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
ChatID: "chat-1",
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "use a tool"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if called {
|
|
t.Fatal("tool executed despite denied approval")
|
|
}
|
|
if client.calls != 1 {
|
|
t.Fatalf("client calls = %d, want 1 after denial", client.calls)
|
|
}
|
|
if len(result.Messages) != 4 {
|
|
t.Fatalf("messages = %#v", result.Messages)
|
|
}
|
|
if result.Messages[2].Role != "tool" || result.Messages[2].ToolCallID != "call-1" {
|
|
t.Fatalf("denial tool message = %#v", result.Messages[2])
|
|
}
|
|
if result.Messages[2].Content == "" || result.Messages[2].Content == "approved" || result.Messages[2].Content == "tool says hello" {
|
|
t.Fatalf("tool denial content = %q", result.Messages[2].Content)
|
|
}
|
|
if result.Messages[3].Role != "tool" || result.Messages[3].ToolCallID != "call-2" {
|
|
t.Fatalf("second denial tool message = %#v", result.Messages[3])
|
|
}
|
|
if result.Messages[3].Content == "" || result.Messages[3].Content == "tool says hello" {
|
|
t.Fatalf("second denial content = %q", result.Messages[3].Content)
|
|
}
|
|
}
|
|
|
|
func TestSessionPromptsOnceForApprovalBatch(t *testing.T) {
|
|
args := api.NewToolCallFunctionArguments()
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{
|
|
{Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{
|
|
{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "approval_tool",
|
|
Arguments: args,
|
|
},
|
|
},
|
|
{
|
|
ID: "call-2",
|
|
Function: api.ToolCallFunction{
|
|
Name: "approval_tool",
|
|
Arguments: args,
|
|
},
|
|
},
|
|
}}},
|
|
},
|
|
{
|
|
{Message: api.Message{Role: "assistant", Content: "done"}},
|
|
},
|
|
},
|
|
}
|
|
called := false
|
|
registry := &Registry{}
|
|
registry.Register(approvalTestTool{called: &called})
|
|
prompter := &recordingApprovalPrompter{
|
|
results: []Approval{{Reason: "denied"}},
|
|
}
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
ApprovalPrompter: prompter,
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "use tools"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(prompter.requests) != 1 {
|
|
t.Fatalf("approval prompts = %d, want 1", len(prompter.requests))
|
|
}
|
|
if len(prompter.requests[0].Calls) != 2 {
|
|
t.Fatalf("approval calls = %#v, want both tool calls", prompter.requests[0].Calls)
|
|
}
|
|
if called {
|
|
t.Fatal("tool ran despite denied approval")
|
|
}
|
|
if client.calls != 1 {
|
|
t.Fatalf("client calls = %d, want 1 after denial", client.calls)
|
|
}
|
|
if len(result.Messages) != 4 || result.Messages[2].Role != "tool" || result.Messages[3].Role != "tool" {
|
|
t.Fatalf("messages = %#v", result.Messages)
|
|
}
|
|
}
|
|
|
|
func TestSessionAllowAllApprovalSkipsFuturePrompts(t *testing.T) {
|
|
args := api.NewToolCallFunctionArguments()
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{
|
|
{Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "approval_tool",
|
|
Arguments: args,
|
|
},
|
|
}}}},
|
|
},
|
|
{
|
|
{Message: api.Message{Role: "assistant", Content: "done"}},
|
|
},
|
|
{
|
|
{Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call-2",
|
|
Function: api.ToolCallFunction{
|
|
Name: "approval_tool",
|
|
Arguments: args,
|
|
},
|
|
}}}},
|
|
},
|
|
{
|
|
{Message: api.Message{Role: "assistant", Content: "done again"}},
|
|
},
|
|
},
|
|
}
|
|
registry := &Registry{}
|
|
registry.Register(approvalTestTool{})
|
|
prompter := &recordingApprovalPrompter{
|
|
results: []Approval{{AllowAll: true}},
|
|
}
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
ApprovalPrompter: prompter,
|
|
}
|
|
|
|
for range 2 {
|
|
if _, err := session.Run(context.Background(), RunOptions{
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "use a tool"}},
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
if !session.AllowAllTools {
|
|
t.Fatal("session did not remember allow all")
|
|
}
|
|
if len(prompter.requests) != 1 {
|
|
t.Fatalf("approval prompts = %d, want 1", len(prompter.requests))
|
|
}
|
|
}
|
|
|
|
func TestSessionAllowAllToolsExecutesApprovalTool(t *testing.T) {
|
|
args := api.NewToolCallFunctionArguments()
|
|
client := &fakeClient{
|
|
responses: [][]api.ChatResponse{
|
|
{
|
|
{Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{{
|
|
ID: "call-1",
|
|
Function: api.ToolCallFunction{
|
|
Name: "approval_tool",
|
|
Arguments: args,
|
|
},
|
|
}}}},
|
|
},
|
|
{
|
|
{Message: api.Message{Role: "assistant", Content: "done"}},
|
|
},
|
|
},
|
|
}
|
|
called := false
|
|
registry := &Registry{}
|
|
registry.Register(approvalTestTool{called: &called})
|
|
session := &Session{
|
|
Client: client,
|
|
Tools: registry,
|
|
AllowAllTools: true,
|
|
}
|
|
|
|
result, err := session.Run(context.Background(), RunOptions{
|
|
Model: "model",
|
|
NewMessages: []api.Message{{Role: "user", Content: "use a tool"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !called {
|
|
t.Fatal("tool did not execute")
|
|
}
|
|
if result.Messages[2].Content != "approved" {
|
|
t.Fatalf("tool content = %q, want approved", result.Messages[2].Content)
|
|
}
|
|
}
|