Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -12,6 +12,8 @@ import uvicorn | |
| 12 | 
             
            from fastapi import FastAPI
         | 
| 13 | 
             
            from fastapi.staticfiles import StaticFiles
         | 
| 14 | 
             
            import trimesh
         | 
|  | |
|  | |
| 15 |  | 
| 16 | 
             
            parser = argparse.ArgumentParser()
         | 
| 17 | 
             
            parser.add_argument("--model_path", type=str, default='tencent/Hunyuan3D-2mini')
         | 
| @@ -38,6 +40,11 @@ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| 38 | 
             
            HTML_HEIGHT = 500
         | 
| 39 | 
             
            HTML_WIDTH = 500
         | 
| 40 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 41 |  | 
| 42 | 
             
            def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
         | 
| 43 | 
             
                if randomize_seed:
         | 
| @@ -131,10 +138,22 @@ floater_remove_worker = FloaterRemover() | |
| 131 | 
             
            degenerate_face_remove_worker = DegenerateFaceRemover()
         | 
| 132 | 
             
            face_reduce_worker = FaceReducer()
         | 
| 133 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 134 | 
             
            progress=gr.Progress()
         | 
| 135 |  | 
| 136 | 
             
            @spaces.GPU(duration=40)
         | 
| 137 | 
            -
            def  | 
| 138 | 
             
                image=None,
         | 
| 139 | 
             
                steps=50,
         | 
| 140 | 
             
                guidance_scale=7.5,
         | 
