11package com .apolloconfig .apollo .ai .qabot .controller ;
22
33import com .apolloconfig .apollo .ai .qabot .api .AiService ;
4+ import com .apolloconfig .apollo .ai .qabot .api .VectorDBService ;
45import com .apolloconfig .apollo .ai .qabot .markdown .MarkdownSearchResult ;
56import com .google .common .base .Strings ;
67import com .google .common .collect .Lists ;
7- import com .apolloconfig . apollo . ai . qabot . api . VectorDBService ;
8+ import com .theokanning . openai . completion . chat . ChatCompletionChunk ;
89import com .theokanning .openai .embedding .Embedding ;
10+ import io .reactivex .Flowable ;
911import java .util .Collections ;
1012import java .util .List ;
1113import java .util .Set ;
14+ import java .util .concurrent .atomic .AtomicInteger ;
1215import java .util .stream .Collectors ;
1316import org .slf4j .Logger ;
1417import org .slf4j .LoggerFactory ;
1518import org .springframework .beans .factory .annotation .Value ;
19+ import org .springframework .http .MediaType ;
20+ import org .springframework .web .bind .annotation .GetMapping ;
1621import org .springframework .web .bind .annotation .PostMapping ;
1722import org .springframework .web .bind .annotation .RequestMapping ;
1823import org .springframework .web .bind .annotation .RequestParam ;
1924import org .springframework .web .bind .annotation .RestController ;
25+ import org .springframework .web .server .ServerWebExchange ;
26+ import reactor .core .publisher .Flux ;
27+ import reactor .core .publisher .Mono ;
2028
2129@ RestController
2230@ RequestMapping ("/qa" )
@@ -38,54 +46,109 @@ public QAController(AiService aiService, VectorDBService vectorDBService) {
3846 this .vectorDBService = vectorDBService ;
3947 }
4048
41- @ PostMapping
42- public Answer qa (@ RequestParam String question ) {
49+ @ GetMapping ( produces = MediaType . TEXT_EVENT_STREAM_VALUE )
50+ public Flux < Answer > qa (@ RequestParam String question ) {
4351 question = question .trim ();
4452 if (Strings .isNullOrEmpty (question )) {
45- return Answer .EMPTY ;
53+ return Flux . just ( Answer .EMPTY ) ;
4654 }
4755
4856 try {
4957 return doQA (question );
5058 } catch (Throwable exception ) {
5159 LOGGER .error ("Error while calling OpenAI API" , exception );
52- return Answer .ERROR ;
60+ return Flux . just ( Answer .ERROR ) ;
5361 }
5462 }
5563
56- private Answer doQA (String question ) {
57- List <Embedding > embeddings = aiService .getEmbeddings (Lists .newArrayList (question ));
64+ /**
65+ * @deprecated Use {@link #qa(String)} instead.
66+ */
67+ @ Deprecated
68+ @ PostMapping
69+ public Mono <Answer > qaSync (ServerWebExchange serverWebExchange ) {
70+ Mono <String > field = getFormField (serverWebExchange , "question" );
71+ return field .flatMap (question -> {
72+ if (Strings .isNullOrEmpty (question )) {
73+ return Mono .just (Answer .EMPTY );
74+ }
75+
76+ try {
77+ Flux <Answer > answer = doQA (question .trim ());
78+ return answer .reduce ((a1 , a2 ) -> {
79+ if (Answer .END .answer ().equals (a2 .answer ())) {
80+ return a1 ;
81+ }
82+ a1 .relatedFiles ().addAll (a2 .relatedFiles );
83+
84+ return new Answer (a1 .answer () + a2 .answer (), a1 .relatedFiles );
85+ });
86+ } catch (Throwable exception ) {
87+ LOGGER .error ("Error while calling OpenAI API" , exception );
88+ return Mono .just (Answer .ERROR );
89+ }
90+ });
91+ }
5892
59- List <List <Float >> searchVectors = Collections .singletonList (
60- embeddings .get (0 ).getEmbedding ().stream ()
61- .map (Double ::floatValue ).collect (Collectors .toList ()));
93+ private Mono <String > getFormField (ServerWebExchange exchange , String fieldName ) {
94+ return exchange .getFormData ()
95+ .flatMap (data -> Mono .justOrEmpty (data .getFirst (fieldName )));
96+ }
6297
63- List <MarkdownSearchResult > searchResults = vectorDBService .search (searchVectors , topK );
98+ private Flux <Answer > doQA (String question ) {
99+ List <MarkdownSearchResult > searchResults = searchFromVectorDB (question );
64100
65101 if (searchResults .isEmpty ()) {
66- return Answer .UNKNOWN ;
102+ return Flux . just ( Answer .UNKNOWN ) ;
67103 }
68104
69105 Set <String > relatedFiles = searchResults .stream ()
70106 .map (MarkdownSearchResult ::getFileRoot ).collect (Collectors .toSet ());
71107
72- StringBuilder sb = new StringBuilder ();
73- searchResults .forEach (
74- markdownSearchResult -> sb .append (markdownSearchResult .getContent ()).append ("\n " ));
75-
76- String promptMessage = prompt .replace ("{question}" , question )
77- .replace ("{context}" , sb .toString ());
108+ String promptMessage = assemblePromptMessage (searchResults , question );
78109
79- String answer = aiService .getCompletion (promptMessage );
110+ Flowable < ChatCompletionChunk > result = aiService .getCompletion (promptMessage );
80111
81112 if (LOGGER .isDebugEnabled ()) {
82- LOGGER .debug ("\n Prompt message: {}\n Answer: {} " , promptMessage , answer );
113+ LOGGER .debug ("\n Prompt message: {}" , promptMessage );
83114 }
84115
85- return new Answer (answer , relatedFiles );
116+ final AtomicInteger counter = new AtomicInteger ();
117+ Flux <Answer > flux = Flux .from (result .filter (
118+ chatCompletionChunk -> chatCompletionChunk .getChoices ().get (0 ).getMessage ().getContent ()
119+ != null ).map (chatCompletionChunk -> {
120+ String value = chatCompletionChunk .getChoices ().get (0 ).getMessage ().getContent ();
121+ if (LOGGER .isDebugEnabled ()) {
122+ System .out .print (value );
123+ }
124+
125+ return counter .incrementAndGet () == 1 ? new Answer (value , relatedFiles )
126+ : new Answer (value , Collections .emptySet ());
127+ }));
128+
129+ return flux .concatWith (Flux .just (Answer .END ));
130+ }
131+
132+ private List <MarkdownSearchResult > searchFromVectorDB (String question ) {
133+ List <Embedding > embeddings = aiService .getEmbeddings (Lists .newArrayList (question ));
134+
135+ List <List <Float >> searchVectors = Collections .singletonList (
136+ embeddings .get (0 ).getEmbedding ().stream ()
137+ .map (Double ::floatValue ).collect (Collectors .toList ()));
138+
139+ return vectorDBService .search (searchVectors , topK );
86140 }
87141
88- static class Answer {
142+ private String assemblePromptMessage (List <MarkdownSearchResult > searchResults , String question ) {
143+ StringBuilder sb = new StringBuilder ();
144+ searchResults .forEach (
145+ markdownSearchResult -> sb .append (markdownSearchResult .getContent ()).append ("\n " ));
146+
147+ return prompt .replace ("{question}" , question )
148+ .replace ("{context}" , sb .toString ());
149+ }
150+
151+ public record Answer (String answer , Set <String > relatedFiles ) {
89152
90153 static final Answer EMPTY = new Answer ("" , Collections .emptySet ());
91154 static final Answer UNKNOWN = new Answer ("Sorry, I don't know the answer." ,
@@ -95,20 +158,6 @@ static class Answer {
95158 "Sorry, I can't answer your question right now. Please try again later." ,
96159 Collections .emptySet ());
97160
98- private final String answer ;
99- private final Set <String > relatedFiles ;
100-
101- public Answer (String answer , Set <String > relatedFiles ) {
102- this .answer = answer ;
103- this .relatedFiles = relatedFiles ;
104- }
105-
106- public String getAnswer () {
107- return answer ;
108- }
109-
110- public Set <String > getRelatedFiles () {
111- return relatedFiles ;
112- }
161+ static final Answer END = new Answer ("$END$" , Collections .emptySet ());
113162 }
114163}
0 commit comments