Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Commit 
							
							·
						
						eb1a863
	
1
								Parent(s):
							
							3ff8c65
								
reactivate gemma models, add some ice cream, support peft
Browse files- requirements.txt +2 -1
- utils/models.py +156 -66
- utils/prompts.py +1 -1
    	
        requirements.txt
    CHANGED
    
    | @@ -7,4 +7,5 @@ openai>=1.60.2 | |
| 7 | 
             
            torch>=2.5.1
         | 
| 8 | 
             
            tqdm==4.67.1
         | 
| 9 | 
             
            vllm>=0.8.5
         | 
| 10 | 
            -
            spaces
         | 
|  | 
|  | |
| 7 | 
             
            torch>=2.5.1
         | 
| 8 | 
             
            tqdm==4.67.1
         | 
| 9 | 
             
            vllm>=0.8.5
         | 
| 10 | 
            +
            spaces
         | 
| 11 | 
            +
            peft>=0.15.1
         | 
    	
        utils/models.py
    CHANGED
    
    | @@ -1,32 +1,41 @@ | |
| 1 | 
             
            import os
         | 
| 2 | 
            -
             | 
|  | |
| 3 | 
             
            import spaces
         | 
|  | |
|  | |
| 4 |  | 
| 5 | 
             
            import torch
         | 
| 6 | 
            -
            from transformers import  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 7 | 
             
            from .prompts import format_rag_prompt
         | 
| 8 | 
             
            from .shared import generation_interrupt
         | 
| 9 |  | 
| 10 | 
             
            models = {
         | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
                 | 
| 16 | 
            -
                 | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
             
                "Cogito-v1-preview-llama-3b": "deepcogito/cogito-v1-preview-llama-3b",
         | 
| 20 | 
            -
             | 
| 21 | 
             
                # #"Bitnet-b1.58-2B4T": "microsoft/bitnet-b1.58-2B-4T",
         | 
| 22 | 
             
                # #"MiniCPM3-RAG-LoRA": "openbmb/MiniCPM3-RAG-LoRA",
         | 
| 23 | 
             
                "Qwen3-0.6b": "qwen/qwen3-0.6b",
         | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
             
            }
         | 
| 31 |  | 
| 32 | 
             
            tokenizer_cache = {}
         | 
| @@ -34,14 +43,16 @@ tokenizer_cache = {} | |
| 34 | 
             
            # List of model names for easy access
         | 
| 35 | 
             
            model_names = list(models.keys())
         | 
| 36 |  | 
|  | |
| 37 | 
             
            # Custom stopping criteria that checks the interrupt flag
         | 
| 38 | 
             
            class InterruptCriteria(StoppingCriteria):
         | 
| 39 | 
             
                def __init__(self, interrupt_event):
         | 
| 40 | 
             
                    self.interrupt_event = interrupt_event
         | 
| 41 | 
            -
             | 
| 42 | 
             
                def __call__(self, input_ids, scores, **kwargs):
         | 
| 43 | 
             
                    return self.interrupt_event.is_set()
         | 
| 44 |  | 
|  | |
| 45 | 
             
            @spaces.GPU
         | 
| 46 | 
             
            def generate_summaries(example, model_a_name, model_b_name):
         | 
| 47 | 
             
                """
         | 
| @@ -49,48 +60,49 @@ def generate_summaries(example, model_a_name, model_b_name): | |
| 49 | 
             
                """
         | 
| 50 | 
             
                if generation_interrupt.is_set():
         | 
| 51 | 
             
                    return "", ""
         | 
| 52 | 
            -
             | 
| 53 | 
             
                context_text = ""
         | 
| 54 | 
             
                context_parts = []
         | 
| 55 | 
            -
             | 
| 56 | 
             
                if "full_contexts" in example and example["full_contexts"]:
         | 
| 57 | 
             
                    for i, ctx in enumerate(example["full_contexts"]):
         | 
| 58 | 
             
                        content = ""
         | 
| 59 | 
            -
             | 
| 60 | 
             
                        # Extract content from either dict or string
         | 
