-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlangchain_helper.py
61 lines (47 loc) · 2.55 KB
/
langchain_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from langchain.vectorstores import FAISS
from langchain.llms import GooglePalm
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
import os
from dotenv import load_dotenv
load_dotenv() # take environment variables from .env (especially openai api key)
# Create Google Palm LLM model
llm = GooglePalm(google_api_key=os.environ["GOOGLE_API_KEY"], temperature=0.3)
# # Initialize instructor embeddings using the Hugging Face model
instructor_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-large")
vectordb_file_path = "faiss_index"
def create_vector_db():
# Load data from FAQ sheet
loader = CSVLoader(file_path='question_answer.csv', source_column="question", encoding="utf-8")
data = loader.load()
# Create a FAISS instance for vector database from 'data'
vectordb = FAISS.from_documents(documents=data,
embedding=instructor_embeddings)
# Save vector database locally
vectordb.save_local(vectordb_file_path)
def get_qa_chain():
# Load the vector database from the local folder
vectordb = FAISS.load_local(vectordb_file_path, instructor_embeddings)
# Create a retriever for querying the vector database
retriever = vectordb.as_retriever(score_threshold=0.7)
prompt_template = """Given the following context and a question, generate only the answer based on the provided context only.
In the answer try to provide as much text as possible from "response" section in the source document context without making much changes.
If the answer is not found in the context, kindly state "I don't know." Don't try to make up an answer.
CONTEXT: {context}
QUESTION: {question}"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
chain = RetrievalQA.from_chain_type(llm=llm,
chain_type="stuff",
retriever=retriever,
input_key="query",
return_source_documents=True,
chain_type_kwargs={"prompt": PROMPT})
return chain
if __name__ == "__main__":
#create_vector_db() # Create vector database only when there is a change in the csv file
chain = get_qa_chain()
print(chain("Buddhism was founded in part as a response to questions about which faith?"))