Spaces:
Running
Running
File size: 8,538 Bytes
7223d40 8721ec2 7223d40 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 |
import os
from PIL import Image
import numpy as np
import onnxruntime as ort
import json
from huggingface_hub import hf_hub_download
class NSFWDetector:
"""
NSFW检测器类,使用YOLOv9模型进行图像分类
"""
def __init__(self, repo_id="Falconsai/nsfw_image_detection",
model_filename="falconsai_yolov9_nsfw_model_quantized.pt",
labels_filename="labels.json",
input_size=(224, 224)):
"""
初始化NSFW检测器
Args:
repo_id (str): Hugging Face仓库ID
model_filename (str): 模型文件名
labels_filename (str): 标签文件名
input_size (tuple): 模型输入尺寸 (height, width)
"""
self.repo_id = repo_id
self.model_filename = model_filename
self.labels_filename = labels_filename
self.input_size = input_size
# 从Hugging Face下载文件
self.model_path = self._download_model()
self.labels_path = self._download_labels()
# 加载标签
self.labels = self._load_labels()
# 加载模型
self.session = self._load_model()
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
def _download_model(self):
"""
从Hugging Face下载模型文件
Returns:
str: 下载的模型文件路径
"""
try:
print(f"正在从 {self.repo_id} 下载模型文件: {self.model_filename}")
model_path = hf_hub_download(
repo_id=self.repo_id,
filename=self.model_filename,
cache_dir="./hf_cache"
)
print(f"✅ 模型下载成功: {model_path}")
return model_path
except Exception as e:
raise RuntimeError(f"模型下载失败: {e}")
def _download_labels(self):
"""
从Hugging Face下载标签文件
Returns:
str: 下载的标签文件路径
"""
try:
print(f"正在从 {self.repo_id} 下载标签文件: {self.labels_filename}")
labels_path = hf_hub_download(
repo_id=self.repo_id,
filename=self.labels_filename,
cache_dir="./hf_cache"
)
print(f"✅ 标签文件下载成功: {labels_path}")
return labels_path
except Exception as e:
raise RuntimeError(f"标签文件下载失败: {e}")
def _load_labels(self):
"""
加载类别标签
Returns:
dict: 标签字典
"""
try:
with open(self.labels_path, "r") as f:
return json.load(f)
except FileNotFoundError:
raise FileNotFoundError(f"标签文件未找到: {self.labels_path}")
except json.JSONDecodeError:
raise ValueError(f"标签文件格式错误: {self.labels_path}")
def _load_model(self):
"""
加载ONNX模型
Returns:
onnxruntime.InferenceSession: 模型会话
"""
try:
return ort.InferenceSession(self.model_path)
except Exception as e:
raise RuntimeError(f"模型加载失败: {self.model_path}, 错误: {e}")
def _preprocess_image(self, image_path):
"""
图像预处理
Args:
image_path (str): 图像文件路径
Returns:
tuple: (预处理后的张量, 原始图像)
"""
try:
# 加载并转换图像
original_image = Image.open(image_path).convert("RGB")
# 调整尺寸
image_resized = original_image.resize(self.input_size, Image.Resampling.BILINEAR)
# 转换为numpy数组并归一化
image_np = np.array(image_resized, dtype=np.float32) / 255.0
# 调整维度顺序 [H, W, C] -> [C, H, W]
image_np = np.transpose(image_np, (2, 0, 1))
# 添加批次维度 [C, H, W] -> [1, C, H, W]
input_tensor = np.expand_dims(image_np, axis=0).astype(np.float32)
return input_tensor, original_image
except FileNotFoundError:
raise FileNotFoundError(f"图像文件未找到: {image_path}")
except Exception as e:
raise RuntimeError(f"图像预处理失败: {e}")
def _postprocess_predictions(self, predictions):
"""
后处理预测结果
Args:
predictions: 模型预测输出
Returns:
str: 预测的类别标签
"""
predicted_index = np.argmax(predictions)
predicted_label = self.labels[str(predicted_index)]
return predicted_label
def predict(self, image_path):
"""
对单张图像进行NSFW检测
Args:
image_path (str): 图像文件路径
Returns:
tuple: (预测标签, 原始图像)
"""
# 预处理图像
input_tensor, original_image = self._preprocess_image(image_path)
# 运行推理
outputs = self.session.run([self.output_name], {self.input_name: input_tensor})
predictions = outputs[0]
# 后处理结果
predicted_label = self._postprocess_predictions(predictions)
return predicted_label, original_image
def predict_label_only(self, image_path):
"""
只返回预测标签(不返回图像)
Args:
image_path (str): 图像文件路径
Returns:
str: 预测的类别标签
"""
predicted_label, _ = self.predict(image_path)
return predicted_label
def predict_from_pil(self, pil_image):
"""
直接从PIL Image对象进行NSFW检测
Args:
pil_image (PIL.Image): PIL图像对象
Returns:
tuple: (预测标签, 原始图像)
"""
try:
# 确保是RGB格式
if pil_image.mode != "RGB":
pil_image = pil_image.convert("RGB")
# 调整尺寸
image_resized = pil_image.resize(self.input_size, Image.Resampling.BILINEAR)
# 转换为numpy数组并归一化
image_np = np.array(image_resized, dtype=np.float32) / 255.0
# 调整维度顺序 [H, W, C] -> [C, H, W]
image_np = np.transpose(image_np, (2, 0, 1))
# 添加批次维度 [C, H, W] -> [1, C, H, W]
input_tensor = np.expand_dims(image_np, axis=0).astype(np.float32)
# 运行推理
outputs = self.session.run([self.output_name], {self.input_name: input_tensor})
predictions = outputs[0]
# 后处理结果
predicted_label = self._postprocess_predictions(predictions)
return predicted_label, pil_image
except Exception as e:
raise RuntimeError(f"PIL图像预测失败: {e}")
def predict_pil_label_only(self, pil_image):
"""
从PIL Image对象只返回预测标签
Args:
pil_image (PIL.Image): PIL图像对象
Returns:
str: 预测的类别标签
"""
predicted_label, _ = self.predict_from_pil(pil_image)
return predicted_label
# --- 使用示例 ---
if __name__ == "__main__":
# 配置参数
single_image_path = "datas/bad01.jpg"
try:
# 创建检测器实例(自动从Hugging Face下载)
detector = NSFWDetector()
# 检查图像文件是否存在
if os.path.exists(single_image_path):
# 进行预测
predicted_label = detector.predict_label_only(single_image_path)
print(f"图像文件: {single_image_path}")
print(f"预测结果: {predicted_label}")
else:
print(f"错误: 指定的图像文件不存在: {single_image_path}")
except Exception as e:
print(f"初始化检测器时发生错误: {e}") |