Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| import os | |
| import pickle | |
| import argparse | |
| import faiss | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM, | |
| pipeline as hf_pipeline, | |
| ) | |
| import evaluate | |
| # ββ 1. Configuration ββ | |
| DATA_DIR = os.path.join(os.getcwd(), "data") | |
| INDEX_PATH = os.path.join(DATA_DIR, "faiss_index.faiss") | |
| EMB_PATH = os.path.join(DATA_DIR, "embeddings.npy") | |
| PCTX_PATH = os.path.join(DATA_DIR, "passages.pkl") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "google/flan-t5-small") | |
| EMBEDDER_MODEL = os.getenv("EMBEDDER_MODEL", "sentence-transformers/all-MiniLM-L6-v2") | |
| DIST_THRESHOLD = float(os.getenv("DIST_THRESHOLD", 1.0)) | |
| MAX_CTX_WORDS = int(os.getenv("MAX_CTX_WORDS", 200)) | |
| DEVICE = 0 if torch.cuda.is_available() else -1 | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| # ββ 2. Helpers ββ | |
| def make_context_snippets(contexts, max_words=MAX_CTX_WORDS): | |
| snippets = [] | |
| for c in contexts: | |
| words = c.split() | |
| if len(words) > max_words: | |
| c = " ".join(words[:max_words]) + " ... [truncated]" | |
| snippets.append(c) | |
| return snippets | |
| def chunk_text(text, max_tokens, stride=None): | |
| words = text.split() | |
| if stride is None: | |
| stride = max_tokens // 4 | |
| chunks, start = [], 0 | |
| while start < len(words): | |
| end = start + max_tokens | |
| chunks.append(" ".join(words[start:end])) | |
| start += stride | |
| return chunks | |
| # ββ 3. Load & preprocess passages ββ | |
| def load_passages(): | |
| wiki_ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages") | |
| squad_ds = load_dataset("rajpurkar/squad_v2", split="train[:100]") | |
| trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]") | |
| wiki_passages = wiki_ds["passage"] | |
| squad_passages = [ex["context"] for ex in squad_ds] | |
| trivia_passages = [] | |
| for ex in trivia_ds: | |
| for fld in ("wiki_context", "search_context"): | |
| txt = ex.get(fld) or "" | |
| if txt: | |
| trivia_passages.append(txt) | |
| all_passages = list(dict.fromkeys(wiki_passages + squad_passages + trivia_passages)) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| max_tokens = tokenizer.model_max_length | |
| chunks = [] | |
| for p in all_passages: | |
| toks = tokenizer.tokenize(p) | |
| if len(toks) > max_tokens: | |
| chunks.extend(chunk_text(p, max_tokens)) | |
| else: | |
| chunks.append(p) | |
| print(f"[load_passages] total chunks: {len(chunks)}") | |
| with open(PCTX_PATH, "wb") as f: | |
| pickle.dump(chunks, f) | |
| return chunks | |
| # ββ 4. Build or load FAISS ββ | |
| def load_faiss_index(passages): | |
| embedder = SentenceTransformer(EMBEDDER_MODEL) | |
| reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| if os.path.exists(INDEX_PATH) and os.path.exists(EMB_PATH): | |
| print("Loading FAISS index & embeddingsβ¦") | |
| index = faiss.read_index(INDEX_PATH) | |
| embeddings = np.load(EMB_PATH) | |
| else: | |
| print("Encoding passages & building FAISS indexβ¦") | |
| embeddings = embedder.encode( | |
| passages, | |
| show_progress_bar=True, | |
| convert_to_numpy=True, | |
| batch_size=32 | |
| ) | |
| embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) | |
| dim = embeddings.shape[1] | |
| index = faiss.IndexFlatIP(dim) | |
| index.add(embeddings) | |
| faiss.write_index(index, INDEX_PATH) | |
| np.save(EMB_PATH, embeddings) | |
| return embedder, reranker, index | |
| # ββ 5. Initialize RAG components ββ | |
| def setup_rag(): | |
| if os.path.exists(PCTX_PATH): | |
| with open(PCTX_PATH, "rb") as f: | |
| passages = pickle.load(f) | |
| else: | |
| passages = load_passages() | |
| embedder, reranker, index = load_faiss_index(passages) | |
| tok = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
| qa_pipe = hf_pipeline( | |
| "text2text-generation", | |
| model=model, | |
| tokenizer=tok, | |
| device=DEVICE, | |
| truncation=True, | |
| max_length=512, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| return passages, embedder, reranker, index, qa_pipe | |
| # ββ 6. Retrieval & generation ββ | |
| def retrieve(question, passages, embedder, reranker, index, k=20, rerank_k=5): | |
| q_emb = embedder.encode([question], convert_to_numpy=True) | |
| distances, idxs = index.search(q_emb, k) | |
| cands = [passages[i] for i in idxs[0]] | |
| scores = reranker.predict([[question, c] for c in cands]) | |
| top = np.argsort(scores)[-rerank_k:][::-1] | |
| return [cands[i] for i in top], [distances[0][i] for i in top] | |
| def generate(question, contexts, qa_pipe): | |
| lines = [ | |
| f"Context {i+1}: {s}" | |
| for i, s in enumerate(make_context_snippets(contexts)) | |
| ] | |
| prompt = ( | |
| "You are a helpful assistant. Use ONLY the following contexts to answer. " | |
| "If the answer is not contained, say 'Sorry, I don't know.'\n\n" | |
| + "\n".join(lines) | |
| + f"\n\nQuestion: {question}\nAnswer:" | |
| ) | |
| return qa_pipe(prompt)[0]["generated_text"].strip() | |
| def retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe): | |
| contexts, dists = retrieve(question, passages, embedder, reranker, index) | |
| if not contexts or dists[0] > DIST_THRESHOLD: | |
| return "Sorry, I don't know.", [] | |
| return generate(question, contexts, qa_pipe), contexts | |
| def answer_and_contexts(question, passages, embedder, reranker, index, qa_pipe): | |
| ans, ctxs = retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe) | |
| if not ctxs: | |
| return ans, "" | |
| snippets = [ | |
| f"Context {i+1}: {s}" | |
| for i, s in enumerate(make_context_snippets(ctxs)) | |
| ] | |
| return ans, "\n\n---\n\n".join(snippets) | |
| # ββ 7. Evaluation routines ββ | |
| def retrieval_recall(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100): | |
| hits = 0 | |
| for ex in dataset.select(range(num_samples)): | |
| question = ex["question"] | |
| gold_answers = ex["answers"]["text"] | |
| if rerank_k: | |
| ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k) | |
| else: | |
| q_emb = embedder.encode([question], convert_to_numpy=True) | |
| distances, idxs = index.search(q_emb, k) | |
| ctxs = [passages[i] for i in idxs[0]] | |
| if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers): | |
| hits += 1 | |
| recall = hits / num_samples | |
| print(f"Retrieval Recall@{k} (rerank_k={rerank_k}): {recall:.3f} ({hits}/{num_samples})") | |
| return recall | |
| def retrieval_recall_answerable(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100): | |
| hits, total = 0, 0 | |
| for ex in dataset.select(range(num_samples)): | |
| gold = ex["answers"]["text"] | |
| if not gold: | |
| continue | |
| total += 1 | |
| question = ex["question"] | |
| if rerank_k: | |
| ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k) | |
| else: | |
| q_emb = embedder.encode([question], convert_to_numpy=True) | |
| distances, idxs = index.search(q_emb, k) | |
| ctxs = [passages[i] for i in idxs[0]] | |
| if any(any(ans in ctx for ctx in ctxs) for ans in gold): | |
| hits += 1 | |
| recall = hits / total if total > 0 else 0.0 | |
| print(f"Retrieval Recall@{k} on answerable only (rerank_k={rerank_k}): {recall:.3f} ({hits}/{total})") | |
| return recall | |
| def qa_eval_answerable(dataset, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100): | |
| squad_metric = evaluate.load("squad") | |
| preds, refs = [], [] | |
| for ex in dataset.select(range(num_samples)): | |
| gold = ex["answers"]["text"] | |
| if not gold: | |
| continue | |
| qid = ex["id"] | |
| answer, _ = retrieve_and_answer(ex["question"], passages, embedder, reranker, index, qa_pipe) | |
| preds.append({"id": qid, "prediction_text": answer}) | |
| refs.append({"id": qid, "answers": ex["answers"]}) | |
| results = squad_metric.compute(predictions=preds, references=refs) | |
| print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}") | |
| return results | |
| # ββ 8. Main entry ββ | |
| def main(): | |
| passages, embedder, reranker, index, qa_pipe = setup_rag() | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--eval", action="store_true", | |
| help="Run retrieval/QA evaluations on SQuAD instead of launching the demo" | |
| ) | |
| args = parser.parse_args() | |
| if args.eval: | |
| squad = load_dataset("rajpurkar/squad_v2", split="validation") | |
| retrieval_recall(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100) | |
| retrieval_recall_answerable(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100) | |
| qa_eval_answerable(squad, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100) | |
| else: | |
| demo = gr.Interface( | |
| fn=lambda q: answer_and_contexts(q, passages, embedder, reranker, index, qa_pipe), | |
| inputs=gr.Textbox(lines=1, placeholder="Ask me anythingβ¦", label="Question"), | |
| outputs=[gr.Textbox(label="Answer"), gr.Textbox(label="Contexts")], | |
| title="π RAG QA Demo", | |
| description="Retrieval-Augmented QA with threshold and context preview", | |
| examples=[ | |
| "When was Abraham Lincoln inaugurated?", | |
| "What is the capital of France?", | |
| "Who wrote '1984'?" | |
| ], | |
| allow_flagging="never", | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |