lydiasolomon commited on
Commit
897fd91
·
verified ·
1 Parent(s): f70d966

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +141 -91
main.py CHANGED
@@ -1,34 +1,40 @@
1
  import os
2
- import logging
3
  import io
4
- from fastapi import FastAPI, Request, Header, HTTPException, UploadFile, File
 
 
 
5
  from fastapi.responses import JSONResponse
6
  from pydantic import BaseModel
7
  from transformers import pipeline
 
8
  from PIL import Image
9
- from smebuilder_vector import query_vector
10
 
11
  # ==============================
12
  # Logging Setup
13
  # ==============================
14
  logging.basicConfig(level=logging.INFO)
15
- logger = logging.getLogger("AgriCopilot")
16
 
17
  # ==============================
18
- # App Initialization
19
  # ==============================
20
- app = FastAPI(title="AgriCopilot AI API", version="2.0")
21
 
22
  @app.get("/")
23
  async def root():
24
- return {"status": "AgriCopilot AI Backend is running smoothly ✅"}
25
 
26
  # ==============================
27
- # AUTH CONFIGURATION
28
  # ==============================
29
- PROJECT_API_KEY = os.getenv("PROJECT_API_KEY", "agricopilot404")
 
 
30
 
31
  def check_auth(authorization: str | None):
 
32
  if not PROJECT_API_KEY:
33
  return
34
  if not authorization or not authorization.startswith("Bearer "):
@@ -38,125 +44,169 @@ def check_auth(authorization: str | None):
38
  raise HTTPException(status_code=403, detail="Invalid token")
39
 
40
  # ==============================
41
- # Exception Handling
42
  # ==============================
43
  @app.exception_handler(Exception)
44
  async def global_exception_handler(request: Request, exc: Exception):
45
- logger.error(f"Unhandled error: {exc}")
46
  return JSONResponse(status_code=500, content={"error": str(exc)})
47
 
48
  # ==============================
49
- # Request Models
50
  # ==============================
51
  class ChatRequest(BaseModel):
52
- query: str
53
-
54
- class DisasterRequest(BaseModel):
55
- report: str
56
 
57
- class MarketRequest(BaseModel):
58
- product: str
59
 
60
- class VectorRequest(BaseModel):
61
- query: str
62
 
63
  # ==============================
64
- # Load Hugging Face Pipelines
65
  # ==============================
66
- HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
67
-
68
  if not HF_TOKEN:
69
- logger.warning("⚠️ No Hugging Face token found. Gated models may fail.")
70
  else:
71
- logger.info("✅ Hugging Face token loaded successfully.")
72
 
73
- # General text-generation model for chat, disaster, market endpoints
74
- default_model = "meta-llama/Llama-3.1-8B-Instruct"
75
- vision_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
 
 
76
 
77
- chat_pipe = pipeline("text-generation", model=default_model, token=HF_TOKEN)
78
- disaster_pipe = pipeline("text-generation", model=default_model, token=HF_TOKEN)
79
- market_pipe = pipeline("text-generation", model=default_model, token=HF_TOKEN)
 
 
 
80
 
81
- # Multimodal crop diagnostic model
82
- try:
83
- crop_pipe = pipeline("image-text-to-text", model=vision_model, token=HF_TOKEN)
84
- except Exception as e:
85
- logger.warning(f"Crop model load failed: {e}")
86
- crop_pipe = None
87
 
88
  # ==============================
89
- # Helper Functions
90
  # ==============================
91
- def run_conversational(pipe, prompt: str):
 
92
  try:
93
- output = pipe(prompt, max_new_tokens=200)
94
  if isinstance(output, list) and len(output) > 0:
95
- return output[0].get("generated_text", str(output))
96
- return str(output)
 
 
 
 
 
 
 
97
  except Exception as e:
98
  logger.error(f"Pipeline error: {e}")
99
- return f"⚠️ Model error: {str(e)}"
100
-
101
- def run_crop_doctor(image_bytes: bytes, symptoms: str):
102
- """
103
- Diagnose crop issues using Meta's multimodal LLaMA Vision model.
104
- """
105
- if not crop_pipe:
106
- return "⚠️ Crop analysis temporarily unavailable (model not loaded)."
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  try:
108
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
109
- prompt = (
110
- f"The farmer reports: {symptoms}. "
111
- "Analyze the image and diagnose the likely crop disease. "
112
- "Then explain it simply and recommend possible treatment steps."
113
  )
114
- output = crop_pipe(image, prompt)
115
- if isinstance(output, list) and len(output) > 0:
116
- return output[0].get("generated_text", str(output))
117
- return str(output)
118
  except Exception as e:
119
- logger.error(f"Crop Doctor pipeline error: {e}")
120
- return f"⚠️ Unexpected model error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  # ==============================
123
- # API ROUTES
124
  # ==============================
125
- @app.post("/multilingual-chat")
126
- async def multilingual_chat(req: ChatRequest, authorization: str | None = Header(None)):
127
  check_auth(authorization)
