frogleo commited on
Commit
66c7593
·
verified ·
1 Parent(s): 3d711d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -5
app.py CHANGED
@@ -12,6 +12,8 @@ import uvicorn
12
  from fastapi import FastAPI
13
  from fastapi.staticfiles import StaticFiles
14
  import trimesh
 
 
15
 
16
  parser = argparse.ArgumentParser()
17
  parser.add_argument("--model_path", type=str, default='tencent/Hunyuan3D-2mini')
@@ -38,6 +40,11 @@ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
38
  HTML_HEIGHT = 500
39
  HTML_WIDTH = 500
40
 
 
 
 
 
 
41
 
42
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
43
  if randomize_seed:
@@ -131,10 +138,22 @@ floater_remove_worker = FloaterRemover()
131
  degenerate_face_remove_worker = DegenerateFaceRemover()
132
  face_reduce_worker = FaceReducer()
133
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  progress=gr.Progress()
135
 
136
  @spaces.GPU(duration=40)
137
- def gen_shape(
138
  image=None,
139
  steps=50,
140
  guidance_scale=7.5,
@@ -152,14 +171,28 @@ def gen_shape(
152
 
153
 
154
  if image is None:
155
- raise gr.Error("Please provide either a caption or an image.")
 
 
 
 
156
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  seed = int(randomize_seed_fn(seed, randomize_seed))
159
  octree_resolution = int(octree_resolution)
160
  save_folder = gen_save_folder()
161
  # 先移除背景
162
- image = rmbg_worker(image.convert('RGB'))
163
 
164
  # 生成模型
165
  generator = torch.Generator()
@@ -188,7 +221,11 @@ def gen_shape(
188
  torch.cuda.empty_cache()
189
 
190
  if path is None:
191
- raise gr.Error('Please generate a mesh first.')
 
 
 
 
192
 
193
  # 简化模型
194
  print(f'exporting {path}')
@@ -221,9 +258,37 @@ def gen_shape(
221
 
222
 
223
  progress(1,desc="Complete")
224
- return model_viewer_html, gr.update(value=sourceObjPath, interactive=True), glbPath, objPath
 
 
 
225
 
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  def get_example_img_list():
229
  print('Loading example img list ...')
 
12
  from fastapi import FastAPI
13
  from fastapi.staticfiles import StaticFiles
14
  import trimesh
15
+ from transformers import AutoProcessor, AutoModelForImageClassification
16
+ from PIL import Image
17
 
18
  parser = argparse.ArgumentParser()
19
  parser.add_argument("--model_path", type=str, default='tencent/Hunyuan3D-2mini')
 
40
  HTML_HEIGHT = 500
41
  HTML_WIDTH = 500
42
 
43
+ # -------------------- NSFW 检测模型加载 --------------------
44
+ nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection")
45
+ nsfw_model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection").to(args.device)
46
+ # -----------------------------------------------------------
47
+
48
 
49
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
50
  if randomize_seed:
 
138
  degenerate_face_remove_worker = DegenerateFaceRemover()
139
  face_reduce_worker = FaceReducer()
140
 
141
+
142
+ def detect_nsfw(image: Image.Image, threshold: float = 0.5) -> bool:
143
+ """Returns True if image is NSFW"""
144
+ inputs = nsfw_processor(images=image, return_tensors="pt").to(args.device)
145
+ with torch.no_grad():
146
+ outputs = nsfw_model(**inputs)
147
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
148
+ nsfw_score = probs[0][1].item() # label 1 = NSFW
149
+ return nsfw_score > threshold
150
+
151
+
152
+
153
  progress=gr.Progress()
154
 
155
  @spaces.GPU(duration=40)
156
+ def _gen_shape_on_gpu(
157
  image=None,
158
  steps=50,
159
  guidance_scale=7.5,
 
171
 
172
 
173
  if image is None:
174
+ error_info = {
175
+ "error": "Please provide either a caption or an image.",
176
+ "status": "failed",
177
+ }
178
+ return None,None,None,None,error_info
179
 
180
+ rgbImage = image.convert('RGB')
181
+
182
+ # NSFW 检测
183
+ if nsfw_model and nsfw_processor:
184
+ if detect_nsfw(rgbImage):
185
+ error_info = {
186
+ "error": "The input image contains NSFW content and cannot be used. Please provide a different image and try again.",
187
+ "status": "failed",
188
+ }
189
+ return None,None,None,None,error_info
190
 
191
  seed = int(randomize_seed_fn(seed, randomize_seed))
192
  octree_resolution = int(octree_resolution)
193
  save_folder = gen_save_folder()
194
  # 先移除背景
195
+ image = rmbg_worker(rgbImage)
196
 
197
  # 生成模型
198
  generator = torch.Generator()
 
221
  torch.cuda.empty_cache()
222
 
223
  if path is None:
224
+ error_info = {
225
+ "error": "'Please generate a mesh first.'",
226
+ "status": "failed",
227
+ }
228
+ return None,None,None,None,error_info
229
 
230
  # 简化模型
231
  print(f'exporting {path}')
 
258
 
259
 
260
  progress(1,desc="Complete")
261
+ info = {
262
+ "status": "success"
263
+ }
264
+ return model_viewer_html, gr.update(value=sourceObjPath, interactive=True), glbPath, objPath, info
265
 
266
 
267
+ def gen_shape(
268
+ image=None,
269
+ steps=50,
270
+ guidance_scale=7.5,
271
+ seed=1234,
272
+ octree_resolution=256,
273
+ num_chunks=200000,
274
+ target_face_num=10000,
275
+ randomize_seed: bool = False,
276
+ ):
277
+ # 调用 GPU 函数
278
+ html_export_mesh,file_export,glbPath_output,objPath_output, info = _gen_shape_on_gpu(
279
+ image,
280
+ steps,
281
+ guidance_scale,
282
+ seed,
283
+ octree_resolution,
284
+ num_chunks,
285
+ target_face_num,
286
+ randomize_seed
287
+ )
288
+ # 如果出错,抛出异常
289
+ if info["status"] == "failed":
290
+ raise gr.Error(info["error"])
291
+ return html_export_mesh, file_export, glbPath_output, objPath_output
292
 
293
  def get_example_img_list():
294
  print('Loading example img list ...')