Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- .ruff_cache/.gitignore +2 -0
- .ruff_cache/0.12.8/5591301162804142724 +0 -0
- .ruff_cache/CACHEDIR.TAG +1 -0
- README.md +75 -0
- config.json +30 -0
- custom_generate/generate.py +608 -0
- generation_config.json +13 -0
- merges.txt +0 -0
- model.safetensors +3 -0
- tokenizer.json +3 -0
- tokenizer_config.json +239 -0
- vocab.json +0 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            tokenizer.json filter=lfs diff=lfs merge=lfs -text
         | 
    	
        .ruff_cache/.gitignore
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Automatically created by ruff.
         | 
| 2 | 
            +
            *
         | 
    	
        .ruff_cache/0.12.8/5591301162804142724
    ADDED
    
    | Binary file (154 Bytes). View file | 
|  | 
    	
        .ruff_cache/CACHEDIR.TAG
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            Signature: 8a477f597d28d172789f06886806bc55
         | 
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,75 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            library_name: transformers
         | 
| 4 | 
            +
            tags:
         | 
| 5 | 
            +
              - custom_generate
         | 
| 6 | 
            +
            ---
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            ## Description
         | 
| 9 | 
            +
            Implementation of [Contrastive Search](https://huggingface.co/blog/introducing-csearch), a decoding strategy that jointly optimizes model confidence and a degeneration penalty to produce fluent, coherent, and low-repetition text. At each step, the model considers the top-k candidate tokens and selects the one maximizing:
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            score(v) = (1 - alpha) * p(v | context) - alpha * max_cosine_similarity(h_v, H_context)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            where `alpha` is the trade-off between confidence and the cosine-similarity-based penalty.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            This strategy typically:
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            - Reduces repetition compared to greedy/beam search
         | 
| 18 | 
            +
            - Preserves semantic coherence better than pure sampling
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            ---
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            ## Base model
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            - `Qwen/Qwen2.5-0.5B-Instruct` (example)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            ---
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            ## Model compatibility
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            - Decoder-only transformer models for causal LM
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            ---
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            ## Additional Arguments
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            - `top_k` (int): Number of candidate tokens to consider each step (e.g., 4)
         | 
| 37 | 
            +
            - `penalty_alpha` (float): Weight of the degeneration penalty (e.g., 0.6)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            Tips:
         | 
| 40 | 
            +
            - Larger `top_k` explores more candidates but increases compute
         | 
| 41 | 
            +
            - `penalty_alpha` in [0.3, 0.8] often works well; `0.0` reduces to greedy
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            ---
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            ## Output Type changes
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            (none) — returns the same structure as standard `transformers` generation
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            ---
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            ## Example usage
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            ```py
         | 
| 54 | 
            +
            from transformers import AutoModelForCausalLM, AutoTokenizer, infer_device
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            device = infer_device()
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            model_id = "Qwen/Qwen2.5-0.5B-Instruct"
         | 
| 59 | 
            +
            tokenizer = AutoTokenizer.from_pretrained(model_id)
         | 
| 60 | 
            +
            model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto").to(device)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            inputs = tokenizer(["DeepMind Company is"], return_tensors="pt").to(device)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            # Contrastive search
         | 
| 65 | 
            +
            gen_out = model.generate(
         | 
| 66 | 
            +
                **inputs,
         | 
| 67 | 
            +
                custom_generate="contrastive_search",
         | 
| 68 | 
            +
                penalty_alpha=0.6,
         | 
| 69 | 
            +
                top_k=4,
         | 
| 70 | 
            +
                max_new_tokens=128,
         | 
| 71 | 
            +
                trust_remote_code=True,
         | 
| 72 | 
            +
            )
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            print(tokenizer.batch_decode(gen_out, skip_special_tokens=True))
         | 
| 75 | 
            +
            ```
         | 
    	
        config.json
    ADDED
    
    | @@ -0,0 +1,30 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "architectures": [
         | 
| 3 | 
            +
                  "Qwen3ForCausalLM"
         | 
| 4 | 
            +
                ],
         | 
| 5 | 
            +
                "attention_bias": false,
         | 
| 6 | 
            +
                "attention_dropout": 0.0,
         | 
| 7 | 
            +
                "bos_token_id": 151643,
         | 
| 8 | 
            +
                "eos_token_id": 151645,
         | 
| 9 | 
            +
                "head_dim": 128,
         | 
| 10 | 
            +
                "hidden_act": "silu",
         | 
| 11 | 
            +
                "hidden_size": 1024,
         | 
| 12 | 
            +
                "initializer_range": 0.02,
         | 
| 13 | 
            +
                "intermediate_size": 3072,
         | 
| 14 | 
            +
                "max_position_embeddings": 40960,
         | 
| 15 | 
            +
                "max_window_layers": 28,
         | 
| 16 | 
            +
                "model_type": "qwen3",
         | 
| 17 | 
            +
                "num_attention_heads": 16,
         | 
| 18 | 
            +
                "num_hidden_layers": 28,
         | 
| 19 | 
            +
                "num_key_value_heads": 8,
         | 
| 20 | 
            +
                "rms_norm_eps": 1e-06,
         | 
| 21 | 
            +
                "rope_scaling": null,
         | 
| 22 | 
            +
                "rope_theta": 1000000,
         | 
| 23 | 
            +
                "sliding_window": null,
         | 
| 24 | 
            +
                "tie_word_embeddings": true,
         | 
| 25 | 
            +
                "torch_dtype": "bfloat16",
         | 
| 26 | 
            +
                "transformers_version": "4.56.0",
         | 
| 27 | 
            +
                "use_cache": true,
         | 
| 28 | 
            +
                "use_sliding_window": false,
         | 
| 29 | 
            +
                "vocab_size": 151936
         | 
| 30 | 
            +
              }
         | 
    	
        custom_generate/generate.py
    ADDED
    
    | @@ -0,0 +1,608 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Union, Optional, TYPE_CHECKING
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from transformers import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
         | 
| 4 | 
            +
            from transformers.generation.utils import (
         | 
| 5 | 
            +
                GenerateNonBeamOutput,
         | 
| 6 | 
            +
                GenerateDecoderOnlyOutput,
         | 
| 7 | 
            +
            )
         | 
| 8 | 
            +
            from transformers.cache_utils import Cache, EncoderDecoderCache, DynamicCache
         | 
| 9 | 
            +
            from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
         | 
| 10 | 
            +
            from transformers.generation.utils import GenerateEncoderDecoderOutput, ALL_CACHE_NAMES
         | 
| 11 | 
            +
            from transformers.utils import ModelOutput
         | 
| 12 | 
            +
            from transformers.configuration_utils import PretrainedConfig
         | 
| 13 | 
            +
            import torch.nn as nn
         | 
| 14 | 
            +
            import logging
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            if TYPE_CHECKING:
         | 
| 17 | 
            +
                from transformers.generation.streamers import BaseStreamer
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def stack_model_outputs(
         | 
| 23 | 
            +
                model_outputs: list[ModelOutput], config: PretrainedConfig
         | 
| 24 | 
            +
            ) -> ModelOutput:
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
                Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
         | 
| 27 | 
            +
                specific ModelOutput subclass from the list provided.
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
                if not model_outputs:
         | 
| 30 | 
            +
                    raise ValueError("Input list is empty.")
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                # Infer the class from the first object in the list
         | 
| 33 | 
            +
                model_output_cls = type(model_outputs[0])
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                # Ensure all objects are of the same type
         | 
| 36 | 
            +
                if not all(isinstance(obj, model_output_cls) for obj in model_outputs):
         | 
| 37 | 
            +
                    raise ValueError("All elements in the list should be of the same type.")
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                # Helper function to concat tensors or tuples of tensors
         | 
| 40 | 
            +
                def _concat(data):
         | 
| 41 | 
            +
                    """
         | 
| 42 | 
            +
                    Reverse of `_split` function above.
         | 
| 43 | 
            +
                    """
         | 
| 44 | 
            +
                    if any(data is None for data in data):
         | 
| 45 | 
            +
                        return None
         | 
| 46 | 
            +
                    if isinstance(data[0], torch.Tensor):
         | 
| 47 | 
            +
                        return torch.cat(data, dim=0)
         | 
| 48 | 
            +
                    elif isinstance(data[0], tuple):
         | 
| 49 | 
            +
                        # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
         | 
| 50 | 
            +
                        if isinstance(data[0][0], tuple):
         | 
| 51 | 
            +
                            return tuple(
         | 
| 52 | 
            +
                                tuple(
         | 
| 53 | 
            +
                                    torch.cat([attr[i][j] for attr in data], dim=0)
         | 
| 54 | 
            +
                                    for j in range(len(data[0][0]))
         | 
| 55 | 
            +
                                )
         | 
| 56 | 
            +
                                for i in range(len(data[0]))
         | 
| 57 | 
            +
                            )
         | 
| 58 | 
            +
                        else:
         | 
| 59 | 
            +
                            return tuple(
         | 
| 60 | 
            +
                                torch.cat([attr[i] for attr in data], dim=0)
         | 
| 61 | 
            +
                                for i in range(len(data[0]))
         | 
| 62 | 
            +
                            )
         | 
| 63 | 
            +
                    elif isinstance(data[0], (int, float)):
         | 
| 64 | 
            +
                        # If the elements are integers or floats, return a tensor
         | 
| 65 | 
            +
                        return torch.tensor(data)
         | 
| 66 | 
            +
                    else:
         | 
| 67 | 
            +
                        raise TypeError(f"Unexpected attribute type: {type(data[0])}")
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                # Use a dictionary comprehension to gather attributes from all objects and concatenate them
         | 
| 70 | 
            +
                concatenated_data = {
         | 
| 71 | 
            +
                    k: _concat([getattr(model_output, k) for model_output in model_outputs])
         | 
| 72 | 
            +
                    for k in model_output_cls.__dataclass_fields__
         | 
| 73 | 
            +
                }
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                # Return a new object of the inferred class with the concatenated attributes
         | 
| 76 | 
            +
                return model_output_cls(**concatenated_data)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            def _ranking_fast(
         | 
| 80 | 
            +
                context_hidden: torch.FloatTensor,
         | 
| 81 | 
            +
                next_hidden: torch.FloatTensor,
         | 
| 82 | 
            +
                next_top_k_probs: torch.FloatTensor,
         | 
| 83 | 
            +
                cosine_matrix_mask: torch.LongTensor,
         | 
| 84 | 
            +
                alpha: float,
         | 
| 85 | 
            +
                beam_width: int,
         | 
| 86 | 
            +
            ) -> torch.FloatTensor:
         | 
| 87 | 
            +
                """
         | 
| 88 | 
            +
                Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described
         | 
| 89 | 
            +
                in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each
         | 
| 90 | 
            +
                row in the batch.
         | 
| 91 | 
            +
                """
         | 
| 92 | 
            +
                norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
         | 
| 93 | 
            +
                norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
         | 
| 94 | 
            +
                cosine_matrix = torch.matmul(
         | 
| 95 | 
            +
                    norm_context_hidden, norm_next_hidden.transpose(1, 2)
         | 
| 96 | 
            +
                ).squeeze(-1)  # [B*K, S]
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                # Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions)
         | 
| 99 | 
            +
                # Using a large negative value for masked positions
         | 
| 100 | 
            +
                cosine_matrix_mask = cosine_matrix_mask.to(dtype=cosine_matrix.dtype)
         | 
| 101 | 
            +
                cosine_matrix_mask = (1 - cosine_matrix_mask) * torch.finfo(cosine_matrix.dtype).min
         | 
| 102 | 
            +
                cosine_matrix = cosine_matrix + cosine_matrix_mask
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1)  # [B*K]
         | 
| 105 | 
            +
                next_top_k_probs = next_top_k_probs.view(-1)  # [B*K]
         | 
| 106 | 
            +
                contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
         | 
| 107 | 
            +
                contrastive_score = torch.stack(
         | 
| 108 | 
            +
                    torch.split(contrastive_score, beam_width)
         | 
| 109 | 
            +
                )  # [B, K]
         | 
| 110 | 
            +
                _, selected_idx = contrastive_score.max(dim=-1)  # [B]
         | 
| 111 | 
            +
                return selected_idx
         | 
| 112 | 
            +
             | 
| 113 | 
            +
             | 
| 114 | 
            +
            @torch.no_grad()
         | 
| 115 | 
            +
            def _contrastive_search(
         | 
| 116 | 
            +
                model,
         | 
| 117 | 
            +
                input_ids: torch.LongTensor,
         | 
| 118 | 
            +
                logits_processor: LogitsProcessorList,
         | 
| 119 | 
            +
                stopping_criteria: StoppingCriteriaList,
         | 
| 120 | 
            +
                generation_config: GenerationConfig,
         | 
| 121 | 
            +
                synced_gpus: bool,
         | 
| 122 | 
            +
                streamer: Optional["BaseStreamer"],
         | 
| 123 | 
            +
                **model_kwargs,
         | 
| 124 | 
            +
            ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
         | 
| 125 | 
            +
                r"""
         | 
| 126 | 
            +
                Generates sequences of token ids for models with a language modeling head using **contrastive search** and can
         | 
| 127 | 
            +
                be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                Parameters:
         | 
| 130 | 
            +
                    input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
         | 
| 131 | 
            +
                        The sequence used as a prompt for the generation.
         | 
| 132 | 
            +
                    logits_processor (`LogitsProcessorList`):
         | 
| 133 | 
            +
                        An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
         | 
| 134 | 
            +
                        used to modify the prediction scores of the language modeling head applied at each generation step.
         | 
| 135 | 
            +
                    stopping_criteria (`StoppingCriteriaList`):
         | 
| 136 | 
            +
                        An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
         | 
| 137 | 
            +
                        used to tell if the generation loop should stop.
         | 
| 138 | 
            +
                    generation_config ([`~generation.GenerationConfig`]):
         | 
| 139 | 
            +
                        The generation configuration to be used as parametrization of the decoding method.
         | 
| 140 | 
            +
                    synced_gpus (`bool`):
         | 
| 141 | 
            +
                        Whether to continue running the while loop until max_length (needed to avoid deadlocking with
         | 
| 142 | 
            +
                        `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
         | 
| 143 | 
            +
                    streamer (`BaseStreamer`, *optional*):
         | 
| 144 | 
            +
                        Streamer object that will be used to stream the generated sequences. Generated tokens are passed
         | 
| 145 | 
            +
                        through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
         | 
| 146 | 
            +
                    model_kwargs:
         | 
| 147 | 
            +
                        Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
         | 
| 148 | 
            +
                        If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                Return:
         | 
| 151 | 
            +
                    [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`]
         | 
| 152 | 
            +
                    or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
         | 
| 153 | 
            +
                    [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
         | 
| 154 | 
            +
                    `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
         | 
| 155 | 
            +
                    `model.config.is_encoder_decoder=True`.
         | 
| 156 | 
            +
                """
         | 
| 157 | 
            +
                if not model_kwargs["use_cache"]:
         | 
| 158 | 
            +
                    raise ValueError("Contrastive search requires `use_cache=True`")
         | 
| 159 | 
            +
                if model._is_stateful:
         | 
| 160 | 
            +
                    # Just like assisted generation, we need to be able to rollback to a previous state (see comment above)
         | 
| 161 | 
            +
                    raise ValueError(
         | 
| 162 | 
            +
                        f"contrastive search is not supported with stateful models, such as {model.__class__.__name__}"
         | 
| 163 | 
            +
                    )
         | 
| 164 | 
            +
                # init values
         | 
| 165 | 
            +
                has_eos_stopping_criteria = any(
         | 
| 166 | 
            +
                    hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
         | 
| 167 | 
            +
                )
         | 
| 168 | 
            +
                top_k = generation_config.top_k
         | 
| 169 | 
            +
                penalty_alpha = generation_config.penalty_alpha
         | 
| 170 | 
            +
                pad_token_id = generation_config._pad_token_tensor
         | 
| 171 | 
            +
                output_attentions = generation_config.output_attentions
         | 
| 172 | 
            +
                output_hidden_states = generation_config.output_hidden_states
         | 
| 173 | 
            +
                output_scores = generation_config.output_scores
         | 
| 174 | 
            +
                output_logits = generation_config.output_logits
         | 
| 175 | 
            +
                return_dict_in_generate = generation_config.return_dict_in_generate
         | 
| 176 | 
            +
                sequential = generation_config.low_memory
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                # init attention / hidden states / scores tuples
         | 
| 179 | 
            +
                raw_logits = () if (return_dict_in_generate and output_logits) else None
         | 
| 180 | 
            +
                scores = () if (return_dict_in_generate and output_scores) else None
         | 
| 181 | 
            +
                decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
         | 
| 182 | 
            +
                cross_attentions = () if (return_dict_in_generate and output_attentions) else None
         | 
| 183 | 
            +
                decoder_hidden_states = (
         | 
| 184 | 
            +
                    () if (return_dict_in_generate and output_hidden_states) else None
         | 
| 185 | 
            +
                )
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
         | 
| 188 | 
            +
                if return_dict_in_generate and model.config.is_encoder_decoder:
         | 
| 189 | 
            +
                    encoder_attentions = (
         | 
| 190 | 
            +
                        model_kwargs["encoder_outputs"].get("attentions")
         | 
| 191 | 
            +
                        if output_attentions
         | 
| 192 | 
            +
                        else None
         | 
| 193 | 
            +
                    )
         | 
| 194 | 
            +
                    encoder_hidden_states = (
         | 
| 195 | 
            +
                        model_kwargs["encoder_outputs"].get("hidden_states")
         | 
| 196 | 
            +
                        if output_hidden_states
         | 
| 197 | 
            +
                        else None
         | 
| 198 | 
            +
                    )
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                # keep track of which sequences are already finished
         | 
| 201 | 
            +
                batch_size, cur_len = input_ids.shape[:2]
         | 
| 202 | 
            +
                unfinished_sequences = torch.ones(
         | 
| 203 | 
            +
                    batch_size, dtype=torch.long, device=input_ids.device
         | 
| 204 | 
            +
                )
         | 
| 205 | 
            +
                model_kwargs = model._get_initial_cache_position(
         | 
| 206 | 
            +
                    cur_len, input_ids.device, model_kwargs
         | 
| 207 | 
            +
                )
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                # Create cosine_matrix_mask based on the attention_mask
         | 
| 210 | 
            +
                cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
         | 
| 211 | 
            +
                if model.config.is_encoder_decoder:
         | 
| 212 | 
            +
                    if (
         | 
| 213 | 
            +
                        "decoder_attention_mask" in model_kwargs
         | 
| 214 | 
            +
                        and model_kwargs["decoder_attention_mask"] is not None
         | 
| 215 | 
            +
                    ):
         | 
| 216 | 
            +
                        cosine_matrix_mask = model_kwargs["decoder_attention_mask"]
         | 
| 217 | 
            +
                else:
         | 
| 218 | 
            +
                    cosine_matrix_mask = model_kwargs["attention_mask"]
         | 
| 219 | 
            +
                cosine_matrix_mask = cosine_matrix_mask.repeat_interleave(top_k, dim=0)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                this_peer_finished = False
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                while model._has_unfinished_sequences(
         | 
| 224 | 
            +
                    this_peer_finished, synced_gpus, device=input_ids.device
         | 
| 225 | 
            +
                ):
         | 
| 226 | 
            +
                    # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
         | 
| 227 | 
            +
                    # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
         | 
| 228 | 
            +
                    if model_kwargs.get("past_key_values") is None or (
         | 
| 229 | 
            +
                        isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache))
         | 
| 230 | 
            +
                        and model_kwargs["past_key_values"].get_seq_length() == 0
         | 
| 231 | 
            +
                    ):
         | 
| 232 | 
            +
                        # prepare inputs
         | 
| 233 | 
            +
                        model_kwargs["use_cache"] = True
         | 
| 234 | 
            +
                        model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                        # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
         | 
| 237 | 
            +
                        # the `encoder_outputs`
         | 
| 238 | 
            +
                        outputs = model(
         | 
| 239 | 
            +
                            **model_inputs,
         | 
| 240 | 
            +
                            return_dict=True,
         | 
| 241 | 
            +
                            output_hidden_states=True,
         | 
| 242 | 
            +
                            output_attentions=output_attentions,
         | 
| 243 | 
            +
                        )
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                        # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with
         | 
| 246 | 
            +
                        # previous tokens)
         | 
| 247 | 
            +
                        if model.config.is_encoder_decoder:
         | 
| 248 | 
            +
                            last_hidden_states = outputs.decoder_hidden_states[-1]
         | 
| 249 | 
            +
                        else:
         | 
| 250 | 
            +
                            last_hidden_states = outputs.hidden_states[-1]
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                        # next logit for contrastive search to select top-k candidate tokens
         | 
| 253 | 
            +
                        # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
         | 
| 254 | 
            +
                        # (the clone itmodel is always small)
         | 
| 255 | 
            +
                        # torch.float32 is needed to retain precision for later logits manipulations
         | 
| 256 | 
            +
                        logit_for_next_step = outputs.logits[:, -1, :].to(
         | 
| 257 | 
            +
                            copy=True, dtype=torch.float32, device=input_ids.device
         | 
| 258 | 
            +
                        )
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                        model_kwargs = model._update_model_kwargs_for_generation(
         | 
| 261 | 
            +
                            outputs,
         | 
| 262 | 
            +
                            model_kwargs,
         | 
| 263 | 
            +
                            is_encoder_decoder=model.config.is_encoder_decoder,
         | 
| 264 | 
            +
                        )
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                        if not sequential:
         | 
| 267 | 
            +
                            # Expands model inputs top_k times, for batched forward passes (akin to beam search).
         | 
| 268 | 
            +
                            # input_ids is required for expanding visual inputs in qwen2vl
         | 
| 269 | 
            +
                            _, model_kwargs = model._expand_inputs_for_generation(
         | 
| 270 | 
            +
                                input_ids=input_ids,
         | 
| 271 | 
            +
                                expand_size=top_k,
         | 
| 272 | 
            +
                                is_encoder_decoder=model.config.is_encoder_decoder,
         | 
| 273 | 
            +
                                **model_kwargs,
         | 
| 274 | 
            +
                            )
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                        past_key_values = model_kwargs.get("past_key_values")
         | 
| 277 | 
            +
                        if past_key_values is None:
         | 
| 278 | 
            +
                            raise ValueError(
         | 
| 279 | 
            +
                                f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
         | 
| 280 | 
            +
                                "for contrastive search."
         | 
| 281 | 
            +
                            )
         | 
| 282 | 
            +
                        elif (
         | 
| 283 | 
            +
                            not isinstance(past_key_values[0], (tuple, torch.Tensor))
         | 
| 284 | 
            +
                            or past_key_values[0][0].shape[0] != batch_size
         | 
| 285 | 
            +
                        ):
         | 
| 286 | 
            +
                            raise ValueError(
         | 
| 287 | 
            +
                                f"{model.__class__.__name__} does not have a standard cache format and therefore **can't** be "
         | 
| 288 | 
            +
                                "used for contrastive search without further modifications."
         | 
| 289 | 
            +
                            )
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    # contrastive_search main logic start:
         | 
| 292 | 
            +
                    # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
         | 
| 293 | 
            +
                    # degeneration penalty
         | 
| 294 | 
            +
                    processed_logit_for_next_step = logits_processor(input_ids, logit_for_next_step)
         | 
| 295 | 
            +
                    next_probs = nn.functional.softmax(processed_logit_for_next_step, dim=-1)
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                    top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    # Store scores, attentions and hidden_states when required
         | 
| 300 | 
            +
                    if return_dict_in_generate:
         | 
| 301 | 
            +
                        if output_logits:
         | 
| 302 | 
            +
                            raw_logits += (logit_for_next_step,)
         | 
| 303 | 
            +
                        if output_scores:
         | 
| 304 | 
            +
                            scores += (processed_logit_for_next_step,)
         | 
| 305 | 
            +
                        if output_attentions:
         | 
| 306 | 
            +
                            decoder_attentions += (
         | 
| 307 | 
            +
                                (outputs.decoder_attentions,)
         | 
| 308 | 
            +
                                if model.config.is_encoder_decoder
         | 
| 309 | 
            +
                                else (outputs.attentions,)
         | 
| 310 | 
            +
                            )
         | 
| 311 | 
            +
                            if model.config.is_encoder_decoder:
         | 
| 312 | 
            +
                                cross_attentions += (outputs.cross_attentions,)
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                        if output_hidden_states:
         | 
| 315 | 
            +
                            decoder_hidden_states += (
         | 
| 316 | 
            +
                                (outputs.decoder_hidden_states,)
         | 
| 317 | 
            +
                                if model.config.is_encoder_decoder
         | 
| 318 | 
            +
                                else (outputs.hidden_states,)
         | 
| 319 | 
            +
                            )
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    # This is needed to properly delete outputs.logits which may be very large for this first iteration
         | 
| 322 | 
            +
                    # Otherwise a reference to outputs.logits is kept all along until after the next call to model.forward()
         | 
| 323 | 
            +
                    del outputs
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    if not sequential:
         | 
| 326 | 
            +
                        # Replicates the new past_key_values to match the `top_k` candidates
         | 
| 327 | 
            +
                        past = model_kwargs["past_key_values"]
         | 
| 328 | 
            +
                        # If it is a static cache, modify it in-place layer after layer to save memory
         | 
| 329 | 
            +
                        if isinstance(past, DynamicCache) or (
         | 
| 330 | 
            +
                            isinstance(past, EncoderDecoderCache)
         | 
| 331 | 
            +
                            and isinstance(past.model_attention_cache, DynamicCache)
         | 
| 332 | 
            +
                        ):
         | 
| 333 | 
            +
                            past.batch_repeat_interleave(top_k)
         | 
| 334 | 
            +
                        else:
         | 
| 335 | 
            +
                            new_key_values = []
         | 
| 336 | 
            +
                            for layer in past:
         | 
| 337 | 
            +
                                items = []
         | 
| 338 | 
            +
                                # item is either the key or the value matrix
         | 
| 339 | 
            +
                                for item in layer:
         | 
| 340 | 
            +
                                    items.append(item.repeat_interleave(top_k, dim=0))
         | 
| 341 | 
            +
                                new_key_values.append(tuple(items))
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                            past = tuple(new_key_values)
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                        model_kwargs["past_key_values"] = past
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    if sequential:
         | 
| 348 | 
            +
                        all_outputs = []
         | 
| 349 | 
            +
                        for i in range(top_k):
         | 
| 350 | 
            +
                            # compute the candidate tokens by the language model and collect their hidden_states
         | 
| 351 | 
            +
                            next_model_inputs = model.prepare_inputs_for_generation(
         | 
| 352 | 
            +
                                top_k_ids[:, i].view(-1, 1), **model_kwargs
         | 
| 353 | 
            +
                            )
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                            outputs = model(
         | 
| 356 | 
            +
                                **next_model_inputs,
         | 
| 357 | 
            +
                                return_dict=True,
         | 
| 358 | 
            +
                                output_hidden_states=True,
         | 
| 359 | 
            +
                                output_attentions=output_attentions,
         | 
| 360 | 
            +
                            )
         | 
| 361 | 
            +
                            if isinstance(outputs["past_key_values"], DynamicCache) or (
         | 
| 362 | 
            +
                                isinstance(outputs["past_key_values"], EncoderDecoderCache)
         | 
| 363 | 
            +
                                and isinstance(
         | 
| 364 | 
            +
                                    outputs["past_key_values"].model_attention_cache, DynamicCache
         | 
| 365 | 
            +
                                )
         | 
| 366 | 
            +
                            ):
         | 
| 367 | 
            +
                                # Remove past K-V from output since we don't need to stack later
         | 
| 368 | 
            +
                                outputs["past_key_values"] = None
         | 
| 369 | 
            +
                                # Remove last token from past K-V since we don't want to append it at this point
         | 
| 370 | 
            +
                                model_kwargs["past_key_values"].crop(-1)
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                            all_outputs.append(outputs)
         | 
| 373 | 
            +
                        outputs = stack_model_outputs(all_outputs, model.config.get_text_config())
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                    else:
         | 
| 376 | 
            +
                        # compute the candidate tokens by the language model and collect their hidden_states
         | 
| 377 | 
            +
                        # assembles top_k_ids into batch of size k
         | 
| 378 | 
            +
                        next_model_inputs = model.prepare_inputs_for_generation(
         | 
| 379 | 
            +
                            top_k_ids.view(-1, 1), **model_kwargs
         | 
| 380 | 
            +
                        )
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                        outputs = model(
         | 
| 383 | 
            +
                            **next_model_inputs,
         | 
| 384 | 
            +
                            return_dict=True,
         | 
| 385 | 
            +
                            output_hidden_states=True,
         | 
| 386 | 
            +
                            output_attentions=output_attentions,
         | 
| 387 | 
            +
                        )
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                    # This is essential to avoid having a last reference to the big past K-V and double the necessary memory
         | 
| 390 | 
            +
                    # in the next loop
         | 
| 391 | 
            +
                    del next_model_inputs
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                    # name is different for encoder-decoder and decoder-only models
         | 
| 394 | 
            +
                    if model.config.is_encoder_decoder:
         | 
| 395 | 
            +
                        next_hidden = outputs.decoder_hidden_states[-1]
         | 
| 396 | 
            +
                        full_hidden_states = outputs.decoder_hidden_states
         | 
| 397 | 
            +
                    else:
         | 
| 398 | 
            +
                        next_hidden = outputs.hidden_states[-1]
         | 
| 399 | 
            +
                        full_hidden_states = outputs.hidden_states
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                    # .float() is needed to retain precision for later logits manipulations
         | 
| 402 | 
            +
                    logits = outputs.logits[:, -1, :].float()
         | 
| 403 | 
            +
                    context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                    # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
         | 
| 406 | 
            +
                    # model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't
         | 
| 407 | 
            +
                    # introduce (noticeable) slowdowns on single-device runs.
         | 
| 408 | 
            +
                    selected_idx = _ranking_fast(
         | 
| 409 | 
            +
                        context_hidden,
         | 
| 410 | 
            +
                        next_hidden,
         | 
| 411 | 
            +
                        top_k_probs,
         | 
| 412 | 
            +
                        cosine_matrix_mask,
         | 
| 413 | 
            +
                        penalty_alpha,
         | 
| 414 | 
            +
                        top_k,
         | 
| 415 | 
            +
                    )
         | 
