Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions cmd/ateapi/internal/controlapi/report_node_image_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package controlapi

import (
"context"
"sort"
"time"

"github.com/agent-substrate/substrate/pkg/proto/ateapipb"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)

// nodeImageCacheTTL allows several reports to be missed without immediately
// discarding useful affinity information. Stale data affects performance only:
// atelet still pulls an image normally if the scheduler predicted a cache hit.
const nodeImageCacheTTL = 2 * time.Minute

func (s *Service) ReportNodeImageCache(ctx context.Context, req *ateapipb.ReportNodeImageCacheRequest) (*ateapipb.ReportNodeImageCacheResponse, error) {
cache := req.GetCache()
if cache.GetNodeName() == "" {
return nil, status.Error(codes.InvalidArgument, "cache.node_name is required")
}
if cache.GetAteletPodUid() == "" {
return nil, status.Error(codes.InvalidArgument, "cache.atelet_pod_uid is required")
}

cache = proto.Clone(cache).(*ateapipb.NodeImageCache)
// Reports are authoritative snapshots. Normalize them to keep storage and
// scheduler comparisons deterministic and avoid duplicate digest entries.
dedup := make(map[string]struct{}, len(cache.GetImageDigests()))
for _, digest := range cache.GetImageDigests() {
if digest == "" {
return nil, status.Error(codes.InvalidArgument, "cache.image_digests must not contain an empty digest")
}
dedup[digest] = struct{}{}
}
cache.ImageDigests = cache.ImageDigests[:0]
for digest := range dedup {
cache.ImageDigests = append(cache.ImageDigests, digest)
}
sort.Strings(cache.ImageDigests)

if err := s.persistence.SetNodeImageCache(ctx, cache, nodeImageCacheTTL); err != nil {
return nil, err
}
return &ateapipb.ReportNodeImageCacheResponse{}, nil
}
58 changes: 58 additions & 0 deletions cmd/ateapi/internal/controlapi/report_node_image_cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package controlapi

