Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Upload 3 files
Browse files- app.py +1 -0
- prompt_generator.py +3 -3
    	
        app.py
    CHANGED
    
    | @@ -48,6 +48,7 @@ logger = logging.getLogger(__name__) | |
| 48 |  | 
| 49 | 
             
            # Constants
         | 
| 50 | 
             
            IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
         | 
|  | |
| 51 | 
             
            CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
         | 
| 52 |  | 
| 53 | 
             
            # PyTorch settings for better performance and determinism
         | 
|  | |
| 48 |  | 
| 49 | 
             
            # Constants
         | 
| 50 | 
             
            IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
         | 
| 51 | 
            +
            HF_TOKEN = os.getenv("HF_TOKEN")
         | 
| 52 | 
             
            CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
         | 
| 53 |  | 
| 54 | 
             
            # PyTorch settings for better performance and determinism
         | 
    	
        prompt_generator.py
    CHANGED
    
    | @@ -119,8 +119,8 @@ def load_model(): | |
| 119 | 
             
                    _model = AutoModelForCausalLM.from_pretrained(
         | 
| 120 | 
             
                        model_path,
         | 
| 121 | 
             
                        torch_dtype=torch_dtype,
         | 
| 122 | 
            -
             | 
| 123 | 
            -
             | 
| 124 | 
             
                        low_cpu_mem_usage=True,
         | 
| 125 | 
             
                    )
         | 
| 126 |  | 
| @@ -277,7 +277,7 @@ masterpiece, best quality, highresなどの品質に関連するタグは後工 | |
| 277 | 
             
                        logger.warning(f"Input tokens were too many and have been truncated to {max_input_length}")
         | 
| 278 |  | 
| 279 | 
             
                    # 生成
         | 
| 280 | 
            -
                    logger.info("before  | 
| 281 | 
             
                    with torch.no_grad():
         | 
| 282 | 
             
                        generated_ids = model.generate(
         | 
| 283 | 
             
                            input_ids=inputs,
         | 
|  | |
| 119 | 
             
                    _model = AutoModelForCausalLM.from_pretrained(
         | 
| 120 | 
             
                        model_path,
         | 
| 121 | 
             
                        torch_dtype=torch_dtype,
         | 
| 122 | 
            +
                        device_map=device_map,
         | 
| 123 | 
            +
                        use_cache=True,
         | 
| 124 | 
             
                        low_cpu_mem_usage=True,
         | 
| 125 | 
             
                    )
         | 
| 126 |  | 
|  | |
| 277 | 
             
                        logger.warning(f"Input tokens were too many and have been truncated to {max_input_length}")
         | 
| 278 |  | 
| 279 | 
             
                    # 生成
         | 
| 280 | 
            +
                    logger.info("before torch.no_grad")
         | 
| 281 | 
             
                    with torch.no_grad():
         | 
| 282 | 
             
                        generated_ids = model.generate(
         | 
| 283 | 
             
                            input_ids=inputs,
         |