ollama/agent/tools/bash.go
2026-07-02 11:44:31 -07:00

388 lines
9.5 KiB
Go

package tools
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"unicode/utf8"
"github.com/ollama/ollama/agent"
"github.com/ollama/ollama/api"
)
const (
bashTimeout = 3 * time.Minute
bashWaitDelay = 1 * time.Second
maxBashOutputBytes = 60_000
)
type Bash struct{}
func (b *Bash) Name() string {
return shellToolName()
}
func (b *Bash) Description() string {
return shellToolDescription()
}
func (b *Bash) Schema() api.ToolFunction {
props := api.NewToolPropertiesMap()
props.Set("command", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: shellCommandDescription(),
})
return api.ToolFunction{
Name: b.Name(),
Description: b.Description(),
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
Required: []string{"command"},
},
}
}
func (b *Bash) RequiresApproval(map[string]any) bool {
return true
}
func (b *Bash) Execute(ctx context.Context, toolCtx agent.ToolContext, args map[string]any) (agent.ToolResult, error) {
command, ok := args["command"].(string)
if !ok || strings.TrimSpace(command) == "" {
return agent.ToolResult{}, fmt.Errorf("command parameter is required")
}
if err := rejectUnsafeShellCommand(command); err != nil {
return agent.ToolResult{}, err
}
ctx, cancel := context.WithTimeout(ctx, bashTimeout)
defer cancel()
cwdFile, err := os.CreateTemp("", "ollama-agent-cwd-*")
if err != nil {
return agent.ToolResult{}, err
}
cwdPath := cwdFile.Name()
_ = cwdFile.Close()
defer os.Remove(cwdPath)
cmd := newBashCommand(ctx, command, cwdPath)
cmd.WaitDelay = bashWaitDelay
cmd.Cancel = func() error {
return killBashCommand(cmd)
}
if toolCtx.WorkingDir != "" {
cmd.Dir = toolCtx.WorkingDir
}
var stdout, stderr boundedOutput
stdout.Limit = maxBashOutputBytes
stderr.Limit = maxBashOutputBytes
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err = runBashCommand(cmd)
finalWorkingDir := readFinalWorkingDir(cwdPath)
var sb strings.Builder
if stdout.Len() > 0 {
sb.WriteString(stdout.String("stdout"))
}
if stderr.Len() > 0 {
if sb.Len() > 0 {
sb.WriteString("\n")
}
sb.WriteString("stderr:\n")
sb.WriteString(stderr.String("stderr"))
}
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return agent.ToolResult{Content: bashContentWithError(sb.String(), "Error: command timed out after "+bashTimeout.String()), WorkingDir: finalWorkingDir}, nil
}
if ctx.Err() == context.Canceled {
return agent.ToolResult{Content: bashContentWithError(sb.String(), "Error: command was canceled"), WorkingDir: finalWorkingDir}, nil
}
if errors.Is(err, exec.ErrWaitDelay) {
_ = killBashCommand(cmd)
return agent.ToolResult{Content: bashContentWithError(sb.String(), "Error: command output pipes did not close after "+bashWaitDelay.String()), WorkingDir: finalWorkingDir}, nil
}
if exitErr, ok := err.(*exec.ExitError); ok {
return agent.ToolResult{Content: bashContentWithError(sb.String(), fmt.Sprintf("Exit code: %d", exitErr.ExitCode())), WorkingDir: finalWorkingDir}, nil
}
return agent.ToolResult{Content: sb.String(), WorkingDir: finalWorkingDir}, fmt.Errorf("executing command: %w", err)
}
if sb.Len() == 0 {
return agent.ToolResult{Content: "(no output)", WorkingDir: finalWorkingDir}, nil
}
return agent.ToolResult{Content: sb.String(), WorkingDir: finalWorkingDir}, nil
}
func bashContentWithError(content, msg string) string {
if content == "" {
return msg
}
return content + "\n\n" + msg
}
func rejectUnsafeShellCommand(command string) error {
switch {
case hasUnsafeRecursiveDelete(command):
return fmt.Errorf("refusing to run unsafe command: recursive delete target is too broad")
case readsCredentialPath(command):
return fmt.Errorf("refusing to run unsafe command: credential file reads are not allowed")
default:
return nil
}
}
func hasUnsafeRecursiveDelete(command string) bool {
fields := shellSafetyFields(command)
for i, field := range fields {
if isRMCommand(field) && rmCommandDeletesUnsafeTarget(fields[i+1:]) {
return true
}
if isPowerShellDeleteCommand(field) && powerShellDeleteCommandDeletesUnsafeTarget(fields[i+1:]) {
return true
}
}
return false
}
func rmCommandDeletesUnsafeTarget(fields []string) bool {
var flags string
for _, field := range fields {
if field == "--" {
continue
}
if strings.HasPrefix(field, "-") {
flags += field
continue
}
if strings.Contains(flags, "r") && strings.Contains(flags, "f") && isUnsafeDeleteTarget(field) {
return true
}
}
return false
}
func powerShellDeleteCommandDeletesUnsafeTarget(fields []string) bool {
var recurse, force bool
var targets []string
for _, field := range fields {
switch field {
case "-r", "-recurse", "-recursive":
recurse = true
case "-f", "-force":
force = true
default:
if !strings.HasPrefix(field, "-") {
targets = append(targets, field)
}
}
}
if !recurse || !force {
return false
}
for _, target := range targets {
if isUnsafeDeleteTarget(target) {
return true
}
}
return false
}
func readsCredentialPath(command string) bool {
fields := shellSafetyFields(command)
if !hasCredentialReadVerb(fields) {
return false
}
normalized := shellSafetyText(command)
for _, fragment := range []string{
"/.ssh/id_rsa",
"/.ssh/id_dsa",
"/.ssh/id_ecdsa",
"/.ssh/id_ed25519",
"/.aws/credentials",
"/.config/gcloud/application_default_credentials.json",
"/.kube/config",
"/etc/shadow",
} {
if strings.Contains(normalized, fragment) {
return true
}
}
return false
}
func hasCredentialReadVerb(fields []string) bool {
for _, field := range fields {
switch field {
case "cat", "less", "more", "head", "tail", "type", "get-content", "gc", "select-string", "grep", "rg", "sed", "awk":
return true
}
}
return false
}
func isRMCommand(field string) bool {
return field == "rm" || strings.HasSuffix(field, "/rm")
}
func isPowerShellDeleteCommand(field string) bool {
switch field {
case "remove-item", "del", "erase", "rd", "rmdir":
return true
default:
return false
}
}
func isUnsafeDeleteTarget(target string) bool {
if target == "." || target == "./" || target == "*" {
return true
}
if target == "/*" {
return true
}
target = strings.TrimSuffix(target, "/*")
for _, prefix := range []string{"~/", "$home/", "${home}/", "$env:home/", "$env:userprofile/", "%userprofile%/"} {
if strings.HasPrefix(target, prefix) {
return true
}
}
for _, prefix := range []string{"/etc/", "/bin/", "/sbin/", "/usr/", "/var/", "/lib/", "/library/", "/system/", "/applications/", "c:/windows/", "c:/program files/"} {
if strings.HasPrefix(target, prefix) {
return true
}
}
for _, exact := range []string{"/", "~", "$home", "${home}", "$env:home", "$env:userprofile", "%userprofile%", "c:", "c:/", "/etc", "/bin", "/sbin", "/usr", "/var", "/lib", "/library", "/system", "/applications", "c:/windows", "c:/program files"} {
if target == exact {
return true
}
}
return false
}
func shellSafetyFields(command string) []string {
return strings.Fields(shellSafetyText(command))
}
func shellSafetyText(command string) string {
command = strings.ToLower(command)
return strings.NewReplacer(
"\\", "/",
"\n", " ",
"\t", " ",
";", " ",
"&", " ",
"|", " ",
"(", " ",
")", " ",
"\"", "",
"'", "",
"`", "",
).Replace(command)
}
func readFinalWorkingDir(path string) string {
content, err := os.ReadFile(path)
if err != nil {
return ""
}
workingDir := strings.TrimPrefix(string(content), "\ufeff")
workingDir = strings.TrimSpace(workingDir)
if workingDir == "" {
return ""
}
workingDir = normalizeBashWorkingDir(workingDir)
info, err := os.Stat(workingDir)
if err != nil || !info.IsDir() {
return ""
}
return workingDir
}
func normalizeBashWorkingDir(workingDir string) string {
if runtime.GOOS == "windows" && len(workingDir) >= 3 && workingDir[0] == '/' && workingDir[2] == '/' && isASCIIAlpha(workingDir[1]) {
workingDir = strings.ToUpper(string(workingDir[1])) + ":" + workingDir[2:]
}
workingDir = filepath.Clean(filepath.FromSlash(workingDir))
if runtime.GOOS == "windows" && len(workingDir) >= 2 && workingDir[1] == ':' && isASCIIAlpha(workingDir[0]) {
workingDir = strings.ToUpper(string(workingDir[0])) + workingDir[1:]
}
return workingDir
}
func isASCIIAlpha(b byte) bool {
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z')
}
type boundedOutput struct {
Limit int
buf []byte
omitted int
}
func (b *boundedOutput) Write(p []byte) (int, error) {
if b.Limit <= 0 {
b.omitted += len(p)
return len(p), nil
}
remaining := b.Limit - len(b.buf)
if remaining <= 0 {
b.omitted += len(p)
return len(p), nil
}
if len(p) <= remaining {
b.buf = append(b.buf, p...)
return len(p), nil
}
writeLen := utf8SafePrefixLen(p[:remaining])
b.buf = append(b.buf, p[:writeLen]...)
b.omitted += len(p) - writeLen
return len(p), nil
}
func (b *boundedOutput) Len() int {
return len(b.buf) + b.omitted
}
func (b *boundedOutput) String(label string) string {
safeLen := utf8SafePrefixLen(b.buf)
content := string(b.buf[:safeLen])
omitted := b.omitted + len(b.buf) - safeLen
if omitted == 0 {
return content
}
return content + fmt.Sprintf("\n\n[%s truncated: omitted ~%d tokens]", label, approximateTokensFromBytes(omitted))
}
func utf8SafePrefixLen(p []byte) int {
if len(p) == 0 {
return 0
}
for i := 0; i < len(p); {
r, size := utf8.DecodeRune(p[i:])
if r == utf8.RuneError && size == 1 {
return i
}
i += size
}
return len(p)
}
func approximateTokensFromBytes(n int) int {
if n <= 0 {
return 0
}
return max(1, (n+3)/4)
}