Spaces:
Sleeping
Sleeping
| import re | |
| import numpy as np | |
| import json | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer | |
| from sklearn.cluster import AgglomerativeClustering | |
| from sklearn.metrics.pairwise import cosine_distances | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| import os | |
| import gradio as gr | |
| tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base") | |
| sentence_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| max_tokens = 3000 | |
| def clean_text(text): | |
| text = re.sub(r'\[speaker_\d+\]', '', text) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| def split_text_with_modernbert_tokenizer(text): | |
| text = clean_text(text) | |
| rough_splits = re.split(r'(?<=[.!?])\s+', text) | |
| segments = [] | |
| current_segment = "" | |
| current_token_count = 0 | |
| for sentence in rough_splits: | |
| if not sentence.strip(): | |
| continue | |
| sentence_tokens = len(tokenizer.encode(sentence, add_special_tokens=False)) | |
| if (current_token_count + sentence_tokens > 100 or | |
| re.search(r'[.!?]$', current_segment.strip())): | |
| if current_segment: | |
| segments.append(current_segment.strip()) | |
| current_segment = sentence | |
| current_token_count = sentence_tokens | |
| else: | |
| current_segment += " " + sentence if current_segment else sentence | |
| current_token_count += sentence_tokens | |
| if current_segment: | |
| segments.append(current_segment.strip()) | |
| refined_segments = [] | |
| for segment in segments: | |
| if len(segment.split()) < 3: | |
| if refined_segments: | |
| refined_segments[-1] += ' ' + segment | |
| else: | |
| refined_segments.append(segment) | |
| continue | |
| tokens = tokenizer.tokenize(segment) | |
| if len(tokens) < 50: | |
| refined_segments.append(segment) | |
| continue | |
| break_indices = [i for i, token in enumerate(tokens) | |
| if ('.' in token or ',' in token or '?' in token or '!' in token) | |
| and i < len(tokens) - 1] | |
| if not break_indices or break_indices[-1] < len(tokens) * 0.7: | |
| refined_segments.append(segment) | |
| continue | |
| mid_idx = break_indices[len(break_indices) // 2] | |
| first_half = tokenizer.convert_tokens_to_string(tokens[:mid_idx+1]) | |
| second_half = tokenizer.convert_tokens_to_string(tokens[mid_idx+1:]) | |
| refined_segments.append(first_half.strip()) | |
| refined_segments.append(second_half.strip()) | |
| return refined_segments | |
| def semantic_chunking(text): | |
| segments = split_text_with_modernbert_tokenizer(text) | |
| segment_embeddings = sentence_model.encode(segments) | |
| distances = cosine_distances(segment_embeddings) | |
| agg_clustering = AgglomerativeClustering( | |
| n_clusters=None, | |
| distance_threshold=1, | |
| metric='precomputed', | |
| linkage='average' | |
| ) | |
| clusters = agg_clustering.fit_predict(distances) | |
| # Group segments by cluster | |
| cluster_groups = {} | |
| for i, cluster_id in enumerate(clusters): | |
| if cluster_id not in cluster_groups: | |
| cluster_groups[cluster_id] = [] | |
| cluster_groups[cluster_id].append(segments[i]) | |
| chunks = [] | |
| for cluster_id in sorted(cluster_groups.keys()): | |
| cluster_segments = cluster_groups[cluster_id] | |
| current_chunk = [] | |
| current_token_count = 0 | |
| for segment in cluster_segments: | |
| segment_tokens = len(tokenizer.encode(segment, truncation=True, add_special_tokens=True)) | |
| if segment_tokens > max_tokens: | |
| if current_chunk: | |
| chunks.append(" ".join(current_chunk)) | |
| current_chunk = [] | |
| current_token_count = 0 | |
| chunks.append(segment) | |
| continue | |
| if current_token_count + segment_tokens > max_tokens and current_chunk: | |
| chunks.append(" ".join(current_chunk)) | |
| current_chunk = [segment] | |
| current_token_count = segment_tokens | |
| else: | |
| current_chunk.append(segment) | |
| current_token_count += segment_tokens | |
| if current_chunk: | |
| chunks.append(" ".join(current_chunk)) | |
| if len(chunks) > 1: | |
| chunk_embeddings = sentence_model.encode(chunks) | |
| chunk_similarities = 1 - cosine_distances(chunk_embeddings) | |
| i = 0 | |
| while i < len(chunks) - 1: | |
| j = i + 1 | |
| if chunk_similarities[i, j] > 0.75: | |
| combined = chunks[i] + " " + chunks[j] | |
| combined_tokens = len(tokenizer.encode(combined, truncation=True, add_special_tokens=True)) | |
| if combined_tokens <= max_tokens: | |
| # Merge chunks | |
| chunks[i] = combined | |
| chunks.pop(j) | |
| chunk_embeddings = sentence_model.encode(chunks) | |
| chunk_similarities = 1 - cosine_distances(chunk_embeddings) | |
| else: | |
| i += 1 | |
| else: | |
| i += 1 | |
| return chunks | |
| def analyze_segment_with_gemini(cluster_text, is_full_text=False): | |
| llm = ChatGoogleGenerativeAI( | |
| model="gemini-1.5-flash", | |
| temperature=0.7, | |
| max_tokens=None, | |
| timeout=None, | |
| max_retries=3 | |
| ) | |
| if is_full_text: | |
| prompt = f""" | |
| FIRST ASSESS THE TEXT: | |
| - Check if it's primarily self-introduction, biographical information, or conclusion | |
| - Check if it's too short or lacks meaningful content (less than 100 words of substance) | |
| - If either case is true, respond with a simple JSON: {{"status": "insufficient", "reason": "Brief explanation"}} | |
| Analyze the following text (likely a transcript or document) and: | |
| 1. First, do text segmentation and identify DISTINCT key topics within the text | |
| 2. For each segment/topic you identify: | |
| - Provide a SPECIFIC and UNIQUE topic name (3-5 words) that clearly differentiates it from other segments | |
| - List 3-5 key concepts discussed in that segment | |
| - Write a brief summary of that segment (3-5 sentences) | |
| - Create 5 quiz questions based DIRECTLY on the content in that segment | |
| For each quiz question: | |
| - Create one correct answer that comes DIRECTLY from the text | |
| - Create two plausible but incorrect answers | |
| - IMPORTANT: Ensure all answer options have similar length (± 3 words) | |
| - Ensure the correct answer is clearly indicated | |
| Text: | |
| {cluster_text} | |
| Format your response as JSON with the following structure: | |
| {{ | |
| "segments": [ | |
| {{ | |
| "topic_name": "Name of segment 1", | |
| "key_concepts": ["concept1", "concept2", "concept3"], | |
| "summary": "Brief summary of this segment.", | |
| "quiz_questions": [ | |
| {{ | |
| "question": "Question text?", | |
| "options": [ | |
| {{ | |
| "text": "Option A", | |
| "correct": false | |
| }}, | |
| {{ | |
| "text": "Option B", | |
| "correct": true | |
| }}, | |
| {{ | |
| "text": "Option C", | |
| "correct": false | |
| }} | |
| ] | |
| }}, | |
| // More questions... | |
| ] | |
| }}, | |
| // More segments... | |
| ] | |
| }} | |
| """ | |
| else: | |
| prompt = f""" | |
| FIRST ASSESS THE TEXT: | |
| - Check if it's primarily self-introduction, biographical information, or conclusion | |
| - Check if it's too short or lacks meaningful content (less than 100 words of substance) | |
| - If either case is true, respond with a simple JSON: {{"status": "insufficient", "reason": "Brief explanation"}} | |
| Analyze the following text segment and provide: | |
| 1. A SPECIFIC and DESCRIPTIVE topic name (3-5 words) that precisely captures the main focus | |
| 2. 3-5 key concepts discussed | |
| 3. A brief summary (6-7 sentences) | |
| 4. Create 5 quiz questions based DIRECTLY on the text content (not from your summary) | |
| For each quiz question: | |
| - Create one correct answer that comes DIRECTLY from the text | |
| - Create two plausible but incorrect answers | |
| - IMPORTANT and STRICTLY: Ensure all answer options have similar length (± 3 words) | |
| - Ensure the correct answer is clearly indicated | |
| Text segment: | |
| {cluster_text} | |
| Format your response as JSON with the following structure: | |
| {{ | |
| "topic_name": "Name of the topic", | |
| "key_concepts": ["concept1", "concept2", "concept3"], | |
| "summary": "Brief summary of the text segment.", | |
| "quiz_questions": [ | |
| {{ | |
| "question": "Question text?", | |
| "options": [ | |
| {{ | |
| "text": "Option A", | |
| "correct": false | |
| }}, | |
| {{ | |
| "text": "Option B", | |
| "correct": true | |
| }}, | |
| {{ | |
| "text": "Option C", | |
| "correct": false | |
| }} | |
| ] | |
| }}, | |
| // More questions... | |
| ] | |
| }} | |
| """ | |
| response = llm.invoke(prompt) | |
| response_text = response.content | |
| try: | |
| json_match = re.search(r'\{[\s\S]*\}', response_text) | |
| if json_match: | |
| response_json = json.loads(json_match.group(0)) | |
| else: | |
| response_json = json.loads(response_text) | |
| return response_json | |
| except json.JSONDecodeError as e: | |
| print(f"Error parsing JSON response: {e}") | |
| print(f"Raw response: {response_text}") | |
| if is_full_text: | |
| return { | |
| "segments": [ | |
| { | |
| "topic_name": "JSON Parsing Error", | |
| "key_concepts": ["Error in response format"], | |
| "summary": f"Could not parse the API response. Raw text: {response_text[:200]}...", | |
| "quiz_questions": [] | |
| } | |
| ] | |
| } | |
| else: | |
| return { | |
| "topic_name": "JSON Parsing Error", | |
| "key_concepts": ["Error in response format"], | |
| "summary": f"Could not parse the API response. Raw text: {response_text[:200]}...", | |
| "quiz_questions": [] | |
| } | |
| def process_document_with_quiz(text): | |
| token_count = len(tokenizer.encode(text)) | |
| print(f"Text contains {token_count} tokens") | |
| if token_count < 8000: | |
| print("Text is short enough to analyze directly without text segmentation") | |
| full_analysis = analyze_segment_with_gemini(text, is_full_text=True) | |
| results = [] | |
| if "segments" in full_analysis: | |
| for i, segment in enumerate(full_analysis["segments"]): | |
| segment["segment_number"] = i + 1 | |
| segment["segment_text"] = "Segment identified by Gemini" | |
| results.append(segment) | |
| print(f"Gemini identified {len(results)} segments in the text") | |
| else: | |
| print("Unexpected response format from Gemini") | |
| results = [full_analysis] | |
| return results | |
| chunks = semantic_chunking(text) | |
| print(f"{len(chunks)} semantic chunks were found\n") | |
| results = [] | |
| for i, chunk in enumerate(chunks): | |
| print(f"Analyzing segment {i+1}/{len(chunks)}...") | |
| analysis = analyze_segment_with_gemini(chunk, is_full_text=False) | |
| analysis["segment_number"] = i + 1 | |
| analysis["segment_text"] = chunk | |
| results.append(analysis) | |
| print(f"Completed analysis of segment {i+1}: {analysis['topic_name']}") | |
| return results | |
| def save_results_to_file(results, output_file="analysis_results.json"): | |
| with open(output_file, "w", encoding="utf-8") as f: | |
| json.dump(results, f, indent=2, ensure_ascii=False) | |
| print(f"Results saved to {output_file}") | |
| def format_quiz_for_display(results): | |
| output = [] | |
| for segment_result in results: | |
| segment_num = segment_result["segment_number"] | |
| topic = segment_result["topic_name"] | |
| output.append(f"\n\n{'='*40}") | |
| output.append(f"SEGMENT {segment_num}: {topic}") | |
| output.append(f"{'='*40}\n") | |
| output.append("KEY CONCEPTS:") | |
| for concept in segment_result["key_concepts"]: | |
| output.append(f"• {concept}") | |
| output.append("\nSUMMARY:") | |
| output.append(segment_result["summary"]) | |
| output.append("\nQUIZ QUESTIONS:") | |
| for i, q in enumerate(segment_result["quiz_questions"]): | |
| output.append(f"\n{i+1}. {q['question']}") | |
| for j, option in enumerate(q['options']): | |
| letter = chr(97 + j).upper() | |
| correct_marker = " ✓" if option["correct"] else "" | |
| output.append(f" {letter}. {option['text']}{correct_marker}") | |
| return "\n".join(output) | |
| def analyze_document(document_text: str, api_key: str) -> tuple: | |
| os.environ["GOOGLE_API_KEY"] = api_key | |
| try: | |
| results = process_document_with_quiz(document_text) | |
| formatted_output = format_quiz_for_display(results) | |
| json_path = "analysis_results.json" | |
| txt_path = "analysis_results.txt" | |
| with open(json_path, "w", encoding="utf-8") as f: | |
| json.dump(results, f, indent=2, ensure_ascii=False) | |
| with open(txt_path, "w", encoding="utf-8") as f: | |
| f.write(formatted_output) | |
| return formatted_output, json_path, txt_path | |
| except Exception as e: | |
| error_msg = f"Error processing document: {str(e)}" | |
| return error_msg, None, None | |
| with gr.Blocks(title="Quiz Generator") as app: | |
| gr.Markdown("# Quiz Generator") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Paste your document text here...", | |
| lines=10 | |
| ) | |
| api_key = gr.Textbox( | |
| label="Gemini API Key", | |
| placeholder="Enter your Gemini API key", | |
| type="password" | |
| ) | |
| analyze_btn = gr.Button("Analyze Document") | |
| with gr.Column(): | |
| output_results = gr.Textbox( | |
| label="Analysis Results", | |
| lines=20 | |
| ) | |
| json_file_output = gr.File(label="Download JSON") | |
| txt_file_output = gr.File(label="Download TXT") | |
| analyze_btn.click( | |
| fn=analyze_document, | |
| inputs=[input_text, api_key], | |
| outputs=[output_results, json_file_output, txt_file_output] | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() |