| 61 | 
             
                        if isinstance(ctx, dict) and "content" in ctx:
         | 
| 62 | 
             
                            content = ctx["content"]
         | 
| 63 | 
             
                        elif isinstance(ctx, str):
         | 
| 64 | 
             
                            content = ctx
         | 
| 65 | 
            -
             | 
| 66 | 
             
                        # Add document number if not already present
         | 
| 67 | 
             
                        if not content.strip().startswith("Document"):
         | 
| 68 | 
            -
                            content = f"Document {i+1}:\n{content}"
         | 
| 69 | 
            -
             | 
| 70 | 
             
                        context_parts.append(content)
         | 
| 71 | 
            -
             | 
| 72 | 
             
                    context_text = "\n\n".join(context_parts)
         | 
| 73 | 
             
                else:
         | 
| 74 | 
             
                    # Provide a graceful fallback instead of raising an error
         | 
| 75 | 
             
                    print("Warning: No full context found in the example, using empty context")
         | 
| 76 | 
             
                    context_text = ""
         | 
| 77 | 
            -
             | 
| 78 | 
             
                question = example.get("question", "")
         | 
| 79 | 
            -
             | 
| 80 | 
             
                if generation_interrupt.is_set():
         | 
| 81 | 
             
                    return "", ""
         | 
| 82 | 
            -
             | 
| 83 | 
             
                # Run model A
         | 
| 84 | 
             
                summary_a = run_inference(models[model_a_name], context_text, question)
         | 
| 85 | 
            -
             | 
| 86 | 
             
                if generation_interrupt.is_set():
         | 
| 87 | 
             
                    return summary_a, ""
         | 
| 88 | 
            -
             | 
| 89 | 
             
                # Run model B
         | 
| 90 | 
             
                summary_b = run_inference(models[model_b_name], context_text, question)
         | 
| 91 | 
            -
             | 
| 92 | 
             
                return summary_a, summary_b
         | 
| 93 |  | 
|  | |
| 94 | 
             
            @spaces.GPU
         | 
| 95 | 
             
            def run_inference(model_name, context, question):
         | 