| 416 | 
            +
                    cosine_matrix_mask = torch.cat(
         | 
| 417 | 
            +
                        [
         | 
| 418 | 
            +
                            cosine_matrix_mask,
         | 
| 419 | 
            +
                            cosine_matrix_mask.new_ones((cosine_matrix_mask.shape[0], 1)),
         | 
| 420 | 
            +
                        ],
         | 
| 421 | 
            +
                        dim=-1,
         | 
| 422 | 
            +
                    )
         | 
| 423 | 
            +
                    selected_idx = selected_idx.to("cpu")
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                    # This will be used instead of the previous inneficient torch.stack(torch.split())
         | 
| 426 | 
            +
                    augmented_idx = torch.tensor(
         | 
| 427 | 
            +
                        [x + i * top_k for i, x in enumerate(selected_idx)]
         | 
| 428 | 
            +
                    )
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                    # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
         | 
| 431 | 
            +
                    # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
         | 
| 432 | 
            +
                    # (model confidence minus degeneration penalty); (6) decoder hidden_states
         | 
| 433 | 
            +
                    next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx]
         | 
| 434 | 
            +
                    next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k))
         | 
| 435 | 
            +
                    next_hidden = next_hidden[range(batch_size), selected_idx, :]
         | 
| 436 | 
            +
                    last_hidden_states = torch.cat(
         | 
| 437 | 
            +
                        [last_hidden_states, next_hidden.unsqueeze(1)], dim=1
         | 
| 438 | 
            +
                    )
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                    next_decoder_hidden_states = ()
         | 
| 441 | 
            +
                    for layer in full_hidden_states:
         | 
| 442 | 
            +
                        layer = torch.stack(torch.split(layer, top_k))[
         | 
| 443 | 
            +
                            range(batch_size), selected_idx, :
         | 
| 444 | 
            +
                        ]
         | 
| 445 | 
            +
                        next_decoder_hidden_states += (layer,)
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                    # generate past_key_values cache of only the selected token
         | 
| 448 | 
            +
                    if sequential:
         | 
| 449 | 
            +
                        next_model_input = model.prepare_inputs_for_generation(
         | 
| 450 | 
            +
                            top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs
         | 
| 451 | 
            +
                        )
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                        selected_outputs = model(
         | 
| 454 | 
            +
                            **next_model_input,
         | 
| 455 | 
            +
                            return_dict=True,
         | 
| 456 | 
            +
                            output_hidden_states=False,
         | 
| 457 | 
            +
                            output_attentions=False,
         | 
| 458 | 
            +
                        )
         | 
| 459 | 
            +
                        next_past_key_values = selected_outputs["past_key_values"]
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                    else:
         | 
| 462 | 
            +
                        next_past_key_values = None
         | 
| 463 | 
            +
                        for possible_cache_name in ALL_CACHE_NAMES:
         | 
| 464 | 
            +
                            next_past_key_values = next_past_key_values or getattr(
         | 
| 465 | 
            +
                                outputs, possible_cache_name, None
         | 
| 466 | 
            +
                            )
         | 
| 467 | 
            +
                        # Do it in-place layer per layer to save memory
         | 
| 468 | 
            +
                        if isinstance(next_past_key_values, DynamicCache) or (
         | 
| 469 | 
            +
                            isinstance(next_past_key_values, EncoderDecoderCache)
         | 
| 470 | 
            +
                            and isinstance(next_past_key_values.model_attention_cache, DynamicCache)
         | 
| 471 | 
            +
                        ):
         | 
| 472 | 
            +
                            next_past_key_values.batch_select_indices(augmented_idx)
         | 
| 473 | 
            +
                        else:
         | 
| 474 | 
            +
                            new_key_values = []
         | 
| 475 | 
            +
                            for layer in next_past_key_values:
         | 
| 476 | 
            +
                                items = []
         | 
| 477 | 
            +
                                # item is either the key or the value matrix
         | 
| 478 | 
            +
                                for item in layer:
         | 
| 479 | 
            +
                                    items.append(item[augmented_idx, ...])
         | 
| 480 | 
            +
                                new_key_values.append(tuple(items))
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                            next_past_key_values = tuple(new_key_values)
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                    logit_for_next_step = torch.stack(torch.split(logits, top_k))[
         | 
| 485 | 
            +
                        range(batch_size), selected_idx, :
         | 
| 486 | 
            +
                    ]
         | 
| 487 | 
            +
                    logit_for_next_step = logit_for_next_step.to(input_ids.device)
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                    # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
         | 
| 490 | 
            +
                    if model.config.is_encoder_decoder:
         | 
| 491 | 
            +
                        next_step_cross_attentions = ()
         | 
| 492 | 
            +
                        next_step_decoder_attentions = ()
         | 
| 493 | 
            +
                        if output_attentions:
         | 
| 494 | 
            +
                            for layer in outputs.cross_attentions:
         | 
| 495 | 
            +
                                layer = torch.stack(torch.split(layer, top_k, dim=0))[
         | 
| 496 | 
            +
                                    range(batch_size), selected_idx, ...
         | 
| 497 | 
            +
                                ]
         | 
| 498 | 
            +
                                next_step_cross_attentions += (layer,)
         | 
| 499 | 
            +
                            for layer in outputs.decoder_attentions:
         | 
| 500 | 
            +
                                layer = torch.stack(torch.split(layer, top_k, dim=0))[
         | 
| 501 | 
            +
                                    range(batch_size), selected_idx, ...
         | 
| 502 | 
            +
                                ]
         | 
| 503 | 
            +
                                next_step_decoder_attentions += (layer,)
         | 
| 504 | 
            +
                        outputs = Seq2SeqLMOutput(
         | 
| 505 | 
            +
                            past_key_values=next_past_key_values,
         | 
| 506 | 
            +
                            decoder_hidden_states=next_decoder_hidden_states,
         | 
| 507 | 
            +
                            decoder_attentions=next_step_decoder_attentions or None,
         | 
| 508 | 
            +
                            cross_attentions=next_step_cross_attentions or None,
         | 
| 509 | 
            +
                        )
         | 
| 510 | 
            +
                    else:
         | 
| 511 | 
            +
                        next_step_attentions = ()
         | 
| 512 | 
            +
                        if output_attentions:
         | 
| 513 | 
            +
                            for layer in outputs.attentions:
         | 
| 514 | 
            +
                                layer = torch.stack(torch.split(layer, top_k, dim=0))[
         | 
| 515 | 
            +
                                    range(batch_size), selected_idx, ...
         | 
| 516 | 
            +
                                ]
         | 
| 517 | 
            +
                                next_step_attentions += (layer,)
         | 
| 518 | 
            +
                        outputs = CausalLMOutputWithPast(
         | 
| 519 | 
            +
                            past_key_values=next_past_key_values,
         | 
| 520 | 
            +
                            hidden_states=next_decoder_hidden_states,
         | 
| 521 | 
            +
                            attentions=next_step_attentions or None,
         | 
| 522 | 
            +
                        )
         | 
| 523 | 
            +
                    # contrastive_search main logic end
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                    # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
         | 
| 526 | 
            +
                    model_kwargs = model._update_model_kwargs_for_generation(
         | 
| 527 | 
            +
                        outputs,
         | 
| 528 | 
            +
                        model_kwargs,
         | 
| 529 | 
            +
                        is_encoder_decoder=model.config.is_encoder_decoder,
         | 
| 530 | 
            +
                    )
         | 
