11package com .apolloconfig .apollo .ai .qabot .milvus ;
22
3- import com .google .common .collect .Lists ;
43import com .apolloconfig .apollo .ai .qabot .api .VectorDBService ;
54import com .apolloconfig .apollo .ai .qabot .config .MilvusConfig ;
65import com .apolloconfig .apollo .ai .qabot .markdown .MarkdownSearchResult ;
6+ import com .google .common .collect .Lists ;
77import com .theokanning .openai .embedding .Embedding ;
88import io .milvus .client .MilvusServiceClient ;
99import io .milvus .common .clientenum .ConsistencyLevelEnum ;
3232import java .util .Arrays ;
3333import java .util .Collections ;
3434import java .util .List ;
35+ import java .util .Random ;
3536import java .util .stream .Collectors ;
3637import org .springframework .context .annotation .Profile ;
3738import org .springframework .stereotype .Service ;
@@ -43,6 +44,7 @@ class MilvusService implements VectorDBService {
4344
4445 private final MilvusServiceClient milvusServiceClient ;
4546 private final MilvusConfig milvusConfig ;
47+ private final List <Float > dummyEmbeddings = Lists .newArrayList ();
4648
4749 public MilvusService (MilvusConfig milvusConfig ) {
4850 this .milvusConfig = milvusConfig ;
@@ -160,7 +162,7 @@ private List<Long> queryChunkIdByFileRoot(String fileRoot) {
160162 R <RpcStatus > loadStatus = milvusServiceClient .loadCollection (
161163 loadCollectionParam );
162164
163- List <String > query_output_fields = Arrays . asList ("chunk_id" );
165+ List <String > query_output_fields = List . of ("chunk_id" );
164166 QueryParam queryParam = QueryParam .newBuilder ()
165167 .withCollectionName (milvusConfig .getCollection ())
166168 .withConsistencyLevel (ConsistencyLevelEnum .STRONG )
@@ -169,6 +171,10 @@ private List<Long> queryChunkIdByFileRoot(String fileRoot) {
169171 .build ();
170172 R <QueryResults > respQuery = milvusServiceClient .query (queryParam );
171173
174+ if (respQuery .getStatus () != Status .Success .getCode ()) {
175+ throw new RuntimeException ("Query failed: " + respQuery .getMessage ());
176+ }
177+
172178 QueryResultsWrapper wrapperQuery = new QueryResultsWrapper (respQuery .getData ());
173179 List <?> chunkIds = wrapperQuery .getFieldWrapper ("chunk_id" ).getFieldData ();
174180
@@ -180,8 +186,116 @@ private List<Long> queryChunkIdByFileRoot(String fileRoot) {
180186 .collect (Collectors .toList ());
181187 }
182188
189+ @ Override
190+ public String queryFileHashValue (String fileRoot ) {
191+ LoadCollectionParam loadCollectionParam = LoadCollectionParam .newBuilder ()
192+ .withCollectionName (milvusConfig .getFileCollection ())
193+ .build ();
194+
195+ R <RpcStatus > loadStatus = milvusServiceClient .loadCollection (
196+ loadCollectionParam );
197+
198+ List <String > query_output_fields = List .of ("hash_value" );
199+ QueryParam queryParam = QueryParam .newBuilder ()
200+ .withCollectionName (milvusConfig .getFileCollection ())
201+ .withConsistencyLevel (ConsistencyLevelEnum .STRONG )
202+ .withExpr (String .format ("file_root in ['%s']" , fileRoot ))
203+ .withOutFields (query_output_fields )
204+ .build ();
205+ R <QueryResults > respQuery = milvusServiceClient .query (queryParam );
206+
207+ if (respQuery .getStatus () != Status .Success .getCode ()) {
208+ throw new RuntimeException ("Query failed: " + respQuery .getMessage ());
209+ }
210+
211+ QueryResultsWrapper wrapperQuery = new QueryResultsWrapper (respQuery .getData ());
212+ List <?> hashValues = wrapperQuery .getFieldWrapper ("hash_value" ).getFieldData ();
213+
214+ if (CollectionUtils .isEmpty (hashValues )) {
215+ return null ;
216+ }
217+
218+ return hashValues .get (0 ).toString ();
219+ }
220+
221+ @ Override
222+ public void persistFile (String fileRoot , String hashValue ) {
223+ List <Long > currentFileIds = queryFileIdByFileRoot (fileRoot );
224+
225+ List <Field > fields = new ArrayList <>();
226+ fields .add (new InsertParam .Field ("hash_value" , List .of (hashValue )));
227+ fields .add (new InsertParam .Field ("dummy_embedding" , List .of (dummyEmbeddings )));
228+ fields .add (new InsertParam .Field ("file_root" , List .of (fileRoot )));
229+
230+ InsertParam insertParam = InsertParam .newBuilder ()
231+ .withCollectionName (milvusConfig .getFileCollection ())
232+ .withFields (fields )
233+ .build ();
234+ milvusServiceClient .insert (insertParam );
235+
236+ deleteByFileIdList (currentFileIds );
237+
238+ FlushParam flushParam = FlushParam .newBuilder ()
239+ .withCollectionNames (Lists .newArrayList (milvusConfig .getFileCollection ()))
240+ .build ();
241+ milvusServiceClient .flush (flushParam );
242+ }
243+
244+ private void deleteByFileIdList (List <Long > fileIds ) {
245+ if (!fileIds .isEmpty ()) {
246+ StringBuilder sb = new StringBuilder ();
247+ sb .append ("file_id in [" );
248+ for (int i = 0 ; i < fileIds .size (); i ++) {
249+ sb .append (fileIds .get (i ));
250+ if (i != fileIds .size () - 1 ) {
251+ sb .append ("," );
252+ }
253+ }
254+ sb .append ("]" );
255+ DeleteParam deleteParam = DeleteParam .newBuilder ()
256+ .withCollectionName (milvusConfig .getFileCollection ())
257+ .withExpr (sb .toString ())
258+ .build ();
259+ milvusServiceClient .delete (deleteParam );
260+ }
261+ }
262+
263+ private List <Long > queryFileIdByFileRoot (String fileRoot ) {
264+ LoadCollectionParam loadCollectionParam = LoadCollectionParam .newBuilder ()
265+ .withCollectionName (milvusConfig .getFileCollection ())
266+ .build ();
267+
268+ R <RpcStatus > loadStatus = milvusServiceClient .loadCollection (
269+ loadCollectionParam );
270+
271+ List <String > query_output_fields = List .of ("file_id" );
272+ QueryParam queryParam = QueryParam .newBuilder ()
273+ .withCollectionName (milvusConfig .getFileCollection ())
274+ .withConsistencyLevel (ConsistencyLevelEnum .STRONG )
275+ .withExpr (String .format ("file_root in ['%s']" , fileRoot ))
276+ .withOutFields (query_output_fields )
277+ .build ();
278+ R <QueryResults > respQuery = milvusServiceClient .query (queryParam );
279+
280+ if (respQuery .getStatus () != Status .Success .getCode ()) {
281+ throw new RuntimeException ("Query failed: " + respQuery .getMessage ());
282+ }
283+
284+ QueryResultsWrapper wrapperQuery = new QueryResultsWrapper (respQuery .getData ());
285+ List <?> fileIds = wrapperQuery .getFieldWrapper ("file_id" ).getFieldData ();
286+
287+ if (CollectionUtils .isEmpty (fileIds )) {
288+ return Collections .emptyList ();
289+ }
290+
291+ return fileIds .stream ().map (id -> Long .parseLong (id .toString ()))
292+ .collect (Collectors .toList ());
293+ }
294+
295+
183296 private void ensureCollections () {
184297 ensureChunkCollection ();
298+ ensureFileCollection ();
185299 }
186300
187301 private void ensureChunkCollection () {
@@ -239,5 +353,64 @@ private void ensureChunkCollection() {
239353
240354 }
241355
356+ private void ensureFileCollection () {
357+ // prepare dummy embedding data
358+ Random random = new Random ();
359+ for (int i = 0 ; i < 1536 ; i ++) {
360+ dummyEmbeddings .add (random .nextFloat ());
361+ }
362+
363+ HasCollectionParam hasCollectionParam = HasCollectionParam .newBuilder ()
364+ .withCollectionName (milvusConfig .getFileCollection ())
365+ .build ();
366+
367+ if (milvusServiceClient .hasCollection (hasCollectionParam ).getData ()) {
368+ return ;
369+ }
370+
371+ FieldType fileId = FieldType .newBuilder ()
372+ .withName ("file_id" )
373+ .withDataType (DataType .Int64 )
374+ .withPrimaryKey (true )
375+ .withAutoID (true )
376+ .build ();
377+ FieldType fileRoot = FieldType .newBuilder ()
378+ .withName ("file_root" )
379+ .withDataType (DataType .VarChar )
380+ .withMaxLength (100 )
381+ .build ();
382+ FieldType hashValue = FieldType .newBuilder ()
383+ .withName ("hash_value" )
384+ .withDataType (DataType .VarChar )
385+ .withMaxLength (3000 )
386+ .build ();
387+ // not used, just for compatibility
388+ FieldType dummyEmbedding = FieldType .newBuilder ()
389+ .withName ("dummy_embedding" )
390+ .withDataType (DataType .FloatVector )
391+ .withDimension (1536 )
392+ .build ();
393+ CreateCollectionParam createCollectionReq = CreateCollectionParam .newBuilder ()
394+ .withCollectionName (milvusConfig .getFileCollection ())
395+ .withDescription ("Files for QA Search" )
396+ .addFieldType (fileId )
397+ .addFieldType (hashValue )
398+ .addFieldType (fileRoot )
399+ .addFieldType (dummyEmbedding )
400+ .build ();
401+
402+ milvusServiceClient .createCollection (createCollectionReq );
403+
404+ // not used, just for compatibility
405+ milvusServiceClient .createIndex (
406+ CreateIndexParam .newBuilder ()
407+ .withCollectionName (milvusConfig .getFileCollection ())
408+ .withFieldName ("dummy_embedding" )
409+ .withIndexType (IndexType .FLAT )
410+ .withMetricType (MetricType .L2 )
411+ .withSyncMode (Boolean .FALSE )
412+ .build ()
413+ );
414+ }
242415
243416}
0 commit comments