Commit 
							
							·
						
						c5a6a24
	
1
								Parent(s):
							
							7d8f12e
								
init spaces
Browse files- __pycache__/run.cpython-310.pyc +0 -0
- app.py +44 -0
- requirements.txt +5 -0
- run.py +210 -0
- saved_model/last-iter-015000-ckpt.pth +3 -0
- tokenizer_Llama-2-7b-chat-hf/generation_config.json +10 -0
- tokenizer_Llama-2-7b-chat-hf/lit_config.json +1 -0
- tokenizer_Llama-2-7b-chat-hf/pytorch_model.bin.index.json +330 -0
- tokenizer_Llama-2-7b-chat-hf/tokenizer.json +0 -0
- tokenizer_Llama-2-7b-chat-hf/tokenizer.model +3 -0
- tokenizer_Llama-2-7b-chat-hf/tokenizer_config.json +36 -0
- tsai_gpt/__init__.py +15 -0
- tsai_gpt/__pycache__/__init__.cpython-310.pyc +0 -0
- tsai_gpt/__pycache__/__init__.cpython-39.pyc +0 -0
- tsai_gpt/__pycache__/config.cpython-310.pyc +0 -0
- tsai_gpt/__pycache__/config.cpython-39.pyc +0 -0
- tsai_gpt/__pycache__/model.cpython-310.pyc +0 -0
- tsai_gpt/__pycache__/model.cpython-39.pyc +0 -0
- tsai_gpt/__pycache__/packed_dataset.cpython-39.pyc +0 -0
- tsai_gpt/__pycache__/speed_monitor.cpython-39.pyc +0 -0
- tsai_gpt/__pycache__/tokenizer.cpython-310.pyc +0 -0
- tsai_gpt/__pycache__/tokenizer.cpython-39.pyc +0 -0
- tsai_gpt/__pycache__/utils.cpython-310.pyc +0 -0
- tsai_gpt/__pycache__/utils.cpython-39.pyc +0 -0
- tsai_gpt/config.py +1192 -0
- tsai_gpt/model.py +367 -0
- tsai_gpt/packed_dataset.py +254 -0
- tsai_gpt/rmsnorm.py +26 -0
- tsai_gpt/speed_monitor.py +438 -0
- tsai_gpt/tokenizer.py +106 -0
- tsai_gpt/utils.py +367 -0
    	
        __pycache__/run.cpython-310.pyc
    ADDED
    
    | Binary file (6.19 kB). View file | 
|  | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,44 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import run
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            title = "Lit GPT: Pythia 160M "
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            with gr.Blocks(title=title) as interface:
         | 
| 8 | 
            +
                with gr.Row():
         | 
| 9 | 
            +
                    prompt = gr.Textbox(label="Input Text")
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                    temperature = gr.Slider(
         | 
| 12 | 
            +
                        0,
         | 
| 13 | 
            +
                        1,
         | 
| 14 | 
            +
                        value=0.8,
         | 
| 15 | 
            +
                        label="Temperature",
         | 
| 16 | 
            +
                        info="Set the creativity level: Higher values produce more varied results, lower values generate more predictable text.",
         | 
| 17 | 
            +
                    )
         | 
| 18 | 
            +
                    top_k = gr.Slider(
         | 
| 19 | 
            +
                        200,
         | 
| 20 | 
            +
                        300,
         | 
| 21 | 
            +
                        value=200,
         | 
| 22 | 
            +
                        label="Top K",
         | 
| 23 | 
            +
                        info="Control the randomness: Limits the AI to consider only the top K most likely next words.",
         | 
| 24 | 
            +
                    )
         | 
| 25 | 
            +
                    max_new_tokens = gr.Slider(
         | 
| 26 | 
            +
                        10,
         | 
| 27 | 
            +
                        500,
         | 
| 28 | 
            +
                        value=500,
         | 
| 29 | 
            +
                        label="Max Tokens",
         | 
| 30 | 
            +
                        info="top most preferable tokens to consider in the sampling process",
         | 
| 31 | 
            +
                    )
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    inputs = [prompt, max_new_tokens, top_k, temperature]
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                with gr.Column():
         | 
| 36 | 
            +
                    outputs = gr.Textbox(label="Generated")
         | 
| 37 | 
            +
                    button = gr.Button("Generate")
         | 
| 38 | 
            +
                    button.click(run.generate_from_prompt, inputs=inputs, outputs=outputs)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                # with gr.Row():
         | 
| 41 | 
            +
                #     gr.Examples(examples=examples, inputs=inputs, outputs=outputs, fn=generate_dialogue, cache_examples=True,)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            interface.launch()
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            torch>=2.1.0
         | 
| 2 | 
            +
            lightning @ git+https://github.com/Lightning-AI/lightning@6cbe9ceb560d798892bdae9186291acf9bf5d2e3 
         | 
| 3 | 
            +
            tokenizers           
         | 
| 4 | 
            +
            sentencepiece
         | 
| 5 | 
            +
            bitsandbytes==0.41.0 
         | 
    	
        run.py
    ADDED
    
    | @@ -0,0 +1,210 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
            from pathlib import Path
         | 
| 4 | 
            +
            from typing import Any, Literal, Optional
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import lightning as L
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from lightning.fabric.plugins import BitsandbytesPrecision
         | 
| 9 | 
            +
            from lightning.fabric.strategies import FSDPStrategy
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from tsai_gpt.model import GPT, Block, Config
         | 
| 12 | 
            +
            from tsai_gpt.tokenizer import Tokenizer
         | 
| 13 | 
            +
            from tsai_gpt.utils import (get_default_supported_precision, gptq_quantization,
         | 
| 14 | 
            +
                                        load_checkpoint)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            L.seed_everything(1234)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
         | 
| 20 | 
            +
                if torch._dynamo.is_compiling():
         | 
| 21 | 
            +
                    # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
         | 
| 22 | 
            +
                    distribution = torch.empty_like(probs).exponential_(1)
         | 
| 23 | 
            +
                    return torch.argmax(probs / distribution, dim=-1, keepdim=True)
         | 
| 24 | 
            +
                return torch.multinomial(probs, num_samples=1)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def sample(
         | 
| 28 | 
            +
                logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None
         | 
| 29 | 
            +
            ) -> torch.Tensor:
         | 
| 30 | 
            +
                logits = logits[0, -1]
         | 
| 31 | 
            +
                # optionally crop the logits to only the top k options
         | 
| 32 | 
            +
                if top_k is not None:
         | 
| 33 | 
            +
                    v, i = torch.topk(logits, min(top_k, logits.size(-1)))
         | 
| 34 | 
            +
                    # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
         | 
| 35 | 
            +
                    logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
         | 
| 36 | 
            +
                # optionally scale the logits and sample from a probability distribution
         | 
| 37 | 
            +
                if temperature > 0.0:
         | 
| 38 | 
            +
                    probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
         | 
| 39 | 
            +
                    return multinomial_num_samples_1(probs)
         | 
| 40 | 
            +
                return torch.argmax(logits, dim=-1, keepdim=True)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def next_token(
         | 
| 44 | 
            +
                model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any
         | 
| 45 | 
            +
            ) -> torch.Tensor:
         | 
| 46 | 
            +
                logits = model(x, input_pos)
         | 
| 47 | 
            +
                next = sample(logits, **kwargs)
         | 
| 48 | 
            +
                return next.type_as(x)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            @torch.inference_mode()
         | 
| 52 | 
            +
            def generate(
         | 
| 53 | 
            +
                model: GPT,
         | 
| 54 | 
            +
                prompt: torch.Tensor,
         | 
| 55 | 
            +
                max_returned_tokens: int,
         | 
| 56 | 
            +
                *,
         | 
| 57 | 
            +
                temperature: float = 1.0,
         | 
| 58 | 
            +
                top_k: Optional[int] = None,
         | 
| 59 | 
            +
                eos_id: Optional[int] = None,
         | 
| 60 | 
            +
            ) -> torch.Tensor:
         | 
| 61 | 
            +
                """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                The implementation of this function is modified from A. Karpathy's nanoGPT.
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                Args:
         | 
| 66 | 
            +
                    model: The model to use.
         | 
| 67 | 
            +
                    prompt: Tensor of shape (T) with indices of the prompt sequence.
         | 
| 68 | 
            +
                    max_returned_tokens: The maximum number of tokens to return (given plus generated).
         | 
| 69 | 
            +
                    temperature: Scales the predicted logits by 1 / temperature.
         | 
| 70 | 
            +
                    top_k: If specified, only sample among the tokens with the k highest probabilities.
         | 
| 71 | 
            +
                    eos_id: If specified, stop generating any more token once the <eos> token is triggered.
         | 
| 72 | 
            +
                """
         | 
| 73 | 
            +
                T = prompt.size(0)
         | 
| 74 | 
            +
                assert max_returned_tokens > T
         | 
| 75 | 
            +
                if model.max_seq_length < max_returned_tokens - 1:
         | 
| 76 | 
            +
                    # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
         | 
| 77 | 
            +
                    # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
         | 
| 78 | 
            +
                    # not support it to avoid negatively impacting the overall speed
         | 
| 79 | 
            +
                    raise NotImplementedError(
         | 
| 80 | 
            +
                        f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
         | 
| 81 | 
            +
                    )
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                device = prompt.device
         | 
| 84 | 
            +
                tokens = [prompt]
         | 
| 85 | 
            +
                input_pos = torch.tensor([T], device=device)
         | 
| 86 | 
            +
                token = next_token(
         | 
| 87 | 
            +
                    model,
         | 
| 88 | 
            +
                    torch.arange(0, T, device=device),
         | 
| 89 | 
            +
                    prompt.view(1, -1),
         | 
| 90 | 
            +
                    temperature=temperature,
         | 
| 91 | 
            +
                    top_k=top_k,
         | 
| 92 | 
            +
                ).clone()
         | 
| 93 | 
            +
                tokens.append(token)
         | 
| 94 | 
            +
                for _ in range(2, max_returned_tokens - T + 1):
         | 
| 95 | 
            +
                    token = next_token(
         | 
| 96 | 
            +
                        model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k
         | 
| 97 | 
            +
                    ).clone()
         | 
| 98 | 
            +
                    tokens.append(token)
         | 
| 99 | 
            +
                    if token == eos_id:
         | 
| 100 | 
            +
                        break
         | 
| 101 | 
            +
                    input_pos = input_pos.add_(1)
         | 
| 102 | 
            +
                return torch.cat(tokens)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            """
         | 
| 106 | 
            +
            quantize (Optional[Literal["bnb.nf4", "bnb.nf4, optional): quantization method to use. Defaults to None.
         | 
| 107 | 
            +
                - "bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq": 4-bit quantization bitsandbytes
         | 
| 108 | 
            +
                - "bnb.int8": 8-bit quantization bitsandbytes
         | 
| 109 | 
            +
                - "gptq.int4": 4-bit quantization GPTQ
         | 
| 110 | 
            +
                for more details see: https://github.com/facebookresearch/bitsandbytes, https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md
         | 
| 111 | 
            +
            strategy (str, optional): Fabric strategy setting. Defaults to "auto".
         | 
| 112 | 
            +
            devices (int, optional): number of devices to be used. Defaults to 1.
         | 
| 113 | 
            +
            precision (Optional[str], optional): fabic precision settings. Defaults to None.
         | 
| 114 | 
            +
            """
         | 
| 115 | 
            +
             | 
| 116 | 
            +
            chptk_path: str = "saved_model/last-iter-015000-ckpt.pth"
         | 
| 117 | 
            +
            tokenizer_path: str = "tokenizer_Llama-2-7b-chat-hf"
         | 
| 118 | 
            +
            quantize: Optional[
         | 
| 119 | 
            +
                Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]
         | 
| 120 | 
            +
            ] = None
         | 
| 121 | 
            +
            strategy: str = "auto"
         | 
| 122 | 
            +
            devices: int = 1
         | 
| 123 | 
            +
            precision: Optional[str] = None
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            precision = precision or get_default_supported_precision(training=False)
         | 
| 126 | 
            +
            plugins = None
         | 
| 127 | 
            +
            if quantize is not None:
         | 
| 128 | 
            +
                if devices > 1:
         | 
| 129 | 
            +
                    raise NotImplemented("Multi-GPU quantization is not supported yet.")
         | 
| 130 | 
            +
                if quantize.startswith("bnb."):
         | 
| 131 | 
            +
                    if "mixed" in precision:
         | 
| 132 | 
            +
                        raise ValueError("Quantization and mixed precision is not supported.")
         | 
| 133 | 
            +
                    dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[
         | 
| 134 | 
            +
                        precision
         | 
| 135 | 
            +
                    ]
         | 
| 136 | 
            +
                    plugins = BitsandbytesPrecision(quantize[4:], dtype)
         | 
| 137 | 
            +
                    precision = None
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            if strategy == "fsdp":
         | 
| 140 | 
            +
                strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
            fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, plugins=plugins)
         | 
| 143 | 
            +
            fabric.launch()
         | 
| 144 | 
            +
             | 
| 145 | 
            +
            tokenizer = Tokenizer(Path("tokenizer_Llama-2-7b-chat-hf"))
         | 
| 146 | 
            +
            config = Config.from_name("pythia-160m")
         | 
| 147 | 
            +
             | 
| 148 | 
            +
            fabric.print(f"Loading model from {chptk_path}", file=sys.stderr)
         | 
| 149 | 
            +
            t0 = time.perf_counter()
         | 
| 150 | 
            +
            with fabric.init_module(empty_init=True), gptq_quantization(quantize == "gptq.int4"):
         | 
| 151 | 
            +
                model = GPT(config)
         | 
| 152 | 
            +
            fabric.print(
         | 
| 153 | 
            +
                f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr
         | 
| 154 | 
            +
            )
         | 
| 155 | 
            +
            with fabric.init_tensor():
         | 
| 156 | 
            +
                # enable the kv cache
         | 
| 157 | 
            +
                model.set_kv_cache(batch_size=1)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
            model.eval()
         | 
| 160 | 
            +
            model = fabric.setup_module(model)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
            t0 = time.perf_counter()
         | 
| 163 | 
            +
            load_checkpoint(fabric, model, chptk_path)
         | 
| 164 | 
            +
            fabric.print(
         | 
| 165 | 
            +
                f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr
         | 
| 166 | 
            +
            )
         | 
| 167 | 
            +
             | 
| 168 | 
            +
             | 
| 169 | 
            +
            def generate_from_prompt(
         | 
| 170 | 
            +
                prompt: str = "",
         | 
| 171 | 
            +
                max_new_tokens: int = 500,
         | 
| 172 | 
            +
                top_k: int = 200,
         | 
| 173 | 
            +
                temperature: float = 0.8,
         | 
| 174 | 
            +
            ):
         | 
| 175 | 
            +
                """Generate text from a prompt using pre-trained model
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                Args:
         | 
| 178 | 
            +
                    prompt (str, optional): Prompt string to be used for generating samples. Defaults to "".
         | 
| 179 | 
            +
                    num_samples (int, optional): Number of samples to be generated. Defaults to 1.
         | 
| 180 | 
            +
                    max_new_tokens (int, optional): number of generation steps to take. Defaults to 500.
         | 
| 181 | 
            +
                    top_k (int, optional): top most preferable tokens to consider in the sampling process. Defaults to 200.
         | 
| 182 | 
            +
                    temperature (float, optional): Control randomness for sampelling process. Defaults to 0.8.
         | 
| 183 | 
            +
                """
         | 
| 184 | 
            +
                encoded = tokenizer.encode(prompt, device=fabric.device)
         | 
| 185 | 
            +
                prompt_length = encoded.size(0)
         | 
| 186 | 
            +
                max_returned_tokens = prompt_length + max_new_tokens
         | 
| 187 | 
            +
                with fabric.init_tensor():
         | 
| 188 | 
            +
                    # set the max_seq_length to limit the memory usage to what we need
         | 
| 189 | 
            +
                    model.max_seq_length = max_returned_tokens
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                num_samples: int = 1
         | 
| 192 | 
            +
                for i in range(num_samples):
         | 
| 193 | 
            +
                    t0 = time.perf_counter()
         | 
| 194 | 
            +
                    y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k)
         | 
| 195 | 
            +
                    t = time.perf_counter() - t0
         | 
| 196 | 
            +
                    # for block in model.transformer.h:
         | 
| 197 | 
            +
                    #     block.attn.kv_cache.reset_parameters()
         | 
| 198 | 
            +
                    pred = tokenizer.decode(y)
         | 
| 199 | 
            +
                    fabric.print(pred)
         | 
| 200 | 
            +
                    tokens_generated = y.size(0) - prompt_length
         | 
| 201 | 
            +
                    fabric.print(
         | 
| 202 | 
            +
                        f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec",
         | 
| 203 | 
            +
                        file=sys.stderr,
         | 
| 204 | 
            +
                    )
         | 
| 205 | 
            +
                if fabric.device.type == "cuda":
         | 
| 206 | 
            +
                    fabric.print(
         | 
| 207 | 
            +
                        f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr
         | 
| 208 | 
            +
                    )
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                return pred
         | 
    	
        saved_model/last-iter-015000-ckpt.pth
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:a3a07dec62fbdcda721f26f4581e736995ee0c195f46ce2f9587a20e969a996c
         | 
| 3 | 
            +
            size 1948052114
         | 
    	
        tokenizer_Llama-2-7b-chat-hf/generation_config.json
    ADDED
    
    | @@ -0,0 +1,10 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "bos_token_id": 1,
         | 
| 3 | 
            +
              "do_sample": true,
         | 
| 4 | 
            +
              "eos_token_id": 2,
         | 
| 5 | 
            +
              "max_length": 4096,
         | 
| 6 | 
            +
              "pad_token_id": 0,
         | 
| 7 | 
            +
              "temperature": 0.6,
         | 
| 8 | 
            +
              "top_p": 0.9,
         | 
| 9 | 
            +
              "transformers_version": "4.32.0.dev0"
         | 
| 10 | 
            +
            }
         | 
    	
        tokenizer_Llama-2-7b-chat-hf/lit_config.json
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            {"name": "Llama-2-7b-chat-hf", "hf_config": {"org": "meta-llama", "name": "Llama-2-7b-chat-hf"}, "block_size": 4096, "vocab_size": 32000, "padding_multiple": 64, "padded_vocab_size": 32000, "n_layer": 32, "n_head": 32, "n_embd": 4096, "rotary_percentage": 1.0, "parallel_residual": false, "bias": false, "lm_head_bias": false, "n_query_groups": 32, "shared_attention_norm": false, "_norm_class": "RMSNorm", "norm_eps": 1e-05, "_mlp_class": "LLaMAMLP", "gelu_approximate": "none", "intermediate_size": 11008, "rope_condense_ratio": 1, "rope_base": 10000}
         | 
    	
        tokenizer_Llama-2-7b-chat-hf/pytorch_model.bin.index.json
    ADDED
    
    | @@ -0,0 +1,330 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "metadata": {
         | 
| 3 | 
            +
                "total_size": 13476839424
         | 
| 4 | 
            +
              },
         | 