| 531 | 
            +
                    if synced_gpus and this_peer_finished:
         | 
| 532 | 
            +
                        continue
         | 
| 533 | 
            +
             | 
| 534 | 
            +
                    # finished sentences should have their next token be a padding token
         | 
| 535 | 
            +
                    if has_eos_stopping_criteria:
         | 
| 536 | 
            +
                        next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
         | 
| 537 | 
            +
                            1 - unfinished_sequences
         | 
| 538 | 
            +
                        )
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                    # update generated ids, model inputs, and length for next step
         | 
| 541 | 
            +
                    input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
         | 
| 542 | 
            +
                    if streamer is not None:
         | 
| 543 | 
            +
                        streamer.put(next_tokens.cpu())
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                    # stop when each sentence is finished
         | 
| 546 | 
            +
                    unfinished_sequences = unfinished_sequences & ~stopping_criteria(
         | 
| 547 | 
            +
                        input_ids, scores
         | 
| 548 | 
            +
                    )
         | 
| 549 | 
            +
                    this_peer_finished = unfinished_sequences.max() == 0
         | 
| 550 | 
            +
             | 
| 551 | 
            +
                if streamer is not None:
         | 
| 552 | 
            +
                    streamer.end()
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                if return_dict_in_generate:
         | 
| 555 | 
            +
                    # Contrastive search works by forward looking at the next token, so we need to exclude it from
         | 
| 556 | 
            +
                    # `past_key_values` to be consistent with the other decoding methods
         | 
| 557 | 
            +
                    if model_kwargs.get("past_key_values") is not None:
         | 
| 558 | 
            +
                        if isinstance(model_kwargs["past_key_values"], DynamicCache) or (
         | 
| 559 | 
            +
                            isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
         | 
| 560 | 
            +
                            and isinstance(
         | 
| 561 | 
            +
                                model_kwargs["past_key_values"].model_attention_cache, DynamicCache
         | 
| 562 | 
            +
                            )
         | 
| 563 | 
            +
                        ):
         | 
| 564 | 
            +
                            model_kwargs["past_key_values"].crop(-1)
         | 
| 565 | 
            +
                        else:
         | 
| 566 | 
            +
                            past_key_values = []
         | 
| 567 | 
            +
                            for layer in model_kwargs["past_key_values"]:
         | 
| 568 | 
            +
                                layer_past_key_values = []
         | 
| 569 | 
            +
                                for item in layer:
         | 
| 570 | 
            +
                                    layer_past_key_values.append(item[..., :-1, :])
         | 
| 571 | 
            +
                                past_key_values.append(tuple(layer_past_key_values))
         | 
| 572 | 
            +
                            model_kwargs["past_key_values"] = tuple(past_key_values)
         | 
| 573 | 
            +
             | 
| 574 | 
            +
                    if model.config.is_encoder_decoder:
         | 
| 575 | 
            +
                        return GenerateEncoderDecoderOutput(
         | 
| 576 | 
            +
                            sequences=input_ids,
         | 
| 577 | 
            +
                            scores=scores,
         | 
| 578 | 
            +
                            logits=raw_logits,
         | 
| 579 | 
            +
                            encoder_attentions=encoder_attentions,
         | 
| 580 | 
            +
                            encoder_hidden_states=encoder_hidden_states,
         | 
| 581 | 
            +
                            decoder_attentions=decoder_attentions,
         | 
| 582 | 
            +
                            cross_attentions=cross_attentions,
         | 
| 583 | 
            +
                            decoder_hidden_states=decoder_hidden_states,
         | 
| 584 | 
            +
                            past_key_values=model_kwargs.get("past_key_values"),
         | 
| 585 | 
            +
                        )
         | 
| 586 | 
            +
                    else:
         | 
| 587 | 
            +
                        return GenerateDecoderOnlyOutput(
         | 
| 588 | 
            +
                            sequences=input_ids,
         | 
| 589 | 
            +
                            scores=scores,
         | 
| 590 | 
            +
                            logits=raw_logits,
         | 
| 591 | 
            +
                            attentions=decoder_attentions,
         | 
| 592 | 
            +
                            hidden_states=decoder_hidden_states,
         | 
| 593 | 
            +
                            past_key_values=model_kwargs.get("past_key_values"),
         | 
| 594 | 
            +
                        )
         | 
| 595 | 
            +
                else:
         | 
| 596 | 
            +
                    return input_ids
         | 
| 597 | 
            +
             | 
| 598 | 
            +
             | 
| 599 | 
            +
            def generate(model, *args, **kwargs):
         | 
| 600 | 
            +
                """Custom generate function for Contrastive Search decoding.
         | 
| 601 | 
            +
                Args:
         | 
| 602 | 
            +
                    model (`PreTrainedModel`):
         | 
| 603 | 
            +
                        The model to generate from.
         | 
| 604 | 
            +
                    penalty_alpha (`float`): The alpha value for the degeneration penalty.
         | 
| 605 | 
            +
                    top_k (`int`): The number of candidates to consider at each step.
         | 
| 606 | 
            +
                """
         | 
| 607 | 
            +
                generation_outputs = model.generate(*args, custom_generate=_contrastive_search, **kwargs)
         | 
| 608 | 
            +
                return generation_outputs
         | 
    	
        generation_config.json
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "bos_token_id": 151643,
         | 
| 3 | 
            +
                "do_sample": true,
         | 
| 4 | 
            +
                "eos_token_id": [
         | 
| 5 | 
            +
                    151645,
         | 
| 6 | 
            +
                    151643
         | 
| 7 | 
            +
                ],
         | 
| 8 | 
            +
                "pad_token_id": 151643,
         | 
| 9 | 
            +
                "temperature": 0.6,
         | 
| 10 | 
            +
                "top_k": 20,
         | 
| 11 | 
            +
                "top_p": 0.95,
         | 
| 12 | 
            +
                "transformers_version": "4.56.0"
         | 
| 13 | 
            +
            }
         | 
    	
        merges.txt
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        model.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:f47f71177f32bcd101b7573ec9171e6a57f4f4d31148d38e382306f42996874b
         | 
| 3 | 
            +
            size 1503300328
         | 
    	
        tokenizer.json
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
         | 
| 3 | 
            +
            size 11422654
         | 
    	
        tokenizer_config.json
    ADDED
    
    | @@ -0,0 +1,239 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "add_bos_token": false,
         | 
| 3 | 
            +
              "add_prefix_space": false,
         | 