import (
"context"
"testing"

"github.com/agent-substrate/substrate/cmd/ateapi/internal/store/storetest"
"github.com/agent-substrate/substrate/pkg/proto/ateapipb"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

func TestReportNodeImageCache(t *testing.T) {
persistence, cleanup := storetest.SetupTestStore(t)
defer cleanup()
service := &Service{persistence: persistence}

_, err := service.ReportNodeImageCache(context.Background(), &ateapipb.ReportNodeImageCacheRequest{
Cache: &ateapipb.NodeImageCache{
NodeName: "node-1",
AteletPodUid: "atelet-uid",
ImageDigests: []string{"sha256:b", "sha256:a", "sha256:a"},
},
})
if err != nil {
t.Fatalf("ReportNodeImageCache failed: %v", err)
}

got, err := persistence.GetNodeImageCache(context.Background(), "node-1")
if err != nil {
t.Fatalf("GetNodeImageCache failed: %v", err)
}
if len(got.GetImageDigests()) != 2 || got.GetImageDigests()[0] != "sha256:a" || got.GetImageDigests()[1] != "sha256:b" {
t.Fatalf("stored digests = %v, want [sha256:a sha256:b]", got.GetImageDigests())
}
}

func TestReportNodeImageCacheValidatesIdentity(t *testing.T) {
service := &Service{}
_, err := service.ReportNodeImageCache(context.Background(), &ateapipb.ReportNodeImageCacheRequest{})
if status.Code(err) != codes.InvalidArgument {
t.Fatalf("ReportNodeImageCache status = %v, want InvalidArgument", status.Code(err))
}
}
96 changes: 94 additions & 2 deletions cmd/ateapi/internal/controlapi/workflow_resume.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"log/slog"
"math/rand"
"slices"
"strings"
"time"

"github.com/agent-substrate/substrate/cmd/ateapi/internal/store"
Expand Down Expand Up @@ -162,7 +163,18 @@ func (s *AssignWorkerStep) Execute(ctx context.Context, input *ResumeInput, stat

// If not, find a free one using randomized shuffling
if assignedWorker == nil {
pickedWorker := s.findFreeWorker(workers, eligible, state.Actor.GetLatestSnapshotInfo().GetLocal().GetNodeVmsWithLocalSnapshots())
requiredDigests, err := actorTemplateImageDigests(state.ActorTemplate)
if err != nil {
return fmt.Errorf("while resolving actor image digests: %w", err)
}
nodeCaches := s.loadNodeImageCaches(ctx, workers)
pickedWorker := s.findFreeWorker(
workers,
eligible,
state.Actor.GetLatestSnapshotInfo().GetLocal().GetNodeVmsWithLocalSnapshots(),
requiredDigests,
nodeCaches,
)
if pickedWorker == nil {
return status.Errorf(codes.FailedPrecondition, "no free workers available")
}
Expand Down Expand Up @@ -201,8 +213,79 @@ func (s *AssignWorkerStep) RetryBackoff() *wait.Backoff {
}
}

func (s *AssignWorkerStep) findFreeWorker(workers []*ateapipb.Worker, eligible map[types.NamespacedName]struct{}, nodesRestrictions []string) *ateapipb.Worker {
func actorTemplateImageDigests(actorTemplate *atev1alpha1.ActorTemplate) ([]string, error) {
refs := make([]string, 0, len(actorTemplate.Spec.Containers)+1)
refs = append(refs, actorTemplate.Spec.PauseImage)
for _, container := range actorTemplate.Spec.Containers {
refs = append(refs, container.Image)
}

digestSet := make(map[string]struct{}, len(refs))
for _, ref := range refs {
at := strings.LastIndexByte(ref, '@')
if at < 0 || at == len(ref)-1 {
return nil, fmt.Errorf("image reference %q is not digest-pinned", ref)
}
digestSet[ref[at+1:]] = struct{}{}
}

digests := make([]string, 0, len(digestSet))
for digest := range digestSet {
digests = append(digests, digest)
}
slices.Sort(digests)
return digests, nil
}

func (s *AssignWorkerStep) loadNodeImageCaches(ctx context.Context, workers []*ateapipb.Worker) map[string]*ateapipb.NodeImageCache {
caches := make(map[string]*ateapipb.NodeImageCache)
seen := make(map[string]struct{})
for _, worker := range workers {
nodeName := worker.GetNodeName()
if nodeName == "" {
continue
}
if _, ok := seen[nodeName]; ok {
continue
}
seen[nodeName] = struct{}{}
cache, err := s.store.GetNodeImageCache(ctx, nodeName)
if err != nil {
if !errors.Is(err, store.ErrNotFound) {
slog.WarnContext(ctx, "Ignoring unavailable node image-cache report", "node", nodeName, "err", err)
}
continue
}
caches[nodeName] = cache
}
return caches
}

func hasAllImageDigests(cache *ateapipb.NodeImageCache, requiredDigests []string) bool {
if cache == nil || len(requiredDigests) == 0 {
return false
}
cached := make(map[string]struct{}, len(cache.GetImageDigests()))
for _, digest := range cache.GetImageDigests() {
cached[digest] = struct{}{}
}
for _, digest := range requiredDigests {
if _, ok := cached[digest]; !ok {
return false
}
}
return true
}

func (s *AssignWorkerStep) findFreeWorker(
workers []*ateapipb.Worker,
eligible map[types.NamespacedName]struct{},
nodesRestrictions []string,
requiredDigests []string,
nodeCaches map[string]*ateapipb.NodeImageCache,
) *ateapipb.Worker {
var freeWorkers []*ateapipb.Worker
var cacheHitWorkers []*ateapipb.Worker
for _, worker := range workers {
if worker.GetActorId() != "" {
continue
Expand All @@ -212,9 +295,18 @@ func (s *AssignWorkerStep) findFreeWorker(workers []*ateapipb.Worker, eligible m
}
if len(nodesRestrictions) == 0 || slices.Contains(nodesRestrictions, worker.GetNodeName()) {
freeWorkers = append(freeWorkers, worker)
if hasAllImageDigests(nodeCaches[worker.GetNodeName()], requiredDigests) {
cacheHitWorkers = append(cacheHitWorkers, worker)
}
}
}

if len(cacheHitWorkers) > 0 {
rand.Shuffle(len(cacheHitWorkers), func(i, j int) {
cacheHitWorkers[i], cacheHitWorkers[j] = cacheHitWorkers[j], cacheHitWorkers[i]
})
return cacheHitWorkers[0]
}
if len(freeWorkers) > 0 {
rand.Shuffle(len(freeWorkers), func(i, j int) {
freeWorkers[i], freeWorkers[j] = freeWorkers[j], freeWorkers[i]
Expand Down
85 changes: 85 additions & 0 deletions cmd/ateapi/internal/controlapi/workflow_resume_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
package controlapi

import (
"context"
"testing"
"time"

"github.com/agent-substrate/substrate/cmd/ateapi/internal/store/storetest"
atev1alpha1 "github.com/agent-substrate/substrate/pkg/api/v1alpha1"
"github.com/agent-substrate/substrate/pkg/proto/ateapipb"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand All @@ -39,6 +42,88 @@ func poolWithClass(namespace, name string, class atev1alpha1.SandboxClass, label
return p
}

const (
testDigestA = "sha256:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
testDigestB = "sha256:bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"
)

func TestActorTemplateImageDigests(t *testing.T) {
actorTemplate := &atev1alpha1.ActorTemplate{
Spec: atev1alpha1.ActorTemplateSpec{
PauseImage: "example.com/pause@" + testDigestA,
Containers: []atev1alpha1.Container{
{Image: "example.com/app@" + testDigestB},
{Image: "example.com/duplicate@" + testDigestA},
},
},
}
got, err := actorTemplateImageDigests(actorTemplate)
if err != nil {
t.Fatalf("actorTemplateImageDigests failed: %v", err)
}
if len(got) != 2 || got[0] != testDigestA || got[1] != testDigestB {
t.Fatalf("actorTemplateImageDigests = %v, want [%s %s]", got, testDigestA, testDigestB)
}
}

func TestFindFreeWorkerPrefersCompleteImageCacheHit(t *testing.T) {
step := &AssignWorkerStep{}
eligible := map[types.NamespacedName]struct{}{{Namespace: "ns", Name: "pool"}: {}}
workers := []*ateapipb.Worker{
{WorkerNamespace: "ns", WorkerPool: "pool", WorkerPod: "cold", NodeName: "node-cold"},
{WorkerNamespace: "ns", WorkerPool: "pool", WorkerPod: "warm", NodeName: "node-warm"},
}
caches := map[string]*ateapipb.NodeImageCache{
"node-warm": {NodeName: "node-warm", ImageDigests: []string{testDigestA, testDigestB}},
}

got := step.findFreeWorker(workers, eligible, nil, []string{testDigestA, testDigestB}, caches)
if got.GetWorkerPod() != "warm" {
t.Fatalf("findFreeWorker selected %q, want warm", got.GetWorkerPod())
}
}

func TestFindFreeWorkerSnapshotRestrictionPrecedesImageAffinity(t *testing.T) {
step := &AssignWorkerStep{}
eligible := map[types.NamespacedName]struct{}{{Namespace: "ns", Name: "pool"}: {}}
workers := []*ateapipb.Worker{
{WorkerNamespace: "ns", WorkerPool: "pool", WorkerPod: "snapshot-local", NodeName: "node-snapshot"},
{WorkerNamespace: "ns", WorkerPool: "pool", WorkerPod: "warm", NodeName: "node-warm"},
}
caches := map[string]*ateapipb.NodeImageCache{
"node-warm": {NodeName: "node-warm", ImageDigests: []string{testDigestA}},
}

got := step.findFreeWorker(workers, eligible, []string{"node-snapshot"}, []string{testDigestA}, caches)
if got.GetWorkerPod() != "snapshot-local" {
t.Fatalf("findFreeWorker selected %q, want snapshot-local", got.GetWorkerPod())
}
}

func TestHasAllImageDigestsRejectsPartialHit(t *testing.T) {
cache := &ateapipb.NodeImageCache{ImageDigests: []string{testDigestA}}
if hasAllImageDigests(cache, []string{testDigestA, testDigestB}) {
t.Fatal("partial image-cache hit was treated as a complete hit")
}
}

func TestLoadNodeImageCachesDeduplicatesNodes(t *testing.T) {
store, cleanup := storetest.SetupTestStore(t)
defer cleanup()
if err := store.SetNodeImageCache(context.Background(), &ateapipb.NodeImageCache{
NodeName: "node-1", ImageDigests: []string{testDigestA},
}, time.Minute); err != nil {
t.Fatalf("SetNodeImageCache failed: %v", err)
}
step := &AssignWorkerStep{store: store}
got := step.loadNodeImageCaches(context.Background(), []*ateapipb.Worker{
{NodeName: "node-1"}, {NodeName: "node-1"}, {NodeName: "node-missing"},
})
if len(got) != 1 || got["node-1"] == nil {
t.Fatalf("loadNodeImageCaches = %v, want only node-1", got)
}
}

func TestEligibleWorkerPools(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading