Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 55 additions & 29 deletions willa/chatbot/graph_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Manages the shared state and workflow for Willa chatbots."""
import re
from typing import Any, Optional, Annotated, NotRequired
from typing_extensions import TypedDict

Expand All @@ -19,10 +20,10 @@ class WillaChatbotState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
filtered_messages: NotRequired[list[AnyMessage]]
summarized_messages: NotRequired[list[AnyMessage]]
docs_context: NotRequired[str]
search_query: NotRequired[str]
tind_metadata: NotRequired[str]
context: NotRequired[dict[str, Any]]
documents: NotRequired[list[Any]]
citations: NotRequired[list[dict[str, Any]]]


class GraphManager: # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -51,13 +52,15 @@ def _create_workflow(self) -> CompiledStateGraph:
workflow.add_node("summarize", summarization_node)
workflow.add_node("prepare_search", self._prepare_search_query)
workflow.add_node("retrieve_context", self._retrieve_context)
workflow.add_node("prepare_for_generation", self._prepare_for_generation)
workflow.add_node("generate_response", self._generate_response)

# Define edges
workflow.add_edge("filter_messages", "summarize")
workflow.add_edge("summarize", "prepare_search")
workflow.add_edge("prepare_search", "retrieve_context")
workflow.add_edge("retrieve_context", "generate_response")
workflow.add_edge("retrieve_context", "prepare_for_generation")
workflow.add_edge("prepare_for_generation", "generate_response")

workflow.set_entry_point("filter_messages")
workflow.set_finish_point("generate_response")
Expand Down Expand Up @@ -87,52 +90,75 @@ def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str]:
vector_store = self._vector_store

if not search_query or not vector_store:
return {"docs_context": "", "tind_metadata": ""}
return {"tind_metadata": "", "documents": []}

# Search for relevant documents
retriever = vector_store.as_retriever(search_kwargs={"k": int(CONFIG['K_VALUE'])})
matching_docs = retriever.invoke(search_query)

# Format context and metadata
docs_context = '\n\n'.join(doc.page_content for doc in matching_docs)
formatted_documents = [
{
"id": f"{doc.metadata.get('tind_metadata', {}).get('tind_id', [''])[0]}_{i}",
"page_content": doc.page_content,
"title": doc.metadata.get('tind_metadata', {}).get('title', [''])[0],
"project": doc.metadata.get('tind_metadata', {}).get('isPartOf', [''])[0],
"tind_link": format_tind_context.get_tind_url(
doc.metadata.get('tind_metadata', {}).get('tind_id', [''])[0])
}
for i, doc in enumerate(matching_docs, 1)
]

# Format tind metadata
tind_metadata = format_tind_context.get_tind_context(matching_docs)

return {"docs_context": docs_context, "tind_metadata": tind_metadata}
return {"tind_metadata": tind_metadata, "documents": formatted_documents}

# This should be refactored probably. Very bulky
def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
"""Generate response using the model."""
def _prepare_for_generation(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
"""Prepare the current and past messages for response generation."""
messages = state["messages"]
summarized_conversation = state.get("summarized_messages", messages)
docs_context = state.get("docs_context", "")
tind_metadata = state.get("tind_metadata", "")
model = self._model

if not model:
return {"messages": [AIMessage(content="Model not available.")]}

# Get the latest human message
latest_message = next(
(msg for msg in reversed(messages) if isinstance(msg, HumanMessage)),
None
)

if not latest_message:

if not any(isinstance(msg, HumanMessage) for msg in messages):
return {"messages": [AIMessage(content="I'm sorry, I didn't receive a question.")]}

prompt = get_langfuse_prompt()
system_messages = prompt.invoke({'context': docs_context,
'question': latest_message.content})
system_messages = prompt.invoke({})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this doing? where is the user's question being inserted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user's question is passed as a user message like it always is. There actually might be unrelated cleanup work for us to do with that. Right now, we're passing the user query as a user message twice (due to the summary bit) and then again as part of the system message (instruction prompt).


if hasattr(system_messages, "messages"):
all_messages = summarized_conversation + system_messages.messages
else:
all_messages = summarized_conversation + [system_messages]

return {"messages": all_messages}

def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
"""Generate response using the model."""
tind_metadata = state.get("tind_metadata", "")
model = self._model
documents = state.get("documents", [])
messages = state["messages"]

if not model:
return {"messages": [AIMessage(content="Model not available.")]}

# Get response from model
response = model.invoke(all_messages)
response = model.invoke(
messages,
additional_model_request_fields={"documents": documents},
additional_model_response_field_paths=["/citations"]
)
citations = response.response_metadata.get('additionalModelResponseFields').get('citations') if response.response_metadata else None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's neat about this, is that the citations returned by cohere can be used to cite specific parts of the response message to documents that we passed above in line 145.

An example list of citations look like this:

[
  {
      "start": 184,
      "end": 229,
      "text": "admission criteria and standards of practice.",
      "document_ids": ["doc_0"]
  },
  {
      "start": 260,
      "end": 275,
      "text": "different roles",
      "document_ids": ["doc_1", "doc_2"]
  }
]


# Create clean response content
response_content = str(response.content) if hasattr(response, 'content') else str(response)

if citations:
state['citations'] = citations
response_content += "\n\nCitations:\n"
for citation in citations:
doc_ids = list(dict.fromkeys([re.sub(r'_\d*$', '', doc_id)
for doc_id in citation.get('document_ids', [])]))
response_content += f"- {citation.get('text', '')} ({', '.join(doc_ids)})\n"

response_messages: list[AnyMessage] = [AIMessage(content=response_content),
ChatMessage(content=tind_metadata, role='TIND',
response_metadata={'tind': True})]
Expand Down
Loading