| 5 | 
            +
              "weight_map": {
         | 
| 6 | 
            +
                "lm_head.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 7 | 
            +
                "model.embed_tokens.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 8 | 
            +
                "model.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 9 | 
            +
                "model.layers.0.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 10 | 
            +
                "model.layers.0.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 11 | 
            +
                "model.layers.0.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 12 | 
            +
                "model.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 13 | 
            +
                "model.layers.0.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 14 | 
            +
                "model.layers.0.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 15 | 
            +
                "model.layers.0.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 16 | 
            +
                "model.layers.0.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 17 | 
            +
                "model.layers.0.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 18 | 
            +
                "model.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 19 | 
            +
                "model.layers.1.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 20 | 
            +
                "model.layers.1.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 21 | 
            +
                "model.layers.1.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 22 | 
            +
                "model.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 23 | 
            +
                "model.layers.1.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 24 | 
            +
                "model.layers.1.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 25 | 
            +
                "model.layers.1.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 26 | 
            +
                "model.layers.1.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 27 | 
            +
                "model.layers.1.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 28 | 
            +
                "model.layers.10.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 29 | 
            +
                "model.layers.10.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 30 | 
            +
                "model.layers.10.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 31 | 
            +
                "model.layers.10.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 32 | 
            +
                "model.layers.10.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 33 | 
            +
                "model.layers.10.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 34 | 
            +
                "model.layers.10.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 35 | 
            +
                "model.layers.10.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 36 | 
            +
                "model.layers.10.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 37 | 
            +
                "model.layers.10.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 38 | 
            +
                "model.layers.11.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 39 | 
            +
                "model.layers.11.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 40 | 
            +
                "model.layers.11.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 41 | 
            +
                "model.layers.11.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 42 | 
            +
                "model.layers.11.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 43 | 
            +
                "model.layers.11.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 44 | 
            +
                "model.layers.11.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 45 | 
            +
                "model.layers.11.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 46 | 
            +
                "model.layers.11.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 47 | 
            +
                "model.layers.11.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 48 | 
            +
                "model.layers.12.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 49 | 
            +
                "model.layers.12.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 50 | 
            +
                "model.layers.12.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 51 | 
            +
                "model.layers.12.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 52 | 
            +
                "model.layers.12.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 53 | 
            +
                "model.layers.12.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 54 | 
            +
                "model.layers.12.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 55 | 
            +
                "model.layers.12.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 56 | 
            +
                "model.layers.12.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 57 | 
            +
                "model.layers.12.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 58 | 
            +
                "model.layers.13.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 59 | 
            +
                "model.layers.13.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 60 | 
            +
                "model.layers.13.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 61 | 
            +
                "model.layers.13.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 62 | 
            +
                "model.layers.13.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 63 | 
            +
                "model.layers.13.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 64 | 
            +
                "model.layers.13.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 65 | 
            +
                "model.layers.13.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 66 | 
            +
                "model.layers.13.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 67 | 
            +
                "model.layers.13.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 68 | 
            +
                "model.layers.14.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 69 | 
            +
                "model.layers.14.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 70 | 
            +
                "model.layers.14.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 71 | 
            +
                "model.layers.14.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 72 | 
            +
                "model.layers.14.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 73 | 
            +
                "model.layers.14.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 74 | 
            +
                "model.layers.14.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 75 | 
            +
                "model.layers.14.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 76 | 
            +
                "model.layers.14.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 77 | 
            +
                "model.layers.14.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 78 | 
            +
                "model.layers.15.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 79 | 
            +
                "model.layers.15.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 80 | 
            +
                "model.layers.15.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 81 | 
            +
                "model.layers.15.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 82 | 
            +
                "model.layers.15.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 83 | 
            +
                "model.layers.15.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 84 | 
            +
                "model.layers.15.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 85 | 
            +
                "model.layers.15.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 86 | 
            +
                "model.layers.15.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 87 | 
            +
                "model.layers.15.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 88 | 
            +
                "model.layers.16.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 89 | 
            +
                "model.layers.16.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 90 | 
            +
                "model.layers.16.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 91 | 
            +
                "model.layers.16.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 92 | 
            +
                "model.layers.16.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 93 | 
            +
                "model.layers.16.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 94 | 
            +
                "model.layers.16.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 95 | 
            +
                "model.layers.16.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 96 | 
            +
                "model.layers.16.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 97 | 
            +
                "model.layers.16.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 98 | 
            +
                "model.layers.17.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 99 | 
            +
                "model.layers.17.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 100 | 
            +
                "model.layers.17.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 101 | 
            +
                "model.layers.17.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 102 | 
            +
                "model.layers.17.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 103 | 
            +
                "model.layers.17.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 104 | 
            +
                "model.layers.17.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 105 | 
            +
                "model.layers.17.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 106 | 
            +
                "model.layers.17.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 107 | 
            +
                "model.layers.17.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 108 | 
            +
                "model.layers.18.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 109 | 
            +
                "model.layers.18.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 110 | 
            +
                "model.layers.18.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 111 | 
            +
                "model.layers.18.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 112 | 
            +
                "model.layers.18.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 113 | 
            +
                "model.layers.18.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 114 | 
            +
                "model.layers.18.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 115 | 
            +
                "model.layers.18.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 116 | 
            +
                "model.layers.18.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 117 | 
            +
                "model.layers.18.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 118 | 
            +
                "model.layers.19.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 119 | 
            +
                "model.layers.19.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 120 | 
            +
                "model.layers.19.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 121 | 
            +
                "model.layers.19.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 122 | 
            +
                "model.layers.19.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 123 | 
            +
                "model.layers.19.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 124 | 
            +
                "model.layers.19.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 125 | 
            +
                "model.layers.19.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 126 | 
            +
                "model.layers.19.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 127 | 
            +
                "model.layers.19.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 128 | 
            +
                "model.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 129 | 
            +
                "model.layers.2.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 130 | 
            +
                "model.layers.2.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 131 | 
            +
                "model.layers.2.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 132 | 
            +
                "model.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 133 | 
            +
                "model.layers.2.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 134 | 
            +
                "model.layers.2.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 135 | 
            +
                "model.layers.2.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 136 | 
            +
                "model.layers.2.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 137 | 
            +
                "model.layers.2.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 138 | 
            +
                "model.layers.20.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 139 | 
            +
                "model.layers.20.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 140 | 
            +
                "model.layers.20.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 141 | 
            +
                "model.layers.20.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 142 | 
            +
                "model.layers.20.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 143 | 
            +
                "model.layers.20.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 144 | 
            +
                "model.layers.20.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 145 | 
            +
                "model.layers.20.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 146 | 
            +
                "model.layers.20.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 147 | 
            +
                "model.layers.20.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 148 | 
            +
                "model.layers.21.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 149 | 
            +
                "model.layers.21.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 150 | 
            +
                "model.layers.21.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 151 | 
            +
                "model.layers.21.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 152 | 
            +
                "model.layers.21.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 153 | 
            +
                "model.layers.21.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 154 | 
            +
                "model.layers.21.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 155 | 
            +
                "model.layers.21.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 156 | 
            +
                "model.layers.21.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 157 | 
            +
                "model.layers.21.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 158 | 
            +
                "model.layers.22.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 159 | 
            +
                "model.layers.22.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 160 | 
            +
                "model.layers.22.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 161 | 
            +
                "model.layers.22.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 162 | 
            +
                "model.layers.22.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 163 | 
            +
                "model.layers.22.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 164 | 
            +
                "model.layers.22.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 165 | 
            +
                "model.layers.22.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 166 | 
            +
                "model.layers.22.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 167 | 
            +
                "model.layers.22.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 168 | 
            +
                "model.layers.23.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 169 | 
            +
                "model.layers.23.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 170 | 
            +
                "model.layers.23.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 171 | 
            +
                "model.layers.23.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 172 | 
            +
                "model.layers.23.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 173 | 
            +
                "model.layers.23.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 174 | 
            +
                "model.layers.23.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 175 | 
            +
                "model.layers.23.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 176 | 
            +
                "model.layers.23.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 177 | 
            +
                "model.layers.23.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 178 | 
            +
                "model.layers.24.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 179 | 
            +
                "model.layers.24.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 180 | 
            +
                "model.layers.24.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 181 | 
            +
                "model.layers.24.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 182 | 
            +
                "model.layers.24.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 183 | 
            +
                "model.layers.24.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 184 | 
            +
                "model.layers.24.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 185 | 
            +
                "model.layers.24.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 186 | 
            +
                "model.layers.24.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
         | 
| 187 | 
            +
                "model.layers.24.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 188 | 
            +
                "model.layers.25.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 189 | 
            +
                "model.layers.25.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 190 | 
            +
                "model.layers.25.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 191 | 
            +
                "model.layers.25.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 192 | 
            +
                "model.layers.25.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 193 | 
            +
                "model.layers.25.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 194 | 
            +
                "model.layers.25.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 195 | 
            +
                "model.layers.25.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 196 | 
            +
                "model.layers.25.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
         | 
| 197 | 
            +
                "model.layers.25.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 198 | 
            +
                "model.layers.26.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 199 | 
            +
                "model.layers.26.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 200 | 
            +
                "model.layers.26.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 201 | 
            +
                "model.layers.26.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 202 | 
            +
                "model.layers.26.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 203 | 
            +
                "model.layers.26.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 204 | 
            +
                "model.layers.26.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 205 | 
            +
                "model.layers.26.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 206 | 
            +
                "model.layers.26.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
         | 
| 207 | 
            +
                "model.layers.26.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 208 | 
            +
                "model.layers.27.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 209 | 
            +
                "model.layers.27.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 210 | 
            +
                "model.layers.27.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 211 | 
            +
                "model.layers.27.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 212 | 
            +
                "model.layers.27.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 213 | 
            +
                "model.layers.27.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 214 | 
            +
                "model.layers.27.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 215 | 
            +
                "model.layers.27.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 216 | 
            +
                "model.layers.27.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
         | 
| 217 | 
            +
                "model.layers.27.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 218 | 
            +
                "model.layers.28.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 219 | 
            +
                "model.layers.28.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 220 | 
            +
                "model.layers.28.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 221 | 
            +
                "model.layers.28.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 222 | 
            +
                "model.layers.28.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 223 | 
            +
                "model.layers.28.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 224 | 
            +
                "model.layers.28.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 225 | 
            +
                "model.layers.28.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 226 | 
            +
                "model.layers.28.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
         | 
| 227 | 
            +
                "model.layers.28.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 228 | 
            +
                "model.layers.29.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 229 | 
            +
                "model.layers.29.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 230 | 
            +
                "model.layers.29.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 231 | 
            +
                "model.layers.29.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 232 | 
            +
                "model.layers.29.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 233 | 
            +
                "model.layers.29.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 234 | 
            +
                "model.layers.29.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 235 | 
            +
                "model.layers.29.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 236 | 
            +
                "model.layers.29.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
         | 
| 237 | 
            +
                "model.layers.29.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 238 | 
            +
                "model.layers.3.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 239 | 
            +
                "model.layers.3.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 240 | 
            +
                "model.layers.3.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 241 | 
            +
                "model.layers.3.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 242 | 
            +
                "model.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 243 | 
            +
                "model.layers.3.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 244 | 
            +
                "model.layers.3.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 245 | 
            +
                "model.layers.3.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 246 | 
            +
                "model.layers.3.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 247 | 
            +
                "model.layers.3.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 248 | 
            +
                "model.layers.30.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 249 | 
            +
                "model.layers.30.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 250 | 
            +
                "model.layers.30.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 251 | 
            +
                "model.layers.30.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 252 | 
            +
                "model.layers.30.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 253 | 
            +
                "model.layers.30.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 254 | 
            +
                "model.layers.30.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 255 | 
            +
                "model.layers.30.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 256 | 
            +
                "model.layers.30.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
         | 
| 257 | 
            +
                "model.layers.30.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 258 | 
            +
                "model.layers.31.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 259 | 
            +
                "model.layers.31.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 260 | 
            +
                "model.layers.31.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 261 | 
            +
                "model.layers.31.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 262 | 
            +
                "model.layers.31.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 263 | 
            +
                "model.layers.31.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 264 | 
            +
                "model.layers.31.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 265 | 
            +
                "model.layers.31.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 266 | 
            +
                "model.layers.31.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
         | 
| 267 | 
            +
                "model.layers.31.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
         | 
| 268 | 
            +
                "model.layers.4.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 269 | 
            +
                "model.layers.4.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 270 | 
            +
                "model.layers.4.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 271 | 
            +
                "model.layers.4.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 272 | 
            +
                "model.layers.4.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 273 | 
            +
                "model.layers.4.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 274 | 
            +
                "model.layers.4.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 275 | 
            +
                "model.layers.4.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 276 | 
            +
                "model.layers.4.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 277 | 
            +
                "model.layers.4.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 278 | 
            +
                "model.layers.5.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 279 | 
            +
                "model.layers.5.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 280 | 
            +
                "model.layers.5.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 281 | 
            +
                "model.layers.5.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 282 | 
            +
                "model.layers.5.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 283 | 
            +
                "model.layers.5.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 284 | 
            +
                "model.layers.5.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 285 | 
            +
                "model.layers.5.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 286 | 
            +
                "model.layers.5.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 287 | 
            +
                "model.layers.5.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 288 | 
            +
                "model.layers.6.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 289 | 
            +
                "model.layers.6.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 290 | 
            +
                "model.layers.6.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 291 | 
            +
                "model.layers.6.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 292 | 
            +
                "model.layers.6.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 293 | 
            +
                "model.layers.6.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 294 | 
            +
                "model.layers.6.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 295 | 
            +
                "model.layers.6.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 296 | 
            +
                "model.layers.6.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 297 | 
            +
                "model.layers.6.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 298 | 
            +
                "model.layers.7.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 299 | 
            +
                "model.layers.7.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 300 | 
            +
                "model.layers.7.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 301 | 
            +
                "model.layers.7.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 302 | 
            +
                "model.layers.7.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 303 | 
            +
                "model.layers.7.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 304 | 
            +
                "model.layers.7.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 305 | 
            +
                "model.layers.7.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 306 | 
            +
                "model.layers.7.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 307 | 
            +
                "model.layers.7.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 308 | 
            +
                "model.layers.8.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 309 | 
            +
                "model.layers.8.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 310 | 
            +
                "model.layers.8.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 311 | 
            +
                "model.layers.8.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 312 | 
            +
                "model.layers.8.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 313 | 
            +
                "model.layers.8.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 314 | 
            +
                "model.layers.8.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 315 | 
            +
                "model.layers.8.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 316 | 
            +
                "model.layers.8.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 317 | 
            +
                "model.layers.8.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 318 | 
            +
                "model.layers.9.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 319 | 
            +
                "model.layers.9.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 320 | 
            +
                "model.layers.9.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 321 | 
            +
                "model.layers.9.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 322 | 
            +
                "model.layers.9.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 323 | 
            +
                "model.layers.9.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 324 | 
            +
                "model.layers.9.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 325 | 
            +
                "model.layers.9.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 326 | 
            +
                "model.layers.9.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
         | 
| 327 | 
            +
                "model.layers.9.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
         | 
| 328 | 
            +
                "model.norm.weight": "pytorch_model-00002-of-00002.bin"
         | 
| 329 | 
            +
              }
         | 
| 330 | 
            +
            }
         | 
    	
        tokenizer_Llama-2-7b-chat-hf/tokenizer.json
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        tokenizer_Llama-2-7b-chat-hf/tokenizer.model
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
         | 
| 3 | 
            +
            size 499723
         | 
    	
        tokenizer_Llama-2-7b-chat-hf/tokenizer_config.json
    ADDED
    
    | @@ -0,0 +1,36 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "add_bos_token": true,
         | 
| 3 | 
            +
              "add_eos_token": false,
         | 
| 4 | 
            +
              "bos_token": {
         | 
| 5 | 
            +
                "__type": "AddedToken",
         | 
| 6 | 
            +
                "content": "<s>",
         | 
| 7 | 
            +
                "lstrip": false,
         | 
| 8 | 
            +
                "normalized": false,
         | 
| 9 | 
            +
                "rstrip": false,
         | 
| 10 | 
            +
                "single_word": false
         | 
| 11 | 
            +
              },
         | 
| 12 | 
            +
              "chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
         | 
| 13 | 
            +
              "clean_up_tokenization_spaces": false,
         | 
| 14 | 
            +
              "eos_token": {
         | 
| 15 | 
            +
                "__type": "AddedToken",
         | 
| 16 | 
            +
                "content": "</s>",
         | 
| 17 | 
            +
                "lstrip": false,
         | 
| 18 | 
            +
                "normalized": false,
         | 
| 19 | 
            +
                "rstrip": false,
         | 
| 20 | 
            +
                "single_word": false
         | 
| 21 | 
            +
              },
         | 
| 22 | 
            +
              "legacy": false,
         | 
| 23 | 
            +
              "model_max_length": 1000000000000000019884624838656,
         | 
| 24 | 
            +
              "pad_token": null,
         | 
| 25 | 
            +
              "padding_side": "right",
         | 
| 26 | 
            +
              "sp_model_kwargs": {},
         | 
| 27 | 
            +
              "tokenizer_class": "LlamaTokenizer",
         | 
| 28 | 
            +
              "unk_token": {
         | 
| 29 | 
            +
                "__type": "AddedToken",
         | 
| 30 | 
            +
                "content": "<unk>",
         | 
| 31 | 
            +
                "lstrip": false,
         | 
| 32 | 
            +
                "normalized": false,
         | 
| 33 | 
            +
                "rstrip": false,
         | 
| 34 | 
            +
                "single_word": false
         | 
| 35 | 
            +
              }
         | 
| 36 | 
            +
            }
         | 
    	
        tsai_gpt/__init__.py
    ADDED
    
    | @@ -0,0 +1,15 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from lightning_utilities.core.imports import RequirementCache
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from tsai_gpt.config import Config
         | 
| 4 | 
            +
            from tsai_gpt.model import GPT
         | 
| 5 | 
            +
            from tsai_gpt.tokenizer import Tokenizer
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            _LIGHTNING_AVAILABLE = RequirementCache("lightning>=2.1.0.dev0")
         | 
| 8 | 
            +
            if not bool(_LIGHTNING_AVAILABLE):
         | 
| 9 | 
            +
                raise ImportError(
         | 
| 10 | 
            +
                    "Lit-GPT requires lightning==2.1. Please run:\n"
         | 
| 11 | 
            +
                    f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}"
         | 
| 12 | 
            +
                )
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            __all__ = ["GPT", "Config", "Tokenizer"]
         | 
    	
        tsai_gpt/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (620 Bytes). View file | 
|  | 
    	
        tsai_gpt/__pycache__/__init__.cpython-39.pyc
    ADDED
    
    | Binary file (606 Bytes). View file | 
|  | 
    	
        tsai_gpt/__pycache__/config.cpython-310.pyc
    ADDED
    
    | Binary file (13.4 kB). View file | 
|  | 
    	
        tsai_gpt/__pycache__/config.cpython-39.pyc
    ADDED
    
    | Binary file (12.7 kB). View file | 
|  | 
    	
        tsai_gpt/__pycache__/model.cpython-310.pyc
    ADDED
    
    | Binary file (11.8 kB). View file | 
|  | 
    	
        tsai_gpt/__pycache__/model.cpython-39.pyc
    ADDED
    
    | Binary file (11.5 kB). View file | 
|  | 
    	
        tsai_gpt/__pycache__/packed_dataset.cpython-39.pyc
    ADDED
    
    | Binary file (7.62 kB). View file | 
|  | 
    	
        tsai_gpt/__pycache__/speed_monitor.cpython-39.pyc
    ADDED
    
    | Binary file (15.6 kB). View file | 
|  | 
    	
        tsai_gpt/__pycache__/tokenizer.cpython-310.pyc
    ADDED
    
    | Binary file (3.57 kB). View file | 
|  | 
    	
        tsai_gpt/__pycache__/tokenizer.cpython-39.pyc
    ADDED
    
    | Binary file (3.54 kB). View file | 
|  | 
    	
        tsai_gpt/__pycache__/utils.cpython-310.pyc
    ADDED
    
    | Binary file (11.8 kB). View file | 
|  | 
    	
        tsai_gpt/__pycache__/utils.cpython-39.pyc
    ADDED
    
    | Binary file (11.8 kB). View file | 
|  | 
    	
        tsai_gpt/config.py
    ADDED
    
    | @@ -0,0 +1,1192 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            from copy import deepcopy
         | 
| 3 | 
            +
            from dataclasses import dataclass, field
         | 
| 4 | 
            +
            from pathlib import Path
         | 
| 5 | 
            +
            from typing import Any, Literal, Optional, Type, Union
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from typing_extensions import Self
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from tsai_gpt.utils import find_multiple
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            @dataclass
         | 
| 14 | 
            +
            class Config:
         | 
| 15 | 
            +
                name: str = ""
         | 
| 16 | 
            +
                hf_config: dict = field(default_factory=dict)
         | 
| 17 | 
            +
                block_size: int = 4096
         | 
| 18 | 
            +
                vocab_size: int = 50254
         | 
| 19 | 
            +
                padding_multiple: int = 512
         | 
| 20 | 
            +
                padded_vocab_size: Optional[int] = None
         | 
| 21 | 
            +
                n_layer: int = 16
         | 
| 22 | 
            +
                n_head: int = 32
         | 
| 23 | 
            +
                n_embd: int = 4096
         | 
| 24 | 
            +
                rotary_percentage: float = 0.25
         | 
| 25 | 
            +
                parallel_residual: bool = True
         | 
| 26 | 
            +
                bias: bool = True
         | 
| 27 | 
            +
                lm_head_bias: bool = False
         | 
| 28 | 
            +
                # to use multi-head attention (MHA), set this to `n_head` (default)
         | 
| 29 | 
            +
                # to use multi-query attention (MQA), set this to 1
         | 
| 30 | 
            +
                # to use grouped-query attention (GQA), set this to a value in between
         | 
| 31 | 
            +
                # Example with `n_head=4`
         | 
| 32 | 
            +
                # ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
         | 
| 33 | 
            +
                # │ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │
         | 
| 34 | 
            +
                # └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
         | 
| 35 | 
            +
                #   │    │    │    │         │        │                 │
         | 
| 36 | 
            +
                # ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
         | 
| 37 | 
            +
                # │ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │
         | 
| 38 | 
            +
                # └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
         | 
| 39 | 
            +
                #   │    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐
         | 
| 40 | 
            +
                # ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐
         | 
| 41 | 
            +
                # │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │
         | 
| 42 | 
            +
                # └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘
         | 
| 43 | 
            +
                # ◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶
         | 
| 44 | 
            +
                #         MHA                    GQA                   MQA
         | 
| 45 | 
            +
                #   n_query_groups=4       n_query_groups=2      n_query_groups=1
         | 
| 46 | 
            +
                #
         | 
| 47 | 
            +
                # credit https://arxiv.org/pdf/2305.13245.pdf
         | 
| 48 | 
            +
                n_query_groups: Optional[int] = None
         | 
| 49 | 
            +
                shared_attention_norm: bool = False
         | 
| 50 | 
            +
                _norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
         | 
| 51 | 
            +
                norm_eps: float = 1e-5
         | 
| 52 | 
            +
                _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP"
         | 
| 53 | 
            +
                gelu_approximate: str = "none"
         | 
| 54 | 
            +
                intermediate_size: Optional[int] = None
         | 
| 55 | 
            +
                rope_condense_ratio: int = 1
         | 
| 56 | 
            +
                rope_base: int = 10000
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def __post_init__(self):
         | 
| 59 | 
            +
                    if not self.name:
         | 
| 60 | 
            +
                        self.name = self.hf_config.get("name", self.name)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    assert self.n_embd % self.n_head == 0
         | 
| 63 | 
            +
                    self.head_size = self.n_embd // self.n_head
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    # vocab size should be a power of 2 to be optimal on hardware. compute the closest value
         | 
| 66 | 
            +
                    if self.padded_vocab_size is None:
         | 
| 67 | 
            +
                        self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple)
         | 
| 68 | 
            +
                    else:
         | 
| 69 | 
            +
                        # vocab size shouldn't be larger than padded vocab size
         | 
| 70 | 
            +
                        self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    # compute the number of query groups
         | 
| 73 | 
            +
                    if self.n_query_groups is not None:
         | 
| 74 | 
            +
                        assert self.n_head % self.n_query_groups == 0
         | 
| 75 | 
            +
                    else:
         | 
| 76 | 
            +
                        self.n_query_groups = self.n_head
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    # compute the intermediate size for MLP if not set
         | 
| 79 | 
            +
                    if self.intermediate_size is None:
         | 
| 80 | 
            +
                        if self._mlp_class == "LLaMAMLP":
         | 
| 81 | 
            +
                            raise ValueError("The config needs to set the `intermediate_size`")
         | 
| 82 | 
            +
                        self.intermediate_size = 4 * self.n_embd
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    self.rope_n_elem = int(self.rotary_percentage * self.head_size)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                @classmethod
         | 
| 87 | 
            +
                def from_name(cls, name: str, **kwargs: Any) -> Self:
         | 
| 88 | 
            +
                    if name not in name_to_config:
         | 
| 89 | 
            +
                        # search through all `config['hf_config']['name']`
         | 
| 90 | 
            +
                        conf_dict = next(config for config in configs if name == config["hf_config"]["name"])
         | 
| 91 | 
            +
                    else:
         | 
| 92 | 
            +
                        conf_dict = name_to_config[name]
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    conf_dict = conf_dict.copy()
         | 
| 95 | 
            +
                    if "condense_ratio" in kwargs:  # legacy name
         | 
| 96 | 
            +
                        kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio")
         | 
| 97 | 
            +
                    conf_dict.update(kwargs)
         | 
| 98 | 
            +
                    return cls(**conf_dict)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                @classmethod
         | 
| 101 | 
            +
                def from_json(cls, path: Union[str, Path], **kwargs: Any) -> Self:
         | 
| 102 | 
            +
                    with open(path, encoding="utf-8") as fp:
         | 
| 103 | 
            +
                        json_kwargs = json.load(fp)
         | 
| 104 | 
            +
                    if "condense_ratio" in json_kwargs:  # legacy name
         | 
| 105 | 
            +
                        json_kwargs["rope_condense_ratio"] = json_kwargs.pop("condense_ratio")
         | 
| 106 | 
            +
                    if "condense_ratio" in kwargs:  # legacy name
         | 
| 107 | 
            +
                        kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio")
         | 
| 108 | 
            +
                    if "org" in json_kwargs:  # legacy name
         | 
| 109 | 
            +
                        json_kwargs["hf_config"] = {"name": json_kwargs["name"], "org": json_kwargs.pop("org")}
         | 
| 110 | 
            +
                    if "org" in kwargs:  # legacy name
         | 
| 111 | 
            +
                        kwargs["hf_config"] = {
         | 
| 112 | 
            +
                            "name": kwargs.get("name", json_kwargs["name"]),
         | 
| 113 | 
            +
                            "org": kwargs.pop("org"),
         | 
| 114 | 
            +
                        }
         | 
| 115 | 
            +
                    json_kwargs.update(kwargs)
         | 
| 116 | 
            +
                    return cls(**json_kwargs)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                @property
         | 
| 119 | 
            +
                def mlp_class(self) -> Type:
         | 
| 120 | 
            +
                    # `self._mlp_class` cannot be the type to keep the config json serializable
         | 
| 121 | 
            +
                    import tsai_gpt.model
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    return getattr(tsai_gpt.model, self._mlp_class)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                @property
         | 
| 126 | 
            +
                def norm_class(self) -> Type:
         | 
| 127 | 
            +
                    # `self._norm_class` cannot be the type to keep the config json serializable
         | 
| 128 | 
            +
                    if self._norm_class == "RMSNorm":
         | 
| 129 | 
            +
                        from tsai_gpt.rmsnorm import RMSNorm
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                        return RMSNorm
         | 
| 132 | 
            +
                    return getattr(torch.nn, self._norm_class)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
             | 
| 135 | 
            +
            ########################
         | 
| 136 | 
            +
            # Stability AI StableLM
         | 
| 137 | 
            +
            ########################
         | 
| 138 | 
            +
            configs = [
         | 
| 139 | 
            +
                # https://huggingface.co/stabilityai/stablelm-base-alpha-3b/blob/main/config.json
         | 
| 140 | 
            +
                dict(
         | 
| 141 | 
            +
                    name="stablelm-base-alpha-3b",
         | 
| 142 | 
            +
                    hf_config=dict(org="stabilityai", name="stablelm-base-alpha-3b"),
         | 
| 143 | 
            +
                ),
         | 
| 144 | 
            +
                # https://huggingface.co/stabilityai/stablelm-base-alpha-7b/blob/main/config.json
         | 
| 145 | 
            +
                dict(
         | 
| 146 | 
            +
                    name="stablelm-base-alpha-7b",
         | 
| 147 | 
            +
                    hf_config=dict(org="stabilityai", name="stablelm-base-alpha-7b"),
         | 
| 148 | 
            +
                    n_head=48,
         | 
| 149 | 
            +
                    n_embd=6144,
         | 
| 150 | 
            +
                    padding_multiple=256,
         | 
| 151 | 
            +
                ),
         | 
| 152 | 
            +
                # https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b/blob/main/config.json
         | 
| 153 | 
            +
                dict(
         | 
| 154 | 
            +
                    name="stablelm-tuned-alpha-3b",
         | 
| 155 | 
            +
                    hf_config=dict(org="stabilityai", name="stablelm-tuned-alpha-3b"),
         | 
| 156 | 
            +
                    n_head=32,
         | 
| 157 | 
            +
                ),
         | 
| 158 | 
            +
                # https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b/blob/main/config.json
         | 
| 159 | 
            +
                dict(
         | 
| 160 | 
            +
                    name="stablelm-tuned-alpha-7b",
         | 
| 161 | 
            +
                    hf_config=dict(org="stabilityai", name="stablelm-tuned-alpha-7b"),
         | 
| 162 | 
            +
                    n_head=48,
         | 
| 163 | 
            +
                    n_embd=6144,
         | 
| 164 | 
            +
                    padding_multiple=256,
         | 
| 165 | 
            +
                ),
         | 
| 166 | 
            +
            ]
         | 
| 167 | 
            +
             | 
| 168 | 
            +
            ####################
         | 
| 169 | 
            +
            # EleutherAI Pythia
         | 
| 170 | 
            +
            ####################
         | 
| 171 | 
            +
            pythia = [
         | 
| 172 | 
            +
                # https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json
         | 
| 173 | 
            +
                dict(
         | 
| 174 | 
            +
                    name="pythia-70m",
         | 
| 175 | 
            +
                    hf_config=dict(org="EleutherAI", name="pythia-70m"),
         | 
| 176 | 
            +
                    block_size=2048,
         | 
| 177 | 
            +
                    n_layer=6,
         | 
| 178 | 
            +
                    n_embd=512,
         | 
| 179 | 
            +
                    n_head=8,
         | 
| 180 | 
            +
                    padding_multiple=128,
         | 
| 181 | 
            +
                ),
         | 
| 182 | 
            +
                # https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json
         | 
| 183 | 
            +
                dict(
         | 
| 184 | 
            +
                    name="pythia-160m",
         | 
| 185 | 
            +
                    hf_config=dict(org="EleutherAI", name="pythia-160m"),
         | 
| 186 | 
            +
                    block_size=2048,
         | 
| 187 | 
            +
                    n_layer=12,
         | 
| 188 | 
            +
                    n_embd=768,
         | 
| 189 | 
            +
                    n_head=12,
         | 
| 190 | 
            +
                    padding_multiple=128,
         | 
| 191 | 
            +
                ),
         | 
| 192 | 
            +
                # https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json
         | 
| 193 | 
            +
                dict(
         | 
| 194 | 
            +
                    name="pythia-410m",
         | 
| 195 | 
            +
                    hf_config=dict(org="EleutherAI", name="pythia-410m"),
         | 
| 196 | 
            +
                    block_size=2048,
         | 
| 197 | 
            +
                    n_layer=24,
         | 
| 198 | 
            +
                    n_embd=1024,
         | 
| 199 | 
            +
                    n_head=16,
         | 
| 200 | 
            +
                    padding_multiple=128,
         | 
| 201 | 
            +
                ),
         | 
| 202 | 
            +
                # https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json
         | 
| 203 | 
            +
                dict(
         | 
| 204 | 
            +
                    name="pythia-1b",
         | 
| 205 | 
            +
                    hf_config=dict(org="EleutherAI", name="pythia-1b"),
         | 
| 206 | 
            +
                    block_size=2048,
         | 
| 207 | 
            +
                    n_embd=2048,
         | 
| 208 | 
            +
                    n_head=8,
         | 
| 209 | 
            +
                    padding_multiple=128,
         | 
| 210 | 
            +
                ),
         | 
| 211 | 
            +
                # https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json
         | 
| 212 | 
            +
                dict(
         | 
| 213 | 
            +
                    name="pythia-1.4b",
         | 
| 214 | 
            +
                    hf_config=dict(org="EleutherAI", name="pythia-1.4b"),
         | 
| 215 | 
            +
                    block_size=2048,
         | 
| 216 | 
            +
                    n_layer=24,
         | 
| 217 | 
            +
                    n_embd=2048,
         | 
| 218 | 
            +
                    n_head=16,
         | 
| 219 | 
            +
                    padding_multiple=128,
         | 
| 220 | 
            +
                ),
         | 
| 221 | 
            +
                # https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json
         | 
| 222 | 
            +
                dict(
         | 
| 223 | 
            +
                    name="pythia-2.8b",
         | 
| 224 | 
            +
                    hf_config=dict(org="EleutherAI", name="pythia-2.8b"),
         | 
| 225 | 
            +
                    block_size=2048,
         | 
| 226 | 
            +
                    n_layer=32,
         | 
| 227 | 
            +
                    n_embd=2560,
         | 
| 228 | 
            +
                    padding_multiple=128,
         | 
| 229 | 
            +
                ),
         | 
| 230 | 
            +
                # https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json
         | 
| 231 | 
            +
                dict(
         | 
| 232 | 
            +
                    name="pythia-6.9b",
         | 
| 233 | 
            +
                    hf_config=dict(org="EleutherAI", name="pythia-6.9b"),
         | 
| 234 | 
            +
                    block_size=2048,
         | 
| 235 | 
            +
                    n_layer=32,
         | 
| 236 | 
            +
                    padding_multiple=256,
         | 
| 237 | 
            +
                ),
         | 
| 238 | 
            +
                # https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json
         | 
| 239 | 
            +
                dict(
         | 
| 240 | 
            +
                    name="pythia-12b",
         | 
| 241 | 
            +
                    hf_config=dict(org="EleutherAI", name="pythia-12b"),
         | 
| 242 | 
            +
                    block_size=2048,
         | 
| 243 | 
            +
                    n_layer=36,
         | 
| 244 | 
            +
                    n_embd=5120,
         | 
| 245 | 
            +
                    n_head=40,
         | 
| 246 | 
            +
                ),
         | 
| 247 | 
            +
            ]
         | 
| 248 | 
            +
            configs.extend(pythia)
         | 
| 249 | 
            +
            for c in pythia:
         | 
| 250 | 
            +
                copy = c.copy()
         | 
| 251 | 
            +
                copy["name"] = f"{c['name']}-deduped"
         | 
| 252 | 
            +
                copy["hf_config"]["name"] = f"{c['hf_config']['name']}-deduped"
         | 
| 253 | 
            +
                configs.append(copy)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
             | 
| 256 | 
            +
            ####################################
         | 
| 257 | 
            +
            # togethercomputer RedPajama INCITE
         | 
| 258 | 
            +
            ####################################
         | 
| 259 | 
            +
            redpajama_incite = [
         | 
| 260 | 
            +
                # https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1/blob/main/config.json
         | 
| 261 | 
            +
                dict(
         | 
| 262 | 
            +
                    name="RedPajama-INCITE-{}-3B-v1",
         | 
| 263 | 
            +
                    hf_config=dict(org="togethercomputer", name="RedPajama-INCITE-{}-3B-v1"),
         | 
| 264 | 
            +
                    block_size=2048,
         | 
| 265 | 
            +
                    n_layer=32,
         | 
| 266 | 
            +
                    n_embd=2560,
         | 
| 267 | 
            +
                    padding_multiple=256,
         | 
| 268 | 
            +
                    rotary_percentage=1.0,
         | 
| 269 | 
            +
                    parallel_residual=False,
         | 
| 270 | 
            +
                ),
         | 
| 271 | 
            +
                # https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Base/blob/main/config.json
         | 
| 272 | 
            +
                dict(
         | 
| 273 | 
            +
                    name="RedPajama-INCITE-7B-{}",
         | 
| 274 | 
            +
                    hf_config=dict(org="togethercomputer", name="RedPajama-INCITE-7B-{}"),
         | 
| 275 | 
            +
                    block_size=2048,
         | 
| 276 | 
            +
                    n_layer=32,
         | 
| 277 | 
            +
                    padding_multiple=256,
         | 
| 278 | 
            +
                    rotary_percentage=1.0,
         | 
| 279 | 
            +
                    parallel_residual=False,
         | 
| 280 | 
            +
                ),
         | 
| 281 | 
            +
                # this redirects to the checkpoint above. kept for those who had the old weights already downloaded
         | 
| 282 | 
            +
                dict(
         | 
| 283 | 
            +
                    name="RedPajama-INCITE-{}-7B-v0.1",
         | 
| 284 | 
            +
                    hf_config=dict(org="togethercomputer", name="RedPajama-INCITE-{}-7B-v0.1"),
         | 
| 285 | 
            +
                    block_size=2048,
         | 
| 286 | 
            +
                    n_layer=32,
         | 
| 287 | 
            +
                    padding_multiple=256,
         | 
| 288 | 
            +
                    rotary_percentage=1.0,
         | 
| 289 | 
            +
                    parallel_residual=False,
         | 
| 290 | 
            +
                ),
         | 
| 291 | 
            +
            ]
         | 
| 292 | 
            +
            for c in redpajama_incite:
         | 
| 293 | 
            +
                for kind in ("Base", "Chat", "Instruct"):
         | 
| 294 | 
            +
                    copy = c.copy()
         | 
| 295 | 
            +
                    copy["name"] = c["name"].format(kind)
         | 
| 296 | 
            +
                    copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
         | 
| 297 | 
            +
                    configs.append(copy)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
             | 
| 300 | 
            +
            #################
         | 
| 301 | 
            +
            # TII UAE Falcon
         | 
| 302 | 
            +
            #################
         | 
| 303 | 
            +
            falcon = [
         | 
| 304 | 
            +
                # https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json
         | 
| 305 | 
            +
                dict(
         | 
| 306 | 
            +
                    name="falcon-7b{}",
         | 
| 307 | 
            +
                    hf_config=dict(org="tiiuae", name="falcon-7b{}"),
         | 
| 308 | 
            +
                    block_size=2048,
         | 
| 309 | 
            +
                    vocab_size=65024,
         | 
| 310 | 
            +
                    padded_vocab_size=65024,
         | 
| 311 | 
            +
                    n_layer=32,
         | 
| 312 | 
            +
                    n_head=71,
         | 
| 313 | 
            +
                    n_embd=4544,
         | 
| 314 | 
            +
                    rotary_percentage=1.0,
         | 
| 315 | 
            +
                    n_query_groups=1,
         | 
| 316 | 
            +
                    bias=False,
         | 
| 317 | 
            +
                    # this is not in the config, but in the original model implementation, only for this config
         | 
| 318 | 
            +
                    shared_attention_norm=True,
         | 
| 319 | 
            +
                ),
         | 
| 320 | 
            +
                # https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json
         | 
| 321 | 
            +
                dict(
         | 
| 322 | 
            +
                    name="falcon-40b{}",
         | 
| 323 | 
            +
                    hf_config=dict(org="tiiuae", name="falcon-40b{}"),
         | 
| 324 | 
            +
                    block_size=2048,
         | 
| 325 | 
            +
                    vocab_size=65024,
         | 
| 326 | 
            +
                    padded_vocab_size=65024,
         | 
| 327 | 
            +
                    n_layer=60,
         | 
| 328 | 
            +
                    n_head=128,
         | 
| 329 | 
            +
                    n_embd=8192,
         | 
| 330 | 
            +
                    rotary_percentage=1.0,
         | 
| 331 | 
            +
                    n_query_groups=8,
         | 
| 332 | 
            +
                    bias=False,
         | 
| 333 | 
            +
                ),
         | 
| 334 | 
            +
            ]
         | 
| 335 | 
            +
            for c in falcon:
         | 
| 336 | 
            +
                for kind in ("", "-instruct"):
         | 
| 337 | 
            +
                    copy = c.copy()
         | 
| 338 | 
            +
                    copy["name"] = c["name"].format(kind)
         | 
| 339 | 
            +
                    copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
         | 
| 340 | 
            +
                    configs.append(copy)
         | 
| 341 | 
            +
             | 
| 342 | 
            +
            # https://huggingface.co/tiiuae/falcon-180b/blob/main/config.json
         | 
| 343 | 
            +
            falcon180b = dict(
         | 
| 344 | 
            +
                name="falcon-180B{}",
         | 
| 345 | 
            +
                hf_config=dict(org="tiiuae", name="falcon-180B{}"),
         | 
| 346 | 
            +
                block_size=2048,
         | 
| 347 | 
            +
                vocab_size=65024,
         | 
| 348 | 
            +
                padded_vocab_size=65024,
         | 
| 349 | 
            +
                n_layer=80,
         | 
| 350 | 
            +
                n_head=232,
         | 
| 351 | 
            +
                n_embd=14848,
         | 
| 352 | 
            +
                rotary_percentage=1.0,
         | 
| 353 | 
            +
                n_query_groups=8,
         | 
| 354 | 
            +
                bias=False,
         | 
| 355 | 
            +
            )
         | 
| 356 | 
            +
             | 
| 357 | 
            +
            for kind in ("", "-chat"):
         | 
| 358 | 
            +
                copy = falcon180b.copy()
         | 
| 359 | 
            +
                copy["name"] = falcon180b["name"].format(kind)
         | 
| 360 | 
            +
                copy["hf_config"]["name"] = falcon180b["hf_config"]["name"].format(kind)
         | 
| 361 | 
            +
                configs.append(copy)
         | 
| 362 | 
            +
             | 
| 363 | 
            +
             | 
| 364 | 
            +
            #############################
         | 
| 365 | 
            +
            # OpenLM Research Open LLaMA
         | 
| 366 | 
            +
            #############################
         | 
| 367 | 
            +
            open_LLaMA = [
         | 
| 368 | 
            +
                # https://huggingface.co/openlm-research/open_llama_3b/blob/main/config.json
         | 
| 369 | 
            +
                dict(
         | 
| 370 | 
            +
                    name="open_llama_3b",
         | 
| 371 | 
            +
                    hf_config=dict(org="openlm-research", name="open_llama_3b"),
         | 
| 372 | 
            +
                    block_size=2048,
         | 
| 373 | 
            +
                    vocab_size=32000,
         | 
| 374 | 
            +
                    padding_multiple=64,
         | 
| 375 | 
            +
                    n_layer=26,
         | 
| 376 | 
            +
                    n_embd=3200,
         | 
| 377 | 
            +
                    rotary_percentage=1.0,
         | 
| 378 | 
            +
                    parallel_residual=False,
         | 
| 379 | 
            +
                    bias=False,
         | 
| 380 | 
            +
                    _norm_class="RMSNorm",
         | 
| 381 | 
            +
                    norm_eps=1e-6,
         | 
| 382 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 383 | 
            +
                    intermediate_size=8640,
         | 
| 384 | 
            +
                ),
         | 
| 385 | 
            +
                # https://huggingface.co/openlm-research/open_llama_7b/blob/main/config.json
         | 
| 386 | 
            +
                dict(
         | 
| 387 | 
            +
                    name="open_llama_7b",
         | 
| 388 | 
            +
                    hf_config=dict(org="openlm-research", name="open_llama_7b"),
         | 
| 389 | 
            +
                    block_size=2048,
         | 
| 390 | 
            +
                    vocab_size=32000,
         | 
| 391 | 
            +
                    padding_multiple=64,
         | 
| 392 | 
            +
                    n_layer=32,
         | 
| 393 | 
            +
                    rotary_percentage=1.0,
         | 
| 394 | 
            +
                    parallel_residual=False,
         | 
| 395 | 
            +
                    bias=False,
         | 
| 396 | 
            +
                    _norm_class="RMSNorm",
         | 
| 397 | 
            +
                    norm_eps=1e-6,
         | 
| 398 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 399 | 
            +
                    intermediate_size=11008,
         | 
| 400 | 
            +
                ),
         | 
| 401 | 
            +
                # https://huggingface.co/openlm-research/open_llama_13b/blob/main/config.json
         | 
| 402 | 
            +
                dict(
         | 
| 403 | 
            +
                    name="open_llama_13b",
         | 
| 404 | 
            +
                    hf_config=dict(org="openlm-research", name="open_llama_13b"),
         | 
| 405 | 
            +
                    block_size=2048,
         | 
| 406 | 
            +
                    vocab_size=32000,
         | 
| 407 | 
            +
                    padding_multiple=64,
         | 
| 408 | 
            +
                    n_layer=40,
         | 
| 409 | 
            +
                    n_head=40,
         | 
| 410 | 
            +
                    n_embd=5120,
         | 
| 411 | 
            +
                    rotary_percentage=1.0,
         | 
| 412 | 
            +
                    parallel_residual=False,
         | 
| 413 | 
            +
                    bias=False,
         | 
| 414 | 
            +
                    _norm_class="RMSNorm",
         | 
| 415 | 
            +
                    norm_eps=1e-6,
         | 
| 416 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 417 | 
            +
                    intermediate_size=13824,
         | 
| 418 | 
            +
                ),
         | 
| 419 | 
            +
            ]
         | 
| 420 | 
            +
            configs.extend(open_LLaMA)
         | 
| 421 | 
            +
             | 
| 422 | 
            +
             | 
| 423 | 
            +
            ###############
         | 
| 424 | 
            +
            # LMSYS Vicuna
         | 
| 425 | 
            +
            ###############
         | 
| 426 | 
            +
            vicuna = [
         | 
| 427 | 
            +
                # https://huggingface.co/lmsys/vicuna-7b-v1.3/blob/main/config.json
         | 
| 428 | 
            +
                dict(
         | 
| 429 | 
            +
                    name="vicuna-7b-v1.3",
         | 
| 430 | 
            +
                    hf_config=dict(org="lmsys", name="vicuna-7b-v1.3"),
         | 
| 431 | 
            +
                    block_size=2048,
         | 
| 432 | 
            +
                    vocab_size=32000,
         | 
| 433 | 
            +
                    padding_multiple=64,
         | 
| 434 | 
            +
                    n_layer=32,
         | 
| 435 | 
            +
                    rotary_percentage=1.0,
         | 
| 436 | 
            +
                    parallel_residual=False,
         | 
| 437 | 
            +
                    bias=False,
         | 
| 438 | 
            +
                    _norm_class="RMSNorm",
         | 
| 439 | 
            +
                    norm_eps=1e-6,
         | 
| 440 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 441 | 
            +
                    intermediate_size=11008,
         | 
| 442 | 
            +
                ),
         | 
| 443 | 
            +
                # https://huggingface.co/lmsys/vicuna-13b-v1.3/blob/main/config.json
         | 
| 444 | 
            +
                dict(
         | 
| 445 | 
            +
                    name="vicuna-13b-v1.3",
         | 
| 446 | 
            +
                    hf_config=dict(org="lmsys", name="vicuna-13b-v1.3"),
         | 
| 447 | 
            +
                    block_size=2048,
         | 
| 448 | 
            +
                    vocab_size=32000,
         | 
| 449 | 
            +
                    padding_multiple=64,
         | 
| 450 | 
            +
                    n_layer=40,
         | 
| 451 | 
            +
                    n_head=40,
         | 
| 452 | 
            +
                    n_embd=5120,
         | 
| 453 | 
            +
                    rotary_percentage=1.0,
         | 
| 454 | 
            +
                    parallel_residual=False,
         | 
| 455 | 
            +
                    bias=False,
         | 
| 456 | 
            +
                    _norm_class="RMSNorm",
         | 
| 457 | 
            +
                    norm_eps=1e-6,
         | 
| 458 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 459 | 
            +
                    intermediate_size=13824,
         | 
| 460 | 
            +
                ),
         | 
| 461 | 
            +
                # https://huggingface.co/lmsys/vicuna-33b-v1.3/blob/main/config.json
         | 
| 462 | 
            +
                dict(
         | 
| 463 | 
            +
                    name="vicuna-33b-v1.3",
         | 
| 464 | 
            +
                    hf_config=dict(org="lmsys", name="vicuna-33b-v1.3"),
         | 
| 465 | 
            +
                    block_size=2048,
         | 
| 466 | 
            +
                    vocab_size=32000,
         | 
| 467 | 
            +
                    padding_multiple=64,
         | 
| 468 | 
            +
                    n_layer=60,
         | 
| 469 | 
            +
                    n_head=52,
         | 
| 470 | 
            +
                    n_embd=6656,
         | 
| 471 | 
            +
                    rotary_percentage=1.0,
         | 
| 472 | 
            +
                    parallel_residual=False,
         | 
| 473 | 
            +
                    bias=False,
         | 
| 474 | 
            +
                    _norm_class="RMSNorm",
         | 
| 475 | 
            +
                    norm_eps=1e-6,
         | 
| 476 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 477 | 
            +
                    intermediate_size=17920,
         | 
| 478 | 
            +
                ),
         | 
| 479 | 
            +
                # https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json
         | 
| 480 | 
            +
                dict(
         | 
| 481 | 
            +
                    name="vicuna-7b-v1.5",
         | 
| 482 | 
            +
                    hf_config=dict(org="lmsys", name="vicuna-7b-v1.5"),
         | 
| 483 | 
            +
                    vocab_size=32000,
         | 
| 484 | 
            +
                    padding_multiple=64,
         | 
| 485 | 
            +
                    n_layer=32,
         | 
| 486 | 
            +
                    rotary_percentage=1.0,
         | 
| 487 | 
            +
                    parallel_residual=False,
         | 
| 488 | 
            +
                    bias=False,
         | 
| 489 | 
            +
                    _norm_class="RMSNorm",
         | 
| 490 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 491 | 
            +
                    intermediate_size=11008,
         | 
| 492 | 
            +
                ),
         | 
| 493 | 
            +
                # https://huggingface.co/lmsys/vicuna-7b-v1.5-16k/blob/main/config.json
         | 
| 494 | 
            +
                dict(
         | 
| 495 | 
            +
                    name="vicuna-7b-v1.5-16k",
         | 
| 496 | 
            +
                    hf_config=dict(org="lmsys", name="vicuna-7b-v1.5-16k"),
         | 
| 497 | 
            +
                    block_size=16384,
         | 
| 498 | 
            +
                    vocab_size=32000,
         | 
| 499 | 
            +
                    padding_multiple=64,
         | 
| 500 | 
            +
                    n_layer=32,
         | 
| 501 | 
            +
                    rotary_percentage=1.0,
         | 
| 502 | 
            +
                    parallel_residual=False,
         | 
| 503 | 
            +
                    bias=False,
         | 
| 504 | 
            +
                    _norm_class="RMSNorm",
         | 
| 505 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 506 | 
            +
                    intermediate_size=11008,
         | 
| 507 | 
            +
                    rope_condense_ratio=4,
         | 
| 508 | 
            +
                ),
         | 
| 509 | 
            +
                # https://huggingface.co/lmsys/vicuna-13b-v1.5/blob/main/config.json
         | 
| 510 | 
            +
                dict(
         | 
| 511 | 
            +
                    name="vicuna-13b-v1.5",
         | 
| 512 | 
            +
                    hf_config=dict(org="lmsys", name="vicuna-13b-v1.5"),
         | 
| 513 | 
            +
                    vocab_size=32000,
         | 
| 514 | 
            +
                    padding_multiple=64,
         | 
| 515 | 
            +
                    n_layer=40,
         | 
| 516 | 
            +
                    n_head=40,
         | 
| 517 | 
            +
                    n_embd=5120,
         | 
| 518 | 
            +
                    rotary_percentage=1.0,
         | 
| 519 | 
            +
                    parallel_residual=False,
         | 
| 520 | 
            +
                    bias=False,
         | 
| 521 | 
            +
                    _norm_class="RMSNorm",
         | 
| 522 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 523 | 
            +
                    intermediate_size=13824,
         | 
| 524 | 
            +
                ),
         | 
| 525 | 
            +
                # https://huggingface.co/lmsys/vicuna-13b-v1.5-16k/blob/main/config.json
         | 
| 526 | 
            +
                dict(
         | 
| 527 | 
            +
                    name="vicuna-13b-v1.5-16k",
         | 
| 528 | 
            +
                    hf_config=dict(org="lmsys", name="vicuna-13b-v1.5-16k"),
         | 
| 529 | 
            +
                    block_size=16384,
         | 
| 530 | 
            +
                    vocab_size=32000,
         | 
| 531 | 
            +
                    padding_multiple=64,
         | 
| 532 | 
            +
                    n_layer=40,
         | 
| 533 | 
            +
                    n_head=40,
         | 
| 534 | 
            +
                    n_embd=5120,
         | 
| 535 | 
            +
                    rotary_percentage=1.0,
         | 
| 536 | 
            +
                    parallel_residual=False,
         | 
| 537 | 
            +
                    bias=False,
         | 
| 538 | 
            +
                    _norm_class="RMSNorm",
         | 
| 539 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 540 | 
            +
                    intermediate_size=13824,
         | 
| 541 | 
            +
                    rope_condense_ratio=4,
         | 
| 542 | 
            +
                ),
         | 
| 543 | 
            +
            ]
         | 
| 544 | 
            +
            configs.extend(vicuna)
         | 
| 545 | 
            +
             | 
| 546 | 
            +
             | 
| 547 | 
            +
            #################
         | 
| 548 | 
            +
            # LMSYS LongChat
         | 
| 549 | 
            +
            #################
         | 
| 550 | 
            +
            long_chat = [
         | 
| 551 | 
            +
                # https://huggingface.co/lmsys/longchat-7b-16k/blob/main/config.json
         | 
| 552 | 
            +
                dict(
         | 
| 553 | 
            +
                    name="longchat-7b-16k",
         | 
| 554 | 
            +
                    hf_config=dict(org="lmsys", name="longchat-7b-16k"),
         | 
| 555 | 
            +
                    block_size=16384,
         | 
| 556 | 
            +
                    vocab_size=32000,
         | 
| 557 | 
            +
                    padding_multiple=64,
         | 
| 558 | 
            +
                    n_layer=32,
         | 
| 559 | 
            +
                    rotary_percentage=1.0,
         | 
| 560 | 
            +
                    parallel_residual=False,
         | 
| 561 | 
            +
                    bias=False,
         | 
| 562 | 
            +
                    _norm_class="RMSNorm",
         | 
| 563 | 
            +
                    norm_eps=1e-6,
         | 
| 564 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 565 | 
            +
                    intermediate_size=11008,
         | 
| 566 | 
            +
                    rope_condense_ratio=8,
         | 
| 567 | 
            +
                ),
         | 
| 568 | 
            +
                # https://huggingface.co/lmsys/longchat-13b-16k/blob/main/config.json
         | 
| 569 | 
            +
                dict(
         | 
| 570 | 
            +
                    name="longchat-13b-16k",
         | 
| 571 | 
            +
                    hf_config=dict(org="lmsys", name="longchat-13b-16k"),
         | 
| 572 | 
            +
                    block_size=16384,
         | 
| 573 | 
            +
                    vocab_size=32000,
         | 
| 574 | 
            +
                    padding_multiple=64,
         | 
| 575 | 
            +
                    n_layer=40,
         | 
| 576 | 
            +
                    n_head=40,
         | 
| 577 | 
            +
                    n_embd=5120,
         | 
| 578 | 
            +
                    rotary_percentage=1.0,
         | 
| 579 | 
            +
                    parallel_residual=False,
         | 
| 580 | 
            +
                    bias=False,
         | 
| 581 | 
            +
                    _norm_class="RMSNorm",
         | 
| 582 | 
            +
                    norm_eps=1e-6,
         | 
| 583 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 584 | 
            +
                    intermediate_size=13824,
         | 
| 585 | 
            +
                    rope_condense_ratio=8,
         | 
| 586 | 
            +
                ),
         | 
| 587 | 
            +
            ]
         | 
| 588 | 
            +
            configs.extend(long_chat)
         | 
| 589 | 
            +
             | 
| 590 | 
            +
             | 
| 591 | 
            +
            ######################
         | 
| 592 | 
            +
            # NousResearch Hermes
         | 
| 593 | 
            +
            ######################
         | 
