Spaces:
Running
Running
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import os | |
| import logging | |
| # === 初始化配置 === | |
| app = FastAPI(title="Code Security API") | |
| # 解决跨域问题 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # === 强制设置缓存路径 === | |
| os.environ["HF_HOME"] = "/app/.cache/huggingface" | |
| cache_path = os.getenv("HF_HOME") | |
| os.makedirs(cache_path, exist_ok=True) | |
| # === 日志配置 === | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("CodeBERT-API") | |
| # === 根路径路由(必须定义)=== | |
| async def read_root(): | |
| """健康检查端点""" | |
| return { | |
| "status": "running", | |
| "endpoints": { | |
| "detect": "POST /detect - 代码安全检测", | |
| "specs": "GET /openapi.json - API文档" | |
| } | |
| } | |
| # === 模型加载 === | |
| try: | |
| logger.info("Loading model from: %s", cache_path) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| "mrm8488/codebert-base-finetuned-detect-insecure-code", | |
| cache_dir=cache_path | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "mrm8488/codebert-base-finetuned-detect-insecure-code", | |
| cache_dir=cache_path | |
| ) | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| logger.error("Model load failed: %s", str(e)) | |
| raise RuntimeError("模型初始化失败") | |
| # === 核心检测接口 === | |
| async def detect_vulnerability(code: str): | |
| """代码安全检测主接口""" | |
| try: | |
| # 输入处理 | |
| code = code[:2000] # 截断超长输入 | |
| # 模型推理 | |
| inputs = tokenizer( | |
| code, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # 结果解析 | |
| label_id = outputs.logits.argmax().item() | |
| return { | |
| "label": label_id, # 0:安全 1:不安全 | |
| "confidence": outputs.logits.softmax(dim=-1)[0][label_id].item() | |
| } | |
| except Exception as e: | |
| return { | |
| "error": str(e), | |
| "tip": "请检查输入代码是否包含非ASCII字符" | |
| } |