Attention, as a core layer of the ubiquitous Transformer architecture, is a bottleneck for large language models and long-context applications. FlashAttention (and FlashAttention-2) pioneered an approach to speed up attention on GPUs by minimizing memory reads/writes, and is now used by most libraries to accelerate Transformer training and inference. This has contributed to a massive increase in LLM context length in the last two years, from 2-4K (GPT-3, OPT) to 128K (GPT-4), or even 1M (Llama 3). However, despite its success, FlashAttention has yet to take advantage of new capabilities in modern hardware, with FlashAttention-2 achieving only 35% utilization of theoretical max FLOPs on the H100 GPU. In this blogpost, we describe three main techniques to speed up attention on Hopper GPUs: exploiting asynchrony of the Tensor Cores and TMA to (1) overlap overall computation and data movement via warp-specialization and (2) interleave block-wise matmul and softmax operations, and (3) incoherent processing that leverages hardware support for FP8 low-precision.

We’re excited to release FlashAttention-3 that incorporates these techniques. It’s 1.5-2.0x faster than FlashAttention-2 with FP16, up to 740 TFLOPS, i.e., 75% utilization of H100 theoretical max FLOPS. With FP8, FlashAttention-3 reaches close to 1.2 PFLOPS, with 2.6x smaller error than baseline FP8 attention.

The improvements from FlashAttention-3 will result in:

**More efficient GPU Utilization**: The new technique uses up to 75% of an H100 GPU’s maximum capabilities, up from just 35% before. This results in significantly (1.5-2x) faster than previous versions for training and running of large language models (LLMs).**Better performance with lower precision**: FlashAttention-3 can work with lower precision numbers (FP8) while maintaining accuracy. This allows for even faster processing and potentially lower memory usage, which could lead to cost savings and improved efficiency for customers running large-scale AI operations.**Ability to use longer context in LLMs**: By speeding up the attention mechanism, FlashAttention-3 enables AI models to work with much longer pieces of text more efficiently. This could allow for applications that can understand and generate longer, more complex content without slowing down.

FlashAttention-3 is available at: https://github.com/Dao-AILab/flash-attention

FlashAttention is an algorithm that reorders the attention computation and leverages tiling and recomputation to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. We use tiling to load blocks of inputs from HBM (GPU memory) to SRAM (fast cache), perform attention with respect to that block, and update the output in HBM. By not writing the large intermediate attention matrices to HBM, we reduce the amount of memory reads/writes, which brings 2-4x wallclock time speedup.

Here we show a diagram of FlashAttention forward pass: with tiling and softmax rescaling, we operate by blocks and avoid having to read/write from HBM, while obtaining the correct output with no approximation.

While FlashAttention-2 can achieve up to 70% theoretical max FLOPS on Ampere (A100) GPUs, it does not yet take advantage of new features on Hopper GPUs to maximize performance. We describe some of the new Hopper-specific features here, and why they are important.

- WGMMA (Warpgroup Matrix Multiply-Accumulate). This new feature makes use of the new Tensor Cores on Hopper, with much higher throughput
Without the wgmma instruction, the older mma.sync instruction can only reach about 2/3 the peak throughput of Hopper Tensor Cores: https://arxiv.org/abs/2402.13499v1. than the older mma.sync instruction in Ampere (image from the H100 white paper).

- TMA (Tensor Memory Accelerator). This is a special hardware unit that accelerates the transfer of data between global memory and shared memory, taking care of all index calculation and out-of-bound predication. This frees up registers, which is a valuable resource to increase tile size and efficiency.

- Low-precision with FP8. This doubles the Tensor Core throughput (e.g. 989 TFLOPS with FP16 and 1978 TFLOPS with FP8), but trades off accuracy by using fewer bits to represent floating point numbers.

FlashAttention-3 makes use of all of these new features of Hopper, using powerful abstractions from NVIDIA’s CUTLASS library.

Several work such as ThunderKitten

Attention has GEMMs (those matmuls between Q and K and between attention probability P and V) and softmax as its two main operations. Why do we need to overlap them? Isn’t most of the FLOPS in the GEMMs anyway? As long as the GEMMs are fast (e.g., computed using WGMMA instructions), shouldn’t the GPU be going brrrr?

The problem is that non-matmul operations are much slower than matmul operations on modern accelerators. Special functions such as exponential (for the softmax) have even lower throughput than floating point multiply-add; they are evaluated by the multi-function unit, a unit separate from floating point multiply-add or matrix multiply-add. As an example, the H100 GPU SXM5 has 989 TFLOPS of FP16 matrix multiply, but only 3.9 TFLOPS (256x less throughput) for special functions

The first and easiest way to overlap GEMM and softmax is to do nothing at all! The warp schedulers already try to schedule warps so that if some warps are blocked (e.g., waiting for GEMM results), other warps can run. That is, the warp schedulers do some of this overlapping for us, for free.

However, we can improve on this by doing some of the scheduling manually. As an example, if we have 2 warpgroups (labeled 1 and 2 – each warpgroup is a group of 4 warps), we can use synchronization barriers (bar.sync) so that warpgroup 1 first does its GEMMs (e.g., GEMM1 of one iteration and GEMM0 of the next iteration), and then warpgroup 2 does its GEMMs while warpgroup 1 does its softmax, and so on. This “pingpong” schedule is illustrated in the figure below, where the same color denotes the same iteration.

This would allow us to perform the softmax in the shadow of the GEMMs of the other warpgroup. Of course, this figure is just a caricature; in practice the scheduling is not really this clean. Nevertheless, pingpong scheduling can improve FP16 attention forward pass from around 570 TFLOPS to 620 TFLOPS (head dim 128, seqlen 8K).

Even within one warpgroup, we can have some part of softmax running while the GEMMs of that warpgroup is running. This is illustrated in this figure, where the same color denotes the same iteration.

This pipelining increases throughput from around 620 TFLOPS to around 640-660 TFLOPS for FP16 attention forward, at the cost of higher register pressure. We need more registers to hold both accumulators of the GEMMs, and the input/output of softmax. Overall, we find this technique to offer a favorable tradeoff.

