|
16 | 16 | import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; |
17 | 17 | import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; |
18 | 18 |
|
| 19 | +import java.util.List; |
19 | 20 | import java.util.Map; |
20 | 21 | import java.util.Set; |
| 22 | +import java.util.function.Function; |
21 | 23 |
|
22 | 24 | import static java.util.stream.Collectors.toMap; |
23 | 25 |
|
@@ -63,65 +65,80 @@ public record MinimalModel( |
63 | 65 | private static final ElasticInferenceServiceRerankServiceSettings RERANK_SERVICE_SETTINGS = |
64 | 66 | new ElasticInferenceServiceRerankServiceSettings(DEFAULT_RERANK_MODEL_ID_V1); |
65 | 67 |
|
66 | | - private static final Map<String, MinimalModel> MODEL_NAME_TO_MINIMAL_MODEL = Map.of( |
| 68 | + // A single model name can map to multiple inference endpoints, so we need a String to a List |
| 69 | + private static final Map<String, List<MinimalModel>> MODEL_NAME_TO_MINIMAL_MODELS = Map.of( |
67 | 70 | DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, |
68 | | - new MinimalModel( |
69 | | - new ModelConfigurations( |
70 | | - DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, |
71 | | - TaskType.CHAT_COMPLETION, |
72 | | - ElasticInferenceService.NAME, |
73 | | - COMPLETION_SERVICE_SETTINGS, |
74 | | - ChunkingSettingsBuilder.DEFAULT_SETTINGS |
75 | | - ), |
76 | | - COMPLETION_SERVICE_SETTINGS |
| 71 | + List.of( |
| 72 | + new MinimalModel( |
| 73 | + new ModelConfigurations( |
| 74 | + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, |
| 75 | + TaskType.CHAT_COMPLETION, |
| 76 | + ElasticInferenceService.NAME, |
| 77 | + COMPLETION_SERVICE_SETTINGS, |
| 78 | + ChunkingSettingsBuilder.DEFAULT_SETTINGS |
| 79 | + ), |
| 80 | + COMPLETION_SERVICE_SETTINGS |
| 81 | + ) |
77 | 82 | ), |
78 | 83 | DEFAULT_ELSER_2_MODEL_ID, |
79 | | - new MinimalModel( |
80 | | - new ModelConfigurations( |
81 | | - DEFAULT_ELSER_ENDPOINT_ID_V2, |
82 | | - TaskType.SPARSE_EMBEDDING, |
83 | | - ElasticInferenceService.NAME, |
84 | | - SPARSE_EMBEDDINGS_SERVICE_SETTINGS, |
85 | | - ChunkingSettingsBuilder.DEFAULT_SETTINGS |
86 | | - ), |
87 | | - SPARSE_EMBEDDINGS_SERVICE_SETTINGS |
| 84 | + List.of( |
| 85 | + new MinimalModel( |
| 86 | + new ModelConfigurations( |
| 87 | + DEFAULT_ELSER_ENDPOINT_ID_V2, |
| 88 | + TaskType.SPARSE_EMBEDDING, |
| 89 | + ElasticInferenceService.NAME, |
| 90 | + SPARSE_EMBEDDINGS_SERVICE_SETTINGS, |
| 91 | + ChunkingSettingsBuilder.DEFAULT_SETTINGS |
| 92 | + ), |
| 93 | + SPARSE_EMBEDDINGS_SERVICE_SETTINGS |
| 94 | + ) |
88 | 95 | ), |
89 | 96 | DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, |
90 | | - new MinimalModel( |
91 | | - new ModelConfigurations( |
92 | | - DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, |
93 | | - TaskType.TEXT_EMBEDDING, |
94 | | - ElasticInferenceService.NAME, |
95 | | - DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS, |
96 | | - ChunkingSettingsBuilder.DEFAULT_SETTINGS |
97 | | - ), |
98 | | - DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS |
| 97 | + List.of( |
| 98 | + new MinimalModel( |
| 99 | + new ModelConfigurations( |
| 100 | + DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, |
| 101 | + TaskType.TEXT_EMBEDDING, |
| 102 | + ElasticInferenceService.NAME, |
| 103 | + DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS, |
| 104 | + ChunkingSettingsBuilder.DEFAULT_SETTINGS |
| 105 | + ), |
| 106 | + DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS |
| 107 | + ) |
99 | 108 | ), |
100 | 109 | DEFAULT_RERANK_MODEL_ID_V1, |
101 | | - new MinimalModel( |
102 | | - new ModelConfigurations( |
103 | | - DEFAULT_RERANK_ENDPOINT_ID_V1, |
104 | | - TaskType.RERANK, |
105 | | - ElasticInferenceService.NAME, |
106 | | - RERANK_SERVICE_SETTINGS, |
107 | | - ChunkingSettingsBuilder.DEFAULT_SETTINGS |
108 | | - ), |
109 | | - RERANK_SERVICE_SETTINGS |
| 110 | + List.of( |
| 111 | + new MinimalModel( |
| 112 | + new ModelConfigurations( |
| 113 | + DEFAULT_RERANK_ENDPOINT_ID_V1, |
| 114 | + TaskType.RERANK, |
| 115 | + ElasticInferenceService.NAME, |
| 116 | + RERANK_SERVICE_SETTINGS, |
| 117 | + ChunkingSettingsBuilder.DEFAULT_SETTINGS |
| 118 | + ), |
| 119 | + RERANK_SERVICE_SETTINGS |
| 120 | + ) |
110 | 121 | ) |
111 | 122 | ); |
112 | 123 |
|
113 | | - private static final Map<String, MinimalModel> INFERENCE_ID_TO_MINIMAL_MODEL = MODEL_NAME_TO_MINIMAL_MODEL.entrySet() |
| 124 | + private static final Map<String, MinimalModel> INFERENCE_ID_TO_MINIMAL_MODEL = MODEL_NAME_TO_MINIMAL_MODELS.entrySet() |
114 | 125 | .stream() |
115 | | - .collect(toMap(e -> e.getValue().configurations().getInferenceEntityId(), Map.Entry::getValue)); |
| 126 | + .flatMap(entry -> entry.getValue().stream()) |
| 127 | + .collect(toMap(m -> m.configurations().getInferenceEntityId(), Function.identity())); |
116 | 128 |
|
117 | 129 | public static final Set<String> EIS_PRECONFIGURED_ENDPOINT_IDS = Set.copyOf(INFERENCE_ID_TO_MINIMAL_MODEL.keySet()); |
118 | 130 |
|
119 | 131 | public static SimilarityMeasure defaultDenseTextEmbeddingsSimilarity() { |
120 | 132 | return SimilarityMeasure.COSINE; |
121 | 133 | } |
122 | 134 |
|
123 | | - public static MinimalModel getWithModelName(String modelName) { |
124 | | - return MODEL_NAME_TO_MINIMAL_MODEL.get(modelName); |
| 135 | + public static List<MinimalModel> getWithModelName(String modelName) { |
| 136 | + var minimalModels = MODEL_NAME_TO_MINIMAL_MODELS.get(modelName); |
| 137 | + if (minimalModels == null) { |
| 138 | + return List.of(); |
| 139 | + } |
| 140 | + |
| 141 | + return minimalModels; |
125 | 142 | } |
126 | 143 |
|
127 | 144 | public static MinimalModel getWithInferenceId(String inferenceId) { |
|
0 commit comments