| 4 | 
            +
              "added_tokens_decoder": {
         | 
| 5 | 
            +
                "151643": {
         | 
| 6 | 
            +
                  "content": "<|endoftext|>",
         | 
| 7 | 
            +
                  "lstrip": false,
         | 
| 8 | 
            +
                  "normalized": false,
         | 
| 9 | 
            +
                  "rstrip": false,
         | 
| 10 | 
            +
                  "single_word": false,
         | 
| 11 | 
            +
                  "special": true
         | 
| 12 | 
            +
                },
         | 
| 13 | 
            +
                "151644": {
         | 
| 14 | 
            +
                  "content": "<|im_start|>",
         | 
| 15 | 
            +
                  "lstrip": false,
         | 
| 16 | 
            +
                  "normalized": false,
         | 
| 17 | 
            +
                  "rstrip": false,
         | 
| 18 | 
            +
                  "single_word": false,
         | 
| 19 | 
            +
                  "special": true
         | 
| 20 | 
            +
                },
         | 
| 21 | 
            +
                "151645": {
         | 
| 22 | 
            +
                  "content": "<|im_end|>",
         | 
| 23 | 
            +
                  "lstrip": false,
         | 
| 24 | 
            +
                  "normalized": false,
         | 
| 25 | 
            +
                  "rstrip": false,
         | 
| 26 | 
            +
                  "single_word": false,
         | 
| 27 | 
            +
                  "special": true
         | 
| 28 | 
            +
                },
         | 
| 29 | 
            +
                "151646": {
         | 
| 30 | 
            +
                  "content": "<|object_ref_start|>",
         | 
| 31 | 
            +
                  "lstrip": false,
         | 
| 32 | 
            +
                  "normalized": false,
         | 
| 33 | 
            +
                  "rstrip": false,
         | 
| 34 | 
            +
                  "single_word": false,
         | 
| 35 | 
            +
                  "special": true
         | 
| 36 | 
            +
                },
         | 
| 37 | 
            +
                "151647": {
         | 
| 38 | 
            +
                  "content": "<|object_ref_end|>",
         | 
| 39 | 
            +
                  "lstrip": false,
         | 
| 40 | 
            +
                  "normalized": false,
         | 
| 41 | 
            +
                  "rstrip": false,
         | 
| 42 | 
            +
                  "single_word": false,
         | 
| 43 | 
            +
                  "special": true
         | 
| 44 | 
            +
                },
         | 
| 45 | 
            +
                "151648": {
         | 
| 46 | 
            +
                  "content": "<|box_start|>",
         | 
| 47 | 
            +
                  "lstrip": false,
         | 
| 48 | 
            +
                  "normalized": false,
         | 
| 49 | 
            +
                  "rstrip": false,
         | 
| 50 | 
            +
                  "single_word": false,
         | 
| 51 | 
            +
                  "special": true
         | 
| 52 | 
            +
                },
         | 
| 53 | 
            +
                "151649": {
         | 
| 54 | 
            +
                  "content": "<|box_end|>",
         | 
| 55 | 
            +
                  "lstrip": false,
         | 
| 56 | 
            +
                  "normalized": false,
         | 
| 57 | 
            +
                  "rstrip": false,
         | 
| 58 | 
            +
                  "single_word": false,
         | 
| 59 | 
            +
                  "special": true
         | 
| 60 | 
            +
                },
         | 
| 61 | 
            +
                "151650": {
         | 
| 62 | 
            +
                  "content": "<|quad_start|>",
         | 
| 63 | 
            +
                  "lstrip": false,
         | 
| 64 | 
            +
                  "normalized": false,
         | 
| 65 | 
            +
                  "rstrip": false,
         | 
| 66 | 
            +
                  "single_word": false,
         | 
| 67 | 
            +
                  "special": true
         | 
| 68 | 
            +
                },
         | 
| 69 | 
            +
                "151651": {
         | 
| 70 | 
            +
                  "content": "<|quad_end|>",
         | 
| 71 | 
            +
                  "lstrip": false,
         | 
| 72 | 
            +
                  "normalized": false,
         | 
| 73 | 
            +
                  "rstrip": false,
         | 
| 74 | 
            +
                  "single_word": false,
         | 
| 75 | 
            +
                  "special": true
         | 
| 76 | 
            +
                },
         | 
| 77 | 
            +
                "151652": {
         | 
| 78 | 
            +
                  "content": "<|vision_start|>",
         | 
| 79 | 
            +
                  "lstrip": false,
         | 
| 80 | 
            +
                  "normalized": false,
         | 
| 81 | 
            +
                  "rstrip": false,
         | 
| 82 | 
            +
                  "single_word": false,
         | 
| 83 | 
            +
                  "special": true
         | 
| 84 | 
            +
                },
         | 
| 85 | 
            +
                "151653": {
         | 
| 86 | 
            +
                  "content": "<|vision_end|>",
         | 
| 87 | 
            +
                  "lstrip": false,
         | 
| 88 | 
            +
                  "normalized": false,
         | 
| 89 | 
            +
                  "rstrip": false,
         | 
| 90 | 
            +
                  "single_word": false,
         | 
| 91 | 
            +
                  "special": true
         | 
| 92 | 
            +
                },
         | 
| 93 | 
            +
                "151654": {
         | 
| 94 | 
            +
                  "content": "<|vision_pad|>",
         | 
| 95 | 
            +
                  "lstrip": false,
         | 
| 96 | 
            +
                  "normalized": false,
         | 
| 97 | 
            +
                  "rstrip": false,
         | 
| 98 | 
            +
                  "single_word": false,
         | 
| 99 | 
            +
                  "special": true
         | 
| 100 | 
            +
                },
         | 
| 101 | 
            +
                "151655": {
         | 
| 102 | 
            +
                  "content": "<|image_pad|>",
         | 
| 103 | 
            +
                  "lstrip": false,
         | 
| 104 | 
            +
                  "normalized": false,
         | 
| 105 | 
            +
                  "rstrip": false,
         | 
| 106 | 
            +
                  "single_word": false,
         | 
| 107 | 
            +
                  "special": true
         | 
| 108 | 
            +
                },
         | 
| 109 | 
            +
                "151656": {
         | 
| 110 | 
            +
                  "content": "<|video_pad|>",
         | 
| 111 | 
            +
                  "lstrip": false,
         | 
| 112 | 
            +
                  "normalized": false,
         | 
| 113 | 
            +
                  "rstrip": false,
         | 
| 114 | 
            +
                  "single_word": false,
         | 
| 115 | 
            +
                  "special": true
         | 
| 116 | 
            +
                },
         | 
| 117 | 
            +
                "151657": {
         | 
| 118 | 
            +
                  "content": "<tool_call>",
         | 
| 119 | 
            +
                  "lstrip": false,
         | 
| 120 | 
            +
                  "normalized": false,
         | 
| 121 | 
            +
                  "rstrip": false,
         | 
| 122 | 
            +
                  "single_word": false,
         | 
| 123 | 
            +
                  "special": false
         | 
| 124 | 
            +
                },
         | 
| 125 | 
            +
                "151658": {
         | 
| 126 | 
            +
                  "content": "</tool_call>",
         | 
| 127 | 
            +
                  "lstrip": false,
         | 
| 128 | 
            +
                  "normalized": false,
         | 
| 129 | 
            +
                  "rstrip": false,
         | 
| 130 | 
            +
                  "single_word": false,
         | 
| 131 | 
            +
                  "special": false
         | 
| 132 | 
            +
                },
         | 
| 133 | 
            +
                "151659": {
         | 
| 134 | 
            +
                  "content": "<|fim_prefix|>",
         | 
| 135 | 
            +
                  "lstrip": false,
         | 
| 136 | 
            +
                  "normalized": false,
         | 
| 137 | 
            +
                  "rstrip": false,
         | 
| 138 | 
            +
                  "single_word": false,
         | 
| 139 | 
            +
                  "special": false
         | 
| 140 | 
            +
                },
         | 
| 141 | 
            +
                "151660": {
         | 
| 142 | 
            +
                  "content": "<|fim_middle|>",
         | 
| 143 | 
            +
                  "lstrip": false,
         | 
| 144 | 
            +
                  "normalized": false,
         | 
| 145 | 
            +
                  "rstrip": false,
         | 
| 146 | 
            +
                  "single_word": false,
         | 
| 147 | 
            +
                  "special": false
         | 
| 148 | 
            +
                },
         | 
| 149 | 
            +
                "151661": {
         | 
| 150 | 
            +
                  "content": "<|fim_suffix|>",
         | 
| 151 | 
            +
                  "lstrip": false,
         | 
| 152 | 
            +
                  "normalized": false,
         | 
| 153 | 
            +
                  "rstrip": false,
         | 
| 154 | 
            +
                  "single_word": false,
         | 
| 155 | 
            +
                  "special": false
         | 
| 156 | 
            +
                },
         | 
| 157 | 
            +
                "151662": {
         | 
| 158 | 
            +
                  "content": "<|fim_pad|>",
         | 
| 159 | 
            +
                  "lstrip": false,
         | 
| 160 | 
            +
                  "normalized": false,
         | 
| 161 | 
            +
                  "rstrip": false,
         | 
| 162 | 
            +
                  "single_word": false,
         | 
| 163 | 
            +
                  "special": false
         | 
| 164 | 
            +
                },
         | 
| 165 | 
            +
                "151663": {
         | 
| 166 | 
            +
                  "content": "<|repo_name|>",
         | 
| 167 | 
            +
                  "lstrip": false,
         | 
| 168 | 
            +
                  "normalized": false,
         | 
| 169 | 
            +
                  "rstrip": false,
         | 
| 170 | 
            +
                  "single_word": false,
         | 
| 171 | 
            +
                  "special": false
         | 
| 172 | 
            +
                },
         | 
| 173 | 
            +
                "151664": {
         | 
| 174 | 
            +
                  "content": "<|file_sep|>",
         | 
| 175 | 
            +
                  "lstrip": false,
         | 
| 176 | 
            +
                  "normalized": false,
         | 
| 177 | 
            +
                  "rstrip": false,
         | 
| 178 | 
            +
                  "single_word": false,
         | 
| 179 | 
            +
                  "special": false
         | 
| 180 | 
            +
                },
         | 
| 181 | 
            +
                "151665": {
         | 
| 182 | 
            +
                  "content": "<tool_response>",
         | 
| 183 | 
            +
                  "lstrip": false,
         | 
| 184 | 
            +
                  "normalized": false,
         | 
| 185 | 
            +
                  "rstrip": false,
         | 
| 186 | 
            +
                  "single_word": false,
         | 
| 187 | 
            +
                  "special": false
         | 
| 188 | 
            +
                },
         | 
| 189 | 
            +
                "151666": {
         | 
| 190 | 
            +
                  "content": "</tool_response>",
         | 
| 191 | 
            +
                  "lstrip": false,
         | 
| 192 | 
            +
                  "normalized": false,
         | 
| 193 | 
            +
                  "rstrip": false,
         | 
| 194 | 
            +
                  "single_word": false,
         | 
| 195 | 
            +
                  "special": false
         | 
| 196 | 
            +
                },
         | 
| 197 | 
            +
                "151667": {
         | 
| 198 | 
            +
                  "content": "<think>",
         | 
| 199 | 
            +
                  "lstrip": false,
         | 
| 200 | 
            +
                  "normalized": false,
         | 
| 201 | 
            +
                  "rstrip": false,
         | 
| 202 | 
            +
                  "single_word": false,
         | 
| 203 | 
            +
                  "special": false
         | 
| 204 | 
            +
                },
         | 
| 205 | 
            +
                "151668": {
         | 
| 206 | 
            +
                  "content": "</think>",
         | 
| 207 | 
            +
                  "lstrip": false,
         | 
| 208 | 
            +
                  "normalized": false,
         | 
| 209 | 
            +
                  "rstrip": false,
         | 
| 210 | 
            +
                  "single_word": false,
         | 
| 211 | 
            +
                  "special": false
         | 
| 212 | 
            +
                }
         | 
| 213 | 
            +
              },
         | 
