Spaces:
Sleeping
Sleeping
File size: 5,687 Bytes
903b444 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# Search engine β supports semantic search (SBERT + FAISS) and keyword search (BM25)
import os
import json
import re
import numpy as np
import pandas as pd
import faiss
import torch # β
for GPU/CPU auto-detect
from sentence_transformers import SentenceTransformer
from config import VIDEO_METADATA, SEARCH_CONFIG
# For BM25 keyword ranking
from rank_bm25 import BM25Okapi
import nltk
# β no downloads at import-time in production; ensure 'punkt' is installed in the image
from nltk.tokenize import word_tokenize
# β
Auto-select device (GPU on server, CPU locally)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Paths
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
INDEX_PATH = os.path.join(BASE_DIR, "../data/embeddings/faiss.index")
METADATA_PATH = os.path.join(BASE_DIR, "../data/metadata.csv")
# Load model + indexes
MODEL_NAME = SEARCH_CONFIG.get("embedding_model", "all-MiniLM-L6-v2")
model = SentenceTransformer(MODEL_NAME, device=DEVICE) # β
now uses GPU if available
faiss_index = faiss.read_index(INDEX_PATH)
metadata_df = pd.read_csv(METADATA_PATH)
# Build BM25 index
bm25_corpus = []
bm25_metadata = []
for _, row in metadata_df.iterrows():
lines_raw = json.loads(row["lines"]) if isinstance(row["lines"], str) else row["lines"]
if not lines_raw:
continue
for i, line in enumerate(lines_raw):
bm25_corpus.append(word_tokenize(line["text"].lower()))
bm25_metadata.append({
"text": line["text"].strip(),
"timestamp": line["timestamp"],
"video_id": line["video_id"],
"context_before": lines_raw[i - 1]["text"].strip() if i > 0 else "",
"context_after": lines_raw[i + 1]["text"].strip() if i + 1 < len(lines_raw) else "",
"summary_input": row["text"]
})
bm25_index = BM25Okapi(bm25_corpus)
# Search function
def search_query(query, offset=0, top_k=SEARCH_CONFIG.get("results_per_page", 5), semantic_mode=True):
"""
Search:
- Semantic mode β SBERT + FAISS + similarity threshold.
- Keyword mode β BM25 ranking over all subtitle lines.
"""
if semantic_mode:
query_vector = model.encode([query])
faiss_top_k = SEARCH_CONFIG.get("faiss_top_k", 100)
semantic_threshold = SEARCH_CONFIG.get("semantic_threshold", 0.40)
semantic_top_n = SEARCH_CONFIG.get("semantic_top_n", 4)
# Semantic search with FAISS
D, I = faiss_index.search(np.array(query_vector), faiss_top_k)
all_hits_with_scores = []
for idx, score in zip(I[0], D[0]):
current = metadata_df.iloc[idx]
lines_raw = json.loads(current["lines"]) if isinstance(current["lines"], str) else current["lines"]
if not lines_raw:
continue
# Encode all lines in this chunk
line_texts = [line["text"] for line in lines_raw]
line_vectors = model.encode(line_texts)
query_vec = query_vector[0]
similarities = np.dot(line_vectors, query_vec) / (
np.linalg.norm(line_vectors, axis=1) * np.linalg.norm(query_vec)
)
line_indices = [i for i, sim in enumerate(similarities) if sim >= semantic_threshold]
line_indices.sort(key=lambda i: similarities[i], reverse=True)
line_indices = line_indices[:semantic_top_n]
for i in line_indices:
match_text = lines_raw[i]["text"]
match_time = lines_raw[i]["timestamp"]
video_id = lines_raw[i]["video_id"]
if re.search(re.escape(query), match_text, re.IGNORECASE):
score -= 0.05
friendly_key = next((k for k, v in VIDEO_METADATA.items() if v["id"] == video_id), None)
video_title = VIDEO_METADATA[friendly_key]["title"] if friendly_key else "Unknown Video"
before = lines_raw[i - 1]["text"] if i > 0 else ""
after = lines_raw[i + 1]["text"] if i + 1 < len(lines_raw) else ""
summary_block = current["text"]
all_hits_with_scores.append((
score,
{
"text": match_text.strip(),
"context_before": before.strip(),
"context_after": after.strip(),
"summary_input": summary_block,
"timestamp": match_time,
"video_id": video_id,
"video_title": video_title
}
))
all_hits_with_scores.sort(key=lambda x: x[0])
sorted_hits = [hit for _, hit in all_hits_with_scores]
return sorted_hits[offset:offset + top_k], len(sorted_hits)
else:
# Keyword mode: BM25
tokenized_query = word_tokenize(query.lower())
scores = bm25_index.get_scores(tokenized_query)
sorted_indices = np.argsort(scores)[::-1]
all_hits_with_scores = []
for idx in sorted_indices:
if scores[idx] <= 0:
continue
r = bm25_metadata[idx]
video_id = r["video_id"]
friendly_key = next((k for k, v in VIDEO_METADATA.items() if v["id"] == video_id), None)
video_title = VIDEO_METADATA[friendly_key]["title"] if friendly_key else "Unknown Video"
r["video_title"] = video_title
all_hits_with_scores.append((scores[idx], r))
all_hits_with_scores.sort(key=lambda x: x[0], reverse=True)
sorted_hits = [hit for _, hit in all_hits_with_scores]
return sorted_hits[offset:offset + top_k], len(sorted_hits)
|