| 96 | 
             
                """
         | 
| @@ -105,29 +117,40 @@ def run_inference(model_name, context, question): | |
| 105 | 
             
                result = ""
         | 
| 106 | 
             
                tokenizer_kwargs = {
         | 
| 107 | 
             
                    "add_generation_prompt": True,
         | 
| 108 | 
            -
                } | 
| 109 | 
             
                generation_kwargs = {
         | 
| 110 | 
             
                    "max_new_tokens": 512,
         | 
| 111 | 
             
                }
         | 
| 112 | 
            -
                if "qwen3" in model_name.lower(): | 
| 113 | 
            -
                    print( | 
|  | |
|  | |
| 114 | 
             
                    tokenizer_kwargs["enable_thinking"] = False
         | 
| 115 |  | 
| 116 | 
             
                try:
         | 
|  | |
| 117 | 
             
                    if model_name in tokenizer_cache:
         | 
| 118 | 
             
                        tokenizer = tokenizer_cache[model_name]
         | 
| 119 | 
             
                    else:
         | 
| 120 | 
            -
                        tokenizer  | 
| 121 | 
            -
             | 
| 122 | 
            -
                            padding_side="left", 
         | 
| 123 | 
            -
                            token=True, 
         | 
| 124 | 
            -
                            kwargs=tokenizer_kwargs
         | 
| 125 | 
            -
                            )
         | 
| 126 | 
            -
                        tokenizer_cache[model_name] = tokenizer
         | 
| 127 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 128 | 
             
                    accepts_sys = (
         | 
| 129 | 
             
                        "System role not supported" not in tokenizer.chat_template
         | 
| 130 | 
            -
                        if tokenizer.chat_template | 
|  | |
| 131 | 
             
                    )
         | 
| 132 |  | 
| 133 | 
             
                    if tokenizer.pad_token is None:
         | 
| @@ -136,40 +159,107 @@ def run_inference(model_name, context, question): | |
| 136 | 
             
                    # Check interrupt before loading the model
         | 
| 137 | 
             
                    if generation_interrupt.is_set():
         | 
| 138 | 
             
                        return ""
         | 
| 139 | 
            -
             | 
| 140 | 
            -
             | 
| 141 | 
            -
                     | 
| 142 | 
            -
             | 
| 143 | 
            -
                         | 
| 144 | 
            -
             | 
| 145 | 
            -
             | 
| 146 | 
            -
             | 
| 147 | 
            -
             | 
| 148 | 
            -
             | 
| 149 | 
            -
                             | 
| 150 | 
            -
             | 
| 151 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 152 |  | 
| 153 | 
             
                    text_input = format_rag_prompt(question, context, accepts_sys)
         | 
| 154 | 
            -
                    if "Gemma-3".lower()  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 155 | 
             
                        formatted = pipe.tokenizer.apply_chat_template(
         | 
| 156 | 
             
                            text_input,
         | 
| 157 | 
            -
                            tokenize= | 
| 158 | 
             
                            **tokenizer_kwargs,
         | 
| 159 | 
             
                        )
         | 
| 160 | 
            -
             | 
| 161 | 
             
                        input_length = len(formatted)
         | 
| 162 | 
            -
             | 
| 163 |  | 
| 164 | 
            -
                        outputs = pipe( | 
| 165 | 
            -
             | 
| 166 | 
            -
             | 
| 167 | 
            -
             | 
| 168 | 
            -
                         | 
| 169 | 
            -
                         | 
|  | |
| 170 |  | 
| 171 | 
             
                except Exception as e:
         | 
| 172 | 
             
                    print(f"Error in inference for {model_name}: {e}")
         | 
|  | |
| 173 | 
             
                    result = f"Error generating response: {str(e)[:200]}..."
         | 
| 174 |  | 
| 175 | 
             
                finally:
         | 
| @@ -177,4 +267,4 @@ def run_inference(model_name, context, question): | |
| 177 | 
             
                    if torch.cuda.is_available():
         | 
| 178 | 
             
                        torch.cuda.empty_cache()
         | 
| 179 |  | 
| 180 | 
            -
                return result
         | 
|  | |
| 1 | 
             
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            os.environ["MKL_THREADING_LAYER"] = "GNU"
         | 
| 4 | 
             
            import spaces
         | 
| 5 | 
            +
            from peft import PeftModel
         | 
| 6 | 
            +
            import traceback
         | 
| 7 |  | 
| 8 | 
             
            import torch
         | 
| 9 | 
            +
            from transformers import (
         | 
| 10 | 
            +
                pipeline,
         | 
| 11 | 
            +
                AutoTokenizer,
         | 
| 12 | 
            +
                AutoModelForCausalLM,
         | 
| 13 | 
            +
                StoppingCriteria,
         | 
| 14 | 
            +
                StoppingCriteriaList,
         | 
| 15 | 
            +
            )
         | 
| 16 | 
             
            from .prompts import format_rag_prompt
         | 
| 17 | 
             
            from .shared import generation_interrupt
         | 
| 18 |  | 
| 19 | 
             
            models = {
         | 
| 20 | 
            +
                "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
         | 
| 21 | 
            +
                "Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct",
         | 
| 22 | 
            +
                "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
         | 
| 23 | 
            +
                "Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct",
         | 
| 24 | 
            +
                "Gemma-3-1b-it": "google/gemma-3-1b-it",
         | 
| 25 | 
            +
                "Gemma-3-4b-it": "google/gemma-3-4b-it",
         | 
| 26 | 
            +
                "Gemma-2-2b-it": "google/gemma-2-2b-it",
         | 
| 27 | 
            +
                "Phi-4-mini-instruct": "microsoft/phi-4-mini-instruct",
         | 
| 28 | 
             
                "Cogito-v1-preview-llama-3b": "deepcogito/cogito-v1-preview-llama-3b",
         | 
| 29 | 
            +
                "IBM Granite-3.3-2b-instruct": "ibm-granite/granite-3.3-2b-instruct",
         | 
| 30 | 
             
                # #"Bitnet-b1.58-2B4T": "microsoft/bitnet-b1.58-2B-4T",
         | 
| 31 | 
             
                # #"MiniCPM3-RAG-LoRA": "openbmb/MiniCPM3-RAG-LoRA",
         | 
| 32 | 
             
                "Qwen3-0.6b": "qwen/qwen3-0.6b",
         | 
| 33 | 
            +
                "Qwen3-1.7b": "qwen/qwen3-1.7b",
         | 
| 34 | 
            +
                "Qwen3-4b": "qwen/qwen3-4b",
         | 
| 35 | 
            +
                "SmolLM2-1.7b-Instruct": "HuggingFaceTB/SmolLM2-1.7B-Instruct",
         | 
| 36 | 
            +
                "EXAONE-3.5-2.4B-instruct": "LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct",
         | 
| 37 | 
            +
                "OLMo-2-1B-Instruct": "allenai/OLMo-2-0425-1B-Instruct",
         | 
| 38 | 
            +
                "icecream-3b": "aizip-dev/icecream-3b",
         | 
| 39 | 
             
            }
         | 
| 40 |  | 
| 41 | 
             
            tokenizer_cache = {}
         | 
|  | |
| 43 | 
             
            # List of model names for easy access
         | 
| 44 | 
             
            model_names = list(models.keys())
         | 
| 45 |  | 
| 46 | 
            +
             | 
| 47 | 
             
            # Custom stopping criteria that checks the interrupt flag
         | 
| 48 | 
             
            class InterruptCriteria(StoppingCriteria):
         | 
| 49 | 
             
                def __init__(self, interrupt_event):
         | 
| 50 | 
             
                    self.interrupt_event = interrupt_event
         | 
| 51 | 
            +
             | 
| 52 | 
             
                def __call__(self, input_ids, scores, **kwargs):
         | 
| 53 | 
             
                    return self.interrupt_event.is_set()
         | 
| 54 |  | 
| 55 | 
            +
             | 
| 56 | 
             
            @spaces.GPU
         | 
| 57 | 
             
            def generate_summaries(example, model_a_name, model_b_name):
         | 
| 58 | 
             
                """
         | 
