Molbap's picture
Molbap HF Staff
Recommit with PNGs tracked by LFS
c34c7de
|
raw
history blame
28.5 kB

Introduction

The transformers library, built with PyTorch, supports all state-of-the-art LLMs, many VLMs, task-specific vision language models, video models, audio models, table models, classical encoders, to a global count of almost 400 models.
The name of the library itself is mostly majority driven as many models are not even transformers architectures, like Mamba, Zamba, RWKV, and convolution-based models.
Regardless, each of these is wrought by the research and engineering team that created them, then harmonized into a now famous interface, and callable with a simple .from_pretrained command.
Inference works for all models, training is functional for most. The library is a foundation for many machine learning courses, cookbooks, and overall, several thousands other open-source libraries depend on it. All models are tested as part of a daily CI ensuring their preservation and reproducibility. Most importantly, it is open-source and has been written by the community for a large part.
This isn't really to brag but to set the stakes: what does it take to keep such a ship afloat, made of so many moving, unrelated parts?
The ML wave has not stopped, there's more and more models being added, at a steadily growing rate. Transformers is widely used, and we read the feedback that users post online. Whether it's about a function that had 300+ keyword arguments, duplicated code and helpers, and mentions of Copied from ... everywhere, along with optimisation concerns. Text-only models are relatively tamed, but multimodal models remain to be harmonized.
Here we will dissect what is the new design philosophy of transformers, as a continuation from the existing older philosophy page, and an accompanying blog post from 2022.
More recently, and I recommend the read if it's not done yet, a blog post about recent upgrades to transformers was written, explaining in particular what makes the library faster today.
Some time ago I dare not say how long, we discussed with transformers maintainers about the state of features in transformers. A lot of recent developments were satisfactory, but if we were only talking about these, self-congratulation would be the only goalpost.
Reflecting on this philosophy now, as models pile up, is essential and will drive new developments.

The core tenets of transformers

Every reader, whether an OSS maintainer, power user, or casual fine-tuner, will walk away knowing how to reason about the transformers code base, how to use it better, how to meaningfully contribute to it. This will also showcase new features you might have missed so you'll be up-to-date.

So, what are the principles of transformers? We will try to summarize the foundations on which we've built everything, and write the "tenets" of the library. They behave like software interfaces, hence it is crucial that they are explicitly written down. However opinionated they are, they have evolved over time.

  1. Source of Truth

    We should be a source of truth for all model definitions. This is not a tenet, but something that still guides our decisions. Model implementations should be reliable, reproducible, and faithful to the original performances.

    This overarching guideline ensures quality and reproducibility across all models in the library.
  2. One Model, One File

    All inference (and most of training, loss is separate, not a part of model) logic visible, topโ€‘toโ€‘bottom.

    Every model should be completely understandable by reading a single file from top to bottom.
  3. Code is Product

    Optimize for reading, diffing, and tweaking, our users are power users. Variables can be explicit, full words, even several words, readability is primordial.

    Code quality matters as much as functionality - optimize for human readers, not just computers.
  4. Standardize, Don't Abstract

    If it's model behavior, keep it in the file; abstractions only for generic infra.

    Model-specific logic belongs in the model file, not hidden behind abstractions.
  5. DRY* (DO Repeat Yourself)

    Copy when it helps users; keep successors in sync without centralizing behavior.

    Amendment: With the introduction and global adoption of modular transformers, we do not repeat any logic in the modular files, but end user files remain faithful to the original tenet.

    Strategic duplication can improve readability and maintainability when done thoughtfully.
  6. Minimal User API

    Config, model, preprocessing; from_pretrained, save_pretrained, push_to_hub. We want the least amount of codepaths. Reading should be obvious, configurations should be obvious.

    Keep the public interface simple and predictable - users should know what to expect.
  7. Backwards Compatibility

    Evolve by additive standardization, never break public APIs.

    Note: Some models are showing almost no use, we also stopped adding new features for non-torch frameworks. Still, we adapt to models existing on the hub.

    Once something is public, it stays public - evolution through addition, not breaking changes.
  8. Consistent Public Surface

    Same argument names, same outputs, hidden states and attentions exposed, enforced by tests.

    All models should feel familiar - consistent interfaces reduce cognitive load.
  9. Modular Toolbox (Not A Framework)

    We ARE a toolbox. What we are not is a framework: you should not be FORCED to rewrite every modeling, but it is better for your model to be able to inherit from PreTrainedModel and have enabled TensorParallel, from_pretrained, sharding, push_to_hub, loss, as well as PEFT/TRL/SGLang/vLLM.

    This is the largest change. Provide tools and utilities, but don't force users into a rigid framework.

