ks-version-1-1 / backend /semantic_search.py
NIKKI77's picture
Deploy: GPU-ready HF Space (Docker)
903b444
# Search engine β€” supports semantic search (SBERT + FAISS) and keyword search (BM25)
import os
import json
import re
import numpy as np
import pandas as pd
import faiss
import torch # βœ… for GPU/CPU auto-detect
from sentence_transformers import SentenceTransformer
from config import VIDEO_METADATA, SEARCH_CONFIG
# For BM25 keyword ranking
from rank_bm25 import BM25Okapi
import nltk
# ❌ no downloads at import-time in production; ensure 'punkt' is installed in the image
from nltk.tokenize import word_tokenize
# βœ… Auto-select device (GPU on server, CPU locally)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Paths
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
INDEX_PATH = os.path.join(BASE_DIR, "../data/embeddings/faiss.index")
METADATA_PATH = os.path.join(BASE_DIR, "../data/metadata.csv")
# Load model + indexes
MODEL_NAME = SEARCH_CONFIG.get("embedding_model", "all-MiniLM-L6-v2")
model = SentenceTransformer(MODEL_NAME, device=DEVICE) # βœ… now uses GPU if available
faiss_index = faiss.read_index(INDEX_PATH)
metadata_df = pd.read_csv(METADATA_PATH)
# Build BM25 index
bm25_corpus = []
bm25_metadata = []
for _, row in metadata_df.iterrows():
lines_raw = json.loads(row["lines"]) if isinstance(row["lines"], str) else row["lines"]
if not lines_raw:
continue
for i, line in enumerate(lines_raw):
bm25_corpus.append(word_tokenize(line["text"].lower()))
bm25_metadata.append({
"text": line["text"].strip(),
"timestamp": line["timestamp"],
"video_id": line["video_id"],
"context_before": lines_raw[i - 1]["text"].strip() if i > 0 else "",
"context_after": lines_raw[i + 1]["text"].strip() if i + 1 < len(lines_raw) else "",
"summary_input": row["text"]
})
bm25_index = BM25Okapi(bm25_corpus)
# Search function
def search_query(query, offset=0, top_k=SEARCH_CONFIG.get("results_per_page", 5), semantic_mode=True):
"""
Search:
- Semantic mode β†’ SBERT + FAISS + similarity threshold.
- Keyword mode β†’ BM25 ranking over all subtitle lines.
"""
if semantic_mode:
query_vector = model.encode([query])
faiss_top_k = SEARCH_CONFIG.get("faiss_top_k", 100)
semantic_threshold = SEARCH_CONFIG.get("semantic_threshold", 0.40)
semantic_top_n = SEARCH_CONFIG.get("semantic_top_n", 4)
# Semantic search with FAISS
D, I = faiss_index.search(np.array(query_vector), faiss_top_k)
all_hits_with_scores = []
for idx, score in zip(I[0], D[0]):
current = metadata_df.iloc[idx]
lines_raw = json.loads(current["lines"]) if isinstance(current["lines"], str) else current["lines"]
if not lines_raw:
continue
# Encode all lines in this chunk
line_texts = [line["text"] for line in lines_raw]
line_vectors = model.encode(line_texts)
query_vec = query_vector[0]
similarities = np.dot(line_vectors, query_vec) / (
np.linalg.norm(line_vectors, axis=1) * np.linalg.norm(query_vec)
)
line_indices = [i for i, sim in enumerate(similarities) if sim >= semantic_threshold]
line_indices.sort(key=lambda i: similarities[i], reverse=True)
line_indices = line_indices[:semantic_top_n]
for i in line_indices:
match_text = lines_raw[i]["text"]
match_time = lines_raw[i]["timestamp"]
video_id = lines_raw[i]["video_id"]
if re.search(re.escape(query), match_text, re.IGNORECASE):
score -= 0.05
friendly_key = next((k for k, v in VIDEO_METADATA.items() if v["id"] == video_id), None)
video_title = VIDEO_METADATA[friendly_key]["title"] if friendly_key else "Unknown Video"
before = lines_raw[i - 1]["text"] if i > 0 else ""
after = lines_raw[i + 1]["text"] if i + 1 < len(lines_raw) else ""
summary_block = current["text"]
all_hits_with_scores.append((
score,
{
"text": match_text.strip(),
"context_before": before.strip(),
"context_after": after.strip(),
"summary_input": summary_block,
"timestamp": match_time,
"video_id": video_id,
"video_title": video_title
}
))
all_hits_with_scores.sort(key=lambda x: x[0])
sorted_hits = [hit for _, hit in all_hits_with_scores]
return sorted_hits[offset:offset + top_k], len(sorted_hits)
else:
# Keyword mode: BM25
tokenized_query = word_tokenize(query.lower())
scores = bm25_index.get_scores(tokenized_query)
sorted_indices = np.argsort(scores)[::-1]
all_hits_with_scores = []
for idx in sorted_indices:
if scores[idx] <= 0:
continue
r = bm25_metadata[idx]
video_id = r["video_id"]
friendly_key = next((k for k, v in VIDEO_METADATA.items() if v["id"] == video_id), None)
video_title = VIDEO_METADATA[friendly_key]["title"] if friendly_key else "Unknown Video"
r["video_title"] = video_title
all_hits_with_scores.append((scores[idx], r))
all_hits_with_scores.sort(key=lambda x: x[0], reverse=True)
sorted_hits = [hit for _, hit in all_hits_with_scores]
return sorted_hits[offset:offset + top_k], len(sorted_hits)