LLM activation can have outliers with much larger magnitude than the rest of the features. These outliers make it difficult to quantize, producing much larger quantization errors. We leverage incoherent processing, a technique used in the quantization literature (e.g. from QuIP and QuIP#) that multiplies the query and key with a random orthogonal matrix to “spread out” the outliers and reduce quantization error. In particular, we use the Hadamard transform (with random signs), which can be done per attention head in O(d log d) instead of O(d^2) time, where d is the head dimension. Since the Hadamard transform is memory-bandwidth bound, it can be fused with previous operations such as rotary embedding (also memory-bandwidth bound) “for free”.

In our experiment where Q, K, V are generated from a standard normal distribution but 0.1% of the entries have large magnitudes (to simulate outliers), we found that incoherent processing can reduce the quantization error by 2.6x. We show numerical error comparison in the table below. Please see the paper for details.

We show some results with FlashAttention-3, and compare it to FlashAttention-2, as well as the implementation in Triton and cuDNN (both of which already use new hardware features of Hopper GPUs).

For FP16, we see about 1.6x-2.0x speedup over FlashAttention-2.

For FP8, we can reach close to 1.2 PFLOPS!

This blogpost highlights some of the optimizations for FlashAttention available on Hopper GPUs. Other optimizations (e.g., variable length sequences, persistent kernel, and in-kernel transpose for FP8) are covered in the paper.

We have seen that designing algorithms that take advantage of the hardware they run on can bring significant efficiency gains and unlock new model capabilities such as long context. We look forward to future work on optimization for LLM inference, as well as generalizing our techniques to other hardware architectures. We also look forward to FlashAttention-3 being integrated in a future release of PyTorch.

]]>**This series is cross-posted at GoombaLab**

- Part I - The Model
- Part II - The Theory
- Part III - The Algorithm
- Part IV - The Systems