When a PR is merged, it is because the contribution is worthwhile, and that the transformers team finds the design of the contribution to be aligned with what is above.

Does all the code in the library follow strictly these tenets? No. The library is a gigantic house with connected nooks, corridors, crannies everywhere built by thousands of different workers. We try to make it so all the code added is inline, lest we break backwards compatibility.

For instance, one function essential to the implementation of Rotary Positional Embeddings is identical in 70 modeling_<file>.py across src/transformers/models/. Why keep it? Because removing it would make those files unloadable checkpoints rather than self-contained blueprints. We do repeat ourselves.

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

You can use a simple regex to look at all methods of a given name across your codebase and look at their differences and similarities, that's what I did (+ a hash to avoid quadraticity).

So.... why keep it in all modeling files? Because if we were to remove it, the model would not work anymore. Think of the modeling files as a car (I know, what a novel metaphor! But, it works out.). All manual transmission cars have a clutch, but we want each view of one of our cars to be able to function. Remove the clutch, you can't drive. Remove the doors, might be uncomfortable but you'll get there. So doors can go, but you have to keep the clutch, even though you know perfectly how it works.

Going modular

It is opinionated, and it can be frustrating when you encounter an opinionated library. Our previous philosophy page, and the blog post were already pointing at some drawbacks, which have been iteratively addressed. Transformers has gone modular, allowing a form of inheritance without breaking One model, One file. If you're familiar with this, you can skip this section and go to the next one.

We amended the principle of DRY* by removing progressively all pieces of code that were "copied from" another file.

It is explained in details in the documentation above, but overall it works like this, you define a modular_ file that can inherit from any function across all other modeling, configuration and processor files:

Auto-generated modeling code

{{{fragment-glm-compare}}}

As you can see, we can now define any model as a modular of another. This isn't strictly groundbreaking if you've done any programming, you might even think "well that's just how inheritance works". The crucial difference is that we do visibly what is essentially the compiler's job: by unrolling the inheritances, we make visible all of the modeling code, keeping it all in one piece.

External Attention classes

A chronological iteration over modular, and a big improvement in terms of readabilty, was to remove the various attention-backend-specific attention classes across the repository. Before, we were adding specific torch operations for each backend (sdpa, flash-attention iterations, flex attention) but it wasn't a minimal user api.

What will forever stay in the modeling code is the eager_attention_forward because it is a core part of the modeling,

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
    attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

We often read and understand that kwargs are criticized, and we are typing them however we can, but we cannot enforce them all the time because other libraries such as vLLM don''t use the same kwargs.

It is a strength of the new attention interface, where it can be plugged in various backends, because most of the signature is not enforced. We INFORM but do not ENFORCE. That way, the current system is a minimal user api.

For better information, we plan to use python features such as Annotated for example, to inform users of what we expect typically in an argument. That way, higher-level information could be included directly in the type annotations, like so (tentative design):

from typing import Annotated

MyModelOutputAnnotated = Annotated[MyModelOutput, "shape: (B, C, H, W)"]

Simpler Tensor Parallelism

We want to touch minimally to the modeling code, and only modify it when architectural changes are involved. For instance, for tensor parallelism, we instead now specify a simple tp_plan.

It is written once in the config and passed to .from_pretrained().

The plan maps module name patterns to partitioning strategies. Strategies are resolved by the internal ParallelInterface, which wires to sharding implementations ColwiseParallel, RowwiseParallel, packed variants, and so on.

{{{fragment-tp-plan}}}

