11# -*- coding: utf-8 -*-
22# pylint: disable=too-few-public-methods
3- """Sales Support Model (SSM) for the LangChain project."""
4-
3+ """
4+ Sales Support Model (SSM) for the LangChain project.
5+ See: https://python.langchain.com/docs/modules/model_io/llms/llm_caching
6+ https://python.langchain.com/docs/modules/data_connection/document_loaders/pdf
7+ """
8+
9+ import glob
10+ import os
511from typing import ClassVar , List
612
713import pinecone
14+ from langchain import hub
15+ from langchain .cache import InMemoryCache
16+
17+ # prompting and chat
818from langchain .chat_models import ChatOpenAI
19+
20+ # document loading
21+ from langchain .document_loaders import PyPDFLoader
22+
23+ # embedding
924from langchain .embeddings import OpenAIEmbeddings
25+
26+ # vector database
27+ from langchain .globals import set_llm_cache
1028from langchain .llms .openai import OpenAI
1129from langchain .prompts import PromptTemplate
12- from langchain .schema import HumanMessage , SystemMessage # AIMessage (not used)
30+ from langchain .schema import HumanMessage , StrOutputParser , SystemMessage
31+ from langchain .schema .runnable import RunnablePassthrough
1332from langchain .text_splitter import Document , RecursiveCharacterTextSplitter
1433from langchain .vectorstores .pinecone import Pinecone
1534from pydantic import BaseModel , ConfigDict , Field # ValidationError
1635
36+ # this project
1737from models .const import Credentials
1838
1939
40+ ###############################################################################
41+ # initializations
42+ ###############################################################################
2043DEFAULT_MODEL_NAME = "text-davinci-003"
2144pinecone .init (api_key = Credentials .PINECONE_API_KEY , environment = Credentials .PINECONE_ENVIRONMENT )
45+ set_llm_cache (InMemoryCache ())
2246
2347
2448class SalesSupportModel (BaseModel ):
@@ -31,24 +55,17 @@ class SalesSupportModel(BaseModel):
3155 default_factory = lambda : ChatOpenAI (
3256 api_key = Credentials .OPENAI_API_KEY ,
3357 organization = Credentials .OPENAI_API_ORGANIZATION ,
58+ cache = True ,
3459 max_retries = 3 ,
3560 model = "gpt-3.5-turbo" ,
36- temperature = 0.3 ,
61+ temperature = 0.0 ,
3762 )
3863 )
3964
4065 # embeddings
41- text_splitter : RecursiveCharacterTextSplitter = Field (
42- default_factory = lambda : RecursiveCharacterTextSplitter (
43- chunk_size = 100 ,
44- chunk_overlap = 0 ,
45- )
46- )
47-
4866 texts_splitter_results : List [Document ] = Field (None , description = "Text splitter results" )
4967 pinecone_search : Pinecone = Field (None , description = "Pinecone search" )
50- pinecone_index_name : str = Field (default = "netec-ssm" , description = "Pinecone index name" )
51- openai_embedding : OpenAIEmbeddings = Field (default_factory = lambda : OpenAIEmbeddings (model = "ada" ))
68+ openai_embedding : OpenAIEmbeddings = Field (OpenAIEmbeddings ())
5269 query_result : List [float ] = Field (None , description = "Vector database query result" )
5370
5471 def cached_chat_request (self , system_message : str , human_message : str ) -> SystemMessage :
@@ -68,24 +85,72 @@ def prompt_with_template(self, prompt: PromptTemplate, concept: str, model: str
6885
6986 def split_text (self , text : str ) -> List [Document ]:
7087 """Split text."""
71- # pylint: disable=no-member
72- retval = self .text_splitter .create_documents ([text ])
88+ text_splitter = RecursiveCharacterTextSplitter (
89+ chunk_size = 100 ,
90+ chunk_overlap = 0 ,
91+ )
92+ retval = text_splitter .create_documents ([text ])
7393 return retval
7494
7595 def embed (self , text : str ) -> List [float ]:
7696 """Embed."""
77- texts_splitter_results = self .split_text (text )
97+ text_splitter = RecursiveCharacterTextSplitter (
98+ chunk_size = 100 ,
99+ chunk_overlap = 0 ,
100+ )
101+ texts_splitter_results = text_splitter .create_documents ([text ])
78102 embedding = texts_splitter_results [0 ].page_content
79103 # pylint: disable=no-member
80104 self .openai_embedding .embed_query (embedding )
81105
82106 self .pinecone_search = Pinecone .from_documents (
83107 texts_splitter_results ,
84108 embedding = self .openai_embedding ,
85- index_name = self . pinecone_index_name ,
109+ index_name = Credentials . PINECONE_INDEX_NAME ,
86110 )
87111
112+ def rag (self , filepath : str , prompt : str ):
113+ """
114+ Embed PDF.
115+ 1. Load PDF document text data
116+ 2. Split into pages
117+ 3. Embed each page
118+ 4. Store in Pinecone
119+ """
120+
121+ # pylint: disable=unused-variable
122+ def format_docs (docs ):
123+ """Format docs."""
124+ return "\n \n " .join (doc .page_content for doc in docs )
125+
126+ for pdf_file in glob .glob (os .path .join (filepath , "*.pdf" )):
127+ loader = PyPDFLoader (file_path = pdf_file )
128+ docs = loader .load ()
129+ for doc in docs :
130+ self .embed (doc .page_content )
131+
132+ text_splitter = RecursiveCharacterTextSplitter (chunk_size = 1000 , chunk_overlap = 200 )
133+ splits = text_splitter .split_documents (docs )
134+ vectorstore = Pinecone .from_documents (documents = splits , embedding = self .openai_embedding )
135+ retriever = vectorstore .as_retriever ()
136+ prompt = hub .pull ("rlm/rag-prompt" )
137+
138+ rag_chain = (
139+ {"context" : retriever | self .format_docs , "question" : RunnablePassthrough ()}
140+ | prompt
141+ | self .chat
142+ | StrOutputParser ()
143+ )
144+
145+ return rag_chain .invoke (prompt )
146+
88147 def embedded_prompt (self , prompt : str ) -> List [Document ]:
89- """Embedded prompt."""
148+ """
149+ Embedded prompt.
150+ 1. Retrieve prompt: Given a user input, relevant splits are retrieved
151+ from storage using a Retriever.
152+ 2. Generate: A ChatModel / LLM produces an answer using a prompt that includes
153+ the question and the retrieved data
154+ """
90155 result = self .pinecone_search .similarity_search (prompt )
91156 return result
0 commit comments