| 214 | 
            +
              "additional_special_tokens": [
         | 
| 215 | 
            +
                "<|im_start|>",
         | 
| 216 | 
            +
                "<|im_end|>",
         | 
| 217 | 
            +
                "<|object_ref_start|>",
         | 
| 218 | 
            +
                "<|object_ref_end|>",
         | 
| 219 | 
            +
                "<|box_start|>",
         | 
| 220 | 
            +
                "<|box_end|>",
         | 
| 221 | 
            +
                "<|quad_start|>",
         | 
| 222 | 
            +
                "<|quad_end|>",
         | 
| 223 | 
            +
                "<|vision_start|>",
         | 
| 224 | 
            +
                "<|vision_end|>",
         | 
| 225 | 
            +
                "<|vision_pad|>",
         | 
| 226 | 
            +
                "<|image_pad|>",
         | 
| 227 | 
            +
                "<|video_pad|>"
         | 
| 228 | 
            +
              ],
         | 
| 229 | 
            +
              "bos_token": null,
         | 
| 230 | 
            +
              "chat_template": "{%- if tools %}\n    {{- '<|im_start|>system\\n' }}\n    {%- if messages[0].role == 'system' %}\n        {{- messages[0].content + '\\n\\n' }}\n    {%- endif %}\n    {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n    {%- for tool in tools %}\n        {{- \"\\n\" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n    {%- if messages[0].role == 'system' %}\n        {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n    {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n    {%- set index = (messages|length - 1) - loop.index0 %}\n    {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n        {%- set ns.multi_step_tool = false %}\n        {%- set ns.last_query_index = index %}\n    {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n    {%- if message.content is string %}\n        {%- set content = message.content %}\n    {%- else %}\n        {%- set content = '' %}\n    {%- endif %}\n    {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n        {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n    {%- elif message.role == \"assistant\" %}\n        {%- set reasoning_content = '' %}\n        {%- if message.reasoning_content is string %}\n            {%- set reasoning_content = message.reasoning_content %}\n        {%- else %}\n            {%- if '</think>' in content %}\n                {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n                {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n            {%- endif %}\n        {%- endif %}\n        {%- if loop.index0 > ns.last_query_index %}\n            {%- if loop.last or (not loop.last and reasoning_content) %}\n                {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n            {%- else %}\n                {{- '<|im_start|>' + message.role + '\\n' + content }}\n            {%- endif %}\n        {%- else %}\n            {{- '<|im_start|>' + message.role + '\\n' + content }}\n        {%- endif %}\n        {%- if message.tool_calls %}\n            {%- for tool_call in message.tool_calls %}\n                {%- if (loop.first and content) or (not loop.first) %}\n                    {{- '\\n' }}\n                {%- endif %}\n                {%- if tool_call.function %}\n                    {%- set tool_call = tool_call.function %}\n                {%- endif %}\n                {{- '<tool_call>\\n{\"name\": \"' }}\n                {{- tool_call.name }}\n                {{- '\", \"arguments\": ' }}\n                {%- if tool_call.arguments is string %}\n                    {{- tool_call.arguments }}\n                {%- else %}\n                    {{- tool_call.arguments | tojson }}\n                {%- endif %}\n                {{- '}\\n</tool_call>' }}\n            {%- endfor %}\n        {%- endif %}\n        {{- '<|im_end|>\\n' }}\n    {%- elif message.role == \"tool\" %}\n        {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n            {{- '<|im_start|>user' }}\n        {%- endif %}\n        {{- '\\n<tool_response>\\n' }}\n        {{- content }}\n        {{- '\\n</tool_response>' }}\n        {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n            {{- '<|im_end|>\\n' }}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|im_start|>assistant\\n' }}\n    {%- if enable_thinking is defined and enable_thinking is false %}\n        {{- '<think>\\n\\n</think>\\n\\n' }}\n    {%- endif %}\n{%- endif %}",
         | 
| 231 | 
            +
              "clean_up_tokenization_spaces": false,
         | 
| 232 | 
            +
              "eos_token": "<|im_end|>",
         | 
| 233 | 
            +
              "errors": "replace",
         | 
| 234 | 
            +
              "model_max_length": 131072,
         | 
| 235 | 
            +
              "pad_token": "<|endoftext|>",
         | 
| 236 | 
            +
              "split_special_tokens": false,
         | 
| 237 | 
            +
              "tokenizer_class": "Qwen2Tokenizer",
         | 
| 238 | 
            +
              "unk_token": null
         | 
| 239 | 
            +
            }
         | 
    	
        vocab.json
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 