| 594 | 
            +
            nous_research = [
         | 
| 595 | 
            +
                # https://huggingface.co/NousResearch/Nous-Hermes-llama-2-7b/blob/main/config.json
         | 
| 596 | 
            +
                dict(
         | 
| 597 | 
            +
                    name="Nous-Hermes-llama-2-7b",
         | 
| 598 | 
            +
                    hf_config=dict(org="NousResearch", name="Nous-Hermes-llama-2-7b"),
         | 
| 599 | 
            +
                    padded_vocab_size=32000,
         | 
| 600 | 
            +
                    n_layer=32,
         | 
| 601 | 
            +
                    rotary_percentage=1.0,
         | 
| 602 | 
            +
                    parallel_residual=False,
         | 
| 603 | 
            +
                    bias=False,
         | 
| 604 | 
            +
                    _norm_class="RMSNorm",
         | 
| 605 | 
            +
                    norm_eps=1e-05,
         | 
| 606 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 607 | 
            +
                    intermediate_size=11008,
         | 
| 608 | 
            +
                ),
         | 
| 609 | 
            +
                # https://huggingface.co/NousResearch/Nous-Hermes-13B/blob/main/config.json
         | 
| 610 | 
            +
                dict(
         | 
| 611 | 
            +
                    name="Nous-Hermes-13b",
         | 
| 612 | 
            +
                    hf_config=dict(org="NousResearch", name="Nous-Hermes-13b"),
         | 
| 613 | 
            +
                    block_size=2048,
         | 
| 614 | 
            +
                    vocab_size=32000,
         | 
| 615 | 
            +
                    padded_vocab_size=32001,
         | 
| 616 | 
            +
                    n_layer=40,
         | 
| 617 | 
            +
                    n_head=40,
         | 
| 618 | 
            +
                    n_embd=5120,
         | 
| 619 | 
            +
                    rotary_percentage=1.0,
         | 
| 620 | 
            +
                    parallel_residual=False,
         | 
| 621 | 
            +
                    bias=False,
         | 
| 622 | 
            +
                    _norm_class="RMSNorm",
         | 
| 623 | 
            +
                    norm_eps=1e-6,
         | 
| 624 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 625 | 
            +
                    intermediate_size=13824,
         | 
| 626 | 
            +
                ),
         | 
| 627 | 
            +
                # https://huggingface.co/NousResearch/Nous-Hermes-Llama2-13b
         | 
| 628 | 
            +
                dict(
         | 
| 629 | 
            +
                    name="Nous-Hermes-Llama2-13b",
         | 
| 630 | 
            +
                    hf_config=dict(org="NousResearch", name="Nous-Hermes-Llama2-13b"),
         | 
| 631 | 
            +
                    vocab_size=32000,
         | 
| 632 | 
            +
                    padded_vocab_size=32032,
         | 
| 633 | 
            +
                    n_layer=40,
         | 
| 634 | 
            +
                    n_head=40,
         | 
| 635 | 
            +
                    n_embd=5120,
         | 
| 636 | 
            +
                    rotary_percentage=1.0,
         | 
| 637 | 
            +
                    parallel_residual=False,
         | 
| 638 | 
            +
                    bias=False,
         | 
| 639 | 
            +
                    _norm_class="RMSNorm",
         | 
| 640 | 
            +
                    norm_eps=1e-05,
         | 
| 641 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 642 | 
            +
                    intermediate_size=13824,
         | 
| 643 | 
            +
                ),
         | 
| 644 | 
            +
            ]
         | 
| 645 | 
            +
            configs.extend(nous_research)
         | 
| 646 | 
            +
             | 
| 647 | 
            +
             | 
| 648 | 
            +
            ###############
         | 
| 649 | 
            +
            # Meta LLaMA 2
         | 
| 650 | 
            +
            ###############
         | 
| 651 | 
            +
            llama_2 = [
         | 
| 652 | 
            +
                # https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json
         | 
| 653 | 
            +
                dict(
         | 
| 654 | 
            +
                    name="Llama-2-7b{}-hf",
         | 
| 655 | 
            +
                    hf_config=dict(org="meta-llama", name="Llama-2-7b{}-hf"),
         | 
| 656 | 
            +
                    vocab_size=32000,
         | 
| 657 | 
            +
                    padding_multiple=64,
         | 
| 658 | 
            +
                    n_layer=32,
         | 
| 659 | 
            +
                    rotary_percentage=1.0,
         | 
| 660 | 
            +
                    parallel_residual=False,
         | 
| 661 | 
            +
                    bias=False,
         | 
| 662 | 
            +
                    _norm_class="RMSNorm",
         | 
| 663 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 664 | 
            +
                    intermediate_size=11008,
         | 
| 665 | 
            +
                ),
         | 
| 666 | 
            +
                # https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json
         | 
| 667 | 
            +
                dict(
         | 
| 668 | 
            +
                    name="Llama-2-13b{}-hf",
         | 
| 669 | 
            +
                    hf_config=dict(org="meta-llama", name="Llama-2-13b{}-hf"),
         | 
| 670 | 
            +
                    vocab_size=32000,
         | 
| 671 | 
            +
                    padding_multiple=64,
         | 
| 672 | 
            +
                    n_layer=40,
         | 
| 673 | 
            +
                    n_head=40,
         | 
| 674 | 
            +
                    n_embd=5120,
         | 
| 675 | 
            +
                    rotary_percentage=1.0,
         | 
| 676 | 
            +
                    parallel_residual=False,
         | 
| 677 | 
            +
                    bias=False,
         | 
| 678 | 
            +
                    _norm_class="RMSNorm",
         | 
| 679 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 680 | 
            +
                    intermediate_size=13824,
         | 
| 681 | 
            +
                ),
         | 
| 682 | 
            +
                # https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json
         | 
| 683 | 
            +
                dict(
         | 
| 684 | 
            +
                    name="Llama-2-70b{}-hf",
         | 
| 685 | 
            +
                    hf_config=dict(org="meta-llama", name="Llama-2-70b{}-hf"),
         | 
| 686 | 
            +
                    vocab_size=32000,
         | 
| 687 | 
            +
                    padding_multiple=64,
         | 
| 688 | 
            +
                    n_layer=80,
         | 
| 689 | 
            +
                    n_head=64,
         | 
| 690 | 
            +
                    n_embd=8192,
         | 
| 691 | 
            +
                    n_query_groups=8,
         | 
| 692 | 
            +
                    rotary_percentage=1.0,
         | 
| 693 | 
            +
                    parallel_residual=False,
         | 
| 694 | 
            +
                    bias=False,
         | 
| 695 | 
            +
                    _norm_class="RMSNorm",
         | 
| 696 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 697 | 
            +
                    intermediate_size=28672,
         | 
| 698 | 
            +
                ),
         | 
| 699 | 
            +
            ]
         | 
| 700 | 
            +
            for c in llama_2:
         | 
| 701 | 
            +
                for kind in ("", "-chat"):
         | 
| 702 | 
            +
                    copy = c.copy()
         | 
| 703 | 
            +
                    copy["name"] = c["name"].format(kind)
         | 
| 704 | 
            +
                    copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
         | 
| 705 | 
            +
                    configs.append(copy)
         | 
| 706 | 
            +
             | 
| 707 | 
            +
             | 
| 708 | 
            +
            ##########################
         | 
| 709 | 
            +
            # Stability AI FreeWilly2
         | 
| 710 | 
            +
            ##########################
         | 
| 711 | 
            +
            freewilly_2 = [
         | 
| 712 | 
            +
                # https://huggingface.co/stabilityai/FreeWilly2/blob/main/config.json
         | 
| 713 | 
            +
                dict(
         | 
| 714 | 
            +
                    name="FreeWilly2",
         | 
| 715 | 
            +
                    hf_config=dict(org="stabilityai", name="FreeWilly2"),
         | 
| 716 | 
            +
                    vocab_size=32000,
         | 
| 717 | 
            +
                    padding_multiple=64,
         | 
| 718 | 
            +
                    n_layer=80,
         | 
| 719 | 
            +
                    n_head=64,
         | 
| 720 | 
            +
                    n_embd=8192,
         | 
| 721 | 
            +
                    n_query_groups=8,
         | 
| 722 | 
            +
                    rotary_percentage=1.0,
         | 
| 723 | 
            +
                    parallel_residual=False,
         | 
| 724 | 
            +
                    bias=False,
         | 
| 725 | 
            +
                    _norm_class="RMSNorm",
         | 
| 726 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 727 | 
            +
                    intermediate_size=28672,
         | 
| 728 | 
            +
                )
         | 
| 729 | 
            +
            ]
         | 
| 730 | 
            +
            configs.extend(freewilly_2)
         | 
| 731 | 
            +
             | 
| 732 | 
            +
             | 
| 733 | 
            +
            ##################
         | 
| 734 | 
            +
            # Meta Code Llama
         | 
| 735 | 
            +
            ##################
         | 
| 736 | 
            +
            code_llama = [
         | 
| 737 | 
            +
                # https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json
         | 
| 738 | 
            +
                dict(
         | 
| 739 | 
            +
                    name="CodeLlama-7b-hf",
         | 
| 740 | 
            +
                    hf_config=dict(org="codellama", name="CodeLlama-7b-hf"),
         | 
| 741 | 
            +
                    block_size=16384,
         | 
| 742 | 
            +
                    vocab_size=32016,
         | 
| 743 | 
            +
                    padding_multiple=16,
         | 
| 744 | 
            +
                    n_layer=32,
         | 
| 745 | 
            +
                    rotary_percentage=1.0,
         | 
| 746 | 
            +
                    parallel_residual=False,
         | 
| 747 | 
            +
                    bias=False,
         | 
| 748 | 
            +
                    _norm_class="RMSNorm",
         | 
| 749 | 
            +
                    norm_eps=1e-05,
         | 
| 750 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 751 | 
            +
                    intermediate_size=11008,
         | 
| 752 | 
            +
                    rope_base=1000000,
         | 
| 753 | 
            +
                ),
         | 
| 754 | 
            +
                # https://huggingface.co/codellama/CodeLlama-13b-hf/blob/main/config.json
         | 
| 755 | 
            +
                dict(
         | 
| 756 | 
            +
                    name="CodeLlama-13b-hf",
         | 
| 757 | 
            +
                    hf_config=dict(org="codellama", name="CodeLlama-13b-hf"),
         | 
| 758 | 
            +
                    block_size=16384,
         | 
| 759 | 
            +
                    vocab_size=32016,
         | 
| 760 | 
            +
                    padding_multiple=16,
         | 
| 761 | 
            +
                    n_layer=40,
         | 
| 762 | 
            +
                    n_head=40,
         | 
| 763 | 
            +
                    n_embd=5120,
         | 
| 764 | 
            +
                    rotary_percentage=1.0,
         | 
| 765 | 
            +
                    parallel_residual=False,
         | 
| 766 | 
            +
                    bias=False,
         | 
| 767 | 
            +
                    _norm_class="RMSNorm",
         | 
| 768 | 
            +
                    norm_eps=1e-05,
         | 
| 769 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 770 | 
            +
                    intermediate_size=13824,
         | 
| 771 | 
            +
                    rope_base=1000000,
         | 
| 772 | 
            +
                ),
         | 
| 773 | 
            +
                # https://huggingface.co/codellama/CodeLlama-34b-hf/blob/main/config.json
         | 
| 774 | 
            +
                dict(
         | 
| 775 | 
            +
                    name="CodeLlama-34b-hf",
         | 
| 776 | 
            +
                    hf_config=dict(org="codellama", name="CodeLlama-34b-hf"),
         | 
| 777 | 
            +
                    block_size=16384,
         | 
| 778 | 
            +
                    vocab_size=32000,
         | 
| 779 | 
            +
                    padding_multiple=64,
         | 
| 780 | 
            +
                    n_layer=48,
         | 
| 781 | 
            +
                    n_head=64,
         | 
| 782 | 
            +
                    n_embd=8192,
         | 
| 783 | 
            +
                    n_query_groups=8,
         | 
| 784 | 
            +
                    rotary_percentage=1.0,
         | 
| 785 | 
            +
                    parallel_residual=False,
         | 
| 786 | 
            +
                    bias=False,
         | 
| 787 | 
            +
                    _norm_class="RMSNorm",
         | 
| 788 | 
            +
                    norm_eps=1e-05,
         | 
| 789 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 790 | 
            +
                    intermediate_size=22016,
         | 
| 791 | 
            +
                    rope_base=1000000,
         | 
| 792 | 
            +
                ),
         | 
| 793 | 
            +
                # https://huggingface.co/codellama/CodeLlama-7b-Python-hf/blob/main/config.json
         | 
| 794 | 
            +
                dict(
         | 
| 795 | 
            +
                    name="CodeLlama-7b-Python-hf",
         | 
| 796 | 
            +
                    hf_config=dict(org="codellama", name="CodeLlama-7b-Python-hf"),
         | 
| 797 | 
            +
                    block_size=16384,
         | 
| 798 | 
            +
                    vocab_size=32000,
         | 
| 799 | 
            +
                    padding_multiple=64,
         | 
| 800 | 
            +
                    n_layer=32,
         | 
| 801 | 
            +
                    rotary_percentage=1.0,
         | 
| 802 | 
            +
                    parallel_residual=False,
         | 
| 803 | 
            +
                    bias=False,
         | 
| 804 | 
            +
                    _norm_class="RMSNorm",
         | 
| 805 | 
            +
                    norm_eps=1e-05,
         | 
| 806 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 807 | 
            +
                    intermediate_size=11008,
         | 
| 808 | 
            +
                    rope_base=1000000,
         | 
| 809 | 
            +
                ),
         | 
| 810 | 
            +
                # https://huggingface.co/codellama/CodeLlama-13b-Python-hf/blob/main/config.json
         | 
| 811 | 
            +
                dict(
         | 
| 812 | 
            +
                    name="CodeLlama-13b-Python-hf",
         | 
| 813 | 
            +
                    hf_config=dict(org="codellama", name="CodeLlama-13b-Python-hf"),
         | 
| 814 | 
            +
                    block_size=16384,
         | 
| 815 | 
            +
                    vocab_size=32000,
         | 
| 816 | 
            +
                    padding_multiple=64,
         | 
| 817 | 
            +
                    n_layer=40,
         | 
| 818 | 
            +
                    n_head=40,
         | 
| 819 | 
            +
                    n_embd=5120,
         | 
| 820 | 
            +
                    rotary_percentage=1.0,
         | 
| 821 | 
            +
                    parallel_residual=False,
         | 
| 822 | 
            +
                    bias=False,
         | 
| 823 | 
            +
                    _norm_class="RMSNorm",
         | 
| 824 | 
            +
                    norm_eps=1e-05,
         | 
| 825 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 826 | 
            +
                    intermediate_size=13824,
         | 
| 827 | 
            +
                    rope_base=1000000,
         | 
| 828 | 
            +
                ),
         | 
| 829 | 
            +
                # https://huggingface.co/codellama/CodeLlama-34b-Python-hf/blob/main/config.json
         | 
| 830 | 
            +
                dict(
         | 
| 831 | 
            +
                    name="CodeLlama-34b-Python-hf",
         | 
| 832 | 
            +
                    hf_config=dict(org="codellama", name="CodeLlama-34b-Python-hf"),
         | 
| 833 | 
            +
                    block_size=16384,
         | 
| 834 | 
            +
                    vocab_size=32000,
         | 
| 835 | 
            +
                    padding_multiple=64,
         | 
| 836 | 
            +
                    n_layer=48,
         | 
| 837 | 
            +
                    n_head=64,
         | 
| 838 | 
            +
                    n_embd=8192,
         | 
| 839 | 
            +
                    n_query_groups=8,
         | 
| 840 | 
            +
                    rotary_percentage=1.0,
         | 
| 841 | 
            +
                    parallel_residual=False,
         | 
| 842 | 
            +
                    bias=False,
         | 
| 843 | 
            +
                    _norm_class="RMSNorm",
         | 
| 844 | 
            +
                    norm_eps=1e-05,
         | 
| 845 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 846 | 
            +
                    intermediate_size=22016,
         | 
| 847 | 
            +
                    rope_base=1000000,
         | 
| 848 | 
            +
                ),
         | 
| 849 | 
            +
                # https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/tree/main/config.json
         | 
| 850 | 
            +
                dict(
         | 
| 851 | 
            +
                    name="CodeLlama-7b-Instruct-hf",
         | 
| 852 | 
            +
                    hf_config=dict(org="codellama", name="CodeLlama-7b-Instruct-hf"),
         | 
| 853 | 
            +
                    block_size=16384,
         | 
| 854 | 
            +
                    vocab_size=32016,
         | 
| 855 | 
            +
                    padding_multiple=16,
         | 
| 856 | 
            +
                    n_layer=32,
         | 
| 857 | 
            +
                    rotary_percentage=1.0,
         | 
| 858 | 
            +
                    parallel_residual=False,
         | 
| 859 | 
            +
                    bias=False,
         | 
| 860 | 
            +
                    _norm_class="RMSNorm",
         | 
| 861 | 
            +
                    norm_eps=1e-05,
         | 
| 862 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 863 | 
            +
                    intermediate_size=11008,
         | 
| 864 | 
            +
                    rope_base=1000000,
         | 
| 865 | 
            +
                ),
         | 
| 866 | 
            +
                # https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf/blob/main/config.json
         | 
| 867 | 
            +
                dict(
         | 
| 868 | 
            +
                    name="CodeLlama-13b-Instruct-hf",
         | 
| 869 | 
            +
                    hf_config=dict(org="codellama", name="CodeLlama-13b-Instruct-hf"),
         | 
| 870 | 
            +
                    block_size=2048,
         | 
| 871 | 
            +
                    vocab_size=32016,
         | 
| 872 | 
            +
                    padding_multiple=16,
         | 
| 873 | 
            +
                    n_layer=40,
         | 
| 874 | 
            +
                    n_head=40,
         | 
| 875 | 
            +
                    n_embd=5120,
         | 
| 876 | 
            +
                    rotary_percentage=1.0,
         | 
| 877 | 
            +
                    parallel_residual=False,
         | 
| 878 | 
            +
                    bias=False,
         | 
| 879 | 
            +
                    _norm_class="RMSNorm",
         | 
| 880 | 
            +
                    norm_eps=1e-05,
         | 
| 881 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 882 | 
            +
                    intermediate_size=13824,
         | 
| 883 | 
            +
                    rope_base=1000000,
         | 
| 884 | 
            +
                ),
         | 
| 885 | 
            +
                # https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf/blob/main/config.json
         | 
| 886 | 
            +
                dict(
         | 
| 887 | 
            +
                    name="CodeLlama-34b-Instruct-hf",
         | 
| 888 | 
            +
                    hf_config=dict(org="codellama", name="CodeLlama-34b-Instruct-hf"),
         | 
| 889 | 
            +
                    block_size=16384,
         | 
| 890 | 
            +
                    vocab_size=32000,
         | 
| 891 | 
            +
                    padding_multiple=64,
         | 
| 892 | 
            +
                    n_layer=48,
         | 
| 893 | 
            +
                    n_head=64,
         | 
| 894 | 
            +
                    n_embd=8192,
         | 
| 895 | 
            +
                    n_query_groups=8,
         | 
| 896 | 
            +
                    rotary_percentage=1.0,
         | 
| 897 | 
            +
                    parallel_residual=False,
         | 
| 898 | 
            +
                    bias=False,
         | 
| 899 | 
            +
                    _norm_class="RMSNorm",
         | 
| 900 | 
            +
                    norm_eps=1e-05,
         | 
| 901 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 902 | 
            +
                    intermediate_size=22016,
         | 
| 903 | 
            +
                    rope_base=1000000,
         | 
| 904 | 
            +
                ),
         | 
| 905 | 
            +
            ]
         | 
| 906 | 
            +
            configs.extend(code_llama)
         | 
| 907 | 
            +
             | 
| 908 | 
            +
             | 
| 909 | 
            +
            ########################
         | 
| 910 | 
            +
            # garage-bAInd Platypus
         | 
| 911 | 
            +
            ########################
         | 
| 912 | 
            +
            platypus = [
         | 
| 913 | 
            +
                # https://huggingface.co/garage-bAInd/Platypus-30B/blob/main/config.json
         | 
| 914 | 
            +
                dict(
         | 
| 915 | 
            +
                    name="Platypus-30B",
         | 
| 916 | 
            +
                    hf_config=dict(org="garage-bAInd", name="Platypus-30B"),
         | 
| 917 | 
            +
                    block_size=2048,
         | 
| 918 | 
            +
                    padded_vocab_size=32000,
         | 
| 919 | 
            +
                    n_layer=60,
         | 
| 920 | 
            +
                    n_head=52,
         | 
| 921 | 
            +
                    n_embd=6656,
         | 
| 922 | 
            +
                    rotary_percentage=1.0,
         | 
| 923 | 
            +
                    parallel_residual=False,
         | 
| 924 | 
            +
                    bias=False,
         | 
| 925 | 
            +
                    _norm_class="RMSNorm",
         | 
| 926 | 
            +
                    norm_eps=1e-06,
         | 
| 927 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 928 | 
            +
                    intermediate_size=17920,
         | 
| 929 | 
            +
                ),
         | 
| 930 | 
            +
                # https://huggingface.co/garage-bAInd/Platypus2-7B/blob/main/config.json
         | 
| 931 | 
            +
                dict(
         | 
| 932 | 
            +
                    name="Platypus2-7B",
         | 
| 933 | 
            +
                    hf_config=dict(org="garage-bAInd", name="Platypus2-7B"),
         | 
| 934 | 
            +
                    padded_vocab_size=32000,
         | 
| 935 | 
            +
                    n_layer=32,
         | 
| 936 | 
            +
                    rotary_percentage=1.0,
         | 
| 937 | 
            +
                    parallel_residual=False,
         | 
| 938 | 
            +
                    bias=False,
         | 
| 939 | 
            +
                    _norm_class="RMSNorm",
         | 
| 940 | 
            +
                    norm_eps=1e-05,
         | 
| 941 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 942 | 
            +
                    intermediate_size=11008,
         | 
| 943 | 
            +
                ),
         | 
| 944 | 
            +
                # https://huggingface.co/garage-bAInd/Platypus2-13B/blob/main/config.json
         | 
| 945 | 
            +
                dict(
         | 
| 946 | 
            +
                    name="Platypus2-13B",
         | 
| 947 | 
            +
                    hf_config=dict(org="garage-bAInd", name="Platypus2-13B"),
         | 
| 948 | 
            +
                    padded_vocab_size=32000,
         | 
| 949 | 
            +
                    n_layer=40,
         | 
| 950 | 
            +
                    n_head=40,
         | 
| 951 | 
            +
                    n_embd=5120,
         | 
| 952 | 
            +
                    rotary_percentage=1.0,
         | 
| 953 | 
            +
                    parallel_residual=False,
         | 
| 954 | 
            +
                    bias=False,
         | 
| 955 | 
            +
                    _norm_class="RMSNorm",
         | 
| 956 | 
            +
                    norm_eps=1e-05,
         | 
| 957 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 958 | 
            +
                    intermediate_size=13824,
         | 
| 959 | 
            +
                ),
         | 
| 960 | 
            +
                # https://huggingface.co/garage-bAInd/Platypus2-70B/blob/main/config.json
         | 
| 961 | 
            +
                dict(
         | 
| 962 | 
            +
                    name="Platypus2-70B",
         | 
| 963 | 
            +
                    hf_config=dict(org="garage-bAInd", name="Platypus2-70B"),
         | 
| 964 | 
            +
                    padded_vocab_size=32000,
         | 
| 965 | 
            +
                    n_layer=80,
         | 
| 966 | 
            +
                    n_head=64,
         | 
| 967 | 
            +
                    n_embd=8192,
         | 
| 968 | 
            +
                    rotary_percentage=1.0,
         | 
| 969 | 
            +
                    parallel_residual=False,
         | 
| 970 | 
            +
                    bias=False,
         | 
| 971 | 
            +
                    _norm_class="RMSNorm",
         | 
| 972 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 973 | 
            +
                    intermediate_size=28672,
         | 
| 974 | 
            +
                ),
         | 
| 975 | 
            +
                # https://huggingface.co/garage-bAInd/Camel-Platypus2-13B/blob/main/config.json
         | 
| 976 | 
            +
                dict(
         | 
| 977 | 
            +
                    name="Camel-Platypus2-13B",
         | 
| 978 | 
            +
                    hf_config=dict(org="garage-bAInd", name="Camel-Platypus2-13B"),
         | 
| 979 | 
            +
                    padded_vocab_size=32000,
         | 
| 980 | 
            +
                    n_layer=40,
         | 
| 981 | 
            +
                    n_head=40,
         | 
| 982 | 
            +
                    n_embd=5120,
         | 
| 983 | 
            +
                    rotary_percentage=1.0,
         | 
| 984 | 
            +
                    parallel_residual=False,
         | 
| 985 | 
            +
                    bias=False,
         | 
| 986 | 
            +
                    _norm_class="RMSNorm",
         | 
| 987 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 988 | 
            +
                    intermediate_size=13824,
         | 
| 989 | 
            +
                ),
         | 
| 990 | 
            +
                # https://huggingface.co/garage-bAInd/Camel-Platypus2-70B/blob/main/config.json
         | 
| 991 | 
            +
                dict(
         | 
| 992 | 
            +
                    name="Camel-Platypus2-70B",
         | 
| 993 | 
            +
                    hf_config=dict(org="garage-bAInd", name="Camel-Platypus2-70B"),
         | 
| 994 | 
            +
                    padded_vocab_size=32000,
         | 
| 995 | 
            +
                    n_layer=80,
         | 
| 996 | 
            +
                    n_head=64,
         | 
| 997 | 
            +
                    n_embd=8192,
         | 
| 998 | 
            +
                    n_query_groups=8,
         | 
| 999 | 
            +
                    rotary_percentage=1.0,
         | 
| 1000 | 
            +
                    parallel_residual=False,
         | 
| 1001 | 
            +
                    bias=False,
         | 
| 1002 | 
            +
                    _norm_class="RMSNorm",
         | 
| 1003 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 1004 | 
            +
                    intermediate_size=28672,
         | 
| 1005 | 
            +
                ),
         | 
| 1006 | 
            +
                # https://huggingface.co/garage-bAInd/Stable-Platypus2-13B/blob/main/config.json
         | 
| 1007 | 
            +
                dict(
         | 
| 1008 | 
            +
                    name="Stable-Platypus2-13B",
         | 
| 1009 | 
            +
                    hf_config=dict(org="garage-bAInd", name="Stable-Platypus2-13B"),
         | 
| 1010 | 
            +
                    padded_vocab_size=32000,
         | 
| 1011 | 
            +
                    n_layer=40,
         | 
| 1012 | 
            +
                    n_head=40,
         | 
| 1013 | 
            +
                    n_embd=5120,
         | 
| 1014 | 
            +
                    rotary_percentage=1.0,
         | 
| 1015 | 
            +
                    parallel_residual=False,
         | 
| 1016 | 
            +
                    bias=False,
         | 
| 1017 | 
            +
                    _norm_class="RMSNorm",
         | 
| 1018 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 1019 | 
            +
                    intermediate_size=13824,
         | 
| 1020 | 
            +
                ),
         | 
| 1021 | 
            +
                # https://huggingface.co/garage-bAInd/Platypus2-70B-instruct/blob/main/config.json
         | 
| 1022 | 
            +
                dict(
         | 
| 1023 | 
            +
                    name="Platypus2-70B-instruct",
         | 
| 1024 | 
            +
                    hf_config=dict(org="garage-bAInd", name="Platypus2-70B-instruct"),
         | 
| 1025 | 
            +
                    padded_vocab_size=32000,
         | 
| 1026 | 
            +
                    n_layer=80,
         | 
| 1027 | 
            +
                    n_head=64,
         | 
| 1028 | 
            +
                    n_embd=8192,
         | 
| 1029 | 
            +
                    n_query_groups=8,
         | 
| 1030 | 
            +
                    rotary_percentage=1.0,
         | 
| 1031 | 
            +
                    parallel_residual=False,
         | 
| 1032 | 
            +
                    bias=False,
         | 
| 1033 | 
            +
                    _norm_class="RMSNorm",
         | 
| 1034 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 1035 | 
            +
                    intermediate_size=28672,
         | 
| 1036 | 
            +
                ),
         | 
| 1037 | 
            +
            ]
         | 
