mirror of
https://github.com/ollama/ollama.git
synced 2026-05-13 06:21:28 +00:00
102 lines
2.7 KiB
Go
102 lines
2.7 KiB
Go
package imagegen
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/llm"
|
|
)
|
|
|
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
|
|
|
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
return fn(req)
|
|
}
|
|
|
|
func newCompletionTestServer(handler func(*http.Request) string) *Server {
|
|
return &Server{
|
|
port: 11434,
|
|
done: make(chan error, 1),
|
|
client: &http.Client{
|
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
|
body := handler(req)
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Header: make(http.Header),
|
|
Body: io.NopCloser(strings.NewReader(body)),
|
|
Request: req,
|
|
}, nil
|
|
}),
|
|
},
|
|
}
|
|
}
|
|
|
|
func TestCompletionReturnsImageData(t *testing.T) {
|
|
s := newCompletionTestServer(func(r *http.Request) string {
|
|
if r.URL.Path != "/completion" {
|
|
t.Fatalf("path = %q, want /completion", r.URL.Path)
|
|
}
|
|
|
|
var req Request
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if req.Prompt != "test prompt" || req.Width != 512 || req.Height != 256 || req.Steps != 7 || req.Seed != 42 {
|
|
t.Fatalf("unexpected request: %+v", req)
|
|
}
|
|
if len(req.Images) != 1 || string(req.Images[0]) != "input-image" {
|
|
t.Fatalf("images = %q, want input-image", req.Images)
|
|
}
|
|
|
|
return `{"step":1,"total":2}` + "\n" +
|
|
`{"done":true,"image":"base64png"}` + "\n"
|
|
})
|
|
|
|
var responses []llm.CompletionResponse
|
|
err := s.Completion(context.Background(), llm.CompletionRequest{
|
|
Prompt: "test prompt",
|
|
Width: 512,
|
|
Height: 256,
|
|
Steps: 7,
|
|
Seed: 42,
|
|
Images: []llm.ImageData{{Data: []byte("input-image")}},
|
|
}, func(resp llm.CompletionResponse) {
|
|
responses = append(responses, resp)
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(responses) != 2 {
|
|
t.Fatalf("responses = %d, want 2", len(responses))
|
|
}
|
|
if responses[0].Step != 1 || responses[0].TotalSteps != 2 || responses[0].Done {
|
|
t.Fatalf("progress response = %+v", responses[0])
|
|
}
|
|
if !responses[1].Done || responses[1].Image != "base64png" {
|
|
t.Fatalf("final response = %+v", responses[1])
|
|
}
|
|
}
|
|
|
|
func TestCompletionEOFBeforeDoneReturnsError(t *testing.T) {
|
|
s := newCompletionTestServer(func(r *http.Request) string {
|
|
return `{"step":1,"total":2}` + "\n"
|
|
})
|
|
|
|
var responses []llm.CompletionResponse
|
|
err := s.Completion(context.Background(), llm.CompletionRequest{Prompt: "test prompt"}, func(resp llm.CompletionResponse) {
|
|
responses = append(responses, resp)
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "closed response before completion") {
|
|
t.Fatalf("error = %v", err)
|
|
}
|
|
if len(responses) != 1 || responses[0].Done {
|
|
t.Fatalf("responses = %+v, want one non-done progress response", responses)
|
|
}
|
|
}
|