Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						0d83943
	
1
								Parent(s):
							
							b725492
								
remove post-processing
Browse files
    	
        app.py
    CHANGED
    
    | @@ -25,23 +25,15 @@ def predict(steps=100, seed=42,scheduler="ddim"): | |
| 25 | 
             
                torch.cuda.empty_cache()
         | 
| 26 | 
             
                generator = torch.manual_seed(seed)
         | 
| 27 | 
             
                if(scheduler == "ddim"):
         | 
| 28 | 
            -
                     | 
| 29 | 
             
                elif(scheduler == "ddpm"):
         | 
| 30 | 
            -
                     | 
| 31 | 
             
                elif(scheduler == "pndm"):
         | 
| 32 | 
             
                    if(steps > 100):
         | 
| 33 | 
             
                        steps = 100
         | 
| 34 | 
            -
                     | 
| 35 | 
            -
             | 
| 36 | 
            -
                 | 
| 37 | 
            -
                if scheduler == "pndm":
         | 
| 38 | 
            -
                    image_processed = (image_processed + 1.0) / 2
         | 
| 39 | 
            -
                    image_processed = torch.clamp(image_processed, 0.0, 1.0)
         | 
| 40 | 
            -
                    image_processed = image_processed * 255
         | 
| 41 | 
            -
                else:
         | 
| 42 | 
            -
                    image_processed = (image_processed + 1.0) * 127.5
         | 
| 43 | 
            -
                image_processed = image_processed.detach().numpy().astype(np.uint8)
         | 
| 44 | 
            -
                return(PIL.Image.fromarray(image_processed[0]))
         | 
| 45 |  | 
| 46 |  | 
| 47 | 
             
            random_seed = random.randint(0, 2147483647)
         | 
|  | |
| 25 | 
             
                torch.cuda.empty_cache()
         | 
| 26 | 
             
                generator = torch.manual_seed(seed)
         | 
| 27 | 
             
                if(scheduler == "ddim"):
         | 
| 28 | 
            +
                    images = ddim_pipeline(generator=generator, num_inference_steps=steps)["sample"]
         | 
| 29 | 
             
                elif(scheduler == "ddpm"):
         | 
| 30 | 
            +
                    images = ddpm_pipeline(generator=generator)["sample"]
         | 
| 31 | 
             
                elif(scheduler == "pndm"):
         | 
| 32 | 
             
                    if(steps > 100):
         | 
| 33 | 
             
                        steps = 100
         | 
| 34 | 
            +
                    images = pndm_pipeline(generator=generator, num_inference_steps=steps)["sample"]
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                return(images[0])
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 37 |  | 
| 38 |  | 
| 39 | 
             
            random_seed = random.randint(0, 2147483647)
         | 