Since the release of Mamba 6 months ago, we’ve been pleasantly surprised by the overwhelming community response. It’s been incredibly gratifying to see the line of research on efficient sequence models we’ve been pursuing for years really resonate with the machine learning community and take off more than we could have anticipated. We’ve seen an enormous amount of exciting follow-up work, from direct applications (e.g. vision

Yet despite its potential so far, we weren’t completely satisfied with the first version of Mamba…

From a conceptual standpoint, one of the reasons we found SSMs so fascinating is how they just feel *fundamental*. One way this is exemplified is how they have rich ties to many major paradigms of sequence models. As developed in our earlier works on structured SSMs

But of course, aside from these, there’s another major sequence model paradigm: variants of the ubiquitous **attention** mechanism

Question 1:

What are the conceptual connections between state space models and attention?Can we combine them?

From a computational standpoint, despite the work that went into making Mamba fast (in particular, its hardware-aware selective scan implementation) it’s still much less hardware-efficient than mechanisms such as attention. The missing piece is that modern accelerators such as GPUs and TPUs are *highly* specialized for matrix multiplications. While this isn’t a problem for inference, which is bottlenecked by somewhat different considerations, this can be a big deal during training time.

Question 2:

Can we speed up the training of Mamba models by recasting them as matrix multiplications?

These are the main questions that Mamba-2 – in particular, its new state space model variant – tries to address.

The main point of the Mamba-2 paper is what we call **structured state space duality** (SSD), which refers to several things:

- The
**SSD model**refers to a specific standalone layer, like attention or an SSM, that can be incorporated into deep neural networks - The
**SSD framework**is a general framework for reasoning about this model (and many more theoretical connections) - The
**SSD algorithm**is an algorithm for computing SSD layers much more efficiently than previous SSMs

The main SSD model or “state space dual model” itself really isn’t so complicated! In this first part of a series of blog posts, we’ll provide a self-contained description of the SSD layer (and Mamba-2) in isolation and how it compares to related models, particularly Mamba-1.

In the next parts of this series, we’ll describe the general framework and theoretical connections, which aren’t necessary to actually use Mamba-2.

SSD starts from the same set of equations as Mamba:

\[\begin{aligned} h_{t} &= A_t h_{t-1} + B_t x_t \\ y_t &= C_t^{\top} h_t \end{aligned}\]\begin{equation} \label{eq:ssm} (\text{Selective state space model (SSM)}) \end{equation}

To recap, a **structured state space model (SSM)** *state size, state dimension, or state expansion factor*.

A *selective* state space model allows the $(A, B, C)$ SSM parameters to vary across time

Structured SSMs require $A$ to have structure to be efficiently computable, such as the most commonly used diagonal structure

The original Mamba (or more precisely its core “S6” layer) is exactly a selective SSM with diagonal structure.

**The SSD layer of Mamba-2 makes only one small modification**: it restricts the diagonal $A$ even further to a *scalar times identity* structure; in other words the diagonal elements of $A$ must all be the same value. In this case $A$ can be represented with shape just $\mathtt{(T)}$ and one can also identify $A_t$ as just a scalar (and so we’ll sometimes denote it $a_t$).

Equation \eqref{eq:ssm} is defined only for a single dimensional input $x \in \mathbb{R}^\mathtt{T}$. If $X \in \mathbb{R}^\mathtt{(T, P)}$ has $\mathtt{P}$ separate channels, we can use the same dynamics (i.e. the same SSM $(A, B, C)$) independently for each channel. This can be interpreted as a *single head* of the SSM model.

Here, we think of $X$ as a tensor of shape $\mathtt{(T, P)}$ where $\mathtt{T}$ is the sequence (time) dimension and $\mathtt{P}$ is the “head dimension”.

Multiple heads can be constructed completely independently; for the remainder of this post, we assume that we’re working with a single head. Note that these heads are exactly analogous to how heads in multi-head attention models work, and in Mamba-2 we also choose similar dimensions as modern Transformers, e.g. $\mathtt{P} = 64$ or $\mathtt{P}=128$. (To scale to larger model widths $\mathtt{D} = \mathtt{d\_model}$, we keep this fixed and increase the number of independent heads.)

We can notate the general (selective) state space model as \begin{equation} \label{eq:ssm-transformation} Y^\mathtt{(T,P)} = \mathsf{SSM}(A^\mathtt{(T,…)}, B^\mathtt{(T,N)}, C^\mathtt{(T,N)})(X^\mathtt{(T,P)}) \end{equation}

Some axes of variation include

- The structure on $A$, which affects its parameter shape:
`... = (N,N)`

for general (unstructured) SSMs`... = (N)`

for diagonal SSMs (or other structures, such as diagonal-plus-low-rank) `... = ()`

for scalar SSMs (i.e. SSD)

- The state dimension $\mathtt{N}$ (i.e.
`d_state`

) - The head dimension $\mathtt{P}$ (i.e.
`d_head`

)

There are other axes of variation of structured SSMs (e.g. time-invariance vs. selectivity, SISO vs. MIMO

But first, let’s switch tacks and forget about state space models for a moment. Given the same tensors above with the same shapes $(A^\mathtt{(T)}, B^\mathtt{(T, N)}, C^\mathtt{(T, N)})$, let’s define a different object.

First, we’ll define the following matrix (don’t worry, we’ll explain more and give it a name in Part II of this series!)

\[L = \begin{bmatrix} 1 & \\ a_1 & 1 & \\ a_2a_1 & a_2 & 1 \\ \vdots & \vdots & \ddots & \ddots \\ a_{\mathtt{T}-1}\dots a_1 & a_{\mathtt{T}-1}\dots a_2 & \dots & a_{\mathtt{T}-1} & 1 \\ \end{bmatrix} .\]Then, let’s define the following matrix

\begin{equation} \label{eq:ssd-attention} M = L \circ C B^\top \in \mathbb{R}^{\mathtt{(T,T)}} \end{equation}

Finally, $M$ encodes a *sequence transformation* $x \in \mathbb{R}^\mathtt{T} \to y \in \mathbb{R}^\mathtt{T}$ mapping a 1D input to a 1D output—just as in equation \eqref{eq:ssm}—through basic matrix multiplication $y = Mx$.

What’s special about this? Well, you may notice that it looks very similar to an attention computation. In fact, if all $a_t = 1$, then $L$ is simply the lower-triangular *causal mask* and \eqref{eq:ssd-attention} is equivalent to **causal linear attention**

This is exactly the same as equation \eqref{eq:ssd-attention} if we rename $(C, B, X) \mapsto (Q, K, V)$!

The so-called “duality” refers to the fact that the two models defined in equations \eqref{eq:ssm} (for the scalar-identity structured $A_t$ case) and \eqref{eq:ssd-attention} are actually *exactly the same model*, which we can view as a particular function

In the general *SSD Framework* (Part II of this series), we’ll show this equivalence in two completely different ways, both of which are actually much more general and each quite illuminating.

If you take our word for it, though, then SSD is relatively simple to contrast in relation to either SSMs or attention.

Compared to previous SSMs, SSD is pretty much the same as the core layer of Mamba but with even more structure on the recurrent $A$ matrices.

- Mamba-1 (S6) uses diagonal structure on $A$, while Mamba-2 (SSD) uses scalar-times-identity structure on $A$.
- Mamba-1 has a head dimension of $\mathtt{P}=1$ (i.e. all channels are completely independently controlled by separate SSMs), while Mamba-2 uses a head dimension of $\mathtt{P}>1$ (something like $\mathtt{P}=64$ by default).

In particular, this can be viewed as weight-tied in two ways:

- By restricting the diagonal structure of $A$ to scalar-times-identity, the recurrence dynamics are shared across all $\mathtt{N}$ elements of the state space.
- These dynamics are also shared across all $\mathtt{P}$ channels of a given head.

In other words, a single SSM head has total state size $\mathtt{P} \times \mathtt{N}$, which are each governed by separate scalar recurrences in Mamba-1 but are controlled by a single shared recurrence in Mamba-2.

Why make these restrictions? The main motivation is efficiency: these changes are necessary to be able to view the model in its [dual attention form], which allows matrix multiplications to be used.

## The Bottom Line: Mamba-1 vs. Mamba-2

Compared to Mamba-1, Mamba-2 allows

much larger state dimensions(from`N=16`

in Mamba-1 to`N=64`

to`N=256`

or even higher in Mamba-2) while simultaneously beingmuch faster during training.

But can this hurt us? There’s some intuition to believe that it shouldn’t. One of the main reasons for the selectivity (e.g. $A$ that depends on the input $X$) introduced in Mamba is to let the SSM be able to control whether to remember or ignore particular pieces of information; for example, if a filler “um” is encountered in a text transcript. But if such information should be ignored, then the entire state can ignore it together, and so it should be okay if the state’s dynamics are shared across all features.

Empirically, we haven’t found evidence that the restricted expressivity of Mamba-2 might hurt, but the jury’s still out! From one perspective, Mamba-2 isn’t *strictly* better than Mamba-1: while it’s a dramatic improvement from a *training* perspective, Mamba-1 might be better from a pure *inference* perspective. Since inference speed of SSMs is entirely governed by the state dimension, if one wants to maximize performance for a target inference efficiency (i.e. for a particular state size $\mathtt{N}$), then the increased expressivity of Mamba-1 might be better. We haven’t fully analyzed the (theoretical or empirical) tradeoffs here, and think this would be a cool direction for the community to dig in more!

Compared, to standard (self-)attention, SSD also only has two differences:

- The softmax normalization is dropped.
- A separate elementwise mask matrix is applied multiplicatively.

The first difference can be interpreted as what reduces the effective state size of the model from linear to constant, and improves its efficiency from quadratic to linear.

The second difference is what distinguishes SSD from standard linear attention. One way to think of the mask is as **input-dependent relative positional encodings**. Because of the mask $L$ in \eqref{eq:ssd-attention}, the standard attention score $\langle Q_i, K_j \rangle$ is attenuated by a weight

which can be interpreted as a “discount factor” based on how far apart the positions $i$ and $j$ are. (This interpretation was concurrently espoused by Tobias Katsch’s GateLoop paper

So why do we care that there are two views of this model? Well, first of all, it’s extremely mathematically interesting, as we’ll cover in Part II, and we hope will inspire future directions. But there are immediate practical benefits too!

The SSM \eqref{eq:ssm} and attention \eqref{eq:ssd-attention} modes represent two different ways of computing the same function, so let’s contrast them.

First, remember that one main reason why SSMs are interesting to begin with is because computing \eqref{eq:ssm} as a recurrence requires maintaining a *constant-size state* (size $\mathtt{N}$ per channel) and scales *linearly in the sequence length* $\mathtt{T}$. The downside is that the raw FLOPs don’t reflect actual speed in practice because of hardware considerations…

On the other hand, computing this sequence transformation $y = Mx$ through equation \eqref{eq:ssd-attention} takes quadratic time in the sequence length, because we’re materializing this $\mathtt{T} \times \mathtt{T}$ matrix. But it can be fast in practice because it only uses matrix multiplications, which are extremely optimized on GPUs and TPUs.

So if there are two equivalent ways of computing the same model, when should we use one mode or the other? During inference, there’s no trade-off: the SSM mode is designed for fast autoregressive inference. But what about training? Here there’s a tension between FLOPs and hardware efficiency where the attention mode uses more FLOPs, but uses them more efficiently through matrix multiplications.

It turns out we can get the best of both worlds by combining the algorithms! There are two equivalent interpretations of this “state space dual” algorithm, either as

- A block decomposition of a particular structured matrix that defines the SSD “token-mixing” sequence transformation.
- A “chunkwise” algorithm that splits the sequence into segments, computes the quadratic attention form on each segment, and adjusts the result by passing the SSM states between segments.

We’ll leave the details of this algorithm to Part III (or Section 6 of the full paper), as it requires a bit of machinery from the theory to derive. But we do emphasize that the implementation of this algorithm isn’t too complicated – a minimal implementation that we provide is only ~30 lines of PyTorch!

The benefits of the SSD algorithm is that it preserves the same efficient FLOP counts as SSMs (compared to quadratic attention), and also dramatically speeds up training compared to general state space models by utilizing matmuls.

Attention | SSM | SSD | |
---|---|---|---|

State size | $\mathrm{T}$ | $\mathbf{N}$ | $\mathbf{N}$ |

Training FLOPs | $\mathrm{T}^2\mathrm{N}$ | $\mathbf{TN^2}$ | $\mathbf{TN^2}$ |

Inference FLOPs | $\mathrm{T}\mathrm{N}$ | $\mathbf{N^2}$ | $\mathbf{N^2}$ |

(Naive) memory | $\mathrm{T}^2$ | $\mathrm{TN}^2$ | $\mathbf{TN}$ |

Matrix multiplications? | :heavy_check_mark: | :x: | :heavy_check_mark: |

Although the core contribution of Mamba-2 is the new SSD layer and theory, we also make some small changes to Mamba’s neural network architecture.

The main change is producing the $(A, B, C)$ SSM parameters in parallel with the $X$ input, instead of sequentially. This is partly motivated by the connections to attention; but more pragmatically, it’s simpler and more amenable to scaling techniques such as tensor parallelism, which will be discussed in Part IV of this series!

There are some other small differences which are covered in more detail in the paper. However, we do want to emphasize that these architectural changes aren’t really the main point of the model.

In terms of empirical results, we didn’t test Mamba-2 as extensively as Mamba-1, but believe it should generally be on par or better across the board. Our full language model results use the same protocol as Mamba, and found slightly better scaling at Chinchilla laws

Fully trained models on the Pile dataset*much* faster to train than Mamba-1!

More interestingly, we highlight the one synthetic task we tried. Since the original Mamba paper, which investigated synthetics such as Synthetic Copying and Induction Heads, many follow-up works have begun investigating harder associative recall tasks. The **multi-query associative recall (MQAR)** task introduced by the Zoology and Based

We ran a version of this task that’s much harder than the one usually reported in the literature, and found that Mamba-2 is substantially better than Mamba-1. One reason for the improved performance is the much larger state size (up to $16\times$ larger than Mamba-1 here), which was one of the primary motivations of Mamba-2 in the first place.

Interestingly, Mamba-2 also appears to be noticeably better than Mamba-1 on this particular task even when the state size is controlled. We’re not quite sure why to be honest, and it would be great to ablate the other aspects of the model to investigate… for example, could it be possible that the [restricted structure of SSD] is actually *helpful* here?

In the next part of this series, we’ll go more into the full SSD framework, including how to prove the claimed “duality” of the SSD layer, and strong generalizations of it.

]]>In Part I of this series, we defined the state space dual (SSD) *model*. In isolation, this model is relatively simple to define, and we claimed that it can be computed either as an SSM recurrence or with an attention-like pattern. If you just want to use the model, feel free to skip this post!

In this post, we’ll dive into the theory behind the model. We’ll derive the SSD “duality” in two completely separate ways, one starting from the SSM perspective and one from the attention perspective. Each method is actually much more broad than the SSD model itself, and the union of these two strong generalizations is what we call the SSD *framework*. This framework provides a rich body of connections between state space models, attention, and structured matrices. While the SSD model can be viewed as a specific instantiation of each prong of the framework, the SSD framework is much more general opens up many directions for future work.

For each of the two parts of this framework, we’ll

- Define the general concepts
- Show how the SSD model is an instantiation, and prove the duality
- Suggest future directions for how the framework can be used

Note that this theory is *not necessary* to use the SSD model itself; this part of the series can be safely skipped for the practitioner that just wants to use SSD (Mamba-2).

Part I of this series introduced the SSD layer, which is defined as a selective SSM

\[\begin{aligned} h_{t} &= A_t h_{t-1} + B_t x_t \\ y_t &= C_t^{\top} y_t \end{aligned}\]\begin{equation} \label{eq:ssm} (\text{Selective state space model (SSM)}) \end{equation}

with scalar-identity structure on $A$.

More formally, we view it as a *sequence transformation* $X \mapsto Y$

\begin{equation} \label{eq:ssm-transformation} Y^\mathtt{(T,P)} = \mathsf{SSM}(A^\mathtt{(T)}, B^\mathtt{(T,N)}, C^\mathtt{(T,N)})(X^\mathtt{(T,P)}) \end{equation}

The dual attention-like form of the SSD layer is

\begin{equation} \label{eq:ssd-attention} M = L \circ C B^\top \in \mathbb{R}^{\mathtt{(T,T)}} \end{equation}

Now let’s see how to prove this!

The first framing of the duality will be from an SSM-centric perspective, where we’ll prove the duality through the framework of **matrix sequence transformations** or “matrix mixers”.

The idea is that many sequence models, i.e. *sequence transformations* $X \in \mathbb{R}^\mathtt{(T,P)} \mapsto Y \in \mathbb{R}^\mathtt{(T,P)}$, can be written in the form of a single matrix multiplication $Y = M(X) \cdot X$ where $M$ is a matrix which can itself depend on $X$. We call this a *matrix sequence transformation*, or matrix transformation for short. In the literature sequence transformations have also been referred to as “sequence mixers” or “token mixers”, and matrix sequence transformations as “matrix mixers”. There are many examples of these, which are distinguished by the structure of the $M$ matrix. The de facto example is self-attention itself, where $M = \mathsf{softmax}(QK^\top)$ is the attention matrix. Other examples include MLP-Mixer

Why do we care about these types of models?

Writing a sequence model as a matrix transformation provides a powerful tool to understand the structure and characteristics of the model.

And although general non-linear RNNs such as LSTMs *cannot* be written as matrix mixers, state space models can! In fact, this is pretty easy to see by just unrolling the definition of the SSM recurrence. The upshot is that the SSM \eqref{eq:ssm-transformation} can be written as a matrix transformation

where $M_{ij} = 0$ for $i < j$ (i.e. it’s lower triangular) and otherwise \begin{equation} \label{eq:semiseparable} M_{ij} = C_i^\top A_{i:j}^\times B_j := C_i^\top A_i \dots A_{j+1} B_j \end{equation}

Drawing it out, this matrix looks like

\[\begin{bmatrix} C_0^\top B_0 & \\ C_1^\top A_1 B_0 & C_1^\top B_1 & \\ C_2^\top A_2A_1 B_0 & C_2^\top A_2 B_1 & C_2^\top B_2 \\ \vdots & \vdots & \ddots & \ddots \\ C_\mathtt{T}^\top A_{\mathtt{T}-1}\dots A_1 B_0 & C_\mathtt{T}^\top A_{\mathtt{T}-1}\dots A_2 B_1 & \dots & C_\mathtt{T}^\top A_{\mathtt{T}-1} B_{\mathtt{T}-2} & C_\mathtt{T}^\top B_{\mathtt{T}-1} \\ \end{bmatrix}\]\begin{equation} \label{eq:ssm-matrix} (\text{Matrix Transformation Representation of State Space Models}) \end{equation}

This type of matrix in fact has a name: it’s called a (triangular) **semiseparable matrix**, and has been studied in other fields of engineering and computational linear algebra*structured rank property*, which says that every submatrix contained in the lower-triangular portion is low rank.

For our purposes, we’ll care about this form mainly for the algorithmic considerations. One of the central messages of this SSD paper is that:

## Takeaway: Computing SSMs Through Matrix Multiplication

All algorithms for computing state space models can be viewed as structured matrix multiplication algorithms on semiseparable matrices.

Let’s see an easy instantiation of this, focusing on our main objective!

To show that equation \eqref{eq:ssd-attention} follows from equation \eqref{eq:ssm} (in the case of the SSD model, i.e. scalar SSM), we directly use the matrix form of the state space model \eqref{eq:semiseparable}. Because the $A_t$ are all scalars in this case, they can be factored out of the entries

\[C_i^\top A_{i:j}^\times B_j = A_{i:j}^\times \cdot (C_i^\top B_j)\]which directly implies equation \eqref{eq:ssd-attention}.

In summary:

## Duality Representation 1 (SSM)

The duality for the SSD model can be seen as two

different matrix multiplication algorithmson the semiseparable matrix.

- The linear form is a
*structured matrix multiplication algorithm*that computes the outputs $Y_0, Y_1, \dots$ sequentially, leveraging the structure of the semiseparable matrix. - The quadratic form is the
*naive matrix multiplication algorithm*that materializes the full matrix.

The power of the semiseparable matrix representation applies to *all* state space models, with various downstream implications.

Algorithmically, the Mamba-2 paper explores several consequences, such as:

- The above duality result for the SSD model, i.e. a scalar-identity structured SSM.
- New asymptotic efficiency results for state space models (Theorem 3.7), which follow from applying known results from the semiseparable matrix literature
. - A more general hybrid algorithm that can be viewed as combining both the linear and quadratic forms to get the best of both worlds. This can be derived as a new matrix multiplication algorithm utilizing
*block decompositions*of the semiseparable matrix. This is the subject of Part III of this blog series!

Conceptually, the matrix transformation viewpoint helps provide a unifying view of sequence models. Some example downstream ideas include

**New sequence models**: Restricting ourselves to matrix transformations reduces the problem of developing new sequence models to that of finding structured matrix classes with target properties. In ongoing work by my students, we study this point of view, and use it to derive the most natural bidirectional extension of Mamba (coming very soon!).**Expressivity**: Looking at the matrix transformation representation can help us understand what different models can represent from a linear algebraic perspective. In another ongoing work, we use this as a tool to study which subquadratic models are the most amenable to being distilled from Transformers.**Interpretability**: A concurrent workderived the matrix formulation of SSMs and use it to probe the internal representations of Mamba models.

We’re excited to see what algorithmic and conceptual ideas from the structured matrix literature can be applied to further improve state space models!

The second framing of the duality is from an attention-centric perspective, where we’ll prove the duality through the framework of **tensor contractions**.

Note that this is entirely independent of the previous [matrix transformation viewpoint].

For our purposes, we’ll define attention as a function

\[(Q^\mathtt{(T,N)}, K^\mathtt{(S,N)} , V^\mathtt{(S,P)} ) \mapsto Y^\mathtt{(T,P)}\]given by the pairwise matrix multiplications

\[Y = (QK^\top) \cdot V\]Think of $\mathtt{P} = \mathtt{N}$ as the head dimension; technically speaking, in attention the $V$ head dimension $\mathtt{P}$ can differ from the $QK$ head dimension $\mathtt{N}$. Think of $\mathtt{T}$ as the *target* sequence dimension and $\mathtt{S}$ as the *source* sequence dimension. Giving these two axes different names will make the math more clear and also covers more general forms of attention such as cross-attention, where the source and target are separate sequences with different lengths. However, for our purposes we’ll assume the self-attention setting where $\mathtt{S}=\mathtt{T}$.

The usual form of attention $Y = f(QK^\top) \cdot V$ (e.g. where $f$ is the softmax function) can, for essentially all functions $f$**feature dimension** of the attention kernel to begin with. Softmax attention, for example, can be represented with a particular infinite-dimensional feature map ($\mathtt{N}=\infty$) which represents the exponential kernel.

We’ll restrict ourselves to the case when $\psi$ is finite, which is sometimes called **kernel attention**. Many, many variants have been proposed before!

Why do we care about this formulation? When the sequence length $\mathtt{T}$ grows and the feature dimension $\mathtt{N}$ is small—commonly, in the regime when $\psi$ is simple such as an elementwise transform and so $\mathtt{N}$ is constant—then the cost of attention can be reduced from quadratic in $\mathtt{T}$ to linear. This follows from simply computing the matrix multiplications in a different order

\[Y = Q \cdot (K^\top V)\]This is a somewhat “folklore” interpretation of linear attention.

The most common way of linearizing attention is usually viewed as a consequence of the

associativity of matrix multiplication

However, once the basic kernel attention is slightly modified, we can no longer use the associativity of matrix multiplication directly.

The seminal **Linear Attention (LA)** framework of Katharopoulos et al.

Let’s be a lot more explicit about how it works. The quadratic form of **causal linear attention** is \begin{equation} \label{eq:quadratic-kernel-attention} Y = (L \circ QK^\top) \cdot V \end{equation} where

is the **causal mask** matrix.

The issue is: once the $L$ mask is incorporated into \eqref{eq:quadratic-kernel-attention}, we can no longer directly apply matrix associativity! This is the problem that the original Linear Attention paper addresses. What they show is that \eqref{eq:quadratic-kernel-attention} is equivalent to a different form which avoids materializing the quadratic $QK^\top$ attention matrix and has linear time complexity

\[Y = Q \cdot \mathsf{cumsum}(K^\top V)\]As far as we’re aware this wasn’t explicitly proved in the paper, although it isn’t too hard to write out the summation to show it.

What we’ll do is prove this equivalence in essentially one line, while revealing *exactly* where the “linear” part of Linear Attention comes from, and how to strongly generalize it.

Spoiler alert:

## Where does the cumsum in Linear Attention come from?

The appearance of the

\[y = L \cdot x \iff y = \mathsf{cumsum}(x)\]cumulative sumin linear attention is exactly equivalent to the fact that the causal mask $L$, as a matrix multiplication, encodes cumulative sums:

Let’s write out the quadratic form of linear attention \eqref{eq:quadratic-kernel-attention} very explicitly in **tensor contraction** or einsum notation, with shape annotations:

\begin{equation} \label{eq:sma-quad} (\text{Structured Masked Attention - Quadratic Form}) \end{equation}

With this notation, we can notice that this sequence of contractions can be written as a *single four-way contraction*

\begin{equation} \label{eq:sma} y = \mathsf{contract}(\mathtt{TN},\mathtt{SN},\mathtt{SP},\mathtt{TS} \to \mathtt{TP})(Q, K, V, L) . \end{equation}

And finally, it can be computed with any other contraction ordering. In particular, we can perform pairwise reductions on the order $V, K, L, Q$ instead of $Q, K, L, V$

\[\begin{aligned} Z &= \mathsf{contract}(\mathtt{SP},\mathtt{SN} \to \mathtt{SPN})(V, K) \\ H &= \mathsf{contract}(\mathtt{TS},\mathtt{SPN} \to \mathtt{TPN})(L, Z) \\ Y &= \mathsf{contract}(\mathtt{TN},\mathtt{TPN} \to \mathtt{TP})(Q, H) \end{aligned}\]\begin{equation} \label{eq:sma-lin} (\text{Structured Masked Attention - Linear Form}) \end{equation}

Now the key observation is that the second line of \eqref{eq:sma-lin} is simply a matrix multiplication by $L$, which can be computed with a cumulative sum.

That’s the entire proof of linear attention! The beauty of it is that we didn’t have to write out a single summation, which was abstracted out into a tensor contraction combined with the structure of $L$.

This immediately proves our claim about the cumsum in linear attention. Moreover, this immediately reveals that the efficiency of linear attention can be made much more general…

The critical observation is that in order for \eqref{eq:sma-lin} to be fast, all that is necessary is for $L$ to be *any structured matrix* – in other words any matrix that has subquadratic matrix-vector multiplication.

This immediately motivates one of the main prongs of the SSD framework, which can be seen as a strong generation of LA.

## Definition: Structured Masked Attention

Structured masked attention (SMA)is defined as thefour-way tensor contraction\eqref{eq:sma} using an attention mask $L$ that is a structured matrix.

## Duality Representation 2 (SMA)

SMA has

dual quadratic and linearAssuming that the structured matrix $L$ has linear time matrix-vector multiplication modeswhich are simplytwo different pairwise reduction orders\eqref{eq:sma-quad} and \eqref{eq:sma-lin}.

Finally, let’s just connect this back to the commonly held view of linear attention as matrix multiplication associativity.

Although it is commonly believed that incorporating attention masks $L$ prevents matrix multiplication reordering, it turns out to still be compatible. In particular,

associativity of matrix multiplicationis a special case oftensor contraction reduction orders; although the former no longer applies, the latter can integrate the attention mask $L$.

Next, let’s look at some consequences of the structured attention framework.

Recall that the SSD model is defined as either a scalar-identity SSM in equation \eqref{eq:ssm}, or through the attention-like form in equation \eqref{eq:ssd-attention}.

To show the equivalence of these forms, we simply recognize that \eqref{eq:ssd-attention} is a special case of structured masked attention where the mask matrix is

\[L = \begin{bmatrix} 1 & \\ a_1 & 1 & \\ a_2a_1 & a_2 & 1 \\ \vdots & \vdots & \ddots & \ddots \\ a_{\mathtt{T}-1}\dots a_1 & a_{\mathtt{T}-1}\dots a_2 & \dots & a_{\mathtt{T}-1} & 1 \\ \end{bmatrix} .\]\begin{equation} \label{eq:1-ss} (\text{1-semiseparable (1-SS) matrix}) \end{equation}

We call this a **1-semiseparable (1-SS) matrix**, for reasons that are explained in more detail in the Mamba-2 paper.

Thus, we can also say that the SSD model is **1-semiseparable masked attention** or **1-SS SMA**.

To prove that this can be written as an SSM, we simply appeal to the SMA framework, which says that the dual form of this model can be computed through matrix multiplication by $L$. So how fast is that? It’s not too hard to see that multiplication $y = Lx$ can be computed in linear time through a scalar recurrence:

\[\begin{aligned} y_0 &= x_0 \\ y_1 &= a_1 x_0 + a_1 \\ y_2 &= a_2a_1 x_0 + a_2 x_1 + x_2 = a_2 y_1 + x_2 \\ \vdots & \qquad \vdots \end{aligned}\]This corresponds exactly to the original SSM recurrence!

(In fact, multiplication by 1-SS matrices $L$ can be computed in a *lot* more ways, which we compile in the full paper! Alternative algorithms can reveal more insights: for example, the associative scan algorithm used by S5

Structured masked attention not only helps define the SSD model and prove its duality, but it is a much broader framework of efficient attention models.

Prior examples include the original linear attention as well as the recent Retentive Network (RetNet) model*any structured matrix*. As a suggestion, we think that Toeplitz or Fourier structured attention may be interesting to consider because they might encode different forms of positional information.

Additionally, other forms of structure can be incorporated into the $L$ mask. For example, another extension my students are developing is viewing SSD (and recurrences in general) as an algorithm operating on *directed line graphs*, and generalizing it to incorporate arbitrary graph structures.

We’ll end this post with a brief recap of what we’ve covered.

The **SSD framework** consists of the two broad approaches covered in this post, which is summarized by the two areas of the [Venn diagram]:

- Viewing state space models through [structured matrix transformations]
- Generalizing linear attention through [tensor contractions]

The [SSD layer] is a particular model which is the purple intersection in the figure, which can be viewed as an instance of either part of the SSD framework, and in particular has dual quadratic and linear forms that can be derived from either representation.

SSD Framework | Structured SSMs | Structured Attention |
---|---|---|

The main representation is… | Structured matrix \eqref{eq:ssm-matrix} sequence transformations | The 4-way \eqref{eq:sma} tensor contraction |

This generalizes… | State space models | Linear attention |

The SSD model is an instantiation as… | Scalar state space model ($A_t$ is a scalar-identity matrix) | 1-semiseparable masked attention ($L$ mask is a 1-SS matrix) |

The linear-quadratic duality is revealed through… | Structured matrix multiplication algorithms | Tensor contraction reduction orderings |

In the next part of this series, we’ll see how to use some of the SSD framework (in particular, the structured matrix algorithm point of view) to derive the more efficient hybrid SSD algorithm that leverages both of the dual forms.

]]>The theoretical framework of structured state space duality (see Part I and Part II of this series) connects SSMs and (linear) attention through structured matrices. As mentioned in Part I, this connection allows us to derive new algorithms for selective SSMs that are faster than the parallel associative scan in Mamba-1 by leveraging matrix multiplication as a primitive. Moreover, the connection can bring system optimizations (e.g. tensor parallelism, sequence parallelism, variable sequence length) originally developed for Transformer to SSM-land.