Which allows a user to run with multiple processes per node, e.g. 4 GPUs:

torchrun --nproc-per-node 4 demo.py

Semantics stay in the model (a Linear stays a Linear), distribution is orthogonal and declared via strings: "colwise" splits columns of weights/bias across ranks; "rowwise" splits rows; packed variants shard fused weights; The mapping keys accept glob patterns like layers.*.mlp.down_proj to target repeated submodules.

Layers, attentions and caches

Following the same logic, the nature of attention and caching per layer of a model should not be hardcoded. We should be able to specify in a configuration-based fashion how each layer is implemented. Thus we defined a mapping that can be then

ALLOWED_LAYER_TYPES = (
    "full_attention",
    "sliding_attention",
    "chunked_attention",
    "linear_attention",
    ...
)

and the configuration can be explicit about which attention type is in which layer, see e.g. gpt-oss, which alternates sliding and full attention:

  "layer_types": [
    "sliding_attention",
    "full_attention",
    ...,
    "sliding_attention",
    "full_attention"
  ],

This is minimal to implement on the user side, and allows to keep the modeling untouched. It is also easy to tweak.

Community Kernels

The same principle extends to normalization, activation, and other code paths. The model defines semantics; a kernel defines how to execute them faster. We annotate the module to borrow a communityโ€‘provided forward, keeping a consistent public surface

@use_kernel_forward_from_hub("RMSNorm")
class GlmRMSNorm(nn.Module):
    ...

Plus, this opened another angle of contribution for the community. People who are GPU whisperers can now contribute optimized kernels. You can check on the kernel community blog post to learn more about it!

Even more resources have been added, like the formidable kernel builder with its connected resources to help you build kernels with it and with nix.

The good modularity

Now, we have a form of inheritance in our codebase. Some models become standards, and model contributors are given the opportunity to define standards. Pushing the boundaries of scientific knowledge can translate into the boundaries of engineering if this effort is made, and we're striving for it. It's hard to conceptualize very large libraries and how their components interact with each other, regardless of your cognitive abilities for abstractions. So I wanted to take a look at the current state of modularity across the repository. How many models are defined using components of others?

To get this graph, I used the heuristic of modular inheritance.

  1. Does this model have a modular file?
  2. In this modular file, what models, configurations and processings are imported?
  3. Recurse through the model list that way.

So what do we see? Llama is a basis for many models, and it shows. Radically different architectures such as mamba have spawned their own dependency subgraph.

{{{fragment-dependency-graph}}}

However, even if llava defines a few VLMs, there's far too many vision-based architectures that are not yet defined as modulars of other existing archs. In other words, there is no strong reference point in terms of software for vision models. As you can see, there is a small DETR island, a little llava pocket, and so on, but it's not comparable to the centrality observed for llama.

Another problem is, this is only for modular models. Several models do NOT have a modular file.

Many models, but not enough yet, are alike

So I looked into Jaccard similarity, which we use to measure set differences. I know that code is more than a set of characters stringed together. I also used code embedding models to check out code similarities, and it yielded better results, for the needs of this blog post I will stick to Jaccard index.

It is interesting, for that, to look at when we deployed this modular logic and what was its rippling effect on the library. You can check the larger space to play around, but the gist is: adding modular allowed to connect more and more models to solid reference points. We have a lot of gaps to fill in still.

{{{fragment-model-timeline}}}

If you've checked out llava, you've seen that llava_video is a red node, connected by a red edge to llava: it's a candidate, something that we can likely remodularize, not touching the actual model but being much more readable with DRY*.

VLM improvements, avoiding abstraction

We don't have cookbook for common VLM patterns (image token scatter, multiโ€‘tower encoders, crossโ€‘attn bridges). This is one of the main improvement points where we can work.

For instance, I thought of abstracting away the mixing of inputs_embeds, the tensor fed into an llm decoder in 95% of the existing VLMs. It would have looked like something like

class InputsEmbeddingMixerMixin(nn.Module):
    #

But this is abstracting away an important component of the modeling.. Embedding mixin is part of the model, removing it would break it. A user opening modeling_qwen2.5_vl should not have to go to another file.