|  | |
| 60 | 
             
                """
         | 
| 61 | 
             
                if generation_interrupt.is_set():
         | 
| 62 | 
             
                    return "", ""
         | 
| 63 | 
            +
             | 
| 64 | 
             
                context_text = ""
         | 
| 65 | 
             
                context_parts = []
         | 
| 66 | 
            +
             | 
| 67 | 
             
                if "full_contexts" in example and example["full_contexts"]:
         | 
| 68 | 
             
                    for i, ctx in enumerate(example["full_contexts"]):
         | 
| 69 | 
             
                        content = ""
         | 
| 70 | 
            +
             | 
| 71 | 
             
                        # Extract content from either dict or string
         | 
| 72 | 
             
                        if isinstance(ctx, dict) and "content" in ctx:
         | 
| 73 | 
             
                            content = ctx["content"]
         | 
| 74 | 
             
                        elif isinstance(ctx, str):
         | 
| 75 | 
             
                            content = ctx
         | 
| 76 | 
            +
             | 
| 77 | 
             
                        # Add document number if not already present
         | 
| 78 | 
             
                        if not content.strip().startswith("Document"):
         | 
| 79 | 
            +
                            content = f"Document {i + 1}:\n{content}"
         | 
| 80 | 
            +
             | 
| 81 | 
             
                        context_parts.append(content)
         | 
| 82 | 
            +
             | 
| 83 | 
             
                    context_text = "\n\n".join(context_parts)
         | 
| 84 | 
             
                else:
         | 
| 85 | 
             
                    # Provide a graceful fallback instead of raising an error
         | 
| 86 | 
             
                    print("Warning: No full context found in the example, using empty context")
         | 
| 87 | 
             
                    context_text = ""
         | 
| 88 | 
            +
             | 
| 89 | 
             
                question = example.get("question", "")
         | 
| 90 | 
            +
             | 
| 91 | 
             
                if generation_interrupt.is_set():
         | 
| 92 | 
             
                    return "", ""
         | 
| 93 | 
            +
             | 
| 94 | 
             
                # Run model A
         | 
| 95 | 
             
                summary_a = run_inference(models[model_a_name], context_text, question)
         | 
| 96 | 
            +
             | 
| 97 | 
             
                if generation_interrupt.is_set():
         | 
| 98 | 
             
                    return summary_a, ""
         | 
| 99 | 
            +
             | 
| 100 | 
             
                # Run model B
         | 
| 101 | 
             
                summary_b = run_inference(models[model_b_name], context_text, question)
         | 
| 102 | 
            +
             | 
| 103 | 
             
                return summary_a, summary_b
         | 
| 104 |  | 
| 105 | 
            +
             | 
| 106 | 
             
            @spaces.GPU
         | 
| 107 | 
             
            def run_inference(model_name, context, question):
         | 
| 108 | 
             
                """
         | 