Even though we already developed optimized scans implementations for Mamba-1, we were limited to small state expansion (typically $\mathtt{N}=16$) as the algorithm and implementation did not use tensor cores (specialized hardware units that perform matrix multiplication). Typically matrix multiplication (matmul) FLOPs are much faster (up to 16x) than non-matmul FLOPs: the A100 GPU has 312 TFLOPS of BF16 matmul but only 19 TFLOPS of FP32 arithmetics, and the H100 has 989 TFLOPS of BF16 matmul but only 67 TFLOPS of FP32 arithmetics. One of our primary goals with Mamba-2 is to **leverage tensor cores to speed up the SSM**.

To recap, after tying parameters and introducing the head structure, the SSM in Mamba-1 turns into SSD, a more restrictive form that has an attention-like formulation. And as SSD connects SSMs and structured matrices, we saw in Part II that efficient algorithms to compute SSMs correspond directly to different decompositions of the “token-mixing” or “sequence-mixing” matrix $M$.

We can therefore create new algorithms to compute SSMs simply by looking for alternative ways to multiply this matrix, for example by decomposing it in various ways. A simple block decomposition of this matrix, with carefully chosen block sizes, turns out to get all the advantages of both the linear-recurrent and quadratic-attention dual forms of SSD. This leads to the SSD algorithm, which has 4 steps. There are two completely different interpretations of this algorithm!

