compose/pkg/compose/model.go
Nicolas De Loof fd4f2f99cf register TTYWritter as an Event Processor
Signed-off-by: Nicolas De Loof <nicolas.deloof@gmail.com>
2025-10-29 10:11:30 +01:00

264 lines
6.4 KiB
Go

/*
Copyright 2020 Docker Compose CLI authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package compose
import (
"bufio"
"context"
"encoding/json"
"fmt"
"os/exec"
"slices"
"strconv"
"strings"
"github.com/compose-spec/compose-go/v2/types"
"github.com/containerd/errdefs"
"github.com/docker/cli/cli-plugins/manager"
"github.com/docker/compose/v2/pkg/progress"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
)
func (s *composeService) ensureModels(ctx context.Context, project *types.Project, quietPull bool) error {
if len(project.Models) == 0 {
return nil
}
api, err := s.newModelAPI(project)
if err != nil {
return err
}
defer api.Close()
availableModels, err := api.ListModels(ctx)
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
return api.SetModelVariables(ctx, project)
})
for name, config := range project.Models {
if config.Name == "" {
config.Name = name
}
eg.Go(func() error {
if !slices.Contains(availableModels, config.Model) {
err = api.PullModel(ctx, config, quietPull, s.events)
if err != nil {
return err
}
}
return api.ConfigureModel(ctx, config, s.events)
})
}
return eg.Wait()
}
type modelAPI struct {
path string
env []string
prepare func(ctx context.Context, cmd *exec.Cmd) error
cleanup func()
}
func (s *composeService) newModelAPI(project *types.Project) (*modelAPI, error) {
dockerModel, err := manager.GetPlugin("model", s.dockerCli, &cobra.Command{})
if err != nil {
if errdefs.IsNotFound(err) {
return nil, fmt.Errorf("'models' support requires Docker Model plugin")
}
return nil, err
}
endpoint, cleanup, err := s.propagateDockerEndpoint()
if err != nil {
return nil, err
}
return &modelAPI{
path: dockerModel.Path,
prepare: func(ctx context.Context, cmd *exec.Cmd) error {
return s.prepareShellOut(ctx, project.Environment, cmd)
},
cleanup: cleanup,
env: append(project.Environment.Values(), endpoint...),
}, nil
}
func (m *modelAPI) Close() {
m.cleanup()
}
func (m *modelAPI) PullModel(ctx context.Context, model types.ModelConfig, quietPull bool, events progress.EventProcessor) error {
events.On(progress.Event{
ID: model.Name,
Status: progress.Working,
Text: "Pulling",
})
cmd := exec.CommandContext(ctx, m.path, "pull", model.Model)
err := m.prepare(ctx, cmd)
if err != nil {
return err
}
stream, err := cmd.StdoutPipe()
if err != nil {
return err
}
err = cmd.Start()
if err != nil {
return err
}
scanner := bufio.NewScanner(stream)
for scanner.Scan() {
msg := scanner.Text()
if msg == "" {
continue
}
if !quietPull {
events.On(progress.Event{
ID: model.Name,
Status: progress.Working,
Text: "Pulling",
StatusText: msg,
})
}
}
err = cmd.Wait()
if err != nil {
events.On(progress.ErrorMessageEvent(model.Name, err.Error()))
}
events.On(progress.Event{
ID: model.Name,
Status: progress.Working,
Text: "Pulled",
})
return err
}
func (m *modelAPI) ConfigureModel(ctx context.Context, config types.ModelConfig, events progress.EventProcessor) error {
events.On(progress.Event{
ID: config.Name,
Status: progress.Working,
Text: "Configuring",
})
// configure [--context-size=<n>] MODEL [-- <runtime-flags...>]
args := []string{"configure"}
if config.ContextSize > 0 {
args = append(args, "--context-size", strconv.Itoa(config.ContextSize))
}
args = append(args, config.Model)
if len(config.RuntimeFlags) != 0 {
args = append(args, "--")
args = append(args, config.RuntimeFlags...)
}
cmd := exec.CommandContext(ctx, m.path, args...)
err := m.prepare(ctx, cmd)
if err != nil {
return err
}
return cmd.Run()
}
func (m *modelAPI) SetModelVariables(ctx context.Context, project *types.Project) error {
cmd := exec.CommandContext(ctx, m.path, "status", "--json")
err := m.prepare(ctx, cmd)
if err != nil {
return err
}
statusOut, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("error checking docker-model status: %w", err)
}
type Status struct {
Endpoint string `json:"endpoint"`
}
var status Status
err = json.Unmarshal(statusOut, &status)
if err != nil {
return err
}
for _, service := range project.Services {
for ref, modelConfig := range service.Models {
model := project.Models[ref]
varPrefix := strings.ReplaceAll(strings.ToUpper(ref), "-", "_")
var variable string
if modelConfig != nil && modelConfig.ModelVariable != "" {
variable = modelConfig.ModelVariable
} else {
variable = varPrefix + "_MODEL"
}
service.Environment[variable] = &model.Model
if modelConfig != nil && modelConfig.EndpointVariable != "" {
variable = modelConfig.EndpointVariable
} else {
variable = varPrefix + "_URL"
}
service.Environment[variable] = &status.Endpoint
}
}
return nil
}
type Model struct {
Id string `json:"id"`
Tags []string `json:"tags"`
Created int `json:"created"`
Config struct {
Format string `json:"format"`
Quantization string `json:"quantization"`
Parameters string `json:"parameters"`
Architecture string `json:"architecture"`
Size string `json:"size"`
} `json:"config"`
}
func (m *modelAPI) ListModels(ctx context.Context) ([]string, error) {
cmd := exec.CommandContext(ctx, m.path, "ls", "--json")
err := m.prepare(ctx, cmd)
if err != nil {
return nil, err
}
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("error checking available models: %w", err)
}
type AvailableModel struct {
Id string `json:"id"`
Tags []string `json:"tags"`
Created int `json:"created"`
}
models := []AvailableModel{}
err = json.Unmarshal(output, &models)
if err != nil {
return nil, fmt.Errorf("error unmarshalling available models: %w", err)
}
var availableModels []string
for _, model := range models {
availableModels = append(availableModels, model.Tags...)
}
return availableModels, nil
}