Skip to content

Commit 3bde173

Browse files
committed
Add retry prompt for llm alignscore failure
1 parent 6306091 commit 3bde173

File tree

4 files changed

+30
-12
lines changed

4 files changed

+30
-12
lines changed

core_backend/app/llm_call/llm_prompts.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,11 @@ def get_prompt(cls) -> str:
177177
You are a helpful question-answering AI. You understand user question and answer their \
178178
question using the REFERENCE TEXT below.
179179
"""
180+
RETRY_PROMPT_SUFFIX = """
181+
If the response above is not aligned with the question, please rectify this by considering \
182+
the following reason(s) for misalignment: "{failure_reason}". Make necessary adjustments \
183+
to ensure the answer is aligned with the question.
184+
"""
180185
RAG_RESPONSE_PROMPT = (
181186
_RAG_PROFILE_PROMPT
182187
+ """
@@ -224,6 +229,7 @@ class RAG(BaseModel):
224229
answer: str
225230

226231
prompt: ClassVar[str] = RAG_RESPONSE_PROMPT
232+
retry_prompt: ClassVar[str] = RAG_RESPONSE_PROMPT + RETRY_PROMPT_SUFFIX
227233

228234

229235
class AlignmentScore(BaseModel):

core_backend/app/llm_call/llm_rag.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,14 @@ async def get_llm_rag_answer(
3737
"""
3838

3939
metadata = metadata or {}
40-
prompt = RAG.prompt.format(context=context, original_language=original_language)
40+
if "failure_reason" in metadata and metadata["failure_reason"]:
41+
prompt = RAG.retry_prompt.format(
42+
context=context,
43+
original_language=original_language,
44+
failure_reason=metadata["failure_reason"],
45+
)
46+
else:
47+
prompt = RAG.prompt.format(context=context, original_language=original_language)
4148

4249
result = await _ask_llm_async(
4350
user_message=question,

core_backend/app/llm_call/process_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def generate_llm_query_response(
5656
Only runs if the generate_llm_response flag is set to True.
5757
Requires "search_results" and "original_language" in the response.
5858
"""
59-
if isinstance(response, QueryResponseError):
59+
if isinstance(response, QueryResponseError) and not metadata["failure_reason"]:
6060
return response
6161

6262
if response.search_results is None:

core_backend/app/question_answer/routers.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ async def search(
125125
query_refined=user_query_refined_template,
126126
response=response,
127127
)
128+
if is_unable_to_generate_response(response):
129+
failure_reason = response.debug_info["factual_consistency"]
130+
response = await retry_search(
131+
query_refined=user_query_refined_template, response=response
132+
)
133+
response.debug_info["past_failure"] = failure_reason
128134

129135
await save_query_response_to_db(user_query_db, response, asession)
130136
await increment_query_count(
@@ -230,7 +236,6 @@ async def voice_search(
230236
asession=asession,
231237
exclude_archived=True,
232238
)
233-
234239
if user_query.generate_llm_response:
235240
response = await get_generation_response(
236241
query_refined=user_query_refined_template,
@@ -343,18 +348,15 @@ def is_unable_to_generate_response(response: QueryResponse) -> bool:
343348
async def retry_search(
344349
query_refined: QueryRefined,
345350
response: QueryResponse | QueryResponseError,
346-
user_id: int,
347-
n_similar: int,
348-
asession: AsyncSession,
349-
exclude_archived: bool = True,
350351
) -> QueryResponse | QueryResponseError:
351352
"""
352-
Retry wrapper for search_base.
353+
Retry wrapper for get_generation_response.
353354
"""
354355

355-
return await search_base(
356-
query_refined, response, user_id, n_similar, asession, exclude_archived
357-
)
356+
metadata = query_refined.query_metadata
357+
metadata["failure_reason"] = response.debug_info["factual_consistency"]["reason"]
358+
query_refined.query_metadata = metadata
359+
return await get_generation_response(query_refined, response)
358360

359361

360362
@generate_tts__after
@@ -376,10 +378,13 @@ async def get_generation_response(
376378
query_id=response.query_id, user_id=query_refined.user_id
377379
)
378380

381+
metadata["failure_reason"] = query_refined.query_metadata.get(
382+
"failure_reason", None
383+
)
384+
379385
response = await generate_llm_query_response(
380386
query_refined=query_refined, response=response, metadata=metadata
381387
)
382-
383388
return response
384389

385390

0 commit comments

Comments
 (0)