We first partition the SSM (semiseparable) matrix into blocks of size $\mathtt{Q} \times \mathtt{Q}$. Then, we use the properties of semiseparable matrices to factorize each off-diagonal block, which is low rank.

- (
*Orange*) Each diagonal block is a smaller semiseparable matrix; we can compute this multiplication however we like; in particular, using the quadratic (attention-like) form of SSD. - (
*Green*) There are only $\mathtt{T} / \mathtt{Q}$ total different green blocks because many of them are shared. These can be computed with a batched matmul. - (
*Yellow*) Notice that the yellow terms themselves form a 1-semiseparable matrix; in other words, this step is equivalently to an SSM scan (on some modified $A$ factors)! - (
*Blue*) Similar to green, these can be computed with a batched matmul.

An alternative interpretation of the algorithm involves reasoning about how the SSM operates on the actual sequence. We first split the sequence of input into blocks (or chunks) of size $\mathtt{Q}$. The steps then have the interpretation

**Intra-chunk outputs**: compute the local output of each chunk (*what is the output per chunk supposing that the initial state (to the chunk) is 0?*)**Chunk states**: compute the final state of each chunk (*what is the final state per chunk supposing that the initial state (to the chunk) is 0?*)**Pass states**: compute a recurrence on all of the chunks’ final states – using any desired algorithm, e.g. parallel or sequential scan (*what is the actual final state per chunk taking into account all previous inputs?*)**Output states**: for each chunk, given its true initial state (computed in Step 3), compute the contribution to the output just from the initial state

