Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from huggingface_hub import hf_hub_download | |
| from pyllamacpp.model import Model | |
| # ----------------------------------------------------------------------------- | |
| # Hugging Face Hub の設定 | |
| # ----------------------------------------------------------------------------- | |
| HF_TOKEN = os.environ.get("HF_TOKEN") # 必要に応じて Secrets にセット | |
| REPO_ID = "google/gemma-3-12b-it-qat-q4_0-gguf" | |
| # 実際にリポジトリに置かれている GGUF ファイル名を確認してください。 | |
| # 例: "gemma-3-12b-it-qat-q4_0-gguf.gguf" | |
| GGUF_FILENAME = "gemma-3-12b-it-qat-q4_0-gguf.gguf" | |
| # キャッシュ先のパス(リポジトリ直下に置く場合) | |
| MODEL_PATH = os.path.join(os.getcwd(), GGUF_FILENAME) | |
| # ----------------------------------------------------------------------------- | |
| # 起動時に一度だけダウンロード | |
| # ----------------------------------------------------------------------------- | |
| if not os.path.exists(MODEL_PATH): | |
| print(f"Downloading {GGUF_FILENAME} from {REPO_ID} …") | |
| hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=GGUF_FILENAME, | |
| token=HF_TOKEN, | |
| repo_type="model", # 明示的にモデルリポジトリを指定 | |
| local_dir=os.getcwd(), # カレントディレクトリに保存 | |
| local_dir_use_symlinks=False | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # llama.cpp (pyllamacpp) で 4bit GGUF モデルをロード | |
| # ----------------------------------------------------------------------------- | |
| llm = Model( | |
| model_path=MODEL_PATH, | |
| n_ctx=512, # 必要に応じて調整 | |
| n_threads=4, # 実マシンのコア数に合わせて | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # FastAPI 定義 | |
| # ----------------------------------------------------------------------------- | |
| app = FastAPI(title="Gemma3-12B-IT Q4_0 GGUF API") | |
| class GenerationRequest(BaseModel): | |
| prompt: str | |
| max_new_tokens: int = 128 | |
| temperature: float = 0.8 | |
| top_p: float = 0.95 | |
| async def generate(req: GenerationRequest): | |
| if not req.prompt: | |
| raise HTTPException(status_code=400, detail="`prompt` は必須です。") | |
| # llama.cpp の generate を呼び出し | |
| text = llm.generate( | |
| req.prompt, | |
| top_p=req.top_p, | |
| temp=req.temperature, | |
| n_predict=req.max_new_tokens, | |
| repeat_last_n=64, | |
| repeat_penalty=1.1 | |
| ) | |
| return {"generated_text": text} | |
| # ----------------------------------------------------------------------------- | |
| # ローカル起動用 | |
| # ----------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 8000)) | |
| uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info") | |