|  | |
| 117 | 
             
                result = ""
         | 
| 118 | 
             
                tokenizer_kwargs = {
         | 
| 119 | 
             
                    "add_generation_prompt": True,
         | 
| 120 | 
            +
                }  # make sure qwen3 doesn't use thinking
         | 
| 121 | 
             
                generation_kwargs = {
         | 
| 122 | 
             
                    "max_new_tokens": 512,
         | 
| 123 | 
             
                }
         | 
| 124 | 
            +
                if "qwen3" in model_name.lower():
         | 
| 125 | 
            +
                    print(
         | 
| 126 | 
            +
                        f"Recognized {model_name} as a Qwen3 model. Setting enable_thinking=False."
         | 
| 127 | 
            +
                    )
         | 
| 128 | 
             
                    tokenizer_kwargs["enable_thinking"] = False
         | 
| 129 |  | 
| 130 | 
             
                try:
         | 
| 131 | 
            +
                    print("REACHED HERE BEFORE tokenizer")
         | 
| 132 | 
             
                    if model_name in tokenizer_cache:
         | 
| 133 | 
             
                        tokenizer = tokenizer_cache[model_name]
         | 
| 134 | 
             
                    else:
         | 
| 135 | 
            +
                        # Common arguments for tokenizer loading
         | 
| 136 | 
            +
                        tokenizer_load_args = {"padding_side": "left", "token": True}
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 137 |  | 
| 138 | 
            +
                        # Determine the Hugging Face model name for the tokenizer
         | 
| 139 | 
            +
                        actual_model_name_for_tokenizer = model_name
         | 
| 140 | 
            +
                        if "icecream" in model_name.lower():
         | 
| 141 | 
            +
                            actual_model_name_for_tokenizer = "meta-llama/llama-3.2-3b-instruct"
         | 
| 142 | 
            +
                        
         | 
| 143 | 
            +
                        # Note: tokenizer_kwargs (defined earlier, with add_generation_prompt etc.)
         | 
| 144 | 
            +
                        # is intended for tokenizer.apply_chat_template, not for AutoTokenizer.from_pretrained generally.
         | 
| 145 | 
            +
                        # If a specific tokenizer (e.g., Qwen) needs special __init__ args that happen to be in tokenizer_kwargs,
         | 
| 146 | 
            +
                        # that would require more specific handling here. For now, we assume general constructor args.
         | 
| 147 | 
            +
                        tokenizer = AutoTokenizer.from_pretrained(actual_model_name_for_tokenizer, **tokenizer_load_args)
         | 
| 148 | 
            +
                        tokenizer_cache[model_name] = tokenizer
         | 
| 149 | 
            +
             | 
| 150 | 
             
                    accepts_sys = (
         | 
| 151 | 
             
                        "System role not supported" not in tokenizer.chat_template
         | 
| 152 | 
            +
                        if tokenizer.chat_template
         | 
| 153 | 
            +
                        else False  # Handle missing chat_template
         | 
| 154 | 
             
                    )
         | 