128
- reply = run_conversational(chat_pipe, req.query)
129
- return {"reply": reply}
 
130
 
131
- @app.post("/disaster-summarizer")
132
- async def disaster_summarizer(req: DisasterRequest, authorization: str | None = Header(None)):
133
  check_auth(authorization)
134
- summary = run_conversational(disaster_pipe, req.report)
135
- return {"summary": summary}
 
136
 
137
- @app.post("/marketplace")
138
- async def marketplace(req: MarketRequest, authorization: str | None = Header(None)):
139
  check_auth(authorization)
140
- recommendation = run_conversational(market_pipe, req.product)
141
- return {"recommendation": recommendation}
 
 
 
 
 
 
142
 
143
- @app.post("/vector-search")
144
- async def vector_search(req: VectorRequest, authorization: str | None = Header(None)):
145
  check_auth(authorization)
 
146
  try:
147
- results = query_vector(req.query)
148
- return {"results": results}
 
 
 
 
 
 
 
 
 
149
  except Exception as e:
150
- logger.error(f"Vector search error: {e}")
151
- return {"error": f"Vector search error: {str(e)}"}
152
-
153
- @app.post("/crop-doctor")
154
- async def crop_doctor(
155
- symptoms: str = Header(...),
156
- image: UploadFile = File(...),
157
- authorization: str | None = Header(None)
158
- ):
159
- check_auth(authorization)
160
- image_bytes = await image.read()
161
- diagnosis = run_crop_doctor(image_bytes, symptoms)
162
- return {"diagnosis": diagnosis}
 
1
  import os
 
2
  import io
3
+ import tempfile
4
+ import logging
5
+ import traceback
6
+ from fastapi import FastAPI, Header, HTTPException, UploadFile, File, Request
7
  from fastapi.responses import JSONResponse
8
  from pydantic import BaseModel
9
  from transformers import pipeline
10
+ from langdetect import detect, DetectorFactory
11
  from PIL import Image
12
+ from smebuilder_vector import retriever # Local vector retriever
13
 
14
  # ==============================
15
  # Logging Setup
16
  # ==============================
17
  logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger("DevAssist")
19
 
20
  # ==============================
21
+ # FastAPI Init
22
  # ==============================
23
+ app = FastAPI(title="DevAssist AI Backend")
24
 
25
  @app.get("/")
26
  async def root():
27
+ return {"status": " DevAssist AI Backend running"}
28
 
29
  # ==============================
30
+ # Auth Configuration
31
  # ==============================
32
+ PROJECT_API_KEY = os.getenv("PROJECT_API_KEY", "devassist-secret")
33
+ HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
34
+ SPITCH_API_KEY = os.getenv("SPITCH_API_KEY")
35
 
36
  def check_auth(authorization: str | None):
37
+ """Bearer token validator."""
38
  if not PROJECT_API_KEY:
39
  return
40
  if not authorization or not authorization.startswith("Bearer "):
 
44
  raise HTTPException(status_code=403, detail="Invalid token")
45
 
46
  # ==============================
47
+ # Global Error Handler
48
  # ==============================
49
  @app.exception_handler(Exception)
50
  async def global_exception_handler(request: Request, exc: Exception):
51
+ logger.error(f"Unhandled Exception: {exc}")
52
  return JSONResponse(status_code=500, content={"error": str(exc)})
53
 
54
  # ==============================
55
+ # Request Schemas
56
  # ==============================
57
  class ChatRequest(BaseModel):
58
+ question: str
 
 
 
59
 
60
+ class AutoDocRequest(BaseModel):
61
+ code: str
62
 
63
+ class SMERequest(BaseModel):
64
+ user_prompt: str
65
 
66
  # ==============================
67
+ # HuggingFace Pipelines
68
  # ==============================
 
 
69
  if not HF_TOKEN:
70
+ logger.warning("⚠️ No Hugging Face token found. Private/gated models may fail.")
71
  else:
72
+ logger.info("✅ Hugging Face token detected and ready.")
73
 
74
+ HF_MODELS = {
75
+ "chat": "meta-llama/Llama-3.1-8B-Instruct",
76
+ "autodoc": "Salesforce/codegen-2B-mono",
77
+ "sme": "deepseek-ai/deepseek-coder-1.3b-instruct"
78
+ }
79
 
80
+ def safe_pipeline(task: str, model: str, fallback="gpt2"):
81
+ try:
82
+ return pipeline(task, model=model, token=HF_TOKEN)
83
+ except Exception as e:
84
+ logger.warning(f"Failed to load {model}: {e} → Falling back to {fallback}")
85
+ return pipeline(task, model=fallback)
86
 
87
+ chat_pipe = safe_pipeline("text-generation", HF_MODELS["chat"])
88
+ autodoc_pipe = safe_pipeline("text-generation", HF_MODELS["autodoc"])
89
+ sme_pipe = safe_pipeline("text-generation", HF_MODELS["sme"])
 
 
 
90
 
91
  # ==============================
92
+ # Helper: Text Generation
93
  # ==============================
