Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Commit 
							
							·
						
						c0fdd5a
	
1
								Parent(s):
							
							e2b5d99
								
remove bitnet handling completely
Browse files- utils/models.py +2 -46
    	
        utils/models.py
    CHANGED
    
    | @@ -11,7 +11,6 @@ from transformers import ( | |
| 11 | 
             
                AutoTokenizer,
         | 
| 12 | 
             
                AutoModelForCausalLM,
         | 
| 13 | 
             
                StoppingCriteria,
         | 
| 14 | 
            -
                BitNetForCausalLM
         | 
| 15 | 
             
            )
         | 
| 16 | 
             
            from .prompts import format_rag_prompt
         | 
| 17 | 
             
            from .shared import generation_interrupt
         | 
| @@ -156,25 +155,7 @@ def run_inference(model_name, context, question): | |
| 156 |  | 
| 157 | 
             
                    print("REACHED HERE BEFORE pipe")
         | 
| 158 | 
             
                    print(f"Loading model {model_name}...")
         | 
| 159 | 
            -
                    if " | 
| 160 | 
            -
                        bitnet_model = BitNetForCausalLM.from_pretrained(
         | 
| 161 | 
            -
                            model_name,
         | 
| 162 | 
            -
                            #device_map="auto",
         | 
| 163 | 
            -
                            torch_dtype=torch.bfloat16,
         | 
| 164 | 
            -
                            #trust_remote_code=True,
         | 
| 165 | 
            -
                        )
         | 
| 166 | 
            -
                        pipe = pipeline(
         | 
| 167 | 
            -
                            "text-generation",
         | 
| 168 | 
            -
                            model=bitnet_model,
         | 
| 169 | 
            -
                            tokenizer=tokenizer,
         | 
| 170 | 
            -
                            #device_map="auto",
         | 
| 171 | 
            -
                            #trust_remote_code=True,
         | 
| 172 | 
            -
                            torch_dtype=torch.bfloat16,
         | 
| 173 | 
            -
                            model_kwargs={
         | 
| 174 | 
            -
                                "attn_implementation": "eager",
         | 
| 175 | 
            -
                            },
         | 
| 176 | 
            -
                        )
         | 
| 177 | 
            -
                    elif "icecream" not in model_name.lower():
         | 
| 178 | 
             
                        pipe = pipeline(
         | 
| 179 | 
             
                            "text-generation",
         | 
| 180 | 
             
                            model=model_name,
         | 
| @@ -221,12 +202,8 @@ def run_inference(model_name, context, question): | |
| 221 | 
             
                            **tokenizer_kwargs, 
         | 
| 222 | 
             
                        )
         | 
| 223 |  | 
| 224 | 
            -
             | 
| 225 | 
             
                        model_inputs = model_inputs.to(model.device)
         | 
| 226 | 
            -
             | 
| 227 | 
             
                        input_ids = model_inputs.input_ids
         | 
| 228 | 
            -
                        attention_mask = model_inputs.attention_mask 
         | 
| 229 | 
            -
             | 
| 230 | 
             
                        prompt_tokens_length = input_ids.shape[1] 
         | 
| 231 |  | 
| 232 | 
             
                        with torch.inference_mode():
         | 
| @@ -235,33 +212,12 @@ def run_inference(model_name, context, question): | |
| 235 | 
             
                                return ""
         | 
| 236 |  | 
| 237 | 
             
                            output_sequences = model.generate(
         | 
| 238 | 
            -
                                 | 
| 239 | 
            -
                                attention_mask=attention_mask,
         | 
| 240 | 
             
                                max_new_tokens=512,
         | 
| 241 | 
            -
                                eos_token_id=tokenizer.eos_token_id, 
         | 
| 242 | 
            -
                                pad_token_id=tokenizer.pad_token_id  # Addresses the warning
         | 
| 243 | 
             
                            )
         | 
| 244 |  | 
| 245 | 
             
                        generated_token_ids = output_sequences[0][prompt_tokens_length:]
         | 
| 246 | 
             
                        result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
         | 
| 247 | 
            -
                    # elif "bitnet" in model_name.lower():
         | 
| 248 | 
            -
                    #     formatted = tokenizer.apply_chat_template(
         | 
| 249 | 
            -
                    #         text_input,
         | 
| 250 | 
            -
                    #         tokenize=True,
         | 
| 251 | 
            -
                    #         return_tensors="pt",
         | 
| 252 | 
            -
                    #         return_dict=True,
         | 
| 253 | 
            -
                    #         **tokenizer_kwargs,
         | 
| 254 | 
            -
                    #     ).to(bitnet_model.device)
         | 
| 255 | 
            -
                    #     with torch.inference_mode():
         | 
| 256 | 
            -
                    #         # Check interrupt before generation
         | 
| 257 | 
            -
                    #         if generation_interrupt.is_set():
         | 
| 258 | 
            -
                    #             return ""
         | 
| 259 | 
            -
                    #         output_sequences = bitnet_model.generate(
         | 
| 260 | 
            -
                    #             **formatted,
         | 
| 261 | 
            -
                    #             max_new_tokens=512,
         | 
| 262 | 
            -
                    #         )
         | 
| 263 | 
            -
             | 
| 264 | 
            -
                    #         result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
         | 
| 265 | 
             
                    else:  # For other models
         | 
| 266 | 
             
                        formatted = pipe.tokenizer.apply_chat_template(
         | 
| 267 | 
             
                            text_input,
         | 
|  | |
| 11 | 
             
                AutoTokenizer,
         | 
| 12 | 
             
                AutoModelForCausalLM,
         | 
| 13 | 
             
                StoppingCriteria,
         | 
|  | |
| 14 | 
             
            )
         | 
| 15 | 
             
            from .prompts import format_rag_prompt
         | 
| 16 | 
             
            from .shared import generation_interrupt
         | 
|  | |
| 155 |  | 
| 156 | 
             
                    print("REACHED HERE BEFORE pipe")
         | 
| 157 | 
             
                    print(f"Loading model {model_name}...")
         | 
| 158 | 
            +
                    if "icecream" not in model_name.lower():
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 159 | 
             
                        pipe = pipeline(
         | 
| 160 | 
             
                            "text-generation",
         | 
| 161 | 
             
                            model=model_name,
         | 
|  | |
| 202 | 
             
                            **tokenizer_kwargs, 
         | 
| 203 | 
             
                        )
         | 
| 204 |  | 
|  | |
| 205 | 
             
                        model_inputs = model_inputs.to(model.device)
         | 
|  | |
| 206 | 
             
                        input_ids = model_inputs.input_ids
         | 
|  | |
|  | |
| 207 | 
             
                        prompt_tokens_length = input_ids.shape[1] 
         | 
| 208 |  | 
| 209 | 
             
                        with torch.inference_mode():
         | 
|  | |
| 212 | 
             
                                return ""
         | 
| 213 |  | 
| 214 | 
             
                            output_sequences = model.generate(
         | 
| 215 | 
            +
                                **model_inputs,
         | 
|  | |
| 216 | 
             
                                max_new_tokens=512,
         | 
|  | |
|  | |
| 217 | 
             
                            )
         | 
| 218 |  | 
| 219 | 
             
                        generated_token_ids = output_sequences[0][prompt_tokens_length:]
         | 
| 220 | 
             
                        result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 221 | 
             
                    else:  # For other models
         | 
| 222 | 
             
                        formatted = pipe.tokenizer.apply_chat_template(
         | 
| 223 | 
             
                            text_input,
         | 
 
			

