1818
1919# document loading
2020import glob
21+
22+ # general purpose imports
2123import logging
2224import os
2325import textwrap
26+ from typing import Union
2427
2528# pinecone integration
2629import pinecone
3841
3942# hybrid search capability
4043from langchain .retrievers import PineconeHybridSearchRetriever
41- from langchain .schema import HumanMessage , SystemMessage
44+ from langchain .schema import BaseMessage , HumanMessage , SystemMessage
4245from langchain .text_splitter import Document
4346from langchain .vectorstores .pinecone import Pinecone
4447from pinecone_text .sparse import BM25Encoder
@@ -95,14 +98,20 @@ class HybridSearchRetriever:
9598 text_splitter = TextSplitter ()
9699 bm25_encoder = BM25Encoder ().default ()
97100
98- def cached_chat_request (self , system_message : str , human_message : str ) -> SystemMessage :
101+ def cached_chat_request (
102+ self , system_message : Union [str , SystemMessage ], human_message : Union [str , HumanMessage ]
103+ ) -> BaseMessage :
99104 """Cached chat request."""
100- messages = [
101- SystemMessage (content = system_message ),
102- HumanMessage (content = human_message ),
103- ]
105+ if not isinstance (system_message , SystemMessage ):
106+ logging .debug ("Converting system message to SystemMessage" )
107+ system_message = SystemMessage (content = str (system_message ))
108+
109+ if not isinstance (human_message , HumanMessage ):
110+ logging .debug ("Converting human message to HumanMessage" )
111+ human_message = HumanMessage (content = str (human_message ))
112+ messages = [system_message , human_message ]
104113 # pylint: disable=not-callable
105- retval = self .chat (messages ). content
114+ retval = self .chat (messages )
106115 return retval
107116
108117 def prompt_with_template (self , prompt : PromptTemplate , concept : str , model : str = DEFAULT_MODEL_NAME ) -> str :
@@ -158,10 +167,10 @@ def load(self, filepath: str):
158167
159168 logging .debug ("Finished loading PDFs" )
160169
161- def rag (self , prompt : str ):
170+ def rag (self , human_message : Union [ str , HumanMessage ] ):
162171 """
163172 Embedded prompt.
164- 1. Retrieve prompt: Given a user input, relevant splits are retrieved
173+ 1. Retrieve human message prompt: Given a user input, relevant splits are retrieved
165174 from storage using a Retriever.
166175 2. Generate: A ChatModel / LLM produces an answer using a prompt that includes
167176 the question and the retrieved data
@@ -174,33 +183,32 @@ def rag(self, prompt: str):
174183 The typical workflow is to use the embeddings to retrieve relevant documents,
175184 and then use the text of these documents as part of the prompt for GPT-3.
176185 """
186+ if not isinstance (human_message , HumanMessage ):
187+ logging .debug ("Converting human_message to HumanMessage" )
188+ human_message = HumanMessage (content = human_message )
189+
177190 retriever = PineconeHybridSearchRetriever (
178191 embeddings = self .openai_embeddings , sparse_encoder = self .bm25_encoder , index = self .pinecone_index
179192 )
180- documents = retriever .get_relevant_documents (query = prompt )
193+ documents = retriever .get_relevant_documents (query = human_message . content )
181194 logging .debug ("Retrieved %i related documents from Pinecone" , len (documents ))
182195
183196 # Extract the text from the documents
184197 document_texts = [doc .page_content for doc in documents ]
185198 leader = textwrap .dedent (
186- """\
187- \n \n You can assume that the following is true.
199+ """You are a helpful assistant.
200+ You can assume that all of the following is true.
188201 You should attempt to incorporate these facts
189- into your response :\n \n
202+ into your responses :\n \n
190203 """
191204 )
205+ system_message = f"{ leader } { '. ' .join (document_texts )} "
192206
193- # Create a prompt that includes the document texts
194- prompt_with_relevant_documents = f"{ prompt + leader } { '. ' .join (document_texts )} "
195-
196- logging .debug ("Prompt contains %i words" , len (prompt_with_relevant_documents .split ()))
197- logging .debug ("Prompt: %s" , prompt_with_relevant_documents )
198-
199- # Get a response from the GPT-3.5-turbo model
200- response = self .cached_chat_request (
201- system_message = "You are a helpful assistant." , human_message = prompt_with_relevant_documents
202- )
207+ logging .debug ("System messages contains %i words" , len (system_message .split ()))
208+ logging .debug ("Prompt: %s" , system_message )
209+ system_message = SystemMessage (content = system_message )
210+ response = self .cached_chat_request (system_message = system_message , human_message = human_message )
203211
204212 logging .debug ("Response:" )
205213 logging .debug ("------------------------------------------------------" )
206- return response
214+ return response . content
0 commit comments