Spaces:
Runtime error
Runtime error
| import json | |
| import chromadb | |
| from datetime import datetime | |
| import math | |
| from utils.general_utils import timeit | |
| from utils.embedding_utils import MyEmbeddingFunction | |
| from youtube_transcript_api import YouTubeTranscriptApi | |
| def run_etl(json_path="data/videos.json", db=None, batch_size=None, overlap=None): | |
| with open(json_path) as f: | |
| video_info = json.load(f) | |
| videos = [] | |
| for video in video_info: | |
| video_id = video["id"] | |
| video_title = video["title"] | |
| transcript = get_video_transcript(video_id) | |
| print(f"Transcript for video {video_id} fetched.") | |
| if transcript: | |
| formatted_transcript = format_transcript(transcript, video_id, video_title, batch_size=batch_size, overlap=overlap) | |
| videos.extend(formatted_transcript) | |
| if db: | |
| initialize_db(db) | |
| load_data_to_db(db, videos) | |
| log_data_load(json_path, db, batch_size, overlap) | |
| else: | |
| print("No database specified. Skipping database load.") | |
| print(videos) | |
| def get_video_transcript(video_id): | |
| try: | |
| transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en', 'en-US']) | |
| return transcript | |
| except Exception as e: | |
| print(f"Error fetching transcript for video {video_id}: {str(e)}") | |
| return None | |
| def format_transcript(transcript, video_id, video_title, batch_size=None, overlap=None): | |
| formatted_data = [] | |
| base_url = f"https://www.youtube.com/watch?v={video_id}" | |
| query_params = "&t={start}s" | |
| if not batch_size: | |
| batch_size = 1 | |
| overlap = 0 | |
| for i in range(0, len(transcript), batch_size - overlap): | |
| batch = list(transcript[i:i+batch_size]) | |
| start_time = batch[0]["start"] | |
| text = " ".join(entry["text"] for entry in batch) | |
| url = base_url + query_params.format(start=start_time) | |
| metadata = { | |
| "video_id": video_id, | |
| "segment_id": video_id + "__" + str(i), | |
| "title": video_title, | |
| "source": url | |
| } | |
| segment = {"text": text, "metadata": metadata} | |
| formatted_data.append(segment) | |
| return formatted_data | |
| embed_text = MyEmbeddingFunction() | |
| def initialize_db(db_path, distance_metric="cosine"): | |
| client = chromadb.PersistentClient(path=db_path) | |
| # Clear existing data | |
| # client.reset() | |
| client.create_collection( | |
| name="huberman_videos", | |
| embedding_function=embed_text, | |
| metadata={"hnsw:space": distance_metric} | |
| ) | |
| print(f"Database created at {db_path}") | |
| def load_data_to_db(db_path, data): | |
| client = chromadb.PersistentClient(path=db_path) | |
| collection = client.get_collection("huberman_videos") | |
| num_rows = len(data) | |
| batch_size = 5461 | |
| num_batches = math.ceil(num_rows / batch_size) | |
| for i in range(num_batches): | |
| batch_data = data[i * batch_size : (i + 1) * batch_size] | |
| documents = [segment['text'] for segment in batch_data] | |
| metadata = [segment['metadata'] for segment in batch_data] | |
| ids = [segment['metadata']['segment_id'] for segment in batch_data] | |
| collection.add( | |
| documents=documents, | |
| metadatas=metadata, | |
| ids=ids | |
| ) | |
| print(f"Batch {i+1} of {num_batches} loaded to database.") | |
| print(f"Data loaded to database at {db_path}.") | |
| def log_data_load(json_path, db_path, batch_size, overlap): | |
| log_json = json.dumps({ | |
| "videos_info_path": json_path, | |
| "db_path": db_path, | |
| "batch_size": batch_size, | |
| "overlap": overlap, | |
| "load_time": str(datetime.now()) | |
| }) | |
| db_file = db_path.split("/")[-1] | |
| db_name = db_file.split(".")[0] | |
| log_path = f"data/logs/{db_name}_load_log.json" | |
| with open(log_path, "w") as f: | |
| f.write(log_json) |