This is the current state of abstractions across a modeling file:

Bloatedness visualizer showing abstraction levels

The following Pull request to standardize placeholder masking is a good example of what kind of changes are acceptable. In a VLM, we always need to insert embeddings from various encoders at various positions, so we can have a function to do it. For Qwen2 VL, for instance, it will look like this:

    def get_placeholder_mask(
        self,
        input_ids: torch.LongTensor,
        inputs_embeds: torch.FloatTensor,
        image_features: torch.FloatTensor = None,
        video_features: torch.FloatTensor = None,
    ):
        """
        Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
        equal to the length of multimodal features. If the lengths are different, an error is raised.
        """
        if input_ids is None:
            special_image_mask = inputs_embeds == self.get_input_embeddings()(
                torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
            )
            special_image_mask = special_image_mask.all(-1)
            special_video_mask = inputs_embeds == self.get_input_embeddings()(
                torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
            )
            special_video_mask = special_video_mask.all(-1)
        else:
            special_image_mask = input_ids == self.config.image_token_id
            special_video_mask = input_ids == self.config.video_token_id

        n_image_tokens = special_image_mask.sum()
        special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
        if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
            raise ValueError(
                f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
            )

        n_video_tokens = special_video_mask.sum()
        special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
        if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
            raise ValueError(
                f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
            )

        return special_image_mask, special_video_mask

But this is within the modeling file, not in the PreTrainedModel base class. It will not move away from it, because it'd break the self-contained logic of the model.

The weight of maintenance

The effect of modular can be measured straight from git history: at every commit I counted LOC (lines of code) under src/transformers/models, but if a model has a modular_*.py I count it. That gives an โ€œeffective LOCโ€ curve: the ๐—บ๐—ฎ๐—ถ๐—ป๐˜๐—ฒ๐—ป๐—ฎ๐—ป๐—ฐ๐—ฒ ๐˜€๐˜‚๐—ฟ๐—ณ๐—ฎ๐—ฐ๐—ฒ.

๐—๐˜‚๐˜€๐˜ ๐—น๐—ผ๐—ผ๐—ธ ๐—ฎ๐˜ ๐˜๐—ต๐—ฒ ๐—ฟ๐—ฒ๐˜€๐˜‚๐—น๐˜: ๐˜๐—ต๐—ฒ ๐—ด๐—ฟ๐—ผ๐˜„๐˜๐—ต ๐—ฟ๐—ฎ๐˜๐—ฒ ๐—ผ๐—ณ ๐—น๐—ถ๐—ป๐—ฒ๐˜€ ๐—ผ๐—ณ ๐—ฐ๐—ผ๐—ฑ๐—ฒ ๐—ฐ๐—ผ๐—น๐—น๐—ฎ๐—ฝ๐˜€๐—ฒ๐—ฑ! Counting raw ๐š–๐š˜๐š๐šŽ๐š•๐š’๐š—๐š_*.๐š™๐šข (with โ€œCopied fromโ€ฆโ€ everywhere) we were around 362 LOC/day; with ๐š–๐š˜๐š๐šž๐š•๐šŠ๐š› in place the effective rate is ~25 LOC/day. About ๐Ÿญ๐Ÿฑร— ๐—น๐—ผ๐˜„๐—ฒ๐—ฟ! Had we continued with a strict "one model, one file" policy who knows where we'd have ended up.

Less code to hand-maintain means fewer places to break.

Cyclomatic complexity isnโ€™t LOC, but they strongly correlate. As Les Hatton notes, defects scale like ๐™™ ~ ๐™ญ ๐™ก๐™ฃ ๐™ญ. Lower ๐˜… (lower loc) helps.

{{{fragment-loc-growth}}}

There's a sharp drop near the end, it's due to us removing support for Jax and TensorFlow library-wide.

Of course, it is not only this effort that allowed to reduce the maintenance load. Externalising the attention classes has moved out a lot of repeated code that was standard.

Embedding models, now and forever.

Models popularity speaks for itself! This is because the usage of encoders lies in embeddings. So we have to keep the encoders part viable, usable, fine-tune-able.

