This series is cross-posted at GoombaLab
Since the release of Mamba 6 months ago, we’ve been pleasantly surprised by the overwhelming community response. It’s been incredibly gratifying to see the line of research on efficient sequence models we’ve been pursuing for years really resonate with the machine learning community and take off more than we could have anticipated. We’ve seen an enormous amount of exciting follow-up work, from direct applications (e.g. vision
Yet despite its potential so far, we weren’t completely satisfied with the first version of Mamba…
From a conceptual standpoint, one of the reasons we found SSMs so fascinating is how they just feel fundamental. One way this is exemplified is how they have rich ties to many major paradigms of sequence models. As developed in our earlier works on structured SSMs
But of course, aside from these, there’s another major sequence model paradigm: variants of the ubiquitous attention mechanism
Question 1: What are the conceptual connections between state space models and attention? Can we combine them?
From a computational standpoint, despite the work that went into making Mamba fast (in particular, its hardware-aware selective scan implementation) it’s still much less hardware-efficient than mechanisms such as attention. The missing piece is that modern accelerators such as GPUs and TPUs are highly specialized for matrix multiplications. While this isn’t a problem for inference, which is bottlenecked by somewhat different considerations, this can be a big deal during training time.
Question 2: Can we speed up the training of Mamba models by recasting them as matrix multiplications?
These are the main questions that Mamba-2 – in particular, its new state space model variant – tries to address.
The main point of the Mamba-2 paper is what we call structured state space duality (SSD), which refers to several things:
The main SSD model or “state space dual model” itself really isn’t so complicated! In this first part of a series of blog posts, we’ll provide a self-contained description of the SSD layer (and Mamba-2) in isolation and how it compares to related models, particularly Mamba-1.
In the next parts of this series, we’ll describe the general framework and theoretical connections, which aren’t necessary to actually use Mamba-2.
SSD starts from the same set of equations as Mamba:
\[\begin{aligned} h_{t} &= A_t h_{t-1} + B_t x_t \\ y_t &= C_t^{\top} h_t \end{aligned}\]\begin{equation} \label{eq:ssm} (\text{Selective state space model (SSM)}) \end{equation}
To recap, a structured state space model (SSM)
A selective state space model allows the $(A, B, C)$ SSM parameters to vary across time
Structured SSMs require $A$ to have structure to be efficiently computable, such as the most commonly used diagonal structure
The original Mamba (or more precisely its core “S6” layer) is exactly a selective SSM with diagonal structure.
The SSD layer of Mamba-2 makes only one small modification: it restricts the diagonal $A$ even further to a scalar times identity structure; in other words the diagonal elements of $A$ must all be the same value. In this case $A$ can be represented with shape just $\mathtt{(T)}$ and one can also identify $A_t$ as just a scalar (and so we’ll sometimes denote it $a_t$).
Equation \eqref{eq:ssm} is defined only for a single dimensional input $x \in \mathbb{R}^\mathtt{T}$. If $X \in \mathbb{R}^\mathtt{(T, P)}$ has $\mathtt{P}$ separate channels, we can use the same dynamics (i.e. the same SSM $(A, B, C)$) independently for each channel. This can be interpreted as a single head of the SSM model.
Here, we think of $X$ as a tensor of shape $\mathtt{(T, P)}$ where $\mathtt{T}$ is the sequence (time) dimension and $\mathtt{P}$ is the “head dimension”.
Multiple heads can be constructed completely independently; for the remainder of this post, we assume that we’re working with a single head. Note that these heads are exactly analogous to how heads in multi-head attention models work, and in Mamba-2 we also choose similar dimensions as modern Transformers, e.g. $\mathtt{P} = 64$ or $\mathtt{P}=128$. (To scale to larger model widths $\mathtt{D} = \mathtt{d\_model}$, we keep this fixed and increase the number of independent heads.)
We can notate the general (selective) state space model as \begin{equation} \label{eq:ssm-transformation} Y^\mathtt{(T,P)} = \mathsf{SSM}(A^\mathtt{(T,…)}, B^\mathtt{(T,N)}, C^\mathtt{(T,N)})(X^\mathtt{(T,P)}) \end{equation}
Some axes of variation include
... = (N,N)
for general (unstructured) SSMs... = (N)
for diagonal SSMs (or other structures, such as diagonal-plus-low-rank ... = ()
for scalar SSMs (i.e. SSD)d_state
)d_head
)There are other axes of variation of structured SSMs (e.g. time-invariance vs. selectivity, SISO vs. MIMO
But first, let’s switch tacks and forget about state space models for a moment. Given the same tensors above with the same shapes $(A^\mathtt{(T)}, B^\mathtt{(T, N)}, C^\mathtt{(T, N)})$, let’s define a different object.
First, we’ll define the following matrix (don’t worry, we’ll explain more and give it a name in Part II of this series!)
\[L = \begin{bmatrix} 1 & \\ a_1 & 1 & \\ a_2a_1 & a_2 & 1 \\ \vdots & \vdots & \ddots & \ddots \\ a_{\mathtt{T}-1}\dots a_1 & a_{\mathtt{T}-1}\dots a_2 & \dots & a_{\mathtt{T}-1} & 1 \\ \end{bmatrix} .\]Then, let’s define the following matrix
\begin{equation} \label{eq:ssd-attention} M = L \circ C B^\top \in \mathbb{R}^{\mathtt{(T,T)}} \end{equation}
Finally, $M$ encodes a sequence transformation $x \in \mathbb{R}^\mathtt{T} \to y \in \mathbb{R}^\mathtt{T}$ mapping a 1D input to a 1D output—just as in equation \eqref{eq:ssm}—through basic matrix multiplication $y = Mx$.
What’s special about this? Well, you may notice that it looks very similar to an attention computation. In fact, if all $a_t = 1$, then $L$ is simply the lower-triangular causal mask and \eqref{eq:ssd-attention} is equivalent to causal linear attention
This is exactly the same as equation \eqref{eq:ssd-attention} if we rename $(C, B, X) \mapsto (Q, K, V)$!
The so-called “duality” refers to the fact that the two models defined in equations \eqref{eq:ssm} (for the scalar-identity structured $A_t$ case) and \eqref{eq:ssd-attention} are actually exactly the same model, which we can view as a particular function
\[(A^\mathtt{(T)}, B^\mathtt{(T, N)}, C^\mathtt{(T, N)}, X^\mathtt{(T, P)}) \mapsto Y^\mathtt{(T, P)}\]In the general SSD Framework (Part II of this series), we’ll show this equivalence in two completely different ways, both of which are actually much more general and each quite illuminating.
If you take our word for it, though, then SSD is relatively simple to contrast in relation to either SSMs or attention.
Compared to previous SSMs, SSD is pretty much the same as the core layer of Mamba but with even more structure on the recurrent $A$ matrices.
In particular, this can be viewed as weight-tied in two ways:
In other words, a single SSM head has total state size $\mathtt{P} \times \mathtt{N}$, which are each governed by separate scalar recurrences in Mamba-1 but are controlled by a single shared recurrence in Mamba-2.
Why make these restrictions? The main motivation is efficiency: these changes are necessary to be able to view the model in its [dual attention form], which allows matrix multiplications to be used.
The Bottom Line: Mamba-1 vs. Mamba-2
Compared to Mamba-1, Mamba-2 allows much larger state dimensions (from
N=16
in Mamba-1 toN=64
toN=256
or even higher in Mamba-2) while simultaneously being much faster during training.
But can this hurt us? There’s some intuition to believe that it shouldn’t. One of the main reasons for the selectivity (e.g. $A$ that depends on the input $X$) introduced in Mamba is to let the SSM be able to control whether to remember or ignore particular pieces of information; for example, if a filler “um” is encountered in a text transcript. But if such information should be ignored, then the entire state can ignore it together, and so it should be okay if the state’s dynamics are shared across all features.
Empirically, we haven’t found evidence that the restricted expressivity of Mamba-2 might hurt, but the jury’s still out! From one perspective, Mamba-2 isn’t strictly better than Mamba-1: while it’s a dramatic improvement from a training perspective, Mamba-1 might be better from a pure inference perspective. Since inference speed of SSMs is entirely governed by the state dimension, if one wants to maximize performance for a target inference efficiency (i.e. for a particular state size $\mathtt{N}$), then the increased expressivity of Mamba-1 might be better. We haven’t fully analyzed the (theoretical or empirical) tradeoffs here, and think this would be a cool direction for the community to dig in more!
Compared, to standard (self-)attention, SSD also only has two differences:
The first difference can be interpreted as what reduces the effective state size of the model from linear to constant, and improves its efficiency from quadratic to linear.
The second difference is what distinguishes SSD from standard linear attention. One way to think of the mask is as input-dependent relative positional encodings. Because of the mask $L$ in \eqref{eq:ssd-attention}, the standard attention score $\langle Q_i, K_j \rangle$ is attenuated by a weight
\[a_{i:j}^\times = a_i \cdots a_{j+1}\]which can be interpreted as a “discount factor” based on how far apart the positions $i$ and $j$ are. (This interpretation was concurrently espoused by Tobias Katsch’s GateLoop paper
So why do we care that there are two views of this model? Well, first of all, it’s extremely mathematically interesting, as we’ll cover in Part II, and we hope will inspire future directions. But there are immediate practical benefits too!
The SSM \eqref{eq:ssm} and attention \eqref{eq:ssd-attention} modes represent two different ways of computing the same function, so let’s contrast them.
First, remember that one main reason why SSMs are interesting to begin with is because computing \eqref{eq:ssm} as a recurrence requires maintaining a constant-size state (size $\mathtt{N}$ per channel) and scales linearly in the sequence length $\mathtt{T}$. The downside is that the raw FLOPs don’t reflect actual speed in practice because of hardware considerations…
On the other hand, computing this sequence transformation $y = Mx$ through equation \eqref{eq:ssd-attention} takes quadratic time in the sequence length, because we’re materializing this $\mathtt{T} \times \mathtt{T}$ matrix. But it can be fast in practice because it only uses matrix multiplications, which are extremely optimized on GPUs and TPUs.
So if there are two equivalent ways of computing the same model, when should we use one mode or the other? During inference, there’s no trade-off: the SSM mode is designed for fast autoregressive inference. But what about training? Here there’s a tension between FLOPs and hardware efficiency where the attention mode uses more FLOPs, but uses them more efficiently through matrix multiplications.
It turns out we can get the best of both worlds by combining the algorithms! There are two equivalent interpretations of this “state space dual” algorithm, either as
We’ll leave the details of this algorithm to Part III (or Section 6 of the full paper), as it requires a bit of machinery from the theory to derive. But we do emphasize that the implementation of this algorithm isn’t too complicated – a minimal implementation that we provide is only ~30 lines of PyTorch!
The benefits of the SSD algorithm is that it preserves the same efficient FLOP counts as SSMs (compared to quadratic attention), and also dramatically speeds up training compared to general state space models by utilizing matmuls.
Attention | SSM | SSD | |
---|---|---|---|
State size | $\mathrm{T}$ | $\mathbf{N}$ | $\mathbf{N}$ |
Training FLOPs | $\mathrm{T}^2\mathrm{N}$ | $\mathbf{TN^2}$ | $\mathbf{TN^2}$ |
Inference FLOPs | $\mathrm{T}\mathrm{N}$ | $\mathbf{N^2}$ | $\mathbf{N^2}$ |
(Naive) memory | $\mathrm{T}^2$ | $\mathrm{TN}^2$ | $\mathbf{TN}$ |
Matrix multiplications? |
Although the core contribution of Mamba-2 is the new SSD layer and theory, we also make some small changes to Mamba’s neural network architecture.
The main change is producing the $(A, B, C)$ SSM parameters in parallel with the $X$ input, instead of sequentially. This is partly motivated by the connections to attention; but more pragmatically, it’s simpler and more amenable to scaling techniques such as tensor parallelism, which will be discussed in Part IV of this series!
There are some other small differences which are covered in more detail in the paper. However, we do want to emphasize that these architectural changes aren’t really the main point of the model.
In terms of empirical results, we didn’t test Mamba-2 as extensively as Mamba-1, but believe it should generally be on par or better across the board. Our full language model results use the same protocol as Mamba, and found slightly better scaling at Chinchilla laws
Fully trained models on the Pile dataset
More interestingly, we highlight the one synthetic task we tried. Since the original Mamba paper, which investigated synthetics such as Synthetic Copying and Induction Heads, many follow-up works have begun investigating harder associative recall tasks. The multi-query associative recall (MQAR) task introduced by the Zoology and Based
We ran a version of this task that’s much harder than the one usually reported in the literature, and found that Mamba-2 is substantially better than Mamba-1. One reason for the improved performance is the much larger state size (up to $16\times$ larger than Mamba-1 here), which was one of the primary motivations of Mamba-2 in the first place.
Interestingly, Mamba-2 also appears to be noticeably better than Mamba-1 on this particular task even when the state size is controlled. We’re not quite sure why to be honest, and it would be great to ablate the other aspects of the model to investigate… for example, could it be possible that the [restricted structure of SSD] is actually helpful here?
In the next part of this series, we’ll go more into the full SSD framework, including how to prove the claimed “duality” of the SSD layer, and strong generalizations of it.