Skip to content

Commit 06259e9

Browse files
[ML] Allow a single model name to map to multiple inference endpoints (#138059)
* Switching to list so we can support completion and chat completion * Adding test for empty list
1 parent 10f73dd commit 06259e9

File tree

3 files changed

+83
-42
lines changed

3 files changed

+83
-42
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java

Lines changed: 58 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
1717
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
1818

19+
import java.util.List;
1920
import java.util.Map;
2021
import java.util.Set;
22+
import java.util.function.Function;
2123

2224
import static java.util.stream.Collectors.toMap;
2325

@@ -63,65 +65,80 @@ public record MinimalModel(
6365
private static final ElasticInferenceServiceRerankServiceSettings RERANK_SERVICE_SETTINGS =
6466
new ElasticInferenceServiceRerankServiceSettings(DEFAULT_RERANK_MODEL_ID_V1);
6567

66-
private static final Map<String, MinimalModel> MODEL_NAME_TO_MINIMAL_MODEL = Map.of(
68+
// A single model name can map to multiple inference endpoints, so we need a String to a List
69+
private static final Map<String, List<MinimalModel>> MODEL_NAME_TO_MINIMAL_MODELS = Map.of(
6770
DEFAULT_CHAT_COMPLETION_MODEL_ID_V1,
68-
new MinimalModel(
69-
new ModelConfigurations(
70-
DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1,
71-
TaskType.CHAT_COMPLETION,
72-
ElasticInferenceService.NAME,
73-
COMPLETION_SERVICE_SETTINGS,
74-
ChunkingSettingsBuilder.DEFAULT_SETTINGS
75-
),
76-
COMPLETION_SERVICE_SETTINGS
71+
List.of(
72+
new MinimalModel(
73+
new ModelConfigurations(
74+
DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1,
75+
TaskType.CHAT_COMPLETION,
76+
ElasticInferenceService.NAME,
77+
COMPLETION_SERVICE_SETTINGS,
78+
ChunkingSettingsBuilder.DEFAULT_SETTINGS
79+
),
80+
COMPLETION_SERVICE_SETTINGS
81+
)
7782
),
7883
DEFAULT_ELSER_2_MODEL_ID,
79-
new MinimalModel(
80-
new ModelConfigurations(
81-
DEFAULT_ELSER_ENDPOINT_ID_V2,
82-
TaskType.SPARSE_EMBEDDING,
83-
ElasticInferenceService.NAME,
84-
SPARSE_EMBEDDINGS_SERVICE_SETTINGS,
85-
ChunkingSettingsBuilder.DEFAULT_SETTINGS
86-
),
87-
SPARSE_EMBEDDINGS_SERVICE_SETTINGS
84+
List.of(
85+
new MinimalModel(
86+
new ModelConfigurations(
87+
DEFAULT_ELSER_ENDPOINT_ID_V2,
88+
TaskType.SPARSE_EMBEDDING,
89+
ElasticInferenceService.NAME,
90+
SPARSE_EMBEDDINGS_SERVICE_SETTINGS,
91+
ChunkingSettingsBuilder.DEFAULT_SETTINGS
92+
),
93+
SPARSE_EMBEDDINGS_SERVICE_SETTINGS
94+
)
8895
),
8996
DEFAULT_MULTILINGUAL_EMBED_MODEL_ID,
90-
new MinimalModel(
91-
new ModelConfigurations(
92-
DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID,
93-
TaskType.TEXT_EMBEDDING,
94-
ElasticInferenceService.NAME,
95-
DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS,
96-
ChunkingSettingsBuilder.DEFAULT_SETTINGS
97-
),
98-
DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS
97+
List.of(
98+
new MinimalModel(
99+
new ModelConfigurations(
100+
DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID,
101+
TaskType.TEXT_EMBEDDING,
102+
ElasticInferenceService.NAME,
103+
DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS,
104+
ChunkingSettingsBuilder.DEFAULT_SETTINGS
105+
),
106+
DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS
107+
)
99108
),
100109
DEFAULT_RERANK_MODEL_ID_V1,
101-
new MinimalModel(
102-
new ModelConfigurations(
103-
DEFAULT_RERANK_ENDPOINT_ID_V1,
104-
TaskType.RERANK,
105-
ElasticInferenceService.NAME,
106-
RERANK_SERVICE_SETTINGS,
107-
ChunkingSettingsBuilder.DEFAULT_SETTINGS
108-
),
109-
RERANK_SERVICE_SETTINGS
110+
List.of(
111+
new MinimalModel(
112+
new ModelConfigurations(
113+
DEFAULT_RERANK_ENDPOINT_ID_V1,
114+
TaskType.RERANK,
115+
ElasticInferenceService.NAME,
116+
RERANK_SERVICE_SETTINGS,
117+
ChunkingSettingsBuilder.DEFAULT_SETTINGS
118+
),
119+
RERANK_SERVICE_SETTINGS
120+
)
110121
)
111122
);
112123

113-
private static final Map<String, MinimalModel> INFERENCE_ID_TO_MINIMAL_MODEL = MODEL_NAME_TO_MINIMAL_MODEL.entrySet()
124+
private static final Map<String, MinimalModel> INFERENCE_ID_TO_MINIMAL_MODEL = MODEL_NAME_TO_MINIMAL_MODELS.entrySet()
114125
.stream()
115-
.collect(toMap(e -> e.getValue().configurations().getInferenceEntityId(), Map.Entry::getValue));
126+
.flatMap(entry -> entry.getValue().stream())
127+
.collect(toMap(m -> m.configurations().getInferenceEntityId(), Function.identity()));
116128

117129
public static final Set<String> EIS_PRECONFIGURED_ENDPOINT_IDS = Set.copyOf(INFERENCE_ID_TO_MINIMAL_MODEL.keySet());
118130

119131
public static SimilarityMeasure defaultDenseTextEmbeddingsSimilarity() {
120132
return SimilarityMeasure.COSINE;
121133
}
122134

123-
public static MinimalModel getWithModelName(String modelName) {
124-
return MODEL_NAME_TO_MINIMAL_MODEL.get(modelName);
135+
public static List<MinimalModel> getWithModelName(String modelName) {
136+
var minimalModels = MODEL_NAME_TO_MINIMAL_MODELS.get(modelName);
137+
if (minimalModels == null) {
138+
return List.of();
139+
}
140+
141+
return minimalModels;
125142
}
126143

127144
public static MinimalModel getWithInferenceId(String inferenceId) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints;
3232

3333
import java.util.EnumSet;
34+
import java.util.List;
3435
import java.util.Map;
3536
import java.util.Objects;
3637
import java.util.Set;
@@ -248,7 +249,7 @@ private Set<String> getNewInferenceEndpointsToStore(ElasticInferenceServiceAutho
248249

249250
var newInferenceIds = authorizedModelIds.stream()
250251
.map(InternalPreconfiguredEndpoints::getWithModelName)
251-
.filter(Objects::nonNull)
252+
.flatMap(List::stream)
252253
.map(model -> model.configurations().getInferenceEntityId())
253254
.collect(Collectors.toSet());
254255

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elastic;
9+
10+
import org.elasticsearch.test.ESTestCase;
11+
12+
import static org.hamcrest.Matchers.hasSize;
13+
14+
public class InternalPreconfiguredEndpointsTests extends ESTestCase {
15+
public void testGetWithModelName_ReturnsAnEmptyList_IfNameDoesNotExist() {
16+
assertThat(InternalPreconfiguredEndpoints.getWithModelName("non-existent-model"), hasSize(0));
17+
}
18+
19+
public void testGetWithModelName_ReturnsChatCompletionModels() {
20+
var models = InternalPreconfiguredEndpoints.getWithModelName(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_MODEL_ID_V1);
21+
assertThat(models, hasSize(1));
22+
}
23+
}

0 commit comments

Comments
 (0)