Spaces:
Paused
Paused
| import os | |
| from typing import List, Dict | |
| import logging | |
| import dotenv | |
| import torch | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import JSONResponse, HTMLResponse, StreamingResponse | |
| from pydantic import BaseModel | |
| from llama_cpp import Llama | |
| from huggingface_hub import hf_hub_download, login | |
| import uvicorn | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| AutoModelForCausalLM, | |
| pipeline | |
| ) | |
| from UofTearsBot import UofTearsBot | |
| MODEL_REPO = "bartowski/Mistral-7B-Instruct-v0.3-GGUF" | |
| MODEL_FILE = "Mistral-7B-Instruct-v0.3-Q4_K_M.gguf" | |
| CHAT_FORMAT = "mistral-instruct" | |
| dotenv.load_dotenv() | |
| login(token=os.getenv("HF_TOKEN")) | |
| MODEL_PATH = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=MODEL_FILE, | |
| local_dir="/tmp/models", | |
| local_dir_use_symlinks=False, | |
| ) | |
| llm = Llama( | |
| model_path=MODEL_PATH, | |
| n_ctx=int(os.getenv("N_CTX", "1024")), | |
| n_threads=os.cpu_count() or 4, | |
| n_batch=int(os.getenv("N_BATCH", "32")), | |
| chat_format=CHAT_FORMAT, | |
| ) | |
| # Start the FastAPI app | |
| app = FastAPI() | |
| chatbots: Dict[str, UofTearsBot] = {} | |
| class ChatRequest(BaseModel): | |
| user_id: str | |
| user_text: str | |
| async def chat(request: ChatRequest): | |
| try: | |
| if request.user_id not in chatbots: | |
| chatbots[request.user_id] = UofTearsBot(llm) | |
| current_bot = chatbots[request.user_id] | |
| def token_generator(): | |
| print("[INFO] Model is streaming response...", flush=True) | |
| for token in current_bot.converse(request.user_text): | |
| yield token | |
| print("[INFO] Model finished streaming β ", flush=True) | |
| return StreamingResponse(token_generator(), media_type="text/plain") | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() # logs to HF logs | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": str(e)} | |
| ) | |
| async def home(): | |
| return "<h1>App is running π</h1>" | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) # huggingface port | |