import streamlit as st import torch import os import tempfile import time from threading import Thread from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from langchain_community.document_loaders import PyPDFLoader, TextLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import FAISS from langchain.retrievers import BM25Retriever, EnsembleRetriever from langchain.schema import Document from langchain.docstore.document import Document as LangchainDocument # --- Avatars --- USER_AVATAR = "๐Ÿ‘ค" BOT_AVATAR = "๐Ÿค–" # --- HF Token --- HF_TOKEN = st.secrets["HF_TOKEN"] # --- Page Config --- st.set_page_config(page_title="DigiTwin RAG", page_icon="๐Ÿ“‚", layout="centered") st.title("๐Ÿ“‚ DigiTs the Twin") # --- Sidebar --- with st.sidebar: st.header("๐Ÿ“„ Upload Knowledge Files") uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"]) hybrid_toggle = st.checkbox("๐Ÿ”€ Enable Hybrid Search", value=True) clear_chat = st.button("๐Ÿงน Clear Chat History") # --- Session State --- if "messages" not in st.session_state or clear_chat: st.session_state.messages = [] # --- Load Model --- @st.cache_resource def load_model(): model_id = "tiiuae/falcon-7b-instruct" tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN) model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN) return tokenizer, model tokenizer, model = load_model() # --- Load & Chunk Documents --- def process_documents(files): documents = [] for file in files: suffix = ".pdf" if file.name.endswith(".pdf") else ".txt" with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file: tmp_file.write(file.read()) tmp_file_path = tmp_file.name loader = PyPDFLoader(tmp_file_path) if suffix == ".pdf" else TextLoader(tmp_file_path) documents.extend(loader.load()) return documents def chunk_documents(documents): splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) return splitter.split_documents(documents) def build_retrievers(chunks): embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") faiss_vectorstore = FAISS.from_documents(chunks, embeddings) faiss_retriever = faiss_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5}) bm25_retriever = BM25Retriever.from_documents([LangchainDocument(page_content=d.page_content) for d in chunks]) bm25_retriever.k = 5 return faiss_retriever, EnsembleRetriever(retrievers=[faiss_retriever, bm25_retriever], weights=[0.5, 0.5]) # --- Prompt Builder --- def build_prompt(history, context=""): conversation = "" for turn in history: role = "User" if turn["role"] == "user" else "Assistant" conversation += f"{role}: {turn['content']}\n" return ( "You are DigiTwin, an expert advisor in asset integrity, reliability, inspection, and maintenance " "of topside piping, structural, mechanical systems, floating units, pressure vessels (VII), and pressure safety devices (PSD's).\n\n" f"Context:\n{context}\n\n" f"{conversation}Assistant:" ) # --- Generator --- def generate_response(prompt): streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=300) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() for token in streamer: yield token # --- Main App --- retriever = None if uploaded_files: with st.spinner("Processing documents..."): docs = process_documents(uploaded_files) chunks = chunk_documents(docs) faiss, hybrid = build_retrievers(chunks) retriever = hybrid if hybrid_toggle else faiss st.success("Documents processed. Ask away!") for msg in st.session_state.messages: with st.chat_message(msg["role"], avatar=USER_AVATAR if msg["role"] == "user" else BOT_AVATAR): st.markdown(msg["content"]) # --- Chat UI --- if prompt := st.chat_input("Ask something based on uploaded documents..."): st.chat_message("user", avatar=USER_AVATAR).markdown(prompt) st.session_state.messages.append({"role": "user", "content": prompt}) context = "" if retriever: docs = retriever.get_relevant_documents(prompt) context = "\n\n".join([d.page_content for d in docs]) full_prompt = build_prompt(st.session_state.messages, context=context) with st.chat_message("assistant", avatar=BOT_AVATAR): streamer = generate_response(full_prompt) container = st.empty() answer = "" for chunk in streamer: answer += chunk container.markdown(answer + "โ–Œ", unsafe_allow_html=True) container.markdown(answer) st.session_state.messages.append({"role": "assistant", "content": answer})