Either way, we see that most of the algorithm (Step 1, 2, and 4) leverages matmuls (and hence tensor cores), and also can be computed completely in parallel! Only Step 3 requires a scan, but it operates on a much shorter sequence and usually only takes a small fraction of the time of the full algorithm.

We note that special cases of this algorithm have been seen before. In particular RetNet

Other forms of “chunkwise” recurrences have recently become popular, such as in Gated Linear Attention (GLA)

In the “Minimal SSD” code that we provide in the paper and the code release, we delineate each of these four steps. As promised, this algorithm is not only faster but also much easier to implement than the original selective scan of Mamba, coming in at just around 25 lines of code!

```
def segsum(x):
"""Naive segment sum calculation. exp(segsum(A)) produces a 1-SS matrix,
which is equivalent to a scalar SSM."""
T = x.size(-1)
x_cumsum = torch.cumsum(x, dim=-1)
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def ssd(X, A, B, C, block_len=64, initial_states=None):
"""
Arguments:
X: (batch, length, n_heads, d_head)
A: (batch, length, n_heads)
B: (batch, length, n_heads, d_state)
C: (batch, length, n_heads, d_state)
Return:
Y: (batch, length, n_heads, d_head)
"""
assert X.dtype == A.dtype == B.dtype == C.dtype
assert X.shape[1] % block_len == 0
# Rearrange into blocks/chunks
X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A))
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")
return Y, final_state
```