| @@ -152,14 +171,28 @@ def gen_shape( | |
| 152 |  | 
| 153 |  | 
| 154 | 
             
                if image is None:
         | 
| 155 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
| 156 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 157 |  | 
| 158 | 
             
                seed = int(randomize_seed_fn(seed, randomize_seed))
         | 
| 159 | 
             
                octree_resolution = int(octree_resolution)
         | 
| 160 | 
             
                save_folder = gen_save_folder()
         | 
| 161 | 
             
                # 先移除背景
         | 
| 162 | 
            -
                image = rmbg_worker( | 
| 163 |  | 
| 164 | 
             
                # 生成模型
         | 
| 165 | 
             
                generator = torch.Generator()
         | 
| @@ -188,7 +221,11 @@ def gen_shape( | |
| 188 | 
             
                    torch.cuda.empty_cache()
         | 
| 189 |  | 
| 190 | 
             
                if path is None:
         | 
| 191 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
| 192 |  | 
| 193 | 
             
                # 简化模型
         | 
| 194 | 
             
                print(f'exporting {path}')
         | 
| @@ -221,9 +258,37 @@ def gen_shape( | |
| 221 |  | 
| 222 |  | 
| 223 | 
             
                progress(1,desc="Complete")
         | 
| 224 | 
            -
                 | 
|  | |
|  | |
|  | |
| 225 |  | 
| 226 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 227 |  | 
| 228 | 
             
            def get_example_img_list():
         | 
| 229 | 
             
                print('Loading example img list ...')
         | 
|  | |
| 12 | 
             
            from fastapi import FastAPI
         | 
| 13 | 
             
            from fastapi.staticfiles import StaticFiles
         | 
| 14 | 
             
            import trimesh
         | 
| 15 | 
            +
            from transformers import AutoProcessor, AutoModelForImageClassification
         | 
| 16 | 
            +
            from PIL import Image
         | 
| 17 |  | 
| 18 | 
             
            parser = argparse.ArgumentParser()
         | 
| 19 | 
             
            parser.add_argument("--model_path", type=str, default='tencent/Hunyuan3D-2mini')
         | 
|  | |
| 40 | 
             
            HTML_HEIGHT = 500
         | 
| 41 | 
             
            HTML_WIDTH = 500
         | 
| 42 |  | 
| 43 | 
            +
            # -------------------- NSFW 检测模型加载 --------------------
         | 
| 44 | 
            +
            nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection")
         | 
| 45 | 
            +
            nsfw_model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection").to(args.device)
         | 
| 46 | 
            +
            # -----------------------------------------------------------
         | 
| 47 | 
            +
             | 
| 48 |  | 
| 49 | 
             
            def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
         | 
| 50 | 
             
                if randomize_seed:
         | 
|  | |
| 138 | 
             
            degenerate_face_remove_worker = DegenerateFaceRemover()
         | 
| 139 | 
             
            face_reduce_worker = FaceReducer()
         | 
| 140 |  | 
| 141 | 
            +
             | 
| 142 | 
            +
            def detect_nsfw(image: Image.Image, threshold: float = 0.5) -> bool:
         | 
| 143 | 
            +
                """Returns True if image is NSFW"""
         | 
| 144 | 
            +
                inputs = nsfw_processor(images=image, return_tensors="pt").to(args.device)
         | 
| 145 | 
            +
                with torch.no_grad():
         | 
| 146 | 
            +
                    outputs = nsfw_model(**inputs)
         | 
| 147 | 
            +
                    probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
         | 
| 148 | 
            +
                    nsfw_score = probs[0][1].item()  # label 1 = NSFW
         | 
| 149 | 
            +
                return nsfw_score > threshold
         | 
| 150 | 
            +
             | 
| 151 | 
            +
             | 
| 152 | 
            +
             | 
| 153 | 
             
            progress=gr.Progress()
         | 
| 154 |  | 
| 155 | 
             
            @spaces.GPU(duration=40)
         | 
| 156 | 
            +
            def _gen_shape_on_gpu(
         | 
| 157 | 
             
                image=None,
         | 
| 158 | 
             
                steps=50,
         | 
| 159 | 
             
                guidance_scale=7.5,
         | 
|  | |
| 171 |  | 
| 172 |  | 
| 173 | 
             
                if image is None:
         | 
| 174 | 
            +
                    error_info = {
         | 
| 175 | 
            +
                        "error": "Please provide either a caption or an image.",
         | 
| 176 | 
            +
                        "status": "failed",
         | 
| 177 | 
            +
                    }
         | 
| 178 | 
            +
                    return None,None,None,None,error_info
         | 
| 179 |  | 
| 180 | 
            +
                rgbImage = image.convert('RGB')
         | 
| 181 | 
            +
                
         | 
| 182 | 
            +
                # NSFW 检测
         | 
| 183 | 
            +
                if nsfw_model and nsfw_processor:
         | 
| 184 | 
            +
                    if detect_nsfw(rgbImage):
         | 
| 185 | 
            +
                        error_info = {
         | 
| 186 | 
            +
                            "error": "The input image contains NSFW content and cannot be used. Please provide a different image and try again.",
         | 
| 187 | 
            +
                            "status": "failed",
         | 
| 188 | 
            +
                        }
         | 
| 189 | 
            +
                        return None,None,None,None,error_info
         | 
| 190 |  | 
| 191 | 
             
                seed = int(randomize_seed_fn(seed, randomize_seed))
         | 
| 192 | 
             
                octree_resolution = int(octree_resolution)
         | 
| 193 | 
             
                save_folder = gen_save_folder()
         | 
| 194 | 
             
                # 先移除背景
         | 
| 195 | 
            +
                image = rmbg_worker(rgbImage)
         | 
| 196 |  | 
| 197 | 
             
                # 生成模型
         | 
| 198 | 
             
                generator = torch.Generator()
         | 
|  | |
| 221 | 
             
                    torch.cuda.empty_cache()
         | 
| 222 |  | 
| 223 | 
             
                if path is None:
         | 
| 224 | 
            +
                    error_info = {
         | 
| 225 | 
            +
                        "error": "'Please generate a mesh first.'",
         | 
| 226 | 
            +
                        "status": "failed",
         | 
| 227 | 
            +
                    }
         | 
| 228 | 
            +
                    return None,None,None,None,error_info
         | 
| 229 |  | 
| 230 | 
             
                # 简化模型
         | 
| 231 | 
             
                print(f'exporting {path}')
         | 
|  | |
| 258 |  | 
| 259 |  | 
| 260 | 
             
                progress(1,desc="Complete")
         | 
| 261 | 
            +
                info = {
         | 
| 262 | 
            +
                    "status": "success"
         | 
| 263 | 
            +
                }
         | 
| 264 | 
            +
                return model_viewer_html, gr.update(value=sourceObjPath, interactive=True), glbPath, objPath, info
         | 
| 265 |  | 
| 266 |  | 
| 267 | 
            +
            def gen_shape(
         | 
| 268 | 
            +
                image=None,
         | 
| 269 | 
            +
                steps=50,
         | 
| 270 | 
            +
                guidance_scale=7.5,
         | 
| 271 | 
            +
                seed=1234,
         | 
| 272 | 
            +
                octree_resolution=256,
         | 
| 273 | 
            +
                num_chunks=200000,
         | 
| 274 | 
            +
                target_face_num=10000,
         | 
| 275 | 
            +
                randomize_seed: bool = False,
         | 
| 276 | 
            +
            ):
         | 
| 277 | 
            +
                # 调用 GPU 函数
         | 
| 278 | 
            +
                html_export_mesh,file_export,glbPath_output,objPath_output, info = _gen_shape_on_gpu(
         | 
| 279 | 
            +
                    image, 
         | 
| 280 | 
            +
                    steps,
         | 
| 281 | 
            +
                    guidance_scale, 
         | 
| 282 | 
            +
                    seed,
         | 
| 283 | 
            +
                    octree_resolution,
         | 
| 284 | 
            +
                    num_chunks, 
         | 
| 285 | 
            +
                    target_face_num,
         | 
| 286 | 
            +
                    randomize_seed
         | 
| 287 | 
            +
                )
         | 
| 288 | 
            +
                # 如果出错,抛出异常
         | 
| 289 | 
            +
                if info["status"] == "failed":
         | 
| 290 | 
            +
                    raise gr.Error(info["error"])
         | 
| 291 | 
            +
                return html_export_mesh, file_export, glbPath_output, objPath_output
         | 
| 292 |  | 
| 293 | 
             
            def get_example_img_list():
         | 
| 294 | 
             
                print('Loading example img list ...')
         |