Sripriya16 commited on
Commit
1895105
·
verified ·
1 Parent(s): 667e559

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -207
app.py CHANGED
@@ -1,207 +1,161 @@
1
- import os
2
- import fitz # PyMuPDF
3
- import fasttext
4
- import requests
5
- import json
6
- import torch
7
- from PIL import Image
8
- from huggingface_hub import hf_hub_download
9
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
10
- from IndicTransToolkit.processor import IndicProcessor
11
- import google.generativeai as genai
12
- import gradio as gr
13
-
14
- # === 1. CONFIGURATION & SECRETS ===
15
- # --- Load the Gemini API Key from Hugging Face Secrets ---
16
- GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
17
-
18
- # --- Model IDs (Using the CPU-friendly TrOCR model) ---
19
- TRANSLATION_MODEL_REPO_ID = "ai4bharat/indictrans2-indic-en-1B"
20
- OCR_MODEL_ID = "microsoft/trocr-base-printed"
21
-
22
- # --- Language Settings ---
23
- LANGUAGE_TO_TRANSLATE = "mal"
24
-
25
- # --- Hardware Settings ---
26
- DEVICE = "cpu" # Forcing CPU for compatibility with free tier
27
-
28
- # === 2. LOAD MODELS & CONFIGURE API ===
29
- # --- Configure Gemini API ---
30
- if not GEMINI_API_KEY:
31
- print("🔴 ERROR: Gemini API key is not set in the Space Secrets.")
32
- else:
33
- genai.configure(api_key=GEMINI_API_KEY)
34
-
35
- # --- Load Translation Model ---
36
- print(f"Loading tokenizer & model: {TRANSLATION_MODEL_REPO_ID} ...")
37
- translation_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL_REPO_ID, trust_remote_code=True)
38
- translation_model = AutoModelForSeq2SeqLM.from_pretrained(
39
- TRANSLATION_MODEL_REPO_ID,
40
- trust_remote_code=True,
41
- torch_dtype=torch.float32 # Use float32 for CPU
42
- ).to(DEVICE)
43
- print("✅ Translation model loaded.")
44
- ip = IndicProcessor(inference=True)
45
-
46
- # --- Load Language Detection Model ---
47
- print("Loading fastText language detector...")
48
- ft_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin")
49
- lang_detect_model = fasttext.load_model(ft_model_path)
50
- print("✅ fastText loaded.")
51
-
52
- # --- Load Standard OCR Model ---
53
- print(f"Loading Standard OCR model: {OCR_MODEL_ID}...")
54
- ocr_pipeline = pipeline("image-to-text", model=OCR_MODEL_ID, device=-1) # device=-1 ensures CPU
55
- print("✅ Standard OCR model loaded.")
56
-
57
-
58
- # === 3. HELPER FUNCTIONS ===
59
-
60
- # --- Phase 1: Text Extraction ---
61
- def classify_image_with_gemini(image: Image.Image):
62
- """Uses Gemini to classify an image as a 'document' or 'diagram'."""
63
- model = genai.GenerativeModel('gemini-1.5-flash-latest')
64
- prompt = "Is this image primarily a text document or an engineering/technical diagram? Answer with only 'document' or 'diagram'."
65
- response = model.generate_content([prompt, image])
66
- classification = response.text.strip().lower()
67
- print(f"✅ Image classified as: {classification}")
68
- return "diagram" if "diagram" in classification else "document"
69
-
70
- def summarize_diagram_with_gemini(image: Image.Image):
71
- """Uses Gemini to generate a summary of an engineering diagram."""
72
- model = genai.GenerativeModel('gemini-1.5-flash-latest')
73
- prompt = "You are an engineering assistant for Kochi Metro Rail Limited (KMRL). Describe the contents of this technical diagram or engineering drawing in a concise summary. Identify key components and their apparent purpose."
74
- response = model.generate_content([prompt, image])
75
- print(" Diagram summary successful.")
76
- return response.text.strip()
77
-
78
- def extract_text_from_image(path):
79
- """
80
- Classifies an image and routes it for either OCR (if a text doc) or summarization (if a diagram).
81
- """
82
- print("\n--- Starting Image Processing ---")
83
- try:
84
- image = Image.open(path).convert("RGB")
85
-
86
- # Step 1: Classify the image using Gemini
87
- image_type = classify_image_with_gemini(image)
88
-
89
- # Step 2: Route to the correct function
90
- if image_type == "diagram":
91
- print("-> Image is a diagram. Summarizing with Gemini...")
92
- return summarize_diagram_with_gemini(image)
93
- else:
94
- print("-> Image is a document. Extracting text with TrOCR...")
95
- out = ocr_pipeline(image)
96
- return out[0]["generated_text"] if out else ""
97
-
98
- except Exception as e:
99
- print(f"❌ An error occurred during image processing: {e}")
100
- return "Error during image processing."
101
-
102
- def extract_text_from_pdf(path):
103
- doc = fitz.open(path)
104
- return "".join(page.get_text("text") + "\n" for page in doc)
105
-
106
- def read_text_from_txt(path):
107
- with open(path, "r", encoding="utf-8") as f:
108
- return f.read()
109
-
110
- # --- Phase 2: Translation ---
111
- def detect_language(text_snippet):
112
- s = text_snippet.replace("\n", " ").strip()
113
- if not s: return None
114
- preds = lang_detect_model.predict(s, k=1)
115
- return preds[0][0].split("__")[-1] if preds and preds[0] else None
116
-
117
- def translate_chunk(chunk):
118
- batch = ip.preprocess_batch([chunk], src_lang="mal_Mlym", tgt_lang="eng_Latn")
119
- inputs = translation_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
120
- with torch.no_grad():
121
- generated_tokens = translation_model.generate(**inputs, num_beams=5, max_length=512, early_stopping=True)
122
- decoded = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
123
- return ip.postprocess_batch(decoded, lang=tgt_lang)[0]
124
-
125
- # --- Phase 3: Gemini Analysis ---
126
- def generate_structured_json(text_to_analyze):
127
- """Generates the detailed JSON analysis."""
128
- model = genai.GenerativeModel('gemini-1.5-flash-latest')
129
- prompt = f"You are an AI assistant for KMRL. Analyze this document and extract key info as JSON: {text_to_analyze}"
130
- json_schema = {"type": "OBJECT", "properties": {"summary": {"type": "STRING"}, "actions_required": {"type": "ARRAY", "items": {"type": "OBJECT", "properties": {"action": {"type": "STRING"}, "priority": {"type": "STRING", "enum": ["High", "Medium", "Low"]}, "deadline": {"type": "STRING"}, "notes": {"type": "STRING"}}, "required": ["action", "priority", "deadline", "notes"]}}, "departments_to_notify": {"type": "ARRAY", "items": {"type": "STRING"}}, "cross_document_flags": {"type": "ARRAY", "items": {"type": "OBJECT", "properties": {"related_document_type": {"type": "STRING"}, "related_issue": {"type": "STRING"}}, "required": ["related_document_type", "related_issue"]}}}, "required": ["summary", "actions_required", "departments_to_notify", "cross_document_flags"]}
131
- generation_config = genai.types.GenerationConfig(response_mime_type="application/json", response_schema=json_schema)
132
- response = model.generate_content(prompt, generation_config=generation_config)
133
- return json.loads(response.text)
134
-
135
- def check_relevance_with_gemini(summary_text):
136
- """Checks if the summary is relevant to KMRL."""
137
- model = genai.GenerativeModel('gemini-1.5-flash-latest')
138
- prompt = f'Is this summary related to transportation, infrastructure, railways, or metro systems? Answer only "Yes" or "No".\n\nSummary: {summary_text}'
139
- response = model.generate_content(prompt)
140
- return "yes" in response.text.strip().lower()
141
-
142
- # === 4. MAIN PROCESSING FUNCTION FOR GRADIO ===
143
- def process_and_analyze_document(input_file):
144
- if not GEMINI_API_KEY:
145
- raise gr.Error("Gemini API key is not configured. The administrator must set it in the Space Secrets.")
146
- if input_file is None:
147
- raise gr.Error("No file uploaded. Please upload a document.")
148
-
149
- try:
150
- input_file_path = input_file.name
151
- ext = os.path.splitext(input_file_path)[1].lower()
152
-
153
- # --- Phase 1: Get Original Text ---
154
- if ext == ".pdf":
155
- original_text = extract_text_from_pdf(input_file_path)
156
- elif ext == ".txt":
157
- original_text = read_text_from_txt(input_file_path)
158
- elif ext in [".png", ".jpg", ".jpeg"]:
159
- original_text = extract_text_from_image(input_file_path)
160
- else:
161
- raise gr.Error("Unsupported file type.")
162
-
163
- if not original_text or not original_text.strip():
164
- raise gr.Error("No text could be extracted from the document.")
165
-
166
- # --- Phase 2: Translate if Necessary ---
167
- lines = original_text.split("\n")
168
- translated_lines = []
169
- for ln in lines:
170
- if not ln.strip(): continue
171
- lang = detect_language(ln)
172
- if lang == LANGUAGE_TO_TRANSLATE:
173
- translated_lines.append(translate_chunk(ln))
174
- else:
175
- translated_lines.append(ln)
176
- final_text = "\n".join(translated_lines)
177
-
178
- # --- Phase 3: Analyze with Gemini ---
179
- summary_data = generate_structured_json(final_text)
180
- if not summary_data or "summary" not in summary_data:
181
- raise gr.Error("Failed to generate a valid analysis from the document.")
182
-
183
- is_relevant = check_relevance_with_gemini(summary_data["summary"])
184
-
185
- if is_relevant:
186
- return summary_data
187
- else:
188
- return {"status": "Not Applicable", "reason": "The document was determined to be not relevant to KMRL."}
189
-
190
- except Exception as e:
191
- raise gr.Error(f"An unexpected error occurred: {str(e)}")
192
-
193
-
194
- iface = gr.Interface(
195
- fn=process_and_analyze_document,
196
- inputs=gr.File(label="Upload Document (.pdf, .txt, .png, .jpeg)"),
197
- outputs=gr.JSON(label="Analysis Result"),
198
- title="KMRL Document Analysis Pipeline",
199
- description="Upload a document (Malayalam or English). The system will detect and translate Malayalam text to English, then send the full text to Gemini for structured analysis.",
200
- allow_flagging="never",
201
- examples=[
202
- ["Malayalam-en.txt"] # If you upload this file to your Space
203
- ]
204
- )
205
-
206
- if __name__ == "__main__":
207
- iface.launch()
 
