|
|
import os |
|
|
import base64 |
|
|
import logging |
|
|
import json |
|
|
import re |
|
|
from contextlib import asynccontextmanager |
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, status, Depends, Header, HTTPException, Query |
|
|
from fastapi.concurrency import run_in_threadpool |
|
|
from pydantic import BaseModel |
|
|
from dotenv import load_dotenv |
|
|
from openai import OpenAI |
|
|
from elevenlabs.client import ElevenLabs |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
from langchain_postgres.vectorstores import PGVector |
|
|
from sqlalchemy import create_engine |
|
|
import asyncio |
|
|
import io |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
logging.getLogger('tensorflow').setLevel(logging.ERROR) |
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
NEON_DATABASE_URL = os.getenv("NEON_DATABASE_URL") |
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
|
ELEVENLABS_API_KEY = os.getenv("ELEVENLABS_API_KEY") |
|
|
SHARED_SECRET = os.getenv("SHARED_SECRET") |
|
|
|
|
|
|
|
|
COLLECTION_NAME = "real_estate_embeddings" |
|
|
EMBEDDING_MODEL = "hkunlp/instructor-large" |
|
|
ELEVENLABS_VOICE_NAME = "Leo" |
|
|
PLANNER_MODEL = "gpt-4o-mini" |
|
|
ANSWERER_MODEL = "gpt-4o" |
|
|
TABLE_DESCRIPTIONS = """ |
|
|
- "ongoing_projects_source": Details about projects currently under construction. |
|
|
- "upcoming_projects_source": Information on future planned projects. |
|
|
- "completed_projects_source": Facts about projects that are already finished. |
|
|
- "historical_sales_source": Specific sales records, including price, date, and property ID. |
|
|
- "past_customers_source": Information about previous customers. |
|
|
- "feedback_source": Customer feedback and ratings for projects. |
|
|
""" |
|
|
|
|
|
|
|
|
embeddings = None |
|
|
vector_store = None |
|
|
|
|
|
client_openai = OpenAI(api_key=OPENAI_API_KEY) |
|
|
client_elevenlabs = ElevenLabs(api_key=ELEVENLABS_API_KEY) |
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""Manages application startup and shutdown logic.""" |
|
|
global embeddings, vector_store |
|
|
logging.info(f"Initializing embedding model: '{EMBEDDING_MODEL}'...") |
|
|
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) |
|
|
logging.info("Embedding model loaded successfully.") |
|
|
|
|
|
logging.info(f"Connecting to vector store '{COLLECTION_NAME}'...") |
|
|
engine = create_engine(NEON_DATABASE_URL, pool_pre_ping=True) |
|
|
vector_store = PGVector( |
|
|
connection=engine, |
|
|
collection_name=COLLECTION_NAME, |
|
|
embeddings=embeddings, |
|
|
) |
|
|
logging.info("Successfully connected to the vector store.") |
|
|
yield |
|
|
logging.info("Application shutting down.") |
|
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
|
|
|
QUERY_FORMULATION_PROMPT = f""" |
|
|
You are a query analysis agent. Your task is to transform a user's query into a precise search query for a vector database and determine the correct table to filter by. |
|
|
**Available Tables:** |
|
|
{TABLE_DESCRIPTIONS} |
|
|
**User's Query:** "{{user_query}}" |
|
|
**Your Task:** |
|
|
1. Rephrase the user's query into a clear, keyword-focused English question suitable for a database search. |
|
|
2. Analyze the user's query for keywords indicating project status (e.g., "ongoing", "under construction", "completed", "finished", "upcoming", "new launch"). |
|
|
3. If such status keywords are present, identify the single most relevant table from the list above to filter by. |
|
|
4. If no specific status keywords are mentioned (e.g., the user asks generally about projects in a location), set the filter table to null. |
|
|
5. Respond ONLY with a JSON object containing "search_query" and "filter_table" (which should be the table name string or null). |
|
|
""" |
|
|
ANSWER_SYSTEM_PROMPT = """ |
|
|
You are an expert AI assistant for a premier real estate developer. |
|
|
## YOUR PERSONA |
|
|
- You are professional, helpful, and highly knowledgeable. Your tone should be polite and articulate. |
|
|
## CORE BUSINESS KNOWLEDGE |
|
|
- **Operational Cities:** We are currently operational in Pune, Mumbai, Bengaluru, Delhi, Chennai, Hyderabad, Goa, Gurgaon, Kolkata. |
|
|
- **Property Types:** We offer luxury apartments, villas, and commercial properties. |
|
|
- **Budget Range:** Our residential properties typically range from 45 lakhs to 5 crores. |
|
|
## CORE RULES |
|
|
1. **Language Adaptation:** If the user's original query was in Hinglish, respond in Hinglish. If in English, respond in English. |
|
|
2. **Fact-Based Answers:** Use the provided CONTEXT to answer the user's question. If the context is empty, use your Core Business Knowledge. |
|
|
3. **Stay on Topic:** Only answer questions related to real estate. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transcribe_audio(audio_bytes: bytes) -> str: |
|
|
""" |
|
|
Transcribes any audio format (WAV, MP3, WebM, Opus) from raw bytes. |
|
|
Whisper will auto-detect the format. |
|
|
""" |
|
|
for attempt in range(3): |
|
|
try: |
|
|
audio_file = io.BytesIO(audio_bytes) |
|
|
|
|
|
audio_file.name = "input.audio" |
|
|
|
|
|
transcript = client_openai.audio.transcriptions.create( |
|
|
model="whisper-1", |
|
|
file=audio_file |
|
|
) |
|
|
text = transcript.text |
|
|
|
|
|
|
|
|
if re.search(r'[\u0900-\u097F]', text): |
|
|
translit_prompt = f"Transliterate this Hindi text to Roman script (Hinglish style): {text}" |
|
|
response = client_openai.chat.completions.create( |
|
|
model="gpt-4o-mini", |
|
|
messages=[{"role": "user", "content": translit_prompt}], |
|
|
temperature=0.0 |
|
|
) |
|
|
text = response.choices[0].message.content |
|
|
|
|
|
return text.strip() |
|
|
except Exception as e: |
|
|
logging.error(f"Error during transcription (attempt {attempt+1}): {e}", exc_info=True) |
|
|
if attempt == 2: |
|
|
return "" |
|
|
|
|
|
def generate_elevenlabs_sync(text: str, voice: str) -> bytes: |
|
|
"""Synchronous ElevenLabs generation wrapper for run_in_threadpool.""" |
|
|
for attempt in range(3): |
|
|
try: |
|
|
return client_elevenlabs.generate( |
|
|
text=text, |
|
|
voice=voice, |
|
|
model="eleven_multilingual_v2", |
|
|
output_format="mp3_44100_128" |
|
|
) |
|
|
except Exception as e: |
|
|
logging.error(f"Error in ElevenLabs generate (attempt {attempt+1}): {e}", exc_info=True) |
|
|
if attempt == 2: |
|
|
return b'' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def formulate_search_plan(user_query: str) -> dict: |
|
|
logging.info("Formulating search plan with Planner LLM...") |
|
|
for attempt in range(3): |
|
|
try: |
|
|
response = await run_in_threadpool( |
|
|
client_openai.chat.completions.create, |
|
|
model=PLANNER_MODEL, |
|
|
messages=[{"role": "user", "content": QUERY_FORMULATION_PROMPT.format(user_query=user_query)}], |
|
|
response_format={"type": "json_object"}, |
|
|
temperature=0.0 |
|
|
) |
|
|
plan = json.loads(response.choices[0].message.content) |
|
|
logging.info(f"Search plan received: {plan}") |
|
|
return plan |
|
|
except Exception as e: |
|
|
logging.error(f"Error in Planner LLM call (attempt {attempt+1}): {e}", exc_info=True) |
|
|
if attempt == 2: |
|
|
return {"search_query": user_query, "filter_table": None} |
|
|
|
|
|
async def get_agent_response(user_text: str) -> str: |
|
|
"""Runs RAG and generation logic for a given text query with retries.""" |
|
|
for attempt in range(3): |
|
|
try: |
|
|
search_plan = await formulate_search_plan(user_text) |
|
|
search_query = search_plan.get("search_query", user_text) |
|
|
filter_table = search_plan.get("filter_table") |
|
|
|
|
|
search_filter = {"source_table": filter_table} if filter_table else {} |
|
|
if search_filter: |
|
|
logging.info(f"Applying initial filter: {search_filter}") |
|
|
|
|
|
|
|
|
retrieved_docs = await run_in_threadpool( |
|
|
vector_store.similarity_search, |
|
|
search_query, k=3, filter=search_filter |
|
|
) |
|
|
|
|
|
if not retrieved_docs: |
|
|
logging.info("Initial search returned no results. Performing a broader fallback search.") |
|
|
retrieved_docs = await run_in_threadpool( |
|
|
vector_store.similarity_search, |
|
|
search_query, k=3 |
|
|
) |
|
|
|
|
|
context_text = "\n\n".join([doc.page_content for doc in retrieved_docs]) |
|
|
logging.info(f"Retrieved Context (preview): {context_text[:500]}...") |
|
|
|
|
|
final_prompt_messages = [ |
|
|
{"role": "system", "content": ANSWER_SYSTEM_PROMPT}, |
|
|
{"role": "system", "content": f"Use the following CONTEXT to answer:\n{context_text}"}, |
|
|
{"role": "user", "content": f"My original question was: '{user_text}'"} |
|
|
] |
|
|
|
|
|
|
|
|
final_response = await run_in_threadpool( |
|
|
client_openai.chat.completions.create, |
|
|
model=ANSWERER_MODEL, |
|
|
messages=final_prompt_messages |
|
|
) |
|
|
|
|
|
return final_response.choices[0].message.content |
|
|
except Exception as e: |
|
|
logging.error(f"Error in get_agent_response (attempt {attempt+1}): {e}", exc_info=True) |
|
|
if attempt == 2: |
|
|
return "Sorry, I couldn't generate a response. Please try again." |
|
|
|
|
|
|
|
|
|
|
|
class TextQuery(BaseModel): |
|
|
query: str |
|
|
|
|
|
async def verify_token(x_auth_token: str = Header(...)): |
|
|
"""Dependency to verify the shared secret token.""" |
|
|
if not SHARED_SECRET or x_auth_token != SHARED_SECRET: |
|
|
logging.warning("Authentication failed for /test-text-query.") |
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or missing authentication token") |
|
|
logging.info("Authentication successful for /test-text-query.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/test-text-query", dependencies=[Depends(verify_token)]) |
|
|
async def test_text_query_endpoint(query: TextQuery): |
|
|
"""Endpoint for text-based testing via Swagger UI.""" |
|
|
logging.info(f"Received text query: {query.query}") |
|
|
response_text = await get_agent_response(query.query) |
|
|
logging.info(f"Generated text response: {response_text}") |
|
|
return {"response": response_text} |
|
|
|
|
|
|
|
|
@app.websocket("/browser-listen") |
|
|
async def browser_websocket_endpoint( |
|
|
websocket: WebSocket, |
|
|
token: Optional[str] = Query(None) |
|
|
): |
|
|
""" |
|
|
Main WebSocket endpoint for browser-based audio. |
|
|
Authenticates using a query parameter. |
|
|
""" |
|
|
|
|
|
if not token or token != SHARED_SECRET: |
|
|
logging.warning(f"Browser auth failed: Invalid token '{token}'") |
|
|
await websocket.accept() |
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION) |
|
|
return |
|
|
|
|
|
await websocket.accept() |
|
|
logging.info("Browser client connected and authenticated.") |
|
|
|
|
|
try: |
|
|
while True: |
|
|
|
|
|
message = await websocket.receive_json() |
|
|
audio_base64 = message.get("audio") |
|
|
|
|
|
if not audio_base64: |
|
|
continue |
|
|
|
|
|
logging.info("Received audio blob from browser.") |
|
|
audio_bytes = base64.b64decode(audio_base64) |
|
|
|
|
|
|
|
|
user_text = await run_in_threadpool(transcribe_audio, audio_bytes) |
|
|
if not user_text: |
|
|
logging.info("Transcription empty; skipping.") |
|
|
continue |
|
|
logging.info(f"User said: {user_text}") |
|
|
|
|
|
|
|
|
agent_response_text = await get_agent_response(user_text) |
|
|
if not agent_response_text: |
|
|
logging.warning("Agent generated empty response.") |
|
|
continue |
|
|
logging.info(f"AI Responded (preview): {agent_response_text[:100]}...") |
|
|
|
|
|
|
|
|
ai_audio_bytes = await run_in_threadpool( |
|
|
generate_elevenlabs_sync, |
|
|
agent_response_text, |
|
|
ELEVENLABS_VOICE_NAME |
|
|
) |
|
|
if not ai_audio_bytes: |
|
|
continue |
|
|
|
|
|
|
|
|
response_audio_base64 = base64.b64encode(ai_audio_bytes).decode('utf-8') |
|
|
|
|
|
await websocket.send_json({ |
|
|
"text": agent_response_text, |
|
|
"audio": response_audio_base64 |
|
|
}) |
|
|
logging.info("Sent AI audio response back to browser.") |
|
|
|
|
|
except WebSocketDisconnect: |
|
|
logging.info("Browser client disconnected.") |
|
|
except Exception as e: |
|
|
logging.error(f"An error occurred in browser websocket: {e}", exc_info=True) |
|
|
finally: |
|
|
try: |
|
|
await websocket.close() |
|
|
except Exception: |
|
|
pass |