94
+ def run_pipeline(pipe, prompt: str, max_tokens=512):
95
+ """Run a text-generation pipeline with proper error capture."""
96
  try:
97
+ output = pipe(prompt, max_new_tokens=max_tokens)
98
  if isinstance(output, list) and len(output) > 0:
99
+ result = output[0].get("generated_text", "").strip()
100
+ else:
101
+ result = str(output).strip()
102
+
103
+ logger.info(f"\n--- PROMPT ---\n{prompt}\n--- OUTPUT ---\n{result}\n--- END ---")
104
+
105
+ if not result:
106
+ return {"success": False, "error": "⚠️ LLM returned empty output."}
107
+ return {"success": True, "data": result}
108
  except Exception as e:
109
  logger.error(f"Pipeline error: {e}")
110
+ return {
111
+ "success": False,
112
+ "error": f"⚠️ LLM error: {str(e)}",
113
+ "trace": traceback.format_exc(),
114
+ }
115
+
116
+ # ==============================
117
+ # Audio Processing Helper
118
+ # ==============================
119
+ async def process_audio(file: UploadFile, lang_hint: str | None = None):
120
+ import spitch
121
+ spitch_client = spitch.Spitch()
122
+ suffix = os.path.splitext(file.filename)[1] or ".wav"
123
+
124
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
125
+ tf.write(await file.read())
126
+ tmp_path = tf.name
127
+
128
+ with open(tmp_path, "rb") as f:
129
+ audio_bytes = f.read()
130
+
131
  try:
132
+ resp = spitch_client.speech.transcribe(
133
+ content=audio_bytes, language=lang_hint or "en"
 
 
 
134
  )
 
 
 
 
135
  except Exception as e:
136
+ logger.warning(f"Speech API failed: {e}")
137
+ resp = {"text": ""}
138
+
139
+ transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "")
140
+ detected_lang = "en"
141
+ try:
142
+ detected_lang = detect(transcription) if transcription.strip() else "en"
143
+ except Exception:
144
+ pass
145
+
146
+ # Optional translation
147
+ translation = transcription
148
+ if detected_lang != "en":
149
+ try:
150
+ translation_resp = spitch_client.text.translate(
151
+ text=transcription, source=detected_lang, target="en"
152
+ )
153
+ translation = getattr(translation_resp, "text", "") or translation_resp.get("text", "")
154
+ except Exception:
155
+ translation = transcription
156
+
157
+ return transcription, detected_lang, translation
158
 
159
  # ==============================
160
+ # Endpoints
161
  # ==============================
162
+ @app.post("/chat")
163
+ async def chat_endpoint(req: ChatRequest, authorization: str | None = Header(None)):
164
  check_auth(authorization)
165
+ prompt = f"You are a helpful developer assistant. Question:\n{req.question}\nAnswer clearly:"
166
+ result = run_pipeline(chat_pipe, prompt)
167
+ return result
168
 
169
+ @app.post("/autodoc")
170
+ async def autodoc_endpoint(req: AutoDocRequest, authorization: str | None = Header(None)):
171
  check_auth(authorization)
172
+ prompt = f"Generate Markdown documentation for the following Python code:\n{req.code}\nDocumentation:"
173
+ result = run_pipeline(autodoc_pipe, prompt)
174
+ return result
175
 
176
+ @app.post("/sme/generate")
177
+ async def sme_generate_endpoint(req: SMERequest, authorization: str | None = Header(None)):
178
  check_auth(authorization)
179
+ try:
180
+ context_docs = retriever.get_relevant_documents(req.user_prompt)
181
+ context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
182
+ prompt = f"Generate production-grade frontend code based on this:\n{req.user_prompt}\nContext:\n{context}\nOutput:"
183
+ result = run_pipeline(sme_pipe, prompt)
184
+ return result
185
+ except Exception as e:
186
+ return {"success": False, "error": f"⚠️ LLM error: {str(e)}", "trace": traceback.format_exc()}
187
 
188
+ @app.post("/sme/speech-generate")
189
+ async def sme_speech_endpoint(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
190
  check_auth(authorization)
191
+ transcription, detected_lang, translation = await process_audio(file, lang_hint)
192
  try:
193
+ context_docs = retriever.get_relevant_documents(translation)
194
+ context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
195
+ prompt = f"Generate production-ready frontend code for this idea:\n{translation}\nContext:\n{context}\nOutput:"
196
+ result = run_pipeline(sme_pipe, prompt)
197
+ return {
198
+ "success": True,
199
+ "transcription": transcription,
200
+ "detected_language": detected_lang,
201
+ "translation": translation,
202
+ "output": result.get("data", ""),
203
+ }
204
  except Exception as e:
205
+ return {"success": False, "error": f"⚠️ LLM error: {str(e)}", "trace": traceback.format_exc()}
206
+
207
+ # ==============================
208
+ # Run App
209
+ # ==============================
210
+ if __name__ == "__main__":
211
+ import uvicorn
212
+ uvicorn.run("main:app", host="0.0.0.0", port=7860)