| 1038 | 
            +
            configs.extend(platypus)
         | 
| 1039 | 
            +
             | 
| 1040 | 
            +
             | 
| 1041 | 
            +
            ##########################
         | 
| 1042 | 
            +
            # Stability AI StableCode
         | 
| 1043 | 
            +
            ##########################
         | 
| 1044 | 
            +
            stablecode = [
         | 
| 1045 | 
            +
                # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b/blob/main/config.json
         | 
| 1046 | 
            +
                dict(
         | 
| 1047 | 
            +
                    name="stablecode-completion-alpha-3b",
         | 
| 1048 | 
            +
                    hf_config=dict(org="stabilityai", name="stablecode-completion-alpha-3b"),
         | 
| 1049 | 
            +
                    block_size=16384,
         | 
| 1050 | 
            +
                    vocab_size=49152,
         | 
| 1051 | 
            +
                    n_layer=32,
         | 
| 1052 | 
            +
                    n_embd=2560,
         | 
| 1053 | 
            +
                ),
         | 
| 1054 | 
            +
                # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b-4k/blob/main/config.json
         | 
| 1055 | 
            +
                dict(
         | 
| 1056 | 
            +
                    name="stablecode-completion-alpha-3b-4k",
         | 
| 1057 | 
            +
                    hf_config=dict(org="stabilityai", name="stablecode-completion-alpha-3b-4k"),
         | 
| 1058 | 
            +
                    vocab_size=49152,
         | 
| 1059 | 
            +
                    n_layer=32,
         | 
| 1060 | 
            +
                    n_embd=2560,
         | 
| 1061 | 
            +
                ),
         | 
| 1062 | 
            +
                # https://huggingface.co/stabilityai/stablecode-instruct-alpha-3b/blob/main/config.json
         | 
| 1063 | 
            +
                dict(
         | 
| 1064 | 
            +
                    name="stablecode-instruct-alpha-3b",
         | 
| 1065 | 
            +
                    hf_config=dict(org="stabilityai", name="stablecode-instruct-alpha-3b"),
         | 
| 1066 | 
            +
                    vocab_size=49152,
         | 
| 1067 | 
            +
                    n_layer=32,
         | 
| 1068 | 
            +
                    n_embd=2560,
         | 
| 1069 | 
            +
                ),
         | 
| 1070 | 
            +
            ]
         | 
| 1071 | 
            +
            configs.extend(stablecode)
         | 
| 1072 | 
            +
             | 
| 1073 | 
            +
             | 
| 1074 | 
            +
            ##################################
         | 
| 1075 | 
            +
            # togethercomputer LLaMA-2-7B-32K
         | 
| 1076 | 
            +
            ##################################
         | 
| 1077 | 
            +
            together_llama2_32k = [
         | 
| 1078 | 
            +
                # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/config.json
         | 
| 1079 | 
            +
                dict(
         | 
| 1080 | 
            +
                    name="LLaMA-2-7B-32K",
         | 
| 1081 | 
            +
                    hf_config=dict(org="togethercomputer", name="LLaMA-2-7B-32K"),
         | 
| 1082 | 
            +
                    vocab_size=32000,
         | 
| 1083 | 
            +
                    padding_multiple=64,
         | 
| 1084 | 
            +
                    n_layer=32,
         | 
| 1085 | 
            +
                    rotary_percentage=1.0,
         | 
| 1086 | 
            +
                    parallel_residual=False,
         | 
| 1087 | 
            +
                    bias=False,
         | 
| 1088 | 
            +
                    _norm_class="RMSNorm",
         | 
| 1089 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 1090 | 
            +
                    intermediate_size=11008,
         | 
| 1091 | 
            +
                    rope_condense_ratio=8,
         | 
| 1092 | 
            +
                )
         | 
| 1093 | 
            +
            ]
         | 
| 1094 | 
            +
            configs.extend(together_llama2_32k)
         | 
| 1095 | 
            +
             | 
| 1096 | 
            +
             | 
| 1097 | 
            +
            ################
         | 
| 1098 | 
            +
            # Microsoft Phi
         | 
| 1099 | 
            +
            ################
         | 
| 1100 | 
            +
            phi = [
         | 
| 1101 | 
            +
                # https://huggingface.co/microsoft/phi-1_5/blob/main/config.json
         | 
| 1102 | 
            +
                dict(
         | 
| 1103 | 
            +
                    name="phi-1_5",
         | 
| 1104 | 
            +
                    hf_config=dict(org="microsoft", name="phi-1_5"),
         | 
| 1105 | 
            +
                    vocab_size=50257,
         | 
| 1106 | 
            +
                    padded_vocab_size=51200,
         | 
| 1107 | 
            +
                    block_size=2048,
         | 
| 1108 | 
            +
                    n_embd=2048,
         | 
| 1109 | 
            +
                    n_layer=24,
         | 
| 1110 | 
            +
                    rotary_percentage=0.5,  # 32 / (n_embd / n_head) = 32 / 64
         | 
| 1111 | 
            +
                    shared_attention_norm=True,
         | 
| 1112 | 
            +
                    lm_head_bias=True,
         | 
| 1113 | 
            +
                    gelu_approximate="tanh",
         | 
| 1114 | 
            +
                )
         | 
| 1115 | 
            +
            ]
         | 
| 1116 | 
            +
            configs.extend(phi)
         | 
| 1117 | 
            +
             | 
| 1118 | 
            +
             | 
| 1119 | 
            +
            #############
         | 
| 1120 | 
            +
            # Mistral AI
         | 
| 1121 | 
            +
            #############
         | 
| 1122 | 
            +
            mistral = [
         | 
| 1123 | 
            +
                # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
         | 
| 1124 | 
            +
                dict(
         | 
| 1125 | 
            +
                    name="Mistral-7B-{}v0.1",
         | 
| 1126 | 
            +
                    hf_config=dict(org="mistralai", name="Mistral-7B-{}v0.1"),
         | 
| 1127 | 
            +
                    padded_vocab_size=32000,
         | 
| 1128 | 
            +
                    block_size=4096,  # should be 32768 but sliding window attention is not implemented
         | 
| 1129 | 
            +
                    n_layer=32,
         | 
| 1130 | 
            +
                    n_query_groups=8,
         | 
| 1131 | 
            +
                    rotary_percentage=1.0,
         | 
| 1132 | 
            +
                    parallel_residual=False,
         | 
| 1133 | 
            +
                    bias=False,
         | 
| 1134 | 
            +
                    _norm_class="RMSNorm",
         | 
| 1135 | 
            +
                    norm_eps=1e-05,
         | 
| 1136 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 1137 | 
            +
                    intermediate_size=14336,
         | 
| 1138 | 
            +
                )
         | 
| 1139 | 
            +
            ]
         | 
| 1140 | 
            +
            for c in mistral:
         | 
| 1141 | 
            +
                for kind in ("", "Instruct-"):
         | 
| 1142 | 
            +
                    copy = c.copy()
         | 
| 1143 | 
            +
                    copy["name"] = c["name"].format(kind)
         | 
| 1144 | 
            +
                    copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
         | 
| 1145 | 
            +
                    configs.append(copy)
         | 
| 1146 | 
            +
             | 
| 1147 | 
            +
             | 
| 1148 | 
            +
            ############
         | 
| 1149 | 
            +
            # TinyLlama
         | 
| 1150 | 
            +
            ############
         | 
| 1151 | 
            +
            tiny_llama = [
         | 
| 1152 | 
            +
                dict(
         | 
| 1153 | 
            +
                    name="tiny-llama-1.1b",
         | 
| 1154 | 
            +
                    hf_config=dict(org="PY007", name="TinyLlama-1.1B-intermediate-step-480k-1T"),
         | 
| 1155 | 
            +
                    block_size=2048,
         | 
| 1156 | 
            +
                    vocab_size=32000,
         | 
| 1157 | 
            +
                    padding_multiple=64,
         | 
| 1158 | 
            +
                    n_layer=22,
         | 
| 1159 | 
            +
                    n_head=32,
         | 
| 1160 | 
            +
                    n_embd=2048,
         | 
| 1161 | 
            +
                    rotary_percentage=1.0,
         | 
| 1162 | 
            +
                    parallel_residual=False,
         | 
| 1163 | 
            +
                    bias=False,
         | 
| 1164 | 
            +
                    _norm_class="RMSNorm",  # original TinyLlama uses FusedRMSNorm
         | 
| 1165 | 
            +
                    norm_eps=1e-5,
         | 
| 1166 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 1167 | 
            +
                    intermediate_size=5632,
         | 
| 1168 | 
            +
                    n_query_groups=4,
         | 
| 1169 | 
            +
                ),
         | 
| 1170 | 
            +
                dict(
         | 
| 1171 | 
            +
                    name="tiny-llama-new",
         | 
| 1172 | 
            +
                    hf_config=dict(org="PY007", name="TinyLlama-1.1B-intermediate-step-480k-1T"),
         | 
| 1173 | 
            +
                    block_size=768,
         | 
| 1174 | 
            +
                    vocab_size=32000,
         | 
| 1175 | 
            +
                    padding_multiple=64,
         | 
| 1176 | 
            +
                    n_layer=18,
         | 
| 1177 | 
            +
                    n_head=32,
         | 
| 1178 | 
            +
                    n_embd=1024,
         | 
| 1179 | 
            +
                    rotary_percentage=1.0,
         | 
| 1180 | 
            +
                    parallel_residual=False,
         | 
| 1181 | 
            +
                    bias=False,
         | 
| 1182 | 
            +
                    _norm_class="RMSNorm",  # original TinyLlama uses FusedRMSNorm
         | 
| 1183 | 
            +
                    norm_eps=1e-5,
         | 
| 1184 | 
            +
                    _mlp_class="LLaMAMLP",
         | 
| 1185 | 
            +
                    intermediate_size=5632,
         | 
| 1186 | 
            +
                    n_query_groups=4,
         | 
| 1187 | 
            +
                ),
         | 
| 1188 | 
            +
            ]
         | 
| 1189 | 
            +
            configs.extend(tiny_llama)
         | 
| 1190 | 
            +
             | 
| 1191 | 
            +
             | 
| 1192 | 
            +
            name_to_config = {config["name"]: config for config in configs}
         | 
    	
        tsai_gpt/model.py
    ADDED
    
    | @@ -0,0 +1,367 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Full definition of a GPT NeoX Language Model, all of it in this single file.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
         | 
| 4 | 
            +
            https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
         | 
| 5 | 
            +
            """
         | 
| 6 | 
            +
            import math
         | 
| 7 | 
            +
            from typing import Any, Optional, Tuple
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.nn as nn
         | 
| 11 | 
            +
            from typing_extensions import Self
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from tsai_gpt.config import Config
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class GPT(nn.Module):
         | 
| 17 | 
            +
                def __init__(self, config: Config) -> None:
         | 
| 18 | 
            +
                    super().__init__()
         | 
| 19 | 
            +
                    assert config.padded_vocab_size is not None
         | 
| 20 | 
            +
                    self.config = config
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                    self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
         | 
| 23 | 
            +
                    self.transformer = nn.ModuleDict(
         | 
| 24 | 
            +
                        dict(
         | 
| 25 | 
            +
                            wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
         | 
| 26 | 
            +
                            h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
         | 
| 27 | 
            +
                            ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
         | 
| 28 | 
            +
                        )
         | 
| 29 | 
            +
                    )
         | 
| 30 | 
            +
                    self.max_seq_length = self.config.block_size
         | 
| 31 | 
            +
                    self.mask_cache: Optional[torch.Tensor] = None
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                @property
         | 
| 34 | 
            +
                def max_seq_length(self) -> int:
         | 
| 35 | 
            +
                    return self._max_seq_length
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                @max_seq_length.setter
         | 
| 38 | 
            +
                def max_seq_length(self, value: int) -> None:
         | 
| 39 | 
            +
                    """
         | 
| 40 | 
            +
                    When doing inference, the sequences used might be shorter than the model's context length.
         | 
| 41 | 
            +
                    This allows setting a smaller number to avoid allocating unused memory
         | 
| 42 | 
            +
                    """
         | 
| 43 | 
            +
                    if value > self.config.block_size:
         | 
| 44 | 
            +
                        raise ValueError(
         | 
| 45 | 
            +
                            f"Cannot attend to {value}, block size is only {self.config.block_size}"
         | 
| 46 | 
            +
                        )
         | 
| 47 | 
            +
                    self._max_seq_length = value
         | 
| 48 | 
            +
                    if not hasattr(self, "cos"):
         | 
| 49 | 
            +
                        # first call
         | 
| 50 | 
            +
                        cos, sin = self.rope_cache()
         | 
| 51 | 
            +
                        self.register_buffer("cos", cos, persistent=False)
         | 
| 52 | 
            +
                        self.register_buffer("sin", sin, persistent=False)
         | 
| 53 | 
            +
                    elif value != self.cos.size(0):
         | 
| 54 | 
            +
                        # override
         | 
| 55 | 
            +
                        self.cos, self.sin = self.rope_cache(device=self.cos.device)
         | 
| 56 | 
            +
                    # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
         | 
| 57 | 
            +
                    # if the kv cache is expected
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def reset_parameters(self) -> None:
         | 
| 60 | 
            +
                    # Trigger resetting the rope-cache
         | 
| 61 | 
            +
                    self.max_seq_length = self.config.block_size
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def _init_weights(self, module: nn.Module) -> None:
         | 
| 64 | 
            +
                    """Meant to be used with `gpt.apply(gpt._init_weights)`."""
         | 
| 65 | 
            +
                    if isinstance(module, nn.Linear):
         | 
| 66 | 
            +
                        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
         | 
| 67 | 
            +
                        if module.bias is not None:
         | 
| 68 | 
            +
                            torch.nn.init.zeros_(module.bias)
         | 
| 69 | 
            +
                    elif isinstance(module, nn.Embedding):
         | 
| 70 | 
            +
                        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
         | 
| 73 | 
            +
                    T = idx.size(1)
         | 
| 74 | 
            +
                    if self.max_seq_length < T:
         | 
| 75 | 
            +
                        raise ValueError(
         | 
| 76 | 
            +
                            f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}."
         | 
| 77 | 
            +
                        )
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    if input_pos is not None:  # use the kv cache
         | 
| 80 | 
            +
                        cos = self.cos.index_select(0, input_pos)
         | 
| 81 | 
            +
                        sin = self.sin.index_select(0, input_pos)
         | 
| 82 | 
            +
                        if self.mask_cache is None:
         | 
| 83 | 
            +
                            raise TypeError("You need to call `gpt.set_kv_cache()`")
         | 
| 84 | 
            +
                        mask = self.mask_cache.index_select(2, input_pos)
         | 
| 85 | 
            +
                    else:
         | 
| 86 | 
            +
                        cos = self.cos[:T]
         | 
| 87 | 
            +
                        sin = self.sin[:T]
         | 
| 88 | 
            +
                        mask = None
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    x = self.transformer.wte(idx)  # token embeddings of shape (b, t, n_embd)
         | 
| 91 | 
            +
                    for block in self.transformer.h:
         | 
| 92 | 
            +
                        x = block(x, cos, sin, mask, input_pos)
         | 
| 93 | 
            +
                    x = self.transformer.ln_f(x)
         | 
| 94 | 
            +
                    return self.lm_head(x)  # (b, t, vocab_size)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                @classmethod
         | 
| 97 | 
            +
                def from_name(cls, name: str, **kwargs: Any) -> Self:
         | 
| 98 | 
            +
                    return cls(Config.from_name(name, **kwargs))
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                def rope_cache(
         | 
| 101 | 
            +
                    self, device: Optional[torch.device] = None
         | 
| 102 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 103 | 
            +
                    return build_rope_cache(
         | 
| 104 | 
            +
                        seq_len=self.max_seq_length,
         | 
| 105 | 
            +
                        n_elem=self.config.rope_n_elem,
         | 
| 106 | 
            +
                        device=device,
         | 
| 107 | 
            +
                        condense_ratio=self.config.rope_condense_ratio,
         | 
| 108 | 
            +
                        base=self.config.rope_base,
         | 
| 109 | 
            +
                    )
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def set_kv_cache(
         | 
| 112 | 
            +
                    self,
         | 
| 113 | 
            +
                    batch_size: int,
         | 
| 114 | 
            +
                    rope_cache_length: Optional[int] = None,
         | 
| 115 | 
            +
                    device: Optional[torch.device] = None,
         | 
| 116 | 
            +
                    dtype: Optional[torch.dtype] = None,
         | 
| 117 | 
            +
                ) -> None:
         | 
| 118 | 
            +
                    if rope_cache_length is None:
         | 
| 119 | 
            +
                        rope_cache_length = self.cos.size(-1)
         | 
| 120 | 
            +
                    max_seq_length = self.max_seq_length
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    # initialize the kv cache for all blocks
         | 
| 123 | 
            +
                    for block in self.transformer.h:
         | 
| 124 | 
            +
                        block.attn.kv_cache = block.attn.build_kv_cache(
         | 
| 125 | 
            +
                            batch_size, max_seq_length, rope_cache_length, device, dtype
         | 
| 126 | 
            +
                        )
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
         | 
| 129 | 
            +
                        # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask
         | 
| 130 | 
            +
                        # for the kv-cache support (only during inference), we only create it in that situation
         | 
| 131 | 
            +
                        # this will be resolved by https://github.com/pytorch/pytorch/issues/96099
         | 
| 132 | 
            +
                        ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
         | 
| 133 | 
            +
                        self.mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def clear_kv_cache(self) -> None:
         | 
| 136 | 
            +
                    self.mask_cache = None
         | 
| 137 | 
            +
                    for block in self.transformer.h:
         | 
| 138 | 
            +
                        block.attn.kv_cache = None
         | 
| 139 | 
            +
             | 
| 140 | 
            +
             | 
| 141 | 
            +
            class Block(nn.Module):
         | 
| 142 | 
            +
                def __init__(self, config: Config) -> None:
         | 
| 143 | 
            +
                    super().__init__()
         | 
| 144 | 
            +
                    self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
         | 
| 145 | 
            +
                    self.attn = CausalSelfAttention(config)
         | 
| 146 | 
            +
                    self.norm_2 = (
         | 
| 147 | 
            +
                        None
         | 
| 148 | 
            +
                        if config.shared_attention_norm
         | 
| 149 | 
            +
                        else config.norm_class(config.n_embd, eps=config.norm_eps)
         | 
| 150 | 
            +
                    )
         | 
| 151 | 
            +
                    self.mlp = config.mlp_class(config)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    self.config = config
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                def forward(
         | 
| 156 | 
            +
                    self,
         | 
| 157 | 
            +
                    x: torch.Tensor,
         | 
| 158 | 
            +
                    cos: torch.Tensor,
         | 
| 159 | 
            +
                    sin: torch.Tensor,
         | 
| 160 | 
            +
                    mask: Optional[torch.Tensor] = None,
         | 
| 161 | 
            +
                    input_pos: Optional[torch.Tensor] = None,
         | 
| 162 | 
            +
                ) -> torch.Tensor:
         | 
| 163 | 
            +
                    n_1 = self.norm_1(x)
         | 
| 164 | 
            +
                    h = self.attn(n_1, cos, sin, mask, input_pos)
         | 
| 165 | 
            +
                    if self.config.parallel_residual:
         | 
| 166 | 
            +
                        n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
         | 
| 167 | 
            +
                        x = self.mlp(n_2) + h + x
         | 
| 168 | 
            +
                    else:
         | 
| 169 | 
            +
                        if self.config.shared_attention_norm:
         | 
| 170 | 
            +
                            raise NotImplementedError(
         | 
| 171 | 
            +
                                "No checkpoint amongst the ones we support uses this configuration"
         | 
| 172 | 
            +
                                " (non-parallel residual and shared attention norm)."
         | 
| 173 | 
            +
                            )
         | 
| 174 | 
            +
                        x = h + x
         | 
| 175 | 
            +
                        x = self.mlp(self.norm_2(x)) + x
         | 
| 176 | 
            +
                    return x
         | 
| 177 | 
            +
             | 
| 178 | 
            +
             | 
| 179 | 
            +
            class CausalSelfAttention(nn.Module):
         | 
| 180 | 
            +
                def __init__(self, config: Config) -> None:
         | 
| 181 | 
            +
                    super().__init__()
         | 
| 182 | 
            +
                    shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
         | 
| 183 | 
            +
                    # key, query, value projections for all heads, but in a batch
         | 
| 184 | 
            +
                    self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
         | 
| 185 | 
            +
                    # output projection
         | 
| 186 | 
            +
                    self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
         | 
| 187 | 
            +
                    # disabled by default
         | 
| 188 | 
            +
                    self.kv_cache: Optional[KVCache] = None
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    self.config = config
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                def forward(
         | 
| 193 | 
            +
                    self,
         | 
| 194 | 
            +
                    x: torch.Tensor,
         | 
| 195 | 
            +
                    cos: torch.Tensor,
         | 
| 196 | 
            +
                    sin: torch.Tensor,
         | 
| 197 | 
            +
                    mask: Optional[torch.Tensor] = None,
         | 
| 198 | 
            +
                    input_pos: Optional[torch.Tensor] = None,
         | 
| 199 | 
            +
                ) -> torch.Tensor:
         | 
| 200 | 
            +
                    B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    qkv = self.attn(x)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
         | 
| 205 | 
            +
                    q_per_kv = self.config.n_head // self.config.n_query_groups
         | 
| 206 | 
            +
                    total_qkv = q_per_kv + 2  # each group has 1+ queries, 1 key, and 1 value
         | 
| 207 | 
            +
                    qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)
         | 
| 208 | 
            +
                    qkv = qkv.permute(0, 2, 3, 1, 4)  # (B, n_query_groups, total_qkv, T, hs)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    # split batched computation into three
         | 
| 211 | 
            +
                    q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    # maybe repeat k and v if for the non multi-head attention cases
         | 
| 214 | 
            +
                    # training: flash attention requires it
         | 
| 215 | 
            +
                    # inference: multi-query would require a full kv cache so avoid it to limit its memory usage
         | 
| 216 | 
            +
                    if self.config.n_query_groups != self.config.n_head and (
         | 
| 217 | 
            +
                        input_pos is None or self.config.n_query_groups != 1
         | 
| 218 | 
            +
                    ):
         | 
| 219 | 
            +
                        k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
         | 
| 220 | 
            +
                        v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    q = q.reshape(B, -1, T, self.config.head_size)  # (B, nh_q, T, hs)
         | 
| 223 | 
            +
                    k = k.reshape(B, -1, T, self.config.head_size)  # (B, nh_k, T, hs)
         | 
| 224 | 
            +
                    v = v.reshape(B, -1, T, self.config.head_size)  # (B, nh_v, T, hs)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
         | 
| 227 | 
            +
                    k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
         | 
| 228 | 
            +
                    q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
         | 
| 229 | 
            +
                    k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    if input_pos is not None:
         | 
| 232 | 
            +
                        if not isinstance(self.kv_cache, KVCache):
         | 
| 233 | 
            +
                            raise TypeError("You need to call `gpt.set_kv_cache()`")
         | 
| 234 | 
            +
                        k, v = self.kv_cache(input_pos, k, v)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    y = self.scaled_dot_product_attention(q, k, v, mask)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    y = y.reshape(B, T, C)  # re-assemble all head outputs side by side
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    # output projection
         | 
| 241 | 
            +
                    return self.proj(y)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                def scaled_dot_product_attention(
         | 
| 244 | 
            +
                    self,
         | 
| 245 | 
            +
                    q: torch.Tensor,
         | 
| 246 | 
            +
                    k: torch.Tensor,
         | 
| 247 | 
            +
                    v: torch.Tensor,
         | 
| 248 | 
            +
                    mask: Optional[torch.Tensor] = None,
         | 
| 249 | 
            +
                ) -> torch.Tensor:
         | 
| 250 | 
            +
                    scale = 1.0 / math.sqrt(self.config.head_size)
         | 
| 251 | 
            +
                    y = torch.nn.functional.scaled_dot_product_attention(
         | 
| 252 | 
            +
                        q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
         | 
| 253 | 
            +
                    )
         | 
| 254 | 
            +
                    return y.transpose(1, 2)
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                def build_kv_cache(
         | 
| 257 | 
            +
                    self,
         | 
| 258 | 
            +
                    batch_size: int,
         | 
| 259 | 
            +
                    max_seq_length: int,
         | 
| 260 | 
            +
                    rope_cache_length: Optional[int] = None,
         | 
| 261 | 
            +
                    device: Optional[torch.device] = None,
         | 
| 262 | 
            +
                    dtype: Optional[torch.dtype] = None,
         | 
| 263 | 
            +
                ) -> "KVCache":
         | 
| 264 | 
            +
                    heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
         | 
| 265 | 
            +
                    v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
         | 
| 266 | 
            +
                    if rope_cache_length is None:
         | 
| 267 | 
            +
                        if self.config.rotary_percentage != 1.0:
         | 
| 268 | 
            +
                            raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value")
         | 
| 269 | 
            +
                        k_shape = v_shape
         | 
| 270 | 
            +
                    else:
         | 
| 271 | 
            +
                        k_shape = (
         | 
| 272 | 
            +
                            batch_size,
         | 
| 273 | 
            +
                            heads,
         | 
| 274 | 
            +
                            max_seq_length,
         | 
| 275 | 
            +
                            rope_cache_length + self.config.head_size - self.config.rope_n_elem,
         | 
| 276 | 
            +
                        )
         | 
| 277 | 
            +
                    return KVCache(k_shape, v_shape, device=device, dtype=dtype)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
             | 
| 280 | 
            +
            class GptNeoxMLP(nn.Module):
         | 
| 281 | 
            +
                def __init__(self, config: Config) -> None:
         | 
| 282 | 
            +
                    super().__init__()
         | 
| 283 | 
            +
                    self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
         | 
| 284 | 
            +
                    self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    self.config = config
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 289 | 
            +
                    x = self.fc(x)
         | 
| 290 | 
            +
                    x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate)
         | 
| 291 | 
            +
                    return self.proj(x)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
             | 
| 294 | 
            +
            class LLaMAMLP(nn.Module):
         | 
| 295 | 
            +
                def __init__(self, config: Config) -> None:
         | 
| 296 | 
            +
                    super().__init__()
         | 
| 297 | 
            +
                    self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
         | 
| 298 | 
            +
                    self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
         | 
| 299 | 
            +
                    self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 302 | 
            +
                    x_fc_1 = self.fc_1(x)
         | 
| 303 | 
            +
                    x_fc_2 = self.fc_2(x)
         | 
| 304 | 
            +
                    x = torch.nn.functional.silu(x_fc_1) * x_fc_2
         | 
| 305 | 
            +
                    return self.proj(x)
         | 
| 306 | 
            +
             | 
| 307 | 
            +
             | 
| 308 | 
            +
            def build_rope_cache(
         | 
| 309 | 
            +
                seq_len: int,
         | 
| 310 | 
            +
                n_elem: int,
         | 
| 311 | 
            +
                device: Optional[torch.device] = None,
         | 
| 312 | 
            +
                base: int = 10000,
         | 
| 313 | 
            +
                condense_ratio: int = 1,
         | 
| 314 | 
            +
            ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 315 | 
            +
                """Enhanced Transformer with Rotary Position Embedding.
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
         | 
