Spaces:
Sleeping
Sleeping
laserbeam2045
commited on
Commit
·
9d3ba14
1
Parent(s):
215bcb0
fix
Browse files- app.py +40 -15
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -3,7 +3,7 @@ import os
|
|
| 3 |
from fastapi import FastAPI, HTTPException
|
| 4 |
from pydantic import BaseModel
|
| 5 |
from huggingface_hub import hf_hub_download
|
| 6 |
-
from
|
| 7 |
|
| 8 |
# -----------------------------------------------------------------------------
|
| 9 |
# Hugging Face Hub の設定
|
|
@@ -32,11 +32,24 @@ if not os.path.exists(MODEL_PATH):
|
|
| 32 |
)
|
| 33 |
|
| 34 |
# -----------------------------------------------------------------------------
|
| 35 |
-
# llama
|
| 36 |
# -----------------------------------------------------------------------------
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
# -----------------------------------------------------------------------------
|
| 42 |
# FastAPI 定義
|
|
@@ -48,21 +61,32 @@ class GenerationRequest(BaseModel):
|
|
| 48 |
max_new_tokens: int = 128
|
| 49 |
temperature: float = 0.8
|
| 50 |
top_p: float = 0.95
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
@app.post("/generate")
|
| 53 |
async def generate(req: GenerationRequest):
|
| 54 |
if not req.prompt:
|
| 55 |
raise HTTPException(status_code=400, detail="`prompt` は必須です。")
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
# -----------------------------------------------------------------------------
|
| 68 |
# ローカル起動用
|
|
@@ -70,4 +94,5 @@ async def generate(req: GenerationRequest):
|
|
| 70 |
if __name__ == "__main__":
|
| 71 |
import uvicorn
|
| 72 |
port = int(os.environ.get("PORT", 8000))
|
|
|
|
| 73 |
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")
|
|
|
|
| 3 |
from fastapi import FastAPI, HTTPException
|
| 4 |
from pydantic import BaseModel
|
| 5 |
from huggingface_hub import hf_hub_download
|
| 6 |
+
from llama_cpp import Llama # llama-cpp-python をインポート
|
| 7 |
|
| 8 |
# -----------------------------------------------------------------------------
|
| 9 |
# Hugging Face Hub の設定
|
|
|
|
| 32 |
)
|
| 33 |
|
| 34 |
# -----------------------------------------------------------------------------
|
| 35 |
+
# llama-cpp-python で 4bit GGUF モデルをロード
|
| 36 |
# -----------------------------------------------------------------------------
|
| 37 |
+
print(f"Loading model from {MODEL_PATH}...")
|
| 38 |
+
try:
|
| 39 |
+
llm = Llama(
|
| 40 |
+
model_path=MODEL_PATH,
|
| 41 |
+
n_ctx=2048, # コンテキストサイズ (モデルに合わせて調整してください)
|
| 42 |
+
# n_gpu_layers=-1, # GPU を使う場合 (Hugging Face Spaces 無料枠では通常 0)
|
| 43 |
+
n_gpu_layers=0, # CPU のみ使用
|
| 44 |
+
verbose=True # 詳細ログを出力
|
| 45 |
+
)
|
| 46 |
+
print("Model loaded successfully.")
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"Error loading model: {e}")
|
| 49 |
+
# エラーが発生した場合、アプリケーションを終了させるか、エラーハンドリングを行う
|
| 50 |
+
# ここでは簡単なエラーメッセージを出力して終了する例
|
| 51 |
+
raise RuntimeError(f"Failed to load the LLM model: {e}")
|
| 52 |
+
|
| 53 |
|
| 54 |
# -----------------------------------------------------------------------------
|
| 55 |
# FastAPI 定義
|
|
|
|
| 61 |
max_new_tokens: int = 128
|
| 62 |
temperature: float = 0.8
|
| 63 |
top_p: float = 0.95
|
| 64 |
+
# llama-cpp-python で利用可能な他のパラメータも追加可能
|
| 65 |
+
# stop: list[str] | None = None
|
| 66 |
+
# repeat_penalty: float = 1.1
|
| 67 |
|
| 68 |
@app.post("/generate")
|
| 69 |
async def generate(req: GenerationRequest):
|
| 70 |
if not req.prompt:
|
| 71 |
raise HTTPException(status_code=400, detail="`prompt` は必須です。")
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
# llama-cpp-python の __call__ メソッドで生成
|
| 75 |
+
output = llm(
|
| 76 |
+
req.prompt,
|
| 77 |
+
max_tokens=req.max_new_tokens,
|
| 78 |
+
temperature=req.temperature,
|
| 79 |
+
top_p=req.top_p,
|
| 80 |
+
# stop=req.stop, # 必要なら追加
|
| 81 |
+
# repeat_penalty=req.repeat_penalty, # 必要なら追加
|
| 82 |
+
)
|
| 83 |
+
# 生成されたテキストを取得
|
| 84 |
+
generated_text = output["choices"][0]["text"]
|
| 85 |
+
return {"generated_text": generated_text}
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print(f"Error during generation: {e}")
|
| 88 |
+
raise HTTPException(status_code=500, detail=f"生成中にエラーが発生しました: {e}")
|
| 89 |
+
|
| 90 |
|
| 91 |
# -----------------------------------------------------------------------------
|
| 92 |
# ローカル起動用
|
|
|
|
| 94 |
if __name__ == "__main__":
|
| 95 |
import uvicorn
|
| 96 |
port = int(os.environ.get("PORT", 8000))
|
| 97 |
+
# アプリケーションのロードに失敗した場合に備えて try-except を追加することも検討
|
| 98 |
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")
|
requirements.txt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
fastapi
|
| 2 |
uvicorn[standard]
|
| 3 |
-
|
|
|
|
| 1 |
fastapi
|
| 2 |
uvicorn[standard]
|
| 3 |
+
llama-cpp-python
|