In Part I of this series, we defined the state space dual (SSD) model. In isolation, this model is relatively simple to define, and we claimed that it can be computed either as an SSM recurrence or with an attention-like pattern. If you just want to use the model, feel free to skip this post!
In this post, we’ll dive into the theory behind the model. We’ll derive the SSD “duality” in two completely separate ways, one starting from the SSM perspective and one from the attention perspective. Each method is actually much more broad than the SSD model itself, and the union of these two strong generalizations is what we call the SSD framework. This framework provides a rich body of connections between state space models, attention, and structured matrices. While the SSD model can be viewed as a specific instantiation of each prong of the framework, the SSD framework is much more general opens up many directions for future work.
For each of the two parts of this framework, we’ll
Note that this theory is not necessary to use the SSD model itself; this part of the series can be safely skipped for the practitioner that just wants to use SSD (Mamba-2).
Part I of this series introduced the SSD layer, which is defined as a selective SSM
\[\begin{aligned} h_{t} &= A_t h_{t-1} + B_t x_t \\ y_t &= C_t^{\top} y_t \end{aligned}\]\begin{equation} \label{eq:ssm} (\text{Selective state space model (SSM)}) \end{equation}
with scalar-identity structure on $A$.
More formally, we view it as a sequence transformation $X \mapsto Y$
\begin{equation} \label{eq:ssm-transformation} Y^\mathtt{(T,P)} = \mathsf{SSM}(A^\mathtt{(T)}, B^\mathtt{(T,N)}, C^\mathtt{(T,N)})(X^\mathtt{(T,P)}) \end{equation}
The dual attention-like form of the SSD layer is
\begin{equation} \label{eq:ssd-attention} M = L \circ C B^\top \in \mathbb{R}^{\mathtt{(T,T)}} \end{equation}
Now let’s see how to prove this!
The first framing of the duality will be from an SSM-centric perspective, where we’ll prove the duality through the framework of matrix sequence transformations or “matrix mixers”.
The idea is that many sequence models, i.e. sequence transformations $X \in \mathbb{R}^\mathtt{(T,P)} \mapsto Y \in \mathbb{R}^\mathtt{(T,P)}$, can be written in the form of a single matrix multiplication $Y = M(X) \cdot X$ where $M$ is a matrix which can itself depend on $X$. We call this a matrix sequence transformation, or matrix transformation for short. In the literature sequence transformations have also been referred to as “sequence mixers” or “token mixers”, and matrix sequence transformations as “matrix mixers”. There are many examples of these, which are distinguished by the structure of the $M$ matrix. The de facto example is self-attention itself, where $M = \mathsf{softmax}(QK^\top)$ is the attention matrix. Other examples include MLP-Mixer
Why do we care about these types of models?
Writing a sequence model as a matrix transformation provides a powerful tool to understand the structure and characteristics of the model.
And although general non-linear RNNs such as LSTMs cannot be written as matrix mixers, state space models can! In fact, this is pretty easy to see by just unrolling the definition of the SSM recurrence. The upshot is that the SSM \eqref{eq:ssm-transformation} can be written as a matrix transformation
\[Y = \mathsf{SSM}(A, B, C)(X) = MX\]where $M_{ij} = 0$ for $i < j$ (i.e. it’s lower triangular) and otherwise \begin{equation} \label{eq:semiseparable} M_{ij} = C_i^\top A_{i:j}^\times B_j := C_i^\top A_i \dots A_{j+1} B_j \end{equation}
Drawing it out, this matrix looks like
\[\begin{bmatrix} C_0^\top B_0 & \\ C_1^\top A_1 B_0 & C_1^\top B_1 & \\ C_2^\top A_2A_1 B_0 & C_2^\top A_2 B_1 & C_2^\top B_2 \\ \vdots & \vdots & \ddots & \ddots \\ C_\mathtt{T}^\top A_{\mathtt{T}-1}\dots A_1 B_0 & C_\mathtt{T}^\top A_{\mathtt{T}-1}\dots A_2 B_1 & \dots & C_\mathtt{T}^\top A_{\mathtt{T}-1} B_{\mathtt{T}-2} & C_\mathtt{T}^\top B_{\mathtt{T}-1} \\ \end{bmatrix}\]\begin{equation} \label{eq:ssm-matrix} (\text{Matrix Transformation Representation of State Space Models}) \end{equation}
This type of matrix in fact has a name: it’s called a (triangular) semiseparable matrix, and has been studied in other fields of engineering and computational linear algebra
For our purposes, we’ll care about this form mainly for the algorithmic considerations. One of the central messages of this SSD paper is that:
Takeaway: Computing SSMs Through Matrix Multiplication
All algorithms for computing state space models can be viewed as structured matrix multiplication algorithms on semiseparable matrices.
Let’s see an easy instantiation of this, focusing on our main objective!
To show that equation \eqref{eq:ssd-attention} follows from equation \eqref{eq:ssm} (in the case of the SSD model, i.e. scalar SSM), we directly use the matrix form of the state space model \eqref{eq:semiseparable}. Because the $A_t$ are all scalars in this case, they can be factored out of the entries
\[C_i^\top A_{i:j}^\times B_j = A_{i:j}^\times \cdot (C_i^\top B_j)\]which directly implies equation \eqref{eq:ssd-attention}.
In summary:
Duality Representation 1 (SSM)
The duality for the SSD model can be seen as two different matrix multiplication algorithms on the semiseparable matrix.
The power of the semiseparable matrix representation applies to all state space models, with various downstream implications.
Algorithmically, the Mamba-2 paper explores several consequences, such as:
Conceptually, the matrix transformation viewpoint helps provide a unifying view of sequence models. Some example downstream ideas include
We’re excited to see what algorithmic and conceptual ideas from the structured matrix literature can be applied to further improve state space models!
The second framing of the duality is from an attention-centric perspective, where we’ll prove the duality through the framework of tensor contractions.
Note that this is entirely independent of the previous [matrix transformation viewpoint].
For our purposes, we’ll define attention as a function
\[(Q^\mathtt{(T,N)}, K^\mathtt{(S,N)} , V^\mathtt{(S,P)} ) \mapsto Y^\mathtt{(T,P)}\]given by the pairwise matrix multiplications
\[Y = (QK^\top) \cdot V\]Think of $\mathtt{P} = \mathtt{N}$ as the head dimension; technically speaking, in attention the $V$ head dimension $\mathtt{P}$ can differ from the $QK$ head dimension $\mathtt{N}$. Think of $\mathtt{T}$ as the target sequence dimension and $\mathtt{S}$ as the source sequence dimension. Giving these two axes different names will make the math more clear and also covers more general forms of attention such as cross-attention, where the source and target are separate sequences with different lengths. However, for our purposes we’ll assume the self-attention setting where $\mathtt{S}=\mathtt{T}$.
The usual form of attention $Y = f(QK^\top) \cdot V$ (e.g. where $f$ is the softmax function) can, for essentially all functions $f$
We’ll restrict ourselves to the case when $\psi$ is finite, which is sometimes called kernel attention. Many, many variants have been proposed before!
Why do we care about this formulation? When the sequence length $\mathtt{T}$ grows and the feature dimension $\mathtt{N}$ is small—commonly, in the regime when $\psi$ is simple such as an elementwise transform and so $\mathtt{N}$ is constant—then the cost of attention can be reduced from quadratic in $\mathtt{T}$ to linear. This follows from simply computing the matrix multiplications in a different order
\[Y = Q \cdot (K^\top V)\]This is a somewhat “folklore” interpretation of linear attention.
The most common way of linearizing attention is usually viewed as a consequence of the associativity of matrix multiplication
However, once the basic kernel attention is slightly modified, we can no longer use the associativity of matrix multiplication directly.
The seminal Linear Attention (LA) framework of Katharopoulos et al.
Let’s be a lot more explicit about how it works. The quadratic form of causal linear attention is \begin{equation} \label{eq:quadratic-kernel-attention} Y = (L \circ QK^\top) \cdot V \end{equation} where
\[L = \begin{bmatrix} 1 \\ \vdots & \ddots \\ 1 & \dots & 1 \end{bmatrix}\]is the causal mask matrix.
The issue is: once the $L$ mask is incorporated into \eqref{eq:quadratic-kernel-attention}, we can no longer directly apply matrix associativity! This is the problem that the original Linear Attention paper addresses. What they show is that \eqref{eq:quadratic-kernel-attention} is equivalent to a different form which avoids materializing the quadratic $QK^\top$ attention matrix and has linear time complexity
\[Y = Q \cdot \mathsf{cumsum}(K^\top V)\]As far as we’re aware this wasn’t explicitly proved in the paper, although it isn’t too hard to write out the summation to show it.
What we’ll do is prove this equivalence in essentially one line, while revealing exactly where the “linear” part of Linear Attention comes from, and how to strongly generalize it.
Spoiler alert:
Where does the cumsum in Linear Attention come from?
The appearance of the cumulative sum in linear attention is exactly equivalent to the fact that the causal mask $L$, as a matrix multiplication, encodes cumulative sums:
\[y = L \cdot x \iff y = \mathsf{cumsum}(x)\]
Let’s write out the quadratic form of linear attention \eqref{eq:quadratic-kernel-attention} very explicitly in tensor contraction or einsum notation, with shape annotations:
\[\begin{aligned} G &= \mathsf{contract}(\mathtt{TN, SN} \to \mathtt{TS})(Q, K) \\ M &= \mathsf{contract}(\mathtt{TS, TS} \to \mathtt{TS})(G, L) \\ Y &= \mathsf{contract}(\mathtt{TS, SP} \to \mathtt{TP})(M, V) \end{aligned}\]\begin{equation} \label{eq:sma-quad} (\text{Structured Masked Attention - Quadratic Form}) \end{equation}
With this notation, we can notice that this sequence of contractions can be written as a single four-way contraction
\begin{equation} \label{eq:sma} y = \mathsf{contract}(\mathtt{TN},\mathtt{SN},\mathtt{SP},\mathtt{TS} \to \mathtt{TP})(Q, K, V, L) . \end{equation}
And finally, it can be computed with any other contraction ordering. In particular, we can perform pairwise reductions on the order $V, K, L, Q$ instead of $Q, K, L, V$
\[\begin{aligned} Z &= \mathsf{contract}(\mathtt{SP},\mathtt{SN} \to \mathtt{SPN})(V, K) \\ H &= \mathsf{contract}(\mathtt{TS},\mathtt{SPN} \to \mathtt{TPN})(L, Z) \\ Y &= \mathsf{contract}(\mathtt{TN},\mathtt{TPN} \to \mathtt{TP})(Q, H) \end{aligned}\]\begin{equation} \label{eq:sma-lin} (\text{Structured Masked Attention - Linear Form}) \end{equation}
Now the key observation is that the second line of \eqref{eq:sma-lin} is simply a matrix multiplication by $L$, which can be computed with a cumulative sum.
That’s the entire proof of linear attention! The beauty of it is that we didn’t have to write out a single summation, which was abstracted out into a tensor contraction combined with the structure of $L$.
This immediately proves our claim about the cumsum in linear attention. Moreover, this immediately reveals that the efficiency of linear attention can be made much more general…
The critical observation is that in order for \eqref{eq:sma-lin} to be fast, all that is necessary is for $L$ to be any structured matrix – in other words any matrix that has subquadratic matrix-vector multiplication.
This immediately motivates one of the main prongs of the SSD framework, which can be seen as a strong generation of LA.
Definition: Structured Masked Attention
Structured masked attention (SMA) is defined as the four-way tensor contraction \eqref{eq:sma} using an attention mask $L$ that is a structured matrix.
Duality Representation 2 (SMA)
SMA has dual quadratic and linear
Assuming that the structured matrix $L$ has linear time matrix-vector multiplication modes which are simply two different pairwise reduction orders \eqref{eq:sma-quad} and \eqref{eq:sma-lin}.
Finally, let’s just connect this back to the commonly held view of linear attention as matrix multiplication associativity.
Although it is commonly believed that incorporating attention masks $L$ prevents matrix multiplication reordering, it turns out to still be compatible. In particular, associativity of matrix multiplication is a special case of tensor contraction reduction orders; although the former no longer applies, the latter can integrate the attention mask $L$.
Next, let’s look at some consequences of the structured attention framework.
Recall that the SSD model is defined as either a scalar-identity SSM in equation \eqref{eq:ssm}, or through the attention-like form in equation \eqref{eq:ssd-attention}.
To show the equivalence of these forms, we simply recognize that \eqref{eq:ssd-attention} is a special case of structured masked attention where the mask matrix is
\[L = \begin{bmatrix} 1 & \\ a_1 & 1 & \\ a_2a_1 & a_2 & 1 \\ \vdots & \vdots & \ddots & \ddots \\ a_{\mathtt{T}-1}\dots a_1 & a_{\mathtt{T}-1}\dots a_2 & \dots & a_{\mathtt{T}-1} & 1 \\ \end{bmatrix} .\]\begin{equation} \label{eq:1-ss} (\text{1-semiseparable (1-SS) matrix}) \end{equation}
We call this a 1-semiseparable (1-SS) matrix, for reasons that are explained in more detail in the Mamba-2 paper.
Thus, we can also say that the SSD model is 1-semiseparable masked attention or 1-SS SMA.
To prove that this can be written as an SSM, we simply appeal to the SMA framework, which says that the dual form of this model can be computed through matrix multiplication by $L$. So how fast is that? It’s not too hard to see that multiplication $y = Lx$ can be computed in linear time through a scalar recurrence:
\[\begin{aligned} y_0 &= x_0 \\ y_1 &= a_1 x_0 + a_1 \\ y_2 &= a_2a_1 x_0 + a_2 x_1 + x_2 = a_2 y_1 + x_2 \\ \vdots & \qquad \vdots \end{aligned}\]This corresponds exactly to the original SSM recurrence!
(In fact, multiplication by 1-SS matrices $L$ can be computed in a lot more ways, which we compile in the full paper! Alternative algorithms can reveal more insights: for example, the associative scan algorithm used by S5
Structured masked attention not only helps define the SSD model and prove its duality, but it is a much broader framework of efficient attention models.
Prior examples include the original linear attention as well as the recent Retentive Network (RetNet) model
Additionally, other forms of structure can be incorporated into the $L$ mask. For example, another extension my students are developing is viewing SSD (and recurrences in general) as an algorithm operating on directed line graphs, and generalizing it to incorporate arbitrary graph structures.
We’ll end this post with a brief recap of what we’ve covered.
The SSD framework consists of the two broad approaches covered in this post, which is summarized by the two areas of the [Venn diagram]:
The [SSD layer] is a particular model which is the purple intersection in the figure, which can be viewed as an instance of either part of the SSD framework, and in particular has dual quadratic and linear forms that can be derived from either representation.
SSD Framework | Structured SSMs | Structured Attention |
---|---|---|
The main representation is… | Structured matrix \eqref{eq:ssm-matrix} sequence transformations | The 4-way \eqref{eq:sma} tensor contraction |
This generalizes… | State space models | Linear attention |
The SSD model is an instantiation as… | Scalar state space model ($A_t$ is a scalar-identity matrix) | 1-semiseparable masked attention ($L$ mask is a 1-SS matrix) |
The linear-quadratic duality is revealed through… | Structured matrix multiplication algorithms | Tensor contraction reduction orderings |
In the next part of this series, we’ll see how to use some of the SSD framework (in particular, the structured matrix algorithm point of view) to derive the more efficient hybrid SSD algorithm that leverages both of the dual forms.