| 318 | 
            +
                transformers/rope/__init__.py. MIT License:
         | 
| 319 | 
            +
                https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
         | 
| 320 | 
            +
                """
         | 
| 321 | 
            +
                # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
         | 
| 322 | 
            +
                theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                # Create position indexes `[0, 1, ..., seq_len - 1]`
         | 
| 325 | 
            +
                seq_idx = torch.arange(seq_len, device=device) / condense_ratio
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                # Calculate the product of position index and $\theta_i$
         | 
| 328 | 
            +
                idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                return torch.cos(idx_theta), torch.sin(idx_theta)
         | 
| 331 | 
            +
             | 
| 332 | 
            +
             | 
| 333 | 
            +
            def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
         | 
| 334 | 
            +
                head_size = x.size(-1)
         | 
| 335 | 
            +
                x1 = x[..., : head_size // 2]  # (B, nh, T, hs/2)
         | 
| 336 | 
            +
                x2 = x[..., head_size // 2 :]  # (B, nh, T, hs/2)
         | 
| 337 | 
            +
                rotated = torch.cat((-x2, x1), dim=-1)  # (B, nh, T, hs)
         | 
| 338 | 
            +
                roped = (x * cos) + (rotated * sin)
         | 
| 339 | 
            +
                return roped.type_as(x)
         | 
| 340 | 
            +
             | 
| 341 | 
            +
             | 
| 342 | 
            +
            class KVCache(nn.Module):
         | 
| 343 | 
            +
                def __init__(
         | 
| 344 | 
            +
                    self,
         | 
| 345 | 
            +
                    k_shape: Tuple[int, int, int, int],
         | 
| 346 | 
            +
                    v_shape: Tuple[int, int, int, int],
         | 
| 347 | 
            +
                    device: Optional[torch.device] = None,
         | 
| 348 | 
            +
                    dtype: Optional[torch.dtype] = None,
         | 
| 349 | 
            +
                ) -> None:
         | 
| 350 | 
            +
                    super().__init__()
         | 
| 351 | 
            +
                    self.register_buffer(
         | 
| 352 | 
            +
                        "k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False
         | 
| 353 | 
            +
                    )
         | 
| 354 | 
            +
                    self.register_buffer(
         | 
| 355 | 
            +
                        "v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False
         | 
| 356 | 
            +
                    )
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                def forward(
         | 
| 359 | 
            +
                    self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor
         | 
| 360 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 361 | 
            +
                    # move the buffer to the activation dtype for when AMP is used
         | 
| 362 | 
            +
                    self.k = self.k.to(k.dtype)
         | 
| 363 | 
            +
                    self.v = self.v.to(v.dtype)
         | 
| 364 | 
            +
                    # update the cache
         | 
| 365 | 
            +
                    k = self.k.index_copy_(2, input_pos, k)
         | 
| 366 | 
            +
                    v = self.v.index_copy_(2, input_pos, v)
         | 
| 367 | 
            +
                    return k, v
         | 
    	
        tsai_gpt/packed_dataset.py
    ADDED
    
    | @@ -0,0 +1,254 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Very loosely inspired by indexed_dataset in Fairseq, Megatron
         | 
| 2 | 
            +
            # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import random
         | 
| 7 | 
            +
            import struct
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from torch.utils.data import IterableDataset, get_worker_info
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            dtypes = {
         | 
| 14 | 
            +
                1: np.uint8,
         | 
| 15 | 
            +
                2: np.int8,
         | 
| 16 | 
            +
                3: np.int16,
         | 
| 17 | 
            +
                4: np.int32,
         | 
| 18 | 
            +
                5: np.int64,
         | 
| 19 | 
            +
                6: np.float32,
         | 
| 20 | 
            +
                7: np.float64,
         | 
| 21 | 
            +
                8: np.uint16,
         | 
| 22 | 
            +
            }
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def code(dtype):
         | 
| 26 | 
            +
                for k in dtypes:
         | 
| 27 | 
            +
                    if dtypes[k] == dtype:
         | 
| 28 | 
            +
                        return k
         | 
| 29 | 
            +
                raise ValueError(dtype)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            HDR_MAGIC = b"LITPKDS"
         | 
| 33 | 
            +
            HDR_SIZE = 24  # bytes
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            class PackedDataset(IterableDataset):
         | 
| 37 | 
            +
                def __init__(
         | 
| 38 | 
            +
                    self,
         | 
| 39 | 
            +
                    filenames,
         | 
| 40 | 
            +
                    n_chunks,
         | 
| 41 | 
            +
                    block_size,
         | 
| 42 | 
            +
                    seed=12345,
         | 
| 43 | 
            +
                    shuffle=True,
         | 
| 44 | 
            +
                    wrap=False,
         | 
| 45 | 
            +
                    num_processes=1,
         | 
| 46 | 
            +
                    process_rank=0,
         | 
| 47 | 
            +
                ):
         | 
| 48 | 
            +
                    self._filenames = filenames
         | 
| 49 | 
            +
                    self._n_chunks = n_chunks
         | 
| 50 | 
            +
                    self._block_size = block_size
         | 
| 51 | 
            +
                    self._seed = seed
         | 
| 52 | 
            +
                    self._shuffle = shuffle
         | 
| 53 | 
            +
                    self._wrap = wrap
         | 
| 54 | 
            +
                    self._num_processes = num_processes
         | 
| 55 | 
            +
                    self._process_rank = process_rank
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def __iter__(self):
         | 
| 58 | 
            +
                    worker_info = get_worker_info()
         | 
| 59 | 
            +
                    num_workers = worker_info.num_workers if worker_info is not None else 1
         | 
| 60 | 
            +
                    worker_id = worker_info.id if worker_info is not None else 0
         | 
| 61 | 
            +
                    num_shards = num_workers * self._num_processes
         | 
| 62 | 
            +
                    shard_id = self._process_rank * num_workers + worker_id
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    max_num_files = len(self._filenames) // num_shards * num_shards
         | 
| 65 | 
            +
                    filenames = self._filenames[shard_id:max_num_files:num_shards]
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    return PackedDatasetIterator(
         | 
| 68 | 
            +
                        filenames=filenames,
         | 
| 69 | 
            +
                        n_chunks=self._n_chunks,
         | 
| 70 | 
            +
                        block_size=self._block_size,
         | 
| 71 | 
            +
                        seed=self._seed,
         | 
| 72 | 
            +
                        shuffle=self._shuffle,
         | 
| 73 | 
            +
                        wrap=self._wrap,
         | 
| 74 | 
            +
                    )
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            class PackedDatasetBuilder(object):
         | 
| 78 | 
            +
                def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None):
         | 
| 79 | 
            +
                    if dtype == "auto":
         | 
| 80 | 
            +
                        if vocab_size is None:
         | 
| 81 | 
            +
                            raise ValueError("vocab_size cannot be None when dtype='auto'")
         | 
| 82 | 
            +
                        if vocab_size is not None and vocab_size < 65500:
         | 
| 83 | 
            +
                            self._dtype = np.uint16
         | 
| 84 | 
            +
                        else:
         | 
| 85 | 
            +
                            self._dtype = np.int32
         | 
| 86 | 
            +
                    else:
         | 
| 87 | 
            +
                        self._dtype = dtype
         | 
| 88 | 
            +
                    self._counter = 0
         | 
| 89 | 
            +
                    self._chunk_size = chunk_size
         | 
| 90 | 
            +
                    self._outdir = outdir
         | 
| 91 | 
            +
                    self._prefix = prefix
         | 
| 92 | 
            +
                    self._sep_token = sep_token
         | 
| 93 | 
            +
                    self._arr = np.zeros(self._chunk_size, dtype=self._dtype)
         | 
| 94 | 
            +
                    self._arr.fill(self._sep_token)
         | 
| 95 | 
            +
                    self._idx = 0
         | 
| 96 | 
            +
                    self._version = 1
         | 
| 97 | 
            +
                    self._filenames = []
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def _write_chunk(self):
         | 
| 100 | 
            +
                    filename = f"{self._prefix}_{self._counter:010d}.bin"
         | 
| 101 | 
            +
                    filename = os.path.join(self._outdir, filename)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    with open(filename, "wb") as f:
         | 
| 104 | 
            +
                        f.write(HDR_MAGIC)
         | 
| 105 | 
            +
                        f.write(struct.pack("<Q", self._version))
         | 
| 106 | 
            +
                        f.write(struct.pack("<B", code(self._dtype)))
         | 
| 107 | 
            +
                        f.write(struct.pack("<Q", self._chunk_size))
         | 
| 108 | 
            +
                        f.write(self._arr.tobytes(order="C"))
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    self._filenames.append(filename)
         | 
| 111 | 
            +
                    self._counter += 1
         | 
| 112 | 
            +
                    self._arr.fill(self._sep_token)
         | 
| 113 | 
            +
                    self._idx = 0
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                @property
         | 
| 116 | 
            +
                def dtype(self):
         | 
| 117 | 
            +
                    return self._dtype
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                @property
         | 
| 120 | 
            +
                def filenames(self):
         | 
| 121 | 
            +
                    return self._filenames.copy()
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                def add_array(self, arr):
         | 
| 124 | 
            +
                    while self._idx + arr.shape[0] > self._chunk_size:
         | 
| 125 | 
            +
                        part_len = self._chunk_size - self._idx
         | 
| 126 | 
            +
                        self._arr[self._idx : self._idx + part_len] = arr[:part_len]
         | 
| 127 | 
            +
                        self._write_chunk()
         | 
| 128 | 
            +
                        arr = arr[part_len:]
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    arr_len = arr.shape[0]
         | 
| 131 | 
            +
                    self._arr[self._idx : self._idx + arr_len] = arr
         | 
| 132 | 
            +
                    self._idx += arr_len
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def write_reminder(self):
         | 
| 135 | 
            +
                    self._write_chunk()
         | 
| 136 | 
            +
             | 
| 137 | 
            +
             | 
| 138 | 
            +
            class PackedDatasetIterator:
         | 
| 139 | 
            +
                def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap):
         | 
| 140 | 
            +
                    self._seed = seed
         | 
| 141 | 
            +
                    self._shuffle = shuffle
         | 
| 142 | 
            +
                    self._rng = np.random.default_rng(seed) if shuffle else None
         | 
| 143 | 
            +
                    self._block_idxs = None
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    self._wrap = wrap
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    # TODO: instead of filenames, we could have a single text stream
         | 
| 148 | 
            +
                    #       (or text file) with the sequence of all files to be
         | 
| 149 | 
            +
                    #       fetched/loaded.
         | 
| 150 | 
            +
                    self._filenames = filenames
         | 
| 151 | 
            +
                    self._file_idx = 0
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    self._n_chunks = n_chunks
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    self._dtype = None
         | 
| 156 | 
            +
                    self._block_size = block_size
         | 
| 157 | 
            +
                    self._n_blocks = None
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    self._mmaps = []
         | 
| 160 | 
            +
                    self._buffers = []
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    self._block_idxs = []
         | 
| 163 | 
            +
                    self._curr_idx = 0
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    self._load_n_chunks()
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                def _read_header(self, path):
         | 
| 168 | 
            +
                    with open(path, "rb") as f:
         | 
| 169 | 
            +
                        magic = f.read(len(HDR_MAGIC))
         | 
| 170 | 
            +
                        assert magic == HDR_MAGIC, "File doesn't match expected format."
         | 
| 171 | 
            +
                        version = struct.unpack("<Q", f.read(8))
         | 
| 172 | 
            +
                        assert version == (1,)
         | 
| 173 | 
            +
                        (dtype_code,) = struct.unpack("<B", f.read(1))
         | 
| 174 | 
            +
                        dtype = dtypes[dtype_code]
         | 
| 175 | 
            +
                        (chunk_size,) = struct.unpack("<Q", f.read(8))
         | 
| 176 | 
            +
                    return dtype, chunk_size
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                def _close_mmaps(self):
         | 
| 179 | 
            +
                    for mmap in self._mmaps:
         | 
| 180 | 
            +
                        mmap._mmap.close()
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                def _load_n_chunks(self):
         | 
| 183 | 
            +
                    self._close_mmaps()
         | 
| 184 | 
            +
                    self._mmaps = []
         | 
| 185 | 
            +
                    self._buffers = []
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    if self._n_chunks > len(self._filenames[self._file_idx :]):
         | 
| 188 | 
            +
                        if not self._wrap:
         | 
| 189 | 
            +
                            raise StopIteration
         | 
| 190 | 
            +
                        self._file_idx = 0
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    for i in range(self._n_chunks):
         | 
| 193 | 
            +
                        filename = self._filenames[self._file_idx + i]
         | 
| 194 | 
            +
                        if self._dtype is None:
         | 
| 195 | 
            +
                            self._dtype, self._chunk_size = self._read_header(filename)
         | 
| 196 | 
            +
                            self._n_blocks = self._chunk_size // self._block_size
         | 
| 197 | 
            +
                        # TODO: check header matches with previous files
         | 
| 198 | 
            +
                        mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
         | 
| 199 | 
            +
                        self._mmaps.append(mmap)
         | 
| 200 | 
            +
                        self._buffers.append(memoryview(mmap))
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    self._file_idx += self._n_chunks
         | 
| 203 | 
            +
                    n_all_blocks = self._n_chunks * self._n_blocks
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    self._block_idxs = (
         | 
| 206 | 
            +
                        self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks)
         | 
| 207 | 
            +
                    )
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    self._curr_idx = 0
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                def __del__(self):
         | 
| 212 | 
            +
                    self._close_mmaps()
         | 
| 213 | 
            +
                    del self._mmaps
         | 
| 214 | 
            +
                    del self._buffers
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                def __iter__(self):
         | 
| 217 | 
            +
                    return self
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                def __next__(self):
         | 
| 220 | 
            +
                    if self._curr_idx >= len(self._block_idxs):
         | 
| 221 | 
            +
                        self._load_n_chunks()
         | 
| 222 | 
            +
                        # TODO: trigger fetching next next n_chunks if remote
         | 
| 223 | 
            +
                    block_idx = self._block_idxs[self._curr_idx]
         | 
| 224 | 
            +
                    chunk_id = block_idx // self._n_blocks
         | 
| 225 | 
            +
                    buffer = self._buffers[chunk_id]
         | 
| 226 | 
            +
                    elem_id = (block_idx % self._n_blocks) * self._block_size
         | 
| 227 | 
            +
                    offset = np.dtype(self._dtype).itemsize * elem_id
         | 
| 228 | 
            +
                    arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
         | 
| 229 | 
            +
                    self._curr_idx += 1
         | 
| 230 | 
            +
                    return torch.from_numpy(arr.astype(np.int64))
         | 
| 231 | 
            +
             | 
| 232 | 
            +
             | 
| 233 | 
            +
            class CombinedDataset(IterableDataset):
         | 
| 234 | 
            +
                def __init__(self, datasets, seed, weights=None):
         | 
| 235 | 
            +
                    self._seed = seed
         | 
| 236 | 
            +
                    self._datasets = datasets
         | 
| 237 | 
            +
                    self._weights = weights
         | 
| 238 | 
            +
                    n_datasets = len(datasets)
         | 
| 239 | 
            +
                    if weights is None:
         | 
| 240 | 
            +
                        self._weights = [1 / n_datasets] * n_datasets
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                def __iter__(self):
         | 
| 243 | 
            +
                    return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
             | 
| 246 | 
            +
            class CombinedDatasetIterator:
         | 
| 247 | 
            +
                def __init__(self, datasets, seed, weights):
         | 
| 248 | 
            +
                    self._datasets = [iter(el) for el in datasets]
         | 
| 249 | 
            +
                    self._weights = weights
         | 
| 250 | 
            +
                    self._rng = random.Random(seed)
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                def __next__(self):
         | 
| 253 | 
            +
                    (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1)
         | 
| 254 | 
            +
                    return next(dataset)
         | 
    	
        tsai_gpt/rmsnorm.py
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            class RMSNorm(torch.nn.Module):
         | 
| 5 | 
            +
                """Root Mean Square Layer Normalization.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
                Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
         | 
| 8 | 
            +
                https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
         | 
| 9 | 
            +
                """
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
         | 
| 12 | 
            +
                    super().__init__()
         | 
| 13 | 
            +
                    self.weight = torch.nn.Parameter(torch.ones(size))
         | 
| 14 | 
            +
                    self.eps = eps
         | 
| 15 | 
            +
                    self.dim = dim
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 18 | 
            +
                    dtype = x.dtype
         | 
| 19 | 
            +
                    x = x.float()
         | 
| 20 | 
            +
                    # NOTE: the original RMSNorm paper implementation is not equivalent
         | 
| 21 | 
            +
                    norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
         | 
| 22 | 
            +
                    x_normed = x * torch.rsqrt(norm_x + self.eps)
         | 
| 23 | 
            +
                    return (self.weight * x_normed).to(dtype=dtype)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def reset_parameters(self) -> None:
         | 
| 26 | 
            +
                    torch.nn.init.ones_(self.weight)
         | 
    	
        tsai_gpt/speed_monitor.py
    ADDED
    
    | @@ -0,0 +1,438 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import time
         | 
| 2 | 
            +
            from collections import deque
         | 
| 3 | 
            +
            from contextlib import nullcontext
         | 
| 4 | 
            +
            from typing import Any, Callable, Deque, Dict, Optional
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from lightning import Callback, Fabric, LightningModule, Trainer
         | 
| 8 | 
            +
            from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
         | 
| 9 | 
            +
            from lightning.fabric.plugins import (BitsandbytesPrecision, DoublePrecision,
         | 
| 10 | 
            +
                                                  FSDPPrecision, HalfPrecision,
         | 
| 11 | 
            +
                                                  MixedPrecision, Precision,
         | 
| 12 | 
            +
                                                  TransformerEnginePrecision, XLAPrecision)
         | 
| 13 | 
            +
            from lightning.fabric.utilities.rank_zero import \
         | 
| 14 | 
            +
                rank_zero_only as fabric_rank_zero_only
         | 
| 15 | 
            +
            from lightning.pytorch.plugins import (DoublePrecisionPlugin,
         | 
| 16 | 
            +
                                                   FSDPPrecisionPlugin,
         | 
| 17 | 
            +
                                                   HalfPrecisionPlugin,
         | 
| 18 | 
            +
                                                   MixedPrecisionPlugin,
         | 
| 19 | 
            +
                                                   XLAPrecisionPlugin)
         | 
| 20 | 
            +
            from lightning.pytorch.utilities.rank_zero import \
         | 
| 21 | 
            +
                rank_zero_only as trainer_rank_zero_only
         | 
| 22 | 
            +
            from torch.utils.flop_counter import FlopCounterMode
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            from tsai_gpt import GPT
         | 
| 25 | 
            +
            from tsai_gpt.utils import num_parameters
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            GPU_AVAILABLE_FLOPS = {
         | 
| 28 | 
            +
                # source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
         | 
| 29 | 
            +
                # nvidia publishes spec sheet with a 2x sparsity factor
         | 
| 30 | 
            +
                "h100-sxm": {
         | 
| 31 | 
            +
                    torch.float64: 67e12,
         | 
| 32 | 
            +
                    torch.float32: 67e12,
         | 
| 33 | 
            +
                    torch.bfloat16: 1.979e15 / 2,
         | 
| 34 | 
            +
                    torch.float16: 1.979e15 / 2,
         | 
| 35 | 
            +
                    torch.int8: 3.958e15 / 2,
         | 
| 36 | 
            +
                },
         | 
| 37 | 
            +
                "h100-pcie": {
         | 
| 38 | 
            +
                    torch.float64: 51e12,
         | 
| 39 | 
            +
                    torch.float32: 51e12,
         | 
| 40 | 
            +
                    torch.bfloat16: 1.513e15 / 2,
         | 
| 41 | 
            +
                    torch.float16: 1.513e15 / 2,
         | 
| 42 | 
            +
                    torch.int8: 3.026e15 / 2,
         | 
| 43 | 
            +
                },
         | 
| 44 | 
            +
                # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
         | 
| 45 | 
            +
                # sxm and pcie have same flop counts
         | 
| 46 | 
            +
                "a100": {
         | 
| 47 | 
            +
                    torch.float64: 19.5e12,
         | 
| 48 | 
            +
                    torch.float32: 19.5e12,
         | 
| 49 | 
            +
                    torch.bfloat16: 312e12,
         | 
| 50 | 
            +
                    torch.float16: 312e12,
         | 
| 51 | 
            +
                },
         | 
| 52 | 
            +
                # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/a10-datasheet.pdf
         | 
| 53 | 
            +
                "a10g": {torch.float32: 31.2e12, torch.bfloat16: 125e12, torch.float16: 125e12},
         | 
| 54 | 
            +
                # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf
         | 
| 55 | 
            +
                "v100-sxm": {torch.float64: 7.8e12, torch.float32: 15.7e12, torch.float16: 125e12},
         | 
| 56 | 
            +
                "v100-pcie": {torch.float64: 7e12, torch.float32: 14e12, torch.float16: 112e12},
         | 
| 57 | 
            +
                "v100s-pcie": {torch.float64: 8.2e12, torch.float32: 16.4e12, torch.float16: 130e12},
         | 
| 58 | 
            +
                # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf
         | 
| 59 | 
            +
                # sxm and pcie have same flop counts
         | 
| 60 | 
            +
                "t4": {torch.float32: 8.1e12, torch.float16: 65e12, torch.int8: 130e12},
         | 
| 61 | 
            +
                # https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/quadro-rtx-5000-data-sheet-us-nvidia-704120-r4-web.pdf
         | 
| 62 | 
            +
                "quadro rtx 5000": {torch.float32: 11.2e12, torch.float16: 89.2e12},
         | 
| 63 | 
            +
            }
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            TPU_AVAILABLE_FLOPS = {
         | 
| 66 | 
            +
                # flop count for each TPU generation is the same for all precisions
         | 
| 67 | 
            +
                # since bfloat16 precision is always used for performing matrix operations
         | 
| 68 | 
            +
                # for more info: https://cloud.google.com/tpu/docs/bfloat16#choosing_bfloat16
         | 
| 69 | 
            +
                # source: https://arxiv.org/pdf/1907.10701.pdf
         | 
| 70 | 
            +
                "v2": 45e12,
         | 
| 71 | 
            +
                # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v3
         | 
| 72 | 
            +
                "v3": 123e12,
         | 
| 73 | 
            +
                # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4
         | 
| 74 | 
            +
                "v4": 275e12,
         | 
| 75 | 
            +
                # source: https://cloud.google.com/tpu/docs/v5e-training
         | 
| 76 | 
            +
                "v5litepod": 197e12,
         | 
| 77 | 
            +
            }
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def get_flops_available(device: torch.device, dtype: torch.dtype) -> Optional[float]:
         | 
| 81 | 
            +
                if device.type == "cuda":
         | 
| 82 | 
            +
                    device_name = torch.cuda.get_device_name(device).lower()
         | 
| 83 | 
            +
                    if "h100" in device_name and "hbm3" in device_name:
         | 
| 84 | 
            +
                        device_name = "h100-sxm"
         | 
| 85 | 
            +
                    elif "h100" in device_name and ("pcie" in device_name or "hbm2e" in device_name):
         | 
| 86 | 
            +
                        device_name = "h100-pcie"
         | 
| 87 | 
            +
                    elif "a100" in device_name:
         | 
| 88 | 
            +
                        device_name = "a100"
         | 
| 89 | 
            +
                    elif "a10g" in device_name:
         | 
| 90 | 
            +
                        device_name = "a10g"
         | 
| 91 | 
            +
                    elif "v100-sxm" in device_name:
         | 
| 92 | 
            +
                        device_name = "v100-sxm"
         | 
| 93 | 
            +
                    elif "v100-pcie" in device_name:
         | 
| 94 | 
            +
                        device_name = "v100-pcie"
         | 
| 95 | 
            +
                    elif "t4" in device_name:
         | 
| 96 | 
            +
                        device_name = "t4"
         | 
| 97 | 
            +
                    elif "quadro rtx 5000" in device_name:
         | 
| 98 | 
            +
                        device_name = "quadro rtx 5000"
         | 
| 99 | 
            +
                    else:
         | 
| 100 | 
            +
                        device_name = None
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    if device_name is not None:
         | 
| 103 | 
            +
                        try:
         | 
| 104 | 
            +
                            return int(GPU_AVAILABLE_FLOPS[device_name][dtype])
         | 
| 105 | 
            +
                        except KeyError:
         | 
| 106 | 
            +
                            raise KeyError(
         | 
| 107 | 
            +
                                f"flop count not found for {device_name} with dtype: {dtype}; "
         | 
| 108 | 
            +
                                "MFU cannot be calculated and reported."
         | 
| 109 | 
            +
                            )
         | 
| 110 | 
            +
                elif device.type == "xla":
         | 
| 111 | 
            +
                    if _XLA_GREATER_EQUAL_2_1:
         | 
| 112 | 
            +
                        from torch_xla._internal import tpu
         | 
| 113 | 
            +
                    else:
         | 
| 114 | 
            +
                        from torch_xla.experimental import tpu
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    device_name = tpu.get_tpu_env()["TYPE"].lower()
         | 
| 117 | 
            +
                    try:
         | 
| 118 | 
            +
                        return int(TPU_AVAILABLE_FLOPS[device_name])
         | 
| 119 | 
            +
                    except KeyError:
         | 
| 120 | 
            +
                        raise KeyError(
         | 
| 121 | 
            +
                            f"flop count not found for {device_name} with dtype: {dtype}; MFU cannot be calculated and reported."
         | 
| 122 | 
            +
                        )
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                return None
         | 
| 125 | 
            +
             | 
| 126 | 
            +
             | 
| 127 | 
            +
            # Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820cb75023b9eb7c46fdfd25273712abd0/composer/callbacks/speed_monitor.py
         | 
| 128 | 
            +
             | 
| 129 | 
            +
             | 
| 130 | 
            +
            class SpeedMonitorBase:
         | 
| 131 | 
            +
                """Logs the training throughput and utilization.
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                +-------------------------------------+-----------------------------------------------------------+
         | 
