Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import json | |
| import time | |
| import tempfile | |
| from typing import Dict, Any, List, Optional | |
| from transformers import AutoTokenizer | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import login | |
| from src.prompts import SUMMARY_PROMPT_TEMPLATE, QUIZ_PROMPT_TEMPLATE | |
| GEMINI_MODEL = "gemini-2.0-flash" | |
| DEFAULT_TEMPERATURE = 0.7 | |
| TOKENIZER_MODEL = "answerdotai/ModernBERT-base" | |
| SENTENCE_TRANSFORMER_MODEL = "all-MiniLM-L6-v2" | |
| hf_token = os.environ.get('HF_TOKEN', None) | |
| login(token=hf_token) | |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL) | |
| sentence_model = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL) | |
| def clean_text(text): | |
| text = re.sub(r'\[speaker_\d+\]', '', text) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| def split_text_by_tokens(text, max_tokens=12000): | |
| text = clean_text(text) | |
| tokens = tokenizer.encode(text) | |
| if len(tokens) <= max_tokens: | |
| return [text] | |
| split_point = len(tokens) // 2 | |
| sentences = re.split(r'(?<=[.!?])\s+', text) | |
| first_half = [] | |
| second_half = [] | |
| current_tokens = 0 | |
| for sentence in sentences: | |
| sentence_tokens = len(tokenizer.encode(sentence)) | |
| if current_tokens + sentence_tokens <= split_point: | |
| first_half.append(sentence) | |
| current_tokens += sentence_tokens | |
| else: | |
| second_half.append(sentence) | |
| return [" ".join(first_half), " ".join(second_half)] | |
| def generate_with_gemini(text, api_key, language, content_type="summary"): | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| os.environ["GOOGLE_API_KEY"] = api_key | |
| llm = ChatGoogleGenerativeAI( | |
| model=GEMINI_MODEL, | |
| temperature=DEFAULT_TEMPERATURE, | |
| max_retries=3 | |
| ) | |
| if content_type == "summary": | |
| base_prompt = SUMMARY_PROMPT_TEMPLATE.format(text=text) | |
| else: | |
| base_prompt = QUIZ_PROMPT_TEMPLATE.format(text=text) | |
| language_instruction = f"\nIMPORTANT: Generate ALL content in {language} language." | |
| prompt = base_prompt + language_instruction | |
| try: | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful AI assistant that creates high-quality text summaries and quizzes."}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| response = llm.invoke(messages) | |
| try: | |
| content = response.content | |
| json_match = re.search(r'```json\s*([\s\S]*?)\s*```', content) | |
| if json_match: | |
| json_str = json_match.group(1) | |
| else: | |
| json_match = re.search(r'(\{[\s\S]*\})', content) | |
| if json_match: | |
| json_str = json_match.group(1) | |
| else: | |
| json_str = content | |
| # Parse the JSON | |
| function_call = json.loads(json_str) | |
| return function_call | |
| except json.JSONDecodeError: | |
| raise Exception("Could not parse JSON from LLM response") | |
| except Exception as e: | |
| raise Exception(f"Error calling API: {str(e)}") | |
| def format_summary_for_display(results, language="English"): | |
| output = [] | |
| if language == "Uzbek": | |
| title_header = "SARLAVHA" | |
| overview_header = "UMUMIY KO'RINISH" | |
| key_points_header = "ASOSIY NUQTALAR" | |
| key_entities_header = "ASOSIY SHAXSLAR VA TUSHUNCHALAR" | |
| conclusion_header = "XULOSA" | |
| elif language == "Russian": | |
| title_header = "ЗАГОЛОВОК" | |
| overview_header = "ОБЗОР" | |
| key_points_header = "КЛЮЧЕВЫЕ МОМЕНТЫ" | |
| key_entities_header = "КЛЮЧЕВЫЕ ОБЪЕКТЫ" | |
| conclusion_header = "ЗАКЛЮЧЕНИЕ" | |
| else: | |
| title_header = "TITLE" | |
| overview_header = "OVERVIEW" | |
| key_points_header = "KEY POINTS" | |
| key_entities_header = "KEY ENTITIES" | |
| conclusion_header = "CONCLUSION" | |
| if "summary" not in results: | |
| if "segments" in results: | |
| segments = results.get("segments", []) | |
| for i, segment in enumerate(segments): | |
| topic = segment.get("topic_name", f"Section {i+1}") | |
| segment_num = i + 1 | |
| output.append(f"\n\n{'='*40}") | |
| output.append(f"SEGMENT {segment_num}: {topic}") | |
| output.append(f"{'='*40}\n") | |
| if "key_concepts" in segment: | |
| output.append("KEY CONCEPTS:") | |
| for concept in segment["key_concepts"]: | |
| output.append(f"• {concept}") | |
| if "summary" in segment: | |
| output.append("\nSUMMARY:") | |
| output.append(segment["summary"]) | |
| return "\n".join(output) | |
| else: | |
| return "Error: Could not parse summary results. Invalid format received." | |
| summary = results["summary"] | |
| if "title" in summary: | |
| output.append(f"\n\n{'='*40}") | |
| output.append(f"{title_header}: {summary['title']}") | |
| output.append(f"{'='*40}\n") | |
| # Overview | |
| if "overview" in summary: | |
| output.append(f"{overview_header}:") | |
| output.append(f"{summary['overview']}\n") | |
| # Key Points | |
| if "key_points" in summary and summary["key_points"]: | |
| output.append(f"{key_points_header}:") | |
| for theme_group in summary["key_points"]: | |
| if "theme" in theme_group: | |
| output.append(f"\n{theme_group['theme']}:") | |
| if "points" in theme_group: | |
| for point in theme_group["points"]: | |
| output.append(f"• {point}") | |
| # Key Entities | |
| if "key_entities" in summary and summary["key_entities"]: | |
| output.append(f"\n{key_entities_header}:") | |
| for entity in summary["key_entities"]: | |
| if "name" in entity and "description" in entity: | |
| output.append(f"• **{entity['name']}**: {entity['description']}") | |
| # Conclusion | |
| if "conclusion" in summary: | |
| output.append(f"\n{conclusion_header}:") | |
| output.append(summary["conclusion"]) | |
| return "\n".join(output) | |
| def format_quiz_for_display(results, language="English"): | |
| output = [] | |
| if language == "Uzbek": | |
| quiz_questions_header = "TEST SAVOLLARI" | |
| elif language == "Russian": | |
| quiz_questions_header = "ТЕСТОВЫЕ ВОПРОСЫ" | |
| else: | |
| quiz_questions_header = "QUIZ QUESTIONS" | |
| output.append(f"{'='*40}") | |
| output.append(f"{quiz_questions_header}") | |
| output.append(f"{'='*40}\n") | |
| quiz_questions = results.get("quiz_questions", []) | |
| for i, q in enumerate(quiz_questions): | |
| output.append(f"\n{i+1}. {q['question']}") | |
| for j, option in enumerate(q['options']): | |
| letter = chr(97 + j).upper() | |
| correct_marker = " ✓" if option["correct"] else "" | |
| output.append(f" {letter}. {option['text']}{correct_marker}") | |
| return "\n".join(output) | |
| def analyze_document(text, gemini_api_key, language, content_type="summary"): | |
| try: | |
| start_time = time.time() | |
| text_parts = split_text_by_tokens(text) | |
| input_tokens = 0 | |
| output_tokens = 0 | |
| if content_type == "summary": | |
| all_results = {} | |
| for part in text_parts: | |
| actual_prompt = SUMMARY_PROMPT_TEMPLATE.format(text=part) | |
| prompt_tokens = len(tokenizer.encode(actual_prompt)) | |
| input_tokens += prompt_tokens | |
| analysis = generate_with_gemini(part, gemini_api_key, language, "summary") | |
| if not all_results and "summary" in analysis: | |
| all_results = analysis | |
| elif "summary" in analysis: | |
| if "key_points" in analysis["summary"] and "key_points" in all_results["summary"]: | |
| all_results["summary"]["key_points"].extend(analysis["summary"]["key_points"]) | |
| if "key_entities" in analysis["summary"] and "key_entities" in all_results["summary"]: | |
| all_results["summary"]["key_entities"].extend(analysis["summary"]["key_entities"]) | |
| formatted_output = format_summary_for_display(all_results, language) | |
| else: | |
| all_results = {"quiz_questions": []} | |
| for part in text_parts: | |
| actual_prompt = QUIZ_PROMPT_TEMPLATE.format(text=part) | |
| prompt_tokens = len(tokenizer.encode(actual_prompt)) | |
| input_tokens += prompt_tokens | |
| analysis = generate_with_gemini(part, gemini_api_key, language, "quiz") | |
| if "quiz_questions" in analysis: | |
| remaining_slots = 10 - len(all_results["quiz_questions"]) | |
| if remaining_slots > 0: | |
| questions_to_add = analysis["quiz_questions"][:remaining_slots] | |
| all_results["quiz_questions"].extend(questions_to_add) | |
| formatted_output = format_quiz_for_display(all_results, language) | |
| end_time = time.time() | |
| total_time = end_time - start_time | |
| output_tokens = len(tokenizer.encode(formatted_output)) | |
| token_info = f"Input tokens: {input_tokens}\nOutput tokens: {output_tokens}\nTotal tokens: {input_tokens + output_tokens}\n" | |
| formatted_text = f"Total Processing time: {total_time:.2f}s\n{token_info}\n" + formatted_output | |
| json_path = tempfile.mktemp(suffix='.json') | |
| with open(json_path, 'w', encoding='utf-8') as json_file: | |
| json.dump(all_results, json_file, indent=2) | |
| txt_path = tempfile.mktemp(suffix='.txt') | |
| with open(txt_path, 'w', encoding='utf-8') as txt_file: | |
| txt_file.write(formatted_text) | |
| return formatted_text, json_path, txt_path | |
| except Exception as e: | |
| error_message = f"Error processing document: {str(e)}" | |
| return error_message, None, None |