Spaces:
Running
Running
File size: 6,401 Bytes
9f84bcd 639d612 9f84bcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
# main.py
from dotenv import load_dotenv
load_dotenv()
import sys
import uuid
import asyncio
import logging
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from chatbot import Chatbot
# -----------------------------
# Windows Asyncio Fix
# -----------------------------
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
# -----------------------------
# FastAPI app & CORS
# -----------------------------
app = FastAPI(
title="Session-Based RAG Chatbot API",
description="Session-based RAG Chatbot API with WebSocket support",
version="1.1.0"
)
origins = [
"http://localhost:8080",
"http://127.0.0.1:8080",
"http://127.0.0.1:5500",
"https://vaishnavibasukar20.github.io/WebIQ-Frontend"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# -----------------------------
# Session storage
# -----------------------------
chatbot_sessions = {} # {session_id: Chatbot instance or None if failed}
# -----------------------------
# Root endpoint
# -----------------------------
@app.get("/")
def read_root():
return {"message": "Welcome to the Session-Based RAG Chatbot API!", "status": "Ready"}
@app.get("/create_session")
def create_session():
return {"session":str(uuid.uuid4())}
@app.get("/session_status/{session_id}")
def session_status(session_id: str):
"""
Returns the current status of a chatbot session.
Status can be:
- initializing (session exists but chatbot not ready)
- ready (chatbot instance ready)
- failed (chatbot initialization failed)
"""
if session_id not in chatbot_sessions:
return {"status": "not_found"}
chatbot = chatbot_sessions[session_id]
if chatbot is None:
return {"status": "initializing"}
elif chatbot == "err":
chatbot = None
return {"status": "failed"}
return {"status": "ready"}
# -----------------------------
# Helper: Run async init in background
# -----------------------------
def run_chatbot_init(session_id, urls, llm_model, embedding_model, api_key):
asyncio.create_task(initialize_chatbot(session_id, urls, llm_model, embedding_model, api_key))
# -----------------------------
# Scrape & initialize chatbot
# -----------------------------
@app.post("/scrape/")
async def scrape_and_load(response: dict, background_tasks: BackgroundTasks):
session_id = response.get("session_id")
urls = response.get("urls")
llm_model = response.get("llm_model", "TheBloke/Llama-2-7B-Chat-GGML")
embedding_model = response.get("embedding_model", "BAAI/bge-small-en")
api_key = response.get("api_key", None)
if not urls:
raise HTTPException(status_code=400, detail="urls are required.")
if session_id in chatbot_sessions:
return {"message": f"Chatbot for session {session_id} already initialized.", "session_id": session_id}
# Mark session as initializing
chatbot_sessions[session_id] = None
# Use a **blocking wrapper** to run async in thread safely
async def init_wrapper():
try:
await initialize_chatbot(session_id, urls, llm_model, embedding_model, api_key)
except Exception as e:
logging.error(f"[{session_id}] Initialization error: {e}", exc_info=True)
chatbot_sessions[session_id] = None
background_tasks.add_task(init_wrapper)
logging.info(f"[{session_id}] Chatbot initialization scheduled in background.")
return {"message": "Chatbot initialization started.", "session_id": session_id}
# -----------------------------
# Initialize chatbot
# -----------------------------
async def initialize_chatbot(session_id, urls, llm_model, embedding_model, api_key):
try:
logging.info(f"[{session_id}] Initializing chatbot...")
chatbot = Chatbot(
url=urls,
llm_model=llm_model,
embedding_model=embedding_model,
api_key=api_key
)
await chatbot.initialize()
chatbot_sessions[session_id] = chatbot
logging.info(f"[{session_id}] Chatbot ready.")
except NotImplementedError as e:
logging.error(f"[{session_id}] Playwright async not supported on Windows: {e}", exc_info=True)
chatbot_sessions[session_id] = None
except Exception as e:
logging.error(f"[{session_id}] Initialization failed: {e}", exc_info=True)
chatbot_sessions[session_id] = "err"
# -----------------------------
# WebSocket endpoint
# -----------------------------
@app.websocket("/ws/chat/{session_id}")
async def websocket_endpoint(websocket: WebSocket, session_id: str):
await websocket.accept()
logging.info(f"[{session_id}] WebSocket connected.")
try:
# Wait until chatbot is ready
while session_id not in chatbot_sessions or chatbot_sessions[session_id] is None:
await websocket.send_json({"text": "Initializing chatbot, please wait..."})
await asyncio.sleep(1)
chatbot_instance = chatbot_sessions[session_id]
if chatbot_instance is None:
await websocket.send_json({
"text": "Chatbot initialization failed. Likely due to Playwright async issue on Windows."
})
return
await websocket.send_json({"text": f"Chatbot session {session_id} is ready! You can start chatting."})
while True:
data = await websocket.receive_json()
query = data.get("query")
if not query:
continue
response_text = await chatbot_instance.query(query)
await websocket.send_json({"text": response_text})
except WebSocketDisconnect:
logging.info(f"[{session_id}] WebSocket disconnected.")
except Exception as e:
logging.error(f"[{session_id}] WebSocket error: {e}", exc_info=True)
try:
await websocket.send_json({"text": "An unexpected server error occurred."})
except:
pass
# -----------------------------
# Run with: uvicorn main:app --reload
# -----------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)
|