From 261d89f288d5314147c219098e6be63706ac4ace Mon Sep 17 00:00:00 2001 From: bobzetian Date: Wed, 5 Nov 2025 06:36:43 +0000 Subject: [PATCH 1/6] infModelRewrite reconciler logic. --- .../inferencemodelrewrite_reconciler.go | 87 +++++++ .../inferencemodelrewrite_reconciler_test.go | 235 ++++++++++++++++++ pkg/epp/datastore/datastore.go | 31 +++ 3 files changed, 353 insertions(+) create mode 100644 pkg/epp/controller/inferencemodelrewrite_reconciler.go create mode 100644 pkg/epp/controller/inferencemodelrewrite_reconciler_test.go diff --git a/pkg/epp/controller/inferencemodelrewrite_reconciler.go b/pkg/epp/controller/inferencemodelrewrite_reconciler.go new file mode 100644 index 000000000..384e0d345 --- /dev/null +++ b/pkg/epp/controller/inferencemodelrewrite_reconciler.go @@ -0,0 +1,87 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package controller + +import ( + "context" + "fmt" + + "k8s.io/apimachinery/pkg/api/errors" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/event" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/predicate" + + "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/common" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +type InferenceModelRewriteReconciler struct { + client.Reader + Datastore datastore.Datastore + PoolGKNN common.GKNN +} + +func (c *InferenceModelRewriteReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + logger := log.FromContext(ctx).V(logutil.DEFAULT) + ctx = ctrl.LoggerInto(ctx, logger) + + logger.Info("Reconciling InferenceModelRewrite") + + infModelRewrite := &v1alpha2.InferenceModelRewrite{} + notFound := false + if err := c.Get(ctx, req.NamespacedName, infModelRewrite); err != nil { + if !errors.IsNotFound(err) { + return ctrl.Result{}, fmt.Errorf("unable to get InferenceModelRewrite - %w", err) + } + notFound = true + } + + if notFound || !infModelRewrite.DeletionTimestamp.IsZero() || infModelRewrite.Spec.PoolRef.Name != v1alpha2.ObjectName(c.PoolGKNN.Name) { + // InferenceModelRewrite object got deleted or changed the referenced pool. + c.Datastore.RewriteDelete(req.NamespacedName) + return ctrl.Result{}, nil + } + + // Add or update if the InferenceModelRewrite instance has a creation timestamp older than the existing entry of the model. + logger = logger.WithValues("poolRef", infModelRewrite.Spec.PoolRef) + c.Datastore.RewriteSet(infModelRewrite) + logger.Info("Added/Updated InferenceModelRewrite") + + return ctrl.Result{}, nil +} + +func (c *InferenceModelRewriteReconciler) SetupWithManager(ctx context.Context, mgr ctrl.Manager) error { + return ctrl.NewControllerManagedBy(mgr). + For(&v1alpha2.InferenceModelRewrite{}). + WithEventFilter(predicate.Funcs{ + CreateFunc: func(e event.CreateEvent) bool { return c.eventPredicate(e.Object.(*v1alpha2.InferenceModelRewrite)) }, + UpdateFunc: func(e event.UpdateEvent) bool { + return c.eventPredicate(e.ObjectOld.(*v1alpha2.InferenceModelRewrite)) || c.eventPredicate(e.ObjectNew.(*v1alpha2.InferenceModelRewrite)) + }, + DeleteFunc: func(e event.DeleteEvent) bool { return c.eventPredicate(e.Object.(*v1alpha2.InferenceModelRewrite)) }, + GenericFunc: func(e event.GenericEvent) bool { return c.eventPredicate(e.Object.(*v1alpha2.InferenceModelRewrite)) }, + }). + Complete(c) +} + +func (c *InferenceModelRewriteReconciler) eventPredicate(infModelRewrite *v1alpha2.InferenceModelRewrite) bool { + return string(infModelRewrite.Spec.PoolRef.Name) == c.PoolGKNN.Name && string(infModelRewrite.Spec.PoolRef.Group) == c.PoolGKNN.Group +} diff --git a/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go b/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go new file mode 100644 index 000000000..d42644402 --- /dev/null +++ b/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go @@ -0,0 +1,235 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package controller + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + + v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" + "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/common" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" +) + +var ( + poolForRewrite = utiltest.MakeInferencePool("test-pool1").Namespace("ns1").ObjRef() + rewrite1 = makeInferenceModelRewrite("rewrite1"). + Namespace(poolForRewrite.Namespace). + PoolName(poolForRewrite.Name). + CreationTimestamp(metav1.Unix(1000, 0)). + ObjRef() + rewrite1Pool2 = makeInferenceModelRewrite(rewrite1.Name). + Namespace(rewrite1.Namespace). + PoolName("test-pool2"). + CreationTimestamp(metav1.Unix(1001, 0)). + ObjRef() + rewrite1Updated = makeInferenceModelRewrite(rewrite1.Name). + Namespace(rewrite1.Namespace). + PoolName(poolForRewrite.Name). + CreationTimestamp(metav1.Unix(1003, 0)). + Rules([]v1alpha2.InferenceModelRewriteRule{{}}). + ObjRef() + rewrite1Deleted = makeInferenceModelRewrite(rewrite1.Name). + Namespace(rewrite1.Namespace). + PoolName(poolForRewrite.Name). + CreationTimestamp(metav1.Unix(1004, 0)). + DeletionTimestamp(). + ObjRef() + rewrite2 = makeInferenceModelRewrite("rewrite2"). + Namespace(poolForRewrite.Namespace). + PoolName(poolForRewrite.Name). + CreationTimestamp(metav1.Unix(1000, 0)). + ObjRef() +) + +type inferenceModelRewriteBuilder struct { + *v1alpha2.InferenceModelRewrite +} + +func makeInferenceModelRewrite(name string) *inferenceModelRewriteBuilder { + return &inferenceModelRewriteBuilder{ + &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + }, + }, + } +} + +func (b *inferenceModelRewriteBuilder) Namespace(ns string) *inferenceModelRewriteBuilder { + b.ObjectMeta.Namespace = ns + return b +} + +func (b *inferenceModelRewriteBuilder) PoolName(name string) *inferenceModelRewriteBuilder { + b.Spec.PoolRef.Name = v1alpha2.ObjectName(name) + return b +} + +func (b *inferenceModelRewriteBuilder) CreationTimestamp(t metav1.Time) *inferenceModelRewriteBuilder { + b.ObjectMeta.CreationTimestamp = t + return b +} + +func (b *inferenceModelRewriteBuilder) DeletionTimestamp() *inferenceModelRewriteBuilder { + now := metav1.Now() + b.ObjectMeta.DeletionTimestamp = &now + return b +} + +func (b *inferenceModelRewriteBuilder) Rules(rules []v1alpha2.InferenceModelRewriteRule) *inferenceModelRewriteBuilder { + b.Spec.Rules = rules + return b +} + +func (b *inferenceModelRewriteBuilder) ObjRef() *v1alpha2.InferenceModelRewrite { + return b.InferenceModelRewrite +} + +func TestInferenceModelRewriteReconciler(t *testing.T) { + tests := []struct { + name string + rewritesInStore []*v1alpha2.InferenceModelRewrite + rewritesInAPIServer []*v1alpha2.InferenceModelRewrite + rewrite *v1alpha2.InferenceModelRewrite + incomingReq *types.NamespacedName + wantRewrites []*v1alpha2.InferenceModelRewrite + wantResult ctrl.Result + }{ + { + name: "Empty store, add new rewrite", + rewrite: rewrite1, + wantRewrites: []*v1alpha2.InferenceModelRewrite{rewrite1}, + }, + { + name: "Existing rewrite changed pools", + rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1}, + rewrite: rewrite1Pool2, + wantRewrites: []*v1alpha2.InferenceModelRewrite{}, + }, + { + name: "Not found, delete existing rewrite", + rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1}, + incomingReq: &types.NamespacedName{Name: rewrite1.Name, Namespace: rewrite1.Namespace}, + wantRewrites: []*v1alpha2.InferenceModelRewrite{}, + }, + { + name: "Deletion timestamp set, delete existing rewrite", + rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1}, + rewrite: rewrite1Deleted, + incomingReq: &types.NamespacedName{Name: rewrite1Deleted.Name, Namespace: rewrite1Deleted.Namespace}, + wantRewrites: []*v1alpha2.InferenceModelRewrite{}, + }, + { + name: "Rewrite updated", + rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1}, + rewrite: rewrite1Updated, + wantRewrites: []*v1alpha2.InferenceModelRewrite{rewrite1Updated}, + }, + { + name: "Rewrite not found, no matching existing rewrite to delete", + rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1}, + incomingReq: &types.NamespacedName{Name: "non-existent-rewrite", Namespace: poolForRewrite.Namespace}, + wantRewrites: []*v1alpha2.InferenceModelRewrite{rewrite1}, + }, + { + name: "Add to existing", + rewritesInStore: []*v1alpha2.InferenceModelRewrite{rewrite1}, + rewrite: rewrite2, + wantRewrites: []*v1alpha2.InferenceModelRewrite{rewrite1, rewrite2}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + _ = v1alpha2.Install(scheme) + _ = v1.Install(scheme) + initObjs := []client.Object{} + if test.rewrite != nil && test.rewrite.DeletionTimestamp.IsZero() { + initObjs = append(initObjs, test.rewrite) + } + for _, r := range test.rewritesInAPIServer { + initObjs = append(initObjs, r) + } + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + WithObjects(initObjs...). + Build() + pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) + ds := datastore.NewDatastore(t.Context(), pmf, 0) + for _, r := range test.rewritesInStore { + ds.RewriteSet(r) + } + _ = ds.PoolSet(context.Background(), fakeClient, poolForRewrite) + reconciler := &InferenceModelRewriteReconciler{ + Reader: fakeClient, + Datastore: ds, + PoolGKNN: common.GKNN{ + NamespacedName: types.NamespacedName{Name: poolForRewrite.Name, Namespace: poolForRewrite.Namespace}, + GroupKind: schema.GroupKind{Group: poolForRewrite.GroupVersionKind().Group, Kind: poolForRewrite.GroupVersionKind().Kind}, + }, + } + if test.incomingReq == nil { + test.incomingReq = &types.NamespacedName{Name: test.rewrite.Name, Namespace: test.rewrite.Namespace} + } + + result, err := reconciler.Reconcile(context.Background(), ctrl.Request{NamespacedName: *test.incomingReq}) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if diff := cmp.Diff(result, test.wantResult); diff != "" { + t.Errorf("Unexpected result diff (+got/-want): %s", diff) + } + + if len(test.wantRewrites) != len(ds.RewriteGetAll()) { + t.Errorf("Unexpected number of rewrites; want: %d, got:%d", len(test.wantRewrites), len(ds.RewriteGetAll())) + } + + if diff := diffStoreRewrites(ds, test.wantRewrites); diff != "" { + t.Errorf("Unexpected diff (+got/-want): %s", diff) + } + }) + } +} + +func diffStoreRewrites(ds datastore.Datastore, wantRewrites []*v1alpha2.InferenceModelRewrite) string { + if wantRewrites == nil { + wantRewrites = []*v1alpha2.InferenceModelRewrite{} + } + + gotRewrites := ds.RewriteGetAll() + if diff := cmp.Diff(wantRewrites, gotRewrites); diff != "" { + return "rewrites:" + diff + } + return "" +} diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index 2ab2e98cb..2cbdbd031 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -59,6 +59,11 @@ type Datastore interface { ObjectiveDelete(namespacedName types.NamespacedName) ObjectiveGetAll() []*v1alpha2.InferenceObjective + // InferenceModelRewrite operations + RewriteSet(infModelRewrite *v1alpha2.InferenceModelRewrite) + RewriteDelete(namespacedName types.NamespacedName) + RewriteGetAll() []*v1alpha2.InferenceModelRewrite + // PodList lists pods matching the given predicate. PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool @@ -75,6 +80,7 @@ func NewDatastore(parentCtx context.Context, epFactory datalayer.EndpointFactory poolAndObjectivesMu: sync.RWMutex{}, pool: nil, objectives: make(map[string]*v1alpha2.InferenceObjective), + rewrites: make(map[types.NamespacedName]*v1alpha2.InferenceModelRewrite), pods: &sync.Map{}, modelServerMetricsPort: modelServerMetricsPort, epf: epFactory, @@ -96,6 +102,8 @@ type datastore struct { pool *datalayer.EndpointPool // key: InferenceObjective.Spec.ModelName, value: *InferenceObjective objectives map[string]*v1alpha2.InferenceObjective + // key: types.NamespacedName, value: *v1alpha2.InferenceModelRewrite + rewrites map[types.NamespacedName]*v1alpha2.InferenceModelRewrite // key: types.NamespacedName, value: backendmetrics.PodMetrics pods *sync.Map // modelServerMetricsPort metrics port from EPP command line argument @@ -109,6 +117,7 @@ func (ds *datastore) Clear() { defer ds.poolAndObjectivesMu.Unlock() ds.pool = nil ds.objectives = make(map[string]*v1alpha2.InferenceObjective) + ds.rewrites = make(map[types.NamespacedName]*v1alpha2.InferenceModelRewrite) // stop all pods go routines before clearing the pods map. ds.pods.Range(func(_, v any) bool { ds.epf.ReleaseEndpoint(v.(backendmetrics.PodMetrics)) @@ -204,6 +213,28 @@ func (ds *datastore) ObjectiveGetAll() []*v1alpha2.InferenceObjective { return res } +func (ds *datastore) RewriteSet(infModelRewrite *v1alpha2.InferenceModelRewrite) { + ds.poolAndObjectivesMu.Lock() + defer ds.poolAndObjectivesMu.Unlock() + ds.rewrites[types.NamespacedName{Name: infModelRewrite.Name, Namespace: infModelRewrite.Namespace}] = infModelRewrite +} + +func (ds *datastore) RewriteDelete(namespacedName types.NamespacedName) { + ds.poolAndObjectivesMu.Lock() + defer ds.poolAndObjectivesMu.Unlock() + delete(ds.rewrites, namespacedName) +} + +func (ds *datastore) RewriteGetAll() []*v1alpha2.InferenceModelRewrite { + ds.poolAndObjectivesMu.RLock() + defer ds.poolAndObjectivesMu.RUnlock() + res := []*v1alpha2.InferenceModelRewrite{} + for _, v := range ds.rewrites { + res = append(res, v) + } + return res +} + // /// Pods/endpoints APIs /// // TODO: add a flag for callers to specify the staleness threshold for metrics. // ref: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/1046#discussion_r2246351694 From 520f274c0a3d4ecd25a9b60a3732a47f266f3fee Mon Sep 17 00:00:00 2001 From: bobzetian Date: Wed, 5 Nov 2025 16:15:35 +0000 Subject: [PATCH 2/6] implments model rewrite and traffic splitting. --- .../inferencemodelrewrite_reconciler.go | 2 +- .../inferencemodelrewrite_reconciler_test.go | 1 + pkg/epp/requestcontrol/director.go | 76 +++- pkg/epp/requestcontrol/director_test.go | 342 +++++++++++++++++- 4 files changed, 393 insertions(+), 28 deletions(-) diff --git a/pkg/epp/controller/inferencemodelrewrite_reconciler.go b/pkg/epp/controller/inferencemodelrewrite_reconciler.go index 384e0d345..8935e4115 100644 --- a/pkg/epp/controller/inferencemodelrewrite_reconciler.go +++ b/pkg/epp/controller/inferencemodelrewrite_reconciler.go @@ -83,5 +83,5 @@ func (c *InferenceModelRewriteReconciler) SetupWithManager(ctx context.Context, } func (c *InferenceModelRewriteReconciler) eventPredicate(infModelRewrite *v1alpha2.InferenceModelRewrite) bool { - return string(infModelRewrite.Spec.PoolRef.Name) == c.PoolGKNN.Name && string(infModelRewrite.Spec.PoolRef.Group) == c.PoolGKNN.Group + return infModelRewrite.Spec.PoolRef != nil && string(infModelRewrite.Spec.PoolRef.Name) == c.PoolGKNN.Name && string(infModelRewrite.Spec.PoolRef.Group) == c.PoolGKNN.Group } diff --git a/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go b/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go index d42644402..0649c7dd8 100644 --- a/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go +++ b/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go @@ -90,6 +90,7 @@ func (b *inferenceModelRewriteBuilder) Namespace(ns string) *inferenceModelRewri } func (b *inferenceModelRewriteBuilder) PoolName(name string) *inferenceModelRewriteBuilder { + b.Spec.PoolRef = &v1alpha2.PoolObjectReference{} b.Spec.PoolRef.Name = v1alpha2.ObjectName(name) return b } diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index c4f4f1c1b..dd4b3424b 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -23,6 +23,7 @@ import ( "fmt" "math/rand" "net" + "sort" "strings" "time" @@ -50,6 +51,7 @@ type Datastore interface { PoolGet() (*datalayer.EndpointPool, error) ObjectiveGet(objectiveName string) *v1alpha2.InferenceObjective PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics + RewriteGetAll() []*v1alpha2.InferenceModelRewrite } // Scheduler defines the interface required by the Director for scheduling. @@ -110,11 +112,16 @@ func (d *Director) getInferenceObjective(ctx context.Context, reqCtx *handlers.R return infObjective } -// resolveTargetModel is a helper to update reqCtx with target model based on request. -func (d *Director) resolveTargetModel(reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { +// HandleRequest orchestrates the request lifecycle. +// It always returns the requestContext even in the error case, as the request context is used in error handling. +func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { + logger := log.FromContext(ctx) + + // Parse Request, Resolve Target Models, and Determine Parameters requestBodyMap := reqCtx.Request.Body var ok bool reqCtx.IncomingModelName, ok = requestBodyMap["model"].(string) + if !ok { return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request body"} } @@ -122,22 +129,11 @@ func (d *Director) resolveTargetModel(reqCtx *handlers.RequestContext) (*handler // Default to incoming model name reqCtx.TargetModelName = reqCtx.IncomingModelName } - reqCtx.Request.Body["model"] = reqCtx.TargetModelName - return reqCtx, nil -} -// HandleRequest orchestrates the request lifecycle. -// It always returns the requestContext even in the error case, as the request context is used in error handling. -func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { - logger := log.FromContext(ctx) + d.applyWeightedModelRewrite(reqCtx) - // Resolve target model and update req context. - reqCtx, err := d.resolveTargetModel(reqCtx) - if err != nil { - return reqCtx, err - } + reqCtx.Request.Body["model"] = reqCtx.TargetModelName - // Parse request body. requestBody, err := requtil.ExtractRequestBody(reqCtx.Request.Body) if err != nil { return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Errorf("failed to extract request data: %w", err).Error()} @@ -198,6 +194,56 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo return reqCtx, nil } +func (d *Director) applyWeightedModelRewrite(reqCtx *handlers.RequestContext) { + rewrites := d.datastore.RewriteGetAll() + if len(rewrites) == 0 { + return + } + + sort.Slice(rewrites, func(i, j int) bool { + return rewrites[i].CreationTimestamp.Before(&rewrites[j].CreationTimestamp) + }) + + for _, rewrite := range rewrites { + for _, rule := range rewrite.Spec.Rules { + for _, match := range rule.Matches { + if match.Model != nil && match.Model.Value == reqCtx.IncomingModelName { + reqCtx.TargetModelName = d.selectWeightedModel(rule.Targets) + return + } + } + } + } +} + +func (d *Director) selectWeightedModel(models []v1alpha2.TargetModel) string { + if len(models) == 0 { + return "" + } + + var totalWeight int32 + for _, model := range models { + totalWeight += model.Weight + } + + if totalWeight == 0 { + // If total weight is 0, distribute evenly + return models[rand.Intn(len(models))].ModelRewrite + } + + randomNum := rand.Intn(int(totalWeight)) + var currentWeight int32 + for _, model := range models { + currentWeight += model.Weight + if randomNum < int(currentWeight) { + return model.ModelRewrite + } + } + + // Should not happen + return models[len(models)-1].ModelRewrite +} + // getCandidatePodsForScheduling gets the list of relevant endpoints for the scheduling cycle from the datastore. // according to EPP protocol, if "x-gateway-destination-endpoint-subset" is set on the request metadata and specifies // a subset of endpoints, only these endpoints will be considered as candidates for the scheduler. diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index f361303c8..d31bf6f06 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -87,7 +87,8 @@ func (m *mockScheduler) Schedule(_ context.Context, _ *schedulingtypes.LLMReques } type mockDatastore struct { - pods []backendmetrics.PodMetrics + pods []backendmetrics.PodMetrics + rewrites []*v1alpha2.InferenceModelRewrite } func (ds *mockDatastore) PoolGet() (*datalayer.EndpointPool, error) { @@ -167,6 +168,10 @@ func (m mockProducedDataType) Clone() datalayer.Cloneable { return mockProducedDataType{value: m.value} } +func (ds *mockDatastore) RewriteGetAll() []*v1alpha2.InferenceModelRewrite { + return ds.rewrites +} + func TestDirector_HandleRequest(t *testing.T) { ctx := logutil.NewTestLoggerIntoContext(context.Background()) @@ -174,6 +179,8 @@ func TestDirector_HandleRequest(t *testing.T) { model := "food-review" modelSheddable := "food-review-sheddable" modelWithResolvedTarget := "food-review-resolve" + modelToBeRewritten := "food-review-to-be-rewritten" + modelRewritten := "food-review-rewritten" objectiveName := "ioFoodReview" objectiveNameSheddable := "imFoodReviewSheddable" @@ -191,6 +198,34 @@ func TestDirector_HandleRequest(t *testing.T) { CreationTimestamp(metav1.Unix(1000, 0)). Priority(1). ObjRef() + + rewrite := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: "rewrite-rule", + CreationTimestamp: metav1.Now(), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + Rules: []v1alpha2.InferenceModelRewriteRule{ + { + Matches: []v1alpha2.Match{ + { + Model: &v1alpha2.ModelMatch{ + Value: modelToBeRewritten, + }, + }, + }, + Targets: []v1alpha2.TargetModel{ + { + ModelRewrite: modelRewritten, + Weight: 100, + }, + }, + }, + }, + }, + } + + pool := &v1.InferencePool{ ObjectMeta: metav1.ObjectMeta{Name: "test-pool", Namespace: "default"}, Spec: v1.InferencePoolSpec{ @@ -209,6 +244,7 @@ func TestDirector_HandleRequest(t *testing.T) { ds.ObjectiveSet(ioFoodReview) ds.ObjectiveSet(ioFoodReviewResolve) ds.ObjectiveSet(ioFoodReviewSheddable) + ds.RewriteSet(rewrite) scheme := runtime.NewScheme() _ = clientgoscheme.AddToScheme(scheme) @@ -284,6 +320,7 @@ func TestDirector_HandleRequest(t *testing.T) { mockAdmissionController *mockAdmissionController inferenceObjectiveName string schedulerMockSetup func(m *mockScheduler) + initialTargetModelName string // Initial target model in the reqCtx. wantErrCode string // Expected errutil code string wantReqCtx *handlers.RequestContext // Fields to check in the returned RequestContext wantMutatedBodyModel string // Expected model in reqCtx.Request.Body after PostDispatch @@ -301,6 +338,7 @@ func TestDirector_HandleRequest(t *testing.T) { schedulerMockSetup: func(m *mockScheduler) { m.scheduleResults = defaultSuccessfulScheduleResults }, + initialTargetModelName: model, wantReqCtx: &handlers.RequestContext{ ObjectiveKey: objectiveName, TargetModelName: model, @@ -314,9 +352,31 @@ func TestDirector_HandleRequest(t *testing.T) { }, wantMutatedBodyModel: model, inferenceObjectiveName: objectiveName, - targetModelName: model, - }, - { + }, { + name: "successful request with model rewrite", + reqBodyMap: map[string]any{ + "model": modelToBeRewritten, + "prompt": "some prompt", + }, + mockAdmissionController: &mockAdmissionController{admitErr: nil}, + schedulerMockSetup: func(m *mockScheduler) { + m.scheduleResults = defaultSuccessfulScheduleResults + }, + initialTargetModelName: model, + wantReqCtx: &handlers.RequestContext{ + ObjectiveKey: model, + TargetModelName: modelRewritten, + TargetPod: &backend.Pod{ + NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, + Address: "192.168.1.100", + Port: "8000", + MetricsHost: "192.168.1.100:8000", + }, + TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000", + }, + wantMutatedBodyModel: modelRewritten, + inferenceObjectiveName: model, + }, { name: "successful chat completions request", reqBodyMap: map[string]any{ "model": model, @@ -331,6 +391,7 @@ func TestDirector_HandleRequest(t *testing.T) { schedulerMockSetup: func(m *mockScheduler) { m.scheduleResults = defaultSuccessfulScheduleResults }, + initialTargetModelName: model, wantReqCtx: &handlers.RequestContext{ TargetModelName: model, TargetPod: &backend.Pod{ @@ -442,6 +503,7 @@ func TestDirector_HandleRequest(t *testing.T) { schedulerMockSetup: func(m *mockScheduler) { m.scheduleResults = defaultSuccessfulScheduleResults }, + initialTargetModelName: model, wantReqCtx: &handlers.RequestContext{ ObjectiveKey: objectiveName, TargetModelName: model, @@ -453,11 +515,8 @@ func TestDirector_HandleRequest(t *testing.T) { }, TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000", }, - wantMutatedBodyModel: model, inferenceObjectiveName: objectiveName, - targetModelName: model, - }, - { + }, { name: "successful request with target model resolution", reqBodyMap: map[string]any{ "model": modelWithResolvedTarget, @@ -467,6 +526,7 @@ func TestDirector_HandleRequest(t *testing.T) { schedulerMockSetup: func(m *mockScheduler) { m.scheduleResults = defaultSuccessfulScheduleResults }, + initialTargetModelName: "resolved-target-model-A", wantReqCtx: &handlers.RequestContext{ ObjectiveKey: objectiveNameResolve, TargetModelName: "resolved-target-model-A", @@ -480,13 +540,13 @@ func TestDirector_HandleRequest(t *testing.T) { }, wantMutatedBodyModel: "resolved-target-model-A", inferenceObjectiveName: objectiveNameResolve, - targetModelName: "resolved-target-model-A", }, { name: "nonexistent target defined, use default inference model", schedulerMockSetup: func(m *mockScheduler) { m.scheduleResults = defaultSuccessfulScheduleResults }, + initialTargetModelName: "food-review-1", wantReqCtx: &handlers.RequestContext{ ObjectiveKey: "food-review-1", TargetModelName: "food-review-1", @@ -505,10 +565,8 @@ func TestDirector_HandleRequest(t *testing.T) { }, mockAdmissionController: &mockAdmissionController{admitErr: nil}, inferenceObjectiveName: "food-review-1", - targetModelName: "food-review-1", }, { - name: "request rejected by admission controller", reqBodyMap: map[string]any{ "model": modelSheddable, @@ -588,7 +646,7 @@ func TestDirector_HandleRequest(t *testing.T) { }, }, ObjectiveKey: test.inferenceObjectiveName, - TargetModelName: test.targetModelName, + TargetModelName: test.initialTargetModelName, } // Deep copy the body map. maps.Copy(reqCtx.Request.Body, test.reqBodyMap) @@ -777,6 +835,266 @@ func TestGetRandomPod(t *testing.T) { } } +func TestDirector_ApplyWeightedModelRewrite(t *testing.T) { + _ = logutil.NewTestLoggerIntoContext(context.Background()) + + // Mock InferenceModelRewrite objects + rewriteOld := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: "rewrite-old", + CreationTimestamp: metav1.Unix(1000, 0), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + Rules: []v1alpha2.InferenceModelRewriteRule{ + { + Matches: []v1alpha2.Match{ + { + Model: &v1alpha2.ModelMatch{ + Value: "model-a", + }, + }, + }, + Targets: []v1alpha2.TargetModel{ + { + ModelRewrite: "model-a-old-tuned", + Weight: 100, + }, + }, + }, + }, + }, + } + + rewriteNew := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: "rewrite-new", + CreationTimestamp: metav1.Unix(2000, 0), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + Rules: []v1alpha2.InferenceModelRewriteRule{ + { + Matches: []v1alpha2.Match{ + { + Model: &v1alpha2.ModelMatch{ + Value: "model-a", + }, + }, + }, + Targets: []v1alpha2.TargetModel{ + { + ModelRewrite: "model-a-new-tuned", + Weight: 100, + }, + }, + }, + }, + }, + } + + rewriteB := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: "rewrite-b", + CreationTimestamp: metav1.Unix(1500, 0), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + Rules: []v1alpha2.InferenceModelRewriteRule{ + { + Matches: []v1alpha2.Match{ + { + Model: &v1alpha2.ModelMatch{ + Value: "model-b", + }, + }, + }, + Targets: []v1alpha2.TargetModel{ + { + ModelRewrite: "model-b-tuned", + Weight: 100, + }, + }, + }, + }, + }, + } + + rewriteWeighted := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: "rewrite-weighted", + CreationTimestamp: metav1.Unix(1200, 0), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + Rules: []v1alpha2.InferenceModelRewriteRule{ + { + Matches: []v1alpha2.Match{ + { + Model: &v1alpha2.ModelMatch{ + Value: "model-c", + }, + }, + }, + Targets: []v1alpha2.TargetModel{ + { + ModelRewrite: "model-c-v1", + Weight: 70, + }, + { + ModelRewrite: "model-c-v2", + Weight: 30, + }, + }, + }, + }, + }, + } + + tests := []struct { + name string + rewrites []*v1alpha2.InferenceModelRewrite + incomingModel string + expectedTarget []string + initialTarget string // Initial value of reqCtx.TargetModelName + }{ + { + name: "no rewrites", + rewrites: []*v1alpha2.InferenceModelRewrite{}, + incomingModel: "model-x", + expectedTarget: []string{"model-x"}, + initialTarget: "model-x", + }, + { + name: "single matching rewrite", + rewrites: []*v1alpha2.InferenceModelRewrite{rewriteB}, + incomingModel: "model-b", + expectedTarget: []string{"model-b-tuned"}, + initialTarget: "model-b", + }, + { + name: "no matching rewrite", + rewrites: []*v1alpha2.InferenceModelRewrite{rewriteB}, + incomingModel: "model-x", + expectedTarget: []string{"model-x"}, + initialTarget: "model-x", + }, + { + name: "oldest rewrite wins for duplicate model", + rewrites: []*v1alpha2.InferenceModelRewrite{rewriteNew, rewriteOld}, // New is first, but Old has older timestamp + incomingModel: "model-a", + expectedTarget: []string{"model-a-old-tuned"}, + initialTarget: "model-a", + }, + { + name: "weighted rewrite applied (probabilistic check)", + rewrites: []*v1alpha2.InferenceModelRewrite{rewriteWeighted}, + incomingModel: "model-c", + initialTarget: "model-c", + expectedTarget: []string{"model-c-v1", "model-c-v2"}, + }, + { + name: "initial TargetModelName is respected if no rewrite matches", + rewrites: []*v1alpha2.InferenceModelRewrite{rewriteB}, + incomingModel: "model-x", + initialTarget: "pre-existing-target", + expectedTarget: []string{"pre-existing-target"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + mockDs := &mockDatastore{rewrites: test.rewrites} + director := NewDirectorWithConfig(mockDs, &mockScheduler{}, &mockAdmissionController{}, NewConfig()) + + reqCtx := &handlers.RequestContext{ + IncomingModelName: test.incomingModel, + TargetModelName: test.initialTarget, + } + + director.applyWeightedModelRewrite(reqCtx) + assert.Contains(t, test.expectedTarget, reqCtx.TargetModelName, "TargetModelName mismatch") + }) + } +} + +func TestDirector_SelectWeightedModel(t *testing.T) { + tests := []struct { + name string + targets []v1alpha2.TargetModel + possibleModels map[string]bool // For probabilistic cases + }{ + { + name: "single target", + targets: []v1alpha2.TargetModel{ + {ModelRewrite: "model-a", Weight: 100}, + }, + possibleModels: map[string]bool{"model-a": true}, + }, + { + name: "multiple targets, equal weight", + targets: []v1alpha2.TargetModel{ + {ModelRewrite: "model-a", Weight: 50}, + {ModelRewrite: "model-b", Weight: 50}, + }, + possibleModels: map[string]bool{"model-a": true, "model-b": true}, + }, + { + name: "multiple targets, different weights", + targets: []v1alpha2.TargetModel{ + {ModelRewrite: "model-x", Weight: 70}, + {ModelRewrite: "model-y", Weight: 30}, + }, + possibleModels: map[string]bool{"model-x": true, "model-y": true}, + }, + { + name: "zero total weight, distribute evenly", + targets: []v1alpha2.TargetModel{ + {ModelRewrite: "model-z1", Weight: 0}, + {ModelRewrite: "model-z2", Weight: 0}, + }, + possibleModels: map[string]bool{"model-z1": true, "model-z2": true}, + }, + } + + director := &Director{} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Run multiple times to check distribution + counter := make(map[string]int) + numRuns := 1000 + for i := 0; i < numRuns; i++ { + selected := director.selectWeightedModel(test.targets) + counter[selected]++ + } + + // Assert that all selected models are within the possible models + for model := range counter { + if _, ok := test.possibleModels[model]; !ok { + t.Errorf("Selected model %s is not in possible models %v", model, test.possibleModels) + } + } + + // Basic check for distribution (e.g., if 70/30, expect roughly 700/300) + if len(test.targets) > 1 { + totalWeight := int32(0) + for _, target := range test.targets { + totalWeight += target.Weight + } + + if totalWeight == 0 { // Special case for zero total weight + for _, target := range test.targets { + expectedCount := numRuns / len(test.targets) + assert.InDelta(t, expectedCount, counter[target.ModelRewrite], float64(numRuns)/float64(len(test.targets))*0.2, "Distribution for %s is off", target.ModelRewrite) + } + } else { + for _, target := range test.targets { + expectedCount := float64(numRuns) * (float64(target.Weight) / float64(totalWeight)) + assert.InDelta(t, expectedCount, float64(counter[target.ModelRewrite]), expectedCount*0.2, "Distribution for %s is off", target.ModelRewrite) + } + } + } + }) + } +} + func TestDirector_HandleResponseReceived(t *testing.T) { pr1 := newTestResponseReceived("pr1") From c3c32f619887b91e7ea7ade0228e75518a93f51d Mon Sep 17 00:00:00 2001 From: bobzetian Date: Wed, 19 Nov 2025 01:00:11 +0000 Subject: [PATCH 3/6] more efficient rewrite fetching per request. --- apix/v1alpha2/inferencemodelrewrite_types.go | 23 ++- ...rking.x-k8s.io_inferencemodelrewrites.yaml | 8 +- .../1816-inferenceomodelrewrite/README.md | 23 ++- .../inferencemodelrewrite_reconciler_test.go | 122 +++++------- pkg/epp/datastore/datastore.go | 96 ++++----- pkg/epp/datastore/modelrewritestore.go | 187 ++++++++++++++++++ pkg/epp/datastore/modelrewritestore_test.go | 185 +++++++++++++++++ pkg/epp/requestcontrol/director.go | 23 +-- pkg/epp/requestcontrol/director_test.go | 33 +++- 9 files changed, 536 insertions(+), 164 deletions(-) create mode 100644 pkg/epp/datastore/modelrewritestore.go create mode 100644 pkg/epp/datastore/modelrewritestore_test.go diff --git a/apix/v1alpha2/inferencemodelrewrite_types.go b/apix/v1alpha2/inferencemodelrewrite_types.go index ef68d8366..262238c28 100644 --- a/apix/v1alpha2/inferencemodelrewrite_types.go +++ b/apix/v1alpha2/inferencemodelrewrite_types.go @@ -57,20 +57,25 @@ type InferenceModelRewriteSpec struct { // If multiple InferenceModelRewrite resources target the same // InferencePool, the controller will merge them based on precedence. // - // **Timestamp Wins:** If two rules from different rewrites all matches, - // the rule from the *oldest* - // InferenceModelRewrite resource (determined by - // metadata.creationTimestamp) will be used. + // Across all rules specified on applicable rewrites, precedence MUST be + // given to the match having an "Exact" model match over a generic match + // (a rule with an empty `matches` array). + // + // If ties still exist across multiple InferenceModelRewrite resources (e.g. + // two rewrites both have an exact match for the same model), matching + // precedence MUST be determined by the oldest resource based on + // creation timestamp. + // + // If ties still exist within a single InferenceModelRewrite resource, the + // FIRST matching rule (in list order) is used. // +required Rules []InferenceModelRewriteRule `json:"rules"` } // InferenceModelRewriteRule defines the match criteria and corresponding action. -// -// A specific model name can only be matched by one rule across all -// rules attached to the same InferencePool. If multiple rules attempt -// to match the same model name, the oldest rule (by creationTimestamp) -// will be the only one considered valid. +// For details on how precedence is determined across multiple rules and +// InferenceModelRewrite resources, see the "Precedence and Conflict Resolution" +// section in InferenceModelRewriteSpec. type InferenceModelRewriteRule struct { // Matches defines the criteria for matching a request. // If multiple match criteria are specified, a request matches if diff --git a/config/crd/bases/inference.networking.x-k8s.io_inferencemodelrewrites.yaml b/config/crd/bases/inference.networking.x-k8s.io_inferencemodelrewrites.yaml index bb9b3e6cf..2680ea091 100644 --- a/config/crd/bases/inference.networking.x-k8s.io_inferencemodelrewrites.yaml +++ b/config/crd/bases/inference.networking.x-k8s.io_inferencemodelrewrites.yaml @@ -74,11 +74,9 @@ spec: items: description: |- InferenceModelRewriteRule defines the match criteria and corresponding action. - - A specific model name can only be matched by one rule across all - rules attached to the same InferencePool. If multiple rules attempt - to match the same model name, the oldest rule (by creationTimestamp) - will be the only one considered valid. + For details on how precedence is determined across multiple rules and + InferenceModelRewrite resources, see the "Precedence and Conflict Resolution" + section in InferenceModelRewriteSpec. properties: matches: items: diff --git a/docs/proposals/1816-inferenceomodelrewrite/README.md b/docs/proposals/1816-inferenceomodelrewrite/README.md index 9f20b36fe..93cfd6d4c 100644 --- a/docs/proposals/1816-inferenceomodelrewrite/README.md +++ b/docs/proposals/1816-inferenceomodelrewrite/README.md @@ -64,20 +64,25 @@ type InferenceModelRewriteSpec struct { // If multiple InferenceModelRewrite resources target the same // InferencePool, the controller will merge them based on precedence. // - // **Timestamp Wins:** If two rules from different rewrite all matches, - // the rule from the *oldest* - // InferenceModelRewrite resource (determined by - // metadata.creationTimestamp) will be used. + // Across all rules specified on applicable rewrites, precedence MUST be + // given to the match having an "Exact" model match over a generic match + // (a rule with an empty `matches` array). + // + // If ties still exist across multiple InferenceModelRewrite resources (e.g. + // two rewrites both have an exact match for the same model), matching + // precedence MUST be determined by the oldest resource based on + // creation timestamp. + // + // If ties still exist within a single InferenceModelRewrite resource, the + // FIRST matching rule (in list order) is used. // +required Rules []InferenceModelRewriteRule `json:"rules"` } // InferenceModelRewriteRule defines the match criteria and corresponding action. -// -// A specific model name can only be matched by one rule across all -// rewrites attached to the same InferencePool. If multiple rules attempt -// to match the same model name, the oldest rule (by creationTimestamp) -// will be the only one considered valid. +// For details on how precedence is determined across multiple rules and +// InferenceModelRewrite resources, see the "Precedence and Conflict Resolution" +// section in InferenceModelRewriteSpec. type InferenceModelRewriteRule struct { // Matches defines the criteria for matching a request. // If multiple match criteria are specified, a request matches if diff --git a/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go b/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go index 0649c7dd8..8365aff0b 100644 --- a/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go +++ b/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go @@ -41,79 +41,59 @@ import ( var ( poolForRewrite = utiltest.MakeInferencePool("test-pool1").Namespace("ns1").ObjRef() - rewrite1 = makeInferenceModelRewrite("rewrite1"). - Namespace(poolForRewrite.Namespace). - PoolName(poolForRewrite.Name). - CreationTimestamp(metav1.Unix(1000, 0)). - ObjRef() - rewrite1Pool2 = makeInferenceModelRewrite(rewrite1.Name). - Namespace(rewrite1.Namespace). - PoolName("test-pool2"). - CreationTimestamp(metav1.Unix(1001, 0)). - ObjRef() - rewrite1Updated = makeInferenceModelRewrite(rewrite1.Name). - Namespace(rewrite1.Namespace). - PoolName(poolForRewrite.Name). - CreationTimestamp(metav1.Unix(1003, 0)). - Rules([]v1alpha2.InferenceModelRewriteRule{{}}). - ObjRef() - rewrite1Deleted = makeInferenceModelRewrite(rewrite1.Name). - Namespace(rewrite1.Namespace). - PoolName(poolForRewrite.Name). - CreationTimestamp(metav1.Unix(1004, 0)). - DeletionTimestamp(). - ObjRef() - rewrite2 = makeInferenceModelRewrite("rewrite2"). - Namespace(poolForRewrite.Namespace). - PoolName(poolForRewrite.Name). - CreationTimestamp(metav1.Unix(1000, 0)). - ObjRef() -) - -type inferenceModelRewriteBuilder struct { - *v1alpha2.InferenceModelRewrite -} - -func makeInferenceModelRewrite(name string) *inferenceModelRewriteBuilder { - return &inferenceModelRewriteBuilder{ - &v1alpha2.InferenceModelRewrite{ - ObjectMeta: metav1.ObjectMeta{ - Name: name, - }, + rewrite1 = &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: "rewrite1", + Namespace: poolForRewrite.Namespace, + CreationTimestamp: metav1.Unix(1000, 0), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + PoolRef: &v1alpha2.PoolObjectReference{Name: v1alpha2.ObjectName(poolForRewrite.Name)}, }, } -} - -func (b *inferenceModelRewriteBuilder) Namespace(ns string) *inferenceModelRewriteBuilder { - b.ObjectMeta.Namespace = ns - return b -} - -func (b *inferenceModelRewriteBuilder) PoolName(name string) *inferenceModelRewriteBuilder { - b.Spec.PoolRef = &v1alpha2.PoolObjectReference{} - b.Spec.PoolRef.Name = v1alpha2.ObjectName(name) - return b -} - -func (b *inferenceModelRewriteBuilder) CreationTimestamp(t metav1.Time) *inferenceModelRewriteBuilder { - b.ObjectMeta.CreationTimestamp = t - return b -} - -func (b *inferenceModelRewriteBuilder) DeletionTimestamp() *inferenceModelRewriteBuilder { - now := metav1.Now() - b.ObjectMeta.DeletionTimestamp = &now - return b -} - -func (b *inferenceModelRewriteBuilder) Rules(rules []v1alpha2.InferenceModelRewriteRule) *inferenceModelRewriteBuilder { - b.Spec.Rules = rules - return b -} - -func (b *inferenceModelRewriteBuilder) ObjRef() *v1alpha2.InferenceModelRewrite { - return b.InferenceModelRewrite -} + rewrite1Pool2 = &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: rewrite1.Name, + Namespace: rewrite1.Namespace, + CreationTimestamp: metav1.Unix(1001, 0), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + PoolRef: &v1alpha2.PoolObjectReference{Name: "test-pool2"}, + }, + } + rewrite1Updated = &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: rewrite1.Name, + Namespace: rewrite1.Namespace, + CreationTimestamp: metav1.Unix(1003, 0), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + PoolRef: &v1alpha2.PoolObjectReference{Name: v1alpha2.ObjectName(poolForRewrite.Name)}, + Rules: []v1alpha2.InferenceModelRewriteRule{{}}, + }, + } + rewrite1Deleted = &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: rewrite1.Name, + Namespace: rewrite1.Namespace, + CreationTimestamp: metav1.Unix(1004, 0), + DeletionTimestamp: &metav1.Time{Time: time.Now()}, + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + PoolRef: &v1alpha2.PoolObjectReference{Name: v1alpha2.ObjectName(poolForRewrite.Name)}, + }, + } + rewrite2 = &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{ + Name: "rewrite2", + Namespace: poolForRewrite.Namespace, + CreationTimestamp: metav1.Unix(1001, 0), + }, + Spec: v1alpha2.InferenceModelRewriteSpec{ + PoolRef: &v1alpha2.PoolObjectReference{Name: v1alpha2.ObjectName(poolForRewrite.Name)}, + }, + } +) func TestInferenceModelRewriteReconciler(t *testing.T) { tests := []struct { diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index 2cbdbd031..b036d2acf 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -31,6 +31,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" + v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" @@ -62,6 +63,7 @@ type Datastore interface { // InferenceModelRewrite operations RewriteSet(infModelRewrite *v1alpha2.InferenceModelRewrite) RewriteDelete(namespacedName types.NamespacedName) + RewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule RewriteGetAll() []*v1alpha2.InferenceModelRewrite // PodList lists pods matching the given predicate. @@ -77,10 +79,10 @@ func NewDatastore(parentCtx context.Context, epFactory datalayer.EndpointFactory // Initialize with defaults store := &datastore{ parentCtx: parentCtx, - poolAndObjectivesMu: sync.RWMutex{}, pool: nil, + mu: sync.RWMutex{}, objectives: make(map[string]*v1alpha2.InferenceObjective), - rewrites: make(map[types.NamespacedName]*v1alpha2.InferenceModelRewrite), + rewrites: NewModelRewriteStore(), pods: &sync.Map{}, modelServerMetricsPort: modelServerMetricsPort, epf: epFactory, @@ -97,13 +99,13 @@ func NewDatastore(parentCtx context.Context, epFactory datalayer.EndpointFactory type datastore struct { // parentCtx controls the lifecycle of the background metrics goroutines that spawn up by the datastore. parentCtx context.Context - // poolAndObjectivesMu is used to synchronize access to pool and the objectives map. - poolAndObjectivesMu sync.RWMutex - pool *datalayer.EndpointPool - // key: InferenceObjective.Spec.ModelName, value: *InferenceObjective + // mu is used to synchronize access to pool, objectives, and rewrites. + mu sync.RWMutex + pool *v1.InferencePool + // key: InferenceObjective name, value: *InferenceObjective objectives map[string]*v1alpha2.InferenceObjective - // key: types.NamespacedName, value: *v1alpha2.InferenceModelRewrite - rewrites map[types.NamespacedName]*v1alpha2.InferenceModelRewrite + // rewrites store for InferenceModelRewrite objects. + rewrites *ModelRewriteStore // key: types.NamespacedName, value: backendmetrics.PodMetrics pods *sync.Map // modelServerMetricsPort metrics port from EPP command line argument @@ -113,11 +115,11 @@ type datastore struct { } func (ds *datastore) Clear() { - ds.poolAndObjectivesMu.Lock() - defer ds.poolAndObjectivesMu.Unlock() + ds.mu.Lock() + defer ds.mu.Unlock() ds.pool = nil ds.objectives = make(map[string]*v1alpha2.InferenceObjective) - ds.rewrites = make(map[types.NamespacedName]*v1alpha2.InferenceModelRewrite) + ds.rewrites = NewModelRewriteStore() // stop all pods go routines before clearing the pods map. ds.pods.Range(func(_, v any) bool { ds.epf.ReleaseEndpoint(v.(backendmetrics.PodMetrics)) @@ -133,8 +135,8 @@ func (ds *datastore) PoolSet(ctx context.Context, reader client.Reader, endpoint return nil } logger := log.FromContext(ctx) - ds.poolAndObjectivesMu.Lock() - defer ds.poolAndObjectivesMu.Unlock() + ds.mu.Lock() + defer ds.mu.Unlock() oldEndpointPool := ds.pool ds.pool = endpointPool @@ -154,9 +156,9 @@ func (ds *datastore) PoolSet(ctx context.Context, reader client.Reader, endpoint return nil } -func (ds *datastore) PoolGet() (*datalayer.EndpointPool, error) { - ds.poolAndObjectivesMu.RLock() - defer ds.poolAndObjectivesMu.RUnlock() +func (ds *datastore) PoolGet() (*v1.InferencePool, error) { + ds.mu.RLock() + defer ds.mu.RUnlock() if !ds.PoolHasSynced() { return nil, errPoolNotSynced } @@ -164,14 +166,14 @@ func (ds *datastore) PoolGet() (*datalayer.EndpointPool, error) { } func (ds *datastore) PoolHasSynced() bool { - ds.poolAndObjectivesMu.RLock() - defer ds.poolAndObjectivesMu.RUnlock() + ds.mu.RLock() + defer ds.mu.RUnlock() return ds.pool != nil } func (ds *datastore) PoolLabelsMatch(podLabels map[string]string) bool { - ds.poolAndObjectivesMu.RLock() - defer ds.poolAndObjectivesMu.RUnlock() + ds.mu.RLock() + defer ds.mu.RUnlock() if ds.pool == nil { return false } @@ -180,33 +182,29 @@ func (ds *datastore) PoolLabelsMatch(podLabels map[string]string) bool { return poolSelector.Matches(podSet) } +// /// InferenceObjective APIs /// func (ds *datastore) ObjectiveSet(infObjective *v1alpha2.InferenceObjective) { - ds.poolAndObjectivesMu.Lock() - defer ds.poolAndObjectivesMu.Unlock() - // Set the objective. + ds.mu.Lock() + defer ds.mu.Unlock() ds.objectives[infObjective.Name] = infObjective } func (ds *datastore) ObjectiveGet(objectiveName string) *v1alpha2.InferenceObjective { - ds.poolAndObjectivesMu.RLock() - defer ds.poolAndObjectivesMu.RUnlock() - iObj, ok := ds.objectives[objectiveName] - if !ok { - return nil - } - return iObj + ds.mu.RLock() + defer ds.mu.RUnlock() + return ds.objectives[objectiveName] } func (ds *datastore) ObjectiveDelete(namespacedName types.NamespacedName) { - ds.poolAndObjectivesMu.Lock() - defer ds.poolAndObjectivesMu.Unlock() + ds.mu.Lock() + defer ds.mu.Unlock() delete(ds.objectives, namespacedName.Name) } func (ds *datastore) ObjectiveGetAll() []*v1alpha2.InferenceObjective { - ds.poolAndObjectivesMu.RLock() - defer ds.poolAndObjectivesMu.RUnlock() - res := []*v1alpha2.InferenceObjective{} + ds.mu.RLock() + defer ds.mu.RUnlock() + res := make([]*v1alpha2.InferenceObjective, 0, len(ds.objectives)) for _, v := range ds.objectives { res = append(res, v) } @@ -214,25 +212,27 @@ func (ds *datastore) ObjectiveGetAll() []*v1alpha2.InferenceObjective { } func (ds *datastore) RewriteSet(infModelRewrite *v1alpha2.InferenceModelRewrite) { - ds.poolAndObjectivesMu.Lock() - defer ds.poolAndObjectivesMu.Unlock() - ds.rewrites[types.NamespacedName{Name: infModelRewrite.Name, Namespace: infModelRewrite.Namespace}] = infModelRewrite + ds.mu.Lock() + defer ds.mu.Unlock() + ds.rewrites.Set(infModelRewrite) } func (ds *datastore) RewriteDelete(namespacedName types.NamespacedName) { - ds.poolAndObjectivesMu.Lock() - defer ds.poolAndObjectivesMu.Unlock() - delete(ds.rewrites, namespacedName) + ds.mu.Lock() + defer ds.mu.Unlock() + ds.rewrites.Delete(namespacedName) +} + +func (ds *datastore) RewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule { + ds.mu.RLock() + defer ds.mu.RUnlock() + return ds.rewrites.GetRule(modelName) } func (ds *datastore) RewriteGetAll() []*v1alpha2.InferenceModelRewrite { - ds.poolAndObjectivesMu.RLock() - defer ds.poolAndObjectivesMu.RUnlock() - res := []*v1alpha2.InferenceModelRewrite{} - for _, v := range ds.rewrites { - res = append(res, v) - } - return res + ds.mu.RLock() + defer ds.mu.RUnlock() + return ds.rewrites.GetAll() } // /// Pods/endpoints APIs /// diff --git a/pkg/epp/datastore/modelrewritestore.go b/pkg/epp/datastore/modelrewritestore.go new file mode 100644 index 000000000..2ad21d41f --- /dev/null +++ b/pkg/epp/datastore/modelrewritestore.go @@ -0,0 +1,187 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package datastore + +import ( + "sort" + "time" + + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" +) + +// ModelRewriteStore encapsulates the logic for storing and retrieving +// InferenceModelRewrite rules, handling precedence correctly. This struct is not +// thread-safe; concurrency must be managed by its consumer. +type ModelRewriteStore struct { + genericRules []*rewriteRuleWithMetadata + rulesByExactModelMatch map[string][]*rewriteRuleWithMetadata + allReWrites map[types.NamespacedName]*v1alpha2.InferenceModelRewrite +} + +func NewModelRewriteStore() *ModelRewriteStore { + return &ModelRewriteStore{ + genericRules: []*rewriteRuleWithMetadata{}, + rulesByExactModelMatch: map[string][]*rewriteRuleWithMetadata{}, + allReWrites: map[types.NamespacedName]*v1alpha2.InferenceModelRewrite{}, + } +} + +// Set adds or updates an InferenceModelRewrite in the store. It deconstructs the +// object into individual rules and stores them in the appropriate data structures, +// ensuring they remain sorted by precedence. +func (ms *ModelRewriteStore) Set(infModelRewrite *v1alpha2.InferenceModelRewrite) { + nn := getNN(infModelRewrite) + + // If the rewrite object already exists, remove its old rules before adding new ones. + if _, ok := ms.allReWrites[nn]; ok { + ms.deleteInternal(nn) + } + ms.allReWrites[nn] = infModelRewrite + + for i := range infModelRewrite.Spec.Rules { + ruleWithMetadata := newRuleWithMetadata(infModelRewrite, i) + if ruleWithMetadata == nil { + continue + } + + if ruleWithMetadata.isGeneric() { + ms.genericRules = append(ms.genericRules, ruleWithMetadata) + } else { + for model := range ruleWithMetadata.exactModels() { + ms.rulesByExactModelMatch[model] = append(ms.rulesByExactModelMatch[model], ruleWithMetadata) + } + } + } + + // Sort all rule lists by timestamp to maintain precedence. + sort.Slice(ms.genericRules, func(i, j int) bool { + return ms.genericRules[i].createTimestamp.Before(ms.genericRules[j].createTimestamp) + }) + + for model := range ms.rulesByExactModelMatch { + sort.Slice(ms.rulesByExactModelMatch[model], func(i, j int) bool { + return ms.rulesByExactModelMatch[model][i].createTimestamp.Before(ms.rulesByExactModelMatch[model][j].createTimestamp) + }) + } +} + +// Delete removes an InferenceModelRewrite and all its associated rules from the store. +func (ms *ModelRewriteStore) Delete(nn types.NamespacedName) { + ms.deleteInternal(nn) +} + +// deleteInternal is the non-locking implementation for deleting a rewrite. +func (ms *ModelRewriteStore) deleteInternal(nn types.NamespacedName) { + if _, ok := ms.allReWrites[nn]; !ok { + return + } + delete(ms.allReWrites, nn) + + // Filter out the generic rules associated with the deleted rewrite. + newGenericRules := make([]*rewriteRuleWithMetadata, 0, len(ms.genericRules)) + for _, ruleWithMd := range ms.genericRules { + if ruleWithMd.parentNN() != nn { + newGenericRules = append(newGenericRules, ruleWithMd) + } + } + ms.genericRules = newGenericRules + + // Filter out the exact-match rules associated with the deleted rewrite. + for modelName, rulesWithMd := range ms.rulesByExactModelMatch { + newRules := make([]*rewriteRuleWithMetadata, 0, len(rulesWithMd)) + for _, r := range rulesWithMd { + if r.parentNN() != nn { + newRules = append(newRules, r) + } + } + + if len(newRules) == 0 { + delete(ms.rulesByExactModelMatch, modelName) + } else { + ms.rulesByExactModelMatch[modelName] = newRules + } + } +} + +// GetRule returns the single, highest-precedence rule for a given model name. +// It prioritizes exact matches over generic ones, and among those, the oldest rule wins. +func (ms *ModelRewriteStore) GetRule(modelName string) *v1alpha2.InferenceModelRewriteRule { + // Exact matches have the highest precedence. + if rulesWithMd, ok := ms.rulesByExactModelMatch[modelName]; ok && len(rulesWithMd) > 0 { + return &rulesWithMd[0].rule // The list is pre-sorted, so the first element is the oldest. + } + + // If no exact match, fall back to the oldest generic rule. + if len(ms.genericRules) > 0 { + return &ms.genericRules[0].rule // The list is pre-sorted. + } + return nil +} + +// GetAll returns a slice of all InferenceModelRewrite objects currently in the store. +func (ms *ModelRewriteStore) GetAll() []*v1alpha2.InferenceModelRewrite { + rewrites := make([]*v1alpha2.InferenceModelRewrite, 0, len(ms.allReWrites)) + for _, rewrite := range ms.allReWrites { + rewrites = append(rewrites, rewrite) + } + return rewrites +} + +func getNN(infModelRewrite *v1alpha2.InferenceModelRewrite) types.NamespacedName { + return types.NamespacedName{ + Namespace: infModelRewrite.Namespace, + Name: infModelRewrite.Name, + } +} + +// rewriteRuleWithMetadata decorates a rule with metadata from its parent object +// to be used in precedence sorting. +type rewriteRuleWithMetadata struct { + rule v1alpha2.InferenceModelRewriteRule + createTimestamp time.Time + parentRewriteNN types.NamespacedName +} + +func newRuleWithMetadata(infModelRewrite *v1alpha2.InferenceModelRewrite, ruleIdx int) *rewriteRuleWithMetadata { + if ruleIdx >= len(infModelRewrite.Spec.Rules) { + return nil + } + return &rewriteRuleWithMetadata{ + rule: infModelRewrite.Spec.Rules[ruleIdx], + createTimestamp: infModelRewrite.CreationTimestamp.Time, + parentRewriteNN: getNN(infModelRewrite), + } +} + +func (rr rewriteRuleWithMetadata) isGeneric() bool { + return len(rr.rule.Matches) == 0 +} + +func (rr rewriteRuleWithMetadata) exactModels() map[string]bool { + modelSet := map[string]bool{} + for _, match := range rr.rule.Matches { + if match.Model != nil { + modelSet[match.Model.Value] = true + } + } + return modelSet +} + +func (rr rewriteRuleWithMetadata) parentNN() types.NamespacedName { + return rr.parentRewriteNN +} diff --git a/pkg/epp/datastore/modelrewritestore_test.go b/pkg/epp/datastore/modelrewritestore_test.go new file mode 100644 index 000000000..41a1dc088 --- /dev/null +++ b/pkg/epp/datastore/modelrewritestore_test.go @@ -0,0 +1,185 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package datastore + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + + "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" +) + +func TestModelRewriteStore(t *testing.T) { + now := time.Now() + oneMinuteAgo := now.Add(-1 * time.Minute) + + // Define common rules with generic names + ruleModel1V1 := v1alpha2.InferenceModelRewriteRule{ + Matches: []v1alpha2.Match{{Model: &v1alpha2.ModelMatch{Value: "model1"}}}, + Targets: []v1alpha2.TargetModel{{ModelRewrite: "model1-v1"}}, + } + ruleModel1V2 := v1alpha2.InferenceModelRewriteRule{ + Matches: []v1alpha2.Match{{Model: &v1alpha2.ModelMatch{Value: "model1"}}}, + Targets: []v1alpha2.TargetModel{{ModelRewrite: "model1-v2"}}, + } + ruleGeneric := v1alpha2.InferenceModelRewriteRule{ + Targets: []v1alpha2.TargetModel{{ModelRewrite: "generic-fallback"}}, + } + + // Define rewrite objects using plain structs + rewriteOld := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{Name: "rewrite-old", Namespace: "default", CreationTimestamp: metav1.NewTime(oneMinuteAgo)}, + Spec: v1alpha2.InferenceModelRewriteSpec{Rules: []v1alpha2.InferenceModelRewriteRule{ruleModel1V1}}, + } + rewriteNew := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{Name: "rewrite-new", Namespace: "default", CreationTimestamp: metav1.NewTime(now)}, + Spec: v1alpha2.InferenceModelRewriteSpec{Rules: []v1alpha2.InferenceModelRewriteRule{ruleModel1V2}}, + } + rewriteGenericOld := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{Name: "rewrite-generic-old", Namespace: "default", CreationTimestamp: metav1.NewTime(oneMinuteAgo)}, + Spec: v1alpha2.InferenceModelRewriteSpec{Rules: []v1alpha2.InferenceModelRewriteRule{ruleGeneric}}, + } + rewriteGenericNew := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{Name: "rewrite-generic-new", Namespace: "default", CreationTimestamp: metav1.NewTime(now)}, + Spec: v1alpha2.InferenceModelRewriteSpec{Rules: []v1alpha2.InferenceModelRewriteRule{{Targets: []v1alpha2.TargetModel{{ModelRewrite: "new-generic"}}}}}, + } + rewriteUpdated := &v1alpha2.InferenceModelRewrite{ + ObjectMeta: metav1.ObjectMeta{Name: "rewrite-old", Namespace: "default", CreationTimestamp: metav1.NewTime(now)}, // Same name as rewriteOld + Spec: v1alpha2.InferenceModelRewriteSpec{Rules: []v1alpha2.InferenceModelRewriteRule{ruleModel1V2}}, + } + + tests := []struct { + name string + initialState []*v1alpha2.InferenceModelRewrite + op func(store *ModelRewriteStore) + modelToGet string + wantRule *v1alpha2.InferenceModelRewriteRule + wantGetAll []*v1alpha2.InferenceModelRewrite + }{ + { + name: "Simple exact match", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteOld}, + modelToGet: "model1", + wantRule: &ruleModel1V1, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteOld}, + }, + { + name: "Simple generic match", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteGenericOld}, + modelToGet: "model2", // A different model to test generic fallback + wantRule: &ruleGeneric, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteGenericOld}, + }, + { + name: "No match", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteOld}, + modelToGet: "model2", + wantRule: nil, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteOld}, + }, + { + name: "Precedence: Exact match wins over generic", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteOld, rewriteGenericOld}, + modelToGet: "model1", + wantRule: &ruleModel1V1, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteOld, rewriteGenericOld}, + }, + { + name: "Precedence: Fallback to generic when no exact match", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteOld, rewriteGenericOld}, + modelToGet: "model2", + wantRule: &ruleGeneric, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteOld, rewriteGenericOld}, + }, + { + name: "Precedence: Oldest exact match wins", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteNew, rewriteOld}, + modelToGet: "model1", + wantRule: &ruleModel1V1, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteNew, rewriteOld}, + }, + { + name: "Precedence: Oldest generic match wins", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteGenericNew, rewriteGenericOld}, + modelToGet: "any-model", + wantRule: &ruleGeneric, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteGenericNew, rewriteGenericOld}, + }, + { + name: "Delete: successfully deletes a rewrite", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteOld, rewriteGenericOld}, + op: func(store *ModelRewriteStore) { + store.Delete(getNN(rewriteOld)) + }, + modelToGet: "model1", + wantRule: &ruleGeneric, // Falls back to generic + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteGenericOld}, + }, + { + name: "Delete: non-existent rewrite does nothing", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteOld}, + op: func(store *ModelRewriteStore) { + store.Delete(types.NamespacedName{Name: "non-existent", Namespace: "default"}) + }, + modelToGet: "model1", + wantRule: &ruleModel1V1, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteOld}, + }, + { + name: "Update: Setting a rewrite with the same name replaces the old one", + initialState: []*v1alpha2.InferenceModelRewrite{rewriteOld}, + op: func(store *ModelRewriteStore) { + store.Set(rewriteUpdated) + }, + modelToGet: "model1", + wantRule: &ruleModel1V2, + wantGetAll: []*v1alpha2.InferenceModelRewrite{rewriteUpdated}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + store := NewModelRewriteStore() + for _, r := range tc.initialState { + store.Set(r) + } + + if tc.op != nil { + tc.op(store) + } + + gotRule := store.GetRule(tc.modelToGet) + if diff := cmp.Diff(tc.wantRule, gotRule); diff != "" { + t.Errorf("GetRule() mismatch (-want +got):\n%s", diff) + } + + if tc.wantGetAll != nil { + gotAll := store.GetAll() + if diff := cmp.Diff(tc.wantGetAll, gotAll, cmpopts.SortSlices(func(a, b *v1alpha2.InferenceModelRewrite) bool { + return getNN(a).String() < getNN(b).String() + })); diff != "" { + t.Errorf("GetAll() mismatch (-want +got):\n%s", diff) + } + } + }) + } +} diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index dd4b3424b..3d13c0f7e 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -23,7 +23,6 @@ import ( "fmt" "math/rand" "net" - "sort" "strings" "time" @@ -51,7 +50,7 @@ type Datastore interface { PoolGet() (*datalayer.EndpointPool, error) ObjectiveGet(objectiveName string) *v1alpha2.InferenceObjective PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics - RewriteGetAll() []*v1alpha2.InferenceModelRewrite + RewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule } // Scheduler defines the interface required by the Director for scheduling. @@ -195,25 +194,11 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo } func (d *Director) applyWeightedModelRewrite(reqCtx *handlers.RequestContext) { - rewrites := d.datastore.RewriteGetAll() - if len(rewrites) == 0 { + rewriteRule := d.datastore.RewriteGet(reqCtx.IncomingModelName) + if rewriteRule == nil { return } - - sort.Slice(rewrites, func(i, j int) bool { - return rewrites[i].CreationTimestamp.Before(&rewrites[j].CreationTimestamp) - }) - - for _, rewrite := range rewrites { - for _, rule := range rewrite.Spec.Rules { - for _, match := range rule.Matches { - if match.Model != nil && match.Model.Value == reqCtx.IncomingModelName { - reqCtx.TargetModelName = d.selectWeightedModel(rule.Targets) - return - } - } - } - } + reqCtx.TargetModelName = d.selectWeightedModel(rewriteRule.Targets) } func (d *Director) selectWeightedModel(models []v1alpha2.TargetModel) string { diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index d31bf6f06..9fde14bc8 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "maps" + "sort" "testing" "time" @@ -168,8 +169,32 @@ func (m mockProducedDataType) Clone() datalayer.Cloneable { return mockProducedDataType{value: m.value} } -func (ds *mockDatastore) RewriteGetAll() []*v1alpha2.InferenceModelRewrite { - return ds.rewrites +func (ds *mockDatastore) RewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule { + // This mock implementation simulates the precedence logic for simplicity. + // It finds the oldest rewrite that has a rule matching the modelName. + var matchingRewrites []*v1alpha2.InferenceModelRewrite + for _, r := range ds.rewrites { + for _, rule := range r.Spec.Rules { + for _, match := range rule.Matches { + if match.Model != nil && match.Model.Value == modelName { + matchingRewrites = append(matchingRewrites, r) + break // break inner loop + } + } + } + } + + if len(matchingRewrites) == 0 { + return nil + } + + // Sort by timestamp to find the oldest. + sort.Slice(matchingRewrites, func(i, j int) bool { + return matchingRewrites[i].CreationTimestamp.Before(&matchingRewrites[j].CreationTimestamp) + }) + + // Return the first rule from the oldest rewrite. + return &matchingRewrites[0].Spec.Rules[0] } func TestDirector_HandleRequest(t *testing.T) { @@ -224,7 +249,6 @@ func TestDirector_HandleRequest(t *testing.T) { }, }, } - pool := &v1.InferencePool{ ObjectMeta: metav1.ObjectMeta{Name: "test-pool", Namespace: "default"}, @@ -636,6 +660,9 @@ func TestDirector_HandleRequest(t *testing.T) { } config = config.WithAdmissionPlugins(newMockAdmissionPlugin("test-admit-plugin", test.admitRequestDenialError)) director := NewDirectorWithConfig(ds, mockSched, test.mockAdmissionController, config) + if test.name == "successful request with model rewrite" { + director.datastore = &mockDatastore{pods: ds.PodList(backendmetrics.AllPodsPredicate), rewrites: []*v1alpha2.InferenceModelRewrite{rewrite}} + } reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ From 19d12d647c39e1e0f702676832733fa6adca6d6c Mon Sep 17 00:00:00 2001 From: bobzetian Date: Thu, 20 Nov 2025 11:17:06 +0000 Subject: [PATCH 4/6] register controller. --- .../charts/inferencepool/templates/rbac.yaml | 2 +- .../inferencemodelrewrite_reconciler.go | 4 +- .../inferencemodelrewrite_reconciler_test.go | 34 ++++++-- pkg/epp/datastore/datastore.go | 19 ++-- pkg/epp/datastore/modelrewritestore.go | 86 ++++++++----------- pkg/epp/datastore/modelrewritestore_test.go | 24 +++--- pkg/epp/server/controller_manager.go | 10 +++ pkg/epp/server/runserver.go | 8 ++ test/testdata/inferencepool-e2e.yaml | 2 +- .../inferencepool-leader-election-e2e.yaml | 2 +- 10 files changed, 105 insertions(+), 86 deletions(-) diff --git a/config/charts/inferencepool/templates/rbac.yaml b/config/charts/inferencepool/templates/rbac.yaml index ff66aebb1..dc6b3e0c4 100644 --- a/config/charts/inferencepool/templates/rbac.yaml +++ b/config/charts/inferencepool/templates/rbac.yaml @@ -46,7 +46,7 @@ metadata: {{- include "gateway-api-inference-extension.labels" . | nindent 4 }} rules: - apiGroups: ["inference.networking.x-k8s.io"] - resources: ["inferenceobjectives"] + resources: ["inferenceobjectives", "inferencemodelrewrites"] verbs: ["get", "watch", "list"] - apiGroups: ["{{ (split "/" .Values.inferencePool.apiVersion)._0 }}"] resources: ["inferencepools"] diff --git a/pkg/epp/controller/inferencemodelrewrite_reconciler.go b/pkg/epp/controller/inferencemodelrewrite_reconciler.go index 8935e4115..f3029ecc6 100644 --- a/pkg/epp/controller/inferencemodelrewrite_reconciler.go +++ b/pkg/epp/controller/inferencemodelrewrite_reconciler.go @@ -54,7 +54,9 @@ func (c *InferenceModelRewriteReconciler) Reconcile(ctx context.Context, req ctr notFound = true } - if notFound || !infModelRewrite.DeletionTimestamp.IsZero() || infModelRewrite.Spec.PoolRef.Name != v1alpha2.ObjectName(c.PoolGKNN.Name) { + if notFound || !infModelRewrite.DeletionTimestamp.IsZero() || infModelRewrite.Spec.PoolRef == nil || + infModelRewrite.Spec.PoolRef.Name != v1alpha2.ObjectName(c.PoolGKNN.Name) || + (infModelRewrite.Spec.PoolRef.Group != v1alpha2.Group(c.PoolGKNN.Group) && infModelRewrite.Spec.PoolRef.Group != "inference.networking.x-k8s.io") { // InferenceModelRewrite object got deleted or changed the referenced pool. c.Datastore.RewriteDelete(req.NamespacedName) return ctrl.Result{}, nil diff --git a/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go b/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go index 8365aff0b..346d0a9e8 100644 --- a/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go +++ b/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go @@ -36,6 +36,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/common" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + poolutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pool" utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) @@ -48,7 +49,10 @@ var ( CreationTimestamp: metav1.Unix(1000, 0), }, Spec: v1alpha2.InferenceModelRewriteSpec{ - PoolRef: &v1alpha2.PoolObjectReference{Name: v1alpha2.ObjectName(poolForRewrite.Name)}, + PoolRef: &v1alpha2.PoolObjectReference{ + Name: v1alpha2.ObjectName(poolForRewrite.Name), + Group: v1alpha2.Group(poolForRewrite.GroupVersionKind().Group), + }, }, } rewrite1Pool2 = &v1alpha2.InferenceModelRewrite{ @@ -58,7 +62,10 @@ var ( CreationTimestamp: metav1.Unix(1001, 0), }, Spec: v1alpha2.InferenceModelRewriteSpec{ - PoolRef: &v1alpha2.PoolObjectReference{Name: "test-pool2"}, + PoolRef: &v1alpha2.PoolObjectReference{ + Name: "test-pool2", + Group: v1alpha2.Group(poolForRewrite.GroupVersionKind().Group), + }, }, } rewrite1Updated = &v1alpha2.InferenceModelRewrite{ @@ -68,8 +75,11 @@ var ( CreationTimestamp: metav1.Unix(1003, 0), }, Spec: v1alpha2.InferenceModelRewriteSpec{ - PoolRef: &v1alpha2.PoolObjectReference{Name: v1alpha2.ObjectName(poolForRewrite.Name)}, - Rules: []v1alpha2.InferenceModelRewriteRule{{}}, + PoolRef: &v1alpha2.PoolObjectReference{ + Name: v1alpha2.ObjectName(poolForRewrite.Name), + Group: v1alpha2.Group(poolForRewrite.GroupVersionKind().Group), + }, + Rules: []v1alpha2.InferenceModelRewriteRule{{}}, }, } rewrite1Deleted = &v1alpha2.InferenceModelRewrite{ @@ -78,9 +88,13 @@ var ( Namespace: rewrite1.Namespace, CreationTimestamp: metav1.Unix(1004, 0), DeletionTimestamp: &metav1.Time{Time: time.Now()}, + Finalizers: []string{"deleted"}, }, Spec: v1alpha2.InferenceModelRewriteSpec{ - PoolRef: &v1alpha2.PoolObjectReference{Name: v1alpha2.ObjectName(poolForRewrite.Name)}, + PoolRef: &v1alpha2.PoolObjectReference{ + Name: v1alpha2.ObjectName(poolForRewrite.Name), + Group: v1alpha2.Group(poolForRewrite.GroupVersionKind().Group), + }, }, } rewrite2 = &v1alpha2.InferenceModelRewrite{ @@ -90,7 +104,10 @@ var ( CreationTimestamp: metav1.Unix(1001, 0), }, Spec: v1alpha2.InferenceModelRewriteSpec{ - PoolRef: &v1alpha2.PoolObjectReference{Name: v1alpha2.ObjectName(poolForRewrite.Name)}, + PoolRef: &v1alpha2.PoolObjectReference{ + Name: v1alpha2.ObjectName(poolForRewrite.Name), + Group: v1alpha2.Group(poolForRewrite.GroupVersionKind().Group), + }, }, } ) @@ -155,7 +172,7 @@ func TestInferenceModelRewriteReconciler(t *testing.T) { _ = v1alpha2.Install(scheme) _ = v1.Install(scheme) initObjs := []client.Object{} - if test.rewrite != nil && test.rewrite.DeletionTimestamp.IsZero() { + if test.rewrite != nil { initObjs = append(initObjs, test.rewrite) } for _, r := range test.rewritesInAPIServer { @@ -170,7 +187,8 @@ func TestInferenceModelRewriteReconciler(t *testing.T) { for _, r := range test.rewritesInStore { ds.RewriteSet(r) } - _ = ds.PoolSet(context.Background(), fakeClient, poolForRewrite) + endpointPool := poolutil.InferencePoolToEndpointPool(poolForRewrite) + _ = ds.PoolSet(context.Background(), fakeClient, endpointPool) reconciler := &InferenceModelRewriteReconciler{ Reader: fakeClient, Datastore: ds, diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index b036d2acf..6b3f4e3f1 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -31,7 +31,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" - v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" @@ -82,7 +81,7 @@ func NewDatastore(parentCtx context.Context, epFactory datalayer.EndpointFactory pool: nil, mu: sync.RWMutex{}, objectives: make(map[string]*v1alpha2.InferenceObjective), - rewrites: NewModelRewriteStore(), + rewrites: newModelRewriteStore(), pods: &sync.Map{}, modelServerMetricsPort: modelServerMetricsPort, epf: epFactory, @@ -101,11 +100,11 @@ type datastore struct { parentCtx context.Context // mu is used to synchronize access to pool, objectives, and rewrites. mu sync.RWMutex - pool *v1.InferencePool + pool *datalayer.EndpointPool // key: InferenceObjective name, value: *InferenceObjective objectives map[string]*v1alpha2.InferenceObjective // rewrites store for InferenceModelRewrite objects. - rewrites *ModelRewriteStore + rewrites *modelRewriteStore // key: types.NamespacedName, value: backendmetrics.PodMetrics pods *sync.Map // modelServerMetricsPort metrics port from EPP command line argument @@ -119,7 +118,7 @@ func (ds *datastore) Clear() { defer ds.mu.Unlock() ds.pool = nil ds.objectives = make(map[string]*v1alpha2.InferenceObjective) - ds.rewrites = NewModelRewriteStore() + ds.rewrites = newModelRewriteStore() // stop all pods go routines before clearing the pods map. ds.pods.Range(func(_, v any) bool { ds.epf.ReleaseEndpoint(v.(backendmetrics.PodMetrics)) @@ -156,7 +155,7 @@ func (ds *datastore) PoolSet(ctx context.Context, reader client.Reader, endpoint return nil } -func (ds *datastore) PoolGet() (*v1.InferencePool, error) { +func (ds *datastore) PoolGet() (*datalayer.EndpointPool, error) { ds.mu.RLock() defer ds.mu.RUnlock() if !ds.PoolHasSynced() { @@ -214,25 +213,25 @@ func (ds *datastore) ObjectiveGetAll() []*v1alpha2.InferenceObjective { func (ds *datastore) RewriteSet(infModelRewrite *v1alpha2.InferenceModelRewrite) { ds.mu.Lock() defer ds.mu.Unlock() - ds.rewrites.Set(infModelRewrite) + ds.rewrites.set(infModelRewrite) } func (ds *datastore) RewriteDelete(namespacedName types.NamespacedName) { ds.mu.Lock() defer ds.mu.Unlock() - ds.rewrites.Delete(namespacedName) + ds.rewrites.delete(namespacedName) } func (ds *datastore) RewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule { ds.mu.RLock() defer ds.mu.RUnlock() - return ds.rewrites.GetRule(modelName) + return ds.rewrites.getRule(modelName) } func (ds *datastore) RewriteGetAll() []*v1alpha2.InferenceModelRewrite { ds.mu.RLock() defer ds.mu.RUnlock() - return ds.rewrites.GetAll() + return ds.rewrites.getAll() } // /// Pods/endpoints APIs /// diff --git a/pkg/epp/datastore/modelrewritestore.go b/pkg/epp/datastore/modelrewritestore.go index 2ad21d41f..e85e52c1a 100644 --- a/pkg/epp/datastore/modelrewritestore.go +++ b/pkg/epp/datastore/modelrewritestore.go @@ -24,39 +24,39 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" ) -// ModelRewriteStore encapsulates the logic for storing and retrieving +// modelRewriteStore encapsulates the logic for storing and retrieving // InferenceModelRewrite rules, handling precedence correctly. This struct is not // thread-safe; concurrency must be managed by its consumer. -type ModelRewriteStore struct { +type modelRewriteStore struct { genericRules []*rewriteRuleWithMetadata rulesByExactModelMatch map[string][]*rewriteRuleWithMetadata - allReWrites map[types.NamespacedName]*v1alpha2.InferenceModelRewrite + allReWrites map[string]*v1alpha2.InferenceModelRewrite } -func NewModelRewriteStore() *ModelRewriteStore { - return &ModelRewriteStore{ +func newModelRewriteStore() *modelRewriteStore { + return &modelRewriteStore{ genericRules: []*rewriteRuleWithMetadata{}, - rulesByExactModelMatch: map[string][]*rewriteRuleWithMetadata{}, - allReWrites: map[types.NamespacedName]*v1alpha2.InferenceModelRewrite{}, + rulesByExactModelMatch: map[string][]*rewriteRuleWithMetadata{}, // Key is the exact model name. + allReWrites: map[string]*v1alpha2.InferenceModelRewrite{}, // Key is the rewrites name. } } -// Set adds or updates an InferenceModelRewrite in the store. It deconstructs the +// set adds or updates an InferenceModelRewrite in the store. It deconstructs the // object into individual rules and stores them in the appropriate data structures, // ensuring they remain sorted by precedence. -func (ms *ModelRewriteStore) Set(infModelRewrite *v1alpha2.InferenceModelRewrite) { - nn := getNN(infModelRewrite) - +func (ms *modelRewriteStore) set(infModelRewrite *v1alpha2.InferenceModelRewrite) { + name := infModelRewrite.Name // If the rewrite object already exists, remove its old rules before adding new ones. - if _, ok := ms.allReWrites[nn]; ok { - ms.deleteInternal(nn) + if _, ok := ms.allReWrites[name]; ok { + ms.deleteInternal(infModelRewrite.Name) } - ms.allReWrites[nn] = infModelRewrite + ms.allReWrites[name] = infModelRewrite for i := range infModelRewrite.Spec.Rules { - ruleWithMetadata := newRuleWithMetadata(infModelRewrite, i) - if ruleWithMetadata == nil { - continue + ruleWithMetadata := &rewriteRuleWithMetadata{ + rule: infModelRewrite.Spec.Rules[i], + createTimestamp: infModelRewrite.CreationTimestamp.Time, + parentRewriteName: name, } if ruleWithMetadata.isGeneric() { @@ -80,22 +80,22 @@ func (ms *ModelRewriteStore) Set(infModelRewrite *v1alpha2.InferenceModelRewrite } } -// Delete removes an InferenceModelRewrite and all its associated rules from the store. -func (ms *ModelRewriteStore) Delete(nn types.NamespacedName) { - ms.deleteInternal(nn) +// delete removes an InferenceModelRewrite and all its associated rules from the store. +func (ms *modelRewriteStore) delete(nn types.NamespacedName) { + ms.deleteInternal(nn.Name) } // deleteInternal is the non-locking implementation for deleting a rewrite. -func (ms *ModelRewriteStore) deleteInternal(nn types.NamespacedName) { - if _, ok := ms.allReWrites[nn]; !ok { +func (ms *modelRewriteStore) deleteInternal(n string) { + if _, ok := ms.allReWrites[n]; !ok { return } - delete(ms.allReWrites, nn) + delete(ms.allReWrites, n) // Filter out the generic rules associated with the deleted rewrite. newGenericRules := make([]*rewriteRuleWithMetadata, 0, len(ms.genericRules)) for _, ruleWithMd := range ms.genericRules { - if ruleWithMd.parentNN() != nn { + if ruleWithMd.parentName() != n { newGenericRules = append(newGenericRules, ruleWithMd) } } @@ -105,7 +105,7 @@ func (ms *ModelRewriteStore) deleteInternal(nn types.NamespacedName) { for modelName, rulesWithMd := range ms.rulesByExactModelMatch { newRules := make([]*rewriteRuleWithMetadata, 0, len(rulesWithMd)) for _, r := range rulesWithMd { - if r.parentNN() != nn { + if r.parentName() != n { newRules = append(newRules, r) } } @@ -118,9 +118,9 @@ func (ms *ModelRewriteStore) deleteInternal(nn types.NamespacedName) { } } -// GetRule returns the single, highest-precedence rule for a given model name. +// getRule returns the single, highest-precedence rule for a given model name. // It prioritizes exact matches over generic ones, and among those, the oldest rule wins. -func (ms *ModelRewriteStore) GetRule(modelName string) *v1alpha2.InferenceModelRewriteRule { +func (ms *modelRewriteStore) getRule(modelName string) *v1alpha2.InferenceModelRewriteRule { // Exact matches have the highest precedence. if rulesWithMd, ok := ms.rulesByExactModelMatch[modelName]; ok && len(rulesWithMd) > 0 { return &rulesWithMd[0].rule // The list is pre-sorted, so the first element is the oldest. @@ -133,8 +133,8 @@ func (ms *ModelRewriteStore) GetRule(modelName string) *v1alpha2.InferenceModelR return nil } -// GetAll returns a slice of all InferenceModelRewrite objects currently in the store. -func (ms *ModelRewriteStore) GetAll() []*v1alpha2.InferenceModelRewrite { +// getAll returns a slice of all InferenceModelRewrite objects currently in the store. +func (ms *modelRewriteStore) getAll() []*v1alpha2.InferenceModelRewrite { rewrites := make([]*v1alpha2.InferenceModelRewrite, 0, len(ms.allReWrites)) for _, rewrite := range ms.allReWrites { rewrites = append(rewrites, rewrite) @@ -142,30 +142,12 @@ func (ms *ModelRewriteStore) GetAll() []*v1alpha2.InferenceModelRewrite { return rewrites } -func getNN(infModelRewrite *v1alpha2.InferenceModelRewrite) types.NamespacedName { - return types.NamespacedName{ - Namespace: infModelRewrite.Namespace, - Name: infModelRewrite.Name, - } -} - // rewriteRuleWithMetadata decorates a rule with metadata from its parent object // to be used in precedence sorting. type rewriteRuleWithMetadata struct { - rule v1alpha2.InferenceModelRewriteRule - createTimestamp time.Time - parentRewriteNN types.NamespacedName -} - -func newRuleWithMetadata(infModelRewrite *v1alpha2.InferenceModelRewrite, ruleIdx int) *rewriteRuleWithMetadata { - if ruleIdx >= len(infModelRewrite.Spec.Rules) { - return nil - } - return &rewriteRuleWithMetadata{ - rule: infModelRewrite.Spec.Rules[ruleIdx], - createTimestamp: infModelRewrite.CreationTimestamp.Time, - parentRewriteNN: getNN(infModelRewrite), - } + rule v1alpha2.InferenceModelRewriteRule + createTimestamp time.Time + parentRewriteName string } func (rr rewriteRuleWithMetadata) isGeneric() bool { @@ -182,6 +164,6 @@ func (rr rewriteRuleWithMetadata) exactModels() map[string]bool { return modelSet } -func (rr rewriteRuleWithMetadata) parentNN() types.NamespacedName { - return rr.parentRewriteNN +func (rr rewriteRuleWithMetadata) parentName() string { + return rr.parentRewriteName } diff --git a/pkg/epp/datastore/modelrewritestore_test.go b/pkg/epp/datastore/modelrewritestore_test.go index 41a1dc088..caf1f2f3d 100644 --- a/pkg/epp/datastore/modelrewritestore_test.go +++ b/pkg/epp/datastore/modelrewritestore_test.go @@ -70,7 +70,7 @@ func TestModelRewriteStore(t *testing.T) { tests := []struct { name string initialState []*v1alpha2.InferenceModelRewrite - op func(store *ModelRewriteStore) + op func(store *modelRewriteStore) modelToGet string wantRule *v1alpha2.InferenceModelRewriteRule wantGetAll []*v1alpha2.InferenceModelRewrite @@ -127,8 +127,8 @@ func TestModelRewriteStore(t *testing.T) { { name: "Delete: successfully deletes a rewrite", initialState: []*v1alpha2.InferenceModelRewrite{rewriteOld, rewriteGenericOld}, - op: func(store *ModelRewriteStore) { - store.Delete(getNN(rewriteOld)) + op: func(store *modelRewriteStore) { + store.delete(types.NamespacedName{Namespace: rewriteOld.Namespace, Name: rewriteOld.Name}) }, modelToGet: "model1", wantRule: &ruleGeneric, // Falls back to generic @@ -137,8 +137,8 @@ func TestModelRewriteStore(t *testing.T) { { name: "Delete: non-existent rewrite does nothing", initialState: []*v1alpha2.InferenceModelRewrite{rewriteOld}, - op: func(store *ModelRewriteStore) { - store.Delete(types.NamespacedName{Name: "non-existent", Namespace: "default"}) + op: func(store *modelRewriteStore) { + store.delete(types.NamespacedName{Name: "non-existent", Namespace: "default"}) }, modelToGet: "model1", wantRule: &ruleModel1V1, @@ -147,8 +147,8 @@ func TestModelRewriteStore(t *testing.T) { { name: "Update: Setting a rewrite with the same name replaces the old one", initialState: []*v1alpha2.InferenceModelRewrite{rewriteOld}, - op: func(store *ModelRewriteStore) { - store.Set(rewriteUpdated) + op: func(store *modelRewriteStore) { + store.set(rewriteUpdated) }, modelToGet: "model1", wantRule: &ruleModel1V2, @@ -158,24 +158,24 @@ func TestModelRewriteStore(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - store := NewModelRewriteStore() + store := newModelRewriteStore() for _, r := range tc.initialState { - store.Set(r) + store.set(r) } if tc.op != nil { tc.op(store) } - gotRule := store.GetRule(tc.modelToGet) + gotRule := store.getRule(tc.modelToGet) if diff := cmp.Diff(tc.wantRule, gotRule); diff != "" { t.Errorf("GetRule() mismatch (-want +got):\n%s", diff) } if tc.wantGetAll != nil { - gotAll := store.GetAll() + gotAll := store.getAll() if diff := cmp.Diff(tc.wantGetAll, gotAll, cmpopts.SortSlices(func(a, b *v1alpha2.InferenceModelRewrite) bool { - return getNN(a).String() < getNN(b).String() + return a.Name < b.Name })); diff != "" { t.Errorf("GetAll() mismatch (-want +got):\n%s", diff) } diff --git a/pkg/epp/server/controller_manager.go b/pkg/epp/server/controller_manager.go index c82b0bcb9..acc0bd51b 100644 --- a/pkg/epp/server/controller_manager.go +++ b/pkg/epp/server/controller_manager.go @@ -55,6 +55,16 @@ func defaultManagerOptions(disableK8sCrdReconcile bool, gknn common.GKNN, metric gknn.Namespace: {}, }, }, + &v1alpha2.InferenceObjective{}: { + Namespaces: map[string]cache.Config{ + gknn.Namespace: {}, + }, + }, + &v1alpha2.InferenceModelRewrite{}: { + Namespaces: map[string]cache.Config{ + gknn.Namespace: {}, + }, + }, }, }, Metrics: metricsServerOptions, diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index e43d84923..83d5c40bc 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -131,6 +131,14 @@ func (r *ExtProcServerRunner) SetupWithManager(ctx context.Context, mgr ctrl.Man } } + if err := (&controller.InferenceModelRewriteReconciler{ + Datastore: r.Datastore, + Reader: mgr.GetClient(), + PoolGKNN: r.GKNN, + }).SetupWithManager(ctx, mgr); err != nil { + return fmt.Errorf("failed setting up InferenceModelRewriteReconciler: %w", err) + } + if err := (&controller.PodReconciler{ Datastore: r.Datastore, Reader: mgr.GetClient(), diff --git a/test/testdata/inferencepool-e2e.yaml b/test/testdata/inferencepool-e2e.yaml index 77d454e6f..4aea814ee 100644 --- a/test/testdata/inferencepool-e2e.yaml +++ b/test/testdata/inferencepool-e2e.yaml @@ -44,7 +44,7 @@ metadata: namespace: $E2E_NS rules: - apiGroups: [ "inference.networking.x-k8s.io" ] - resources: [ "inferenceobjectives", "inferencepools" ] + resources: [ "inferenceobjectives", "inferencepools", "inferencemodelrewrites" ] verbs: [ "get", "watch", "list" ] - apiGroups: [ "inference.networking.k8s.io" ] resources: [ "inferencepools" ] diff --git a/test/testdata/inferencepool-leader-election-e2e.yaml b/test/testdata/inferencepool-leader-election-e2e.yaml index 976fbbd02..ede2d6a40 100644 --- a/test/testdata/inferencepool-leader-election-e2e.yaml +++ b/test/testdata/inferencepool-leader-election-e2e.yaml @@ -42,7 +42,7 @@ metadata: namespace: $E2E_NS rules: - apiGroups: [ "inference.networking.x-k8s.io" ] - resources: [ "inferenceobjectives", "inferencepools" ] + resources: [ "inferenceobjectives", "inferencepools", "inferencemodelrewrites" ] verbs: [ "get", "watch", "list" ] - apiGroups: [ "inference.networking.k8s.io" ] resources: [ "inferencepools" ] From 8923424ffd3d414842be26c1b082f821d2b92412 Mon Sep 17 00:00:00 2001 From: bobzetian Date: Mon, 24 Nov 2025 19:22:05 +0000 Subject: [PATCH 5/6] fix e2e. --- test/e2e/epp/e2e_suite_test.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/e2e/epp/e2e_suite_test.go b/test/e2e/epp/e2e_suite_test.go index 504e59e2e..bb8c0eb6b 100644 --- a/test/e2e/epp/e2e_suite_test.go +++ b/test/e2e/epp/e2e_suite_test.go @@ -68,6 +68,8 @@ const ( xInferPoolManifest = "../../../config/crd/bases/inference.networking.x-k8s.io_inferencepools.yaml" // xInferObjectiveManifest is the manifest for the inference model CRD with 'inference.networking.x-k8s.io' group. xInferObjectiveManifest = "../../../config/crd/bases/inference.networking.x-k8s.io_inferenceobjectives.yaml" + // xInferenceModelRewritesManifest is the manifest for the inference rewrites CRD with 'inference.networking.x-k8s.io' group. + xInferenceModelRewritesManifest = "../../../config/crd/bases/inference.networking.x-k8s.io_inferencemodelrewrites.yaml" // inferPoolManifest is the manifest for the inference pool CRD with 'inference.networking.k8s.io' group. inferPoolManifest = "../../../config/crd/bases/inference.networking.k8s.io_inferencepools.yaml" // inferExtManifestDefault is the manifest for the default inference extension test resources (single replica). @@ -132,9 +134,10 @@ func setupInfra() { createHfSecret(testConfig, modelServerSecretManifest) } crds := map[string]string{ - "inferencepools.inference.networking.x-k8s.io": xInferPoolManifest, - "inferenceobjectives.inference.networking.x-k8s.io": xInferObjectiveManifest, - "inferencepools.inference.networking.k8s.io": inferPoolManifest, + "inferencepools.inference.networking.x-k8s.io": xInferPoolManifest, + "inferenceobjectives.inference.networking.x-k8s.io": xInferObjectiveManifest, + "inferencemodelrewrites.inference.networking.x-k8s.io": xInferenceModelRewritesManifest, + "inferencepools.inference.networking.k8s.io": inferPoolManifest, } createCRDs(testConfig, crds) From a109e9030cf402040bc5ccab0c0e108de4b41662 Mon Sep 17 00:00:00 2001 From: bobzetian Date: Mon, 24 Nov 2025 22:31:40 +0000 Subject: [PATCH 6/6] add integration test, rename datasoter interface, change split to targets. --- apix/v1alpha2/inferencemodelrewrite_types.go | 2 +- .../v1alpha2/inferencemodelrewriterule.go | 2 +- ...rking.x-k8s.io_inferencemodelrewrites.yaml | 2 +- .../inferencemodelrewrite_reconciler.go | 11 +++-- .../inferencemodelrewrite_reconciler_test.go | 8 ++-- pkg/epp/datastore/datastore.go | 32 +++++++------- pkg/epp/requestcontrol/director.go | 4 +- pkg/epp/requestcontrol/director_test.go | 4 +- test/integration/epp/hermetic_test.go | 44 +++++++++++++++++++ .../inferencepool-with-model-hermetic.yaml | 15 +++++++ 10 files changed, 93 insertions(+), 31 deletions(-) diff --git a/apix/v1alpha2/inferencemodelrewrite_types.go b/apix/v1alpha2/inferencemodelrewrite_types.go index 262238c28..43cd6dc4b 100644 --- a/apix/v1alpha2/inferencemodelrewrite_types.go +++ b/apix/v1alpha2/inferencemodelrewrite_types.go @@ -92,7 +92,7 @@ type InferenceModelRewriteRule struct { // +optional // +kubebuilder:validation:MinItems=1 // - Targets []TargetModel `json:"split,omitempty"` + Targets []TargetModel `json:"targets,omitempty"` } // TargetModel defines a weighted model destination for traffic distribution. diff --git a/client-go/applyconfiguration/apix/v1alpha2/inferencemodelrewriterule.go b/client-go/applyconfiguration/apix/v1alpha2/inferencemodelrewriterule.go index 7e03c192d..80a69a83e 100644 --- a/client-go/applyconfiguration/apix/v1alpha2/inferencemodelrewriterule.go +++ b/client-go/applyconfiguration/apix/v1alpha2/inferencemodelrewriterule.go @@ -22,7 +22,7 @@ package v1alpha2 // with apply. type InferenceModelRewriteRuleApplyConfiguration struct { Matches []MatchApplyConfiguration `json:"matches,omitempty"` - Targets []TargetModelApplyConfiguration `json:"split,omitempty"` + Targets []TargetModelApplyConfiguration `json:"targets,omitempty"` } // InferenceModelRewriteRuleApplyConfiguration constructs a declarative configuration of the InferenceModelRewriteRule type for use with diff --git a/config/crd/bases/inference.networking.x-k8s.io_inferencemodelrewrites.yaml b/config/crd/bases/inference.networking.x-k8s.io_inferencemodelrewrites.yaml index 2680ea091..22dff8af8 100644 --- a/config/crd/bases/inference.networking.x-k8s.io_inferencemodelrewrites.yaml +++ b/config/crd/bases/inference.networking.x-k8s.io_inferencemodelrewrites.yaml @@ -108,7 +108,7 @@ spec: - model type: object type: array - split: + targets: items: description: TargetModel defines a weighted model destination for traffic distribution. diff --git a/pkg/epp/controller/inferencemodelrewrite_reconciler.go b/pkg/epp/controller/inferencemodelrewrite_reconciler.go index f3029ecc6..caa0df014 100644 --- a/pkg/epp/controller/inferencemodelrewrite_reconciler.go +++ b/pkg/epp/controller/inferencemodelrewrite_reconciler.go @@ -54,17 +54,20 @@ func (c *InferenceModelRewriteReconciler) Reconcile(ctx context.Context, req ctr notFound = true } - if notFound || !infModelRewrite.DeletionTimestamp.IsZero() || infModelRewrite.Spec.PoolRef == nil || + isDeleted := !infModelRewrite.DeletionTimestamp.IsZero() + isPooRefUnmatch := infModelRewrite.Spec.PoolRef == nil || infModelRewrite.Spec.PoolRef.Name != v1alpha2.ObjectName(c.PoolGKNN.Name) || - (infModelRewrite.Spec.PoolRef.Group != v1alpha2.Group(c.PoolGKNN.Group) && infModelRewrite.Spec.PoolRef.Group != "inference.networking.x-k8s.io") { + infModelRewrite.Spec.PoolRef.Group != v1alpha2.Group(c.PoolGKNN.Group) + + if notFound || isDeleted || isPooRefUnmatch { // InferenceModelRewrite object got deleted or changed the referenced pool. - c.Datastore.RewriteDelete(req.NamespacedName) + c.Datastore.ModelRewriteDelete(req.NamespacedName) return ctrl.Result{}, nil } // Add or update if the InferenceModelRewrite instance has a creation timestamp older than the existing entry of the model. logger = logger.WithValues("poolRef", infModelRewrite.Spec.PoolRef) - c.Datastore.RewriteSet(infModelRewrite) + c.Datastore.ModelRewriteSet(infModelRewrite) logger.Info("Added/Updated InferenceModelRewrite") return ctrl.Result{}, nil diff --git a/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go b/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go index 346d0a9e8..d349f46df 100644 --- a/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go +++ b/pkg/epp/controller/inferencemodelrewrite_reconciler_test.go @@ -185,7 +185,7 @@ func TestInferenceModelRewriteReconciler(t *testing.T) { pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) ds := datastore.NewDatastore(t.Context(), pmf, 0) for _, r := range test.rewritesInStore { - ds.RewriteSet(r) + ds.ModelRewriteSet(r) } endpointPool := poolutil.InferencePoolToEndpointPool(poolForRewrite) _ = ds.PoolSet(context.Background(), fakeClient, endpointPool) @@ -210,8 +210,8 @@ func TestInferenceModelRewriteReconciler(t *testing.T) { t.Errorf("Unexpected result diff (+got/-want): %s", diff) } - if len(test.wantRewrites) != len(ds.RewriteGetAll()) { - t.Errorf("Unexpected number of rewrites; want: %d, got:%d", len(test.wantRewrites), len(ds.RewriteGetAll())) + if len(test.wantRewrites) != len(ds.ModelRewriteGetAll()) { + t.Errorf("Unexpected number of rewrites; want: %d, got:%d", len(test.wantRewrites), len(ds.ModelRewriteGetAll())) } if diff := diffStoreRewrites(ds, test.wantRewrites); diff != "" { @@ -226,7 +226,7 @@ func diffStoreRewrites(ds datastore.Datastore, wantRewrites []*v1alpha2.Inferenc wantRewrites = []*v1alpha2.InferenceModelRewrite{} } - gotRewrites := ds.RewriteGetAll() + gotRewrites := ds.ModelRewriteGetAll() if diff := cmp.Diff(wantRewrites, gotRewrites); diff != "" { return "rewrites:" + diff } diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index 6b3f4e3f1..d550ad27b 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -60,10 +60,10 @@ type Datastore interface { ObjectiveGetAll() []*v1alpha2.InferenceObjective // InferenceModelRewrite operations - RewriteSet(infModelRewrite *v1alpha2.InferenceModelRewrite) - RewriteDelete(namespacedName types.NamespacedName) - RewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule - RewriteGetAll() []*v1alpha2.InferenceModelRewrite + ModelRewriteSet(infModelRewrite *v1alpha2.InferenceModelRewrite) + ModelRewriteDelete(namespacedName types.NamespacedName) + ModelRewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule + ModelRewriteGetAll() []*v1alpha2.InferenceModelRewrite // PodList lists pods matching the given predicate. PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics @@ -81,7 +81,7 @@ func NewDatastore(parentCtx context.Context, epFactory datalayer.EndpointFactory pool: nil, mu: sync.RWMutex{}, objectives: make(map[string]*v1alpha2.InferenceObjective), - rewrites: newModelRewriteStore(), + modelRewrites: newModelRewriteStore(), pods: &sync.Map{}, modelServerMetricsPort: modelServerMetricsPort, epf: epFactory, @@ -103,8 +103,8 @@ type datastore struct { pool *datalayer.EndpointPool // key: InferenceObjective name, value: *InferenceObjective objectives map[string]*v1alpha2.InferenceObjective - // rewrites store for InferenceModelRewrite objects. - rewrites *modelRewriteStore + // modelRewrites store for InferenceModelRewrite objects. + modelRewrites *modelRewriteStore // key: types.NamespacedName, value: backendmetrics.PodMetrics pods *sync.Map // modelServerMetricsPort metrics port from EPP command line argument @@ -118,7 +118,7 @@ func (ds *datastore) Clear() { defer ds.mu.Unlock() ds.pool = nil ds.objectives = make(map[string]*v1alpha2.InferenceObjective) - ds.rewrites = newModelRewriteStore() + ds.modelRewrites = newModelRewriteStore() // stop all pods go routines before clearing the pods map. ds.pods.Range(func(_, v any) bool { ds.epf.ReleaseEndpoint(v.(backendmetrics.PodMetrics)) @@ -210,28 +210,28 @@ func (ds *datastore) ObjectiveGetAll() []*v1alpha2.InferenceObjective { return res } -func (ds *datastore) RewriteSet(infModelRewrite *v1alpha2.InferenceModelRewrite) { +func (ds *datastore) ModelRewriteSet(infModelRewrite *v1alpha2.InferenceModelRewrite) { ds.mu.Lock() defer ds.mu.Unlock() - ds.rewrites.set(infModelRewrite) + ds.modelRewrites.set(infModelRewrite) } -func (ds *datastore) RewriteDelete(namespacedName types.NamespacedName) { +func (ds *datastore) ModelRewriteDelete(namespacedName types.NamespacedName) { ds.mu.Lock() defer ds.mu.Unlock() - ds.rewrites.delete(namespacedName) + ds.modelRewrites.delete(namespacedName) } -func (ds *datastore) RewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule { +func (ds *datastore) ModelRewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule { ds.mu.RLock() defer ds.mu.RUnlock() - return ds.rewrites.getRule(modelName) + return ds.modelRewrites.getRule(modelName) } -func (ds *datastore) RewriteGetAll() []*v1alpha2.InferenceModelRewrite { +func (ds *datastore) ModelRewriteGetAll() []*v1alpha2.InferenceModelRewrite { ds.mu.RLock() defer ds.mu.RUnlock() - return ds.rewrites.getAll() + return ds.modelRewrites.getAll() } // /// Pods/endpoints APIs /// diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index 3d13c0f7e..dc910d779 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -50,7 +50,7 @@ type Datastore interface { PoolGet() (*datalayer.EndpointPool, error) ObjectiveGet(objectiveName string) *v1alpha2.InferenceObjective PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics - RewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule + ModelRewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule } // Scheduler defines the interface required by the Director for scheduling. @@ -194,7 +194,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo } func (d *Director) applyWeightedModelRewrite(reqCtx *handlers.RequestContext) { - rewriteRule := d.datastore.RewriteGet(reqCtx.IncomingModelName) + rewriteRule := d.datastore.ModelRewriteGet(reqCtx.IncomingModelName) if rewriteRule == nil { return } diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index 9fde14bc8..5236adc52 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -169,7 +169,7 @@ func (m mockProducedDataType) Clone() datalayer.Cloneable { return mockProducedDataType{value: m.value} } -func (ds *mockDatastore) RewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule { +func (ds *mockDatastore) ModelRewriteGet(modelName string) *v1alpha2.InferenceModelRewriteRule { // This mock implementation simulates the precedence logic for simplicity. // It finds the oldest rewrite that has a rule matching the modelName. var matchingRewrites []*v1alpha2.InferenceModelRewrite @@ -268,7 +268,7 @@ func TestDirector_HandleRequest(t *testing.T) { ds.ObjectiveSet(ioFoodReview) ds.ObjectiveSet(ioFoodReviewResolve) ds.ObjectiveSet(ioFoodReviewSheddable) - ds.RewriteSet(rewrite) + ds.ModelRewriteSet(rewrite) scheme := runtime.NewScheme() _ = clientgoscheme.AddToScheme(scheme) diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go index f1f32d58b..59bf12cd4 100644 --- a/test/integration/epp/hermetic_test.go +++ b/test/integration/epp/hermetic_test.go @@ -90,6 +90,8 @@ const ( // Model Names modelMyModel = "my-model" modelMyModelTarget = "my-model-12345" + modelToBeWritten = "model-to-be-rewritten" + modelAfterRewrite = "rewritten-model" modelSQLLora = "sql-lora" modelSQLLoraTarget = "sql-lora-1fdg2" modelSheddable = "sql-lora-sheddable" @@ -981,6 +983,42 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { }, }, }, + { + name: "rewrite request model", + requests: integrationutils.GenerateStreamedRequestSet(logger, "test-rewrite", modelToBeWritten, modelToBeWritten, nil), + // Pod 0 will be picked. + // Expected flow: + // 1. Request asks for "model-to-be-rewritten" + // 2. Rewrite rule transforms "model-to-be-rewritten" -> "rewritten-model" + // 3. EPP sends request to backend with model "rewritten-model" + pods: newPodStates( + podState{index: 0, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", "rewritten-model"}}, + ), + wantMetrics: map[string]string{ + "inference_objective_request_total": inferenceObjectiveRequestTotal([]label{ + {"model_name", modelToBeWritten}, + {"target_model_name", modelAfterRewrite}, + }), + }, + wantErr: false, + wantResponses: integrationutils.NewRequestBufferedResponse( + "192.168.1.1:8000", + // Note: The prompt remains "test-rewrite", but the model in the JSON body is updated to the *rewritten target* model. + fmt.Sprintf(`{"max_tokens":100,"model":%q,"prompt":"test-rewrite","temperature":0}`, modelAfterRewrite), + &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: "hi", + RawValue: []byte("mom"), + }, + }, + &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: requtil.RequestIdHeaderKey, + RawValue: []byte("test-request-id"), + }, + }, + ), + }, } for _, test := range tests { @@ -1247,6 +1285,7 @@ func BeforeSuite() func() { _ = testEnv.Stop() _ = k8sClient.DeleteAllOf(context.Background(), &v1.InferencePool{}) _ = k8sClient.DeleteAllOf(context.Background(), &v1alpha2.InferenceObjective{}) + _ = k8sClient.DeleteAllOf(context.Background(), &v1alpha2.InferenceModelRewrite{}) } } @@ -1299,6 +1338,11 @@ func managerTestOptions(namespace, name string, metricsServerOptions metricsserv namespace: {}, }, }, + &v1alpha2.InferenceModelRewrite{}: { + Namespaces: map[string]cache.Config{ + namespace: {}, + }, + }, }, }, Controller: crconfig.Controller{ diff --git a/test/testdata/inferencepool-with-model-hermetic.yaml b/test/testdata/inferencepool-with-model-hermetic.yaml index df6ae30db..af5a82a10 100644 --- a/test/testdata/inferencepool-with-model-hermetic.yaml +++ b/test/testdata/inferencepool-with-model-hermetic.yaml @@ -62,3 +62,18 @@ spec: priority: 2 poolRef: name: vllm-llama3-8b-instruct-pool +--- +apiVersion: inference.networking.x-k8s.io/v1alpha2 +kind: InferenceModelRewrite +metadata: + name: rewrite-test + namespace: default +spec: + poolRef: + name: vllm-llama3-8b-instruct-pool + rules: + - matches: + - model: + value: model-to-be-rewritten + targets: + - modelRewrite: rewritten-model