{{{fragment-model-visualisation}}}

As the codebase grows, with our friend codebase Sentence Transformers, we need to maintain this one as well. Retrieval use-cases, smart dbs, like FAISS-based indexing rely on it, and thus indirectly on transformers.

On image processing and processors

Choosing to be a torch-first software meant relieving a tremendous amount of support from jax and TensorFlow , and it also meant that we could be more lenient into the amount of torch-dependent utilities that we were able to add. One of these is the fast processing of images. Where they were before assumed to be minimal ndarrays, making stronger assumptions and enforcing torch and torchvisionnative inputs allowed up to speed up massively the processing time for each model.

The gains in performance are immense, up to 20x speed for most models when compiled torchvision ops.

Fast Image Processors Performance

Reduce barrier to entry/contribution

This is an overall objective: there's no transformer without its community.

We didn't want to make a toolbox, because having a framework means forcing users into it. It restrains flexibility and creativity, which are the fertile soil for new ideas to grow.

Among the most valuable contributions to transformersis of course the addition of new models. A second one is the ability to fine-tune and pipeline these models into many other softwares.

In that regard, we DO want to be a modular toolbox, being minimal enough (and hopefully well documented enough) so any ML/AI developer can use transformers without having to think about it. We aim to reduce the cognitive load brought about by model development, not increase it.

A surgical toolbox for model development

Attention visualisation

If all models have the same API internally for attention computation, it allows us to build cool tools to visualize the inner workings of the attention mechanism. One particular piece of machinery is the attention mask, cause of confusion.

Here you see the famous bidirectional attention pattern for the whole prefix (text + image) in PaliGemma and all Gemma2+ models, contrasting with the usual "causal-only" models.

{{{fragment-attention-visualizer}}}

Logging entire model activations

Further, because it is all PyTorch (and it is even more now that we support only PyTorch), we can easily debug any model when we want to add it to transformers. We now have a power-user tool for porting or adding models, that wraps a forward pass, intercepts every submodule call, and logs shapes, dtypes, and sample statistics of inputs/outputs to nested JSON.

It just works with PyTorch models and is especially useful when aligning outputs with a reference implementation, aligned with our core guideline.

Model debugger interface

Cooking faster CUDA warmups

Having a clean external API allows us to work on the true inner workings of transformers. One of the few recent additions was the CUDA warmup via caching_allocator_warmup which improved massively the loading footprint by pre-allocating GPU memory to avoid malloc bottlenecks during model loading.

{{{fragment-warmup_demo}}}

It's hard to overstate how much of a lifesaver that is when you're trying to load a model as fast as possible, your iteration speed.

Transformers-serve and continuous batching

Having all these models readily available allows to use all of them with transformers-serve, and enable interfacing with them with an Open API-like pattern.

transformers serve

curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"messages": [{"role": "system", "content": "hello"}], "temperature": 0.9, "max_tokens": 1000, "stream": true, "model": "Qwen/Qwen2.5-0.5B-Instruct"}'

This provides an OpenAI-compatible API with features like continuous batching (also check here) for better GPU utilization.

Continuous batching is in itself very much linked to the great work of vLLM with the paged attention kernel, further justifying the facilitation of external kernels.

Community reusability

Transformers-serve is transformers-first, for sure, but it's not limited to that. Adding a model to transformers means:

  • having it immediately available to the community
  • having it immediately usable in vLLM, SGLang, and so on without additional code. In April 2025, transformers was added as a backend to run models on vLLM, which optimizes throughput/latency on top of existing transformers architectures as seen in this great blog post.

This cements the need even more for a consistent public surface: we are now a backend, and there's more optimized software than us to handle serving. At the time of writing, more effort is done in that direction. We already have compatible configs for VLMs for vLLM (say that three times fast), here for GLM4 video support, and here for MoE support for instance.

What is coming next

It sounds dumb, but it's true: the future is very soon. One tenet that will be broken when the next major version is released, v5, backwards compatibility will be heavily broken. Instead, what we aim to be is way more of a modular toolbox, while maintaining a consistent public surface.