Sripriya16 commited on
Commit
20e0676
·
verified ·
1 Parent(s): 695e326

Upload inference_api.py

Browse files
Files changed (1) hide show
  1. inference_api.py +161 -0
inference_api.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference_api.py
2
+ import os
3
+ import fitz # PyMuPDF
4
+ import fasttext
5
+ import torch
6
+ from PIL import Image
7
+ from huggingface_hub import hf_hub_download
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
9
+ from IndicTransToolkit.processor import IndicProcessor
10
+ import google.generativeai as genai
11
+ from fastapi import FastAPI
12
+ from pydantic import BaseModel
13
+ from typing import Optional
14
+ import json
15
+
16
+ app = FastAPI()
17
+
18
+ # === CONFIGURATION ===
19
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
20
+ TRANSLATION_MODEL_REPO_ID = "ai4bharat/indictrans2-indic-en-1B"
21
+ OCR_MODEL_ID = "microsoft/trocr-base-printed"
22
+ LANGUAGE_TO_TRANSLATE = "mal"
23
+ DEVICE = "cpu"
24
+
25
+ # --- Configure Gemini ---
26
+ if GEMINI_API_KEY:
27
+ genai.configure(api_key=GEMINI_API_KEY)
28
+ else:
29
+ print("🔴 GEMINI_API_KEY not set.")
30
+
31
+ # --- Load Models ---
32
+ translation_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL_REPO_ID, trust_remote_code=True)
33
+ translation_model = AutoModelForSeq2SeqLM.from_pretrained(
34
+ TRANSLATION_MODEL_REPO_ID, trust_remote_code=True, torch_dtype=torch.float32
35
+ ).to(DEVICE)
36
+ ip = IndicProcessor(inference=True)
37
+
38
+ ft_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin")
39
+ lang_detect_model = fasttext.load_model(ft_model_path)
40
+
41
+ ocr_pipeline = pipeline("image-to-text", model=OCR_MODEL_ID, device=-1)
42
+
43
+ # === HELPER FUNCTIONS ===
44
+ def classify_image_with_gemini(image: Image.Image):
45
+ model = genai.GenerativeModel('gemini-1.5-flash-latest')
46
+ prompt = "Is this image primarily a text document or an engineering/technical diagram? Answer with only 'document' or 'diagram'."
47
+ response = model.generate_content([prompt, image])
48
+ classification = response.text.strip().lower()
49
+ return "diagram" if "diagram" in classification else "document"
50
+
51
+ def summarize_diagram_with_gemini(image: Image.Image):
52
+ model = genai.GenerativeModel('gemini-1.5-flash-latest')
53
+ prompt = "Describe the contents of this technical diagram in a concise summary."
54
+ response = model.generate_content([prompt, image])
55
+ return response.text.strip()
56
+
57
+ def extract_text_from_image(path):
58
+ image = Image.open(path).convert("RGB")
59
+ image_type = classify_image_with_gemini(image)
60
+ if image_type == "diagram":
61
+ return summarize_diagram_with_gemini(image)
62
+ else:
63
+ out = ocr_pipeline(image)
64
+ return out[0]["generated_text"] if out else ""
65
+
66
+ def extract_text_from_pdf(path):
67
+ doc = fitz.open(path)
68
+ return "".join(page.get_text("text") + "\n" for page in doc)
69
+
70
+ def read_text_from_txt(path):
71
+ with open(path, "r", encoding="utf-8") as f:
72
+ return f.read()
73
+
74
+ def detect_language(text_snippet):
75
+ s = text_snippet.replace("\n", " ").strip()
76
+ if not s: return None
77
+ preds = lang_detect_model.predict(s, k=1)
78
+ return preds[0][0].split("__")[-1] if preds and preds[0] else None
79
+
80
+ def translate_chunk(chunk):
81
+ batch = ip.preprocess_batch([chunk], src_lang="mal_Mlym", tgt_lang="eng_Latn")
82
+ inputs = translation_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
83
+ with torch.no_grad():
84
+ generated_tokens = translation_model.generate(**inputs, num_beams=5, max_length=512, early_stopping=True)
85
+ decoded = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
86
+ return ip.postprocess_batch(decoded, lang="eng_Latn")[0]
87
+
88
+ def generate_structured_json(text_to_analyze):
89
+ model = genai.GenerativeModel('gemini-1.5-flash-latest')
90
+ prompt = f"Analyze this document and extract key info as JSON: {text_to_analyze}"
91
+ json_schema = {
92
+ "type": "OBJECT",
93
+ "properties": {
94
+ "summary": {"type": "STRING"},
95
+ "actions_required": {"type": "ARRAY", "items": {
96
+ "type": "OBJECT",
97
+ "properties": {"action": {"type": "STRING"}, "priority": {"type": "STRING", "enum": ["High","Medium","Low"]}, "deadline": {"type": "STRING"}, "notes": {"type": "STRING"}},
98
+ "required": ["action","priority","deadline","notes"]
99
+ }},
100
+ "departments_to_notify": {"type": "ARRAY", "items": {"type": "STRING"}},
101
+ "cross_document_flags": {"type": "ARRAY", "items": {
102
+ "type": "OBJECT",
103
+ "properties": {"related_document_type": {"type": "STRING"}, "related_issue": {"type": "STRING"}},
104
+ "required": ["related_document_type","related_issue"]
105
+ }}
106
+ },
107
+ "required": ["summary","actions_required","departments_to_notify","cross_document_flags"]
108
+ }
109
+ generation_config = genai.types.GenerationConfig(response_mime_type="application/json", response_schema=json_schema)
110
+ response = model.generate_content(prompt, generation_config=generation_config)
111
+ return json.loads(response.text)
112
+
113
+ def check_relevance_with_gemini(summary_text):
114
+ model = genai.GenerativeModel('gemini-1.5-flash-latest')
115
+ prompt = f'Is this summary relevant to transportation, infrastructure, railways, or metro systems? Answer "Yes" or "No". Summary: {summary_text}'
116
+ response = model.generate_content(prompt)
117
+ return "yes" in response.text.strip().lower()
118
+
119
+ # === API INPUT SCHEMA ===
120
+ class InputFile(BaseModel):
121
+ file_path: str
122
+
123
+ @app.post("/predict")
124
+ def predict(file: InputFile):
125
+ if not GEMINI_API_KEY:
126
+ return {"error": "Gemini API key not set."}
127
+ path = file.file_path
128
+ ext = os.path.splitext(path)[1].lower()
129
+
130
+ # Phase 1: Extract text
131
+ if ext == ".pdf":
132
+ original_text = extract_text_from_pdf(path)
133
+ elif ext == ".txt":
134
+ original_text = read_text_from_txt(path)
135
+ elif ext in [".png", ".jpg", ".jpeg"]:
136
+ original_text = extract_text_from_image(path)
137
+ else:
138
+ return {"error": "Unsupported file type."}
139
+
140
+ # Phase 2: Translate Malayalam if detected
141
+ lines = original_text.split("\n")
142
+ translated_lines = []
143
+ for ln in lines:
144
+ if not ln.strip(): continue
145
+ lang = detect_language(ln)
146
+ if lang == LANGUAGE_TO_TRANSLATE:
147
+ translated_lines.append(translate_chunk(ln))
148
+ else:
149
+ translated_lines.append(ln)
150
+ final_text = "\n".join(translated_lines)
151
+
152
+ # Phase 3: Gemini analysis
153
+ summary_data = generate_structured_json(final_text)
154
+ if not summary_data or "summary" not in summary_data:
155
+ return {"error": "Failed to generate analysis."}
156
+
157
+ is_relevant = check_relevance_with_gemini(summary_data["summary"])
158
+ if is_relevant:
159
+ return summary_data
160
+ else:
161
+ return {"status": "Not Applicable", "reason": "Document not relevant to KMRL."}