Spaces:
Running
Running
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware # 新增 CORS 支持 | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import os | |
| # === FastAPI 初始化 === | |
| app = FastAPI() | |
| # 添加 CORS 中间件(关键步骤) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # 允许所有来源 | |
| allow_methods=["*"], # 允许所有 HTTP 方法 | |
| allow_headers=["*"], # 允许所有请求头 | |
| ) | |
| # === 模型加载 === | |
| os.environ["HF_HOME"] = "/app/.cache/huggingface" | |
| model = AutoModelForSequenceClassification.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code") | |
| tokenizer = AutoTokenizer.from_pretrained("mrm8488/codebert-base-finetuned-detect-insecure-code") | |
| # === HTTP API 接口 === | |
| async def api_detect(code: str): | |
| """HTTP API 接口""" | |
| try: | |
| inputs = tokenizer(code[:2000], return_tensors="pt", truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| label_id = outputs.logits.argmax().item() | |
| return { | |
| "label": int(label_id), # 强制返回 0/1 数字 | |
| "score": outputs.logits.softmax(dim=-1)[0][label_id].item() | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # === Gradio 界面(可选)=== | |
| def gradio_predict(code: str): | |
| result = api_detect(code) | |
| return f"Prediction: {result['label']} (Confidence: {result['score']:.2f})" | |
| gr_interface = gr.Interface( | |
| fn=gradio_predict, | |
| inputs=gr.Textbox(lines=10, placeholder="Paste code here..."), | |
| outputs="text", | |
| title="Code Security Detector" | |
| ) | |
| app = gr.mount_gradio_app(app, gr_interface, path="/") |