| 134 | 
            +
                | Key                                 | Logged data                                               |
         | 
| 135 | 
            +
                +=====================================+===========================================================+
         | 
| 136 | 
            +
                |                                     | Rolling average (over `window_size` most recent           |
         | 
| 137 | 
            +
                | `throughput/batches_per_sec`        | batches) of the number of batches processed per second    |
         | 
| 138 | 
            +
                |                                     |                                                           |
         | 
| 139 | 
            +
                +-------------------------------------+-----------------------------------------------------------+
         | 
| 140 | 
            +
                |                                     | Rolling average (over `window_size` most recent           |
         | 
| 141 | 
            +
                | `throughput/samples_per_sec`        | batches) of the number of samples processed per second    |
         | 
| 142 | 
            +
                |                                     |                                                           |
         | 
| 143 | 
            +
                +-------------------------------------+-----------------------------------------------------------+
         | 
| 144 | 
            +
                |                                     | Rolling average (over `window_size` most recent           |
         | 
| 145 | 
            +
                | `throughput/tokens_per_sec`         | batches) of the number of tokens processed per second.    |
         | 
| 146 | 
            +
                |                                     | This may include padding depending on dataset             |
         | 
| 147 | 
            +
                +-------------------------------------+-----------------------------------------------------------+
         | 
| 148 | 
            +
                |                                     | Estimates flops by `flops_per_batch * batches_per_sec`    |
         | 
| 149 | 
            +
                | `throughput/flops_per_sec`          |                                                           |
         | 
| 150 | 
            +
                |                                     |                                                           |
         | 
| 151 | 
            +
                +-------------------------------------+-----------------------------------------------------------+
         | 
