Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions willa/chatbot/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
LANGFUSE_HANDLER = CallbackHandler()
"""The Langfuse callback handler."""


class Chatbot: # pylint: disable=R0903
"""An instance of a Willa chatbot.

Expand Down Expand Up @@ -89,6 +90,7 @@ def ask(self, question: str) -> dict[str, str]:

if ai_message:
answers["ai_message"] = str(ai_message[-1].content)
answers["langfuse_trace_id"] = str(LANGFUSE_HANDLER.last_trace_id)

if len(answers) == 0:
return {"no_result": "I'm sorry, I couldn't generate a response."}
Expand Down
49 changes: 46 additions & 3 deletions willa/web/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

import logging
import os
from typing import Optional

import chainlit as cl
from chainlit.data import get_data_layer
from chainlit.data.chainlit_data_layer import ChainlitDataLayer
from chainlit.types import ThreadDict, CommandDict
from chainlit.types import ThreadDict, CommandDict, Feedback
from chainlit.step import StepDict

from willa.chatbot import Chatbot
from willa.config import CONFIG
from willa.config import CONFIG, get_langfuse_client
from willa.web.cas_provider import CASProvider
from willa.web.inject_custom_auth import add_custom_oauth_provider

Expand All @@ -34,6 +37,22 @@
]


async def get_step(self: ChainlitDataLayer, step_id: str) -> Optional[StepDict]:
"""Get step and related feedback"""
query = """
SELECT s.*,
f.id feedback_id,
f.value feedback_value,
f."comment" feedback_comment
FROM "Step" s LEFT JOIN "Feedback" f ON s.id = f."stepId"
WHERE s.id = $1
"""
result = await self.execute_query(query, {"step_id": step_id})
if not result:
return None
return self._convert_step_row_to_dict(result[0]) # pylint: disable="protected-access"


@cl.on_chat_start
async def ocs() -> None:
"""loaded when new chat is started"""
Expand All @@ -48,6 +67,25 @@ async def on_chat_resume(thread: ThreadDict) -> None:
# pylint: enable="unused-argument"


@cl.on_feedback
async def on_feedback(feedback: Feedback) -> None:
"""Handle feedback."""
step: Optional[StepDict] = await get_data_layer().get_step(feedback.forId)
if step is None:
LOGGER.warning("Feedback left for unknown step %s", feedback.forId)
return

trace_id: Optional[str] = step['metadata'].get('langfuse_trace_id')
get_langfuse_client().create_score(
name='feedback',
value=float(feedback.value),
session_id=step['threadId'] if not trace_id else None,
trace_id=trace_id,
data_type='BOOLEAN',
comment=feedback.comment
)
Comment on lines +79 to +86
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on 29816ad



@cl.data_layer
def data_layer() -> ChainlitDataLayer:
"""Retrieve the data layer to use with Chainlit.
Expand All @@ -68,7 +106,11 @@ def _secret() -> str:
database_url = os.environ.get(
'DATABASE_URL', f"postgresql://{_pg('USER')}:{_secret()}@{_pg('HOST')}/{_pg('DB')}"
)
return ChainlitDataLayer(database_url=database_url)
dl = ChainlitDataLayer(database_url=database_url)
# pylint: disable="no-value-for-parameter"
dl.get_step = get_step.__get__(dl) # type: ignore[attr-defined]
# pylint: enable="no-value-for-parameter"
return dl


def _get_history() -> str:
Expand Down Expand Up @@ -117,6 +159,7 @@ async def chat(message: cl.Message) -> None:

if 'ai_message' in reply:
await cl.Message(content=reply['ai_message']).send()
cl.context.current_run.metadata['langfuse_trace_id'] = reply['langfuse_trace_id']
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Save the trace ID into the run metadata, which is where Chainlit stores Feedback objects (forId -> the run, not the message).


if 'tind_message' in reply:
tind_refs = cl.CustomElement(
Expand Down