| 155 |  | 
| 156 | 
             
                    if tokenizer.pad_token is None:
         | 
|  | |
| 159 | 
             
                    # Check interrupt before loading the model
         | 
| 160 | 
             
                    if generation_interrupt.is_set():
         | 
| 161 | 
             
                        return ""
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    print("REACHED HERE BEFORE pipe")
         | 
| 164 | 
            +
                    print(f"Loading model {model_name}...")
         | 
| 165 | 
            +
                    if "icecream" not in model_name.lower():
         | 
| 166 | 
            +
                        pipe = pipeline(
         | 
| 167 | 
            +
                            "text-generation",
         | 
| 168 | 
            +
                            model=model_name,
         | 
| 169 | 
            +
                            tokenizer=tokenizer,
         | 
| 170 | 
            +
                            device_map="cuda",
         | 
| 171 | 
            +
                            trust_remote_code=True,
         | 
| 172 | 
            +
                            torch_dtype=torch.bfloat16,
         | 
| 173 | 
            +
                            model_kwargs={
         | 
| 174 | 
            +
                                "attn_implementation": "eager",
         | 
| 175 | 
            +
                            },
         | 
| 176 | 
            +
                        )
         | 
| 177 | 
            +
                    else:
         | 
| 178 | 
            +
                        base_model = AutoModelForCausalLM.from_pretrained(
         | 
| 179 | 
            +
                            "meta-llama/llama-3.2-3b-instruct",
         | 
| 180 | 
            +
                            device_map="cuda",
         | 
| 181 | 
            +
                            torch_dtype=torch.bfloat16,
         | 
| 182 | 
            +
                            trust_remote_code=True,
         | 
| 183 | 
            +
                        )
         | 
| 184 | 
            +
                        model = PeftModel.from_pretrained(
         | 
| 185 | 
            +
                            base_model,
         | 
| 186 | 
            +
                            "aizip-dev/icecream-3b",
         | 
| 187 | 
            +
                            device_map="cuda",
         | 
| 188 | 
            +
                            torch_dtype=torch.bfloat16,
         | 
| 189 | 
            +
                        )
         | 
| 190 |  | 
| 191 | 
             
                    text_input = format_rag_prompt(question, context, accepts_sys)
         | 
| 192 | 
            +
                    if "Gemma-3".lower() in model_name.lower():
         | 
| 193 | 
            +
                        print("REACHED HERE BEFORE GEN")
         | 
| 194 | 
            +
                        result = pipe(
         | 
| 195 | 
            +
                            text_input,
         | 
| 196 | 
            +
                            max_new_tokens=512,
         | 
| 197 | 
            +
                            generation_kwargs={"skip_special_tokens": True},
         | 
| 198 | 
            +
                        )[0]["generated_text"]
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                        result = result[-1]["content"]
         | 
| 201 | 
            +
                    elif "icecream" in model_name.lower():
         | 
| 202 | 
            +
                        
         | 
| 203 | 
            +
                        print("ICECREAM")
         | 
| 204 | 
            +
                        # text_input is the list of messages from format_rag_prompt
         | 
| 205 | 
            +
                        # tokenizer_kwargs (e.g., {"add_generation_prompt": True}) are correctly passed to apply_chat_template
         | 
| 206 | 
            +
                        model_inputs = tokenizer.apply_chat_template(
         | 
| 207 | 
            +
                            text_input,
         | 
| 208 | 
            +
                            tokenize=True,
         | 
| 209 | 
            +
                            return_tensors="pt",
         | 
| 210 | 
            +
                            return_dict=True,
         | 
| 211 | 
            +
                            **tokenizer_kwargs, 
         | 
| 212 | 
            +
                        )
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                        # Move all tensors within the BatchEncoding (model_inputs) to the model's device
         | 
| 215 | 
            +
                        model_inputs = model_inputs.to(model.device)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                        input_ids = model_inputs.input_ids
         | 
