|
|
import os |
|
|
import io |
|
|
import json |
|
|
import re |
|
|
import tempfile |
|
|
import logging |
|
|
from contextlib import asynccontextmanager |
|
|
from fastapi import FastAPI, Request, status, Depends, Header, HTTPException |
|
|
from fastapi.concurrency import run_in_threadpool |
|
|
from pydantic import BaseModel |
|
|
from dotenv import load_dotenv |
|
|
from openai import OpenAI |
|
|
from elevenlabs.client import ElevenLabs |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
from langchain_postgres.vectorstores import PGVector |
|
|
from sqlalchemy import create_engine |
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" |
|
|
logging.getLogger("tensorflow").setLevel(logging.ERROR) |
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s - %(levelname)s - %(message)s", |
|
|
) |
|
|
|
|
|
load_dotenv() |
|
|
NEON_DATABASE_URL = os.getenv("NEON_DATABASE_URL") |
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
|
ELEVENLABS_API_KEY = os.getenv("ELEVENLABS_API_KEY") |
|
|
SHARED_SECRET = os.getenv("SHARED_SECRET") |
|
|
|
|
|
COLLECTION_NAME = "real_estate_embeddings" |
|
|
EMBEDDING_MODEL = "hkunlp/instructor-large" |
|
|
|
|
|
|
|
|
ELEVENLABS_VOICE_ID = "LHJy3mhZWsvhUjy0zUM1" |
|
|
|
|
|
PLANNER_MODEL = "gpt-4o-mini" |
|
|
ANSWERER_MODEL = "gpt-4o" |
|
|
|
|
|
TABLE_DESCRIPTIONS = """ |
|
|
- "ongoing_projects_source": Details about projects currently under construction. |
|
|
- "upcoming_projects_source": Information on future planned projects. |
|
|
- "completed_projects_source": Facts about projects that are already finished. |
|
|
- "historical_sales_source": Specific sales records, including price, date, and property ID. |
|
|
- "past_customers_source": Information about previous customers. |
|
|
- "feedback_source": Customer feedback and ratings for projects. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embeddings = None |
|
|
vector_store = None |
|
|
client_openai = OpenAI(api_key=OPENAI_API_KEY) |
|
|
client_elevenlabs = None |
|
|
|
|
|
|
|
|
try: |
|
|
key_preview = ( |
|
|
f"{ELEVENLABS_API_KEY[:5]}...{ELEVENLABS_API_KEY[-4:]}" |
|
|
if ELEVENLABS_API_KEY and len(ELEVENLABS_API_KEY) > 9 |
|
|
else "None" |
|
|
) |
|
|
logging.info(f"Initializing ElevenLabs client with key: {key_preview}") |
|
|
|
|
|
if not ELEVENLABS_API_KEY: |
|
|
raise ValueError("ELEVENLABS_API_KEY is missing or empty.") |
|
|
|
|
|
client_elevenlabs = ElevenLabs(api_key=ELEVENLABS_API_KEY) |
|
|
logging.info(f"ElevenLabs client created β type: {type(client_elevenlabs)}") |
|
|
|
|
|
|
|
|
voices = client_elevenlabs.voices.get_all() |
|
|
logging.info(f"Fetched {len(voices.voices)} voices from ElevenLabs.") |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"ElevenLabs init failed: {e}", exc_info=True) |
|
|
client_elevenlabs = None |
|
|
|
|
|
|
|
|
try: |
|
|
import elevenlabs |
|
|
|
|
|
logging.info(f"elevenlabs SDK version: {elevenlabs.__version__}") |
|
|
except Exception: |
|
|
logging.error("Could not import elevenlabs package.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
global embeddings, vector_store |
|
|
logging.info(f"Loading embedding model: {EMBEDDING_MODEL}") |
|
|
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) |
|
|
|
|
|
logging.info(f"Connecting to vector store: {COLLECTION_NAME}") |
|
|
engine = create_engine(NEON_DATABASE_URL, pool_pre_ping=True) |
|
|
vector_store = PGVector( |
|
|
connection=engine, |
|
|
collection_name=COLLECTION_NAME, |
|
|
embeddings=embeddings, |
|
|
) |
|
|
logging.info("Vector store ready.") |
|
|
yield |
|
|
logging.info("Shutting down.") |
|
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
QUERY_FORMULATION_PROMPT = """ |
|
|
You are a query analysis agent. Transform the user's query into a precise search query and determine the correct table to filter by. |
|
|
**Available Tables:** |
|
|
{table_descriptions} |
|
|
**User's Query:** "{user_query}" |
|
|
**Task:** |
|
|
1. Rephrase into a clear, keyword-focused English search query. |
|
|
2. If status keywords (ongoing, completed, upcoming, etc.) are present, pick the matching table. |
|
|
3. If no status keyword, set filter_table to null. |
|
|
4. Return JSON: {{"search_query": "...", "filter_table": "table_name or null"}} |
|
|
""".strip() |
|
|
|
|
|
ANSWER_SYSTEM_PROMPT = """ |
|
|
You are an expert AI assistant for a premier real estate developer. |
|
|
## CORE KNOWLEDGE |
|
|
- Cities: Pune, Mumbai, Bengaluru, Delhi, Chennai, Hyderabad, Goa, Gurgaon, Kolkata. |
|
|
- Properties: Luxury apartments, villas, commercial. |
|
|
- Budget: 45 lakhs to 5 crores. |
|
|
## RULES |
|
|
1. Match user language (Hinglish β Hinglish, English β English). |
|
|
2. Use CONTEXT if available, else use core knowledge. |
|
|
3. Only answer real estate questions. |
|
|
""".strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transcribe_audio(audio_path: str, audio_bytes: bytes) -> str: |
|
|
for attempt in range(3): |
|
|
try: |
|
|
audio_file = io.BytesIO(audio_bytes) |
|
|
filename = os.path.basename(audio_path) |
|
|
|
|
|
logging.info(f"Transcribing {filename} ({len(audio_bytes)} bytes)") |
|
|
transcript = client_openai.audio.transcriptions.create( |
|
|
model="whisper-1", |
|
|
file=(filename, audio_file), |
|
|
) |
|
|
text = transcript.text.strip() |
|
|
|
|
|
|
|
|
if re.search(r"[\u0900-\u097F]", text): |
|
|
resp = client_openai.chat.completions.create( |
|
|
model="gpt-4o-mini", |
|
|
messages=[ |
|
|
{"role": "user", "content": f"Transliterate to Roman (Hinglish): {text}"} |
|
|
], |
|
|
temperature=0.0, |
|
|
) |
|
|
text = resp.choices[0].message.content.strip() |
|
|
|
|
|
logging.info(f"Transcribed: {text}") |
|
|
return text |
|
|
except Exception as e: |
|
|
logging.error(f"Transcription error (attempt {attempt + 1}): {e}", exc_info=True) |
|
|
if attempt == 2: |
|
|
return "" |
|
|
return "" |
|
|
|
|
|
|
|
|
def generate_elevenlabs_sync(text: str) -> bytes: |
|
|
""" |
|
|
Uses the hard-coded voice ID and the correct SDK method. |
|
|
NOTE: `model` parameter is REMOVED in SDK v2.17.0+ |
|
|
""" |
|
|
if client_elevenlabs is None: |
|
|
logging.error("ElevenLabs client not initialized β skipping TTS.") |
|
|
return b"" |
|
|
|
|
|
for attempt in range(3): |
|
|
try: |
|
|
logging.info("Calling ElevenLabs text_to_speech.convert...") |
|
|
stream = client_elevenlabs.text_to_speech.convert( |
|
|
voice_id=ELEVENLABS_VOICE_ID, |
|
|
text=text, |
|
|
output_format="mp3_44100_128", |
|
|
|
|
|
) |
|
|
audio_bytes = b"" |
|
|
for chunk in stream: |
|
|
if chunk: |
|
|
audio_bytes += chunk |
|
|
logging.info(f"TTS returned {len(audio_bytes)} bytes.") |
|
|
return audio_bytes |
|
|
except Exception as e: |
|
|
logging.error( |
|
|
f"ElevenLabs TTS error (attempt {attempt + 1}): {e}", exc_info=True |
|
|
) |
|
|
if attempt == 2: |
|
|
return b"" |
|
|
return b"" |
|
|
|
|
|
|
|
|
async def formulate_search_plan(user_query: str) -> dict: |
|
|
logging.info(f"Formulating search plan for: {user_query}") |
|
|
for attempt in range(3): |
|
|
try: |
|
|
formatted = QUERY_FORMULATION_PROMPT.format( |
|
|
table_descriptions=TABLE_DESCRIPTIONS, user_query=user_query |
|
|
) |
|
|
resp = await run_in_threadpool( |
|
|
client_openai.chat.completions.create, |
|
|
model=PLANNER_MODEL, |
|
|
messages=[{"role": "user", "content": formatted}], |
|
|
response_format={"type": "json_object"}, |
|
|
temperature=0.0, |
|
|
) |
|
|
raw = resp.choices[0].message.content |
|
|
logging.info(f"Planner raw response: {raw}") |
|
|
plan = json.loads(raw) |
|
|
logging.info(f"Parsed plan: {plan}") |
|
|
return plan |
|
|
except Exception as e: |
|
|
logging.error(f"Planner error (attempt {attempt + 1}): {e}", exc_info=True) |
|
|
if attempt == 2: |
|
|
return {"search_query": user_query, "filter_table": None} |
|
|
return {"search_query": user_query, "filter_table": None} |
|
|
|
|
|
|
|
|
async def get_agent_response(user_text: str) -> str: |
|
|
for attempt in range(3): |
|
|
try: |
|
|
plan = await formulate_search_plan(user_text) |
|
|
search_q = plan.get("search_query", user_text) |
|
|
filter_tbl = plan.get("filter_table") |
|
|
search_filter = {"source_table": filter_tbl} if filter_tbl else {} |
|
|
|
|
|
docs = await run_in_threadpool( |
|
|
vector_store.similarity_search, |
|
|
search_q, |
|
|
k=3, |
|
|
filter=search_filter, |
|
|
) |
|
|
if not docs: |
|
|
docs = await run_in_threadpool(vector_store.similarity_search, search_q, k=3) |
|
|
|
|
|
context = "\n\n".join(d.page_content for d in docs) |
|
|
|
|
|
resp = await run_in_threadpool( |
|
|
client_openai.chat.completions.create, |
|
|
model=ANSWERER_MODEL, |
|
|
messages=[ |
|
|
{"role": "system", "content": ANSWER_SYSTEM_PROMPT}, |
|
|
{"role": "system", "content": f"CONTEXT:\n{context}"}, |
|
|
{"role": "user", "content": f"Question: {user_text}"}, |
|
|
], |
|
|
) |
|
|
return resp.choices[0].message.content.strip() |
|
|
except Exception as e: |
|
|
logging.error(f"RAG error (attempt {attempt + 1}): {e}", exc_info=True) |
|
|
if attempt == 2: |
|
|
return "Sorry, I couldn't respond. Please try again." |
|
|
return "Sorry, I couldn't respond." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TextQuery(BaseModel): |
|
|
query: str |
|
|
|
|
|
|
|
|
async def verify_token(x_auth_token: str = Header(...)): |
|
|
if not SHARED_SECRET or x_auth_token != SHARED_SECRET: |
|
|
logging.warning("Auth failed for /test-text-query") |
|
|
raise HTTPException(status_code=401, detail="Invalid token") |
|
|
logging.info("Auth passed") |
|
|
|
|
|
|
|
|
@app.post("/test-text-query", dependencies=[Depends(verify_token)]) |
|
|
async def test_text_query_endpoint(query: TextQuery): |
|
|
logging.info(f"Text query: {query.query}") |
|
|
response = await get_agent_response(query.query) |
|
|
return {"response": response} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def process_audio(audio_path): |
|
|
if not audio_path or not os.path.exists(audio_path): |
|
|
return None, "No valid audio file received." |
|
|
|
|
|
try: |
|
|
|
|
|
with open(audio_path, "rb") as f: |
|
|
audio_bytes = f.read() |
|
|
if not audio_bytes: |
|
|
return None, "Empty audio file." |
|
|
|
|
|
|
|
|
user_text = await run_in_threadpool(transcribe_audio, audio_path, audio_bytes) |
|
|
if not user_text: |
|
|
return None, "Couldn't understand audio. Try again." |
|
|
|
|
|
logging.info(f"User: {user_text}") |
|
|
|
|
|
|
|
|
agent_response = await get_agent_response(user_text) |
|
|
if not agent_response: |
|
|
return None, "No response generated." |
|
|
|
|
|
logging.info(f"AI: {agent_response[:100]}...") |
|
|
|
|
|
logging.info(f"FULL AI Response sent to ElevenLabs: >>>{agent_response}<<<") |
|
|
|
|
|
|
|
|
ai_audio_bytes = await run_in_threadpool(generate_elevenlabs_sync, agent_response) |
|
|
if not ai_audio_bytes: |
|
|
logging.error("TTS failed β returning text only.") |
|
|
return ( |
|
|
None, |
|
|
f"**You:** {user_text}\n\n**AI:** {agent_response}\n\n_(Audio generation failed)_", |
|
|
) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f: |
|
|
f.write(ai_audio_bytes) |
|
|
out_path = f.name |
|
|
logging.info(f"Saved TTS audio to {out_path}") |
|
|
|
|
|
return out_path, f"**You:** {user_text}\n\n**AI:** {agent_response}" |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Audio processing error: {e}", exc_info=True) |
|
|
return None, f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Real Estate AI") as demo: |
|
|
gr.Markdown("# Real Estate Voice Assistant") |
|
|
gr.Markdown("Ask about projects in Pune, Mumbai, Bengaluru, etc.") |
|
|
|
|
|
with gr.Row(): |
|
|
inp = gr.Audio(sources=["microphone"], type="filepath", label="Speak") |
|
|
out_audio = gr.Audio(label="AI Response", type="filepath") |
|
|
|
|
|
out_text = gr.Textbox(label="Conversation", lines=8) |
|
|
|
|
|
inp.change(process_audio, inputs=inp, outputs=[out_audio, out_text]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = gr.mount_gradio_app(app, demo, path="/") |