Spaces:
Running
Running
| from fastapi import FastAPI | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import os | |
| import logging | |
| # 初始化日志 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("CodeSecurityAPI") | |
| # 强制设置缓存路径(解决权限问题) | |
| os.environ["HF_HOME"] = "/app/.cache/huggingface" | |
| # 加载模型 | |
| try: | |
| logger.info("Loading model...") | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| "mrm8488/codebert-base-finetuned-detect-insecure-code", | |
| cache_dir=os.getenv("HF_HOME") | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "mrm8488/codebert-base-finetuned-detect-insecure-code", | |
| cache_dir=os.getenv("HF_HOME") | |
| ) | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Model load failed: {str(e)}") | |
| raise RuntimeError("模型加载失败,请检查网络连接或模型路径") | |
| app = FastAPI() | |
| async def detect(code: str): | |
| try: | |
| # 输入处理(限制长度) | |
| code = code[:2000] # 截断超长输入 | |
| # 模型推理 | |
| inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # 解析结果 | |
| label_id = outputs.logits.argmax().item() | |
| return { | |
| "label": model.config.id2label[label_id], | |
| "score": outputs.logits.softmax(dim=-1)[0][label_id].item() | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| async def health(): | |
| return {"status": "ok"} |