Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
|
| 4 |
import os
|
| 5 |
import pickle
|
|
|
|
| 6 |
import faiss
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
|
@@ -15,6 +16,7 @@ from transformers import (
|
|
| 15 |
AutoModelForSeq2SeqLM,
|
| 16 |
pipeline as hf_pipeline,
|
| 17 |
)
|
|
|
|
| 18 |
|
| 19 |
# ββ 1. Configuration ββ
|
| 20 |
DATA_DIR = os.path.join(os.getcwd(), "data")
|
|
@@ -30,7 +32,6 @@ MAX_CTX_WORDS = int(os.getenv("MAX_CTX_WORDS", 200))
|
|
| 30 |
DEVICE = 0 if torch.cuda.is_available() else -1
|
| 31 |
os.makedirs(DATA_DIR, exist_ok=True)
|
| 32 |
|
| 33 |
-
print(f"MODEL={MODEL_NAME}, EMBEDDER={EMBEDDER_MODEL}, DEVICE={'GPU' if DEVICE==0 else 'CPU'}")
|
| 34 |
|
| 35 |
# ββ 2. Helpers ββ
|
| 36 |
def make_context_snippets(contexts, max_words=MAX_CTX_WORDS):
|
|
@@ -53,15 +54,15 @@ def chunk_text(text, max_tokens, stride=None):
|
|
| 53 |
start += stride
|
| 54 |
return chunks
|
| 55 |
|
|
|
|
| 56 |
# ββ 3. Load & preprocess passages ββ
|
| 57 |
def load_passages():
|
| 58 |
-
# 3.1 load raw corpora
|
| 59 |
wiki_ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages")
|
| 60 |
squad_ds = load_dataset("rajpurkar/squad_v2", split="train[:100]")
|
| 61 |
trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]")
|
| 62 |
|
| 63 |
-
wiki_passages
|
| 64 |
-
squad_passages
|
| 65 |
trivia_passages = []
|
| 66 |
for ex in trivia_ds:
|
| 67 |
for fld in ("wiki_context", "search_context"):
|
|
@@ -69,12 +70,10 @@ def load_passages():
|
|
| 69 |
if txt:
|
| 70 |
trivia_passages.append(txt)
|
| 71 |
|
| 72 |
-
# dedupe
|
| 73 |
all_passages = list(dict.fromkeys(wiki_passages + squad_passages + trivia_passages))
|
|
|
|
|
|
|
| 74 |
|
| 75 |
-
# chunk long passages
|
| 76 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 77 |
-
max_tokens = tokenizer.model_max_length
|
| 78 |
chunks = []
|
| 79 |
for p in all_passages:
|
| 80 |
toks = tokenizer.tokenize(p)
|
|
@@ -88,6 +87,7 @@ def load_passages():
|
|
| 88 |
pickle.dump(chunks, f)
|
| 89 |
return chunks
|
| 90 |
|
|
|
|
| 91 |
# ββ 4. Build or load FAISS ββ
|
| 92 |
def load_faiss_index(passages):
|
| 93 |
embedder = SentenceTransformer(EMBEDDER_MODEL)
|
|
@@ -116,6 +116,7 @@ def load_faiss_index(passages):
|
|
| 116 |
|
| 117 |
return embedder, reranker, index
|
| 118 |
|
|
|
|
| 119 |
# ββ 5. Initialize RAG components ββ
|
| 120 |
def setup_rag():
|
| 121 |
if os.path.exists(PCTX_PATH):
|
|
@@ -141,8 +142,9 @@ def setup_rag():
|
|
| 141 |
|
| 142 |
return passages, embedder, reranker, index, qa_pipe
|
| 143 |
|
|
|
|
| 144 |
# ββ 6. Retrieval & generation ββ
|
| 145 |
-
def retrieve(question, passages, embedder, index, k=20, rerank_k=5):
|
| 146 |
q_emb = embedder.encode([question], convert_to_numpy=True)
|
| 147 |
distances, idxs = index.search(q_emb, k)
|
| 148 |
|
|
@@ -166,7 +168,7 @@ def generate(question, contexts, qa_pipe):
|
|
| 166 |
return qa_pipe(prompt)[0]["generated_text"].strip()
|
| 167 |
|
| 168 |
def retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe):
|
| 169 |
-
contexts, dists = retrieve(question, passages, embedder, index)
|
| 170 |
if not contexts or dists[0] > DIST_THRESHOLD:
|
| 171 |
return "Sorry, I don't know.", []
|
| 172 |
return generate(question, contexts, qa_pipe), contexts
|
|
@@ -181,24 +183,99 @@ def answer_and_contexts(question, passages, embedder, reranker, index, qa_pipe):
|
|
| 181 |
]
|
| 182 |
return ans, "\n\n---\n\n".join(snippets)
|
| 183 |
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
def main():
|
| 186 |
passages, embedder, reranker, index, qa_pipe = setup_rag()
|
| 187 |
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
title="π RAG QA Demo",
|
| 193 |
-
description="Retrieval-Augmented QA with threshold and context preview",
|
| 194 |
-
examples=[
|
| 195 |
-
"When was Abraham Lincoln inaugurated?",
|
| 196 |
-
"What is the capital of France?",
|
| 197 |
-
"Who wrote '1984'?"
|
| 198 |
-
],
|
| 199 |
-
allow_flagging="never",
|
| 200 |
)
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
if __name__ == "__main__":
|
| 204 |
main()
|
|
|
|
| 3 |
|
| 4 |
import os
|
| 5 |
import pickle
|
| 6 |
+
import argparse
|
| 7 |
import faiss
|
| 8 |
import numpy as np
|
| 9 |
import torch
|
|
|
|
| 16 |
AutoModelForSeq2SeqLM,
|
| 17 |
pipeline as hf_pipeline,
|
| 18 |
)
|
| 19 |
+
import evaluate
|
| 20 |
|
| 21 |
# ββ 1. Configuration ββ
|
| 22 |
DATA_DIR = os.path.join(os.getcwd(), "data")
|
|
|
|
| 32 |
DEVICE = 0 if torch.cuda.is_available() else -1
|
| 33 |
os.makedirs(DATA_DIR, exist_ok=True)
|
| 34 |
|
|
|
|
| 35 |
|
| 36 |
# ββ 2. Helpers ββ
|
| 37 |
def make_context_snippets(contexts, max_words=MAX_CTX_WORDS):
|
|
|
|
| 54 |
start += stride
|
| 55 |
return chunks
|
| 56 |
|
| 57 |
+
|
| 58 |
# ββ 3. Load & preprocess passages ββ
|
| 59 |
def load_passages():
|
|
|
|
| 60 |
wiki_ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages")
|
| 61 |
squad_ds = load_dataset("rajpurkar/squad_v2", split="train[:100]")
|
| 62 |
trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]")
|
| 63 |
|
| 64 |
+
wiki_passages = wiki_ds["passage"]
|
| 65 |
+
squad_passages = [ex["context"] for ex in squad_ds]
|
| 66 |
trivia_passages = []
|
| 67 |
for ex in trivia_ds:
|
| 68 |
for fld in ("wiki_context", "search_context"):
|
|
|
|
| 70 |
if txt:
|
| 71 |
trivia_passages.append(txt)
|
| 72 |
|
|
|
|
| 73 |
all_passages = list(dict.fromkeys(wiki_passages + squad_passages + trivia_passages))
|
| 74 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 75 |
+
max_tokens = tokenizer.model_max_length
|
| 76 |
|
|
|
|
|
|
|
|
|
|
| 77 |
chunks = []
|
| 78 |
for p in all_passages:
|
| 79 |
toks = tokenizer.tokenize(p)
|
|
|
|
| 87 |
pickle.dump(chunks, f)
|
| 88 |
return chunks
|
| 89 |
|
| 90 |
+
|
| 91 |
# ββ 4. Build or load FAISS ββ
|
| 92 |
def load_faiss_index(passages):
|
| 93 |
embedder = SentenceTransformer(EMBEDDER_MODEL)
|
|
|
|
| 116 |
|
| 117 |
return embedder, reranker, index
|
| 118 |
|
| 119 |
+
|
| 120 |
# ββ 5. Initialize RAG components ββ
|
| 121 |
def setup_rag():
|
| 122 |
if os.path.exists(PCTX_PATH):
|
|
|
|
| 142 |
|
| 143 |
return passages, embedder, reranker, index, qa_pipe
|
| 144 |
|
| 145 |
+
|
| 146 |
# ββ 6. Retrieval & generation ββ
|
| 147 |
+
def retrieve(question, passages, embedder, reranker, index, k=20, rerank_k=5):
|
| 148 |
q_emb = embedder.encode([question], convert_to_numpy=True)
|
| 149 |
distances, idxs = index.search(q_emb, k)
|
| 150 |
|
|
|
|
| 168 |
return qa_pipe(prompt)[0]["generated_text"].strip()
|
| 169 |
|
| 170 |
def retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe):
|
| 171 |
+
contexts, dists = retrieve(question, passages, embedder, reranker, index)
|
| 172 |
if not contexts or dists[0] > DIST_THRESHOLD:
|
| 173 |
return "Sorry, I don't know.", []
|
| 174 |
return generate(question, contexts, qa_pipe), contexts
|
|
|
|
| 183 |
]
|
| 184 |
return ans, "\n\n---\n\n".join(snippets)
|
| 185 |
|
| 186 |
+
|
| 187 |
+
# ββ 7. Evaluation routines ββ
|
| 188 |
+
def retrieval_recall(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100):
|
| 189 |
+
hits = 0
|
| 190 |
+
for ex in dataset.select(range(num_samples)):
|
| 191 |
+
question = ex["question"]
|
| 192 |
+
gold_answers = ex["answers"]["text"]
|
| 193 |
+
|
| 194 |
+
if rerank_k:
|
| 195 |
+
ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k)
|
| 196 |
+
else:
|
| 197 |
+
q_emb = embedder.encode([question], convert_to_numpy=True)
|
| 198 |
+
distances, idxs = index.search(q_emb, k)
|
| 199 |
+
ctxs = [passages[i] for i in idxs[0]]
|
| 200 |
+
|
| 201 |
+
if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers):
|
| 202 |
+
hits += 1
|
| 203 |
+
|
| 204 |
+
recall = hits / num_samples
|
| 205 |
+
print(f"Retrieval Recall@{k} (rerank_k={rerank_k}): {recall:.3f} ({hits}/{num_samples})")
|
| 206 |
+
return recall
|
| 207 |
+
|
| 208 |
+
def retrieval_recall_answerable(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100):
|
| 209 |
+
hits, total = 0, 0
|
| 210 |
+
for ex in dataset.select(range(num_samples)):
|
| 211 |
+
gold = ex["answers"]["text"]
|
| 212 |
+
if not gold:
|
| 213 |
+
continue
|
| 214 |
+
total += 1
|
| 215 |
+
question = ex["question"]
|
| 216 |
+
|
| 217 |
+
if rerank_k:
|
| 218 |
+
ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k)
|
| 219 |
+
else:
|
| 220 |
+
q_emb = embedder.encode([question], convert_to_numpy=True)
|
| 221 |
+
distances, idxs = index.search(q_emb, k)
|
| 222 |
+
ctxs = [passages[i] for i in idxs[0]]
|
| 223 |
+
|
| 224 |
+
if any(any(ans in ctx for ctx in ctxs) for ans in gold):
|
| 225 |
+
hits += 1
|
| 226 |
+
|
| 227 |
+
recall = hits / total if total > 0 else 0.0
|
| 228 |
+
print(f"Retrieval Recall@{k} on answerable only (rerank_k={rerank_k}): {recall:.3f} ({hits}/{total})")
|
| 229 |
+
return recall
|
| 230 |
+
|
| 231 |
+
def qa_eval_answerable(dataset, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100):
|
| 232 |
+
squad_metric = evaluate.load("squad")
|
| 233 |
+
preds, refs = [], []
|
| 234 |
+
for ex in dataset.select(range(num_samples)):
|
| 235 |
+
gold = ex["answers"]["text"]
|
| 236 |
+
if not gold:
|
| 237 |
+
continue
|
| 238 |
+
qid = ex["id"]
|
| 239 |
+
answer, _ = retrieve_and_answer(ex["question"], passages, embedder, reranker, index, qa_pipe)
|
| 240 |
+
preds.append({"id": qid, "prediction_text": answer})
|
| 241 |
+
refs.append({"id": qid, "answers": ex["answers"]})
|
| 242 |
+
|
| 243 |
+
results = squad_metric.compute(predictions=preds, references=refs)
|
| 244 |
+
print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
|
| 245 |
+
return results
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# ββ 8. Main entry ββ
|
| 249 |
def main():
|
| 250 |
passages, embedder, reranker, index, qa_pipe = setup_rag()
|
| 251 |
|
| 252 |
+
parser = argparse.ArgumentParser()
|
| 253 |
+
parser.add_argument(
|
| 254 |
+
"--eval", action="store_true",
|
| 255 |
+
help="Run retrieval/QA evaluations on SQuAD instead of launching the demo"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
)
|
| 257 |
+
args = parser.parse_args()
|
| 258 |
+
|
| 259 |
+
if args.eval:
|
| 260 |
+
squad = load_dataset("rajpurkar/squad_v2", split="validation")
|
| 261 |
+
retrieval_recall(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100)
|
| 262 |
+
retrieval_recall_answerable(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100)
|
| 263 |
+
qa_eval_answerable(squad, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100)
|
| 264 |
+
else:
|
| 265 |
+
demo = gr.Interface(
|
| 266 |
+
fn=lambda q: answer_and_contexts(q, passages, embedder, reranker, index, qa_pipe),
|
| 267 |
+
inputs=gr.Textbox(lines=1, placeholder="Ask me anythingβ¦", label="Question"),
|
| 268 |
+
outputs=[gr.Textbox(label="Answer"), gr.Textbox(label="Contexts")],
|
| 269 |
+
title="π RAG QA Demo",
|
| 270 |
+
description="Retrieval-Augmented QA with threshold and context preview",
|
| 271 |
+
examples=[
|
| 272 |
+
"When was Abraham Lincoln inaugurated?",
|
| 273 |
+
"What is the capital of France?",
|
| 274 |
+
"Who wrote '1984'?"
|
| 275 |
+
],
|
| 276 |
+
allow_flagging="never",
|
| 277 |
+
)
|
| 278 |
+
demo.launch()
|
| 279 |
|
| 280 |
if __name__ == "__main__":
|
| 281 |
main()
|