| 218 | 
            +
                        attention_mask = model_inputs.attention_mask # Expecting this from a correctly configured tokenizer
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                        prompt_tokens_length = input_ids.shape[1] # Get length of tokenized prompt
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                        with torch.inference_mode():
         | 
| 223 | 
            +
                            # Check interrupt before generation
         | 
| 224 | 
            +
                            if generation_interrupt.is_set():
         | 
| 225 | 
            +
                                return ""
         | 
| 226 | 
            +
                            
         | 
| 227 | 
            +
                            # Explicitly pass input_ids, attention_mask, and pad_token_id
         | 
| 228 | 
            +
                            # tokenizer.pad_token is set to tokenizer.eos_token if None, earlier in the code.
         | 
| 229 | 
            +
                            output_sequences = model.generate(
         | 
| 230 | 
            +
                                input_ids=input_ids,
         | 
| 231 | 
            +
                                attention_mask=attention_mask,
         | 
| 232 | 
            +
                                max_new_tokens=512,
         | 
| 233 | 
            +
                                eos_token_id=tokenizer.eos_token_id, # Good practice for stopping generation
         | 
| 234 | 
            +
                                pad_token_id=tokenizer.pad_token_id  # Addresses the warning
         | 
| 235 | 
            +
                            )
         | 
| 236 | 
            +
                        
         | 
| 237 | 
            +
                        # output_sequences[0] contains the full sequence (prompt + generation)
         | 
| 238 | 
            +
                        # Decode only the newly generated tokens
         | 
| 239 | 
            +
                        generated_token_ids = output_sequences[0][prompt_tokens_length:]
         | 
| 240 | 
            +
                        result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    else:  # For other models
         | 
| 243 | 
             
                        formatted = pipe.tokenizer.apply_chat_template(
         | 
| 244 | 
             
                            text_input,
         | 
| 245 | 
            +
                            tokenize=True,
         | 
| 246 | 
             
                            **tokenizer_kwargs,
         | 
| 247 | 
             
                        )
         | 
| 248 | 
            +
             | 
| 249 | 
             
                        input_length = len(formatted)
         | 
| 250 | 
            +
                        # Check interrupt before generation
         | 
| 251 |  | 
| 252 | 
            +
                        outputs = pipe(
         | 
| 253 | 
            +
                            formatted,
         | 
| 254 | 
            +
                            max_new_tokens=512,
         | 
| 255 | 
            +
                            generation_kwargs={"skip_special_tokens": True},
         | 
| 256 | 
            +
                        )
         | 
| 257 | 
            +
                        # print(outputs[0]['generated_text'])
         | 
| 258 | 
            +
                        result = outputs[0]["generated_text"][input_length:]
         | 
| 259 |  | 
| 260 | 
             
                except Exception as e:
         | 
| 261 | 
             
                    print(f"Error in inference for {model_name}: {e}")
         | 
| 262 | 
            +
                    print(traceback.format_exc())
         | 
| 263 | 
             
                    result = f"Error generating response: {str(e)[:200]}..."
         | 
| 264 |  | 
| 265 | 
             
                finally:
         | 
|  | |
| 267 | 
             
                    if torch.cuda.is_available():
         | 
| 268 | 
             
                        torch.cuda.empty_cache()
         | 
| 269 |  | 
| 270 | 
            +
                return result
         | 
    	
        utils/prompts.py
    CHANGED
    
    | @@ -26,7 +26,7 @@ Given the following query and context, please provide your response: | |
| 26 |  | 
| 27 | 
             
            {context}
         | 
| 28 |  | 
| 29 | 
            -
            WITHOUT mentioning your judgement either your grounded answer, OR refusal and clarifications:
         | 
| 30 | 
             
            """
         | 
| 31 |  | 
| 32 | 
             
                messages = (
         | 
|  | |
| 26 |  | 
| 27 | 
             
            {context}
         | 
| 28 |  | 
| 29 | 
            +
            WITHOUT mentioning your judgement on answerability, either your grounded answer, OR refusal and clarifications:
         | 
| 30 | 
             
            """
         | 
| 31 |  | 
| 32 | 
             
                messages = (
         | 
 
			