Let’s talk about a couple of additional details in the implementation (these don’t even appear in the full paper, so pay attention!) that unpack some of the choices in this reference code.

In the above code, we utilized the connection between scalar SSM recurrences

\[h_{t+1} = A_t h_t + B_t x_t\]and matrix multiplication by 1-semiseparable matrices

\[L = \begin{bmatrix} 1 & \\ a_1 & 1 & \\ a_2a_1 & a_2 & 1 \\ \vdots & \vdots & \ddots & \ddots \\ a_{\mathtt{T}-1}\dots a_1 & a_{\mathtt{T}-1}\dots a_2 & \dots & a_{\mathtt{T}-1} & 1 \\ \end{bmatrix}\]which we covered in Part II (and Section 3.2.2 of the paper). In this minimal implementation, we compute Step 3 of the algorithm, which is computing a scalar SSM by *any* algorithm of our choice, by explicitly materializing a 1-SS matrix and doing dense matrix multiplication.

We use this version for several reasons:

- Code-wise, it’s simpler to materialize and multiply by this matrix than to actually implement a parallel associative scan
- Because of the block decomposition of the SSM matrix, the sequence length $\mathtt{T}$ is reduced by a factor of $\approx 100$ – so doing the scan in time $O(\mathtt{T}^2)$ instead of $O(\mathtt{T})$ isn’t too bad
- We have to materialize a 1-SS matrix anyways for Step 1 of the algorithm (the diagonal blocks), so might as well reuse the code ¯\_(ツ)_/¯

While this example code is simpler and reasonably efficient on GPU (and probably TPU as well!), it’s no longer truly linear at long sequences. Our more optimized Triton implementation does replace the 1-SS multiplication in Step 3 with an actual associative scan.

The first naive attempt may be to notice that the entries of this matrix are cumulative products

\[a_{i:j}^\times = a_i \times \cdots \times a_{j-1} = \frac{a_{i:\mathtt{T}}^\times}{a_{j:\mathtt{T}}^\times}\]However, this runs into severe numerical issues because these products can get really tiny (imagine $a_t \approx 0.9$ and powering it up for a sequence length $\mathtt{T}$ in the thousands!)

`segsum`

) OperationThe second attempt would be to do all of this in log-space, because all the $a_t$ are positive; so the products become additions, and instead of `cumprod`

s to deal with we have `cumsum`

s instead. Then in order to compute the 1-SS matrix, we just have to compute the sums $\log a_i + \dots + \log a_{j-1}$ for every *segment* $[i:j]$. We call this the **segment sum (segsum)** primitive, analogous to cumulative sum (cumsum).

The obvious way to do this again is using the same idea as above, but in log space

\[a_{i:j}^\times = \exp\left( \log a_i + \cdots + \log a_{j-1} \right) = \left( (\log a)_{i:\mathtt{T}}^+ - (\log a)_{j:\mathtt{T}}^+ \right)\]where we compute a single cumulative sum of $a$ along the time axis, and then compute all pairwise differences. In code, we can do this with

```
def segsum_unstable(x):
"""Naive segment sum calculation."""
T = x.size(-1)
x_cumsum = torch.cumsum(x, dim=-1)
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
```

(and then the 1-semiseparable matrix is just the exponential of this output).

