Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import json | |
| import time | |
| import gradio as gr | |
| import tempfile | |
| from typing import Dict, Any, List, Optional | |
| from transformers import AutoTokenizer | |
| from sentence_transformers import SentenceTransformer | |
| from pydantic import BaseModel, Field | |
| from anthropic import Anthropic | |
| from huggingface_hub import login | |
| CLAUDE_MODEL = "claude-3-5-sonnet-20241022" | |
| OPENAI_MODEL = "gpt-4o" | |
| GEMINI_MODEL = "gemini-2.0-flash" | |
| DEFAULT_TEMPERATURE = 0.7 | |
| TOKENIZER_MODEL = "answerdotai/ModernBERT-base" | |
| SENTENCE_TRANSFORMER_MODEL = "all-MiniLM-L6-v2" | |
| class CourseInfo(BaseModel): | |
| course_name: str = Field(description="Name of the course") | |
| section_name: str = Field(description="Name of the course section") | |
| lesson_name: str = Field(description="Name of the lesson") | |
| class QuizOption(BaseModel): | |
| text: str = Field(description="The text of the answer option") | |
| correct: bool = Field(description="Whether this option is correct") | |
| class QuizQuestion(BaseModel): | |
| question: str = Field(description="The text of the quiz question") | |
| options: List[QuizOption] = Field(description="List of answer options") | |
| class Segment(BaseModel): | |
| segment_number: int = Field(description="The segment number") | |
| topic_name: str = Field(description="Unique and specific topic name that clearly differentiates it from other segments") | |
| key_concepts: List[str] = Field(description="3-5 key concepts discussed in the segment") | |
| summary: str = Field(description="Brief summary of the segment (3-5 sentences)") | |
| quiz_questions: List[QuizQuestion] = Field(description="5 quiz questions based on the segment content") | |
| class TextSegmentAnalysis(BaseModel): | |
| course_info: CourseInfo = Field(description="Information about the course") | |
| segments: List[Segment] = Field(description="List of text segments with analysis") | |
| hf_token = os.environ.get('HF_TOKEN', None) | |
| login(token=hf_token) | |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL) | |
| sentence_model = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL) | |
| # System prompt | |
| system_prompt = """You are an expert educational content analyzer. Your task is to analyze text content, | |
| identify distinct segments, and create high-quality educational quiz questions for each segment.""" | |
| def clean_text(text): | |
| text = re.sub(r'\[speaker_\d+\]', '', text) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| def split_text_by_tokens(text, max_tokens=8000): | |
| text = clean_text(text) | |
| tokens = tokenizer.encode(text) | |
| if len(tokens) <= max_tokens: | |
| return [text] | |
| split_point = len(tokens) // 2 | |
| sentences = re.split(r'(?<=[.!?])\s+', text) | |
| first_half = [] | |
| second_half = [] | |
| current_tokens = 0 | |
| for sentence in sentences: | |
| sentence_tokens = len(tokenizer.encode(sentence)) | |
| if current_tokens + sentence_tokens <= split_point: | |
| first_half.append(sentence) | |
| current_tokens += sentence_tokens | |
| else: | |
| second_half.append(sentence) | |
| return [" ".join(first_half), " ".join(second_half)] | |
| def generate_with_claude(text, api_key, course_name="", section_name="", lesson_name=""): | |
| from prompts import SYSTEM_PROMPT, ANALYSIS_PROMPT_TEMPLATE_CLAUDE | |
| client = Anthropic(api_key=api_key) | |
| segment_analysis_schema = TextSegmentAnalysis.model_json_schema() | |
| tools = [ | |
| { | |
| "name": "build_segment_analysis", | |
| "description": "Build the text segment analysis with quiz questions", | |
| "input_schema": segment_analysis_schema | |
| } | |
| ] | |
| system_prompt = """You are a helpful assistant specialized in text analysis and educational content creation. | |
| You analyze texts to identify distinct segments, create summaries, and generate quiz questions.""" | |
| prompt = prompt = ANALYSIS_PROMPT_TEMPLATE_CLAUDE.format( | |
| course_name=course_name, | |
| section_name=section_name, | |
| lesson_name=lesson_name, | |
| text=text | |
| ) | |
| try: | |
| response = client.messages.create( | |
| model=CLAUDE_MODEL, | |
| max_tokens=8192, | |
| temperature=DEFAULT_TEMPERATURE, | |
| system=system_prompt, | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": prompt | |
| } | |
| ], | |
| tools=tools, | |
| tool_choice={"type": "tool", "name": "build_segment_analysis"} | |
| ) | |
| # Extract the tool call content | |
| if response.content and len(response.content) > 0 and hasattr(response.content[0], 'input'): | |
| function_call = response.content[0].input | |
| return function_call | |
| else: | |
| raise Exception("No valid tool call found in the response") | |
| except Exception as e: | |
| raise Exception(f"Error calling Anthropic API: {str(e)}") | |
| def get_llm_by_api_key(api_key): | |
| if api_key.startswith("sk-ant-"): # Claude API key format | |
| from langchain_anthropic import ChatAnthropic | |
| return ChatAnthropic( | |
| anthropic_api_key=api_key, | |
| model_name=CLAUDE_MODEL, | |
| temperature=DEFAULT_TEMPERATURE, | |
| max_retries=3 | |
| ) | |
| elif api_key.startswith("sk-"): # OpenAI API key format | |
| from langchain_openai import ChatOpenAI | |
| return ChatOpenAI( | |
| openai_api_key=api_key, | |
| model_name=OPENAI_MODEL, | |
| temperature=DEFAULT_TEMPERATURE, | |
| max_retries=3 | |
| ) | |
| else: # Default to Gemini | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| os.environ["GOOGLE_API_KEY"] = api_key | |
| return ChatGoogleGenerativeAI( | |
| model=GEMINI_MODEL, | |
| temperature=DEFAULT_TEMPERATURE, | |
| max_retries=3 | |
| ) | |
| def segment_and_analyze_text(text: str, api_key: str, course_name="", section_name="", lesson_name="") -> Dict[str, Any]: | |
| from prompts import SYSTEM_PROMPT, ANALYSIS_PROMPT_TEMPLATE_GEMINI | |
| if api_key.startswith("sk-ant-"): | |
| return generate_with_claude(text, api_key, course_name, section_name, lesson_name) | |
| # For other models, use LangChain | |
| llm = get_llm_by_api_key(api_key) | |
| prompt = ANALYSIS_PROMPT_TEMPLATE_GEMINI.format( | |
| course_name=course_name, | |
| section_name=section_name, | |
| lesson_name=lesson_name, | |
| text=text | |
| ) | |
| try: | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| response = llm.invoke(messages) | |
| try: | |
| content = response.content | |
| json_match = re.search(r'```json\s*([\s\S]*?)\s*```', content) | |
| if json_match: | |
| json_str = json_match.group(1) | |
| else: | |
| json_match = re.search(r'(\{[\s\S]*\})', content) | |
| if json_match: | |
| json_str = json_match.group(1) | |
| else: | |
| json_str = content | |
| # Parse the JSON | |
| function_call = json.loads(json_str) | |
| return function_call | |
| except json.JSONDecodeError: | |
| raise Exception("Could not parse JSON from LLM response") | |
| except Exception as e: | |
| raise Exception(f"Error calling API: {str(e)}") | |
| def format_quiz_for_display(results): | |
| output = [] | |
| if "course_info" in results: | |
| course_info = results["course_info"] | |
| output.append(f"{'='*40}") | |
| output.append(f"COURSE: {course_info.get('course_name', 'N/A')}") | |
| output.append(f"SECTION: {course_info.get('section_name', 'N/A')}") | |
| output.append(f"LESSON: {course_info.get('lesson_name', 'N/A')}") | |
| output.append(f"{'='*40}\n") | |
| segments = results.get("segments", []) | |
| for i, segment in enumerate(segments): | |
| topic = segment["topic_name"] | |
| segment_num = i + 1 | |
| output.append(f"\n\n{'='*40}") | |
| output.append(f"SEGMENT {segment_num}: {topic}") | |
| output.append(f"{'='*40}\n") | |
| output.append("KEY CONCEPTS:") | |
| for concept in segment["key_concepts"]: | |
| output.append(f"• {concept}") | |
| output.append("\nSUMMARY:") | |
| output.append(segment["summary"]) | |
| output.append("\nQUIZ QUESTIONS:") | |
| for i, q in enumerate(segment["quiz_questions"]): | |
| output.append(f"\n{i+1}. {q['question']}") | |
| for j, option in enumerate(q['options']): | |
| letter = chr(97 + j).upper() | |
| correct_marker = " ✓" if option["correct"] else "" | |
| output.append(f" {letter}. {option['text']}{correct_marker}") | |
| return "\n".join(output) | |
| def analyze_document(text, api_key, course_name, section_name, lesson_name): | |
| try: | |
| start_time = time.time() | |
| # Split text if it's too long | |
| text_parts = split_text_by_tokens(text) | |
| all_results = { | |
| "course_info": { | |
| "course_name": course_name, | |
| "section_name": section_name, | |
| "lesson_name": lesson_name | |
| }, | |
| "segments": [] | |
| } | |
| segment_counter = 1 | |
| # Process each part of the text | |
| for part in text_parts: | |
| analysis = segment_and_analyze_text( | |
| part, | |
| api_key, | |
| course_name=course_name, | |
| section_name=section_name, | |
| lesson_name=lesson_name | |
| ) | |
| if "segments" in analysis: | |
| for segment in analysis["segments"]: | |
| segment["segment_number"] = segment_counter | |
| all_results["segments"].append(segment) | |
| segment_counter += 1 | |
| end_time = time.time() | |
| total_time = end_time - start_time | |
| # Format the results for display | |
| formatted_text = format_quiz_for_display(all_results) | |
| formatted_text = f"Total processing time: {total_time:.2f} seconds\n\n" + formatted_text | |
| # Create temporary files for JSON and text output | |
| json_path = tempfile.mktemp(suffix='.json') | |
| with open(json_path, 'w', encoding='utf-8') as json_file: | |
| json.dump(all_results, json_file, indent=2) | |
| txt_path = tempfile.mktemp(suffix='.txt') | |
| with open(txt_path, 'w', encoding='utf-8') as txt_file: | |
| txt_file.write(formatted_text) | |
| return formatted_text, json_path, txt_path | |
| except Exception as e: | |
| error_message = f"Error processing document: {str(e)}" | |
| return error_message, None, None | |
| with gr.Blocks(title="Quiz Generator") as app: | |
| gr.Markdown("# Quiz Generator") | |
| with gr.Row(): | |
| with gr.Column(): | |
| course_name = gr.Textbox( | |
| placeholder="Enter the course name", | |
| label="Course Name" | |
| ) | |
| section_name = gr.Textbox( | |
| placeholder="Enter the section name", | |
| label="Section Name" | |
| ) | |
| lesson_name = gr.Textbox( | |
| placeholder="Enter the lesson name", | |
| label="Lesson Name" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox( | |
| label="Input Document Text", | |
| placeholder="Paste your document text here...", | |
| lines=10 | |
| ) | |
| api_key = gr.Textbox( | |
| label="API Key", | |
| placeholder="Enter your OpenAI, Claude, or Gemini API key", | |
| type="password" | |
| ) | |
| analyze_btn = gr.Button("Analyze Document") | |
| with gr.Column(): | |
| output_results = gr.Textbox( | |
| label="Analysis Results", | |
| lines=20 | |
| ) | |
| json_file_output = gr.File(label="Download JSON") | |
| txt_file_output = gr.File(label="Download TXT") | |
| analyze_btn.click( | |
| fn=analyze_document, | |
| inputs=[input_text, api_key, course_name, section_name, lesson_name], | |
| outputs=[output_results, json_file_output, txt_file_output] | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() |