Files changed (2) hide show
  1. app.py +5 -70
  2. requirements.txt +1 -1
app.py CHANGED
@@ -12,8 +12,6 @@ import uvicorn
12
  from fastapi import FastAPI
13
  from fastapi.staticfiles import StaticFiles
14
  import trimesh
15
- from transformers import AutoProcessor, AutoModelForImageClassification
16
- from PIL import Image
17
 
18
  parser = argparse.ArgumentParser()
19
  parser.add_argument("--model_path", type=str, default='tencent/Hunyuan3D-2mini')
@@ -40,11 +38,6 @@ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
40
  HTML_HEIGHT = 500
41
  HTML_WIDTH = 500
42
 
43
- # -------------------- NSFW 检测模型加载 --------------------
44
- nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection")
45
- nsfw_model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection").to(args.device)
46
- # -----------------------------------------------------------
47
-
48
 
49
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
50
  if randomize_seed:
@@ -138,22 +131,10 @@ floater_remove_worker = FloaterRemover()
138
  degenerate_face_remove_worker = DegenerateFaceRemover()
139
  face_reduce_worker = FaceReducer()
140
 
141
-
142
- def detect_nsfw(image: Image.Image, threshold: float = 0.5) -> bool:
143
- """Returns True if image is NSFW"""
144
- inputs = nsfw_processor(images=image, return_tensors="pt").to(args.device)
145
- with torch.no_grad():
146
- outputs = nsfw_model(**inputs)
147
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
148
- nsfw_score = probs[0][1].item() # label 1 = NSFW
149
- return nsfw_score > threshold
150
-
151
-
152
-
153
  progress=gr.Progress()
154
 
155
  @spaces.GPU(duration=40)
156
- def _gen_shape_on_gpu(
157
  image=None,
158
  steps=50,
159
  guidance_scale=7.5,
@@ -171,28 +152,14 @@ def _gen_shape_on_gpu(
171
 
172
 
173
  if image is None:
174
- error_info = {
175
- "error": "Please provide either a caption or an image.",
176
- "status": "failed",
177
- }
178
- return None,None,None,None,error_info
179
 
180
- rgbImage = image.convert('RGB')
181
-
182
- # NSFW 检测
183
- if nsfw_model and nsfw_processor:
184
- if detect_nsfw(rgbImage):
185
- error_info = {
186
- "error": "The input image contains NSFW content and cannot be used. Please provide a different image and try again.",
187
- "status": "failed",
188
- }
189
- return None,None,None,None,error_info
190
 
191
  seed = int(randomize_seed_fn(seed, randomize_seed))
192
  octree_resolution = int(octree_resolution)
193
  save_folder = gen_save_folder()
194
  # 先移除背景
195
- image = rmbg_worker(rgbImage)
196
 
197
  # 生成模型
198
  generator = torch.Generator()
@@ -221,11 +188,7 @@ def _gen_shape_on_gpu(
221
  torch.cuda.empty_cache()
222
 
223
  if path is None:
224
- error_info = {
225
- "error": "'Please generate a mesh first.'",
226
- "status": "failed",
227
- }
228
- return None,None,None,None,error_info
229
 
230
  # 简化模型
231
  print(f'exporting {path}')
@@ -258,37 +221,9 @@ def _gen_shape_on_gpu(
258
 
259
 
260
  progress(1,desc="Complete")
261
- info = {
262
- "status": "success"
263
- }
264
- return model_viewer_html, gr.update(value=sourceObjPath, interactive=True), glbPath, objPath, info
265
 
266
 
267
- def gen_shape(
268
- image=None,
269
- steps=50,
270
- guidance_scale=7.5,
271
- seed=1234,
272
- octree_resolution=256,
273
- num_chunks=200000,
274
- target_face_num=10000,
275
- randomize_seed: bool = False,
276
- ):
277
- # 调用 GPU 函数
278
- html_export_mesh,file_export,glbPath_output,objPath_output, info = _gen_shape_on_gpu(
279
- image,
280
- steps,
281
- guidance_scale,
282
- seed,
283
- octree_resolution,
284
- num_chunks,
285
- target_face_num,
286
- randomize_seed
287
- )
288
- # 如果出错,抛出异常
289
- if info["status"] == "failed":
290
- raise gr.Error(info["error"])
291
- return html_export_mesh, file_export, glbPath_output, objPath_output
292
 
293
  def get_example_img_list():
294
  print('Loading example img list ...')
 
12
  from fastapi import FastAPI
13
  from fastapi.staticfiles import StaticFiles
14
  import trimesh
 
 
15
 
16
  parser = argparse.ArgumentParser()
17
  parser.add_argument("--model_path", type=str, default='tencent/Hunyuan3D-2mini')
 
38
  HTML_HEIGHT = 500
39
  HTML_WIDTH = 500
40
 
 
 
 
 
 
41
 
42
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
43
  if randomize_seed:
 
131
  degenerate_face_remove_worker = DegenerateFaceRemover()
132
  face_reduce_worker = FaceReducer()
133
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  progress=gr.Progress()
135
 
136
  @spaces.GPU(duration=40)
137
+ def gen_shape(
138
  image=None,
139
  steps=50,
140
  guidance_scale=7.5,
 
152
 
153
 
154
  if image is None:
155
+ raise gr.Error("Please provide either a caption or an image.")
 
 
 
 
156
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  seed = int(randomize_seed_fn(seed, randomize_seed))
159
  octree_resolution = int(octree_resolution)
160
  save_folder = gen_save_folder()
161
  # 先移除背景
162
+ image = rmbg_worker(image.convert('RGB'))
163
 
164
  # 生成模型
165
  generator = torch.Generator()
 
188
  torch.cuda.empty_cache()
189
 
190
  if path is None:
191
+ raise gr.Error('Please generate a mesh first.')
 
 
 
 
192
 
193
  # 简化模型
194
  print(f'exporting {path}')
 
221
 
222
 
223
  progress(1,desc="Complete")
224
+ return model_viewer_html, gr.update(value=sourceObjPath, interactive=True), glbPath, objPath
 
 
 
225
 
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  def get_example_img_list():
229
  print('Loading example img list ...')
requirements.txt CHANGED
@@ -18,7 +18,7 @@ tqdm
18
 
19
  # Mesh Processing
20
  trimesh
21
- pymeshlab==2023.12.post3
22
  pygltflib
23
  xatlas
24
  #kornia
 
18
 
19
  # Mesh Processing
20
  trimesh
21
+ pymeshlab
22
  pygltflib
23
  xatlas
24
  #kornia