Spaces:
Running
Running
| import os | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| # Define the API schema for the request body | |
| class EmbeddingRequest(BaseModel): | |
| text: list[str] | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| # Check for GPU and load model accordingly | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_name = os.getenv("MODEL_NAME") | |
| model = SentenceTransformer(model_name, device=device) | |
| # Define the embedding endpoint | |
| async def get_embeddings(request: EmbeddingRequest): | |
| try: | |
| # Get embeddings for the input text | |
| embeddings = model.encode(request.text, convert_to_numpy=True).tolist() | |
| return {"embeddings": embeddings} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) |