Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
RAG_Chatbot_and_LLM_evaluation/chatAPP.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
245 lines (199 sloc)
10 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Import the Streamlit library | |
import streamlit as streamlit_interface | |
# Import the load_dotenv function from the dotenv module | |
from dotenv import load_dotenv | |
# Import the PdfReader class from the PyPDF2 module | |
from PyPDF2 import PdfReader | |
# Import the CharacterTextSplitter class from the langchain.text_splitter module | |
from langchain.text_splitter import CharacterTextSplitter | |
# Import the OpenAIEmbeddings class from the langchain.embeddings module | |
from langchain.embeddings import OpenAIEmbeddings | |
# Import the FAISS class from the langchain.vectorstores module | |
from langchain.vectorstores import FAISS | |
# Import the ChatOpenAI class from the langchain.chat_models module | |
from langchain.chat_models import ChatOpenAI | |
# Import the ConversationBufferMemory class from the langchain.memory module | |
from langchain.memory import ConversationBufferMemory | |
# Import the ConversationalRetrievalChain class from the langchain.chains module | |
from langchain.chains import ConversationalRetrievalChain | |
# Import the css, bot_template, and user_template variables from the htmlTemplates module | |
from htmlTemplates import css, bot_template, user_template | |
# Import the torch library | |
import torch | |
from langchain_community.embeddings import HuggingFaceInstructEmbeddings | |
# Import the Ollama class from the langchain_community.llms module | |
from langchain_community.llms import Ollama | |
from langchain_core.prompts import PromptTemplate | |
# Import the InferenceApi class from the huggingface_hub.inference_api module | |
from huggingface_hub.inference_api import InferenceApi | |
# Load model directly | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Define a function to create a conversation chain | |
from transformers import pipeline | |
from langchain_community.llms import HuggingFacePipeline | |
from langchain.embeddings import HuggingFaceInstructEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.chains import RetrievalQA | |
import logging | |
from InstructorEmbedding import INSTRUCTOR | |
system_prompt = "Your system prompt goes here" # Define the system_prompt variable | |
def get_prompt_template(system_prompt=system_prompt, promptTemplate_type=None, history=False): | |
if promptTemplate_type=="llama": | |
B_INST, E_INST = "[INST]", "[/INST]" | |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
# Import the PromptTemplate class from the langchain.templates module | |
else: | |
# change this based on the model you have selected. | |
if history: | |
prompt_template = system_prompt + """ | |
Context: {history} \n {context} | |
User: {question} | |
Answer:""" | |
prompt = PromptTemplate(input_variables=["history", "context", "question"], template=prompt_template) | |
else: | |
prompt_template = system_prompt + """ | |
Context: {context} | |
User: {question} | |
Answer:""" | |
prompt = PromptTemplate(input_variables=["context", "question"], template=prompt_template) | |
memory = ConversationBufferMemory(input_key="question", memory_key="history") | |
return prompt, memory, | |
# Define a function to extract text from PDF files | |
def extract_text_from_pdf(pdf_documents): | |
# Initialize an empty string to store the extracted text | |
extracted_text = "" | |
# Loop over each PDF document | |
for pdf in pdf_documents: | |
# Create a PdfReader object for the current PDF document | |
pdf_reader = PdfReader(pdf) | |
# Loop over each page in the current PDF document | |
for page in pdf_reader.pages: | |
# Add the text extracted from the current page to the extracted_text string | |
extracted_text += page.extract_text() | |
# Return the extracted text | |
return extracted_text | |
# Define a function to split the text into chunks | |
def split_text_into_chunks(text): | |
# Create a CharacterTextSplitter object | |
text_splitter = CharacterTextSplitter( | |
separator="\n", | |
chunk_size=1000, | |
chunk_overlap=200, | |
length_function=len | |
) | |
# Use the CharacterTextSplitter object to split the text into chunks | |
chunks = text_splitter.split_text(text) | |
# Return the chunks | |
return chunks | |
# Define a function to create a vector store from the text chunks | |
def create_vector_store(text_chunks): | |
# Create an OpenAIEmbeddings object | |
#embeddings = OpenAIEmbeddings() | |
embeddings = HuggingFaceInstructEmbeddings() | |
# Create a FAISS object from the text chunks and the OpenAIEmbeddings object | |
vector_store = Chroma.from_texts(texts=text_chunks, embedding=embeddings) | |
# Return the vector store | |
return vector_store | |
def create_conversation_chain(vector_store): | |
# Create a pipeline for text generation | |
pipe = pipeline( | |
"text-generation", | |
model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | |
, | |
tokenizer=AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf"), | |
max_length=1000, | |
temperature=0.2, | |
repetition_penalty=1.15, | |
) | |
local_llm = HuggingFacePipeline(pipeline=pipe) | |
#logging.info("Local LLM Loaded") | |
embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-large") | |
#embeddings = HuggingFaceInstructEmbeddings(model_name="gpt-3.5-turbo") | |
db = Chroma(persist_directory="./", | |
embedding_function=embeddings,) | |
retriever = db.as_retriever() | |
# get the prompt template and memory if set by the user. | |
prompt, memory = get_prompt_template(promptTemplate_type="llama", | |
history=True) | |
if True: | |
conversation_chain = RetrievalQA.from_chain_type(llm=local_llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True, | |
callbacks=None, | |
chain_type_kwargs={"prompt": prompt, "memory": memory},) | |
else: | |
conversation_chain = RetrievalQA.from_chain_type(llm=local_llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True, | |
callbacks=None, | |
chain_type_kwargs={"prompt": prompt,},) | |
return conversation_chain | |
# Define a function to handle user input and generate responses | |
def process_user_input(user_question): | |
# Get the response from the conversation chain | |
response = streamlit_interface.session_state.conversation({'question': user_question}) | |
# Store the chat history in the session state | |
streamlit_interface.session_state.chat_history = response['chat_history'] | |
# Loop over each message in the chat history | |
for i, message in enumerate(streamlit_interface.session_state.chat_history): | |
# If the index of the message is even | |
if i % 2 == 0: | |
# Write the user's message to the Streamlit interface | |
streamlit_interface.write(user_template.replace( | |
"{{MSG}}", message.content), unsafe_allow_html=True) | |
# If the index of the message is odd | |
else: | |
# Write the bot's message to the Streamlit interface | |
streamlit_interface.write(bot_template.replace( | |
"{{MSG}}", message.content), unsafe_allow_html=True) | |
# Define the main function | |
def main(): | |
# Load the environment variables | |
load_dotenv() | |
# Set the page configuration of the Streamlit interface | |
streamlit_interface.set_page_config(page_title="Chat with your own documents", | |
page_icon=":books:") | |
# Write the CSS to the Streamlit interface | |
streamlit_interface.write(css, unsafe_allow_html=True) | |
# If there is no conversation in the session state | |
if "conversation" not in streamlit_interface.session_state: | |
# Set the conversation in the session state to None | |
streamlit_interface.session_state.conversation = None | |
# If there is no chat history in the session state | |
if "chat_history" not in streamlit_interface.session_state: | |
# Set the chat history in the session state to None | |
streamlit_interface.session_state.chat_history = None | |
# Write a header to the Streamlit interface | |
streamlit_interface.header("Chat with your documents :books:") | |
# Get the user's question from the Streamlit interface | |
user_question = streamlit_interface.text_input("Ask a question about your documents:") | |
# If the user has asked a question | |
if user_question: | |
# Process the user's question | |
process_user_input(user_question) | |
# Create a sidebar in the Streamlit interface | |
with streamlit_interface.sidebar: | |
# Write a subheader to the sidebar | |
streamlit_interface.subheader("Documents") | |
# Create a file uploader in the sidebar | |
pdf_documents = streamlit_interface.file_uploader( | |
"Upload your documents here and click on 'Upload'", accept_multiple_files=True) | |
# If the user clicks the "Process" button | |
if streamlit_interface.button("Upload"): | |
# Show a spinner in the Streamlit interface while processing | |
with streamlit_interface.spinner("Processing"): | |
# Extract the text from the PDF documents | |
raw_text = extract_text_from_pdf(pdf_documents) | |
# Split the text into chunks | |
text_chunks = split_text_into_chunks(raw_text) | |
# Create a vector store from the text chunks | |
vector_store = create_vector_store(text_chunks) | |
# Create a conversation chain from the vector store | |
streamlit_interface.session_state.conversation = create_conversation_chain( | |
vector_store) | |
# If this script is being run directly (not imported as a module) | |
if __name__ == '__main__': | |
# Run the main function | |
main() |