Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 44 additions & 24 deletions cli/compose/convert/service.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
// FIXME(thaJeztah): remove once we are a module; the go:build directive prevents go from downgrading language version to go1.16:
//go:build go1.24

package convert

import (
"cmp"
"context"
"errors"
"fmt"
"net/netip"
"os"
"slices"
"sort"
"strings"
"time"
Expand Down Expand Up @@ -567,21 +572,42 @@ func convertResources(source composetypes.Resources) (*swarm.ResourceRequirement
return resources, nil
}

// compareSwarmPortConfig returns the lexical ordering of a and b, and can be used
// with [slices.SortFunc].
//
// The comparison is performed in the following priority order:
//
// 1. PublishedPort (host port)
// 2. TargetPort (container port)
// 3. Protocol
// 4. PublishMode
//
// TODO(thaJeztah): define this on swarm.PortConfig itself to allow re-use.
func compareSwarmPortConfig(a, b swarm.PortConfig) int {
if n := cmp.Compare(a.PublishedPort, b.PublishedPort); n != 0 {
return n
}
if n := cmp.Compare(a.TargetPort, b.TargetPort); n != 0 {
return n
}
if n := cmp.Compare(a.Protocol, b.Protocol); n != 0 {
return n
}
return cmp.Compare(a.PublishMode, b.PublishMode)
}

func convertEndpointSpec(endpointMode string, source []composetypes.ServicePortConfig) *swarm.EndpointSpec {
portConfigs := make([]swarm.PortConfig, 0, len(source))
for _, port := range source {
portConfig := swarm.PortConfig{
portConfigs = append(portConfigs, swarm.PortConfig{
Protocol: network.IPProtocol(port.Protocol),
TargetPort: port.Target,
PublishedPort: port.Published,
PublishMode: swarm.PortConfigPublishMode(port.Mode),
}
portConfigs = append(portConfigs, portConfig)
})
}

sort.Slice(portConfigs, func(i, j int) bool {
return portConfigs[i].PublishedPort < portConfigs[j].PublishedPort
})
slices.SortFunc(portConfigs, compareSwarmPortConfig)

return &swarm.EndpointSpec{
Mode: swarm.ResolutionMode(strings.ToLower(endpointMode)),
Expand Down Expand Up @@ -702,28 +728,22 @@ func convertCredentialSpec(namespace Namespace, spec composetypes.CredentialSpec
}

func convertUlimits(origUlimits map[string]*composetypes.UlimitsConfig) []*container.Ulimit {
newUlimits := make(map[string]*container.Ulimit)
ulimits := make([]*container.Ulimit, 0, len(origUlimits))
for name, u := range origUlimits {
soft, hard := int64(u.Soft), int64(u.Hard)
if u.Single != 0 {
newUlimits[name] = &container.Ulimit{
Name: name,
Soft: int64(u.Single),
Hard: int64(u.Single),
}
} else {
newUlimits[name] = &container.Ulimit{
Name: name,
Soft: int64(u.Soft),
Hard: int64(u.Hard),
}
soft, hard = int64(u.Single), int64(u.Single)
}

ulimits = append(ulimits, &container.Ulimit{
Name: name,
Soft: soft,
Hard: hard,
})
}
ulimits := make([]*container.Ulimit, 0, len(newUlimits))
for _, ulimit := range newUlimits {
ulimits = append(ulimits, ulimit)
}
sort.SliceStable(ulimits, func(i, j int) bool {
return ulimits[i].Name < ulimits[j].Name

slices.SortFunc(ulimits, func(a, b *container.Ulimit) int {
return cmp.Compare(a.Name, b.Name)
})
return ulimits
}
59 changes: 31 additions & 28 deletions cli/compose/loader/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
package loader

import (
"cmp"
"fmt"
"reflect"
"slices"
"sort"

"dario.cat/mergo"
Expand Down Expand Up @@ -52,10 +54,10 @@ func merge(configs []*types.Config) (*types.Config, error) {
}

func mergeServices(base, override []types.ServiceConfig) ([]types.ServiceConfig, error) {
baseServices := mapByName(base)
overrideServices := mapByName(override)
specials := &specials{
m: map[reflect.Type]func(dst, src reflect.Value) error{
mergeOpts := []func(*mergo.Config){
mergo.WithAppendSlice,
mergo.WithOverride,
mergo.WithTransformers(&specials{m: map[reflect.Type]func(dst, src reflect.Value) error{
reflect.PointerTo(reflect.TypeFor[types.LoggingConfig]()): safelyMerge(mergeLoggingConfig),
reflect.TypeFor[[]types.ServicePortConfig](): mergeSlice(toServicePortConfigsMap, toServicePortConfigsSlice),
reflect.TypeFor[[]types.ServiceSecretConfig](): mergeSlice(toServiceSecretConfigsMap, toServiceSecretConfigsSlice),
Expand All @@ -65,23 +67,34 @@ func mergeServices(base, override []types.ServiceConfig) ([]types.ServiceConfig,
reflect.TypeFor[types.ShellCommand](): mergeShellCommand,
reflect.PointerTo(reflect.TypeFor[types.ServiceNetworkConfig]()): mergeServiceNetworkConfig,
reflect.PointerTo(reflect.TypeFor[uint64]()): mergeUint64,
},
}}),
}
for name, overrideService := range overrideServices {
if baseService, ok := baseServices[name]; ok {
if err := mergo.Merge(&baseService, &overrideService, mergo.WithAppendSlice, mergo.WithOverride, mergo.WithTransformers(specials)); err != nil {
return base, fmt.Errorf("cannot merge service %s: %w", name, err)

baseServices := make(map[string]types.ServiceConfig, len(base))
for _, s := range base {
baseServices[s.Name] = s
}

for _, overrideService := range override {
if baseService, ok := baseServices[overrideService.Name]; ok {
if err := mergo.Merge(&baseService, &overrideService, mergeOpts...); err != nil {
return base, fmt.Errorf("cannot merge service %s: %w", overrideService.Name, err)
}
baseServices[name] = baseService
baseServices[overrideService.Name] = baseService
continue
}
baseServices[name] = overrideService
baseServices[overrideService.Name] = overrideService
}

services := make([]types.ServiceConfig, 0, len(baseServices))
for _, baseService := range baseServices {
services = append(services, baseService)
}
sort.Slice(services, func(i, j int) bool { return services[i].Name < services[j].Name })

slices.SortFunc(services, func(a, b types.ServiceConfig) int {
return cmp.Compare(a.Name, b.Name)
})

return services, nil
}

Expand Down Expand Up @@ -217,11 +230,13 @@ func sliceToMap(tomap tomapFn, v reflect.Value) (map[any]any, error) {
}

func mergeLoggingConfig(dst, src reflect.Value) error {
dstDriver := dst.Elem().FieldByName("Driver").String()
srcDriver := src.Elem().FieldByName("Driver").String()

// Same driver, merging options
if getLoggingDriver(dst.Elem()) == getLoggingDriver(src.Elem()) ||
getLoggingDriver(dst.Elem()) == "" || getLoggingDriver(src.Elem()) == "" {
if getLoggingDriver(dst.Elem()) == "" {
dst.Elem().FieldByName("Driver").SetString(getLoggingDriver(src.Elem()))
if dstDriver == srcDriver || dstDriver == "" || srcDriver == "" {
if dstDriver == "" {
dst.Elem().FieldByName("Driver").SetString(srcDriver)
}
dstOptions := dst.Elem().FieldByName("Options").Interface().(map[string]string)
srcOptions := src.Elem().FieldByName("Options").Interface().(map[string]string)
Expand Down Expand Up @@ -270,18 +285,6 @@ func mergeUint64(dst, src reflect.Value) error {
return nil
}

func getLoggingDriver(v reflect.Value) string {
return v.FieldByName("Driver").String()
}

func mapByName(services []types.ServiceConfig) map[string]types.ServiceConfig {
m := map[string]types.ServiceConfig{}
for _, service := range services {
m[service.Name] = service
}
return m
}

func mergeVolumes(base, override map[string]types.VolumeConfig) (map[string]types.VolumeConfig, error) {
err := mergo.Map(&base, &override, mergo.WithOverride)
return base, err
Expand Down
Loading