Sums/differences are a lot more stable than products/quotients, so this should work – right?

Unfortunately, it turns out this still doesn’t work. The values of this 1-SS matrix roughly represent the SSM dynamics, which are very sensitive to these values of $a_t$, so we have to be very precise. And even in log space, these cumsums can be fairly large, which runs into catastrophic cancellation when subtracted. So we really have to find a way to compute this matrix with only additions, while still vectorizing everything…

This leads to the helper function in the reference SSD code. Instead of computing a single cumsum and then subtracting, we find a way to use a batch of independent cumsums that immediately produces the right answer without subtraction.

These details do matter! Without the right implementation of these primitives, the basic SSD algorithm produces NaNs immediately during training (even with FP32).

This lineage of structured state space models developed from S4 and its predecessors which were viewed as continuous-time systems.

In Mamba, however, we don’t really view the SSM as continuous anymore. In fact, as mentioned in the Discussion (Section 5) of the original paper, Mamba trades off with S4 on modeling different types of data:

- S4 is a continuous-time model that excels at modeling continuous data, e.g. perceptual signals such as audio waveforms and pixel-level vision.
- Mamba is a discrete-time model that excels at modeling discrete data, e.g. tokenized data such as language.

However, the parameterization of Mamba still used the same discretization step as in prior structured SSMs, where there is another parameter $\Delta$ being modeled. We do this because the discretization step has other side effects such as properly normalizing the activations

The initializations and parameterizations from the previous theory on structured SSMs still work out-of-the-box, so why fix what’s not broken?

Despite this, we’re pretty sure that the discretization step isn’t really necessary for Mamba. In the Mamba-2 paper, we chose to work directly with the “discrete parameters” $A$ and $B$, which in all previous structured SSM papers (including Mamba-1) were denoted $(\bar{A}, \bar{B})$ and defined through an additional transformation

\[\begin{align*} \bar{A} &= \exp(e^{\Delta A}) \\ \bar{B} &= (\exp(e^{\Delta A}) - I) A^{-1} B \end{align*}\]This doesn’t pose any problems: to use the continuous SSM parameterization, simply transform the parameters through the above formulas before plugging into the SSD code above.

In the full Mamba-2 code, we also kept the same parameterization and discretization step as in Mamba—again, why fix what’s not broken?—but hypothesize that “discrete-centric” variants (such as the *gamma normalization* of LRU

## Is Discretization Necessary?

It’s useful for other structured SSMs, but perhaps not needed for Mamba. But it’s just a simple invertible transformation, so use either discrete or continuous parameterizations as you like!

In the final part of this series, we’ll continue talking about the implementation of Mamba-2, but on a more macroscopic level; about the entire neural network, instead of just details of the core SSD layer.

We’ll also talk about the actual speed of the algorithm covered in this post.

]]>Transformers have benefited from 7 years of systems optimization from the whole research community and large companies. The SSD framework draws connections between SSMs and attention, and allows us to implement many of these optimizations for models like Mamba-2 as well. We focus on tensor parallel and sequence parallel for large-scale training, as well as variable-length sequences for efficient finetuning and inference.

One difficulty with large-scaling training of Mamba-1 using tensor parallelism (TP) is that it requires 2 all-reduces per layer, compared to just 1 all-reduce per attention or MLP layer in Transformer. This is because some of the SSM parameters are functions of the inner activations, not of the input to the layer. In Mamba-2, with the “parallel projection” structure, all SSM parameters are functions of the input to the layer, and we can easily apply TP to the input projection: We split the input projection and output projection matrices into 2, 4, 8 shards, depending on the TP degree. We use a grouped norm with number of groups divisible by the TP degree, so that normalization is done separately per GPU. These changes result in 1 all-reduce per layer, instead of 2.

When training on very long sequence length, we might need to split along the sequence length and assign different parts to different devices. There are two main forms of sequence parallelism (SP): For the residual and normalization operation: this replaces the all-reduce in TP with a reduce-scatter, residual + normalization, then all-gather. Since Mamba-2 uses the same residual and normalization structure as Transformer, this form of SP applies directly with no modification. For the attention or SSM operation, aka context parallelism (CP). For attention, one could use Ring attention to split it up along the sequence dimension. For Mamba-2, the SSD framework comes to our help once again: using the same block decomposition, we can have each GPU computing its local output and its final states, then pass the states between GPUs (using send/receive communication primitives), before updating the final output of each GPU.

For finetuning and inference, in the same batch we often have sequences of different lengths. For Transformer, one would usually pad so all sequences have the same length (wasting computation), or implement attention specifically for variable length sequences with careful load-balancing. With SSM, we can simply treat the whole batch as a long “sequence”, and avoid passing the states between different sequences in the batch by setting the state transition $A_t$ to 0 for tokens at the end of each sequence.

How well do these optimizations work? The faster SSD algorithm allows us to increase the state dimension ($\mathtt{N}=64$ or $128$ compared to $\mathtt{N}=16$ in Mamba-1). Even though technically Mamba-2 is more restricted than Mamba-1 for the same $\mathtt{N}$, the larger state dimensions generally improve model quality. Here we show results for models trained on 300B tokens on the Pile, with Mamba-2 outperforming Mamba-1 and Pythia.

What about **hybrid models**? We have seen from recent and concurrent work (such as Jamba and Zamba) that combining Mamba layers with attention layers can improve over pure Transformer or Mamba. We validate at 2.7B parameters and 300B tokens scale that a hybrid model with just 6 attention blocks (and 58 SSD blocks) outperforms 64 SSD blocks, as well as our standard Transformer++ baseline (32 gated MLP and 32 attention blocks).

We also validated that the SSD algorithm is significantly faster than the selective scan algorithm from Mamba-1 for the same state dimension, and scales much better computationally to larger state dimensions. Getting those tensor cores to go brrr is the key!

With SSD, we have connected (linear) attention and SSMs, allowing us to design faster algorithms and implement systems optimizations for SSMs. There are still tons of exciting directions that we (and hopefully the community) want to tackle:

**Understanding**: hybrid models with a few (4-6) attention layers perform very well, even better than pure Mamba(-2) or Transformer++. What are these attention layers doing? Can they be replaced with another mechanism?**Training optimizations**: though SSD might be faster than attention, Mamba-2 as a whole might still be slower than Transformers at short (e.g. 2K) sequence length, since the MLP layers in Transformers are very hardware-friendly. Our implementation of SSD does not specifically take advantage of new features on H100 GPUs, and we look forward to future optimizations that could make SSMs faster to train than Transformers for large-scale pretraining at 2-4K sequence length.**Inference optimizations**: there’s a whole suite of optimizations tailored to Transformers, in particular handling the KV cache (quantization, speculative decoding). How would the inference landscape change if model states (e.g. SSM states) no longer scale with context length, and KV cache is no longer the bottleneck?