1
+ # inference_api.py
2
+ import os
3
+ import fitz # PyMuPDF
4
+ import fasttext
5
+ import torch
6
+ from PIL import Image
7
+ from huggingface_hub import hf_hub_download
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
9
+ from IndicTransToolkit.processor import IndicProcessor
10
+ import google.generativeai as genai
11
+ from fastapi import FastAPI
12
+ from pydantic import BaseModel
13
+ from typing import Optional
14
+ import json
15
+
16
+ app = FastAPI()
17
+
18
+ # === CONFIGURATION ===
19
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
20
+ TRANSLATION_MODEL_REPO_ID = "ai4bharat/indictrans2-indic-en-1B"
21
+ OCR_MODEL_ID = "microsoft/trocr-base-printed"
22
+ LANGUAGE_TO_TRANSLATE = "mal"
23
+ DEVICE = "cpu"
24
+
25
+ # --- Configure Gemini ---
26
+ if GEMINI_API_KEY:
27
+ genai.configure(api_key=GEMINI_API_KEY)
28
+ else:
29
+ print("🔴 GEMINI_API_KEY not set.")
30
+
31
+ # --- Load Models ---
32
+ translation_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL_REPO_ID, trust_remote_code=True)
33
+ translation_model = AutoModelForSeq2SeqLM.from_pretrained(
34
+ TRANSLATION_MODEL_REPO_ID, trust_remote_code=True, torch_dtype=torch.float32
35
+ ).to(DEVICE)
36
+ ip = IndicProcessor(inference=True)
37
+
38
+ ft_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin")
39
+ lang_detect_model = fasttext.load_model(ft_model_path)
40
+
41
+ ocr_pipeline = pipeline("image-to-text", model=OCR_MODEL_ID, device=-1)
42
+
43
+ # === HELPER FUNCTIONS ===
44
+ def classify_image_with_gemini(image: Image.Image):
45
+ model = genai.GenerativeModel('gemini-1.5-flash-latest')
46
+ prompt = "Is this image primarily a text document or an engineering/technical diagram? Answer with only 'document' or 'diagram'."
47
+ response = model.generate_content([prompt, image])
48
+ classification = response.text.strip().lower()
49
+ return "diagram" if "diagram" in classification else "document"
50
+
51
+ def summarize_diagram_with_gemini(image: Image.Image):
52
+ model = genai.GenerativeModel('gemini-1.5-flash-latest')
53
+ prompt = "Describe the contents of this technical diagram in a concise summary."
54
+ response = model.generate_content([prompt, image])
55
+ return response.text.strip()
56
+
57
+ def extract_text_from_image(path):
58
+ image = Image.open(path).convert("RGB")
59
+ image_type = classify_image_with_gemini(image)
60
+ if image_type == "diagram":
61
+ return summarize_diagram_with_gemini(image)
62
+ else:
63
+ out = ocr_pipeline(image)
64
+ return out[0]["generated_text"] if out else ""
65
+
66
+ def extract_text_from_pdf(path):
67
+ doc = fitz.open(path)
68
+ return "".join(page.get_text("text") + "\n" for page in doc)
69
+
70
+ def read_text_from_txt(path):
71
+ with open(path, "r", encoding="utf-8") as f:
72
+ return f.read()
73
+
74
+ def detect_language(text_snippet):
75
+ s = text_snippet.replace("\n", " ").strip()
76
+ if not s: return None
77
+ preds = lang_detect_model.predict(s, k=1)
78
+ return preds[0][0].split("__")[-1] if preds and preds[0] else None
79
+
80
+ def translate_chunk(chunk):
81
+ batch = ip.preprocess_batch([chunk], src_lang="mal_Mlym", tgt_lang="eng_Latn")
82
+ inputs = translation_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
83
+ with torch.no_grad():
84
+ generated_tokens = translation_model.generate(**inputs, num_beams=5, max_length=512, early_stopping=True)
85
+ decoded = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
86
+ return ip.postprocess_batch(decoded, lang="eng_Latn")[0]
87
+
88
+ def generate_structured_json(text_to_analyze):
89
+ model = genai.GenerativeModel('gemini-1.5-flash-latest')
90
+ prompt = f"Analyze this document and extract key info as JSON: {text_to_analyze}"
91
+ json_schema = {
92
+ "type": "OBJECT",
93
+ "properties": {
94
+ "summary": {"type": "STRING"},
95
+ "actions_required": {"type": "ARRAY", "items": {
96
+ "type": "OBJECT",
97
+ "properties": {"action": {"type": "STRING"}, "priority": {"type": "STRING", "enum": ["High","Medium","Low"]}, "deadline": {"type": "STRING"}, "notes": {"type": "STRING"}},
98
+ "required": ["action","priority","deadline","notes"]
99
+ }},
100
+ "departments_to_notify": {"type": "ARRAY", "items": {"type": "STRING"}},
101
+ "cross_document_flags": {"type": "ARRAY", "items": {
102
+ "type": "OBJECT",
103
+ "properties": {"related_document_type": {"type": "STRING"}, "related_issue": {"type": "STRING"}},
104
+ "required": ["related_document_type","related_issue"]
105
+ }}
106
+ },
107
+ "required": ["summary","actions_required","departments_to_notify","cross_document_flags"]
108
+ }
109
+ generation_config = genai.types.GenerationConfig(response_mime_type="application/json", response_schema=json_schema)
110
+ response = model.generate_content(prompt, generation_config=generation_config)
111
+ return json.loads(response.text)
112
+
113
+ def check_relevance_with_gemini(summary_text):
114
+ model = genai.GenerativeModel('gemini-1.5-flash-latest')
115
+ prompt = f'Is this summary relevant to transportation, infrastructure, railways, or metro systems? Answer "Yes" or "No". Summary: {summary_text}'
116
+ response = model.generate_content(prompt)
117
+ return "yes" in response.text.strip().lower()
118
+
119
+ # === API INPUT SCHEMA ===
120
+ class InputFile(BaseModel):
121
+ file_path: str
122
+
123
+ @app.post("/predict")
124
+ def predict(file: InputFile):
125
+ if not GEMINI_API_KEY:
126
+ return {"error": "Gemini API key not set."}
127
+ path = file.file_path
128
+ ext = os.path.splitext(path)[1].lower()
129
+
130
+ # Phase 1: Extract text
131
+ if ext == ".pdf":
132
+ original_text = extract_text_from_pdf(path)
133
+ elif ext == ".txt":
134
+ original_text = read_text_from_txt(path)
135
+ elif ext in [".png", ".jpg", ".jpeg"]:
136
+ original_text = extract_text_from_image(path)
137
+ else:
138
+ return {"error": "Unsupported file type."}
139
+
140
+ # Phase 2: Translate Malayalam if detected
141
+ lines = original_text.split("\n")
142
+ translated_lines = []
143
+ for ln in lines:
144
+ if not ln.strip(): continue
145
+ lang = detect_language(ln)
146
+ if lang == LANGUAGE_TO_TRANSLATE:
147
+ translated_lines.append(translate_chunk(ln))
148
+ else:
149
+ translated_lines.append(ln)
150
+ final_text = "\n".join(translated_lines)
151
+
152
+ # Phase 3: Gemini analysis
153
+ summary_data = generate_structured_json(final_text)
154
+ if not summary_data or "summary" not in summary_data:
155
+ return {"error": "Failed to generate analysis."}
156
+
157
+ is_relevant = check_relevance_with_gemini(summary_data["summary"])
158
+ if is_relevant:
159
+ return summary_data
160
+ else:
161
+ return {"status": "Not Applicable", "reason": "Document not relevant to KMRL."}