|
|
--- |
|
|
title: "Maintain the unmaintainable:\n1M python loc, 400+ models" |
|
|
subtitle: "A peek into software engineering for the transformers library" |
|
|
description: "A peek into software engineering for the transformers library" |
|
|
authors: |
|
|
- name: "Pablo Montalvo" |
|
|
url: "https://huggingface.co/Molbap" |
|
|
affiliations: [1] |
|
|
affiliations: |
|
|
- name: "Hugging Face" |
|
|
url: "https://huggingface.co" |
|
|
published: "October 2, 2025" |
|
|
tags: [transformers, engineering, design-philosophy] |
|
|
tableOfContentsAutoCollapse: true |
|
|
--- |
|
|
|
|
|
import HtmlEmbed from "../components/HtmlEmbed.astro"; |
|
|
|
|
|
## Preface |
|
|
|
|
|
One million lines of `python` code. Through them, the `transformers` library supports more than 400 model architectures, from state-of-the-art LLMs and VLMs to specialized models for audio, video, and tables. |
|
|
|
|
|
Built on `PyTorch`, transformers is a foundational tool for modern LLM usage, research, education, and tens of thousands of other open-source projects. Each AI model is added by the community, harmonized into a consistent interface, and tested daily on a CI to ensure reproducibility. |
|
|
|
|
|
This scale presents a monumental engineering challenge. |
|
|
|
|
|
How do you keep such a ship afloat, made of so many moving, unrelated parts, contributed to by a buzzing hivemind? Especially as the pace of ML research accelerates? We receive constant feedback on everything from function signatures with hundreds of arguments to duplicated code and optimization concerns, and we listen to all of it, or try to. The library |
|
|
We continue to support all new models and expect to do so for the foreseeable future. |
|
|
|
|
|
This post dissects the design philosophy that makes this possible. It |
|
|
|
|
|
We formalize and articulate the "tenets" that have been guiding our development, demonstrate how they are implemented in code, and show the measurable impact they have on the library |
|
|
|
|
|
For any OSS maintainer, power user, or contributor, this is the map to understanding, using, and building upon `transformers`; but not only: any project of comparable size will require you to make deep choices, not only on design and choice of abstractions, but on the very mindset of the software you are building. These tenets may or may not be applicable to your project, but they provide a glimpse on how we work that could be helpful or inspirational. |
|
|
|
|
|
Conventions used throughout this post: |
|
|
|
|
|
* [Tenets exemplified](#source-of-truth) will have their summary available on hover. |
|
|
|
|
|
* [External links](https://huggingface.co/blog/welcome-openai-gpt-oss) to articles will help you solidify your knowledge. |
|
|
|
|
|
* [Several interactive visualisations](#generated-modeling) are available as you go - scroll, zoom, drag away to explore. |
|
|
|
|
|
<div class="crumbs"> |
|
|
* Breadcrumb boxes summarize what you just learned, connect it to the tenets, and point to what |
|
|
</div> |
|
|
|
|
|
We will get started by enumerating the tenets. Then we |
|
|
|
|
|
## The core tenets of transformers |
|
|
|
|
|
|
|
|
We summarize the foundations on which we |
|
|
|
|
|
These principles were not decided in a vacuum. The library _evolved_ towards them, and once they _emerged_, they were recognized as critical. |
|
|
|
|
|
<div class="tenet-list"> |
|
|
<ol> |
|
|
<li class="tenet"> |
|
|
<a id="source-of-truth"></a> |
|
|
<strong>Source of Truth</strong> |
|
|
<p>We aim to be a [source of truth for all model definitions](https://huggingface.co/blog/transformers-model-definition). This is more of a goal than a tenet, but it strongly guides our decisions. Model implementations should be reliable, reproducible, and faithful to the original implementations. If we are successful, they should become reference baselines for the ecosystem, so they |
|
|
<em>This overarching guideline ensures quality and reproducibility across all models in the library, and aspires to make the community work easier.</em> |
|
|
</li> |
|
|
|
|
|
<li class="tenet"> |
|
|
<a id="one-model-one-file"></a> |
|
|
<strong>One Model, One File</strong> |
|
|
<p>All inference and training core logic has to be visible, top‑to‑bottom, to maximize each model |
|
|
<em>Every model should be completely understandable and hackable by reading a single file from top to bottom.</em> |
|
|
</li> |
|
|
<li class="tenet"> |
|
|
<a id="code-is-product"></a> |
|
|
<strong>Code is The Product</strong> |
|
|
<p>Optimize for reading, diffing, and tweaking. Our users are power users. Variables are explicit, we use full words, and even several words. Readability is primordial.</p> |
|
|
<em>Code quality matters as much as functionality - optimize for human readers, not just computers.</em> |
|
|
</li> |
|
|
<li class="tenet"> |
|
|
<a id="standardize-dont-abstract"></a> |
|
|
<strong>Standardize, Don |
|
|
<p>If it |
|
|
<em>Model-specific logic belongs in the model file, not hidden behind abstractions.</em> |
|
|
</li> |
|
|
<li class="tenet"> |
|
|
<a id="do-repeat-yourself"></a> |
|
|
<strong>DRY* (DO Repeat Yourself)</strong> |
|
|
<p>Copy code when it helps users; keep successors in sync without centralizing behavior.</p> |
|
|
<p><strong>Evolution:</strong> With the introduction and global adoption of <a href="#modular">modular</a> transformers, we do not repeat any logic in the modular files, but end user files remain faithful to the original tenet as if code had been copied to make modeling files standalone.</p> |
|
|
<em>Strategic duplication can improve readability and maintainability when done thoughtfully.</em> |
|
|
</li> |
|
|
<li class="tenet"> |
|
|
<a id="minimal-user-api"></a> |
|
|
<strong>Minimal User API</strong> |
|
|
<p>Config, model, preprocessing; `from_pretrained`, `save_pretrained`, `push_to_hub`. We want the least amount of codepaths. Reading should be obvious, configurations should be obvious.</p> |
|
|
<em>Keep the public interface simple and predictable, users should know what to expect.</em> |
|
|
</li> |
|
|
<li class="tenet"> |
|
|
<a id="backwards-compatibility"></a> |
|
|
<strong>Backwards Compatibility</strong> |
|
|
<p>Evolve by additive standardization, never break public APIs.</p> |
|
|
<p>Any artifact that was once on the hub and loadable with transformers should be usable indefinitely with the same interface. Further, public methods should not change to avoid breaking dependencies.</p> |
|
|
<em>Once something is public, it stays public. Evolution through addition, not breaking changes.</em> |
|
|
</li> |
|
|
<li class="tenet"> |
|
|
<a id="consistent-public-surface"></a> |
|
|
<strong>Consistent Public Surface</strong> |
|
|
<p>Same argument names, same outputs, hidden states and attentions exposed, enforced by tests. This is a goal as well as a tenet.</p> |
|
|
<em>All models should feel familiar - consistent interfaces reduce cognitive load.</em> |
|
|
</li> |
|
|
</ol> |
|
|
</div> |
|
|
|
|
|
|
|
|
When a PR is merged, it is because the contribution is worthwhile, and because the `transformers` team finds the design of the contribution to be aligned with these principles. |
|
|
|
|
|
Does all the code in the library strictly follow these tenets? No. The library is a gigantic house with connected nooks, corridors, crannies everywhere, built by thousands of different workers. We _try_ to make it so all the code added is compliant, because if we fail and merge it, we cannot change it lest we break [backwards compatibility](#backwards-compatibility). |
|
|
|
|
|
<!-- I found the transition to the following example confusing. It implied (because of the previous paragraph and the `for instance` clause) that it |
|
|
|
|
|
To see what constitutes adherence to the tenets, let |
|
|
|
|
|
The following function, which is essential to the implementation of [Rotary Positional Embeddings](https://huggingface.co/papers/2104.09864), can be found in 70 `modeling_<file>.py` files across `src/transformers/models/.` Why keep it? Because we want all the model logic to be [contained in the modeling file](#one-model-one-file). In order to do that, we [do repeat ourselves](#do-repeat-yourself). |
|
|
|
|
|
```python |
|
|
def rotate_half(x): |
|
|
"""Rotates half the hidden dims of the input.""" |
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2 :] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
``` |
|
|
|
|
|
You can use a simple regex, like [this one]() to look at all methods of a given name across your codebase and look at their differences and similarities. |
|
|
<!-- I |
|
|
|
|
|
We want all models to have self-contained modeling code. Every core functionality _must_ be in the modeling code, every non-core functionality _can_ be outside of it. |
|
|
|
|
|
This comes at a great cost. For a long time we used the `#Copied from...` mechanism: we added comments that documented that some code was copied from another model, saving time both for the reviewers and for the CI: we had tooling to ensure that the copied blocks remained in sync. But the LOC count kept creeping up. Each new model copied over hundreds of lines that we considered largely boilerplate, yet, we could not remove them. |
|
|
|
|
|
We needed to separate two principles that were so far intertwined, [repetition](#do-repeat-yourself) and [hackabilty](#one-model-one-file). |
|
|
|
|
|
What was the solution to this? Let |
|
|
|
|
|
<div class="crumbs"> |
|
|
<strong>TL;DR:</strong> Read the code in one place (<a href="#one-model-one-file">One Model, One File</a>). Keep semantics local (<a href="#standardize-dont-abstract">Standardize, Don |
|
|
|
|
|
<strong>Next:</strong> how modular transformers honor these while removing boilerplate. |
|
|
</div> |
|
|
|
|
|
|
|
|
## <a id="modular"></a> Modular transformers |
|
|
|
|
|
Transformers is an opiniated library. The previous [philosophy](https://huggingface.co/docs/transformers/en/philosophy) page, and the [2022 blog post](https://huggingface.co/blog/transformers-design-philosophy) were already pointing at the drawbacks mentioned just above, which have been iteratively addressed. [`modular` transformers was introduced](https://huggingface.co/docs/transformers/en/modular_transformers) to allow a form of inheritance without breaking [One model, One file](#one-model-one-file). |
|
|
|
|
|
We amended the principle of [DRY*](#do-repeat-yourself) by progressively removing all pieces of code that were "copied from" another file. |
|
|
|
|
|
It works as follows. In order to contribute a model –GLM, for instance– we define a `modular_` file that can inherit from _any function across all other modeling, configuration and processor files already available in the library_. The modular file can use inheritance across models, but then it |
|
|
|
|
|
<summary id="generated-modeling">Auto-generated modeling code</summary> |
|
|
|
|
|
<HtmlEmbed src="transformers/glm-compare.html" /> |
|
|
|
|
|
As you can see, we can define a new model as a _modular_ combination of fragments taken from others. |
|
|
|
|
|
You might think "well that's just how inheritance works". The crucial difference is that we do _visibly_ what is essentially the _compiler_ |
|
|
|
|
|
|
|
|
<!-- some ideas for additional hand-holding: link to the implementation of `LlamaAttention` to show it was copied (and modified), or maybe provide a git diff view between the GlmAttention and LlamaAttention implementations --> |
|
|
|
|
|
What is the consequence? When adding a model, we do not need to go over the entire modeling file. The modular (left side above) is enough. |
|
|
|
|
|
When `AutoModel.from_pretrained(...)` is called, it is indeed the modeling (right side) that is ran, and all the tests run on the modeling code. More importantly, the auto-generated modeling file is what users _read_ to understand the code, what they step through in their debuggers and what they hack for their needs. |
|
|
|
|
|
What does that give us? |
|
|
|
|
|
<div class="crumbs"> |
|
|
<strong>TL;DR:</strong> A small <code>modular_*.py</code> declares reuse; the expanded modeling file stays visible (<a href="#one-model-one-file">One Model, One File tenet preserved</a>). Reviewers and contributors maintain the shard, not the repetition. |
|
|
|
|
|
<strong>Next:</strong> the measurable effect on effective LOC and maintenance cost. |
|
|
</div> |
|
|
|
|
|
|
|
|
### A maintainable control surface |
|
|
|
|
|
The effect of modular can be measured in lines of code (LOC). If a model only has a modeling file, we add its LOC count. |
|
|
However, if a model has a modular_*.py and a corresponding automatically generated modeling_*/.py, we only count the LOC under the modular file. The modeling code has no maintenance cost as it is strictly dependent on the modular file. |
|
|
|
|
|
That gives an "effective LOC" curve: the 𝗺𝗮𝗶𝗻𝘁𝗲𝗻𝗮𝗻𝗰𝗲 𝘀𝘂𝗿𝗳𝗮𝗰𝗲. |
|
|
|
|
|
Measured on git history, raw `modeling_*.py` grew at ~362 LOC/day before modular; counting only modular shards yields ~25 LOC/day after — about **15× lower**. The effective curve (blue line below) represents the **maintenance surface** today: what maintainers actually read and review. |
|
|
|
|
|
<!-- Yeah, super good point that effective == maintenable --> |
|
|
|
|
|
Less code to hand-maintain means fewer places to break. Of course LOC is not a direct measure of complexity, but they correlate in review effort and change risk. |
|
|
|
|
|
<HtmlEmbed src="transformers/loc-growth.html" /> |
|
|
|
|
|
<!-- What is "Modeling LOC (included)"? The modeling code, not counting the files that have a modular counterpart? If so, perhaps we can say that the blue line (effective) is the sum of the red + green, whereas the yellow would have been the progression without modular. Also worth mentioning imo that the surface area has been essentially constant (in LOC) since modular. --> |
|
|
|
|
|
Notice there |
|
|
|
|
|
But this was not the only effort that allowed us to reduce maintenance load. |
|
|
|
|
|
We recently underwent a thoughtful refactor of the attention implementation. You |
|
|
|
|
|
_Attention computation_ happens at a _lower_ level of abstraction than the model itself. |
|
|
|
|
|
However, we were adding specific torch operations to every model for each backend (sdpa, various flash-attention versions, flex attention) but it wasn |
|
|
|
|
|
<div class="crumbs"> |
|
|
Evidence: effective (i.e., maintenable) LOC growth drops ~15× when counting shards instead of expanded modeling files. Less code to read, fewer places to break. |
|
|
|
|
|
<strong>Next:</strong> how the attention interface stays standard without hiding semantics. |
|
|
</div> |
|
|
|
|
|
### <a id="attention-classes"></a> External Attention classes |
|
|
|
|
|
The solution for the "attention abstraction problem" was to move to a standard [attention interface](https://huggingface.co/docs/transformers/en/attention_interface) that allows the following: |
|
|
|
|
|
The naive implementation of attention, called "eager", is available by default. We use a `Callable` called `eager_attention_forward`, which can run as long as the user has PyTorch installed – which is a requirement any way. |
|
|
|
|
|
Instead of using a class interface and a class hierarchy, we just moved to a function interface. When a more complex attention implementation is needed, we use other Callables, including much faster kernel bindings when available. The decision to use a different attention implementation is based on the model configuration file we download from the Hub, and it can also be overridden by the user. |
|
|
|
|
|
This is a clear example that that we prefer an interface that is [standard, but not abstract](#standardize-dont-abstract). To be completely precise, this is what the interface selection looks like in transformers code: |
|
|
|
|
|
```python |
|
|
attention_interface: Callable = eager_attention_forward |
|
|
if self.config._attn_implementation != "eager": |
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
``` |
|
|
|
|
|
A strength of the new attention interface is the possibility to enforce specific kwargs, which are needed by kernel providers and other dependencies. We know that kwargs are often a necessary evil that plagues tools with widespread compatibility; it is something we have aimed to reduce, and will continue to reduce in order to improve readability - with them, the current system is a [minimal user api](#minimal-user-api). |
|
|
|
|
|
<!-- not fully following the transition here --> |
|
|
|
|
|
Backend integrations sometimes require specific kwargs. We reduce that surface and document expectations; where flexibility is necessary, we plan to use `typing.Annotated` to convey shapes and invariants without constraining integrations. Such an implementation could look like this in the future: |
|
|
|
|
|
```python |
|
|
from typing import Annotated |
|
|
|
|
|
MyModelOutputAnnotated = Annotated[MyModelOutput, "shape: (B, C, H, W)"] |
|
|
``` |
|
|
|
|
|
|
|
|
<div class="crumbs"> |
|
|
Attention semantics remain in <code>eager_attention_forward</code>; faster backends are opt-in via config. We inform via types/annotations rather than enforce rigid kwargs, preserving integrations. |
|
|
|
|
|
<strong>Next:</strong> parallel partitioning is declared as a plan, not through model surgery. |
|
|
</div> |
|
|
|
|
|
### <a id="simpler-tensor-parallelism"></a> Configurable Tensor Parallelism |
|
|
|
|
|
If you |
|
|
|
|
|
The essential part is that, as [the documentation states](https://huggingface.co/docs/transformers/v4.56.2/perf_train_gpu_many#tensor-parallelism), when tensors get too large to fit on a single GPU, they are sliced along a particular dimension and every slice is sent to a different GPU. |
|
|
|
|
|
Why does it matter? |
|
|
|
|
|
Because we want to avoid code modifications that are unrelated to the model. |
|
|
|
|
|
We choose to place the level of abstraction higher than the device placement: a matrix multiplication - a `nn.Linear` layer - should be always expressed in the same way, regardless of how it is placed. |
|
|
|
|
|
Hence, we want to touch the modeling code [minimally](#minimal-user-api), and only modify it when _architectural changes_ are involved – not depending on the way you run it. For tensor parallelism, we simply specify a `tp_plan`: |
|
|
|
|
|
<HtmlEmbed src="transformers/tp-plan.html" /> |
|
|
|
|
|
The plan is written once, saved as part of the config and passed to `.from_pretrained()`. It maps module name patterns to partitioning strategies. Strategies are resolved by the internal `ParallelInterface`, which wires to sharding implementations `ColwiseParallel`, `RowwiseParallel`, packed variants, and so on. |
|
|
|
|
|
The alternative would be to modify classes depending on supported types of parallelism. |
|
|
|
|
|
The `tp_plan` solution allows users to run the same model on a single GPU, or distribute it using multiple processes per node, e.g. 4 GPUs: |
|
|
|
|
|
`torchrun --nproc-per-node 4 demo.py` |
|
|
|
|
|
Semantics stay in the model (a Linear stays a Linear), parallelization is orthogonal and declared via strings: "colwise" splits columns of weights/bias across ranks; "rowwise" splits rows; packed variants shard fused weights; The mapping keys accept glob patterns like `layers.*.mlp.down_proj` to target repeated submodules. |
|
|
|
|
|
<div class="crumbs"> |
|
|
Parallelization is specified in the configuration (<code>tp_plan</code>), not through edits to <code>Linear</code>s. Glob patterns target repeated blocks; modeling semantics stay intact. |
|
|
|
|
|
<strong>Next:</strong> per-layer attention/caching schedules declared in config, not hardcoded. |
|
|
</div> |
|
|
|
|
|
### <a id="layers-attentions-caches"></a> Layers, attentions and caches |
|
|
|
|
|
Following the same logic, the _nature_ of attention and per-layer caching should not be hardcoded. We should be able to specify in the configuration how each layer is implemented. Thus, we define a mapping like: |
|
|
|
|
|
|
|
|
```python |
|
|
ALLOWED_LAYER_TYPES = ( |
|
|
"full_attention", |
|
|
"sliding_attention", |
|
|
"chunked_attention", |
|
|
"linear_attention", |
|
|
... |
|
|
) |
|
|
``` |
|
|
|
|
|
and the configuration can be _explicit_ about which attention type is in which layer. See, for example, [gpt-oss](https://huggingface.co/openai/gpt-oss-120b/blob/main/config.json#L15), which alternates sliding and full attention: |
|
|
|
|
|
```python |
|
|
"layer_types": [ |
|
|
"sliding_attention", |
|
|
"full_attention", |
|
|
..., |
|
|
"sliding_attention", |
|
|
"full_attention" |
|
|
], |
|
|
``` |
|
|
|
|
|
This is [minimal](#minimal-user-api) to implement on the user side, and allows to keep the modeling code untouched. It is also easy to tweak. |
|
|
|
|
|
<div class="crumbs"> |
|
|
Allowed layer types are explicit; schedules (e.g., sliding/full alternation) live in config. This keeps the file readable and easy to tweak. |
|
|
|
|
|
<strong>Next:</strong> speedups come from kernels that don |
|
|
</div> |
|
|
|
|
|
|
|
|
### <a id="community-kernels"></a>Community Kernels |
|
|
|
|
|
The same principle extends to normalization, activation, and other code paths. The model defines **semantics**; a kernel defines **how** to execute them faster. We annotate the module to borrow a community‑provided forward, keeping a [consistent public surface](#consistent-public-surface) |
|
|
|
|
|
```python |
|
|
@use_kernel_forward_from_hub("RMSNorm") |
|
|
class GlmRMSNorm(nn.Module): |
|
|
... |
|
|
``` |
|
|
|
|
|
This also opens another contribution path: GPU specialists can contribute optimized kernels to the [Kernels Hub](https://huggingface.co/kernels-community), and have them immediately available to use in `transformers` and other libraries. You can check the [kernel community blog post](https://huggingface.co/blog/hello-hf-kernels) to learn more about it! |
|
|
|
|
|
Even more resources have been added, like the formidable [kernel builder](https://github.com/huggingface/kernel-builder) with its connected resources to [help you build kernels with it](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) and [with nix](https://github.com/huggingface/kernel-builder/blob/main/docs/nix.md). |
|
|
|
|
|
<div class="crumbs"> |
|
|
Models define semantics; kernels define how to run them faster. Use decorations to borrow community forwards while keeping a consistent public surface. |
|
|
|
|
|
<strong>Next:</strong> what modularity looks like across the repo. |
|
|
</div> |
|
|
|
|
|
## The Sate of Modular |
|
|
|
|
|
Modular provides a form of inheritance in our codebase. Some models become standards, and model contributors have the opportunity to _define standards_ if their architectures are adopted. Pushing the boundaries of scientific knowledge can translate into the boundaries of engineering if this effort is made, and we |
|
|
|
|
|
It |
|
|
So I wanted to take a look at the current **state of modularity** across the repository. How many models are defined using components of others? |
|
|
|
|
|
To get this graph, I used the heuristic of modular inheritance. |
|
|
1. Does this model have a `modular` file? |
|
|
2. In this `modular` file, what models, configurations and processings are imported? |
|
|
3. Recurse through the model list that way. |
|
|
|
|
|
So what do we see? Llama is a basis and an influence for many models, and it shows. |
|
|
Radically different architectures such as mamba have spawned their own dependency subgraph. |
|
|
|
|
|
<!-- A couple of ideas here: |
|
|
- Use screenshots to clearly show the points we make. For example, the cluster with Llama in the center, or the one about DETR/llava below. |
|
|
- Use a link to open the viewer full-screen for better manipulation and exploration. |
|
|
--> |
|
|
|
|
|
(Graph reading guide: nodes are models; edges are modular imports). |
|
|
|
|
|
<HtmlEmbed src="transformers/dependency-graph.html" /> |
|
|
|
|
|
In the case of VLMs, there |
|
|
As you can see, there is a small DETR island, a little llava pocket, and so on, but it |
|
|
|
|
|
Another problem is, this visualization only shows `modular` models. Several models still do NOT have a modular file. |
|
|
|
|
|
How do we spot them, and how do we identify modularisable models? |
|
|
|
|
|
<div class="crumbs"> |
|
|
Llama-lineage is a hub; several VLMs remain islands — engineering opportunity for shared parents. |
|
|
|
|
|
<strong>Next:</strong> timeline + similarity signals to spot modularisable candidates. |
|
|
</div> |
|
|
|
|
|
|
|
|
### Many models, but not enough yet, are alike |
|
|
|
|
|
I looked into Jaccard similarity, which we use to measure set differences, to find similarities across models. I know that code is more than a set of characters stringed together. We also tried code-embedding models that ranked candidates better in practice, but for this post we stick to the deterministic Jaccard index. |
|
|
|
|
|
It is interesting, for our comparison, to look at _when_ we deployed the modular logic and what was its rippling effect on the library. You can check the [larger space](https://huggingface.co/spaces/Molbap/transformers-modular-refactor) to play around, but the gist is: adding modular allowed to connect more and more models to solid reference points. But we still have a lot of gaps to fill. |
|
|
|
|
|
Zoom out below - it |
|
|
|
|
|
<HtmlEmbed src="transformers/model-timeline.html" /> |
|
|
|
|
|
<!-- screenshot would be helpful --> |
|
|
|
|
|
If you check llava, you |
|
|
|
|
|
<div class="crumbs"> |
|
|
Similarity metrics (Jaccard or embeddings) surface likely parents; the timeline shows consolidation after modular landed. Red nodes/edges = candidates (e.g., <code>llava_video</code> → <code>llava</code>) for refactors that preserve behavior. |
|
|
|
|
|
<strong>Next:</strong> concrete VLM choices that avoid leaky abstractions. |
|
|
</div> |
|
|
|
|
|
### VLM improvements, avoiding abstraction |
|
|
|
|
|
We don |
|
|
|
|
|
For instance, we thought of abstracting away the mixing of `inputs_embeds`, the tensor fed into an LLM decoder in 95% of the existing VLMs. It would have looked like something like |
|
|
|
|
|
```python |
|
|
class InputsEmbeddingMixerMixin(nn.Module): |
|
|
# |
|
|
``` |
|
|
|
|
|
But this is [abstracting away an important component of the modeling](#standardize-dont-abstract). Embedding mixin is part of the model, removing it would break it. A user opening [`modeling_qwen2.5_vl`](https://huggingface.co/collections/Qwen/qwen25-vl-6795ffac22b334a837c0f9a5) should not have to go to another file to understand how it works. |
|
|
|
|
|
<!-- ^ should we link to the code instead? --> |
|
|
|
|
|
What is the current state of these “abstractions” across the codebase? |
|
|
You will see all the imports around a modeling file, here [Gemma3n](https://huggingface.co/google/gemma-3n-E4B-it). |
|
|
|
|
|
 |
|
|
|
|
|
|
|
|
As you can see, the `GenerationMixin` node is already very heavy. It encompasses all of the utilities around `.generate`, it is second only to `nn.Module`. |
|
|
That means every decision we make to abstract something else has to be extremely careful. |
|
|
|
|
|
The following [Pull request to standardize placeholder masking](https://github.com/huggingface/transformers/pull/39777) is a good example of what kind of changes are acceptable. In a VLM, we always need to insert embeddings from various encoders at various positions, so we can have a function to do it. For Qwen2 VL, for instance, it will look like this: |
|
|
|
|
|
```python |
|
|
def get_placeholder_mask( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
inputs_embeds: torch.FloatTensor, |
|
|
image_features: torch.FloatTensor = None, |
|
|
video_features: torch.FloatTensor = None, |
|
|
): |
|
|
""" |
|
|
Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is |
|
|
equal to the length of multimodal features. If the lengths are different, an error is raised. |
|
|
""" |
|
|
if input_ids is None: |
|
|
special_image_mask = inputs_embeds == self.get_input_embeddings()( |
|
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) |
|
|
) |
|
|
special_image_mask = special_image_mask.all(-1) |
|
|
special_video_mask = inputs_embeds == self.get_input_embeddings()( |
|
|
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) |
|
|
) |
|
|
special_video_mask = special_video_mask.all(-1) |
|
|
else: |
|
|
special_image_mask = input_ids == self.config.image_token_id |
|
|
special_video_mask = input_ids == self.config.video_token_id |
|
|
|
|
|
n_image_tokens = special_image_mask.sum() |
|
|
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) |
|
|
if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): |
|
|
raise ValueError( |
|
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" |
|
|
) |
|
|
|
|
|
n_video_tokens = special_video_mask.sum() |
|
|
special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) |
|
|
if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): |
|
|
raise ValueError( |
|
|
f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" |
|
|
) |
|
|
|
|
|
return special_image_mask, special_video_mask |
|
|
``` |
|
|
|
|
|
But this is _within_ the modeling file, not in the `PreTrainedModel` base class. It will not move away from it, because it |
|
|
|
|
|
<!-- So the main conclusion here is that VLMs should use modular more to come up with de-facto standard modules without abstracting them away? --> |
|
|
|
|
|
<div class="crumbs"> |
|
|
Keep VLM embedding mix in the modeling file (semantics), standardize safe helpers (e.g., placeholder masking), don |
|
|
|
|
|
<strong>Next:</strong> pipeline-level wins that came from PyTorch-first choices (fast processors). |
|
|
</div> |
|
|
|
|
|
|
|
|
### On image processing and processors |
|
|
|
|
|
Deciding to become a `torch`-first library meant relieving a tremendous amount of support for `jax ` and `TensorFlow`, and it also meant that we could be more lenient into the amount of torch-dependent utilities that we were able to accept. One of these is the _fast processing_ of images. Where inputs were once minimally assumed to be ndarrays, enforcing native `torch` and `torchvision` inputs allowed us to massively improve processing speed for each model. |
|
|
|
|
|
The gains in performance are immense, up to 20x speedup for most models when using compiled torchvision ops. Furthermore, it allows to run the whole pipeline solely on GPU. |
|
|
|
|
|
 |
|
|
<p class="figure-legend">Thanks <a href="https://huggingface.co/yonigozlan">Yoni Gozlan</a> for the great work!</p> |
|
|
|
|
|
<div class="crumbs"> |
|
|
PyTorch-first lets processors assume torch/torchvision and run the whole pipeline on GPU; big per-model speedups. |
|
|
|
|
|
<strong>Next:</strong> how this lowers friction for contributors and downstream users. |
|
|
</div> |
|
|
|
|
|
|
|
|
## Reduce barrier to entry/contribution |
|
|
|
|
|
This is an overall objective: there |
|
|
|
|
|
Having a framework means forcing users into it. It restrains flexibility and creativity, which are the fertile soil for new ideas to grow. |
|
|
|
|
|
Among the most valuable contributions to `transformers` is of course the addition of new models. Very recently, [OpenAI added GPT-OSS](https://huggingface.co/blog/welcome-openai-gpt-oss), which prompted the addition of many new features to the library in order to support [their model](https://huggingface.co/openai/gpt-oss-120b). These additions are immediately available for other models to use. |
|
|
|
|
|
Another important advantage is the ability to fine-tune and pipeline these models into many other libraries and tools. Check here on the hub how many finetunes are registered for [gpt-oss 120b](https://huggingface.co/models?other=base_model:finetune:openai/gpt-oss-120b), despite its size! |
|
|
|
|
|
|
|
|
<div class="crumbs"> |
|
|
The shape of a contribution: add a model (or variant) with a small modular shard; the community and serving stacks pick it up immediately. Popularity trends (encoders/embeddings) guide where we invest. |
|
|
|
|
|
<strong>Next:</strong> power tools enabled by a consistent API. |
|
|
</div> |
|
|
|
|
|
|
|
|
### <a id="encoders-ftw"></a> Models popularity |
|
|
|
|
|
Talking about dependencies, we can take a look at the number of downloads as a measure of popularity. One thing we see is the prominence of encoders, despite the apparent prevalence of decoder LLMs. The reason is that encoders are used to generate embeddings, which have multiple downstream uses. Just check out [EmbeddingGemma](https://huggingface.co/blog/embeddinggemma) for a modern recap. Hence, it is vital to keep the encoders portion of the library viable, usable, fine-tune-able. |
|
|
|
|
|
<div> |
|
|
<HtmlEmbed src="transformers/model-visualisation.html" /> |
|
|
</div> |
|
|
|
|
|
As the codebase grows, we need to maintain it in coordination with our friend [Sentence Transformers codebase](https://huggingface.co/sentence-transformers). Retrieval use-cases, smart databases, FAISS-based indexing rely on it, and thus indirectly on transformers. |
|
|
|
|
|
In that regard, we DO want to be a modular toolbox, being [minimal](#minimal-user-api) enough and well documented enough so any ML/AI developer can use `transformers` without having to think about it. We aim to reduce the cognitive load brought about by model development, not increase it. |
|
|
|
|
|
So, how do these design choices, these "tenets" influence development of models and overall usage of transformers? |
|
|
|
|
|
<div class="crumbs"> |
|
|
Encoders remain critical for embeddings and retrieval; maintaining them well benefits the broader ecosystem (e.g., Sentence Transformers, FAISS). |
|
|
|
|
|
<strong>Next:</strong> dev tools that leverage unified attention APIs and PyTorch-only internals. |
|
|
</div> |
|
|
|
|
|
|
|
|
## A surgical toolbox for model development |
|
|
|
|
|
Transformers provides many tools that can help you while adding a new architecture, or help you understand the inner workings of the library. |
|
|
|
|
|
### Attention visualisation |
|
|
|
|
|
All models have the same internal API for attention computation, thanks to [the externalisation of attention classes](#external-attention-classes). This allows us to build cool tools to visualize the inner workings of the attention mechanism. |
|
|
|
|
|
One particular piece of machinery is the `attention mask`. Here you see the famous bidirectional attention pattern for the whole prefix (text + image) in PaliGemma and all Gemma2+ models, contrasting with the usual "causal-only" models. |
|
|
|
|
|
<HtmlEmbed src="transformers/attention-visualizer.html" /> |
|
|
|
|
|
<div class="crumbs"> |
|
|
Uniform attention APIs enable cross-model diagnostics (e.g., PaliGemma prefix bidirectionality vs causal). |
|
|
|
|
|
<strong>Next:</strong> whole-model tracing for ports and regressions. |
|
|
</div> |
|
|
|
|
|
|
|
|
### Logging entire model activations |
|
|
|
|
|
Because everything is PyTorch, we can easily [debug any model](https://huggingface.co/docs/transformers/internal/model_debugging_utils) when we want to add it to transformers. We now have a power-user tool for porting or adding models, that wraps a forward pass, intercepts every submodule call, and logs shapes, dtypes, and sample statistics of inputs/outputs to nested JSON. |
|
|
|
|
|
It just works with PyTorch models and is especially useful when aligning outputs with a reference implementation, to match our [Source of Truth guideline](#source-of-truth). |
|
|
|
|
|
 |
|
|
|
|
|
|
|
|
<div class="crumbs"> |
|
|
Forward interception and nested JSON logging align ports to reference implementations, reinforcing "Source of Truth." <strong> |
|
|
|
|
|
Next:</strong> CUDA warmup reduces load-time without touching modeling semantics. |
|
|
</div> |
|
|
|
|
|
|
|
|
|
|
|
### Cooking faster CUDA warmups |
|
|
|
|
|
Having a clean _external_ API allows us to work on the [true inner workings of transformers](#code-is-product). One of a few recent additions was the _CUDA warmup_ via `caching_allocator_warmup`, which dramatically improved loading times by pre-allocating GPU memory to avoid malloc bottlenecks during model loading. It can achieve a 7x speedup factor for an 8B model, or 6x for a 32B one, as you can check in [the PR](https://github.com/huggingface/transformers/pull/36380)! |
|
|
|
|
|
<HtmlEmbed src="transformers/warmup_demo.html" /> |
|
|
|
|
|
It |
|
|
|
|
|
<div class="crumbs"> |
|
|
Pre-allocating GPU memory removes malloc spikes (e.g., 7× for 8B, 6× for 32B in the referenced PR). |
|
|
|
|
|
<strong>Next:</strong> consistent interfaces allow transformers-serve. |
|
|
</div> |
|
|
|
|
|
|
|
|
### Transformers-serve and continuous batching |
|
|
|
|
|
Having all these models readily available and sharing the same interface allowed us to implement transformers-serve, a CLI tool to expose models through a standard OpenAI http API. |
|
|
|
|
|
```bash |
|
|
transformers serve |
|
|
|
|
|
curl -X POST http://localhost:8000/v1/chat/completions \ |
|
|
-H "Content-Type: application/json" \ |
|
|
-d |
|
|
``` |
|
|
|
|
|
transformers-serve uses continuous batching (see [this PR](https://github.com/huggingface/transformers/pull/38085) and also [this one](https://github.com/huggingface/transformers/pull/40426)) for better GPU utilization, and is very much linked to the great work of vLLM with the `paged attention kernel` – a futher justification of [external kernels](#community-kernels). |
|
|
|
|
|
transformers-serve is not meant for user-facing production services – tools like vLLM or SGLang are super optimized for that –, but it |
|
|
- Quickly verify that your model is compatible with continuous batching and paged attention. |
|
|
- Run ad-hoc vibe tests on any model, without worrying to deploy anything. |
|
|
- Run evaluations efficiently, again without having to spend a lot of time engineering your infrastructure. |
|
|
|
|
|
For model deployment, check [Inference Providers](https://huggingface.co/docs/inference-providers/en/index) or roll your solution using any of the excellent serving libraries. |
|
|
|
|
|
<div class="crumbs"> |
|
|
OpenAI-compatible surface + continuous batching; kernels/backends slot in because the modeling API stayed stable. |
|
|
|
|
|
<strong>Next:</strong> reuse across vLLM/SGLang relies on the same consistency. |
|
|
</div> |
|
|
|
|
|
|
|
|
## Community reusability |
|
|
|
|
|
The transformers-serve CLI built on transformers, for sure, but the library is made first and foremost to be _reused_ at large by the open-source ecosystem. |
|
|
|
|
|
Adding a model to transformers means: |
|
|
|
|
|
- having it immediately available to the community |
|
|
- having it immediately usable in vLLM, [SGLang](https://huggingface.co/blog/transformers-backend-sglang), and so on without additional code. In the case of vLLM, transformers was added as a backend to run models on vLLM, which optimizes throughput/latency on top of existing transformers architectures [as seen in this great vLLM x HF blog post.](https://blog.vllm.ai/2025/04/11/transformers-backend.html) |
|
|
- being the reference code for implementations in MLX, llama.cpp and other libraries. |
|
|
|
|
|
This further cements the need for a [consistent public surface](#consistent-public-surface): we are a backend and a reference, and there |
|
|
|
|
|
|
|
|
<div class="crumbs"> |
|
|
Being a good backend consumer requires a consistent public surface; modular shards and configs make that stability practical. |
|
|
|
|
|
<strong>Next:</strong> what changes in v5 without breaking the promise of visible semantics. |
|
|
</div> |
|
|
|
|
|
## What is coming next |
|
|
|
|
|
The next major version of `transformers` is just around the corner (and will have another blog post to its name when it comes out). When v5 is released, we aim to keep [backwards compatibility](#backwards-compatibility) as solid as possible. The changes we make now are in service of that goal. |
|
|
|
|
|
We will lean further into a modular toolbox, not a framework. You should not be forced to rewrite modeling code. It’s better when a model can inherit from `PreTrainedModel` and opt into Tensor Parallel, `from_pretrained`, sharding, `push_to_hub`, loss plumbing, and external stacks like PEFT/TRL/SGLang/vLLM. |
|
|
|
|
|
<!-- Maybe end with some statement that shows lots of excitement --> |
|
|
|