Spaces:
Sleeping
Sleeping
| """Deployment server. | |
| Routes: | |
| - Get client.zip | |
| - Add a key | |
| - Compute | |
| """ | |
| import io | |
| import os | |
| import uuid | |
| from pathlib import Path | |
| from typing import Dict | |
| import uvicorn | |
| from fastapi import FastAPI, Form, HTTPException, UploadFile | |
| from fastapi.responses import FileResponse, StreamingResponse | |
| # No relative import here because when not used in the package itself | |
| from concrete.ml.deployment import FHEModelServer | |
| if __name__ == "__main__": | |
| app = FastAPI(debug=False) | |
| FILE_FOLDER = Path(__file__).parent | |
| KEY_PATH = Path(os.environ.get("KEY_PATH", FILE_FOLDER / Path("server_keys"))) | |
| CLIENT_SERVER_PATH = Path(os.environ.get("PATH_TO_MODEL", FILE_FOLDER / Path("dev"))) | |
| PORT = os.environ.get("PORT", "5000") | |
| fhe = FHEModelServer(str(CLIENT_SERVER_PATH.resolve())) | |
| KEYS: Dict[str, bytes] = {} | |
| PATH_TO_CLIENT = (CLIENT_SERVER_PATH / "client.zip").resolve() | |
| PATH_TO_SERVER = (CLIENT_SERVER_PATH / "server.zip").resolve() | |
| assert PATH_TO_CLIENT.exists() | |
| assert PATH_TO_SERVER.exists() | |
| def get_client(): | |
| """Get client. | |
| Returns: | |
| FileResponse: client.zip | |
| Raises: | |
| HTTPException: if the file can't be find locally | |
| """ | |
| path_to_client = (CLIENT_SERVER_PATH / "client.zip").resolve() | |
| if not path_to_client.exists(): | |
| raise HTTPException(status_code=500, detail="Could not find client.") | |
| return FileResponse(path_to_client, media_type="application/zip") | |
| async def add_key(key: UploadFile): | |
| """Add public key. | |
| Arguments: | |
| key (UploadFile): public key | |
| Returns: | |
| Dict[str, str] | |
| - uid: uid a personal uid | |
| """ | |
| uid = str(uuid.uuid4()) | |
| KEYS[uid] = await key.read() | |
| return {"uid": uid} | |
| async def compute(model_input: UploadFile, uid: str = Form()): # noqa: B008 | |
| """Compute the circuit over encrypted input. | |
| Arguments: | |
| model_input (UploadFile): input of the circuit | |
| uid (str): uid of the public key to use | |
| Returns: | |
| StreamingResponse: the result of the circuit | |
| """ | |
| key = KEYS[uid] | |
| encrypted_results = fhe.run( | |
| serialized_encrypted_quantized_data=await model_input.read(), | |
| serialized_evaluation_keys=key, | |
| ) | |
| return StreamingResponse( | |
| io.BytesIO(encrypted_results), | |
| ) | |
| uvicorn.run(app, host="0.0.0.0", port=int(PORT)) | |