| 152 | 
            +
                | `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size        |
         | 
| 153 | 
            +
                +-------------------------------------+-----------------------------------------------------------+
         | 
| 154 | 
            +
                | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size        |
         | 
| 155 | 
            +
                +-------------------------------------+-----------------------------------------------------------+
         | 
| 156 | 
            +
                |                                     | `throughput/tokens_per_sec` divided by world size. This   |
         | 
| 157 | 
            +
                | `throughput/device/tokens_per_sec`  | may include pad tokens depending on dataset               |
         | 
| 158 | 
            +
                |                                     |                                                           |
         | 
| 159 | 
            +
                +-------------------------------------+-----------------------------------------------------------+
         | 
| 160 | 
            +
                |                                     | `throughput/flops_per_sec` divided by world size. Only    |
         | 
| 161 | 
            +
                | `throughput/device/flops_per_sec`   | logged when model has attribute `flops_per_batch`         |
         | 
| 162 | 
            +
                |                                     |                                                           |
         | 
| 163 | 
            +
                +-------------------------------------+-----------------------------------------------------------+
         | 
| 164 | 
            +
                |                                     | `throughput/device/flops_per_sec` divided by world size.  |
         | 
| 165 | 
            +
                | `throughput/device/mfu`             |                                                           |
         | 
| 166 | 
            +
                |                                     |                                                           |
         | 
| 167 | 
            +
                +-------------------------------------+-----------------------------------------------------------+
         | 
| 168 | 
            +
                | `time/train`                        | Total elapsed training time                               |
         | 
| 169 | 
            +
                +-------------------------------------+-----------------------------------------------------------+
         | 
| 170 | 
            +
                | `time/val`                          | Total elapsed validation time                             |
         | 
| 171 | 
            +
                +-------------------------------------+-----------------------------------------------------------+
         | 
| 172 | 
            +
                | `time/total`                        | Total elapsed time (time/train + time/val)                |
         | 
| 173 | 
            +
                +-------------------------------------+-----------------------------------------------------------+
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                Notes:
         | 
| 176 | 
            +
                    - The implementation assumes that devices are homogeneous as it normalizes by the world size.
         | 
| 177 | 
            +
                    - Tokens/sec, flops/sec and MFU do not account for padding tokens if present. We suggest using samples/sec or
         | 
| 178 | 
            +
                      batches/sec to measure throughput under this circumstance.
         | 
| 179 | 
            +
                    - Be careful when comparing MFU numbers across projects, as this will highly depend on the ``flops_per_batch``.
         | 
| 180 | 
            +
                      There is no widespread, realistic, and reliable implementation to compute them.
         | 
| 181 | 
            +
                      We suggest using our ``measure_flops`` function, but many other works will use ``estimated_flops`` which
         | 
| 182 | 
            +
                      will almost always be an overestimate when compared to the true value.
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                Args:
         | 
| 185 | 
            +
                    window_size (int, optional): Number of batches to use for a rolling average of throughput.
         | 
| 186 | 
            +
                        Defaults to 100.
         | 
| 187 | 
            +
                    time_unit (str, optional): Time unit to use for `time` logging. Can be one of
         | 
| 188 | 
            +
                        'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'.
         | 
| 189 | 
            +
                """
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                def __init__(
         | 
| 192 | 
            +
                    self,
         | 
| 193 | 
            +
                    flops_available: float,
         | 
| 194 | 
            +
                    log_dict: Callable[[Dict, int], None],
         | 
| 195 | 
            +
                    window_size: int = 100,
         | 
| 196 | 
            +
                    time_unit: str = "hours",
         | 
| 197 | 
            +
                ):
         | 
| 198 | 
            +
                    self.flops_available = flops_available
         | 
| 199 | 
            +
                    self.log_dict = log_dict
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    # Track the batch num samples and wct to compute throughput over a window of batches
         | 
| 202 | 
            +
                    self.history_samples: Deque[int] = deque(maxlen=window_size + 1)
         | 
| 203 | 
            +
                    self.history_wct: Deque[float] = deque(maxlen=window_size + 1)
         | 
| 204 | 
            +
                    self.history_lengths: Deque[int] = deque(maxlen=window_size + 1)
         | 
| 205 | 
            +
                    self.history_flops: Deque[int] = deque(maxlen=window_size + 1)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    self.divider = 1
         | 
| 208 | 
            +
                    if time_unit == "seconds":
         | 
| 209 | 
            +
                        self.divider = 1
         | 
| 210 | 
            +
                    elif time_unit == "minutes":
         | 
| 211 | 
            +
                        self.divider = 60
         | 
| 212 | 
            +
                    elif time_unit == "hours":
         | 
| 213 | 
            +
                        self.divider = 60 * 60
         | 
| 214 | 
            +
                    elif time_unit == "days":
         | 
| 215 | 
            +
                        self.divider = 60 * 60 * 24
         | 
| 216 | 
            +
                    else:
         | 
| 217 | 
            +
                        raise ValueError(
         | 
| 218 | 
            +
                            f'Invalid time_unit: {time_unit}. Must be one of "seconds", "minutes", "hours", or "days".'
         | 
| 219 | 
            +
                        )
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    # Keep track of time spent evaluating
         | 
| 222 | 
            +
                    self.total_eval_wct = 0.0
         | 
| 223 | 
            +
                    self.step = -1
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                def on_train_batch_end(
         | 
| 226 | 
            +
                    self,
         | 
| 227 | 
            +
                    samples: int,  # total samples seen (per device)
         | 
| 228 | 
            +
                    train_elapsed: float,  # total training time (seconds)
         | 
| 229 | 
            +
                    world_size: int,
         | 
| 230 | 
            +
                    flops_per_batch: Optional[int] = None,  # (per device)
         | 
| 231 | 
            +
                    lengths: Optional[int] = None,  # total length of the samples seen (per device)
         | 
| 232 | 
            +
                ) -> None:
         | 
| 233 | 
            +
                    self.step += 1
         | 
| 234 | 
            +
                    step = self.step
         | 
| 235 | 
            +
                    metrics = {}
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    self.history_samples.append(samples)
         | 
| 238 | 
            +
                    if lengths is not None:
         | 
| 239 | 
            +
                        self.history_lengths.append(lengths)
         | 
| 240 | 
            +
                        # if lengths are passed, there should be as many values as samples
         | 
| 241 | 
            +
                        assert len(self.history_samples) == len(self.history_lengths)
         | 
| 242 | 
            +
                    self.history_wct.append(train_elapsed)
         | 
| 243 | 
            +
                    if len(self.history_wct) == self.history_wct.maxlen:
         | 
| 244 | 
            +
                        elapsed_batches = len(self.history_samples) - 1
         | 
| 245 | 
            +
                        elapsed_samples = self.history_samples[-1] - self.history_samples[0]
         | 
| 246 | 
            +
                        elapsed_wct = self.history_wct[-1] - self.history_wct[0]
         | 
| 247 | 
            +
                        samples_per_sec = elapsed_samples * world_size / elapsed_wct
         | 
| 248 | 
            +
                        dev_samples_per_sec = elapsed_samples / elapsed_wct
         | 
| 249 | 
            +
                        metrics.update(
         | 
| 250 | 
            +
                            {
         | 
| 251 | 
            +
                                "throughput/batches_per_sec": elapsed_batches * world_size / elapsed_wct,
         | 
| 252 | 
            +
                                "throughput/samples_per_sec": samples_per_sec,
         | 
| 253 | 
            +
                                "throughput/device/batches_per_sec": elapsed_batches / elapsed_wct,
         | 
| 254 | 
            +
                                "throughput/device/samples_per_sec": dev_samples_per_sec,
         | 
| 255 | 
            +
                            }
         | 
| 256 | 
            +
                        )
         | 
| 257 | 
            +
                        if lengths is not None:
         | 
| 258 | 
            +
                            elapsed_lengths = int(self.history_lengths[-1]) - int(self.history_lengths[0])
         | 
| 259 | 
            +
                            avg_length = elapsed_lengths / elapsed_batches
         | 
| 260 | 
            +
                            metrics.update(
         | 
| 261 | 
            +
                                {
         | 
| 262 | 
            +
                                    "throughput/tokens_per_sec": samples_per_sec * avg_length,
         | 
| 263 | 
            +
                                    "throughput/device/tokens_per_sec": dev_samples_per_sec * avg_length,
         | 
| 264 | 
            +
                                }
         | 
| 265 | 
            +
                            )
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    if flops_per_batch is not None:
         | 
| 268 | 
            +
                        # sum of flops per batch across ranks
         | 
| 269 | 
            +
                        self.history_flops.append(flops_per_batch * world_size)
         | 
| 270 | 
            +
                    if len(self.history_flops) == self.history_flops.maxlen:
         | 
| 271 | 
            +
                        elapsed_flops = sum(self.history_flops) - self.history_flops[0]
         | 
| 272 | 
            +
                        elapsed_wct = self.history_wct[-1] - self.history_wct[0]
         | 
| 273 | 
            +
                        flops_per_sec = elapsed_flops / elapsed_wct
         | 
| 274 | 
            +
                        device_flops_per_sec = flops_per_sec / world_size
         | 
| 275 | 
            +
                        metrics.update(
         | 
| 276 | 
            +
                            {
         | 
| 277 | 
            +
                                "throughput/flops_per_sec": flops_per_sec,
         | 
| 278 | 
            +
                                "throughput/device/flops_per_sec": device_flops_per_sec,
         | 
| 279 | 
            +
                            }
         | 
| 280 | 
            +
                        )
         | 
| 281 | 
            +
                        if self.flops_available:
         | 
| 282 | 
            +
                            metrics["throughput/device/mfu"] = device_flops_per_sec / self.flops_available
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    metrics.update(
         | 
| 285 | 
            +
                        {
         | 
| 286 | 
            +
                            "time/train": train_elapsed / self.divider,
         | 
| 287 | 
            +
                            "time/val": self.total_eval_wct / self.divider,
         | 
| 288 | 
            +
                            "time/total": (train_elapsed + self.total_eval_wct) / self.divider,
         | 
| 289 | 
            +
                            "samples": samples,
         | 
| 290 | 
            +
                        }
         | 
| 291 | 
            +
                    )
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    self.log_dict(metrics, step)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                def eval_end(self, eval_elapsed: float) -> None:
         | 
| 296 | 
            +
                    self.total_eval_wct += eval_elapsed  # seconds
         | 
| 297 | 
            +
             | 
| 298 | 
            +
             | 
| 299 | 
            +
            def plugin_to_compute_dtype(plugin: Precision) -> torch.dtype:
         | 
| 300 | 
            +
                if isinstance(plugin, BitsandbytesPrecision):
         | 
| 301 | 
            +
                    return plugin.dtype
         | 
| 302 | 
            +
                if isinstance(plugin, (HalfPrecision, MixedPrecision, HalfPrecisionPlugin)):
         | 
| 303 | 
            +
                    return plugin._desired_input_dtype
         | 
| 304 | 
            +
                if isinstance(plugin, MixedPrecisionPlugin):
         | 
| 305 | 
            +
                    return torch.bfloat16 if plugin.precision == "bf16-mixed" else torch.half
         | 
| 306 | 
            +
                if isinstance(plugin, (DoublePrecision, DoublePrecisionPlugin)):
         | 
| 307 | 
            +
                    return torch.double
         | 
| 308 | 
            +
                if isinstance(plugin, (XLAPrecision, XLAPrecisionPlugin)):
         | 
| 309 | 
            +
                    return plugin._desired_dtype
         | 
| 310 | 
            +
                if isinstance(plugin, TransformerEnginePrecision):
         | 
| 311 | 
            +
                    return torch.int8
         | 
| 312 | 
            +
                if isinstance(plugin, (FSDPPrecision, FSDPPrecisionPlugin)):
         | 
| 313 | 
            +
                    return plugin.mixed_precision_config.reduce_dtype
         | 
| 314 | 
            +
                if isinstance(plugin, Precision):
         | 
| 315 | 
            +
                    return torch.float32
         | 
| 316 | 
            +
                raise NotImplementedError(plugin)
         | 
| 317 | 
            +
             | 
| 318 | 
            +
             | 
| 319 | 
            +
            class SpeedMonitorFabric(SpeedMonitorBase):
         | 
| 320 | 
            +
                def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None:
         | 
| 321 | 
            +
                    dtype = plugin_to_compute_dtype(fabric.strategy.precision)
         | 
| 322 | 
            +
                    flops_available = get_flops_available(fabric.device, dtype)
         | 
| 323 | 
            +
                    super().__init__(flops_available, fabric.log_dict, *args, **kwargs)
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                @fabric_rank_zero_only
         | 
| 326 | 
            +
                def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None:
         | 
| 327 | 
            +
                    super().on_train_batch_end(*args, **kwargs)
         | 
| 328 | 
            +
             | 
| 329 | 
            +
             | 
| 330 | 
            +
            class SpeedMonitorCallback(Callback):
         | 
| 331 | 
            +
                def __init__(self, length_fn: Callable[[Any], int], batch_size: int, **kwargs: Any) -> None:
         | 
| 332 | 
            +
                    super().__init__()
         | 
| 333 | 
            +
                    self.speed_monitor: Optional[SpeedMonitorBase] = None
         | 
| 334 | 
            +
                    self.speed_monitor_kwargs = kwargs
         | 
| 335 | 
            +
                    self.length_fn = length_fn
         | 
| 336 | 
            +
                    self.batch_size = batch_size
         | 
| 337 | 
            +
                    self.eval_t0: int = 0
         | 
| 338 | 
            +
                    self.train_t0: int = 0
         | 
| 339 | 
            +
                    self.total_lengths: int = 0
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
         | 
| 342 | 
            +
                    if self.speed_monitor is not None:
         | 
| 343 | 
            +
                        return  # already setup
         | 
| 344 | 
            +
                    dtype = plugin_to_compute_dtype(trainer.precision_plugin)
         | 
| 345 | 
            +
                    flops_available = get_flops_available(trainer.strategy.root_device, dtype)
         | 
| 346 | 
            +
                    self.speed_monitor = SpeedMonitorBase(
         | 
| 347 | 
            +
                        flops_available, trainer.logger.log_metrics, **self.speed_monitor_kwargs
         | 
| 348 | 
            +
                    )
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                @trainer_rank_zero_only
         | 
| 351 | 
            +
                def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
         | 
| 352 | 
            +
                    if trainer.fit_loop._should_accumulate():
         | 
| 353 | 
            +
                        return
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    self.train_t0 = time.perf_counter()
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                @trainer_rank_zero_only
         | 
| 358 | 
            +
                def on_train_batch_end(
         | 
| 359 | 
            +
                    self,
         | 
| 360 | 
            +
                    trainer: Trainer,
         | 
| 361 | 
            +
                    pl_module: LightningModule,
         | 
| 362 | 
            +
                    outputs: Any,
         | 
| 363 | 
            +
                    batch: Any,
         | 
| 364 | 
            +
                    batch_idx: int,
         | 
| 365 | 
            +
                ) -> None:
         | 
| 366 | 
            +
                    self.total_lengths += self.length_fn(batch)
         | 
| 367 | 
            +
                    if trainer.fit_loop._should_accumulate():
         | 
| 368 | 
            +
                        return
         | 
| 369 | 
            +
                    train_elapsed = time.perf_counter() - self.train_t0
         | 
| 370 | 
            +
                    assert self.speed_monitor is not None
         | 
| 371 | 
            +
                    iter_num = trainer.fit_loop.total_batch_idx
         | 
| 372 | 
            +
                    assert (measured_flops := pl_module.measured_flops) is not None
         | 
| 373 | 
            +
                    self.speed_monitor.on_train_batch_end(
         | 
| 374 | 
            +
                        (iter_num + 1) * self.batch_size,
         | 
| 375 | 
            +
                        train_elapsed,
         | 
| 376 | 
            +
                        # this assumes that device FLOPs are the same and that all devices have the same batch size
         | 
| 377 | 
            +
                        trainer.world_size,
         | 
| 378 | 
            +
                        flops_per_batch=measured_flops,
         | 
| 379 | 
            +
                        lengths=self.total_lengths,
         | 
| 380 | 
            +
                    )
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                @trainer_rank_zero_only
         | 
| 383 | 
            +
                def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
         | 
| 384 | 
            +
                    self.eval_t0 = time.perf_counter()
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                @trainer_rank_zero_only
         | 
| 387 | 
            +
                def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
         | 
| 388 | 
            +
                    eval_elapsed = time.perf_counter() - self.eval_t0
         | 
| 389 | 
            +
                    assert self.speed_monitor is not None
         | 
| 390 | 
            +
                    self.speed_monitor.eval_end(eval_elapsed)
         | 
| 391 | 
            +
             | 
| 392 | 
            +
             | 
| 393 | 
            +
            def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int:
         | 
| 394 | 
            +
                flops_per_token = (
         | 
| 395 | 
            +
                    2 * n_params
         | 
| 396 | 
            +
                )  # each parameter is used for a MAC (2 FLOPS) per network operation
         | 
| 397 | 
            +
                # this assumes that all samples have a fixed length equal to the block size
         | 
| 398 | 
            +
                # which is most likely false during finetuning
         | 
| 399 | 
            +
                flops_per_seq = flops_per_token * max_seq_length
         | 
| 400 | 
            +
                attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
         | 
| 401 | 
            +
                return flops_per_seq + attn_flops_per_seq
         | 
| 402 | 
            +
             | 
| 403 | 
            +
             | 
| 404 | 
            +
            def estimate_flops(model: GPT) -> int:
         | 
| 405 | 
            +
                """Measures estimated FLOPs for MFU.
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                Refs:
         | 
| 408 | 
            +
                    * https://ar5iv.labs.arxiv.org/html/2205.05198#A1
         | 
| 409 | 
            +
                    * https://ar5iv.labs.arxiv.org/html/2204.02311#A2
         | 
| 410 | 
            +
                """
         | 
| 411 | 
            +
                # using all parameters for this is a naive over estimation because not all model parameters actually contribute to
         | 
| 412 | 
            +
                # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
         | 
| 413 | 
            +
                # (~10%) compared to the measured FLOPs, making those lower but more realistic.
         | 
| 414 | 
            +
                # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
         | 
| 415 | 
            +
                n_trainable_params = num_parameters(model, requires_grad=True)
         | 
| 416 | 
            +
                trainable_flops = flops_per_param(
         | 
| 417 | 
            +
                    model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params
         | 
| 418 | 
            +
                )
         | 
| 419 | 
            +
                # forward + backward + gradients (assumes no gradient accumulation)
         | 
| 420 | 
            +
                ops_per_step = 3 if model.training else 1
         | 
| 421 | 
            +
                n_frozen_params = num_parameters(model, requires_grad=False)
         | 
| 422 | 
            +
                frozen_flops = flops_per_param(
         | 
| 423 | 
            +
                    model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params
         | 
| 424 | 
            +
                )
         | 
| 425 | 
            +
                # forward + backward
         | 
| 426 | 
            +
                frozen_ops_per_step = 2 if model.training else 1
         | 
| 427 | 
            +
                return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
         | 
| 428 | 
            +
             | 
| 429 | 
            +
             | 
| 430 | 
            +
            def measure_flops(model: GPT, x: torch.Tensor) -> int:
         | 
| 431 | 
            +
                """Measures real FLOPs for HFU"""
         | 
| 432 | 
            +
                flop_counter = FlopCounterMode(model, display=False)
         | 
| 433 | 
            +
                ctx = nullcontext() if model.training else torch.no_grad()
         | 
| 434 | 
            +
                with ctx, flop_counter:
         | 
| 435 | 
            +
                    y = model(x)
         | 
| 436 | 
            +
                    if model.training:
         | 
| 437 | 
            +
                        y.sum().backward()
         | 
| 438 | 
            +
                return flop_counter.get_total_flops()
         | 
    	
        tsai_gpt/tokenizer.py
    ADDED
    
    | @@ -0,0 +1,106 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            from pathlib import Path
         | 
| 3 | 
            +
            from typing import Optional
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class Tokenizer:
         | 
| 9 | 
            +
                def __init__(self, checkpoint_dir: Path) -> None:
         | 
| 10 | 
            +
                    self.use_bos = self.check_if_bos_token_used(checkpoint_dir)
         | 
| 11 | 
            +
                    self.bos_id = None
         | 
| 12 | 
            +
                    self.eos_id = None
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                    # some checkpoints have both files, `.model` takes precedence
         | 
| 15 | 
            +
                    if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
         | 
| 16 | 
            +
                        from sentencepiece import SentencePieceProcessor
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                        self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
         | 
| 19 | 
            +
                        self.backend = "sentencepiece"
         | 
| 20 | 
            +
                        self.bos_id = self.processor.bos_id()
         | 
| 21 | 
            +
                        self.eos_id = self.processor.eos_id()
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file():
         | 
| 24 | 
            +
                        from tokenizers import Tokenizer as HFTokenizer
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                        self.processor = HFTokenizer.from_file(str(vocabulary_path))
         | 
| 27 | 
            +
                        self.backend = "huggingface"
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                        if (special_tokens_path := checkpoint_dir / "tokenizer_config.json").is_file():
         | 
| 30 | 
            +
                            with open(special_tokens_path) as fp:
         | 
| 31 | 
            +
                                config = json.load(fp)
         | 
| 32 | 
            +
                            bos_token = config.get("bos_token")
         | 
| 33 | 
            +
                            self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None
         | 
| 34 | 
            +
                            eos_token = config.get("eos_token")
         | 
| 35 | 
            +
                            self.eos_id = self.token_to_id(eos_token) if eos_token is not None else None
         | 
| 36 | 
            +
                        if (special_tokens_path := checkpoint_dir / "generation_config.json").is_file():
         | 
| 37 | 
            +
                            with open(special_tokens_path) as fp:
         | 
| 38 | 
            +
                                config = json.load(fp)
         | 
| 39 | 
            +
                            if self.bos_id is None:
         | 
| 40 | 
            +
                                self.bos_id = config.get("bos_token_id")
         | 
| 41 | 
            +
                            if self.eos_id is None:
         | 
| 42 | 
            +
                                self.eos_id = config.get("eos_token_id")
         | 
| 43 | 
            +
                    else:
         | 
| 44 | 
            +
                        raise NotImplementedError
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                @property
         | 
| 47 | 
            +
                def vocab_size(self) -> int:
         | 
| 48 | 
            +
                    if self.backend == "huggingface":
         | 
| 49 | 
            +
                        return self.processor.get_vocab_size(with_added_tokens=False)
         | 
| 50 | 
            +
                    if self.backend == "sentencepiece":
         | 
| 51 | 
            +
                        return self.processor.vocab_size()
         | 
| 52 | 
            +
                    raise RuntimeError
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def token_to_id(self, token: str) -> int:
         | 
| 55 | 
            +
                    if self.backend == "huggingface":
         | 
| 56 | 
            +
                        id_ = self.processor.token_to_id(token)
         | 
| 57 | 
            +
                    elif self.backend == "sentencepiece":
         | 
| 58 | 
            +
                        id_ = self.processor.piece_to_id(token)
         | 
| 59 | 
            +
                    else:
         | 
| 60 | 
            +
                        raise RuntimeError
         | 
| 61 | 
            +
                    if id_ is None:
         | 
| 62 | 
            +
                        raise ValueError(f"token {token!r} not found in the collection.")
         | 
| 63 | 
            +
                    return id_
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
         | 
| 66 | 
            +
                    if not (tokenizer_config_path := checkpoint_dir / "tokenizer_config.json").is_file():
         | 
| 67 | 
            +
                        return False
         | 
| 68 | 
            +
                    with open(tokenizer_config_path) as fp:
         | 
| 69 | 
            +
                        config = json.load(fp)
         | 
| 70 | 
            +
                    if any(config.get(check, False) for check in ("add_bos_token", "add_prefix_space")):
         | 
| 71 | 
            +
                        return True
         | 
| 72 | 
            +
                    # for examples that also use the Llama tokenizer, but do not have or set add_bos_token to True.
         | 
| 73 | 
            +
                    # ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2
         | 
| 74 | 
            +
                    return (
         | 
| 75 | 
            +
                        config.get("add_bos_token") is None
         | 
| 76 | 
            +
                        and config.get("tokenizer_class") == "LlamaTokenizer"
         | 
| 77 | 
            +
                    )
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def encode(
         | 
| 80 | 
            +
                    self,
         | 
| 81 | 
            +
                    string: str,
         | 
| 82 | 
            +
                    device: Optional[torch.device] = None,
         | 
| 83 | 
            +
                    bos: Optional[bool] = None,
         | 
| 84 | 
            +
                    eos: bool = False,
         | 
| 85 | 
            +
                    max_length: int = -1,
         | 
| 86 | 
            +
                ) -> torch.Tensor:
         | 
| 87 | 
            +
                    if self.backend == "huggingface":
         | 
| 88 | 
            +
                        tokens = self.processor.encode(string).ids
         | 
| 89 | 
            +
                    elif self.backend == "sentencepiece":
         | 
| 90 | 
            +
                        tokens = self.processor.encode(string)
         | 
| 91 | 
            +
                    else:
         | 
| 92 | 
            +
                        raise RuntimeError
         | 
| 93 | 
            +
                    if bos or (bos is None and self.use_bos):
         | 
| 94 | 
            +
                        bos_id = self.bos_id
         | 
| 95 | 
            +
                        if bos_id is None:
         | 
| 96 | 
            +
                            raise NotImplementedError("This tokenizer does not have a defined a bos token")
         | 
| 97 | 
            +
                        tokens = [bos_id] + tokens
         | 
| 98 | 
            +
                    if eos:
         | 
| 99 | 
            +
                        tokens = tokens + [self.eos_id]
         | 
| 100 | 
            +
                    if max_length > 0:
         | 
| 101 | 
            +
                        tokens = tokens[:max_length]
         | 
| 102 | 
            +
                    return torch.tensor(tokens, dtype=torch.int, device=device)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                def decode(self, tensor: torch.Tensor) -> str:
         | 
| 105 | 
            +
                    tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
         | 
| 106 | 
            +
                    return self.processor.decode(tokens)
         | 
    	
        tsai_gpt/utils.py
    ADDED
    
    | @@ -0,0 +1,367 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Utility functions for training and inference."""
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            import pickle
         | 
| 4 | 
            +
            import sys
         | 
| 5 | 
            +
            from contextlib import nullcontext
         | 
| 6 | 
            +
            from io import BytesIO
         | 
| 7 | 
            +
            from pathlib import Path
         | 
| 8 | 
            +
            from typing import (TYPE_CHECKING, ContextManager, Dict, List, Mapping,
         | 
| 9 | 
            +
                                Optional, TypeVar, Union)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import lightning as L
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            import torch.nn as nn
         | 
| 14 | 
            +
            import torch.utils._device
         | 
| 15 | 
            +
            from lightning.fabric.strategies import FSDPStrategy
         | 
| 16 | 
            +
            from lightning.fabric.utilities.load import _lazy_load as lazy_load
         | 
| 17 | 
            +
            from torch.serialization import normalize_storage_type
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            if TYPE_CHECKING:
         | 
| 20 | 
            +
                from model import GPT
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def find_multiple(n: int, k: int) -> int:
         | 
| 24 | 
            +
                assert k > 0
         | 
| 25 | 
            +
                if n % k == 0:
         | 
| 26 | 
            +
                    return n
         | 
| 27 | 
            +
                return n + k - (n % k)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
         | 
| 31 | 
            +
                total = 0
         | 
| 32 | 
            +
                for p in module.parameters():
         | 
| 33 | 
            +
                    if requires_grad is None or p.requires_grad == requires_grad:
         | 
| 34 | 
            +
                        if hasattr(p, "quant_state"):
         | 
| 35 | 
            +
                            # bitsandbytes 4bit layer support
         | 
| 36 | 
            +
                            total += math.prod(p.quant_state[1])
         | 
| 37 | 
            +
                        else:
         | 
| 38 | 
            +
                            total += p.numel()
         | 
| 39 | 
            +
                return total
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            def gptq_quantization(enabled: bool = False) -> ContextManager:
         | 
| 43 | 
            +
                if not enabled:
         | 
| 44 | 
            +
                    return nullcontext()
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                from lightning.fabric.plugins.precision.utils import \
         | 
| 47 | 
            +
                    _ClassReplacementContextManager
         | 
| 48 | 
            +
                from quantize.gptq import ColBlockQuantizedLinear
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                class QuantizedLinear(ColBlockQuantizedLinear):
         | 
| 51 | 
            +
                    def __init__(self, *args, **kwargs):
         | 
| 52 | 
            +
                        super().__init__(*args, bits=4, tile_cols=-1, **kwargs)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                return _ClassReplacementContextManager({"torch.nn.Linear": QuantizedLinear})
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None:
         | 
| 58 | 
            +
                files = {
         | 
| 59 | 
            +
                    "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(),
         | 
| 60 | 
            +
                    "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(),
         | 
| 61 | 
            +
                    "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file()
         | 
| 62 | 
            +
                    or (checkpoint_dir / "tokenizer.model").is_file(),
         | 
| 63 | 
            +
                    "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
         | 
| 64 | 
            +
                }
         | 
| 65 | 
            +
                if checkpoint_dir.is_dir():
         | 
| 66 | 
            +
                    if all(files.values()):
         | 
| 67 | 
            +
                        # we're good
         | 
| 68 | 
            +
                        return
         | 
| 69 | 
            +
                    problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
         | 
| 70 | 
            +
                else:
         | 
| 71 | 
            +
                    problem = " is not a checkpoint directory"
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                # list locally available checkpoints
         | 
| 74 | 
            +
                available = list(Path("checkpoints").glob("*/*"))
         | 
| 75 | 
            +
                if available:
         | 
| 76 | 
            +
                    options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available])
         | 
| 77 | 
            +
                    extra = f"\nYou have downloaded locally:{options}\n"
         | 
| 78 | 
            +
                else:
         | 
| 79 | 
            +
                    extra = ""
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                error_message = (
         | 
| 82 | 
            +
                    f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
         | 
| 83 | 
            +
                    "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n"
         | 
| 84 | 
            +
                    f"{extra}\nSee all download options by running:\n python scripts/download.py"
         | 
| 85 | 
            +
                )
         | 
| 86 | 
            +
                print(error_message, file=sys.stderr)
         | 
| 87 | 
            +
                raise SystemExit(1)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            class SavingProxyForStorage:
         | 
| 91 | 
            +
                def __init__(self, obj, saver, protocol_version=5):
         | 
| 92 | 
            +
                    self.protocol_version = protocol_version
         | 
| 93 | 
            +
                    self.saver = saver
         | 
| 94 | 
            +
                    if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
         | 
| 95 | 
            +
                        raise TypeError(f"expected storage, not {type(obj)}")
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    # this logic is taken from PyTorch 2.0+ torch/serialization.py
         | 
| 98 | 
            +
                    if isinstance(obj, torch.storage.TypedStorage):
         | 
| 99 | 
            +
                        # PT upstream wants to deprecate this eventually...
         | 
| 100 | 
            +
                        storage = obj._untyped_storage
         | 
| 101 | 
            +
                        storage_type_str = obj._pickle_storage_type()
         | 
| 102 | 
            +
                        storage_type = getattr(torch, storage_type_str)
         | 
| 103 | 
            +
                        storage_numel = obj._size()
         | 
| 104 | 
            +
                    else:
         | 
| 105 | 
            +
                        storage = obj
         | 
| 106 | 
            +
                        storage_type = normalize_storage_type(type(obj))
         | 
| 107 | 
            +
                        storage_numel = storage.nbytes()
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    storage_key = saver._write_storage_and_return_key(storage)
         | 
| 110 | 
            +
                    location = torch.serialization.location_tag(storage)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    self.storage_info = ("storage", storage_type, storage_key, location, storage_numel)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                def __reduce_ex__(self, protocol_version):
         | 
| 115 | 
            +
                    assert False, "this should be handled with out of band"
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            class SavingProxyForTensor:
         | 
| 119 | 
            +
                def __init__(self, tensor, saver, protocol_version=5):
         | 
| 120 | 
            +
                    self.protocol_version = protocol_version
         | 
| 121 | 
            +
                    self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)
         | 
| 122 | 
            +
                    if reduce_args[0] == torch._utils._rebuild_tensor_v2:
         | 
| 123 | 
            +
                        # for Tensors with Python attributes
         | 
| 124 | 
            +
                        (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
         | 
| 125 | 
            +
                        assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
         | 
| 126 | 
            +
                        storage_proxy = SavingProxyForStorage(
         | 
| 127 | 
            +
                            storage, saver, protocol_version=protocol_version
         | 
| 128 | 
            +
                        )
         | 
| 129 | 
            +
                        self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
         | 
| 130 | 
            +
                    else:
         | 
| 131 | 
            +
                        (storage, *other_reduce_args) = reduce_args
         | 
| 132 | 
            +
                        assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
         | 
| 133 | 
            +
                        storage_proxy = SavingProxyForStorage(
         | 
| 134 | 
            +
                            storage, saver, protocol_version=protocol_version
         | 
| 135 | 
            +
                        )
         | 
| 136 | 
            +
                        self.reduce_args = (storage_proxy, *other_reduce_args)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                def __reduce_ex__(self, protocol_version):
         | 
| 139 | 
            +
                    if protocol_version != self.protocol_version:
         | 
| 140 | 
            +
                        raise RuntimeError(
         | 
| 141 | 
            +
                            f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}"
         | 
| 142 | 
            +
                        )
         | 
| 143 | 
            +
                    return self.reduce_ret_fn, self.reduce_args
         | 
| 144 | 
            +
             | 
| 145 | 
            +
             | 
| 146 | 
            +
            class IncrementalPyTorchPickler(pickle.Pickler):
         | 
| 147 | 
            +
                def __init__(self, saver, *args, **kwargs):
         | 
| 148 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 149 | 
            +
                    self.storage_dtypes = {}
         | 
| 150 | 
            +
                    self.saver = saver
         | 
| 151 | 
            +
                    self.id_map = {}
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                # this logic is taken from PyTorch 2.0+ torch/serialization.py
         | 
| 154 | 
            +
                def persistent_id(self, obj):
         | 
| 155 | 
            +
                    # FIXME: the docs say that persistent_id should only return a string
         | 
| 156 | 
            +
                    # but torch store returns tuples. This works only in the binary protocol
         | 
| 157 | 
            +
                    # see
         | 
| 158 | 
            +
                    # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
         | 
| 159 | 
            +
                    # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
         | 
| 160 | 
            +
                    if isinstance(obj, SavingProxyForStorage):
         | 
| 161 | 
            +
                        return obj.storage_info
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
         | 
| 164 | 
            +
                        if isinstance(obj, torch.storage.TypedStorage):
         | 
| 165 | 
            +
                            # TODO: Once we decide to break serialization FC, this case
         | 
| 166 | 
            +
                            # can be deleted
         | 
| 167 | 
            +
                            storage = obj._untyped_storage
         | 
| 168 | 
            +
                            storage_dtype = obj.dtype
         | 
| 169 | 
            +
                            storage_type_str = obj._pickle_storage_type()
         | 
| 170 | 
            +
                            storage_type = getattr(torch, storage_type_str)
         | 
| 171 | 
            +
                            storage_numel = obj._size()
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                        else:
         | 
| 174 | 
            +
                            storage = obj
         | 
| 175 | 
            +
                            storage_dtype = torch.uint8
         | 
| 176 | 
            +
                            storage_type = normalize_storage_type(type(obj))
         | 
| 177 | 
            +
                            storage_numel = storage.nbytes()
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                        # If storage is allocated, ensure that any other saved storages
         | 
| 180 | 
            +
                        # pointing to the same data all have the same dtype. If storage is
         | 
| 181 | 
            +
                        # not allocated, don't perform this check
         | 
| 182 | 
            +
                        if storage.data_ptr() != 0:
         | 
| 183 | 
            +
                            if storage.data_ptr() in self.storage_dtypes:
         | 
| 184 | 
            +
                                if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
         | 
| 185 | 
            +
                                    raise RuntimeError(
         | 
| 186 | 
            +
                                        "Cannot save multiple tensors or storages that view the same data as different types"
         | 
| 187 | 
            +
                                    )
         | 
| 188 | 
            +
                            else:
         | 
| 189 | 
            +
                                self.storage_dtypes[storage.data_ptr()] = storage_dtype
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                        storage_key = self.id_map.get(storage._cdata)
         | 
| 192 | 
            +
                        if storage_key is None:
         | 
| 193 | 
            +
                            storage_key = self.saver._write_storage_and_return_key(storage)
         | 
| 194 | 
            +
                            self.id_map[storage._cdata] = storage_key
         | 
| 195 | 
            +
                        location = torch.serialization.location_tag(storage)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                        return ("storage", storage_type, storage_key, location, storage_numel)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    return None
         | 
| 200 | 
            +
             | 
| 201 | 
            +
             | 
| 202 | 
            +
            class incremental_save:
         | 
| 203 | 
            +
                def __init__(self, name):
         | 
| 204 | 
            +
                    self.name = name
         | 
| 205 | 
            +
                    self.zipfile = torch._C.PyTorchFileWriter(str(name))
         | 
| 206 | 
            +
                    self.has_saved = False
         | 
| 207 | 
            +
                    self.next_key = 0
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                def __enter__(self):
         | 
| 210 | 
            +
                    return self
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                def store_early(self, tensor):
         | 
| 213 | 
            +
                    if isinstance(tensor, torch.Tensor):
         | 
| 214 | 
            +
                        return SavingProxyForTensor(tensor, self)
         | 
| 215 | 
            +
                    raise TypeError(f"can only store tensors early, not {type(tensor)}")
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                def save(self, obj):
         | 
| 218 | 
            +
                    if self.has_saved:
         | 
| 219 | 
            +
                        raise RuntimeError("have already saved")
         | 
| 220 | 
            +
                    # Write the pickle data for `obj`
         | 
| 221 | 
            +
                    data_buf = BytesIO()
         | 
| 222 | 
            +
                    pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
         | 
| 223 | 
            +
                    pickler.dump(obj)
         | 
| 224 | 
            +
                    data_value = data_buf.getvalue()
         | 
| 225 | 
            +
                    self.zipfile.write_record("data.pkl", data_value, len(data_value))
         | 
| 226 | 
            +
                    self.has_saved = True
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                def _write_storage_and_return_key(self, storage):
         | 
| 229 | 
            +
                    if self.has_saved:
         | 
| 230 | 
            +
                        raise RuntimeError("have already saved")
         | 
| 231 | 
            +
                    key = self.next_key
         | 
| 232 | 
            +
                    self.next_key += 1
         | 
| 233 | 
            +
                    name = f"data/{key}"
         | 
| 234 | 
            +
                    if storage.device.type != "cpu":
         | 
| 235 | 
            +
                        storage = storage.cpu()
         | 
| 236 | 
            +
                    num_bytes = storage.nbytes()
         | 
| 237 | 
            +
                    self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
         | 
| 238 | 
            +
                    return key
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                def __exit__(self, type, value, traceback):
         | 
| 241 | 
            +
                    self.zipfile.write_end_of_file()
         | 
| 242 | 
            +
             | 
| 243 | 
            +
             | 
| 244 | 
            +
            T = TypeVar("T")
         | 
| 245 | 
            +
             | 
| 246 | 
            +
             | 
| 247 | 
            +
            def chunked_cross_entropy(
         | 
| 248 | 
            +
                logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128
         | 
| 249 | 
            +
            ) -> torch.Tensor:
         | 
| 250 | 
            +
                # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
         | 
| 251 | 
            +
                # the memory usage in fine-tuning settings with low number of parameters.
         | 
| 252 | 
            +
                # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
         | 
| 253 | 
            +
                # the memory spike's magnitude
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                # lm_head was chunked (we are fine-tuning)
         | 
| 256 | 
            +
                if isinstance(logits, list):
         | 
| 257 | 
            +
                    # don't want to chunk cross entropy
         | 
| 258 | 
            +
                    if chunk_size == 0:
         | 
| 259 | 
            +
                        logits = torch.cat(logits, dim=1)
         | 
| 260 | 
            +
                        logits = logits.reshape(-1, logits.size(-1))
         | 
| 261 | 
            +
                        targets = targets.reshape(-1)
         | 
| 262 | 
            +
                        return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                    # chunk cross entropy
         | 
| 265 | 
            +
                    logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits]
         | 
| 266 | 
            +
                    target_chunks = [
         | 
| 267 | 
            +
                        target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)
         | 
| 268 | 
            +
                    ]
         | 
| 269 | 
            +
                    loss_chunks = [
         | 
| 270 | 
            +
                        torch.nn.functional.cross_entropy(
         | 
| 271 | 
            +
                            logit_chunk, target_chunk, ignore_index=-1, reduction="none"
         | 
| 272 | 
            +
                        )
         | 
| 273 | 
            +
                        for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
         | 
| 274 | 
            +
                    ]
         | 
| 275 | 
            +
                    return torch.cat(loss_chunks).mean()
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                # no chunking at all
         | 
| 278 | 
            +
                logits = logits.reshape(-1, logits.size(-1))
         | 
| 279 | 
            +
                targets = targets.reshape(-1)
         | 
| 280 | 
            +
                if chunk_size == 0:
         | 
| 281 | 
            +
                    return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                # lm_head wasn't chunked, chunk cross entropy
         | 
| 284 | 
            +
                logit_chunks = logits.split(chunk_size)
         | 
| 285 | 
            +
                target_chunks = targets.split(chunk_size)
         | 
| 286 | 
            +
                loss_chunks = [
         | 
| 287 | 
            +
                    torch.nn.functional.cross_entropy(
         | 
| 288 | 
            +
                        logit_chunk, target_chunk, ignore_index=-1, reduction="none"
         | 
| 289 | 
            +
                    )
         | 
| 290 | 
            +
                    for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
         | 
| 291 | 
            +
                ]
         | 
| 292 | 
            +
                return torch.cat(loss_chunks).mean()
         | 
| 293 | 
            +
             | 
| 294 | 
            +
             | 
| 295 | 
            +
            def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
         | 
| 296 | 
            +
                for checkpoint_name, attribute_name in mapping.items():
         | 
| 297 | 
            +
                    full_checkpoint_name = prefix + checkpoint_name
         | 
| 298 | 
            +
                    if full_checkpoint_name in state_dict:
         | 
| 299 | 
            +
                        full_attribute_name = prefix + attribute_name
         | 
| 300 | 
            +
                        state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
         | 
| 301 | 
            +
                return state_dict
         | 
| 302 | 
            +
             | 
| 303 | 
            +
             | 
| 304 | 
            +
            def get_default_supported_precision(training: bool) -> str:
         | 
| 305 | 
            +
                """Return default precision that is supported by the hardware: either `bf16` or `16`.
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                Args:
         | 
| 308 | 
            +
                    training: `-mixed` or `-true` version of the precision to use
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                Returns:
         | 
| 311 | 
            +
                    default precision that is suitable for the task and is supported by the hardware
         | 
| 312 | 
            +
                """
         | 
| 313 | 
            +
                from lightning.fabric.accelerators import MPSAccelerator
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                if MPSAccelerator.is_available() or (
         | 
| 316 | 
            +
                    torch.cuda.is_available() and not torch.cuda.is_bf16_supported()
         | 
| 317 | 
            +
                ):
         | 
| 318 | 
            +
                    return "16-mixed" if training else "16-true"
         | 
| 319 | 
            +
                return "bf16-mixed" if training else "bf16-true"
         | 
| 320 | 
            +
             | 
| 321 | 
            +
             | 
| 322 | 
            +
            def load_checkpoint(
         | 
| 323 | 
            +
                fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True
         | 
| 324 | 
            +
            ) -> None:
         | 
| 325 | 
            +
                if isinstance(fabric.strategy, FSDPStrategy):
         | 
| 326 | 
            +
                    fabric.load_raw(checkpoint_path, model, strict=strict)
         | 
| 327 | 
            +
                else:
         | 
| 328 | 
            +
                    state_dict = lazy_load(checkpoint_path)
         | 
| 329 | 
            +
                    state_dict = state_dict.get("model", state_dict)
         | 
| 330 | 
            +
                    model.load_state_dict(state_dict, strict=strict)
         | 
| 331 | 
            +
             | 
| 332 | 
            +
             | 
| 333 | 
            +
            def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int:
         | 
| 334 | 
            +
                flops_per_token = (
         | 
| 335 | 
            +
                    2 * n_params
         | 
| 336 | 
            +
                )  # each parameter is used for a MAC (2 FLOPS) per network operation
         | 
| 337 | 
            +
                # this assumes that all samples have a fixed length equal to the block size
         | 
| 338 | 
            +
                # which is most likely false during finetuning
         | 
| 339 | 
            +
                flops_per_seq = flops_per_token * max_seq_length
         | 
| 340 | 
            +
                attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
         | 
| 341 | 
            +
                return flops_per_seq + attn_flops_per_seq
         | 
| 342 | 
            +
             | 
| 343 | 
            +
             | 
| 344 | 
            +
            def estimate_flops(model: "GPT", training: bool) -> int:
         | 
| 345 | 
            +
                """Measures estimated FLOPs for MFU.
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                Refs:
         | 
| 348 | 
            +
                    * https://ar5iv.labs.arxiv.org/html/2205.05198#A1
         | 
| 349 | 
            +
                    * https://ar5iv.labs.arxiv.org/html/2204.02311#A2
         | 
| 350 | 
            +
                """
         | 
| 351 | 
            +
                # using all parameters for this is a naive over estimation because not all model parameters actually contribute to
         | 
| 352 | 
            +
                # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
         | 
| 353 | 
            +
                # (~10%) compared to the measured FLOPs, making those lower but more realistic.
         | 
| 354 | 
            +
                # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
         | 
| 355 | 
            +
                n_trainable_params = num_parameters(model, requires_grad=True)
         | 
| 356 | 
            +
                trainable_flops = flops_per_param(
         | 
| 357 | 
            +
                    model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params
         | 
| 358 | 
            +
                )
         | 
| 359 | 
            +
                # forward + backward + gradients (assumes no gradient accumulation)
         | 
| 360 | 
            +
                ops_per_step = 3 if training else 1
         | 
| 361 | 
            +
                n_frozen_params = num_parameters(model, requires_grad=False)
         | 
| 362 | 
            +
                frozen_flops = flops_per_param(
         | 
| 363 | 
            +
                    model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params
         | 
| 364 | 
            +
                )
         | 
| 365 | 
            +
                # forward + backward
         | 
| 366 | 
            +
                frozen_ops_per_step = 2 if training else 1
         | 
| 367 | 
            +
                return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
         |