<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en"><generator uri="https://jekyllrb.com/" version="4.4.1">Jekyll</generator><link href="tridao.github.io/feed.xml" rel="self" type="application/atom+xml"/><link href="tridao.github.io/" rel="alternate" type="text/html" hreflang="en"/><updated>2026-03-30T16:28:13+00:00</updated><id>tridao.github.io/feed.xml</id><title type="html">Tri Dao</title><subtitle>Homepage of Tri Dao. # A simple, whitespace theme for academics. Based on [*folio](https://github.com/bogoli/-folio) design. </subtitle><entry><title type="html">Gram Newton-Schulz: A Fast, Hardware-Aware Newton-Schulz Algorithm for Muon</title><link href="tridao.github.io/blog/2026/gram-newton-schulz/" rel="alternate" type="text/html" title="Gram Newton-Schulz: A Fast, Hardware-Aware Newton-Schulz Algorithm for Muon"/><published>2026-03-30T00:00:00+00:00</published><updated>2026-03-30T00:00:00+00:00</updated><id>tridao.github.io/blog/2026/gram-newton-schulz</id><content type="html" xml:base="tridao.github.io/blog/2026/gram-newton-schulz/"><![CDATA[<style>.post img{max-width:100%;height:auto}.post blockquote p{margin-top:.2em;margin-bottom:.2em;line-height:1.4}.post blockquote p:first-child{margin-top:0}.post blockquote p:nth-child(2),.post blockquote p:nth-child(3){margin-bottom:1em}.post blockquote strong{font-style:normal!important}.post blockquote{background-color:transparent;border-left:4px solid var(--global-theme-color,#4c9eff);padding:1rem 1.5rem;font-size:inherit;color:inherit}.post blockquote .MJXc-display,.post blockquote .katex-display{text-align:center!important;margin:1em 0!important}.post blockquote .MathJax,.post blockquote .katex,.post blockquote .MathJax_Display,.post blockquote mjx-container,.post blockquote mjx-math,.post blockquote mjx-mrow,.post blockquote .MathJax *,.post blockquote mjx-container *{color:inherit!important}html[data-theme='dark'] .post blockquote .MathJax,html[data-theme='dark'] .post blockquote mjx-container,html[data-theme='dark'] .post blockquote mjx-container *{color:var(--global-text-color)!important}.post h1{font-weight:normal!important;font-style:normal!important;border-bottom:1px solid var(--global-divider-color)!important;padding-bottom:.5rem!important}.post h1{margin-top:3rem!important;margin-bottom:1.5rem!important}.post h2{margin-top:2.5rem!important;margin-bottom:1.25rem!important}.post h3{margin-top:2rem!important;margin-bottom:1rem!important}.post h4{margin-top:1.5rem!important;margin-bottom:.75rem!important}.post h5,.post h6{margin-top:1rem!important;margin-bottom:.5rem!important}</style> <p>Muon is becoming the optimizer of choice for training state-of-the-art language models like Kimi K2 Thinking and GLM-5.<sup id="fnref:kimi"><a href="#fn:kimi" class="footnote" rel="footnote" role="doc-noteref">1</a></sup><sup id="fnref:GLM"><a href="#fn:GLM" class="footnote" rel="footnote" role="doc-noteref">2</a></sup> Compared to AdamW, Muon needs fewer optimizer steps to reach a given loss, but each step is more expensive. This overhead is due to Muon’s Newton-Schulz orthogonalization procedure, a cubic time matrix operation not present in older optimizers.</p> <p><img src="https://hackmd.io/_uploads/ByJX-6BsWg.png" alt="icml_optimizer_plot_blackwell (2)"/> <em>Figure 1: AdamW vs. Muon: Wall clock time of optimizer step across LLaMa model sizes, benchmarked on B300.</em></p> <p>Muon’s superior optimization quality justifies its more expensive optimizer step. However, as model size scales up, the overhead of computing each Muon step grows rapidly. Traditional optimization methods (SGD, AdamW) perform element-wise operations, such as updating the momentum or rescaling it by the second moment. For a weight matrix of size $n \times m$, performing the optimizer step takes $O(mn)$ time given the gradient matrix as input. In contrast, many modern optimizers (Muon, Scion, Dion, SOAP, Shampoo, SPlus, etc.) use orthogonalization or higher-order preconditioning to compute the update to the weights at each training step.<sup id="fnref:muon"><a href="#fn:muon" class="footnote" rel="footnote" role="doc-noteref">3</a></sup><sup id="fnref:dion"><a href="#fn:dion" class="footnote" rel="footnote" role="doc-noteref">4</a></sup><sup id="fnref:scion"><a href="#fn:scion" class="footnote" rel="footnote" role="doc-noteref">5</a></sup><sup id="fnref:soap"><a href="#fn:soap" class="footnote" rel="footnote" role="doc-noteref">6</a></sup><sup id="fnref:shampoo"><a href="#fn:shampoo" class="footnote" rel="footnote" role="doc-noteref">7</a></sup><sup id="fnref:splus"><a href="#fn:splus" class="footnote" rel="footnote" role="doc-noteref">8</a></sup> These methods require matrix multiplications that cost $O(mn^2)$ time (assuming $n \leq m$). Therefore, the runtime of each call to the optimizer is far greater than for AdamW. Depending on the training setup (global batch size, cluster size, and parallelism settings), Newton-Schulz accounts for between <a href="#appendix">2% and 17%</a> of end-to-end wall clock time.</p> <p>While $O(mn^2)$ runtime is an unavoidable cost of these algorithms, there is still significant room for improvement in both FLOPs and wall clock time. As it is typically implemented, the Newton-Schulz routine has several shortcomings:</p> <ol> <li>It uses not just one or two, but <em>ten</em> multiplications of $n \times m$ matrices, costing $2mn^2$ FLOPs each. Most weights in popular architectures are rectangular, with $m \gg n$, and those of recent MoE architectures with many fine-grained experts are even <em>more</em> rectangular. Thus, the rectangular matrix multiplications dominate the costs of other operations (like small multiplications of $n \times n$ matrices).</li> <li>Many of the intermediate matrices it computes are symmetric, but no computational advantage is taken of this structure. Half the work used to compute these matrices is redundant.</li> <li>It uses cuBLAS for batched matrix multiplication/addition $\alpha \mathbf A \mathbf B + \beta \mathbf C$, which is not fully optimized for the Hopper GPU architecture. </li> </ol> <p>Previous work has sought to improve Newton-Schulz by optimizing its polynomial coefficients or its normalization step. While this can reduce the number of iterations needed for Newton-Schulz to converge, it does not address the shortcomings listed above. Others<sup id="fnref:flashmuon"><a href="#fn:flashmuon" class="footnote" rel="footnote" role="doc-noteref">9</a></sup> have implemented Newton-Schulz using special-purpose symmetric matrix multiplication routines, but the runtime benefit is limited due to the high number of rectangular and non-symmetric multiplications. While Newton-Schulz and related methods have been studied for decades in the numerical analysis literature, research attention has mostly focused on regimes where high accuracy is required, where algorithms are optimized for CPUs rather than GPUs, or where input matrices are square. In recent years, randomized sketching has been used to design sophisticated algorithms for many computations involving highly rectangular matrices; however (aside from further optimizing the coefficients<sup id="fnref:PRISM"><a href="#fn:PRISM" class="footnote" rel="footnote" role="doc-noteref">10</a></sup>) these do not seem to be applicable to Muon.</p> <h2 id="our-contributions">Our Contributions</h2> <p>To address these shortcomings, we introduce <strong>Gram Newton-Schulz</strong>, a reworking of the Newton-Schulz routine that <strong>reduces the optimizer time by up to 50%</strong> in trillion-parameter MoE models like Kimi K2. Instead of iterating on the rectangular input matrix $\mathbf{X} \in \mathbb{R}^{n \times m}$, Gram Newton-Schulz iterates on the small square symmetric Gram matrix $\mathbf{XX^\top} \in\mathbb{R}^{n \times n}$, reducing the FLOP cost and enabling a greater use of symmetric GEMM kernels.</p> <p>Our contributions are as follows. First, we show how to rewrite standard Newton-Schulz in a form that is <em>mathematically identical</em>, producing the exact same output up to floating-point error, but that acts mostly on the space of $n \times n$ matrices. Because these matrices are smaller and admit specialized symmetric matrix multiplication routines, each iteration is faster than in standard Newton-Schulz. Only the preprocessing step (forming $\mathbf X \mathbf X^\top$) and the post-processing step (multiplying by $\mathbf X$) require rectangular matrix multiplications. We call this new form <a href="#alg-naive-gram-ns">Naive Gram Newton-Schulz</a>.</p> <p>Second, we conduct a thorough study of the numerical properties of Naive Gram Newton-Schulz. We identify the potential for numerical instability when using half-precision floating point arithmetic, especially due to spurious negative eigenvalues in the Gram matrix. We remedy this instability using a “restarting” strategy, where we reconstruct the Gram matrix partway through the algorithm. We call this modified algorithm <a href="#alg-stable-gram-ns">Stabilized Gram Newton-Schulz</a>.</p> <p>Third, to take full advantage of the latest generation of GPUs and of the mathematical structure of Newton-Schulz, we implement custom kernels for <em>symmetric</em> matrix multiplication. The kernels, implemented in CuTeDSL for the Hopper and Blackwell architectures, attain state-of-art performance.</p> <p>Finally, we replace Muon’s Newton-Schulz routine with Gram Newton-Schulz, an optimizer we call <strong>GramMuon</strong>, and observe a 40-50% reduction in the runtime of the orthogonalization step. Experiments confirm that training language models with GramMuon is stable and preserves the optimization quality of the standard version within $0.01$ validation perplexity, making our algorithm a rare instance of “free lunch” performance improvement.</p> <p>To facilitate the adoption of Gram Newton-Schulz, we release the following open-source implementations:</p> <ol> <li>A <a href="https://github.com/Dao-AILab/gram-newton-schulz">drop-in replacement</a> for Muon’s Newton-Schulz routine that is mathematically equivalent, numerically stable, and up to twice as fast.</li> <li><a href="https://github.com/Dao-AILab/quack/blob/main/quack/gemm_symmetric.py">Fast GPU kernels</a> for symmetric matrix multiplication ($AB$, $\alpha AB + \beta C$) written in CuTeDSL for Hopper and Blackwell, which may be of independent interest. </li> </ol> <p></p> <p>We will begin by recapping Muon to see why we need Newton-Schulz in the first place, describing how standard Newton-Schulz works mathematically, and analyzing its performance bottlenecks.</p> <h1 id="muon-recap">Muon Recap</h1> <p>The Muon optimizer<sup id="fnref:muon:1"><a href="#fn:muon" class="footnote" rel="footnote" role="doc-noteref">3</a></sup> is best described as steepest-direction descent with respect to the spectral norm.<sup id="fnref:deriving_muon"><a href="#fn:deriving_muon" class="footnote" rel="footnote" role="doc-noteref">11</a></sup> At step $k$ of training, let $\mathbf W_k \in \mathbb{R}^{n \times m}$ be a weight matrix and let $\mathbf G_k$ be the gradient of the loss with respect to $\mathbf W_k$. The Muon update rule is</p> \[\begin{align*} \mathbf{M}_k &amp;= \mu \mathbf{M}_{k-1} + \mathbf{G}_k \\ \mathbf{W}_{k+1} &amp;= \mathbf{W}_k - \eta \operatorname{polar}(\mathbf{M}_k) \end{align*}\] <p>where $\mu$ is the momentum coefficient, $\eta$ is the learning rate, and $\mathbf M_k$ is the momentum matrix (with $\mathbf M_0 := 0$).</p> <p>In most ways, Muon resembles basic stochastic gradient descent (SGD) with momentum. Its key innovation is using the $\operatorname{polar}$ operation, which is defined as follows:</p> <blockquote> <p><strong>Definition 1: Polar Decomposition</strong></p> <p>If $\mathbf X = \mathbf U \mathbf \Sigma \mathbf V^\top$ is the singular value decomposition (SVD) of a matrix, then $\operatorname{polar}(\mathbf X) = \mathbf U \mathbf V^\top$.</p> </blockquote> <p>Since $\operatorname{polar}(\mathbf X)$ is expensive to compute exactly, Muon uses the Newton-Schulz method to approximate it. Newton-Schulz is an iterative method based on matrix polynomials. Beginning with $\mathbf X_0$, each iteration improves the approximation $\mathbf X_t \approx \operatorname{polar}(\mathbf X)$ according to the update rule</p> \[\mathbf X_{t+1} = a_t \mathbf X_t + b_t \mathbf X_t \mathbf X_t^\top \mathbf X_t + c_t \left(\mathbf X_t \mathbf X_t^\top\right)^2 \mathbf X_t.\] <p>We can interpret Newton-Schulz by understanding how it affects the singular value decomposition.</p> <p>Let $\mathbf X_0 = \mathbf U \mathbf \Sigma \mathbf V^\top$ be the SVD. Recall that $\mathbf U$ and $\mathbf V$ have orthonormal columns, such that $\mathbf U^\top \mathbf U = \mathbf V^\top \mathbf V = \mathbf I$, and $\mathbf \Sigma$ is a diagonal matrix whose entries are called the singular values. Then</p> \[\mathbf X_0 \mathbf X_0^\top \mathbf X_0 = \left(\mathbf U \mathbf \Sigma \mathbf V^\top\right) \left(\mathbf U \mathbf \Sigma \mathbf V^\top\right)^\top \left(\mathbf U \mathbf \Sigma \mathbf V^\top\right) = \mathbf U \mathbf \Sigma \mathbf V^\top \mathbf V \mathbf \Sigma \mathbf U^\top \mathbf U \mathbf \Sigma \mathbf V^\top = \mathbf U \mathbf \Sigma^3 \mathbf V^\top\] <p>By the same logic,</p> \[\mathbf X_1 = \mathbf U \left(a_1 \mathbf \Sigma + b_1 \mathbf \Sigma^3 + c_1 \mathbf \Sigma^5 \right) \mathbf V^\top = \mathbf U p_1(\mathbf \Sigma) \mathbf V^\top\] <p>where we have defined the polynomial $p_1(x) = a_1 x + b_1 x^3 + c_1 x^5$. Since $\mathbf U$ and $\mathbf V$ have orthonormal columns and $p_1(\mathbf \Sigma)$ is diagonal, the right-hand side of this equation must be the SVD of $\mathbf X_1$! This shows that $\mathbf X_1$ shares the same singular vectors $\mathbf U$ and $\mathbf V$ as $\mathbf X_0$, and that its singular values are those of $\mathbf X_0$ transformed according to the polynomial $p_1$. By extension, $\mathbf X_T$ also shares the same singular vectors $\mathbf U$ and $\mathbf V$, and its singular values have been transformed according to the composition of polynomials $(p_T \circ \cdots \circ p_1)(\mathbf \Sigma)$. If $(p_T \circ \cdots \circ p_1)(x) \approx 1$ for all singular values, then $(p_T \circ \cdots \circ p_1)(\mathbf \Sigma) \approx \mathbf I$ and so $\mathbf X_T \approx \mathbf U \mathbf V^\top = \operatorname{polar}(\mathbf X_0)$.</p> <p>All that remains is to find a sequence of odd polynomials for which $(p_T \circ \cdots \circ p_1)(x) \approx 1$ on the singular values. To make this easier, we first normalize the matrix $\mathbf X_0 = \mathbf X / |\mathbf X|_{\mathsf F}$. This ensures that the singular values of $\mathbf X_0$ lie in the interval $[0, 1]$. The developers of Muon identified a sequence of five degree-5 odd polynomials that approximate $1$ for every input on this interval $[0, 1]$, giving a decent approximation to $\operatorname{polar}(\mathbf X)$ for typical inputs $\mathbf X$ in just five iterations.<sup id="fnref:muon:2"><a href="#fn:muon" class="footnote" rel="footnote" role="doc-noteref">3</a></sup></p> <p>A standard implementation of Newton-Schulz looks like this:</p> <p><a id="alg-standard-ns"></a></p> <blockquote> <p><strong>Algorithm 1: Standard Newton-Schulz</strong></p> <p>Input: $\mathbf X \in \mathbb{R}^{n \times m}$, coefficients ${(a_t, b_t, c_t)}_{t=1}^5$</p> <ol> <li>$\mathbf X \gets \mathbf X \,/\, (\lVert\mathbf X\rVert_{\mathsf F} + \epsilon)$    // Normalize sing vals to $[0, 1]$. $\epsilon = 10^{-7}$</li> <li>$\mathbf X \gets \texttt{bfloat16}(\mathbf X)$     // Cast to half precision for speed</li> <li>If $m &lt; n$:  $\mathbf X \gets \mathbf X^\top$   // Trick to make $\mathbf X \mathbf X^\top$ cheaper</li> <li>For $t = 1, \ldots, 5$:       // Apply $p_t(\mathbf X)$</li> <li>   $\mathbf A \gets \mathbf X\mathbf X^\top$</li> <li>   $\mathbf B \gets b_t \mathbf A + c_t \mathbf A^2$</li> <li>   $\mathbf X \gets a_t \mathbf X + \mathbf B \mathbf X$</li> <li>If $m &lt; n$:  $\mathbf X \gets \mathbf X^\top$   // Undo trick</li> <li>Return $\mathbf X$</li> </ol> </blockquote> <p>Successive work has sought to improve Muon in several ways. Most of these proposals modify Muon’s update rule so as to reach the same loss in fewer training steps; however, they use the same Newton-Schulz routine described above. Some methods (e.g., Polar Express) do address Newton-Schulz by changing the sequence of polynomials or the normalization step.<sup id="fnref:polar-express"><a href="#fn:polar-express" class="footnote" rel="footnote" role="doc-noteref">12</a></sup><sup id="fnref:grishina"><a href="#fn:grishina" class="footnote" rel="footnote" role="doc-noteref">13</a></sup> While they improve its approximation accuracy, they do not change its wall-clock runtime. The Dion optimizer<sup id="fnref:dion:1"><a href="#fn:dion" class="footnote" rel="footnote" role="doc-noteref">4</a></sup> reduces the runtime in the distributed setting, when weights and gradients are sharded across different GPUs. It uses a low-rank approximation of Muon to reduce the communication cost and the dimension of $\mathbf X$, but each step still calls the standard Newton-Schulz routine.</p> <p>In contrast, our work speeds up Newton-Schulz itself. Since Gram Newton-Schulz is mathematically identical to the standard version, it is compatible with nearly all varieties of Muon.</p> <h2 id="runtime-of-standard-newton-schulz">Runtime of Standard Newton-Schulz</h2> <p>Let’s analyze the runtime of Newton-Schulz in FLOPs to help us understand its performance bottlenekcs. We count only the cubic-time matrix multiplication operations, ignoring the lower-order scalar multiplications and matrix additions. For clarity, we let $T$ denote the number of iterations, remembering that within Muon, $T=5$.<sup id="fnref:num-iters"><a href="#fn:num-iters" class="footnote" rel="footnote" role="doc-noteref">14</a></sup> We also assume without loss of generality that $n \leq m$ and define the aspect ratio $\alpha = m / n \geq 1$. Intuitively, $\alpha$ measures how rectangular the shape of the matrix is, with $\alpha = 1$ being square and $\alpha \gg 1$ being very rectangular.</p> <p>Each iteration has three steps. Each step contains a single matrix multiplication costing, respectively,</p> <ul> <li>$\mathbf X \mathbf X^\top$: $2mn^2$</li> <li>$\mathbf A^2$: $2n^3$</li> <li>$\mathbf B \mathbf X$: $2mn^2$</li> </ul> <p>for a total cost of $T(4mn^2 + 2n^3) = 2T(2\alpha + 1)n^3$ FLOPs. When $T=5$, the cost is $(20\alpha + 10)n^3$ spread across 15 GEMMs. This analysis highlights two shortcomings of standard Newton-Schulz that inspire our work:</p> <h3 id="symmetric-matrix-multiplication">Symmetric Matrix Multiplication</h3> <p>The matrices $\mathbf A = \mathbf X \mathbf X^\top$ and $\mathbf B = b_t \mathbf A + c_t \mathbf A^2$ computed at each iteration of Newton-Schulz are symmetric by definition. This fact can be exploited to reduce the cost of Newton-Schulz. Instead of calling general matrix multiplication routines as typical implementations of Newton-Schulz do, we can compute the lower triangular part of these matrices in the usual way and then simply copy the results to the upper triangular part. This technique halves the cost of computing $\mathbf X \mathbf X^\top$ and $\mathbf A^2$, giving an overall total of $T(3\alpha + 1)n^3$ FLOPs. We describe our custom CuTeDSL kernels that implement this technique <a href="#symmetric-gemm-kernels-in-cutedsl">below</a>.</p> <p><a id="dependence-on-alpha"></a></p> <h3 id="dependence-on-alpha">Dependence on $\alpha$</h3> <p>Even using symmetric GEMMs, Newton-Schulz’s runtime is dominated by the large rectangular matrix multiplications needed to compute $\mathbf A$ and $\mathbf X$, which together cost $3\alpha n^3$ FLOPs per iteration. A typical implementation with $T=5$ requires 10 of these expensive rectangular multiplications.</p> <p>This strong dependence on $\alpha$ is unfortunate. Most of the weight matrices in transformer architectures are rectangular, including the MLP weights, MoE weights, and attention projection weights when using GQA or MLA.<sup id="fnref:embeddings"><a href="#fn:embeddings" class="footnote" rel="footnote" role="doc-noteref">15</a></sup> Furthermore, we observe that the latest MoE architectures are trending towards finer-grained, sparser experts, meaning that the aspect ratios of their hidden dimensions to intermediate dimensions are increasing as well.<sup id="fnref:kimi:1"><a href="#fn:kimi" class="footnote" rel="footnote" role="doc-noteref">1</a></sup><sup id="fnref:sonicmoe"><a href="#fn:sonicmoe" class="footnote" rel="footnote" role="doc-noteref">16</a></sup><sup id="fnref:qwen"><a href="#fn:qwen" class="footnote" rel="footnote" role="doc-noteref">17</a></sup><sup id="fnref:gpt-oss"><a href="#fn:gpt-oss" class="footnote" rel="footnote" role="doc-noteref">18</a></sup></p> <p>Thus, at large scales, pretraining time would benefit greatly from an algorithm that uses fewer rectangular multiplications and more small symmetric ones.</p> <h1 id="gram-newton-schulz">Gram Newton-Schulz</h1> <p>We now show how to rewrite Newton-Schulz to reduce the number of expensive rectangular matrix multiplications by iterating on the small, square, symmetric Gram matrix $\mathbf X \mathbf X^\top$ instead of the rectangular input matrix $\mathbf X$. The output of this algorithm is mathematically identical to that of standard Newton-Schulz, but it is significantly cheaper to compute.</p> <p>At a high level, our strategy is based on the following formula. If $\mathbf X \in \mathbb{R}^{n \times m}$ with $n \leq m$, then $\mathrm{polar}(\mathbf X) = (\mathbf X \mathbf X^\top)^{-1/2} \mathbf X$. Rather than use an iterative method to approximate $\mathbf X_T \approx \mathrm{polar}(\mathbf X)$ directly, we instead</p> <ol> <li>Compute the $n \times n$ Gram matrix $\mathbf X \mathbf X^\top$</li> <li>Use an iterative method to approximate $\mathbf Q_T \approx (\mathbf X \mathbf X^\top)^{-1/2}$</li> <li>Compute $\mathbf Q_T \mathbf X$</li> </ol> <p>Step 2—which comprises almost all of the algorithm’s wall clock runtime and FLOP cost—works entirely with small $n \times n$ symmetric matrices. This version uses just two rectangular matrix multiplications: $\mathbf X \mathbf X^\top$ in the beginning, and $\mathbf Q_T \mathbf X$ at the end. It also synergizes well with our symmetric GEMM kernels. Because we now use more symmetric multiplications, our kernels provide an even greater speedup than before. Since this method works on the $n \times n$ Gram matrix of $\mathbf X$, we call it “Gram Newton-Schulz”.</p> <p>How can we turn an iterative polynomial method $(p_T \circ \cdots \circ p_1)(\mathbf X) \approx \operatorname{polar}(\mathbf X)$ like Newton-Schulz into an iterative polynomial method for approximating $\mathbf Y \mapsto \mathbf Y^{-1/2}$? Recall that each $p_t$ is an odd polynomial $p(x) = ax + bx^3 + cx^5$. Any odd polynomial can be rewritten in the form $p(x) = xh(x^2)$, where $h$ is a lower-degree polynomial with the same coefficients, like $h(x) = a + bx + cx^2$. Intuitively, if $p(x) \approx 1$, then $h(y) = p(y^{1/2})y^{-1/2} \approx y^{-1/2}$, so the Newton-Schulz polynomials implicitly provide a way to approximate inverse square roots.</p> <p>Formally, Gram Newton-Schulz is based on the following theorem. In effect, it shows how to compute $\mathbf X_T$ from $\mathbf X_0$ without ever constructing the intermediate values $\mathbf X_1, \ldots, \mathbf X_{T-1}$:</p> <blockquote> <p><strong>Theorem 1:</strong></p> <p>If $p_t(x) = xh_t(x^2)$ for all $t \in {1, \ldots, T}$, then $(p_T \circ \cdots \circ p_1)(x) = q_T x$, where $q_T$ is defined by the iteration $r_0 = x^2$, $q_0 = 1$, and</p> \[z_t = h_t(r_{t-1})\] \[r_t = r_{t-1}z_t^2\] \[q_t = q_{t-1}z_t\] <p>for all $t \in {1, \ldots, T}$.</p> </blockquote> <p><em>Proof</em>. Define $x_0 = x$ and $x_t = p_t(x_{t-1})$ for $t \in {1, \ldots, T}$. We will show by induction that $r_t = x_t^2$ and $q_t = x_t / x_0$ for all $t$. The base case $t = 0$ holds by the definition $r_0 = x^2, q_0 = 1$. Now assume the hypothesis holds for $t-1$. By assumption,</p> \[x_t = p_t(x_{t-1}) = x_{t-1} h_t(x_{t-1}^2)\] <p>Observe that $h_t(x_{t-1}^2) = h_t(r_{t-1}) = z_t$, so $x_t = x_{t-1} z_t$. Squaring both sides,</p> \[x_t^2 = x_{t-1}^2 z_t^2 = r_{t-1} z_t^2 = r_t\] <p>If we instead divide both sides by $x_0$,</p> \[\frac{x_t}{x_0} = \frac{x_{t-1}}{x_0}z_t = q_{t-1} z_t = q_t\] <p>Thus, the hypothesis holds for $t$ as well. Finally, observe that $(p_T \circ \cdots \circ p_1)(x) = x_T = q_T x_0$.$\blacksquare$</p> <p>Note that, as an immediate corollary of the proof, $q_t = x_t / x_0 \to 1/x_0 = \left(x_0^2\right)^{-1/2} = r_0^{-1/2}$. In effect, this shows that $\mathbf Q_T \to (\mathbf X \mathbf X^\top)^{-1/2}$.</p> <p>To obtain our initial version of Gram Newton-Schulz, we simply lift the iteration from Theorem 1 to matrices. As in standard Newton-Schulz, each matrix operation preserves singular vectors. Therefore, each singular value of $\mathbf R_t$, $\mathbf Q_t$, and $\mathbf Z_t$ evolves independently of the others according to the scalar iteration described above. Note that while this algorithm is mathematically equivalent to standard Newton-Schulz, it is not yet practical due to numerical instability. The only difference between <a href="#alg-stable-gram-ns">our proposed method</a> and this naive version is the presence of what we call a “restart” at the beginning of iteration 3 of the loop. We will motivate this modification soon.</p> <p><a id="alg-naive-gram-ns"></a></p> <blockquote> <p><strong>Algorithm 2: Naive Gram Newton-Schulz</strong></p> <p>Input: $\mathbf X \in \mathbb{R}^{n \times m}$ with $n \leq m$, coefficients ${(a_t, b_t, c_t)}_{t=1}^5$</p> <ol> <li>$\mathbf X \gets \mathbf X \,/\, (\lVert\mathbf X\rVert_{\mathsf F} + \epsilon)$    // Normalize sing vals to $[0, 1]$. $\epsilon = 10^{-7}$</li> <li>$\mathbf R_0 = \mathbf X \mathbf X^\top$</li> <li>$\mathbf Q_0 = \mathbf I$</li> <li>For $t = 1, \ldots, 5$:</li> <li>   $\mathbf Z_t \gets a_t\mathbf I + b_t \mathbf R_{t-1} + c_t \mathbf R_{t-1}^2$    // Apply $h_t(\mathbf R_{t-1})$</li> <li>   $\mathbf Q_t \gets \mathbf Q_{t-1} \mathbf Z_t$</li> <li>   $\mathbf R_t \gets \mathbf Z_t \mathbf R_{t-1} \mathbf Z_t$</li> <li>Return $\mathbf Q_5 \mathbf X$</li> </ol> </blockquote> <p>Gram Newton-Schulz is closely akin to a method proposed in Appendix F of the Polar Express paper.<sup id="fnref:polar-express:1"><a href="#fn:polar-express" class="footnote" rel="footnote" role="doc-noteref">12</a></sup> Both form the Gram matrix and transform standard Newton-Schulz into an iteration on $n \times n$ matrices. Both aim to reduce the FLOP cost of Newton-Schulz. However, our work supersedes the proposal from Appendix F in several ways. First, the precise formulas of Gram Newton-Schulz are different, and we believe more stable. Second, we use symmetric matrix multiplication kernels; the opportunity to use these kernels more is an essential advantage of Gram Newton-Schulz not studied previously, and using symmetric matrix multiplication can have subtly different numerical properties in half-precision that require more careful stability strategies. Third, we undertake a thorough stability analysis and provide practical recommendations that allow Gram Newton-Schulz to be used in practice with minimal ad-hoc hyperparameter tuning.</p> <h2 id="runtime-of-naive-gram-newton-schulz">Runtime of Naive Gram Newton-Schulz</h2> <p>Let’s measure the FLOP count of this new algorithm to see how its runtime improves on standard Newton-Schulz. There are four matrix multiplications per iteration. If we use our symmetric GEMM kernel, these cost:</p> <ul> <li>$\mathbf R_{t-1}^2$: $n^3$</li> <li>$\mathbf Q_{t-1} \mathbf X_t$: $n^3$</li> <li>$\mathbf Z_t \mathbf R_{t-1} \mathbf Z_t$: $n^3 + n^3$</li> </ul> <p>The initialization and output steps cost:</p> <ul> <li>$\mathbf X \mathbf X^\top$: $mn^2$</li> <li>$\mathbf Q_5 \mathbf X$: $2mn^2$ (not symmetric)</li> </ul> <p>Lastly, computing $\mathbf Q_1 = \mathbf Z_1$ is free since $\mathbf Q_0 = \mathbf I$, and we do not need to compute $\mathbf R_5$:</p> <ul> <li>Skipping $\mathbf Q_0 \mathbf Z_1,\,\mathbf Z_5 \mathbf R_4 \mathbf Z_5$: $-3n^3$</li> </ul> <p>Thus, the total FLOP count is $T\cdot4n^3 + 3mn^2 - 3n^3 = (4T + 3\alpha - 3)n^3$ for general $T$, or $(17 + 3\alpha)n^3$ across 19 GEMMs for $T=5$. Compare this to standard Newton-Schulz’s $T(3\alpha + 1)n^3$ FLOPs when using symmetric GEMMs. When $\alpha = 1$, they are equal. When $\alpha &gt; 1$, Gram Newton-Schulz is cheaper. For a typical Muon application ($T=5, \alpha = 4$), <strong>it saves 55% of the FLOPs</strong> used by standard Newton-Schulz with symmetric GEMMs, <strong>or 68%</strong> compared to a typical implementation without symmetric GEMMs.</p> <p>In practice, when $\alpha=1$, we fall back to <a href="#kernel-optimizations-for-standard-newton-schulz">standard Newton-Schulz with our symmetric GEMMs</a>, since it launches fewer GEMMs and will have a faster wall clock time.</p> <h1 id="instability-of-naive-gram-newton-schulz">Instability of Naive Gram Newton-Schulz</h1> <p>Let’s try training a transformer LLM with Muon using Naive Gram Newton-Schulz:</p> <p>![llama_430<em>no_reset](https://hackmd.io/_uploads/SJiRa6W9Wl.png) _Figure 2: Naive Gram Newton-Schulz on Llama-430M.</em></p> <p>This is no good. Not only do we get loss spikes, but eventually, the output of Gram Newton-Schulz is full of Infs! While Gram Newton-Schulz is mathematically equivalent to standard Newton-Schulz in exact arithmetic, it behaves differently in finite precision, especially in half precision.</p> <p>We will now pause to explain the source of this instability in detail and motivate our solution. Readers not concerned with these technical details can <a href="#stabilized-gram-newton-schulz">skip ahead</a> to see the stabilized method. Code for running these stability experiments and generating the figures is available <a href="https://github.com/NoahAmsel/PolarExpress/blob/appF-stability/gram_newton_schulz_stability.ipynb">here</a>.</p> <h2 id="tracking-eigenvalues-of-intermediate-matrices">Tracking Eigenvalues of Intermediate Matrices</h2> <p>We can understand how matrices evolve and why they diverge by studying their eigenvalues and singular values. Recall that the entries of any matrix are upper bounded by its largest singular value, so if we control the singular values, we will prevent blowups.</p> <p>If $\mathbf X = \mathbf U \mathbf \Sigma \mathbf V^\top$ is the SVD of the input matrix, then intermediate matrices of Algorithm 2 ($\mathbf R_t$, $\mathbf Q_t$, $\mathbf Z_t$) are square symmetric with eigenvectors $\mathbf V$. In exact arithmetic, $\mathbf U^\top \mathbf R_t \mathbf U$ is a diagonal matrix containing $\mathbf R_t$’s eigenvalues, each of which corresponds to a singular value of $\mathbf X$. We can therefore plot the eigenvalues of $\mathbf R_t$ and $\mathbf Q_t$ against the corresponding singular values of $\mathbf X$ to track how each evolves according to the polynomial update rules—or diverges from them.</p> <p>To see how things should look, let’s start by running Naive Gram Newton-Schulz in full <code class="language-plaintext highlighter-rouge">float64</code> precision for $10$ steps. We will use a synthetic input—a $128 \times 512$ matrix with an exponentially decaying spectrum. In order to make our plots more readable, with smooth monotonic curves, the experiments in this section use the coefficients $(a_t, b_t, c_t) = (\tfrac{15}8, \tfrac{10}8, \tfrac38)$ at every iteration. The numerical behavior we observe will generalize to other coefficients; those used in practice (like You Jiacheng’s or Polar Express) will blow up at an even earlier iteration, matching the behavior we observe in training.<sup id="fnref:you"><a href="#fn:you" class="footnote" rel="footnote" role="doc-noteref">19</a></sup><sup id="fnref:polar-express:2"><a href="#fn:polar-express" class="footnote" rel="footnote" role="doc-noteref">12</a></sup> Even though our method does not need to compute the intermediate matrices $\mathbf X_1, \ldots, \mathbf X_{T-1}$, we do so here for demonstration using the formula $\mathbf X_t = \mathbf Q_t \mathbf X_0$, where we label the input $\mathbf X_0$ for clarity.</p> <p>![f64<em>diagnostics](https://hackmd.io/_uploads/Bkm4pJ_iZg.gif) _Figure 3: Evolution of eigenvalues of $\mathbf R_t$, $\mathbf Q_t$, and $\mathbf X_t$ in Float64 in Naive Gram Newton-Schulz with coefficients $(\tfrac{15}8, \tfrac{10}8, \tfrac38)$.</em></p> <p>Initially, we have $r_0 = x_0^2$, and $q_0 = 1$. As the algorithm proceeds, we know that $x_t \to 1$, so we expect $r_t \to 1$ and $q_t = x_t / x_0 \to 1/x_0 = r_0^{-1/2}$ as per Theorem 1. Note that if $x_0$ is close to 1, the method converges quickly, while if $x_0$ is close to zero, it converges slowly. After 10 iterations, the spectrum of $\mathbf X_t$ is visually indistinguishable from $1$, as expected.</p> <p>Now let’s repeat the experiment using <code class="language-plaintext highlighter-rouge">bfloat16</code> instead of <code class="language-plaintext highlighter-rouge">float64</code> arithmetic:</p> <p>![f16<em>diagnostics](https://hackmd.io/_uploads/B1PA2ydo-e.gif) _Figure 4: Evolution of eigenvalues of $\mathbf R_t$, $\mathbf Q_t$, and $\mathbf X_t$ in BFloat16 in Naive Gram Newton-Schulz with coefficients $(\tfrac{15}8, \tfrac{10}8, \tfrac38)$.</em></p> <p>The first few iterations proceed as before. However, by step 7, we see unexpected behavior in the spectrum of $\mathbf X_t$. The singular values that began near $0$ suddenly jump up above 1, instead of converging to 1 from below. By step 8, the algorithm is returning complete junk. What happened?</p> <p>We identify two key causes of divergence:</p> <ol> <li>Spurious negative eigenvalues of the Gram matrix $\mathbf X \mathbf X^\top$</li> <li>Eigenvector drift</li> </ol> <h2 id="spurious-negative-eigenvalues">Spurious Negative Eigenvalues</h2> <p>The main cause of divergence is the presence of negative eigenvalues in the Gram matrix due to half-precision arithmetic. These negative eigenvalues blow up after a few iterations of Gram Newton-Schulz.</p> <p>If you look closely, you can see that the trouble begins in $\mathbf R_t$. By construction, $r_t = x_t^2 \geq 0$, so in exact arithmetic, $\mathbf R_t$ should be a positive semidefinite matrix. However, when using <code class="language-plaintext highlighter-rouge">bfloat16</code>, our plots show that $\mathbf R_t$ has negative eigenvalues! Because $\mathbf X_0$ is numerically low rank, $\mathbf R_0 = \mathbf X_0 \mathbf X_0^\top$ has eigenvalues that are <em>numerically</em> equal to zero, and in <code class="language-plaintext highlighter-rouge">bfloat16</code>, a number like $-10^{-5}$ is numerically equal to zero. Let’s transform the y-axis to emphasize values close to zero and replot this:</p> <p>![f16<em>diagnostics_zoomed](https://hackmd.io/_uploads/ry9ZpJdjWg.gif) _Figure 5: Evolution of eigenvalues of $\mathbf R_t$, $\mathbf Q_t$, and $\mathbf X_t$ in BFloat16, with y-axis centered around $0$.</em></p> <p>Now we see that from the very beginning, $\mathbf R_0$ has tiny negative eigenvalues introduced in the first computation $\mathbf X_0 \mathbf X_0^\top$. Later computations can introduce additional negative eigenvalues to $\mathbf R_t$ too. These eigenvalues represent nothing about the original problem, they are just an artifact of floating point arithmetic. Therefore, we call them “spurious eigenvalues”.</p> <p>These spurious negative eigenvalues start small, but the plot shows that their magnitude grows quickly. Let’s understand mathematically why this happens. Recall the update rule: \(r_t = r_{t-1} z_t^2 = r_{t-1} h_t(r_{t-1})^2\) If we now substitute $h_t(x) = \tfrac{15}8 - \tfrac{10}8 x + \tfrac38 x^2$ and plot this update rule, we can see the problem:</p> <p><img src="https://hackmd.io/_uploads/HkOvAkujWe.svg" alt="r_update_map"/></p> <p><em>Figure 6: Negative values of $r_t$ diverge towards negative infinity.</em></p> <p>As the plot shows, $r_t &lt; \left(\tfrac{15}{8}\right)^2 r_{t-1}$. Thus, if $r_0 &lt; 0$, the magnitude of the spurious eigenvalues grows exponentially! This sets off a chain reaction. As $r_t \to -\infty$, we get $z_t \to \infty$. This causes $q_t \to \infty$ and therefore also $x_t \to \infty$. This problem cannot be fixed by choosing different polynomials. Conceptually, in the main loop, we are attempting to compute the inverse square root of a negative number. It cannot help but diverge.</p> <p>To show that the spurious negative eigenvalues of $\mathbf R_0$ are enough to cause this catastrophic failure, let’s rerun the method with every operation in <code class="language-plaintext highlighter-rouge">float64</code> precision, except that we will convert $\mathbf R_0$ from <code class="language-plaintext highlighter-rouge">float64</code> to <code class="language-plaintext highlighter-rouge">bfloat16</code> and then back to <code class="language-plaintext highlighter-rouge">float64</code> to induce a little floating point error. As you can see, even this causes a blowup.</p> <p>![posthoc_f16<em>diagnostics](https://hackmd.io/_uploads/HyEH0k_iZx.gif) _Figure 7: Evolution of eigenvalues of $\mathbf R_t$, $\mathbf Q_t$, and $\mathbf X_t$ when all operations use Float64 except $\mathbf R_0 = \mathbf X \mathbf X^\top$.</em></p> <p>Recall that the average magnitude of a matrix’s entries (root mean squared) is proportional to its Frobenius norm, which is larger than the largest singular value. Therefore, as $\mathbf Q_t$’s largest singular value blows up, its entries do too.</p> <h2 id="eigenvector-drift">Eigenvector Drift</h2> <p>Spurious negative eigenvalues are not the only source of instability. If we take as input a matrix that excludes small singular values (i.e., all $\geq 0.017$), then we do not observe any negative eigenvalues in $\mathbf R_t$, but we still see a moderate blow up in $\mathbf X_t$. The culprit seems to be eigenvector drift.</p> <p>In exact arithmetic, the eigenvectors of all intermediate matrices match $\mathbf U$, the left singular vectors of $\mathbf X_0$, but in finite precision they do not. This effect can be measured by observing how far $\mathbf U^\top \mathbf R_t \mathbf U$, $\mathbf U^\top \mathbf Q_t \mathbf U$, and $\mathbf U^\top \mathbf X_t \mathbf V$ are from being diagonal matrices. The plot below shows that after several iterations, the eigenvectors of $\mathbf Q_t$ and $\mathbf X_t$ have drifted significantly. At the same time, we see the eigen<em>values</em> of $\mathbf Q_t$ (and by extension, those of $\mathbf X_t$) diverge from where they should be in exact arithmetic. The growing eigenvalues of $\mathbf Q_t$ seem to spill into one another. The strength of this effect is less consistent than that of negative eigenvalues, but it is still harmful.</p> <p><img src="https://hackmd.io/_uploads/HyuQYBm5-x.svg" alt="easy_spectrum_diagnostics"/> <em>Figure 8: As the eigenvectors drift (left) the spectral norms of $\mathbf R_t$, $\mathbf Q_t$, and $\mathbf X_t$ diverge.</em></p> <h2 id="stabilizing-gram-newton-schulz-by-restarting">Stabilizing Gram Newton-Schulz by Restarting</h2> <p>If we run Gram Newton-Schulz for more than a few iterations, the spurious negative eigenvalues grow unmanageably large and $\mathbf Q_t$ blows up. Our solution is simple: run Gram Newton-Schulz for only a few iterations. Rather than using Gram Newton-Schulz to compute $\mathbf X_T$ directly, we use it to compute, say, $\mathbf X_5$ in a stable manner for coefficients $(\tfrac{15}8, \tfrac{10}8, \tfrac38)$. While $\mathbf X_5$ is not a good approximation to $\lim_{T \to \infty} \mathbf X_T = \mathrm{polar}(\mathbf X_0)$, we are closer than when we started. Now we can apply Gram Newton-Schulz a second time on the input $\mathbf X_5$ to compute $\mathbf X_{10}$ stably. We can repeat this over and over to reach whatever $T$ we like. This restarting technique sacrifices some of the performance gains of Gram Newton-Schulz, but it still offers a significant speedup over standard Newton-Schulz.</p> <p>Below we plot the results of this method on the same test matrix used above. As before, we compute $\mathbf X_t$ for all $t$ for diagnostic purposes, though the algorithm computes only $\mathbf X_5, \mathbf X_{10}, \ldots, \mathbf X_{30}$. Looking closely, you can see that $\mathbf R_t$ develops some negative eigenvalues, but unlike before, the growth of these eigenvalues is controlled. Each time we restart, we re-initialize $\mathbf R_t = \mathbf X_t \mathbf X_t^\top$, eliminating any negative eigenvalues of large magnitude. As you can see, at iteration $5, 10, 20, 25$, and $30$, $\mathbf Q_t$ resets to the identity. Therefore, the eigenvalues of $\mathbf Q_t$ never grow beyond $\approx 12$, despite the negative eigenvalues in $\mathbf R_t$. Since the eigenvalues of $\mathbf Q_t$ remain controlled, those of $\mathbf X_t = \mathbf Q_t \mathbf X_{t-5}$ stay strictly smaller than $1$.</p> <p>![restart5<em>diagnostics](https://hackmd.io/_uploads/BJ5oC1_oZl.gif) _Figure 9: Restarting prevents the divergence of $\mathbf R_t$.</em></p> <p>Restarting also helps control eigenvector drift. We repeat the experiment from above on the same matrix (with all singular values $&gt; 0.017$), but now with a restart after step 5. We observe that the diagonalization error remains $\leq 0.05$ for all matrices, and the maximum eigenvalues now align closely with their values in exact arithmetic. Note that we always measure eigenvector drift relative to the original input $\mathbf X_0$, not the restarted $\mathbf X_5$.</p> <p>![easy_spectrum_restart2<em>diagnostics](https://hackmd.io/_uploads/rkmItSQ5Wx.svg) _Figure 10: Restarting prevents eigenvector drift.</em></p> <h3 id="when-to-restart-polar-express-coefficients-for-muon">When to Restart: Polar Express Coefficients for Muon</h3> <p>At what iteration should we restart? To avoid numerical trouble, we need to control the magnitude of $\mathbf Q_t$, even when $\mathbf R_0$ has spurious negative eigenvalues. (Because each $q_r \geq 1$, this is equivalent to controlling the condition number of $\mathbf Q_t$.) So long as each eigenvalue of $\mathbf Q_t$ remains smaller than the inverse of the corresponding eigenvalue from $\mathbf X$, then $\mathbf X_t = \mathbf Q_t \mathbf X$ will have eigenvalues $\leq 1$.</p> <p>The growth of $\mathbf Q_t$ in turn depends on the size of the spurious negative eigenvalues and the specific sequence of polynomials we use. Furthermore, since the polynomial $p_t$ changes at each iteration, it may not be ideal to restart at regular intervals. Instead, we can choose when to restart adaptively, depending on the specific sequence of polynomials we have applied since the previous restart.</p> <p>For the application to Muon, let’s now switch over to using five iterations of the Polar Express polynomials, which are defined as follows:</p> <table> <thead> <tr> <th>$t$</th> <th>$a$</th> <th>$b$</th> <th>$c$</th> </tr> </thead> <tbody> <tr> <td>1</td> <td>8.123737</td> <td>-22.232240</td> <td>16.373715</td> </tr> <tr> <td>2</td> <td>4.026529</td> <td>-2.776323</td> <td>0.514551</td> </tr> <tr> <td>3</td> <td>3.870284</td> <td>-2.739120</td> <td>0.520999</td> </tr> <tr> <td>4</td> <td>3.253351</td> <td>-2.343223</td> <td>0.481420</td> </tr> <tr> <td>5</td> <td>2.300652</td> <td>-1.668904</td> <td>0.418807</td> </tr> </tbody> </table> <p>In the example above, we observe that the most negative spurious eigenvalue of $\mathbf R_0$ is about $-4 \cdot 10^{-4}$. Using the scalar analogue Gram Newton-Schulz, let’s simulate how the eigenvalues of $\mathbf Q_t$ evolve in full precision when $\mathbf R_0$ has eigenvalues in the range $[-4\cdot 10^{-4}, 1]$. With no restart, they blow up:</p> <p><img src="https://hackmd.io/_uploads/BJjkC1dsZl.svg" alt="polar_no_restart_growth (1)"/></p> <p><em>Figure 11: Min/max eigenvalue of $\mathbf R_t$ and $\mathbf Q_t$ without restarts. $\mathbf R_0$ starts with a negative eigenvalue as low as $-4 \times 10^{-4}$.</em></p> <p>Now let’s repeat the experiment with a restarted version of the algorithm. To obtain a good balance of stability and speed, let’s limit ourselves to a single restart. When should this restart take place? We’ll try all possibilities. As above, we begin with eigenvalues in the range $[-4 \cdot 10^{-4}, 1]$. Every time we restart and form $\mathbf R = \mathbf X\mathbf X^\top$, we subtract $4 \times 10^{-4}\mathbf I$ to simulate a potentially dangerous shift in the eigenvalues due to floating point error. As you can see, restarting after the second iteration ensures that the eigenvalues of $\mathbf R_t$ stay well above $-0.4$ and that the condition number of $\mathbf Q_t$ stay below $\approx 100$ for all iterations, much better than the other options.</p> <p><img src="https://hackmd.io/_uploads/Hk4qTyuiZx.svg" alt="polar1restart_results (1)"/></p> <p><em>Figure 12: Minimum eigenvalue of $\mathbf R_t$ and condition number of $\mathbf Q_t$ if restart is placed after iteration $1$, $2$, $3$, or $4$. $\mathbf R_0$ starts with a negative eigenvalue as low as $-4 \times 10^{-4}$. Restarting after iteration $2$ provides the best bound on $\mathbf Q_t$.</em></p> <p>Note that restarting works precisely because we reset the minimum negative eigenvalue of $\mathbf R_t$, which in turn tightens the bound on $\mathbf Q_t$’s eigenvalues. In <a href="https://github.com/Dao-AILab/gram-newton-schulz">our repo</a>, we provide a utility that performs this analysis. For any given Newton-Schulz coefficients and any number of restarts, it identifies the best iterations at which to restart.</p> <p>Now let’s run the full method with a restart after the second iteration on our test matrix. Now it converges! All singular values of $\mathbf X_t$ approach 1.</p> <p><img src="https://hackmd.io/_uploads/HyJU6JOiWl.gif" alt="final_diagnostics"/> <em>Figure 13: Restarting after $2$ iterations creates a stable polar decomposition of our test matrix with Polar Express coefficients.</em></p> <h2 id="further-precautions">Further Precautions</h2> <p>While restarting greatly improves stability, it is not absolutely foolproof. The usual numerical snags for Newton-Schulz still apply.</p> <p>For example, most choices of Newton-Schulz polynomials are designed to converge only when $\lVert\mathbf X_0\rVert \leq 1$; any singular values larger than $1$ may diverge rapidly.</p> <p><img src="https://hackmd.io/_uploads/rkkBqyBcZl.png" alt="X_final_unbounded (4)"/> <em>Figure 14: Theoretical behavior of both standard and Gram Newton-Schulz on $\sigma</em>{X<em>0}$ slightly above $1$ using Polar Express coefficients.</em></p> <p>Even with a properly normalized input, perturbed singular values of $\mathbf X_0$ slightly greater than $1$ can develop due to numerical error. This problem affects standard Newton-Schulz as well, so the Polar Express polynomials are typically adjusted according to the formula $\tilde p_t(x) = p_t(x / 1.02)$. This ensures convergence even for singular values as large as $1.02$. When using Gram Newton-Schulz, roundoff errors like this can worsen due to computations like $\mathbf X\mathbf X^\top$, which do not have built-in safety factors; however, we have never seen this happen when using our recommended setup (<code class="language-plaintext highlighter-rouge">float16</code> arithmetic with restarting after $2$ iterations). It is generally wise to be extra conservative in the choice of safety factor, for instance, by replacing $1.02$ with $1.05$.</p> <h3 id="float16-vs-bfloat16-in-newton-schulz">Float16 vs BFloat16 in Newton-Schulz</h3> <p>In addition, we argue for using <code class="language-plaintext highlighter-rouge">float16</code> instead of <code class="language-plaintext highlighter-rouge">bfloat16</code> to implement Newton-Schulz. Compared to <code class="language-plaintext highlighter-rouge">bfloat16</code>, <code class="language-plaintext highlighter-rouge">float16</code> can only represent values from a narrower range, but it has greater precision within that range. For our purposes, the range of <code class="language-plaintext highlighter-rouge">float16</code> (roughly $6.1\cdot 10^{-5}$ to $6.5 \cdot 10^4$) suffices because the magnitudes of our matrices are controlled to lie near 1. And in some cases, we can benefit from using <code class="language-plaintext highlighter-rouge">float16</code> to reduce numerical errors.</p> <p>On certain test matrices, we see more accurate $\operatorname{polar}(\mathbf X)$ approximations with <code class="language-plaintext highlighter-rouge">float16</code>, but in practice, we have not found a case where the pretraining loss is meaningfully different between <code class="language-plaintext highlighter-rouge">float16</code> and <code class="language-plaintext highlighter-rouge">bfloat16</code>. Still, we default to <code class="language-plaintext highlighter-rouge">float16</code>.</p> <h3 id="computing-matrix-quadratics">Computing Matrix Quadratics</h3> <p>A key step in Gram Newton-Schulz is computing the matrix quadratic $\mathbf Z_t \gets a_t\mathbf I + b_t \mathbf R_{t-1} + c_t \mathbf R_{t-1}^2$. PyTorch implementations of Newton-Schulz typically do not assemble such polynomials explicitly; to compute $\mathbf X(\mathbf a_t \mathbf I + b_t \mathbf A + c_t \mathbf A^2)$, they partly distribute $\mathbf X$ and use two calls to <code class="language-plaintext highlighter-rouge">torch.baddbmm</code>, which dispatches to cuBLAS GEMM, as follows</p> <blockquote> <ol> <li>   $\mathbf B \gets b_t \mathbf A + c_t \mathbf A^2$</li> <li>   $\mathbf X \gets a_t \mathbf X + \mathbf B \mathbf X$</li> </ol> </blockquote> <p>Our symmetric GEMM kernel is capable of fusing these matrix quadratics into a single step. In particular, it can fuse the addition of $\gamma \mathbf I$ by adding $\gamma$ to all diagonal entries of the output when they are at the register level. This optimization completely obviates any I/O operations needed for the $\gamma I$ addition, typically outspeeding <code class="language-plaintext highlighter-rouge">gemm_symmetric(A, B, C, alpha, beta) + gamma * I</code>, which would require loading $\mathbf I$ from general memory to shared memory to registers. Once $\mathbf Z_t$ is assembled, Gram Newton-Schulz can use it in three subsequent multiplications.</p> <p>However, our tests show that adding $\gamma \mathbf I$ explicitly can be less stable than handling it implicitly in some corner cases. If we stress-test our method by ignoring some of our own advice—restarting after three iterations instead of two and using a Polar Express safety factor of $1.02$ instead of $1.05$, and computing the quadratic with $a_t \mathbf I$ explicitly—then we observe instability. This instability disappears if we use non-symmetric GEMMs (either from <code class="language-plaintext highlighter-rouge">torch</code> or Quack) instead of our symmetric kernels. We conclude that our fused quadratic kernel can hurt stability in this setting. Since we reproduce this issue by forcing symmetry after calling standard <code class="language-plaintext highlighter-rouge">torch</code> GEMMs, we know this is not a kernel bug, but a numerical property.</p> <p>We believe this effect can be explained as follows. While the fused kernel computes $a_t\mathbf I + b_t \mathbf R_{t-1} + c_t \mathbf R_{t-1}^2$ in <code class="language-plaintext highlighter-rouge">float32</code> arithmetic under the hood, the result $\mathbf Z_t$ is rounded back down to <code class="language-plaintext highlighter-rouge">float16</code> at the end of the GEMM. Future computations like $\mathbf Q_t \mathbf Z_t$ suffer from this loss of precision in $a_t$. In contrast, if the $a_t \mathbf I$ term is handled implicitly, all arithmetic involving $a_t$ takes place in <code class="language-plaintext highlighter-rouge">float32</code>. Therefore, it is more accurate to compute $a_t \mathbf Q_t + \mathbf Q_t\left(b_t \mathbf R_{t-1} + c_t \mathbf R_{t-1}^2\right)$ than $\mathbf Q_t\left(a_t\mathbf I + b_t \mathbf R_{t-1} + c_t \mathbf R_{t-1}^2\right)$.</p> <p>We reiterate that in all our experiments, this instability can be avoided entirely by restarting correctly or by using a higher safety factor of $1.05$. Out of an abundance of caution, we rearrange the arithmetic of <a href="#alg-naive-gram-ns">Naive Gram Newton-Schulz</a> to avoid adding $\mathbf I$ explicitly. That is, we change</p> <blockquote> <ol> <li>   $\mathbf Z_t \gets a_t\mathbf I + b_t \mathbf R_{t-1} + c_t \mathbf R_{t-1}^2$    // Apply $h_t(\mathbf R_{t-1})$</li> <li>   $\mathbf Q_t \gets \mathbf Q_{t-1} \mathbf Z_t$</li> <li>   \((\mathbf{RZ})_t \gets \mathbf R_{t-1} \mathbf Z_t\)</li> <li>   \(\mathbf R_t \gets \mathbf Z_t (\mathbf{RZ})_t\)</li> </ol> </blockquote> <p>to</p> <blockquote> <ol> <li>   $\mathbf Z_t \gets b_t \mathbf R_{t-1} + c_t \mathbf R_{t-1}^2$</li> <li>   $\mathbf Q_t \gets \mathbf Q_{t-1} \mathbf Z_t + a_t\mathbf Q_{t-1}$</li> <li>   \((\mathbf{RZ})_t \gets \mathbf R_{t-1} \mathbf Z_t + a_t\mathbf R_{t-1}\)</li> <li>   \(\mathbf R_t \gets \mathbf Z_t (\mathbf{RZ})_t + a_t(\mathbf{RZ})_t\)</li> </ol> </blockquote> <p>This change fixes all collected examples in which symmetric GEMMs were less stable than non-symmetric GEMMs.</p> <h2 id="takeaways-on-stability">Takeaways on Stability</h2> <p>While Gram Newton-Schulz is fundamentally more unstable than standard Newton-Schulz, it can be coaxed into behaving equally stably with the proper care. The understanding gleaned from these experiments gives us the confidence to use Gram Newton-Schulz in practice. However, users should be willing to monitor the method, and if they find instability, to adjust the hyperparameters (e.g., $1.02 \to 1.05$ above). For example, a second restart may be required if using a particularly sensitive set of coefficients or if running more than five iterations with Polar Express polynomials.</p> <p>In the application of Muon for pretraining, we do not need very high polar decomposition accuracy, and our experiments below show that Muon with Gram Newton-Schulz yields effectively identical results to Muon with standard Newton-Schulz in terms of training quality. However, when high accuracy is desired, the usual warnings about forming the Gram matrix apply. Since forming $\mathbf X \mathbf X^\top$ immediately squares the condition number, Gram Newton-Schulz may not be appropriate in these cases.</p> <h1 id="stabilized-gram-newton-schulz">Stabilized Gram Newton-Schulz</h1> <p>We now present our complete algorithm, which enjoys the speed of naive Gram Newton-Schulz while remaining numerically stable. We use five iterations of Newton-Schulz with degree-5 polynomials (such as Polar Express). We use <code class="language-plaintext highlighter-rouge">float16</code> arithmetic, and we “restart” after the first two iterations by setting $\mathbf X \gets \mathbf Q_2 \mathbf X$ and reinitializing $\mathbf R_2$ and $\mathbf Q_2$. As in standard Newton-Schulz, we write the logic of our routine assuming that $\mathbf X$ has more columns than rows. If this is not the case, we simply run on $\mathbf X^\top$ and output the transpose of the result.</p> <p><a id="alg-stable-gram-ns"></a></p> <blockquote> <p><strong>Algorithm 3: Stabilized Gram Newton-Schulz</strong></p> <p>Input: $\mathbf X \in \mathbb{R}^{n \times m}$ with $n \leq m$, coefficients ${(a_t, b_t, c_t)}_{t=1}^5$</p> <ol> <li>$\mathbf X \gets \mathbf X \,/\, (\lVert\mathbf X\rVert_{\mathsf F} + \epsilon)$    // Normalize sing vals to $[0, 1]$. $\epsilon = 10^{-7}$</li> <li>$\mathbf X \gets \texttt{float16}(\mathbf X)$     // Cast to half precision for speed</li> <li>If $m &lt; n$:  $\mathbf X \gets \mathbf X^\top$   // Trick to make $\mathbf X \mathbf X^\top$ cheaper</li> <li>$\mathbf R_{0} \gets \mathbf X \mathbf X^\top$</li> <li>$\mathbf Q_{0} \gets \mathbf I$</li> <li>For $t = 1, \ldots, 5$:</li> <li>   If $t = 3$:        // Restart to stabilize</li> <li>     $\mathbf X \gets \mathbf Q_2 \mathbf X$</li> <li>     $\mathbf R_2 \gets \mathbf X \mathbf X^\top$</li> <li>     $\mathbf Q_2 \gets \mathbf I$</li> <li>   $\mathbf Z_t \gets b_t \mathbf R_{t-1} + c_t \mathbf R_{t-1}^2$</li> <li>   $\mathbf Q_t \gets \mathbf Q_{t-1} \mathbf Z_t + a_t\mathbf Q_{t-1}$</li> <li>   \((\mathbf{RZ})_t \gets \mathbf R_{t-1} \mathbf Z_t + a_t\mathbf R_{t-1}\)</li> <li>   \(\mathbf R_t \gets \mathbf Z_t (\mathbf{RZ})_t + a_t(\mathbf{RZ})_t\)</li> <li>$\mathbf X \gets \mathbf Q_4 \mathbf X$</li> <li>If $m &lt; n$:  $\mathbf X \gets \mathbf X^\top$  // Undo trick</li> <li>Return $\mathbf X$</li> </ol> </blockquote> <h2 id="runtime-of-stabilized-gram-newton-schulz">Runtime of Stabilized Gram Newton-Schulz</h2> <p>Above, we showed that Naive Gram Newton-Schulz uses $(4T + 3\alpha - 3)n^3$ FLOPs. How does restarting change this? It requires two additional matrix multiplications:</p> <ul> <li>$\mathbf X \gets \mathbf Q_2 \mathbf X$: $2mn^2$</li> <li>$\mathbf R_2 \gets \mathbf X \mathbf X^\top$: $mn^2$</li> </ul> <p>Since the initial value of $\mathbf R_2$ is discarded and $\mathbf Q_2 = \mathbf I$, it also allows us to skip three matrix multiplications:</p> <ul> <li>$\mathbf R_2 \gets \mathbf Z_2 \mathbf R_1\mathbf Z_2$: $-n^3 - n^3$</li> <li>$\mathbf Q_3 \gets \mathbf Q_2 \mathbf Z_3$: $-n^3$</li> </ul> <p>Therefore, Stabilized Gram Newton-Schulz with one restart uses $(4T + 6\alpha - 6)n^3$ FLOPs. As before, this matches standard Newton-Schulz for $\alpha = 1$ and improves on it for $\alpha &gt; 1$. For $T=5, \alpha = 4$, our algorithm reduces the number of FLOPs by 42% compared to standard Newton-Schulz with symmetric GEMMs, or by 58% compared to typical implementations lacking symmetric GEMMs.</p> <p>Observe that if we hypothetically used more restarts, each would increase the FLOPs by $3mn^2 - 3n^3$. With $T-1$ restarts, Gram Newton-Schulz would be exactly the same algorithm as standard Newton-Schulz.</p> <p>In this sense, adding restarts can be viewed as trading wall clock time for greater guaranteed stability, with the extrema being Naive Gram Newton-Schulz and standard Newton-Schulz.</p> <h1 id="symmetric-gemm-kernels-in-cutedsl">Symmetric GEMM Kernels in CuTeDSL</h1> <p>To take advantage of the greater share of symmetric matrix multiplications enabled by Gram Newton-Schulz, we implement kernels for the operations $\mathbf A \mathbf B$ and $\alpha \mathbf A \mathbf B + \beta \mathbf C$ that assume $\mathbf A \mathbf B$ and $\mathbf C$ are symmetric. Symmetric kernels also accelerate standard Newton-Schulz; this idea has been around for a while for the construction of the Gram $\mathbf{XX^\top}$, but to our knowledge, hasn’t been explored for fused symmetric matrix multiplication with addition.<sup id="fnref:flashmuon:1"><a href="#fn:flashmuon" class="footnote" rel="footnote" role="doc-noteref">9</a></sup><sup id="fnref:laker"><a href="#fn:laker" class="footnote" rel="footnote" role="doc-noteref">20</a></sup> We target the Hopper and Blackwell GPU architectures and <a href="https://github.com/Dao-AILab/quack/blob/main/quack/gemm_symmetric.py">open source</a> our implementation in the <a href="https://github.com/Dao-AILab/quack">Quack</a> library of CuTeDSL kernels developed by our lab.</p> <p><img src="https://hackmd.io/_uploads/BJq8sVPibl.png" alt="gemm_benchmarks (1)"/> <em>Figure 15: SOTA Symmetric GEMM Kernels benchmarked on Hopper and Blackwell against cuBLAS.</em></p> <h2 id="layout-engineering-and-work-scheduling">Layout Engineering and Work Scheduling</h2> <p>GEMM implementations of $\mathbf A \mathbf B$ and $\alpha \mathbf A \mathbf B + \beta \mathbf C$ can be broken down into the following components:</p> <ol> <li>How do we schedule GEMM output tiles as work among groups of workers?</li> <li>Once assigned a tile, how does a group of workers compute the tile?</li> </ol> <p>In most GEMM and fused GEMM kernels, tiles are computed in the same way, with the following components:</p> <ol> <li>The prologue, in which rows of $A$ and columns of $B$ needed for the current tile are loaded in from general memory (high-bandwidth memory) to shared memory (SRAM)</li> <li>Matrix-Multiply Accumulate (MMA), in which those rows and columns are multiplied and written to the register file in Hopper or tensor memory in Blackwell</li> <li>The epilogue, in which additional tensors needed for the fusion are loaded in, the fused arithmetic occurs, and the final values are written to the output tensor(s), from the register file to shared to general memory. An example is loading in $C$, $\alpha$, $\beta$, and then scaling $\mathbf A \mathbf B$ with $\alpha$ and adding $\beta \mathbf C$.</li> </ol> <p>Our symmetric GEMM kernel and the standard GEMM kernel only differ in how they schedule and partition output tiles as work and how they implement their epilogues.</p> <h3 id="triangular-scheduler">Triangular Scheduler</h3> <p>In the standard GEMM, the entire output matrix is divided into work tiles that are load balanced and evenly partitioned amongst clusters of thread blocks, where thread blocks in the same cluster can access the same shared memory and are therefore scheduled to run together. Then, each cluster computes its assigned work tiles in succession.</p> <p>Our tile scheduler in the symmetric GEMM is almost identical. The only difference is that only the work tiles in the lower triangle of the matrix are partitioned amongst the clusters, and work tiles in the upper triangle are unassigned, since their values are identical to the transposed values of the lower triangle.</p> <p>Instead of using the standard tile scheduler which evenly divides the tiles of both triangles among the clusters, we use a <em>triangular scheduler</em> to evenly divide only the tiles of only the lower triangle among the clusters. This ensures that every cluster gets assigned the same number of tiles that actually need to be worked on, ensuring load balancing.</p> <h3 id="epilogue-writing-to-the-transposed-tile">Epilogue: Writing to the Transposed Tile</h3> <p>In the GEMM epilogue, when the computed values of the lower triangle are written to their assigned tile in general memory (HBM), they are also written to their transposed tile location in the upper triangle.</p> <p><img src="https://hackmd.io/_uploads/SkkoEAVq-e.png" alt="symm_gemm_diagram (1)"/> <em>Figure 16: Symmetric GEMM only computes $256 \times 256$ work tiles on the diagonal and in the lower triangle, copying and transposing each lower tile to its transposed location in the upper triangle.</em></p> <p>We implement all of our symmetric GEMM kernels with square cluster work tiles. Hopper uses cluster size $(2, 1)$ and thread block tile size $(128, 256)$, and Blackwell uses cluster size $(2, 1)$ and 2-CTA collaboration, in which the 2 thread blocks in the cluster collaborate on the same big $(256, 256)$ tile.</p> <p>Notably, highly optimized custom GEMM kernels on Hopper typically use Ping Pong Scheduling, in which the MMA of tile $i$ and the epilogue of tile $i-1$ are overlapped in two consumer warp groups<sup id="fnref:ping-pong"><a href="#fn:ping-pong" class="footnote" rel="footnote" role="doc-noteref">21</a></sup>. However, Ping Pong Scheduling uses more registers at once, and $(128, 256)$ is too large of a tile size for Ping Pong Scheduling, leading to register spillage. This is much slower than standard single producer warp, single consumer warp scheduling. Thus, our Hopper symmetric kernels do not use Ping Pong Scheduling. Blackwell GEMM kernels have no explicit conception of Ping Pong Scheduling, since by default in both cuBLAS and Quack, two accumulators are kept in the new tensor memory hierarchy, and MMA is computed on one accumulator while the epilogue is computed on the other.</p> <p>As a small implementation detail, note that the main diagonal of $256 \times 256$ cluster work tiles is part of the work assigned by the triangular scheduler. Since their transposed locations are identical to their current locations, we only write those values to general memory once - writing twice can cause inaccurate values or NaN’s.</p> <h2 id="implementation-strategy-in-code">Implementation Strategy in Code</h2> <p>There are only two differences between the symmetric GEMM kernel and the standard GEMM kernel: the triangular scheduler and the transposed tile write in the epilogue. Quack is designed around abstracting the standard GEMM kernel to enable lightweight but maximally performant GEMM epilogue fusions. Using these abstractions, we are able to implement the symmetric GEMM kernel for both Hopper and Blackwell in just 160 lines, while achieving SOTA performance.</p> <p>We override the standard tile scheduler with our triangular scheduler and wrap the symmetric GEMM class around the <a href="https://github.com/Dao-AILab/quack/blob/main/quack/gemm_act.py">GEMM with activation</a> class. GEMM with activation itself is a wrapper around the <a href="https://github.com/Dao-AILab/quack/blob/main/quack/gemm_sm100.py">Blackwell</a> and <a href="https://github.com/Dao-AILab/quack/blob/main/quack/gemm_sm90.py">Hopper</a> default GEMMs. It supports writing two output tensors - the standard GEMM output (the preactivation) and the standard GEMM output with an activation function such as SwiGLU or ReLU applied (the postactivation). We define the activation function to be the identity and the postactivation tensor to be the inplace transpose of the preactivation tensor. Then, when the GEMM with activation class writes to the postactivation, it is really writing to the upper triangle with a tranposed layout - this is exactly the intent of the symmetric GEMM kernel. We override the epilogue of GEMM with activation just to ensure we don’t write twice to the diagonal tiles, for the correctness reasons mentioned previously.</p> <p>We’re super excited about this simplicity! Without this abstraction, the initial implementation was close to 1500 lines of CuTeDSL. It shows the convenience of principled abstractions in kernel engineering, specifically that of the GEMM main loop + Epilogue paradigm.<sup id="fnref:JCZ_anecdote"><a href="#fn:JCZ_anecdote" class="footnote" rel="footnote" role="doc-noteref">22</a></sup></p> <h2 id="kernel-optimizations-for-standard-newton-schulz">Kernel Optimizations for Standard Newton-Schulz</h2> <p>Using just our Quack CuTeDSL kernels, we can accelerate standard Newton-Schulz with two changes.</p> <ol> <li><strong>Symmetric Matrix Multiplication</strong> As discussed above, the matrices $\mathbf A = \mathbf X \mathbf X^\top$ and $\mathbf B = b_t \mathbf A + c_t \mathbf A^2$ computed at each iteration of Newton-Schulz are symmetric by definition. Therefore, we use our symmetric GEMM kernels for these operations, reducing their FLOP cost by half.</li> <li><strong>Fused GEMM + Add</strong> The typical way to implement the non-symmetric multiplication $\mathbf X \gets a_t \mathbf X + \mathbf B \mathbf X$ is to use <code class="language-plaintext highlighter-rouge">torch.baddbmm</code>, which calls cuBLAS under the hood. However, Quack offers a much faster implementation of this “Fused GEMM + Add” operation for Hopper. Unlike cuBLAS, Quack supports <a href="https://github.com/Dao-AILab/quack/blob/main/quack/gemm_sm90.py">Ping Pong Scheduling for Hopper</a>, which better hides the epilogue addition of $a_t \mathbf X$.</li> </ol> <p>This table shows that the total runtime of applying standard Newton-Schulz to all weight matrices of various LLMs decreases by about 25\% when combining these kernel optimizations on the Hopper architecture.</p> <table> <thead> <tr> <th>Model</th> <th><code class="language-plaintext highlighter-rouge">torch.compile</code> (Pure cuBLAS)</th> <th>1. CuTeDSL Symmetric GEMMs</th> <th>1. + 2. Fused GEMM Add</th> <th>Final Speedup over <code class="language-plaintext highlighter-rouge">torch.compile</code></th> </tr> </thead> <tbody> <tr> <td>Llama-430M</td> <td>18.909 ms</td> <td>16.114 ms</td> <td>13.71 ms</td> <td><strong>27% faster</strong></td> </tr> <tr> <td>Qwen-600M</td> <td>24.751 ms</td> <td>21.939 ms</td> <td>17.606 ms</td> <td><strong>29% faster</strong></td> </tr> <tr> <td>Gemma-1B</td> <td>75.055 ms</td> <td>66.063 ms</td> <td>55.444 ms</td> <td><strong>26% faster</strong></td> </tr> </tbody> </table> <p><em>Table 1: On Hopper, using symmetric kernels and Ping Pong Scheduling in GEMM + Add accelerates standard Newton-Schulz by around $25\%$ already.</em></p> <h1 id="training-experiments-and-benchmarks">Training Experiments and Benchmarks</h1> <p>We validate Gram Newton-Schulz’s training quality and performance gain on Llama-430M, Qwen-600M, Gemma-1B, and a custom MoE-1B architecture with ~20% active parameters across 1 billion total parameters.<sup id="fnref:llama"><a href="#fn:llama" class="footnote" rel="footnote" role="doc-noteref">23</a></sup><sup id="fnref:qwen:1"><a href="#fn:qwen" class="footnote" rel="footnote" role="doc-noteref">17</a></sup><sup id="fnref:gemma"><a href="#fn:gemma" class="footnote" rel="footnote" role="doc-noteref">24</a></sup></p> <p>We train on FineWeb-Edu. The number of training tokens for each dense model is given by the Chinchilla scaling law and for MoE-1B by twice the Chinchilla scaling law with respect to its active parameters. We use a cosine learning rate scheduler with the following base learning rates:</p> <table> <thead> <tr> <th>Model</th> <th>Learning Rate</th> </tr> </thead> <tbody> <tr> <td>Llama-430M</td> <td>3e-3</td> </tr> <tr> <td>Qwen-600M</td> <td>1.5e-3</td> </tr> <tr> <td>Gemma-1B</td> <td>3e-4</td> </tr> <tr> <td>MoE-1B</td> <td>2.5e-3</td> </tr> </tbody> </table> <p>For both profiling and full training runs, our Muon setup is as follows:</p> <ol> <li>Weights orthogonalized by Muon include $\mathbf W_q, \mathbf W_k, \mathbf W_v$ (the projection matrices for attention), $\mathbf W_o$ (the out-projection matrix following attention), $\mathbf W_{MLP_{UP}}$, $\mathbf W_{MLP_{GATE}}$, and $\mathbf W_{MLP_{DOWN}}$ (the SwiGLU MLP weights), and $\mathbf W_{router}$ (the token router matrix for MoE).</li> <li>Each instance of $\mathbf W_q, \mathbf W_k, \mathbf W_v, \mathbf W_o, \mathbf W_{MLP_{UP}}, \mathbf W_{MLP_{GATE}}$, $\mathbf W_{MLP_{DOWN}},$ and $\mathbf W_{router}$ is batched across all transformer layers; that is, we execute a Newton-Schulz call for all the $\mathbf W_q$’s at once, for all the $\mathbf W_k$’s at once, etc. Maximizing the batch size of Newton-Schulz improves efficiency by making the batched GEMM operations as compute-bound as possible.</li> </ol> <p>Muon is generally combined with a learning rate adjustment that scales the effective step size for each weight matrix based on its dimensions. We find that using Moonshot AI’s strategy of scaling the update by $0.2 \sqrt{\max(\mathrm{fan_out}, \mathrm{fan_in})}$—roughly matching the RMS of Muon’s update with that of AdamW—yields the best loss curves.<sup id="fnref:moonshot-muon-is-scalable"><a href="#fn:moonshot-muon-is-scalable" class="footnote" rel="footnote" role="doc-noteref">25</a></sup></p> <h3 id="splitting-the-weights">Splitting the Weights</h3> <p>We draw special attention to the fact that we split $\mathbf W_{MLP_{UP}}$ from $\mathbf W_{MLP_{GATE}}$ and orthogonalize them separately. Ordinarily, MLPs are implemented as Linear + SwiGlu + Linear, where the weight matrix of the first linear layer is a concatenation of $\mathbf W_{MLP_{UP}}$ and $\mathbf W_{MLP_{GATE}}$. However, the gradients flowing back into the $\mathbf W_{MLP_{UP}}$ and $\mathbf W_{MLP_{GATE}}$ halves are calculated quite differently since their contributions to the activation are fundamentally different. We observe that orthogonalizing them separately improves the final loss; for example, in Llama-430M, we observe an improvement of $\approx 0.2$ in perplexity. In addition, splitting $\mathbf W_{MLP_{UP/GATE}}$ halves its small dimension in MoE architectures, where the intermediate size is smaller than the hidden size, leading to greater speedup from Gram Newton-Schulz.</p> <p>Likewise, while earlier implementations of Muon orthogonalized the combined matrix $\begin{bmatrix} \mathbf W_q \,\vert\, \mathbf W_k \,\vert\, \mathbf W_v\end{bmatrix}$, we orthogonalize each piece separately.</p> <p>We are also aware that in some settings, including pretraining GLM-5, Muon benefits from splitting Multi-Latent Attention weights ($\mathbf W^{UQ}$, $\mathbf W^{UK}$, and $\mathbf W^{UV}$) by attention head before orthogonalizing.<sup id="fnref:GLM:1"><a href="#fn:GLM" class="footnote" rel="footnote" role="doc-noteref">2</a></sup> This choice is principled, since the actual matrix multiplications happening in attention are between attention heads rather than the full query, key, value, and out projections. On our test models, we experimented with splitting $\mathbf W_q$, $\mathbf W_k$, $\mathbf W_v$, and $\mathbf W_o$ by attention heads to form $H$ matrices each of size $\tfrac{d}{H} \times d$, where $d$ is the embedding dimension and $H$ is the number of heads. However, we observed higher losses throughout training when using this design.</p> <p>Still, we believe that there are other settings like GLM-5 where this strategy works well. Such cases would benefit <em>immensely</em> from Gram Newton-Schulz, since the aspect ratio of these weight matrices would be the number of heads $H$. For a standard attention weight like $\mathbf W_q$ with $H=16$ and $T=5$, Gram Newton-Schulz on the little matrices would use <strong>$80\times$</strong> fewer FLOPs than orthogonalizing the big matrix!</p> <h2 id="model-quality-is-preserved">Model Quality is Preserved</h2> <p>We see loss preserved as follows, when both using the Polar Express coefficients and the coefficients derived by You Jiacheng:<sup id="fnref:you:1"><a href="#fn:you" class="footnote" rel="footnote" role="doc-noteref">19</a></sup> <img src="https://hackmd.io/_uploads/ryyhZAroWe.png" alt="validation_perplexity_hopper"/> <em>Figure 17: Validation perplexity is always preserved within 0.01. We train with Muon using the Chinchilla scaling law on Hopper.</em> <img src="https://hackmd.io/_uploads/ryrgMRrsZx.png" alt="moe_1b_blackwell_ppl"/> <em>Figure 18: Validation perplexity is preserved within 0.01. We train with Muon using the Chinchilla scaling law on Blackwell.</em></p> <h2 id="our-method-speeds-up-the-optimizer-step">Our Method Speeds up the Optimizer Step</h2> <p><strong>Newton-Schulz Performance</strong> We observe that our method speeds up the runtime of the Newton-Schulz step in each iteration by up to $2\times$, especially as weights become more rectangular. The tables below report these speed-ups for each model, benchmarked on both H100 and B300. In these experiments, we use standard Newton-Schulz as the fallback when $m = n$:</p> <p><img src="https://hackmd.io/_uploads/BJy11AZoWe.png" alt="icml_ns_breakdown (6)"/> <em>Figure 19: Hopper architecture Newton-Schulz time per model weight. Very rectangular weights like Up/Gate and Down in Gemma-1B will especially benefit from Gram Newton-Schulz, while square weights like Llama-430m’s attention weights will just benefit from the kernels. The speedup on MoE-1B for Up/Gate and Down doesn’t even take advantage of the symmetric kernel, since the small intermediate size of $256$ is exactly the tile size. The speedup is fully algorithmic.</em></p> <p><img src="https://hackmd.io/_uploads/Bk1YJCZj-x.png" alt="icml_ns_breakdown_b300"/> <em>Figure 20: Blackwell architecture Newton-Schulz time per model weight. The speedup on MoE-1B for Up/Gate and Down is fully algorithmic, like in Figure 17.</em></p> <p><strong>End-to-End Optimizer Performance</strong> The following figure shows the end-to-end wall clock time of the optimizer step for each method. For Muon, these timings include the AdamW updates for weights not assigned to Muon (such as the embedding layer and the vector-valued weights), PyTorch operations for splitting and reconcatenating weights, and learning rate scaling.</p> <p><img src="https://hackmd.io/_uploads/r1s2fCZjWg.png" alt="icml_optimizer_plot2 (4)"/> <em>Figure 21: Hopper architecture end-to-end optimizer step during training, including matrix splitting and recombination for QKV and MLP, LR scaling, master weight updates, and the scalar optimizer (AdamW) step for non-2D weights.</em></p> <p>These results allow us to measure the impact of our optimized kernels separately from that of our Gram Newton-Schulz algorithm. We see that both pieces contribute significantly to the speedup. We observe that Llama-430M’s and Qwen-600M’s smaller, square weights benefit from the kernels - again, we stress that square architectures are the rare case. Meanwhile, Gemma benefits from both the algorithm and kernels, seeing the biggest speedup due to its MLP weights’ higher aspect ratio of $8$ instead of $4$.</p> <p>We run our experiments on a single GPU. The speedup of using our method in different parallelism configurations should be the same as on one GPU in most cases.</p> <h3 id="gram-newton-schulz-time-in-kimi-k2">Gram Newton-Schulz time in Kimi K2</h3> <p>Kimi K2 is a trillion parameter sparse, fine-grained MoE model with $384$ experts per layer, a hidden size of $7168$, and a small expert intermediate dimension of $2048$. Since models are trending towards finer-grained MoE architectures and Kimi K2 was trained with Muon, this is a perfect setting to benchmark Gram Newton-Schulz.</p> <p>In the <a href="#appendix">Appendix</a>, we approximate the exposed Newton-Schulz wall clock time of a global training step of Kimi K2 to be that of:</p> <ul> <li>216 expert up/gate/down weights of shape $2048 \times 7168$</li> <li>1 dense up/gate/down weight of shape $7168\times18432$</li> </ul> <p><img src="https://hackmd.io/_uploads/B1w5VCZjbe.png" alt="kimi (2)"/> <em>Figure 22: On Hopper, Gram Newton-Schulz is $2\times$ faster than standard Newton-Schulz in Kimi K2’s pipeline parallelism configuration.</em></p> <p><img src="https://hackmd.io/_uploads/HkQAVRZoWe.png" alt="kimi_b300"/> <em>Figure 23: On Blackwell, Gram Newton-Schulz is $2\times$ faster than standard Newton-Schulz in Kimi K2’s pipeline parallelism configuration.</em></p> <p>Observe that the speedup of Gram Newton-Schulz over standard Newton-Schulz in <code class="language-plaintext highlighter-rouge">torch</code> is twice the speedup of standard Newton-Schulz in CuTeDSL over standard Newton-Schulz in <code class="language-plaintext highlighter-rouge">torch</code>, showing the contribution of the new algorithm.</p> <h1 id="impact-on-end-to-end-training-time">Impact on End-to-End Training Time</h1> <p>In the previous section, we showed that Gram Newton-Schulz significantly speeds up the optimizer step time. This improvement is most impactful when the optimizer time is a large share of the global training step time. Many factors affect the relative runtimes of the optimizer step and the forward and backward passes. In this section, we describe several common settings where the optimizer step is a meaningful performance bottleneck.</p> <h3 id="low-precision-training">Low precision training</h3> <p>In low precision training, the forward and backward passes are computed in 4 bit or 8 bit precision, greatly speeding up their wall clock time. However, Newton-Schulz must be computed in 16 bit precision for stability and accuracy. Therefore, the optimizer time will occupy a greater share of training time.</p> <h3 id="small-global-batch-size">Small global batch size</h3> <p>When global batch size decreases, fewer microbatches are needed, so fewer forward and backward passes will occur per global training step. The optimizer time will remain the same, since it is agnostic to batch size. Therefore, the optimizer step will occupy a greater share of training time. For example, when SFT and RL use Muon, as in Kimi K2’s post-training pipeline, batch sizes are significantly smaller than in pretraining.<sup id="fnref:kimi:2"><a href="#fn:kimi" class="footnote" rel="footnote" role="doc-noteref">1</a></sup><sup id="fnref:SFT"><a href="#fn:SFT" class="footnote" rel="footnote" role="doc-noteref">26</a></sup></p> <h3 id="optimizer-step-frequency-is-bottlenecked-by-optimizer-duration">Optimizer step frequency is bottlenecked by optimizer duration</h3> <p>Fixing the total number of tokens used in training, smaller global batch sizes are typically preferable to larger global batch sizes for model quality, since they allow for more frequent weight updates.<sup id="fnref:allen"><a href="#fn:allen" class="footnote" rel="footnote" role="doc-noteref">27</a></sup> However, when using pipeline parallelism at scale, smaller batch sizes can come with a performance tradeoff. The backward pass of pipeline stage $i-1$ needs to hide the optimizer step of pipeline stage $i$ as much as possible, and increasing the batch size to better hide the optimizer step with a longer backward pass can increase throughput.</p> <p>Gram Newton-Schulz decreases the optimizer step time, allowing the backward pass to hide the optimizer at smaller batch sizes. Thus, Gram Newton-Schulz can improve model quality by allowing for smaller batch sizes and more frequent updates without a throughput tradeoff.</p> <h3 id="large-cluster-size">Large cluster size</h3> <p>A larger cluster size allows for more data parallel groups, decreasing the forward and backward pass time of a global training step. The optimizer step time will usually be the same. Distributing the Newton-Schulz work of a GPU’s model parameters across its corresponding rank in the other data parallel groups is possible, but it invokes significant internode communication overhead and occupies bandwidth that is usually not worth the cost.</p> <h1 id="conclusion">Conclusion</h1> <p>We hope our analysis and experiments will encourage researchers to try Gram Newton-Schulz. Our results show that Gram Newton-Schulz preserves training quality and speeds up the optimizer step by up to $2\times$ on popular model architectures, providing a rare case of free lunch performance.</p> <p>We release an <a href="https://github.com/Dao-AILab/gram-newton-schulz/blob/main/gram_newton_schulz/gram_newton_schulz.py">implementation of Gram Newton-Schulz</a> that serves as a drop-in replacement for the standard five-step Newton-Schulz used in Muon along with the <a href="https://github.com/Dao-AILab/quack/blob/main/quack/gemm_symmetric.py">symmetric GEMM kernels</a> that accelerate it. We believe that the stability analysis provided in this blog post lays the foundation for easily adapting Gram Newton-Schulz to other use cases. The only hyperparameter that needs to be retuned at all is the set of iterations at which to restart. To this end, we provide an <a href="https://github.com/Dao-AILab/gram-newton-schulz/blob/main/gram_newton_schulz/autotune_restarts.py">autotuning script</a> that takes a series of coefficients (for instance, 10 steps of Polar Express) and suggests the optimal set of restarts according to <a href="#stabilized-gram-newton-schulz">our analysis above</a>.</p> <h2 id="citing-this-blog-post">Citing this blog post</h2> <div class="language-bibtex highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nc">@misc</span><span class="p">{</span><span class="nl">GramNewtonSchulz</span><span class="p">,</span>
  <span class="na">title</span>   <span class="p">=</span> <span class="s">{Gram Newton-Schulz}</span><span class="p">,</span>
  <span class="na">author</span>  <span class="p">=</span> <span class="s">{Jack Zhang and Noah Amsel and Berlin Chen and Tri Dao}</span><span class="p">,</span>
  <span class="na">year</span>    <span class="p">=</span> <span class="s">{2026}</span><span class="p">,</span>
  <span class="na">url</span>     <span class="p">=</span> <span class="s">{https://dao-ailab.github.io/blog/2026/gram-newton-schulz/}</span>
<span class="p">}</span>
</code></pre></div></div> <h2 id="references">References</h2> <ol> <li>Keller Jordan, Yuchen Jin, Vlado Boza, Jiacheng You, Franz Cesista, Laker Newhouse, and Jeremy Bernstein. “Muon: An optimizer for hidden layers in neural networks.” Blog post, 2024. Available at: https://kellerjordan.github.io/posts/muon/</li> <li>Jeremy Bernstein. “Deriving Muon.” Blog post, 2025. Available at: https://jeremybernste.in/writing/deriving-muon</li> <li>Less Wright and Adnan Hoque. “CUTLASS Ping-Pong GEMM Kernel.” PyTorch Blog, November 1, 2024. Available at: https://pytorch.org/blog/cutlass-ping-pong-gemm-kernel/</li> <li>Noah Amsel, David Persson, Christopher Musco, and Robert M. Gower. “The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm.” International Conference on Learning Representations (ICLR), 2026.</li> <li>Ekaterina Grishina, Matvey Smirnov, and Maxim Rakhuba. “Accelerating Newton-Schulz Iteration for Orthogonalization via Chebyshev-type Polynomials.” arXiv preprint arXiv:2506.10935 (2025).</li> <li>Kwangjun Ahn, Byron Xu, Natalie Abreu, Ying Fan, Gagik Magakyan, Pratyusha Sharma, Zheng Zhan, and John Langford. “Dion: Distributed Orthonormalized Updates.” arXiv preprint arXiv:2504.05295 (2025).</li> <li>Jingyuan Liu et al. “Muon is Scalable for LLM Training.” arXiv preprint arXiv:2502.16982 (2025).</li> <li>Kimi Team. “Kimi K2: Open Agentic Intelligence.” arXiv preprint arXiv:2507.20534 (2026).</li> <li>Aaron Grattafiori et al. “The Llama 3 Herd of Models.” arXiv preprint arXiv:2407.21783 (2024).</li> <li>GLM-5 Team et al. “GLM-5: From Vibe Coding to Agentic Engineering.” arXiv preprint arXiv:2602.15763 (2026).</li> <li>An Yang et al. “Qwen3 Technical Report.” arXiv preprint arXiv:2505.09388 (2025).</li> <li>OpenAI et al. “gpt-oss-120b &amp; gpt-oss-20b Model Card.” arXiv preprint arXiv:2508.10925 (2025).</li> <li>DeepSeek-AI et al. “DeepSeek-V3 Technical Report.” arXiv preprint arXiv:2412.19437 (2025).</li> <li>Han Zhong, Zikang Shan, Guhao Feng, Wei Xiong, Xinle Cheng, Li Zhao, Di He, Jiang Bian, and Liwei Wang. “DPO Meets PPO: Reinforced Token Optimization for RLHF.” arXiv preprint arXiv:2404.18922 (2025).</li> <li>Thomas Pethick, Wanyun Xie, Kimon Antonakopoulos, Zhenyu Zhu, Antonio Silveti-Falls, and Volkan Cevher. “Training Deep Learning Models with Norm-Constrained LMOs.” arXiv preprint arXiv:2502.07529 (2025).</li> <li>Nikhil Vyas, Depen Morwani, Rosie Zhao, Mujin Kwun, Itai Shapira, David Brandfonbrener, Lucas Janson, and Sham Kakade. “SOAP: Improving and Stabilizing Shampoo using Adam.” arXiv preprint arXiv:2409.11321 (2025).</li> <li>Kevin Frans, Sergey Levine, and Pieter Abbeel. “A Stable Whitening Optimizer for Efficient Neural Network Training.” arXiv preprint arXiv:2506.07254 (2025).</li> <li>Vineet Gupta, Tomer Koren, and Yoram Singer. “Shampoo: Preconditioned Stochastic Tensor Optimization.” International Conference on Machine Learning. PMLR, 2018.</li> <li>Gemma Team et al. “Gemma 3 Technical Report.” arXiv preprint arXiv:2503.19786 (2025).</li> <li>Laker Newhouse, Dakota Goldberg, and Ricardo Ruiz. “Faster Symmetric Matrix Multiplication with ThunderKittens.”Available at: https://www.lakernewhouse.com/assets/writing/faster-symmul-with-thunderkittens.pdf</li> <li>Tianyang Lin. “Flash-Muon: An Efficient Implementation of Muon Optimizer.” GitHub repository, 2025. Available at: https://github.com/nil0x9/flash-muon</li> <li>Will Merrill. “Critical Batch Size Revisited: A Simple Empirical Approach to Large-Batch Language Model Training.” arXiv preprint arXiv:2505.23971 (2025). Available at: https://allenai.org/blog/critical-batch-size</li> </ol> <h1 id="appendix">Appendix</h1> <p>The share of end-to-end training time taken up by Newton-Schulz can vary widely depending on the training setup. To explain this variability, we analyze two idealized scenarios. In one, standard Newton-Schulz takes 2% of training time; in the another it takes 17%.</p> <h3 id="case-study-1-standard-newton-schulz-accounts-for-2-of-kimi-k2-training-time">Case Study 1: Standard Newton-Schulz accounts for 2% of Kimi K2 training time</h3> <p>The following analysis gives a very optimistic estimate of the optimizer’s wall clock time. We assume an efficient training infrastructure with highly optimized pipeline parallelism. Moreoveer, we assume that the optimizer step of each pipeline stage is completely hidden behind the backward pass of the next pipeline stage.</p> <p>Kimi K2 Thinking is a $1.1$ trillion parameter model with $32$ billion active parameters. It has $1$ dense layer followed by $60$ MoE layers.<sup id="fnref:kimi:3"><a href="#fn:kimi" class="footnote" rel="footnote" role="doc-noteref">1</a></sup> It is pretrained with $256$-GPU model parallel groups, $16$-way pipeline parallelism, $16$-way expert parallelism within each pipeline stage, and a huge batch size of $67$ million tokens.</p> <p>We use a single H100 to approximate the share of each training step’s runtime occupied by Newton-Schulz in this setting under the following assumptions:</p> <ol> <li>The training cluster is $2048$ H100s across $256$ nodes (8 GPUs per node), connected with NDR 400 Gb/s InfiniBand inter-node (8 NICs per node, 1:1 NIC-to-GPU ratio) and NVLink 4.0 intra-node. This is the size of the cluster used to train DeepSeekV3, with upgraded hardware.<sup id="fnref:deepseek"><a href="#fn:deepseek" class="footnote" rel="footnote" role="doc-noteref">28</a></sup> This means there are $\frac{2048}{256} = 8$ groups in data parallel.</li> <li>Training in <code class="language-plaintext highlighter-rouge">bfloat16</code> hits $35\%$ to $45\%$ MFU, which is a typical range for MoEs at this scale on H100 clusters.</li> <li>The only non-overlapped optimizer wall clock time is of the last pipeline stage that completes its backward (i.e. pipeline stage $1$ of $16$). The optimizer steps of pipeline stages $2$ to $16$ are fully hidden behind the backwards of stages $1$ to $15$.</li> <li>Pipeline stage $1$ has the dense layer and $3$ MoE layers.</li> </ol> <p>Under these assumptions, the optimal way to partition the Newton-Schulz work of pipeline stage 1 is as follows:</p> <ol> <li>Each of the $16$ GPUs in pipeline stage $1$’s expert parallel group gets \(\frac{384 \text{ experts/layer} \times 3 \text{ MoE layers}}{16 \text{ GPUs}} = 72 \text{ experts/GPU} = 216 \text{ expert up-gate-down/GPU}\) Each of the 16 GPUs has its own unique expert weights, so no communication is needed.</li> <li>The four shared experts’ weights and the dense MLP’s weights are divided evenly based on orthogonalization wall clock time amongst the 16 expert parallel GPUs, which run Newton-Schulz in parallel. The dense MLP’s three $7168\times18432$ weights (up/gate/down) dominate the wall clock time, so they are sent to 3 different GPUs, with the rest of the weights split amongst the remaining 13. Thus, the total Newton-Schulz time for all these weights when the 16 GPUs run in parallel is the same as the time to run Newton-Schulz on one of the dense up/gate/down weights. An <code class="language-plaintext highlighter-rouge">all_gather</code> is required between the two nodes to collect the distributed orthogonalized gradients, but we assume it is substantially faster than redundant Newton-Schulz work.</li> </ol> <p>Then, the total Newton-Schulz time of Pipeline Stage 1 is that of 216 expert up/gate/down weights and 1 dense up/gate/down weight.</p> <p>Per our assumptions, Pipeline Stage 1’s Newton-Schulz time is the only non-overlapped Newton-Schulz time. As benchmarked <a href="#gram-newton-schulz-time-in-kimi-k2">here</a>, standard Newton-Schulz in <code class="language-plaintext highlighter-rouge">torch</code> will take 315 ms.</p> <p>Let’s estimate the end-to-end wall clock time of an entire Kimi K2 global training step.</p> <p><strong>Given:</strong></p> <ul> <li>Active parameters: $N = 32 \times 10^9$</li> <li>H100 peak: $P = 989 \times 10^{12}$ FLOP/s</li> <li>Cluster size: $G = 2048$ GPUs</li> <li>Global batch size: $B = 67 \times 10^6$ tokens</li> </ul> \[\text{sec/batch} = \frac{B \times 6N}{P \times \text{MFU} \times G} = \frac{67 \times 10^6 \times 6 \times 32 \times 10^9}{989 \times 10^{12} \times \text{MFU} \times 2048} = \frac{6.351}{\text{MFU}}\] <p>For realistic estimates of MFU, we have</p> <table> <thead> <tr> <th>MFU</th> <th>sec/batch</th> </tr> </thead> <tbody> <tr> <td>35%</td> <td>18.14 s</td> </tr> <tr> <td>45%</td> <td>14.11 s</td> </tr> </tbody> </table> <p>Thus, Newton-Schulz takes approximately $\frac{315\text{ ms}}{18140\text{ ms} + 315\text{ ms}} = 1.7\%$ to $\frac{315\text{ ms}}{7060\text{ ms}+315\text{ ms}} = 2.2\%$ of total pretraining wall clock time in this setting.</p> <h3 id="case-study-2-standard-newton-schulz-occupies-17-of-llama3-70b-sft-time">Case Study 2: Standard Newton-Schulz occupies 17% of Llama3-70B SFT time</h3> <p>Llama3-70B is a 80-layer dense model with hidden size 8192, intermediate size 28672, and grouped query attention with $1024 \times 8192$ $\mathbf W_k, \mathbf W_v$ weights and $8192 \times 8192$ $\mathbf W_q, \mathbf W_o$ weights.<sup id="fnref:llama:1"><a href="#fn:llama" class="footnote" rel="footnote" role="doc-noteref">23</a></sup> Supervised finetuning (SFT) typically uses small batch sizes, ranging from $32$ to $256$ sequences.<sup id="fnref:SFT:1"><a href="#fn:SFT" class="footnote" rel="footnote" role="doc-noteref">26</a></sup><sup id="fnref:deepseek:1"><a href="#fn:deepseek" class="footnote" rel="footnote" role="doc-noteref">28</a></sup></p> <p>We construct the following SFT case:</p> <ol> <li>Training uses $32$ H100s across $4$ nodes (8 GPUs per node).</li> <li>Training in <code class="language-plaintext highlighter-rouge">bfloat16</code> hits $40\%$ MFU.</li> <li>Weights are sharded evenly across GPUs using FSDP, and the exposed Newton-Schulz time is that of $\frac{80 \text{ layers}}{32 \text{ GPUs}} \approx 3 \text{ layers}$. Each layer has 3 up-gate-down weights, 2 $\mathbf W_q, \mathbf W_o$ weights, and 2 $\mathbf W_k, \mathbf W_v$ weights.</li> </ol> <p>According to our benchmarking, standard Newton-Schulz of</p> <ul> <li>Nine $8192 \times 28672$ weights takes 738.731 ms</li> <li>Six $8192 \times 8192$ weights takes 156.368 ms</li> <li>Six $1024 \times 8192$ weights takes 2.318 ms</li> </ul> <p>totalling 250 ms.</p> <p><strong>Given:</strong></p> <ul> <li>Parameters: $N = 70 \times 10^9$</li> <li>H100 peak: $P = 989 \times 10^{12}$ FLOP/s</li> <li>Cluster size: $G = 32$ GPUs</li> <li>Global batch size: $B = 64 \text{ sequences} \times 2048 \text{ tokens/sequence} = 131{,}072$ tokens</li> </ul> \[\text{sec/batch} = \frac{B \times 6N}{P \times \text{MFU} \times G} = \frac{131{,}072 \times 6 \times 70 \times 10^9}{989 \times 10^{12} \times 0.40 \times 32} = 4.35\text{ s}\] <p>Newton-Schulz takes approximately $\frac{897.417\text{ ms}}{4350\text{ ms} + 897.417\text{ ms}} = 17\%$ of total SFT wall clock time in this parallelism setting.</p> <div class="footnotes" role="doc-endnotes"> <ol> <li id="fn:kimi"> <p>https://arxiv.org/abs/2507.20534 <a href="#fnref:kimi" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:kimi:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a> <a href="#fnref:kimi:2" class="reversefootnote" role="doc-backlink">&#8617;<sup>3</sup></a> <a href="#fnref:kimi:3" class="reversefootnote" role="doc-backlink">&#8617;<sup>4</sup></a></p> </li> <li id="fn:GLM"> <p>https://arxiv.org/abs/2602.15763 <a href="#fnref:GLM" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:GLM:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a></p> </li> <li id="fn:muon"> <p>https://kellerjordan.github.io/posts/muon/ <a href="#fnref:muon" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:muon:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a> <a href="#fnref:muon:2" class="reversefootnote" role="doc-backlink">&#8617;<sup>3</sup></a></p> </li> <li id="fn:dion"> <p>https://arxiv.org/abs/2504.05295 <a href="#fnref:dion" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:dion:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a></p> </li> <li id="fn:scion"> <p>https://arxiv.org/abs/2502.07529 <a href="#fnref:scion" class="reversefootnote" role="doc-backlink">&#8617;</a></p> </li> <li id="fn:soap"> <p>https://arxiv.org/abs/2409.11321 <a href="#fnref:soap" class="reversefootnote" role="doc-backlink">&#8617;</a></p> </li> <li id="fn:shampoo"> <p>https://arxiv.org/abs/1802.09568 <a href="#fnref:shampoo" class="reversefootnote" role="doc-backlink">&#8617;</a></p> </li> <li id="fn:splus"> <p>https://arxiv.org/abs/2506.07254 <a href="#fnref:splus" class="reversefootnote" role="doc-backlink">&#8617;</a></p> </li> <li id="fn:flashmuon"> <p>https://github.com/nil0x9/flash-muon <a href="#fnref:flashmuon" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:flashmuon:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a></p> </li> <li id="fn:PRISM"> <p>https://arxiv.org/abs/2601.22137 <a href="#fnref:PRISM" class="reversefootnote" role="doc-backlink">&#8617;</a></p> </li> <li id="fn:deriving_muon"> <p>https://jeremybernste.in/writing/deriving-muon <a href="#fnref:deriving_muon" class="reversefootnote" role="doc-backlink">&#8617;</a></p> </li> <li id="fn:polar-express"> <p>https://openreview.net/forum?id=yRtgZ1K8hO <a href="#fnref:polar-express" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:polar-express:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a> <a href="#fnref:polar-express:2" class="reversefootnote" role="doc-backlink">&#8617;<sup>3</sup></a></p> </li> <li id="fn:grishina"> <p>https://arxiv.org/abs/2506.10935 <a href="#fnref:grishina" class="reversefootnote" role="doc-backlink">&#8617;</a></p> </li> <li id="fn:num-iters"> <p>Some variants set $T=6$ or $T=4$, but never anything else. <a href="#fnref:num-iters" class="reversefootnote" role="doc-backlink">&#8617;</a></p> </li> <li id="fn:embeddings"> <p>In the case of standard attention, the $\mathbf W_{QKV}$ matrix is rectangular with aspect ratio 3, but for unrelated reasons we divide it into three square matrices and apply Newton-Schulz to each as we discuss <a href="#training-experiments-and-benchmarks">here</a>. Other authors subdivide these matrices into separate weights for each head, making them highly rectangular. The embedding and unembedding matrices are also rectangular, but these are not typically optimized using Muon. <a href="#fnref:embeddings" class="reversefootnote" role="doc-backlink">&#8617;</a></p> </li> <li id="fn:sonicmoe"> <p>https://arxiv.org/abs/2512.14080 <a href="#fnref:sonicmoe" class="reversefootnote" role="doc-backlink">&#8617;</a></p> </li> <li id="fn:qwen"> <p>https://arxiv.org/abs/2505.09388 <a href="#fnref:qwen" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:qwen:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a></p> </li> <li id="fn:gpt-oss"> <p>https://arxiv.org/abs/2508.10925 <a href="#fnref:gpt-oss" class="reversefootnote" role="doc-backlink">&#8617;</a></p> </li> <li id="fn:you"> <p>https://x.com/YouJiacheng/status/1905861218138804534 <a href="#fnref:you" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:you:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a></p> </li> <li id="fn:laker"> <p>https://www.lakernewhouse.com/assets/writing/faster-symmul-with-thunderkittens.pdf <a href="#fnref:laker" class="reversefootnote" role="doc-backlink">&#8617;</a></p> </li> <li id="fn:ping-pong"> <p>https://pytorch.org/blog/cutlass-ping-pong-gemm-kernel/ <a href="#fnref:ping-pong" class="reversefootnote" role="doc-backlink">&#8617;</a></p> </li> <li id="fn:JCZ_anecdote"> <p>We had previously mentioned a fused symmetric quadratic kernel for $\mathbf a_t \mathbf I + b_t \mathbf A + c_t \mathbf A^2$ that we ended up passing on for stability reasons. Quack’s abstraction was so convenient that Claude and I were able to write the register-level $\mathbf a_t \mathbf I$ fusion in 5 minutes on a car ride. <a href="#fnref:JCZ_anecdote" class="reversefootnote" role="doc-backlink">&#8617;</a></p> </li> <li id="fn:llama"> <p>https://arxiv.org/abs/2407.21783 <a href="#fnref:llama" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:llama:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a></p> </li> <li id="fn:gemma"> <p>https://arxiv.org/abs/2503.19786 <a href="#fnref:gemma" class="reversefootnote" role="doc-backlink">&#8617;</a></p> </li> <li id="fn:moonshot-muon-is-scalable"> <p>See section 2.2 of https://arxiv.org/abs/2502.16982. <a href="#fnref:moonshot-muon-is-scalable" class="reversefootnote" role="doc-backlink">&#8617;</a></p> </li> <li id="fn:SFT"> <p>https://arxiv.org/abs/2404.18922 <a href="#fnref:SFT" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:SFT:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a></p> </li> <li id="fn:allen"> <p>https://allenai.org/blog/critical-batch-size <a href="#fnref:allen" class="reversefootnote" role="doc-backlink">&#8617;</a></p> </li> <li id="fn:deepseek"> <p>https://arxiv.org/abs/2412.19437 <a href="#fnref:deepseek" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:deepseek:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a></p> </li> </ol> </div>]]></content><author><name>Jack Zhang</name></author><summary type="html"><![CDATA[]]></summary></entry><entry><title type="html">Mamba-3 Part 1</title><link href="tridao.github.io/blog/2026/mamba3-part1/" rel="alternate" type="text/html" title="Mamba-3 Part 1"/><published>2026-03-16T00:00:00+00:00</published><updated>2026-03-16T00:00:00+00:00</updated><id>tridao.github.io/blog/2026/mamba3-part1</id><content type="html" xml:base="tridao.github.io/blog/2026/mamba3-part1/"><![CDATA[<p><strong>This series is cross-posted at <a href="https://goombalab.github.io/blog/2026/mamba3-part1/">GoombaLab</a></strong></p> <p>[<a href="https://arxiv.org/abs/2603.15569">Paper</a>] [<a href="https://github.com/state-spaces/mamba">Code</a>]</p> <ol> <li>Part I</li> <li><a href="/blog/2026/mamba3-part2/">Part II</a></li> </ol> <p>Since the release of Mamba-2 in mid-2024, most architectures have switched from Mamba-1. Why? Mamba-2 made the bet that training efficiency was the largest bottleneck for state space models (SSMs), and thus simplified the underlying SSM mechanism to deliver $2-8\times$ faster training compared to its predecessor, leading to wider adoption.</p> <p>Since then, the LLM landscape has started to shift. While pretraining is still super important, more attention has been focused on post-training and deployment, both of which are <em>extremely inference-heavy</em>. The scaling of post-training methods, especially with reinforcement learning with verifiable rewards (RLVR) for coding or math, requires huge amounts of generated rollouts, and most recently, agentic workflows, such as Codex, Claude Code, or even OpenClaw, have <strong>pushed inference demand through the roof</strong>.</p> <p>Despite the clear, growing importance of inference, many linear architectures (including Mamba-2) were developed from a training-first perspective. To accelerate pretraining, the underlying SSM was <em>progressively simplified</em> (e.g., the diagonal transition was reduced to a scalar times identity). While this brought training speed, it left the inference step “too simple” and squarely memory-bound — the GPUs aren’t brr-ing but moving memory most of the time.</p> <p>In this new age of inference, we care a lot about pushing the boundaries of the quality-efficiency frontier: we want the <em>better</em> models to run <em>faster</em>.</p> <p>A natural question arises:</p> <blockquote> <p>What would an SSM designed with <strong>inference</strong> in mind look like?</p> </blockquote> <h2 id="the-mamba-3-model">The Mamba-3 Model</h2> <p><strong>What’s missing?</strong> The main appeal of linear models is in their name: compute scales linearly with sequence length because of a fixed-size state. Unfortunately, there is <em>no free lunch</em>. The same <strong>fixed state size</strong> that enables efficient computation forces the model to compress all past information into one representation, the exact opposite of a Transformer, which stores all past information through a continuously growing state (the KV cache) — a <em>fundamental</em> difference. So, if we can’t grow the state, how do we make that fixed state do more work?</p> <p>We see that earlier designs simplified the recurrence and the transition matrix to make training fast. However, the change also <em>reduced the richness</em> of the dynamics and left decoding memory-bound: each token update performs very little computation relative to memory movement. This provides us with three levers we can pull: <strong>(1)</strong> make the recurrence itself more expressive, <strong>(2)</strong> use a richer transition matrix, and <strong>(3)</strong> add more parallel (and almost free) work inside each update.</p> <p>From these insights, we improve upon Mamba-2 in three core ways that:</p> <ol> <li>increase the expressivity of the SSM mechanism through a more general recurrence derived from our <strong>exponential-trapezoidal discretization scheme</strong>,</li> <li>expand the state-tracking capabilities by modeling a <strong>complex-valued SSM system</strong>, and</li> <li>improve the model’s general performance with little impact on decode latency by using <strong>multi-input, multi-output (MIMO) SSMs</strong>, which model multiple SSMs in parallel, instead of the current single-input, single-output (SISO) SSMs.</li> </ol> <p>Through these three changes, <strong>Mamba-3 pushes the frontier of performance while maintaining similar inference latency</strong>.</p> <blockquote> <p>Notably, all three of these changes are inspired by the more “classical” control theory and <strong>state space model</strong> literature.</p> </blockquote> <p>Our work goes against the grain of many modern linear architectures, which use alternative interpretations of recurrence (such as <strong>linear attention</strong> or <strong>test-time training</strong>) that <em>don’t easily capture these concepts</em>.</p> <h2 id="architecture">Architecture</h2> <p>What has changed in the Mamba-2 layer? Beyond the three methodological upgrades to the core SSM discussed above, we’ve revamped the architecture a bit to make it more in line with conventional modern language models.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2026-03-16-mamba-3/mamba3-arch-480.webp 480w,/assets/img/2026-03-16-mamba-3/mamba3-arch-800.webp 800w,/assets/img/2026-03-16-mamba-3/mamba3-arch-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2026-03-16-mamba-3/mamba3-arch.png" width="100%" height="auto" title="Mamba-3 Architecture" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> <figcaption class="caption">Mamba-3 Architecture.</figcaption> </figure> <p>Based on the diagram, you’ll notice we’ve changed a couple of things. On a high level,</p> <p><strong>Norms.</strong> We added in QKNorm<d-footnote>or "BCNorm" in SSM terminology</d-footnote>, which empirically stabilizes the training of Mamba-3 models. The addition of this norm brings Mamba-3 in line with contemporary Transformer and Gated DeltaNet (GDN) models. With QKNorm, the RMSNorm from Mamba-2 becomes optional. However, we empirically find that it may still be worth keeping in hybrid models due to helping length extrapolation capabilities. More on this later.</p> <p><strong>Goodbye Short Conv.</strong> We’ve been able to get rid of the pesky short causal convolution of Mamba-1/2 by combining (1) simple biases on B and C after BCNorm with (2) our new discretization-based recurrence. The new recurrence implicitly applies a <strong>convolution</strong> on the input to the hidden state, and we show how this is the case in Part 2 of our blog.</p> <details><summary>Can the short conv really be removed?</summary> <p>The changes in Mamba-3 add convolution-like components <strong>inside the SSM recurrence</strong> but aren’t exactly interchangeable with the standard short conv placed <strong>outside the SSM recurrence</strong>.</p> <p>The latter can still be used together with Mamba-3, but the decision not to was made empirically. We find adding the standard short conv back:</p> <ol> <li>does not improve performance; in fact, it <em>slightly worsens it</em>, and</li> <li>does not degrade retrieval capabilities on more real-world tasks (e.g., NIAH). That said, without a short convolution, training on small-scale synthetic tasks like MQAR becomes somewhat harder. Since real-world retrieval behavior remains unaffected, though, we don’t consider this a major limitation.</li> </ol> <p>As for why? We didn’t study the theoretical mechanisms, but in the paper, we hypothesize about how both the BC bias and the exponential-trapezoidal recurrence perform similar <strong>convolution-like mechanisms</strong> which empirically serve the same function as the external short conv.</p> </details> <details><summary>Quick history lesson on the short conv.</summary> <p>The short convolution is now a core component of most performant linear models today <d-cite key="gu2024mambalineartimesequencemodeling"></d-cite><d-cite key="dao2024transformersssmsgeneralizedmodels"></d-cite><d-cite key="yang2025gateddeltanetworksimproving"></d-cite><d-cite key="sun2025learninglearntesttime"></d-cite>. Versions of the short conv were first used in recurrent architectures by H3<d-cite key="fu2023hungryhungryhipposlanguage"></d-cite> (in the form of a “shift SSM” which was inspired by the “smeared” induction heads work by Anthropic <d-cite key="olsson2022context"></d-cite>) and RWKV-4 <d-cite key="peng2023rwkvreinventingrnnstransformer"></d-cite> (through its “token shift” mechanism), before being popularized in its current form by Mamba-1.</p> <p>The reason it’s so commonplace is because previous works have repeatedly shown that short convolutions improve empirical performance as well as theoretically support <strong>induction-style retrieval capabilities</strong><d-cite key="wang2025testtimeregressionunifyingframework"></d-cite> .</p> </details> <p>Finally, you’ll notice a couple of new components, namely <strong>RoPE</strong> and <strong>MIMO projections</strong>. The RoPE module expresses complex-valued SSMs via the interpretation of complex transitions as rotations, forgoing the costly reimplementation of kernels. The MIMO projections expand the B and C matrices to the appropriate representation needed for MIMO SSMs.</p> <p>We dig into the motivation and exact implementation of these two in greater detail in the second part of our blog (lots of goodies there 🎁), so for now, just think of them as <strong>standalone, fundamental improvements</strong> that individually contribute to improving the model’s performance and/or capabilities.</p> <p>Finally, our overall architecture now adopts interleaved MLP layers following the standard convention of Transformers and other linear models.</p> <h2 id="empirical-results">Empirical Results</h2> <p>We evaluate our final Mamba-3 model against other popular linear alternatives and the Transformer baseline.</p> <h3 id="language-modeling">Language Modeling</h3> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2026-03-16-mamba-3/evals-480.webp 480w,/assets/img/2026-03-16-mamba-3/evals-800.webp 800w,/assets/img/2026-03-16-mamba-3/evals-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2026-03-16-mamba-3/evals.png" width="100%" height="auto" title="Downstream LM Evals" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> <figcaption class="caption">Downstream Language Modeling Evaluations for Pretrained Models.</figcaption> </figure> <p>We find that our new Mamba-3 model <em>outperforms</em> the prior Mamba-2 model and strong linear attention alternatives, such as GDN, on language modeling across various pretrained model scales. <strong>Mamba-3-SISO</strong> is directly comparable to prior linear models; for example, it matches Mamba-2 exactly in architecture shapes (model dimensions, state size, etc.) and has comparable training time. Our <strong>MIMO</strong> variant of Mamba-3 further boosts accuracy on our downstream tasks by more than 1 percentage point over the regular Mamba-3 at the 1B scale, with the caveat that MIMO requires longer training times but not longer decoding latencies!</p> <details><summary>How can training costs go up but not inference?</summary> <p>While we will talk about this in detail in the second part of the blog, we give readers a sneak peek here:</p> <p>This dichotomy can be traced back to the respective compute versus memory-bound nature of training and inference. Current linear models have been designed to use lots of <strong>GPU tensor cores</strong> (one of the main contributions of Mamba-2) for fast training, but during decoding, each timestep requires so little compute that the hardware remains cold most of the time.</p> <p>Thus, if we design architectures around just increasing the amount of FLOPs needed for each time-step, inference latency stays roughly constant since we can just use some of the idle cores — not so much for training!</p> </details> <h3 id="retrieval-tasks">Retrieval Tasks</h3> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2026-03-16-mamba-3/retrieval-480.webp 480w,/assets/img/2026-03-16-mamba-3/retrieval-800.webp 800w,/assets/img/2026-03-16-mamba-3/retrieval-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2026-03-16-mamba-3/retrieval.png" width="100%" height="auto" title="Retrieval Tasks" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> <figcaption class="caption">Real-world and Synthetic Retrieval Tasks.</figcaption> </figure> <p>Linear models, with their fixed-size state, <strong>naturally underperform</strong> their Transformer counterparts on <strong>retrieval-based tasks</strong>. As expected, within pure models, the Transformer is superior on retrieval tasks, but Mamba-3 performs well within the class of sub-quadratic alternatives. Interestingly, the addition of MIMO further improves retrieval performance <em>without increasing the state size</em>.</p> <p>Given this innate deficit but overall strong modeling performance,</p> <blockquote> <p>we predict that linear layers will be predominantly used in <strong>conjunction</strong> with global self-attention layers in the future. $^*$</p> </blockquote> <p>$^*$ at least for language modeling</p> <p><strong>Hybrid models</strong> that combine the general <em>memory-like</em> nature of linear layers with the exact <em>database-like</em> storage of self-attention’s KV cache have been shown empirically to outperform pure models while enabling significant memory and compute savings <d-cite key="waleffe2024empiricalstudymambabasedlanguage"></d-cite>, and we do find here that the combination of linear layers with self-attention enables better retrieval compared to a vanilla Transformer.</p> <p>However, we highlight that the exact way that these linear models interact with self-attention is <em>not fully understood</em>. For instance, we find that the use of the optional pre-output projection for Mamba-3 improves the length generalization performance on the synthetic NIAH tasks at the slight cost of in-context real-world retrieval tasks. Furthermore, even the details of the returned norm such as placement, e.g., pre-gate vs post-gate, and type, grouped vs regular, have non-negligible effects on accuracy on tasks composed of semi-structured and unstructured data, such as FDA and SWDE.</p> <h2 id="kernels-here-there-and-everywhere">Kernels Here, There, and Everywhere</h2> <p>We’re excited to see what people build with Mamba-3. To help facilitate this, we are open-sourcing our kernels, which are <strong>on par in terms of speed</strong> with the original Mamba-2 Triton kernels.</p> <h3 id="benchmarking-latencies">Benchmarking Latencies</h3> <p><strong>Prefill Latency</strong></p> <table> <thead> <tr> <th>Model</th> <th style="text-align: right">n=512</th> <th style="text-align: right">1024</th> <th style="text-align: right">2048</th> <th style="text-align: right">4096</th> <th style="text-align: right">16384</th> </tr> </thead> <tbody> <tr> <td>vLLM (Llama-3.2-1B)</td> <td style="text-align: right"><strong>0.26</strong></td> <td style="text-align: right"><strong>0.52</strong></td> <td style="text-align: right"><strong>1.08</strong></td> <td style="text-align: right"><strong>2.08</strong></td> <td style="text-align: right"><strong>12.17</strong></td> </tr> <tr> <td>Gated DeltaNet</td> <td style="text-align: right">0.51</td> <td style="text-align: right">1.01</td> <td style="text-align: right">2.01</td> <td style="text-align: right">4.00</td> <td style="text-align: right">16.21</td> </tr> <tr> <td>Mamba-2</td> <td style="text-align: right">0.51</td> <td style="text-align: right">1.02</td> <td style="text-align: right">2.02</td> <td style="text-align: right">4.02</td> <td style="text-align: right">16.22</td> </tr> <tr> <td>Mamba-3 (SISO)</td> <td style="text-align: right">0.51</td> <td style="text-align: right">1.01</td> <td style="text-align: right">2.02</td> <td style="text-align: right">4.01</td> <td style="text-align: right">16.22</td> </tr> <tr> <td>Mamba-3 (MIMO $R=4$)</td> <td style="text-align: right">0.60</td> <td style="text-align: right">1.21</td> <td style="text-align: right">2.42</td> <td style="text-align: right">4.76</td> <td style="text-align: right">19.44</td> </tr> </tbody> </table> <p><strong>Prefill+Decode Latency</strong></p> <table> <thead> <tr> <th>Model</th> <th style="text-align: right">n=512</th> <th style="text-align: right">1024</th> <th style="text-align: right">2048</th> <th style="text-align: right">4096</th> <th style="text-align: right">16384</th> </tr> </thead> <tbody> <tr> <td>vLLM (Llama-3.2-1B)</td> <td style="text-align: right">4.45</td> <td style="text-align: right">9.60</td> <td style="text-align: right">20.37</td> <td style="text-align: right">58.64</td> <td style="text-align: right">976.50</td> </tr> <tr> <td>Gated DeltaNet</td> <td style="text-align: right">4.56</td> <td style="text-align: right">9.11</td> <td style="text-align: right">18.22</td> <td style="text-align: right">36.41</td> <td style="text-align: right">145.87</td> </tr> <tr> <td>Mamba-2</td> <td style="text-align: right">4.66</td> <td style="text-align: right">9.32</td> <td style="text-align: right">18.62</td> <td style="text-align: right">37.22</td> <td style="text-align: right">149.02</td> </tr> <tr> <td>Mamba-3 (SISO)</td> <td style="text-align: right"><strong>4.39</strong></td> <td style="text-align: right"><strong>8.78</strong></td> <td style="text-align: right"><strong>17.57</strong></td> <td style="text-align: right"><strong>35.11</strong></td> <td style="text-align: right"><strong>140.61</strong></td> </tr> <tr> <td>Mamba-3 (MIMO $R=4$)</td> <td style="text-align: right">4.74</td> <td style="text-align: right">9.48</td> <td style="text-align: right">18.96</td> <td style="text-align: right">37.85</td> <td style="text-align: right">151.81</td> </tr> </tbody> </table> <figure> <figcaption class="caption"> Prefill and prefill+decode (same token count for both prefill and decode) latencies across sequence lengths for a 1.5B model on a single H100-SXM 80GB GPU. A batch size of 128 was used for all sequence lengths, wall-clock times (in seconds) are reported over three repetitions. </figcaption> </figure> <p>When comparing models at the 1.5B scale, Mamba-3 (SISO variant) <em>achieves the fastest prefill + decode latency</em> across all sequence lengths, outperforming Mamba-2, Gated DeltaNet, and even the Transformer with its highly optimized vLLM ecosystem. Furthermore, <strong>Mamba-3 MIMO is comparable to Mamba-2 in terms of speed but has much stronger performance</strong>.</p> <p>Mamba-3 SISO’s Triton-based prefill maintains nearly identical performance to Mamba-2, demonstrating that the new discretization and data-dependent RoPE embeddings do not introduce additional overhead, while Mamba-3 MIMO only incurs a moderate slowdown for prefill due to its efficient TileLang implementation. The strong decode performance for both Mamba-3 variants can be partially attributed to the CuTe DSL implementation, which was made significantly easier by the simplicity of Mamba-3 components.</p> <h3 id="design-choices">Design Choices</h3> <p>We spent a lot of time thinking about how to make the kernels as fast as possible without compromising on ease-of-use. We ended up using the following stack: <strong>Triton</strong>, <strong>TileLang</strong>, and <strong>CuTe DSL</strong>.</p> <p>The use of <strong>Triton</strong> was quite an easy choice. It’s pretty much standard for architecture development (the great <a href="https://github.com/fla-org/flash-linear-attention">flash linear attention</a> repo is purely in PyTorch and Triton) for good reason, as it enables better performance than standard PyTorch by enabling controlled tiling and kernel fusion while being a platform-agnostic language. Triton also has some pretty nifty features, like <a href="https://modal.com/gpu-glossary/device-software/parallel-thread-execution">PTX</a> (a GPU-oriented assembly language) injection and its Tensor Memory Accelerator support (on Hopper GPUs) for bulk, asynchronous transfers from global to shared memory.</p> <p>Our MIMO prefill kernels were developed with <strong>TileLang</strong> instead. The additional projections corresponding with the variant present an opportunity where we can reduce memory IO via strategic manipulation across a GPU’s memory hierarchy. Unfortunately, Triton didn’t provide the granularity of memory control we desired, so we opted for TileLang, which allows us to explicitly declare and control shared-memory tiles and create register fragments, reusing memory more efficiently while still being high-level enough for us to develop the kernels quickly.</p> <p>Since we’ve been hammering the importance of inference and decode, we decided to use <strong>CuTe DSL</strong> for our decode kernels. Through its Python interface, we’re able to generate low-level kernels using high-level abstractions from CUTLASS. Here, we practically have CUDA-level control, enabling us to develop highly-performant kernels tailored to the specifications of our hardware (Hopper GPUs, in this case). With fine-grained control over tensor layouts and warp specialization, we built a kernel that takes advantage of all the bells and whistles in the GPU.</p> <p>Importantly, these implementations across varying levels of GPU abstraction are made possible by the <strong>underlying algorithmic design</strong> of Mamba-3’s simple, lightweight additions and their clever instantiations. We discuss details such as the exact fusion structure and kernel DSL in more depth in our full release.</p> <h2 id="next-up">Next Up</h2> <p>Glad you made it to the end of Part 1! There were a lot of details regarding our kernels and experimental results and ablations we didn’t have time to cover in this post, but don’t fret! Everything can be found in <a href="https://arxiv.org/abs/2603.15569">our paper</a>, and the kernels have been open-sourced at <a href="https://github.com/state-spaces/mamba">mamba-ssm</a>!</p> <p>Up next, the <a href="/blog/2026/mamba3-part2/">second (and final) part</a> of the series delves into the three core improvements to Mamba-3 and their SSM foundations, and gives some directions we’re especially interested in.</p>]]></content><author><name>Aakash Lahoti*</name></author><summary type="html"><![CDATA[This series is cross-posted at GoombaLab]]></summary></entry><entry><title type="html">Mamba-3 Part 2 - Methodological Deep Dive</title><link href="tridao.github.io/blog/2026/mamba3-part2/" rel="alternate" type="text/html" title="Mamba-3 Part 2 - Methodological Deep Dive"/><published>2026-03-16T00:00:00+00:00</published><updated>2026-03-16T00:00:00+00:00</updated><id>tridao.github.io/blog/2026/mamba3-part2</id><content type="html" xml:base="tridao.github.io/blog/2026/mamba3-part2/"><![CDATA[<p><strong>This series is cross-posted at <a href="https://goombalab.github.io/blog/2026/mamba3-part2/">GoombaLab</a></strong></p> <ol> <li><a href="/blog/2026/mamba3-part1/">Part I</a></li> <li>Part II</li> </ol> <p>We introduced our Mamba-3 model in <a href="/blog/2026/mamba3-part1/">Part I</a> in which we mentioned that the three core methodological changes were inspired by the SSM perspective. Here, we’ll actually do a deep dive into what each of these three improvements entail, their motivations, and their derivations.</p> <p>But first, let’s refresh our memory on the underlying state space model and its background.</p> <h2 id="state-space-foundations">State Space Foundations</h2> <p>The state space model, at its most primitive, is a simple, <strong>continuous</strong> ordinary differential equation (ODE). The input $x(t) \in \mathbb{R}$ is mapped to output $y(t) \in \mathbb{R}$ through a hidden state $h(t) \in \mathbb{R}^N$ of size $N$, also referred to as the state size. In the past, in both deep learning and classical control theory, these systems were <strong>linear time-invariant</strong> (LTI), where the “state decay” transition $A \in \mathbb{R}^{(N\times N)}$, $B \in \mathbb{R}^{N}$, and $C \in \mathbb{R}^{N}$ terms were constant.</p> \[\begin{aligned} h'(t) &amp;= A h(t) + B x(t) \\ y(t) &amp;= C^\top h(t) \end{aligned}\] <p>We will occasionally refer to $A$ as the <em>state-transition</em>, and $Bx(t)$ as the <em>state-input</em>.</p> <p>Upon <strong>discretization</strong> with one’s favorite method, as demonstrated with the zero-order hold (ZOH) used in both Mamba-1 and Mamba-2, a familiar <strong>recurrence</strong> materializes,</p> \[\begin{aligned} h_{t} &amp;= e^{\Delta_t A_t} h_{t-1}+ A_t^{-1}(e^{\Delta_t A_t} - I)\,B_t\,x_t \\ y_t &amp;= C_t^\top h_t \end{aligned}\] <p>where the discretized $\bar{A}$ and $\bar{B}$ are now $e^{\Delta_t A_t}$ and $A_t^{-1}(e^{\Delta_t A_t} - I)\,B_t$ respectively.</p> <p>Eagle-eyed readers may ask “how does one go from a LTI system to a linear-time varying (LTV) system?” — if you did, the answer is revealed below!</p> <details><summary>Aside on prior Mamba discretizations</summary> <p>We’ll let you in on a little secret: prior Mamba discretizations used the canonical ZOH discretization scheme and just converted the fixed time-invariant variables A, B, and C to time-varying!</p> <p>No worries if this feels uneasy. We felt that too, which is why we formalized the discretization later (sorry for the clickbait, the answer <em>is</em> not here; you’ll have to keep on reading).</p> </details> <p>While there are no theoretical restrictions on the class of matrices $\bar{A}$ can be, <strong>computational constraints keep transition matrices structured</strong>, e.g., diagonal, scalar times identity, Householder (identity plus low-rank), etc.</p> <p>Great, now we’ve set up the underlying mechanism of Mamba used in both Mamba-1 and Mamba-2!</p> <p>As a quick recap, Mamba-3 builds on Mamba-2 to <em>improve the efficiency-performance trade-off</em> of current SSMs with inference at the forefront. The three core improvements we’ve been discussing in this post are rooted in classical state space theory:</p> <ol> <li>Instantiating a more generalized recurrence through a formal framework for discretizing the underlying ODE</li> <li>Improving state-tracking abilities by converting to complex-valued SSM without the engineering challenges of explicit complex numbers</li> <li>Increasing the expressivity of the SSM without increasing state size through a multi-input, multi-output (MIMO) formulation.</li> </ol> <p>Let’s jump straight into it.</p> <h2 id="upgraded-discretization">Upgraded Discretization</h2> <p>Our end goal is to obtain a more general recurrence than that of current models from first principles. Luckily for us, the discretization of the continuous ODE provides the <em>perfect</em> opportunity to do so.</p> <p>But first, let’s lay down the foundation for our framework used in discretizing time-varying systems. Remember how we mentioned that prior Mamba discretizations adapted the canonical ZOH discretization by adding a time subscript to convert the method from LTI to LTV? Well, to be honest, we left a bit more out earlier. The <a href="https://github.com/state-spaces/mamba/issues/129">actual implementation</a> of Mamba frankensteined the canonical ZOH and Euler methods to create discretized parameters $\bar{A}_t = \exp(\Delta_t A_t), \bar{B}_t=\Delta_t B_t$.</p> <p>Holy heuristic! But it works empirically ¯\<em>(ツ)</em>/¯.</p> <p>One potential explanation for why this mixture does so well despite not being theoretically grounded is that Euler is an approximation of ZOH. Taking the ZOH formula for discretized $\bar{B}_t=A_t^{-1}\left(\exp(\Delta_t A_t) - I \right) B_t$, if we use the approximation $\exp(x) \approx 1+x$, the resulting $\bar{B}_t \approx \Delta_t B_t$.</p> <p>This heuristic was always bugging us in the back of our minds, so in our work, we finally formalized the discretization.</p> <blockquote> <p>We develop a method that produces a class of formal discretizations for time-varying systems, including one called exponential-Euler that exactly corresponds to the formula used in Mamba-1/2.</p> </blockquote> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2026-03-16-mamba-3/discretization-table-480.webp 480w,/assets/img/2026-03-16-mamba-3/discretization-table-800.webp 800w,/assets/img/2026-03-16-mamba-3/discretization-table-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2026-03-16-mamba-3/discretization-table.png" width="100%" height="auto" title="Discretization Table" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> <figcaption class="caption">Resulting final recurrent $h_t = \alpha_t h_{t-1} + \beta_t B_{t-1}x_{t-1} + \gamma_t B_tx_t$ from various discretization methods. Top half is LTI methods, and bottom is LTV methods derived from our discretization framework. </figcaption> </figure> <h3 id="exponential-adjusted-discretizations">Exponential-Adjusted Discretizations</h3> <p>So let’s actually figure out how to discretize our LTV system in a principled manner.</p> \[h'(t) = A(t)h(t) + B(t)x(t)\] <p>The general intuition behind our framework is that a bare-bones ODE $f'(t) = Af(t)$ has a closed-form solution $f(t) = e^{tA} f(0)$. It follows that the one-step discrete update is then $x_{t+1} = e^{\Delta A}x_t$. Here, since the derivative includes $Ah(t)$, the state directly impacts the rate of change. Thus, the parameterization of A can rapidly oscillate the dynamics of the system, which forces explicit methods, like Euler, to take small $\Delta$ steps which limits the expressivity of the system.</p> <p>To mitigate this, we <strong>adjust</strong> the dynamics with an integrating factor of $e^{-At}$ to counteract the dominating exponential and directly analyze $e^{-At}h(t)$ instead. Let’s see how it applies to our system.</p> <p>Taking our $h'(t)=A(t)h(t)+B(t)x(t)$ system, we apply an integrating factor of $e^{\int_0^t -A(s)ds}$ as $A$ is now time-varying.</p> \[\begin{aligned} e^{\int_0^t -A(s)ds}h'(t) &amp;= e^{\int_0^t -A(s)ds}A(t)h(t) + e^{\int_0^t -A(s)ds}B(t)x(t) \\ (e^{\int_0^t -A(s)ds}h(t))' &amp;= e^{\int_0^t -A(s)ds}B(t)x(t) \end{aligned}\] <p>since $(e^{\int_0^t -A(s)ds})' = -A(t)e^{\int_0^t -A(s)ds}$.</p> <p>Thus, when we want to discretize between timesteps $[\tau_{t-1}, \tau_t]$, we can just integrate both sides over that interval. For ease of notation, we denote $z(t) := e^{\int_0^t -A(s)ds}$.</p> \[\begin{aligned} \tfrac{d}{dt}(z(t)h(t)) &amp;= z(t)B(t)x(t) \\ \int_{\tau_{t-1}}^{\tau_t}\tfrac{d}{d\tau}(z(\tau)h(\tau))d\tau &amp;= \int_{\tau_{t-1}}^{\tau_t}z(\tau)B(\tau)x(\tau)d\tau \\ z(\tau_{t})h(\tau_{t}) - z(\tau_{t-1})h(\tau_{t-1}) &amp;= \int_{\tau_{t-1}}^{\tau_t}z(\tau)B(\tau)x(\tau)d\tau \\ z(\tau_{t})h(\tau_{t}) &amp;= z(\tau_{t-1})h(\tau_{t-1}) + \int_{\tau_{t-1}}^{\tau_t}z(\tau)B(\tau)x(\tau)d\tau \\ \end{aligned}\] <p>Rearranging into a more familiar form and substituting back the $z(\tau)$ value:</p> \[\begin{aligned} h(\tau_{t}) &amp;= z(\tau_{t})^{-1}z(\tau_{t-1})h(\tau_{t-1}) + z(\tau_{t})^{-1}\int_{\tau_{t-1}}^{\tau_t}z(\tau)B(\tau)x(\tau)d\tau \\ h(\tau_t) &amp;= \exp\left(\int_{\tau_{t-1}}^{\tau_t}A(s)ds\right)h(\tau_{t-1}) + \int_{\tau_{t-1}}^{\tau_t} \exp\left(\int_{\tau}^{\tau_t}A(s)ds\right)B(\tau)x(\tau) d\tau \end{aligned}\] <p>Now, we’ve isolated the state-transition and the state-input through our integration factor. This means the most “difficult” part of the adjusted system can be calculated independently of the state-input integral, which is left to be approximated with many possible methods.</p> <p>Under the LTV case, because $A(s)$ is continuous, we “sample” it with a right-hold assumption where $\forall s \in [\tau_{t-1},\tau_t], A(s) = A(\tau_t) = A_t$, resulting in</p> \[h_t \approx \exp(\Delta_t A_t)h_{t-1} + \int_{\tau_{t-1}}^{\tau_t} \exp\left((\tau_t - \tau)A_t\right)B(\tau)x(\tau)d\tau\] <details><summary>When can the transition integral be directly calculated?</summary> <p>If $A$ is LTI, i.e., constant, then the state-transition integral is exactly $\exp(A \Delta_t)$, which is the discretized $\bar{A}$ term for canonical ZOH.</p> <p>Under the assumption $x(\tau) = x_t$, if $B$ is also LTI, then $\bar{B}$ also recovers the canonical ZOH term, $A^{-1}\left(\exp(\Delta A) - I\right)B$.</p> </details> <p>This final equation lays the foundation for recovering the prior Mamba discretization methods and is also the inspiration of our “exponential-“ style name, as the application of the integration factor can be seen as a style of exponential tilting or adjustment.</p> <h3 id="recovering-prior-mamba-discretization">Recovering Prior Mamba Discretization</h3> <p>As previously mentioned, prior Mamba discretizations <em>differ</em> on paper and in practice. Now, using our new discretization derivation and certain sampling assumptions, we can recover the reported LTV ZOH discretization <em>and</em> the implemented exponential-adjusted Euler discretization scheme, or exponential-Euler for short.</p> <p>We have already recovered the $\bar{A}$ term for both, so we will focus only on the remaining state-input integral which evaluates to $\bar{B}$.</p> <p><strong>ZOH</strong>: Assuming similar assumptions where $B(\tau), x(\tau)$ are constant and sampled at the right endpoint,</p> \[\begin{aligned} (\cdot) &amp;= B(\tau_t)x(\tau_t) \int_{\tau_{t-1}}^{\tau_t} \exp\left((\tau_t - \tau)A_t\right)d\tau\\ &amp;= B_t\,x_t\exp(A_t\tau_t) \int_{\tau_{t-1}}^{\tau_t} \exp\left(-A_t\tau\right)d\tau \\ &amp;= B_t\,x_t\exp(A_t\tau_t)\dfrac{1}{A_t}\left(\exp(-A_t\tau_{t-1}) - \exp(-A_t\tau_t)\right) \\ &amp;= A_t^{-1}\left(\exp(A_t(\tau_t - \tau_{t-1})) - I\right)B_t\,x_t \\ &amp;= A_t^{-1}\left(\exp({\Delta_tA_t})-I\right)B_t\,x_t \end{aligned}\] <p><strong>Exponential-Euler</strong>: Once again, we approximate the integral with Euler’s rule and hold the $B, x$ terms to the right endpoint.</p> \[\begin{aligned} (\cdot) &amp;= B(\tau_t)x(\tau_t) \int_{\tau_{t-1}}^{\tau_t} \exp\left((\tau_t - \tau)A_t\right)d\tau \\ &amp;= B_t\,x_t\exp(A_t\tau_t) \int_{\tau_{t-1}}^{\tau_t} \exp\left(-A_t\tau\right)d\tau \\ &amp;\approx B_t\,x_t\exp(A_t\tau_t) (\tau_t - \tau_{t-1})\exp\left(-A_t\tau_t\right) \\ &amp;= \Delta_t\,B_t\,x_t \\ \end{aligned}\] <h3 id="new-exponential-trapezoidal-discretization">New Exponential-Trapezoidal Discretization</h3> <p>Thus, the linchpin of converting continuous-time SSMs into various tangible recurrences is <em>approximating $\int_{\tau_{t-1}}^{\tau_t} \exp\left((\tau_t - \tau)A_t\right)B(\tau)x(\tau)d\tau$ in different ways</em>. We’ve shown above that one can either analytically solve it (<strong>ZOH</strong>) or approximate it (<strong>Euler</strong>), but what if we don’t want to use inverses but also want something more precise than Euler’s?</p> <p>We can instead use the <strong>trapezoid method</strong> to approximate the integral using both endpoints instead of just one for Euler. Unlike the standard method which averages both endpoints, we use a convex combination which we find empirically performs better. The integral then evaluates to</p> \[\begin{aligned} &amp; \Delta_t \left(\lambda_t \exp((\tau_t - \tau_t) A_t) B_t\,x_t + (1 - \lambda_t) \exp((\tau_t - \tau_{t-1}) A_t)B_{t-1}\,x_{t-1} \right) \\ =&amp; (1-\lambda_t)\Delta_t\,\exp({\Delta_tA_t})\,B_{t-1}\,x_{t-1} + \lambda_t\,\Delta_t\,B_t\,x_t \end{aligned}\] <p>Interestingly, we can see here that for our new exponential-trapezoidal recurrence, there is some structured time-mixing across the state-input terms. Thus, it acts as an <strong>implicit data-dependent convolution</strong> of size two on the SSM’s state-input.</p> <h3 id="parallel-representation-of-new-recurrence">Parallel Representation of New Recurrence</h3> <p>Now how can we format our recurrence in a parallel representation to enable faster training? To do so, we’ll be viewing the recurrence in its parallel form. This hearkens back to the state space duality (SSD) framework introduced in Mamba-2.</p> <p>Let’s rewrite our recurrence so this doesn’t get <em>too</em> messy:</p> \[h_t = \alpha_t h_{t-1} + \beta_t B_{t-1}x_{t-1} + \gamma_t B_tx_t\] <p>where $\alpha_t = e^{\Delta_t A_t}, \beta_t=(1-\lambda_t)\Delta_t e^{\Delta_t A_t},\gamma_t=\lambda_t\Delta_t$.</p> <details><summary>Refresher on SSD</summary> <p>SSD demonstrated that a large class of recurrent SSMs could be represented in a parallel form that uses an element-wise multiplicative mask to model the state-transition decay. The form that such parallel representations take is $Y = (L \circ C^\top B) X$ where $L\in\mathbb{R}^{T,T}, C, B\in\mathbb{R}^{T,N}, X,Y\in\mathbb{R}^{T,D}$ where $T$ is the total sequence length.</p> <p>This format makes it pretty clear the connection between SSMs and attention, especially when changing the SSM-centric notation to one that is more common in attention literature: $C \to Q, B \to K, X \to V$. When $L$ is a lower triangular matrix of all ones, we get the vanilla linear attention <d-cite key="katharopoulos2020transformersrnnsfastautoregressive"></d-cite>, and Mamba-2 is a lower triangular one-semiseparable matrix.</p> <p>This parallel formulation is what enables the matmul-focused forward pass.</p> </details> <p>If we expand the recurrence where $h_{-1}=0$,</p> \[\begin{aligned} h_0 &amp; = \gamma_0 B_0x_0 \\ h_1 &amp; = (\alpha_1 \gamma_0 + \beta_1)B_0x_0 + \gamma_1 B_1x_1 \\ h_2 &amp; = \alpha_2(\alpha_1 \gamma_0 + \beta_1)B_0x_0 + (\alpha_2 \gamma_1 + \beta_2)B_1x_1 + \gamma_2 B_2x_2 \\ ... \\ h_T &amp; = \alpha_{T\dots2}(\alpha_1 \gamma_0 + \beta_1)B_0x_0 + \ldots + \gamma_T B_Tx_T \end{aligned}\] <p>, we can express the output as a matrix operation</p> \[\small \begin{bmatrix} y_0 \\ y_1 \\ y_2 \\ \vdots \end{bmatrix} = \left( \begin{bmatrix} \gamma_0 &amp; &amp; \\ (\gamma_0\alpha_1 + \beta_1) &amp; \gamma_1 &amp; \\ \alpha_2(\gamma_0\alpha_1 + \beta_1) &amp; (\gamma_1\alpha_2+\beta_2) &amp; \gamma_2 \\ \vdots &amp; &amp; &amp; \ddots \\ \end{bmatrix} \odot \begin{bmatrix} C_0^\top B_0 &amp; &amp; &amp; \\ C_1^\top B_0 &amp; C_1^\top B_1 &amp; &amp; \\ C_2^\top B_0 &amp; C_2^\top B_1 &amp; C_2^\top B_2 &amp; \\ \vdots &amp; &amp; &amp; \ddots \\ \end{bmatrix} \right) \begin{bmatrix} x_0 \\ x_1 \\ x_2 \\ \vdots \end{bmatrix}\] <p>which can further be decomposed into a 1-semiseparable matrix (Mamba-2’s decay mask) and a 2-band matrix.</p> <blockquote> <p>The equivalence established between the recurrent and parallel forms of Mamba-3 is another instance of what Mamba-2’s SSD established: that certain classes of SSMs have a matrix form that vectorizes the recurrence. This forms the <strong>foundation of the hardware-efficient algorithm</strong> used for training.</p> </blockquote> <h2 id="complex-valued-ssm">Complex-valued SSM</h2> <p>As we’ve mentioned, the simplification of the SSM over the past couple years to improve efficiency has reduced the abilities of newer models<d-footnote>We've mainly been highlighting the simplification of diagonal to identity times scalar across the LTV systems of Mamba-1 to Mamba-2, but the original LTI SSMs were actually complex-valued! Mamba-1 simplified the SSM to be all real-valued which empirically did not impact language modeling, but as we will see, reduced state-tracking capabilities. </d-footnote>. This has been corroborated by a whole host of work which finds that linear RNN-style models are theoretically constrained on state-tracking tasks by their lack of non-linearity between timesteps and their structured matrix transitions, both of which, unfortunately, are critical to their efficient computation <d-cite key="grazzi2025unlockingstatetrackinglinearrnns"></d-cite><d-cite key="merrill2025illusionstatestatespacemodels"></d-cite><d-cite key="cirone2025theoreticalfoundationsdeepselective"></d-cite><d-cite key="Hahn_2024"></d-cite>.</p> <p>While more complex state transitions, such as diagonal plus low-rank (DPLR) in older LTI SSM models, can improve the method’s expressivity, the simplification of transitions across iterations of LTV SSMs has resulted in even the simplest state-tracking tasks falling out-of-reach for Mamba-style models <d-footnote>Related delta-rule based linear attention models, e.g., GDN, KDA<d-cite key="kimiteam2025kimilinearexpressiveefficient"></d-cite>, are able to partially mitigate these state-tracking limitations through more expressive state transitions, i.e., identity plus low-rank. </d-footnote>. The inability to solve some of these state-tracking synthetics may signal poor performance in practice where models might need to keep track of parentheses and diffs for coding or actions and states throughout a story.</p> <h3 id="parity-what-is-that">Parity, what is that?</h3> <p>One of the simplest tasks, <strong>parity</strong>, determining whether the sum of a sequence of 0’s and 1’s is even, is unsolvable by Mamba models in a constant number of layers. The ideal solution requires the hidden state to track whether the running sum is even or odd and to alternate depending on the next input, modeling a simple two-state automaton <d-cite key="grazzi2025unlockingstatetrackinglinearrnns"></d-cite>. While this seems simple enough, current Mamba models constrain the transition $\bar{A}_t \in [0,1]$, which <em>forces</em> the model to learn the naive solution: add all the values together then mod 2 <d-footnote>If $\bar{A}_t$ can be -1, it would enable the alternating solution, but would require the underlying implementation to be rewritten as currently the $\bar{A}$ is handled in log-space. </d-footnote>. This does work for shorter sequences but quickly becomes infeasible when the sequence outgrows the state.</p> <p>But, parity and other modulo tasks can be solved with <strong>rotations</strong>! The way one can visualize how rotations solve modulo $m$ problems is that one has some 2D vector that can be rotated around the origin. The entire possible angle distribution $[0, 2\pi]$ is partitioned into $m$ sections, and the vector is rotated by $\tfrac{2\pi}{m}$ to align with the current running modulo remainder.</p> <h3 id="representing-with-real-valued-ssms">Representing with Real-valued SSMs</h3> <p>Working with complex values in computer systems is quite a pain due to their multiplicative interactions. Luckily for us, <strong>diagonal complex-valued continuous SSMs <em>can</em> be represented as discretized real-valued SSMs</strong> (without any additional approximation loss compared to standard discretization).</p> <p>While the full proof can be found within our paper, the general intuition is to expand the original $N$-sized state complex-valued SSM to a $2N$-sized real-valued where each complex-valued dimension is split into its real and imaginary counterpart, and the complex transition matrix is partitioned into its scaling and oscillatory portions. With the commutative property of the scaling and oscillatory components in effect (due to the diagonal structure of the underlying matrix), we can map the continuous diagonal complex transition into a block-diagonal, scaled rotation transformation.</p> <details><summary>Converting general complex SSMs to real-valued SSMs</summary> <p>It is also possible to convert a general, unstructured complex SSM transition matrix to a real-valued SSM, though the scaled rotation intuition breaks down. The conversion still doubles the state size, but while the expansion of the $B, C, x$ will remain similar, the transition matrix will no longer be as simple. With an unstructured transition matrix $\mathbf{A} + i\Theta$, the exponential (resulting from our integrating factor technique) cannot be factored as</p> \[\exp(\Delta(\mathbf{A} + i\Theta)) \neq \exp(\Delta\mathbf{A}) \exp(i\Delta\Theta)\] <p>since $\mathbf{A},\Theta$ generally do not commute, unlike the diagonal case. Consequently, while a real-valued equivalent exists, computing it would require the expensive full matrix exponential.</p> </details> <p>Eventually, under the prior exponential-Euler discretization, we obtain the following recurrence</p> \[\begin{aligned} h_t &amp;= e^{\Delta_t A_t} \underbrace{ \begin{bmatrix} \cos(\Delta_t \theta_t) &amp; -\sin(\Delta_t \theta_t) \\ \sin(\Delta_t \theta_t) &amp; \cos(\Delta_t \theta_t) \end{bmatrix}}_{\vphantom{\Big|}R_t} h_{t-1} + \Delta_t B_t x_t \end{aligned}\] <p>for an $N=2$ state. For larger states, the rotation matrix $R_t$ is block-diagonal and the $\theta$’s can differ.</p> <h3 id="efficient-implementation-with-rope-trick">Efficient Implementation with RoPE Trick</h3> <p>Great, now we’ve shown that we can implement a complex SSM without having to explicitly model the imaginary components! But another issue remains: the rotation of the hidden state requires us to reimplement the kernels to incorporate this new type of transition — more moving parts, rotating the entire hidden state, etc., — seems like quite the hassle <em>ugh</em>… or is it?</p> <p>Luckily for us, given the structure of $A$, we can sidestep all of this and directly adjust the $B, C$ to achieve the same goal. This is because the output for timestep $t$ can be modeled as</p> \[y_t = C^\top_t\bar{B}_t + \cdots + C^\top_t(\bar{A}R)^\times_{t\cdots 1}\bar{B}_0\] <p>Since $\bar{A}$ is a scaled identity matrix, we can ignore the $\bar{A}$ terms for now by absorbing them into $C$. This results in the term $C_i^\top R_i \cdots R_{j+1} \bar{B}_j$ which can be represented by $\left(R_i \cdots R_0 C_i\right)^\top \left(R_j \cdots R_0 \bar{B}_j\right)$. The $\bar{A}$ terms can be reintroduced at this point. Thus, it’s apparent that the rotations can be embedded into the $B, C$ terms prior to performing the SSM recurrence instead of directly adjusting the transition matrix.</p> <blockquote> <p>The application of our data-dependent rotations onto $B, C$ can be done efficiently. Instead of performing numerous matrix multiplications, we can run a cumulative sum over the $\theta$’s and perform the efficient realization of rotary matrix multiplication from the RoFormer paper <d-cite key="su2023roformerenhancedtransformerrotary"></d-cite>, which itself used data-independent rotations. This inspired us to call the use of a vanilla SSM to compute a complex SSM as the “RoPE trick.”</p> </blockquote> <p>The RoPE trick extends to our exponential-trapezoidal recurrence, and we empirically validate that our complex-valued SSM is able to solve state-tracking tasks previously too hard for prior Mambas.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2026-03-16-mamba-3/state-tracking-480.webp 480w,/assets/img/2026-03-16-mamba-3/state-tracking-800.webp 800w,/assets/img/2026-03-16-mamba-3/state-tracking-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2026-03-16-mamba-3/state-tracking.png" width="100%" height="auto" title="Complex vs Real Mamba on State-Tracking" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> <figcaption class="caption">Converting the SSM from real to complex-valued gives the model the capability to solve parity and other state-tracking tasks.</figcaption> </figure> <h2 id="multi-input-multi-output">Multi-Input, Multi-Output</h2> <p>The compute paradigm in scaling LLMs has shifted from training to inference in the past two years or so. Nowadays, more and more compute is dedicated to the actual deployment and usage of these models, and to some degree, the writing was on the wall. Emergent properties such as chain-of-thought and in-context learning techniques dramatically improved performance of earlier models by making them think longer and process more tokens, and all of the best models to date are reasoners which have (almost certainly) been post-trained with reinforcement learning using large rollout budgets. With the advent of agentic workflows, we have agents spawning subagents and so forth.</p> <blockquote> <p>What does such a paradigm change mean for hardware efficiency?</p> </blockquote> <h3 id="working-while-memory-bound">Working while Memory-Bound</h3> <p>Compared to the <strong>compute-bound</strong> nature of training, the deployment of the same models, especially decoding, is <strong>memory-bound</strong>. Throughout training, the hardware is constantly performing operations, but during decoding, the compute units of the <em>hardware sit idle</em> for large swathes of time as it waits for data to be moved across different levels of memory hierarchy.</p> <p>As an example of why this happens, think about a simple MLP. During training, the entire sequence is processed, but during decode, only the current token is processed as the past tokens are cached. The latency spent moving the MLP weights is around the same for both training and decode, but can be amortized a lot better over more computation under a training regime.</p> <p>So with current linear models where the state update and output calculation can be performed in <strong>constant time</strong>, compute units sit idle for most of the time, and we are bottlenecked by simply moving data back and forth! One way to estimate how “hard” the hardware is working is through <strong>arithmetic intensity</strong>, a ratio of compute performed to memory moved.</p> <p>Let’s analyze how SSMs are deployed in practice and their arithmetic intensity. A typical SSM, say Mamba-2 for instance, is organized into heads with head dimension $P$, where a single head is composed of $P$ SISO SSMs that share the same $a_t, B_t, C_t, B_t$</p> \[\begin{align*} \mathbf{h}_t &amp; = a_t \mathbf{h}_{t-1} + B_t \mathbf{x}_t \\ \mathbf{y}_t &amp;= C_t^\top \mathbf{h}_t \end{align*}\] <p>where $a_t$ is a scalar decay and $\mathbf{x}_t, \mathbf{y}_t \in \mathbb{R}^{P}, \mathbf{h}_t \in \mathbb{R}^{N\times P}$. If we use 2-byte data, for a single decode step, the total memory traffic is $2(1 + 2N + P + NP)$ when accounting for all SSM parameters. The movement of the hidden state is clearly the main contributing factor at reasonable values of $P$ and $N$.</p> <p>When calculating the number of FLOPs used for the same operation, we get around $5NP - P$ <d-footnote>A quick rundown why. The $a_th_{t-1}$ scaling, $B_tx_t$ outer product, and their summation take a total of 3NP. The matmul between $C_t^\top h_t$ takes $2N - 1$ per $P$ dimension ($N$ multiplication and $N-1$ accumulation), resulting in a final $5NP - P$.</d-footnote>. Thus, default SSM decoding has an arithmetic intensity of around $2.5$. To put this into context, the arithmetic intensity of matmuls for an H100 is around $300$ ops per byte; anything above this is compute-bound. Having an arithmetic intensity of as low as $2.5$ means that decoding is squarely memory-bound… <em>yikes</em></p> <p>Since we pay for the entire rack and the expensive tensor cores, how can we use as many of them as possible?</p> <h3 id="the-mimo-system">The MIMO System</h3> <details><summary>Q: How does one increase a ratio?</summary> <p>A: Either increase the numerator or decrease the denominator.</p> </details> <p>We’ve seen empirically that the state size is quite important for performance but expanding it also increases memory… so, let’s keep that the same. Now how do we increase the compute required for calculating the hidden state recurrence while maintaining the same hidden state? Referring back to our state space/control theory toolbox, <strong>multi-input, multi-output (MIMO)</strong> SSMs can be used instead of the single-input, single-output (SISO) SSMs we’ve been using.</p> <p>Through the expansion of the dimension of $\mathcal{C}_t, \mathcal{B}_t$ to $N \times R$ and $\mathbf{x}_t, \mathbf{y}_t$ to $P \times R$ where $R$ is the rank of the system, we can maintain similar memory traffic (for small enough $R$) while increasing the FLOPs utilized when operating with matrix multiplications with $B_t x_t^\top$ instead of outer-products.</p> \[\begin{aligned} \mathbf{h}_{t} &amp;= a_t h_{t-1} + \mathcal{B}_t \mathbf{x}_t^\top \\ \mathbf{y}_t &amp;= \mathcal{C}_t^\top \mathbf{h}_t \end{aligned}\] <p>The total FLOP count thus increases to $4NPR + NP - PR$ which <strong>results in an arithmetic intensity that scales with $O(R)$</strong> when $R \ll P, N$ (generally the case as $P=64, N=128$ and $R=4$).</p> <h3 id="intuition-and-training">Intuition and Training</h3> <blockquote> <p>The downstream gains and comparable decoding latency associated with switching from SISO to MIMO require compute costs that scale linearly with $R$ during compute-bound training.</p> </blockquote> <p>Expressing the output of a MIMO SSM would require $R^2$ SISO SSMs due to its rank $R$ state-input and $R$ unique outputs. Its hidden state can be partitioned into the sum of $R$ SISO hidden states, and subsequently, the hidden state needs to be instantiated $R$ times for each of the outputs. But if the expressivity is a $R^2$ increase, how does training compute required scale by <em>only $R$</em>?</p> <p>The <strong>chunked training algorithm</strong> is the reason for this disparity. Most linear models, including both Mamba-2 and Mamba-3, are computed in chunked fashion, where the sequence is partitioned into chunk sizes of $C$. The hidden state is aggregated across chunks in a <em>sequential</em> manner, while the output of the SSM is calculated with a quadratic, <em>parallel</em> algorithm.</p> <p>For MIMO, the computation of outputs between chunks increases by a factor of $R$, whereas the computation of outputs within each chunk increases by $R^2$. So by decreasing the chunk size to $\tfrac{C}{R}$, the total FLOP count only increases by a factor of $R$. Our paper covers the actual FLOP calculations, but one way to think about it is that we want to reduce the amount of compute required for each quadratic algorithmic pass by increasing the number of times we call it.</p> <h3 id="instantiation">Instantiation</h3> <p>Given the interpretation of MIMO SSMs as multiple SISO ones, improvements introduced for vanilla SSMs, like our exponential-trapezoidal discretization and complex-valued transition, can be directly applied to our MIMO variant. However, the conversion must be done carefully without drastically increasing the total parameter count. The naive solution of expanding the projection size would lead to a $R\times$ increase as the SSM inputs $x, B, C$ would all need to be adjusted. The subsequent rank $R$ output $Y$ would also force the output gate $Z$ and output projection to expand as well. This approach is clearly untenable.</p> <p>Instead, we can use Mamba’s multi-value attention structure to our advantage. Since the $B, C$ are tied across all heads, we can increase the projection size without much issue, resulting in a fairly negligible $DN \to DNR$ increase for the entire layer. However, the input $x$, output $y$, and gate $Z$ are unique per head and are the main source of parameters, thus cannot be increased in such a way. Instead, we keep the original projections then element-wise scale each dimension of the projected value to size $R$ using a learnable, data-independent vector. For each head, we are able to reduce the parameter count from $DPR$ to $DP + PR$, which is quite the reduction given the number of heads each Mamba layer has!</p> <p>We show that our instantiation <strong>balances the expressivity of multi-input, multi-output systems and parameter efficiency</strong>. In parameter-matched settings, our Mamba-3 MIMO variant further improves the already strong performance of regular Mamba-3 at all scales we tested on. When analyzing state size (proxy for decoding speed) to performance in controlled experiments, Mamba-3 sets the Pareto front compared to prior Mamba-2, able to achieve comparable performance with half the state size.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2026-03-16-mamba-3/mamba-pareto-480.webp 480w,/assets/img/2026-03-16-mamba-3/mamba-pareto-800.webp 800w,/assets/img/2026-03-16-mamba-3/mamba-pareto-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2026-03-16-mamba-3/mamba-pareto.png" width="100%" height="auto" title="Mamba State Size to Performance" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> <figcaption class="caption">When analyzing the Pareto frontier of state size (a good proxy for decoding speed) to performance, Mamba-3 dominates prior Mamba-2. The MIMO variant of Mamba-3 pushes performance further without increasing state size at all.</figcaption> </figure> <blockquote> <p>Mamba-3 offers a faster model with the same quality or a better model for the same speed.</p> </blockquote> <h2 id="the-end-for-now">The End, For Now</h2> <p>We had to cut a bunch of proofs and results to keep the content in here digestible, but if that interests you, please do read our paper!</p> <p>Within our work, we’ve aimed at boosting the performance and capabilities of the Mamba series from a few SSM-centric improvements. We’re curious to see how and where the community explores in architecture research. In particular, we are quite excited (and think addressing them would be really impactful) in the following directions:</p> <ul> <li> <p><strong>Building better hybrids</strong>: It’s been amazing to see the general research community and industry labs appreciate the benefits linear models can provide, especially with hybrid models <d-cite key="kimiteam2025kimilinearexpressiveefficient"></d-cite><d-cite key="qwen3technicalreport"></d-cite><d-cite key="tencenthunyuanteam2025hunyuanturbosadvancinglargelanguage"></d-cite>. Most architectures follow an interleaved structure, but the “science” of what enables good linear-self attention synergy is still unknown. We’ve seen a lot of cool work making important ground, e.g., shifting from RoPE to NoPE for attention layers <d-cite key="yang2025ropenopeagainnew"></d-cite> or keeping the first and last layers attention-free <d-cite key="waleffe2024empiricalstudymambabasedlanguage"></d-cite> A more meta question might be: <em>are interleaved hybrid models truly the best way to utilize linear models?</em></p> </li> <li> <p><strong>Improving Layer Primitives</strong>: Our methodological improvements, while most natural to SSMs, can be applied to other architectures. It would be interesting to see how they scale under different transition mechanisms. In addition, there seems to be a whole trove of untapped improvements waiting to be uncovered or inspired in the “classics,” if you will. Just as Mamba and other SSMs are grounded within signal processing and traditional state space literature, such parallels can be found in other types of linear models — fast-weight programmers for linear attention <d-cite key="schlag2021lineartransformerssecretlyfast"></d-cite>, for example. <em>What might the standard transition look like for the best self-attention alternative in two, three years?</em></p> </li> </ul>]]></content><author><name>Aakash Lahoti*</name></author><summary type="html"><![CDATA[This series is cross-posted at GoombaLab]]></summary></entry><entry><title type="html">FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling</title><link href="tridao.github.io/blog/2026/flash4/" rel="alternate" type="text/html" title="FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling"/><published>2026-03-05T00:00:00+00:00</published><updated>2026-03-05T00:00:00+00:00</updated><id>tridao.github.io/blog/2026/flash4</id><content type="html" xml:base="tridao.github.io/blog/2026/flash4/"><![CDATA[<p>[<a href="https://github.com/Dao-AILab/flash-attention/blob/main/assets/fa4_paper.pdf">Paper</a>] [<a href="https://github.com/Dao-AILab/flash-attention">Code</a>]</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2026-03-05-flash4/h100_vs_b200-480.webp 480w,/assets/img/2026-03-05-flash4/h100_vs_b200-800.webp 800w,/assets/img/2026-03-05-flash4/h100_vs_b200-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2026-03-05-flash4/h100_vs_b200.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>Modern accelerators like Blackwell GPUs continue the trend of asymmetric hardware scaling, where tensor core throughput grows far faster than other resources such as shared memory bandwidth, special function units (SFUs) for transcendental operations like exponential, and general-purpose integer and floating-point ALUs. From the Hopper H100 to the Blackwell B200, for instance, BF16 tensor core throughput increases from 1 to 2.25 PFLOPs, while both the SFU count and shared memory bandwidth remains unchanged.</p> <p>This scaling asymmetry has profound implications for optimizing complex kernels like attention for the Blackwell architecture. At its core, attention comprises two GEMMs ($S = Q \cdot K^T$ and $O = P \cdot V$) with softmax in-between; in practice, it also involves substantial plumbing and bookkeeping: data movement, synchronization, layout transforms, element-wise ops, scheduling, masking, etc.</p> <p>A naive viewpoint on attention might be that the speed of the GEMMs completely controls the kernel performance and one can effectively disregard these other attention components, at least to first order. However, doing a “feeds and speeds” analysis for B200 in fact shows the opposite: the main performance bottleneck lies not in how fast the tensor cores can do MMA, but rather (a) in the SFU units for softmax exponential during the FWD computation, and (b) in the shared-memory traffic during the BWD computation.</p> <p>In this blog post, we present FlashAttention-4, an algorithm and kernel co-design that maximizes overlap between matmul and these other resource bottlenecks. On B200 with BF16, it reaches up to 1605 TFLOPs/s (71% utilization), up to 1.3x faster than cuDNN version 9.13 and 2.7x faster than Triton.</p> <p>Our main algorithmic and kernel co-design ideas are as follows:</p> <ol> <li> <p><strong>New pipelining for maximum overlap</strong>: New forward and backward software pipelines that exploit Blackwell fully asynchronous MMA and larger tile sizes, overlapping tensor cores, softmax exponential, and memory operations.</p> </li> <li> <p><strong>Forward (FWD) pass</strong>: A software emulation of the exponential function implemented via polynomial approximation on FMA units to mitigate the exponential bottleneck, plus conditional online softmax rescaling.</p> </li> <li> <p><strong>Backward (BWD) pass</strong>: Storing intermediate results in tensor memory to relieve shared-memory traffic, combined with Blackwell’s new 2-CTA MMA mode to reduce shared memory traffic further and also cut atomic reduction in half, and additional support for deterministic execution mode for reproducible training.</p> </li> <li> <p><strong>Scheduling</strong>: New tile scheduler to mitigate load imbalance from causal mask and variable sequence length.</p> </li> </ol> <h2 id="new-hardware-features-on-blackwell">New hardware features on Blackwell</h2> <p><strong>Tensor memory (TMEM)</strong>: On B200, each of the 148 SMs has 256 KB of TMEM, an on-chip scratchpad wired into the tensor cores for warp-synchronous intermediate storage.</p> <p><strong>Fully asynchronous 5th gen tensor cores</strong>: <code class="language-plaintext highlighter-rouge">tcgen05.mma</code> is asynchronous and accumulates in TMEM. For BF16 and FP16, the largest single CTA UMMA tile is 128x256x16, which is about 2x larger than the largest Hopper WGMMA atom. UMMA is launched by a single thread, easing register pressure and making larger tiles and deeper pipelines practical without the spilling pain points of Hopper warpgroup MMA. This also makes warp specialization more viable, with some warps moving tiles while others issue MMA to overlap matrix multiply accumulate with softmax and memory traffic. <code class="language-plaintext highlighter-rouge">tcgen05.mma</code> can also source operand A from TMEM.</p> <p><strong>2-CTA MMA</strong>: Blackwell can execute one UMMA across a CTA pair in the same cluster, spanning the TMEM of both peer CTAs. One thread in the leader CTA launches the MMA, but both CTAs must stay active while it is in flight. This scales the MMA tile dimension up to 256x256x16 by splitting M and N across the pair, reducing redundant traffic and lowering per-CTA footprint. The CTA group size, 1 or 2, must remain constant across TMEM and tensor core operations within a kernel.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2026-03-05-flash4/2cta_mma-480.webp 480w,/assets/img/2026-03-05-flash4/2cta_mma-800.webp 800w,/assets/img/2026-03-05-flash4/2cta_mma-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2026-03-05-flash4/2cta_mma.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <h2 id="feeds-and-speeds">Feeds and Speeds</h2> <p>For M=N=D=128, here are the feeds on B200 (per SM):</p> <ul> <li><strong>Tensor Cores (BF16)</strong>: 8192 ops/cycle</li> <li><strong>Exponential unit</strong>: 16 ops/cycle</li> <li><strong>Shared Memory traffic</strong>: 128 bytes/cycle</li> </ul> <p>And the speeds (clock-cycles per tile):</p> <p><strong>Forward (2 MMAs + MN exp)</strong>:</p> <ul> <li>Tensor Cores: 1024</li> <li>Exp: 1024</li> <li>SMEM: 768</li> </ul> <p><strong>Backward (5 MMAs + MN exp) — 1-CTA</strong>:</p> <ul> <li>Tensor Cores: 2560</li> <li>Exp: 1024</li> <li>SMEM: 3328</li> </ul> <p><strong>Takeaway</strong>: Forward is bottlenecked by compute and exponential, backward is bottlenecked by shared memory bandwidth. So we overlap softmax with MMA in the forward pass and reduce shared memory traffic in the backward pass.</p> <h2 id="forward-pass-new-softmax-pipelining-with-conditional-rescaling">Forward pass: New softmax pipelining with conditional rescaling</h2> <p>The forward pass has two matmuls, <code class="language-plaintext highlighter-rouge">QK^T</code> and <code class="language-plaintext highlighter-rouge">PV</code>. On Blackwell, tensor cores got much faster, but the exponential unit (MUFU.EX2) did not. So softmax is no longer “just the thing between the two matmuls” — it is a bottleneck that must be carefully pipelined.</p> <p>The FWD pass in short:</p> <ul> <li><strong>Ping-pong schedule</strong> 2x Q and 2x O tiles per CTA: maximize overlap between MMA and Softmax</li> <li><strong>2x softmax warpgroups</strong>: per-tile softmax with synchronization to not overlap when computing exponential</li> <li><strong>Software emulation of $2^x$</strong>: distribute exp computation across hardware’s MUFU and software emulated on FMA</li> <li><strong>Store P in TMEM in stages</strong>: mitigate register pressure</li> <li><strong>Correction warpgroup</strong>: designated “correction” warpgroup to perform rescaling to remove from critical path</li> <li><strong>Online softmax (conditional) rescaling</strong>: Rescale less frequently to minimize non-matmul operations</li> </ul> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2026-03-05-flash4/fa4_fwd_pipeline-480.webp 480w,/assets/img/2026-03-05-flash4/fa4_fwd_pipeline-800.webp 800w,/assets/img/2026-03-05-flash4/fa4_fwd_pipeline-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2026-03-05-flash4/fa4_fwd_pipeline.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <h3 id="pipeline-ping-pong-q-tiles-plus-a-dedicated-correction-stage">Pipeline: Ping-pong Q tiles plus a dedicated correction stage</h3> <p>FlashAttention-4 computes two query tiles per CTA — $Q^H$ and $Q^L$ — each covering 128 query tokens, and alternates them in a ping-pong schedule.</p> <p>Blackwell changes the softmax mapping. The accumulator tile for <code class="language-plaintext highlighter-rouge">S = QK^T</code> is 128x128 and lives in tensor memory; however, upon being read into registers, we have one thread per row for the partitioning of the tile as dictated by the hardware. We use two 128-thread warpgroups, one per Q tile, and each softmax warpgroup executes the following sequence of operations:</p> <ol> <li>Each thread loads one 128-element row of <code class="language-plaintext highlighter-rouge">S</code> from tensor memory into registers</li> <li>Reduce <code class="language-plaintext highlighter-rouge">rowmax</code> and <code class="language-plaintext highlighter-rouge">rowsum</code></li> <li>Using a tunable parameter, decide which portion of the 128 elements uses hardware’s MUFU vs. software-emulated $e^x$</li> <li>Compute <code class="language-plaintext highlighter-rouge">P = softmax(S)</code> and convert to BF16 precision</li> <li>Store <code class="language-plaintext highlighter-rouge">P</code> back to tensor memory in stages to relieve register pressure (as opposed to holding 128 elements of S and 64 BF16 elements of P simultaneously)</li> <li>Trigger the corresponding <code class="language-plaintext highlighter-rouge">PV</code> matmul as soon as a 3/4 chunk of <code class="language-plaintext highlighter-rouge">P</code> is stored</li> </ol> <p>The critical detail is that exp is the bottlenecked section. We explicitly synchronize the two softmax warpgroups so they do not evaluate exp at the same time, thereby reducing MUFU contention.</p> <p>To keep rescaling off the critical path, the kernel assigns it to a dedicated warpgroup. The correction warpgroup computes:</p> <p>Only rescale when the max jump is large:</p> \[O_j = \begin{cases}\exp(m_{j-1}-m_j)\,O_{j-1} + \exp(S_j-m_j)\,V_j, &amp; \text{if } m_j - m_{j-1} &gt; \tau,\\O_{j-1} + \exp(S_j-m_{j-1})\,V_j, &amp; \text{otherwise.}\end{cases}\] <p>Apply the final normalization at the end of the iteration $O_{\text{final}} = \frac{O}{l_{\text{final}}}$.</p> <p>At the end we still normalize using the true final statistics, so skipping small rescale steps preserves the final output while deleting many vector computations from the critical path. We make the decision at warp granularity to avoid divergence.</p> <h3 id="faster-exponential-distribute-2x-across-mufuex2-and-fma">Faster exponential: Distribute 2^x across MUFU.EX2 and FMA</h3> <p>Softmax requires many exponentials, and MUFU throughput is much lower than tensor core throughput. FlashAttention-4 increases effective exp throughput by running the software emulation of <code class="language-plaintext highlighter-rouge">exp2</code> alongside the hardware <code class="language-plaintext highlighter-rouge">MUFU.EX2</code> path, using FMA units that would otherwise be underutilized.</p> <p><strong>Range-reduction (Cody-Waite)</strong>: We use the classical technique of Cody-Waite range reduction to decompose the exponential computation into the integer and the fractional part: $2^x = 2^{n} \cdot 2^{f}$. In IEEE 754 float32, scaling by $2^n$ is just an exponent update.</p> <p><strong>Polynomial approximation of $2^{x_\text{frac}}$ (Horner’s Method)</strong>: To approximate $2^f$ we rewrite in Horner’s form for efficient evaluation.</p> \[2^{x_{\text{frac}}} \approx p_0 + p_1 x_{\text{frac}} + p_2 x_{\text{frac}}^{2} + p_3 x_{\text{frac}}^{3}\] <p>The coefficients <code class="language-plaintext highlighter-rouge">p0 = 1.0</code>, <code class="language-plaintext highlighter-rouge">p1 ≈ 0.6951</code>, <code class="language-plaintext highlighter-rouge">p2 ≈ 0.2276</code>, <code class="language-plaintext highlighter-rouge">p3 ≈ 0.0771</code> are chosen using the Sollya software package to minimize the relative approximation error over $[0, 1)$.</p> <p><strong>Exponent bits shift and add</strong>: The final step is to combine the integer part $n$ and the fractional approximation $2^f$ to form $2^{x} \approx 2^{n}\cdot 2^{f}$. Since $2^f \in [1,2)$ has float32 exponent 127, multiplying by $2^{n}$ is just shifting the integer $n$ into the exponent field and then adding the mantissa bits of $2^{f}$.</p> <h2 id="backward-pass-where-shared-memory-traffic-dominates">Backward pass: Where shared memory traffic dominates</h2> <p>Optimizing FlashAttention backward can feel like stuffing an oversized rug into a room: flatten one corner and another pops up. Backward computes about 2.5x the tensor core work of the forward pass, chaining five MMA operations to recompute S and run the QK and PV gradient MMAs for dQ, dK, dP, and dV, plus the element-wise work for P and dS. On Blackwell, FLOPs are not the limiter for backward; shared memory bandwidth is.</p> <h3 id="pipeline-overlap-mmas-with-softmax">Pipeline: Overlap MMAs with softmax</h3> <p>Hopper-era FlashAttention-3 keeps MMA accumulators in registers, so register pressure often forces a more serial schedule. On Blackwell, accumulators live in TMEM, which makes it practical to keep multiple MMAs in flight while the CUDA cores handle the element-wise work for P and dS. Since exponential throughput is comparable to two MMAs in our roofline, hiding it is worth it.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2026-03-05-flash4/fa4_bwd_pipeline-480.webp 480w,/assets/img/2026-03-05-flash4/fa4_bwd_pipeline-800.webp 800w,/assets/img/2026-03-05-flash4/fa4_bwd_pipeline-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2026-03-05-flash4/fa4_bwd_pipeline.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>The key overlap is simple: while we compute softmax for tile j, we already issue the dK and dQ MMAs for tile j-1.</p> <p>To reduce shared memory traffic, the backward pass recomputes S and P in a transposed tile relative to the forward pass, so the intermediate is already $S^T$ and $P^T$. We can then store $P^T$ (and later $dS^T$) directly in TMEM in the exact operand A layout consumed by the dV and dK MMAs respectively.</p> <p>TMEM cannot hold five full accumulators and intermediates at once, so FA4 reuses TMEM columns across stages: S and P share one set of columns, and dP, dS, and dQ share another.</p> <h3 id="2-cta-backward-pass-reducing-shared-memory-traffic-and-global-atomic-adds">2-CTA backward pass: Reducing shared memory traffic and global atomic adds</h3> <p><strong>Shared memory traffic.</strong> Even with the improved pipeline and with two of the ten GEMM operands kept in tensor memory, the backward pass is still limited by shared memory bandwidth. We mitigate this with Blackwell 2-CTA MMA mode, which partitions the output accumulator across the CTA pair. With M=256 and N=K=128, the two CTAs cooperate as one tile: each CTA stages half of operand B and keeps only its own accumulator slice. This roughly halves shared memory traffic for operand B.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2026-03-05-flash4/2cta_bwd_tiles-480.webp 480w,/assets/img/2026-03-05-flash4/2cta_bwd_tiles-800.webp 800w,/assets/img/2026-03-05-flash4/2cta_bwd_tiles-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2026-03-05-flash4/2cta_bwd_tiles.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p><strong>Reduction axis conflict.</strong> We use M=256 and N=K=128 MMA tile across the five backward GEMMs to cut B traffic, but the nature of dQ MMA introduces a mismatch. In FlashAttention backward, each CTA owns a fixed KV tile (outer loop parallelized across N CTAs) and iterates over M tiles in the inner loop. The dQ update reduces over the KV sequence in the outer loop. 2-CTA MMA splits the output tile, not the reduction, and the dQ reduction dimension is N, which is already split across the CTA pair. Each CTA still needs the full reduction for the rows it owns.</p> <p><strong>Solution: DSMEM exchange.</strong> We resolve this by exchanging half of dS between the two CTAs using distributed shared memory within the cluster. This repacks dS so it is partitioned along the non-reduction axis: each CTA owns M/2 rows while holding the full 2N reduction. The per-CTA dQ MMA becomes (M/2, 2N)(2N, d), accumulating an (M/2, d) tile in tensor memory. In 2-CTA mode, the S, dP, dV, and dK MMAs keep M=256, while dQ uses M=128 with doubled reduction 2N=256. We then reorder the pipeline to hide DSMEM latency: compute dP for the current tile before computing dQ for the previous tile. Since the dQ tile fits in TMEM alongside P, it can reuse the TMEM region used for S, so dP and dQ no longer share a region as in 1-CTA mode. With this ordering, element-wise dS for the current tile overlaps with the dQ MMA from the previous iteration.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2026-03-05-flash4/2cta_dq_dsmem-480.webp 480w,/assets/img/2026-03-05-flash4/2cta_dq_dsmem-800.webp 800w,/assets/img/2026-03-05-flash4/2cta_dq_dsmem-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2026-03-05-flash4/2cta_dq_dsmem.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p><strong>dQ atomic adds.</strong> As a side benefit, the dQ decomposition halves the number of global atomic reductions. Atomics are nondeterministic and expensive, and they occur in every inner-loop iteration. Consequently, in the 2-CTA backward pass each CTA writes only half of the dQ tile and performs half as many global atomic reductions as the 1-CTA counterpart.</p> <h3 id="deterministic-mode">Deterministic mode</h3> <p>The source of nondeterminism is the global atomic accumulation for dQ. FA4 provides a deterministic mode that serializes the global reductions with a semaphore-style lock and memory fence to enforce a fixed accumulation order. However, determinism does not have to mean “everything stops.” FA4 reduces lock contention with CTA swizzling, and uses a shortest-processing-time-first (SPT) ordering for causal masking to reduce stalls. In practice, deterministic backward reaches up to about 85-90% of the nondeterministic throughput in our benchmarks.</p> <h2 id="scheduling">Scheduling</h2> <p>Causal masking and variable sequence length make attention load-imbalanced because different worktiles have different mainloop lengths, so FA4 improves grid linearization and applies longest-processing-time-first (LPT) scheduling to reduce the tail. In fact, these ideas are non-specific to Blackwell or any particular GPU architecture, and we also use them in FA3.</p> <p>For causal masking, the standard (mblocks, heads, batches) grid order suboptimally processes tiles from shortest to longest, so FA4 swizzles batch-heads into L2-sized sections and traverses the grid by batch-head section, iterating mblocks in reverse order and then the batch-heads within each section.</p> <p>For variable sequence length, since different batches involve different amounts of work, the given batch-processing order is typically suboptimal from the point of view of the LPT scheduling heuristic. To rectify this, we can launch a preprocessing kernel that sorts batches by maximum per-worktile execution time and writes a virtual-to-actual batch index mapping that the attention kernel uses to traverse batches in sorted order; moreover, the metadata can be cached so that sorting adds no performance loss.</p> <h2 id="language-and-framework-cute-dsl">Language and framework: CuTe-DSL</h2> <p>FA4 is implemented entirely in CuTe-DSL, CUTLASS’ Python kernel DSL. Kernels are written in Python; the DSL lowers to PTX, then the CUDA toolkit compiles to GPU machine code. The programming model mirrors CuTe/CUTLASS abstractions with a PTX escape hatch, while cutting compile times by ~20-30x vs C++ templates.</p> <h2 id="attention-benchmarks">Attention Benchmarks</h2> <p>We show results for FlashAttention-4 on B200 (BF16) and compare it to FlashAttention-2, as well as to implementations in Triton, Gluon, and cuDNN. For cuDNN, we compare against cuDNN 9.13 and the latest version, 9.19.1.2. Starting with versions 9.13 and 9.14, we have worked with the cuDNN team to incorporate some techniques from FlashAttention-4 into cuDNN, so that our work can benefit as many practitioners as possible.</p> <p>For the forward pass, FlashAttention-4 is 1.1-1.3x faster than cuDNN 9.13 and 2.1-2.7x faster than Triton. For the backward pass, FlashAttention-4 consistently outperforms the other baselines for large sequence lengths.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2026-03-05-flash4/fa4_fwd_causalFalse_hdim128-480.webp 480w,/assets/img/2026-03-05-flash4/fa4_fwd_causalFalse_hdim128-800.webp 800w,/assets/img/2026-03-05-flash4/fa4_fwd_causalFalse_hdim128-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2026-03-05-flash4/fa4_fwd_causalFalse_hdim128.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2026-03-05-flash4/fa4_fwd_causalTrue_hdim128-480.webp 480w,/assets/img/2026-03-05-flash4/fa4_fwd_causalTrue_hdim128-800.webp 800w,/assets/img/2026-03-05-flash4/fa4_fwd_causalTrue_hdim128-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2026-03-05-flash4/fa4_fwd_causalTrue_hdim128.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2026-03-05-flash4/fa4_bwd_causalFalse_hdim128-480.webp 480w,/assets/img/2026-03-05-flash4/fa4_bwd_causalFalse_hdim128-800.webp 800w,/assets/img/2026-03-05-flash4/fa4_bwd_causalFalse_hdim128-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2026-03-05-flash4/fa4_bwd_causalFalse_hdim128.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2026-03-05-flash4/fa4_bwd_causalTrue_hdim128-480.webp 480w,/assets/img/2026-03-05-flash4/fa4_bwd_causalTrue_hdim128-800.webp 800w,/assets/img/2026-03-05-flash4/fa4_bwd_causalTrue_hdim128-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2026-03-05-flash4/fa4_bwd_causalTrue_hdim128.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>Since our initial code release 8 months ago, it’s been fun collaborating with the cuDNN and CUTLASS teams at NVIDIA. Newer versions of cuDNN have now implemented many of the optimizations here, and latest cuDNN offers similar perf to FA4.</p> <h2 id="acknowledgements">Acknowledgements</h2> <p>We thank Together AI, Meta, xAI, and Princeton Language and Intelligence (PLI) for compute support. We want to further thank the following teams at NVIDIA: cuDNN, TensorRT-LLM, and CUTLASS teams for constant discussions, ideas, and feedback.</p>]]></content><author><name>Ted Zadouri</name></author><summary type="html"><![CDATA[[Paper] [Code]]]></summary></entry><entry><title type="html">FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision</title><link href="tridao.github.io/blog/2024/flash3/" rel="alternate" type="text/html" title="FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision"/><published>2024-07-11T00:00:00+00:00</published><updated>2024-07-11T00:00:00+00:00</updated><id>tridao.github.io/blog/2024/flash3</id><content type="html" xml:base="tridao.github.io/blog/2024/flash3/"><![CDATA[<p>[<a href="https://arxiv.org/abs/2407.08608">Paper</a>] [<a href="https://github.com/Dao-AILab/flash-attention">Code</a>]</p> <p>Attention, as a core layer of the ubiquitous Transformer architecture, is a bottleneck for large language models and long-context applications. FlashAttention (and FlashAttention-2) pioneered an approach to speed up attention on GPUs by minimizing memory reads/writes, and is now used by most <a href="https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html">libraries</a> to accelerate Transformer training and inference. This has contributed to a massive increase in LLM context length in the last two years, from 2-4K (GPT-3, OPT) to 128K (GPT-4), or even 1M (<a href="https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k">Llama 3</a>). However, despite its success, FlashAttention has yet to take advantage of new capabilities in modern hardware, with FlashAttention-2 achieving only 35% utilization of theoretical max FLOPs on the H100 GPU. In this blogpost, we describe three main techniques to speed up attention on Hopper GPUs: exploiting asynchrony of the Tensor Cores and TMA to (1) overlap overall computation and data movement via warp-specialization and (2) interleave block-wise matmul and softmax operations, and (3) incoherent processing that leverages hardware support for FP8 low-precision.</p> <p>We’re excited to release FlashAttention-3 that incorporates these techniques. It’s 1.5-2.0x faster than FlashAttention-2 with FP16, up to 740 TFLOPS, i.e., 75% utilization of H100 theoretical max FLOPS. With FP8, FlashAttention-3 reaches close to 1.2 PFLOPS, with 2.6x smaller error than baseline FP8 attention.</p> <p>The improvements from FlashAttention-3 will result in:</p> <ol> <li><strong>More efficient GPU Utilization</strong>: The new technique uses up to 75% of an H100 GPU’s maximum capabilities, up from just 35% before. This results in significantly (1.5-2x) faster than previous versions for training and running of large language models (LLMs).</li> <li><strong>Better performance with lower precision</strong>: FlashAttention-3 can work with lower precision numbers (FP8) while maintaining accuracy. This allows for even faster processing and potentially lower memory usage, which could lead to cost savings and improved efficiency for customers running large-scale AI operations.</li> <li><strong>Ability to use longer context in LLMs</strong>: By speeding up the attention mechanism, FlashAttention-3 enables AI models to work with much longer pieces of text more efficiently. This could allow for applications that can understand and generate longer, more complex content without slowing down.</li> </ol> <p>FlashAttention-3 is available at: <a href="https://github.com/Dao-AILab/flash-attention">https://github.com/Dao-AILab/flash-attention</a></p> <h2 id="flashattention-recap">FlashAttention Recap</h2> <p><a href="https://arxiv.org/abs/2205.14135">FlashAttention</a> is an algorithm that reorders the attention computation and leverages tiling and recomputation to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. We use tiling to load blocks of inputs from HBM (GPU memory) to SRAM (fast cache), perform attention with respect to that block, and update the output in HBM. By not writing the large intermediate attention matrices to HBM, we reduce the amount of memory reads/writes, which brings 2-4x wallclock time speedup.</p> <p>Here we show a diagram of FlashAttention forward pass: with tiling and softmax rescaling, we operate by blocks and avoid having to read/write from HBM, while obtaining the correct output with no approximation.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-07-11-flash3/flash_recap_diagram-480.webp 480w,/assets/img/2024-07-11-flash3/flash_recap_diagram-800.webp 800w,/assets/img/2024-07-11-flash3/flash_recap_diagram-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-07-11-flash3/flash_recap_diagram.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <h2 id="new-hardware-features-on-hopper-gpus---wgmma-tma-fp8">New hardware features on Hopper GPUs - WGMMA, TMA, FP8</h2> <p>While FlashAttention-2 can achieve up to 70% theoretical max FLOPS on Ampere (A100) GPUs, it does not yet take advantage of new features on Hopper GPUs to maximize performance. We describe some of the new Hopper-specific features here, and why they are important.</p> <ol> <li>WGMMA (Warpgroup Matrix Multiply-Accumulate). This new feature makes use of the new Tensor Cores on Hopper, with much higher throughput<d-footnote>Without the wgmma instruction, the older mma.sync instruction can only reach about 2/3 the peak throughput of Hopper Tensor Cores: https://arxiv.org/abs/2402.13499v1.</d-footnote> than the older mma.sync instruction in Ampere (image from the <a href="https://resources.nvidia.com/en-us-tensor-core/gtc22-whitepaper-hopper?ncid=no-ncid">H100 white paper</a>).</li> </ol> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-07-11-flash3/h100_wgmma-480.webp 480w,/assets/img/2024-07-11-flash3/h100_wgmma-800.webp 800w,/assets/img/2024-07-11-flash3/h100_wgmma-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-07-11-flash3/h100_wgmma.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <ol start="2"> <li>TMA (Tensor Memory Accelerator). This is a special hardware unit that accelerates the transfer of data between global memory and shared memory, taking care of all index calculation and out-of-bound predication. This frees up registers, which is a valuable resource to increase tile size and efficiency.</li> </ol> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-07-11-flash3/h100_tma-480.webp 480w,/assets/img/2024-07-11-flash3/h100_tma-800.webp 800w,/assets/img/2024-07-11-flash3/h100_tma-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-07-11-flash3/h100_tma.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <ol start="3"> <li>Low-precision with FP8. This doubles the Tensor Core throughput (e.g. 989 TFLOPS with FP16 and 1978 TFLOPS with FP8), but trades off accuracy by using fewer bits to represent floating point numbers.</li> </ol> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-07-11-flash3/h100_wgmma_fp8-480.webp 480w,/assets/img/2024-07-11-flash3/h100_wgmma_fp8-800.webp 800w,/assets/img/2024-07-11-flash3/h100_wgmma_fp8-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-07-11-flash3/h100_wgmma_fp8.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>FlashAttention-3 makes use of all of these new features of Hopper, using powerful abstractions from <a href="https://github.com/NVIDIA/cutlass">NVIDIA’s CUTLASS</a> library.</p> <p>Several work such as <a href="https://github.com/HazyResearch/ThunderKittens">ThunderKitten</a><d-footnote>ThunderKitten also presents an elegant tile-based abstraction for writing fast kernels, you should definite check that out</d-footnote> and <a href="https://developer.nvidia.com/blog/accelerating-transformers-with-nvidia-cudnn-9/">cuDNN 9</a> has already shown that these new hardware features can speedup attention computation. By rewriting FlashAttention to use these new features, we can already significantly speed it up (e.g., from 350 TFLOPS in FlashAttention-2 FP16 forward pass to around 540-570 TFLOPS). However, the asynchronous nature of the new instructions on Hopper (WGMMA and TMA) opens up additional algorithmic opportunities to overlap operations and thereby extract even greater performance. For this blogpost, we’ll explain two such techniques specific to attention. The generic technique of warp specialization, with separate producer and consumer warps doing TMA and WGMMA, is <a href="https://github.com/NVIDIA/cutlass/blob/main/media/docs/efficient_gemm.md#warp-specialization">well-covered elsewhere</a> in the context of GEMM and works the same here.</p> <h2 id="asynchrony-overlapping-gemm-and-softmax">Asynchrony: Overlapping GEMM and Softmax</h2> <h3 id="why-overlap">Why overlap?</h3> <p>Attention has GEMMs (those matmuls between Q and K and between attention probability P and V) and softmax as its two main operations. Why do we need to overlap them? Isn’t most of the FLOPS in the GEMMs anyway? As long as the GEMMs are fast (e.g., computed using WGMMA instructions), shouldn’t the <a href="https://horace.io/brrr_intro.html">GPU be going brrrr</a>?</p> <p>The problem is that non-matmul operations are much slower than matmul operations on modern accelerators. Special functions such as exponential (for the softmax) have even lower throughput than floating point multiply-add; they are evaluated by the multi-function unit, a unit separate from floating point multiply-add or matrix multiply-add. As an example, the H100 GPU SXM5 has 989 TFLOPS of FP16 matrix multiply, but only 3.9 TFLOPS (256x less throughput) for special functions<d-footnote>The CUDA programming guide specifies that the throughput for special functions is 16 operations per streaming multiprocessor (SM) per clock cycle. We multiply 16 by 132 SMs and 1830 Mhz (clock speed used to calculate 989 TFLOPS of FP16 matmul) to get 3.9 TFLOPS</d-footnote>! For head dimension 128, there are 512x more matmul FLOPS than exponential, which means that exponential can take 50% of the time compared to matmul. The situation is even worse for FP8, where the matmul FLOPS are twice as fast yet exponential FLOPS stay the same speed. Ideally we want matmul and softmax to operate in parallel. While the Tensor Cores are busy with matmul, the multi-function units should be calculating exponential!</p> <h3 id="inter-warpgroup-overlapping-with-pingpong-scheduling">Inter-warpgroup overlapping with pingpong scheduling</h3> <p>The first and easiest way to overlap GEMM and softmax is to do nothing at all! The warp schedulers already try to schedule warps so that if some warps are blocked (e.g., waiting for GEMM results), other warps can run. That is, the warp schedulers do some of this overlapping for us, for free.</p> <p>However, we can improve on this by doing some of the scheduling manually. As an example, if we have 2 warpgroups (labeled 1 and 2 – each warpgroup is a group of 4 warps), we can use synchronization barriers (bar.sync) so that warpgroup 1 first does its GEMMs (e.g., GEMM1 of one iteration and GEMM0 of the next iteration), and then warpgroup 2 does its GEMMs while warpgroup 1 does its softmax, and so on. This “pingpong” schedule is illustrated in the figure below, where the same color denotes the same iteration.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-07-11-flash3/pingpong_pipelining-480.webp 480w,/assets/img/2024-07-11-flash3/pingpong_pipelining-800.webp 800w,/assets/img/2024-07-11-flash3/pingpong_pipelining-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-07-11-flash3/pingpong_pipelining.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>This would allow us to perform the softmax in the shadow of the GEMMs of the other warpgroup. Of course, this figure is just a caricature; in practice the scheduling is not really this clean. Nevertheless, pingpong scheduling can improve FP16 attention forward pass from around 570 TFLOPS to 620 TFLOPS (head dim 128, seqlen 8K).</p> <h3 id="intra-warpgroup-overlapping-of-gemm-and-softmax">Intra-warpgroup overlapping of GEMM and Softmax</h3> <p>Even within one warpgroup, we can have some part of softmax running while the GEMMs of that warpgroup is running. This is illustrated in this figure, where the same color denotes the same iteration.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-07-11-flash3/2_stage_pipelining-480.webp 480w,/assets/img/2024-07-11-flash3/2_stage_pipelining-800.webp 800w,/assets/img/2024-07-11-flash3/2_stage_pipelining-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-07-11-flash3/2_stage_pipelining.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>This pipelining increases throughput from around 620 TFLOPS to around 640-660 TFLOPS for FP16 attention forward, at the cost of higher register pressure. We need more registers to hold both accumulators of the GEMMs, and the input/output of softmax. Overall, we find this technique to offer a favorable tradeoff.</p> <h2 id="low-precision-reduce-quantization-error-with-incoherent-processing">Low-precision: reduce quantization error with incoherent processing</h2> <p>LLM activation can have <a href="https://arxiv.org/abs/2208.07339">outliers</a> with much larger magnitude than the rest of the features. These outliers make it difficult to quantize, producing much larger quantization errors. We leverage incoherent processing, a technique used in the quantization literature (e.g. from <a href="https://arxiv.org/abs/2307.13304">QuIP</a> and <a href="https://arxiv.org/abs/2402.04396">QuIP#</a>) that multiplies the query and key with a random orthogonal matrix to “spread out” the outliers and reduce quantization error. In particular, we use the Hadamard transform (with random signs), which can be done per attention head in O(d log d) instead of O(d^2) time, where d is the head dimension. Since the Hadamard transform is memory-bandwidth bound, it can be fused with previous operations such as rotary embedding (also memory-bandwidth bound) “for free”.</p> <p>In our experiment where Q, K, V are generated from a standard normal distribution but 0.1% of the entries have large magnitudes (to simulate outliers), we found that incoherent processing can reduce the quantization error by 2.6x. We show numerical error comparison in the table below. Please see the paper for details.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-07-11-flash3/flash3_numerical_error-480.webp 480w,/assets/img/2024-07-11-flash3/flash3_numerical_error-800.webp 800w,/assets/img/2024-07-11-flash3/flash3_numerical_error-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-07-11-flash3/flash3_numerical_error.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <h2 id="attention-benchmark">Attention Benchmark</h2> <p>We show some results with FlashAttention-3, and compare it to FlashAttention-2, as well as the implementation in Triton and cuDNN (both of which already use new hardware features of Hopper GPUs).</p> <p>For FP16, we see about 1.6x-2.0x speedup over FlashAttention-2.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-07-11-flash3/flash3_fp16_fwd-480.webp 480w,/assets/img/2024-07-11-flash3/flash3_fp16_fwd-800.webp 800w,/assets/img/2024-07-11-flash3/flash3_fp16_fwd-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-07-11-flash3/flash3_fp16_fwd.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-07-11-flash3/flash3_fp16_bwd-480.webp 480w,/assets/img/2024-07-11-flash3/flash3_fp16_bwd-800.webp 800w,/assets/img/2024-07-11-flash3/flash3_fp16_bwd-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-07-11-flash3/flash3_fp16_bwd.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>For FP8, we can reach close to 1.2 PFLOPS!</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-07-11-flash3/flash3_fp8_fwd-480.webp 480w,/assets/img/2024-07-11-flash3/flash3_fp8_fwd-800.webp 800w,/assets/img/2024-07-11-flash3/flash3_fp8_fwd-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-07-11-flash3/flash3_fp8_fwd.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <h2 id="discussion">Discussion</h2> <p>This blogpost highlights some of the optimizations for FlashAttention available on Hopper GPUs. Other optimizations (e.g., variable length sequences, persistent kernel, and in-kernel transpose for FP8) are covered in the paper.</p> <p>We have seen that designing algorithms that take advantage of the hardware they run on can bring significant efficiency gains and unlock new model capabilities such as long context. We look forward to future work on optimization for LLM inference, as well as generalizing our techniques to other hardware architectures. We also look forward to FlashAttention-3 being integrated in a future release of PyTorch.</p>]]></content><author><name>Jay Shah</name></author><summary type="html"><![CDATA[[Paper] [Code]]]></summary></entry><entry><title type="html">State Space Duality (Mamba-2) Part I - The Model</title><link href="tridao.github.io/blog/2024/mamba2-part1-model/" rel="alternate" type="text/html" title="State Space Duality (Mamba-2) Part I - The Model"/><published>2024-05-31T00:00:00+00:00</published><updated>2024-05-31T00:00:00+00:00</updated><id>tridao.github.io/blog/2024/mamba2-part1-model</id><content type="html" xml:base="tridao.github.io/blog/2024/mamba2-part1-model/"><![CDATA[<figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-05-31-mamba-2/mamba-2-V3-transparent-480.webp 480w,/assets/img/2024-05-31-mamba-2/mamba-2-V3-transparent-800.webp 800w,/assets/img/2024-05-31-mamba-2/mamba-2-V3-transparent-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-05-31-mamba-2/mamba-2-V3-transparent.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>[<a href="https://arxiv.org/abs/2405.21060">Paper</a>] [<a href="https://github.com/state-spaces/mamba">Code</a>]</p> <p><strong>This series is cross-posted at <a href="https://goombalab.github.io/blog/2024/mamba2-part1-model/">GoombaLab</a></strong></p> <ol> <li>Part I - The Model</li> <li><a href="/blog/2024/mamba2-part2-theory/">Part II - The Theory</a></li> <li><a href="/blog/2024/mamba2-part3-algorithm/">Part III - The Algorithm</a></li> <li><a href="/blog/2024/mamba2-part4-systems/">Part IV - The Systems</a></li> </ol> <p>Since the release of <a href="https://arxiv.org/abs/2312.00752">Mamba</a> 6 months ago, we’ve been pleasantly surprised by the overwhelming <a href="https://github.com/AvivBick/awesome-ssm-ml">community response</a>. It’s been incredibly gratifying to see the line of research on efficient sequence models we’ve been pursuing for years really resonate with the machine learning community and take off more than we could have anticipated. We’ve seen an enormous amount of exciting follow-up work, from direct applications (e.g. vision <d-cite key="zhu2024vision"></d-cite><d-cite key="ma2024u"></d-cite><d-cite key="liu2024vmamba"></d-cite>, genomics <d-cite key="schiff2024caduceus"></d-cite>, graphs <d-cite key="wang2024graph"></d-cite><d-cite key="behrouz2024graph"></d-cite>, and more) to understanding (e.g. on recall abilities <d-cite key="jelassi2024repeat"></d-cite>, in-context learning<d-cite key="akyurek2024context"></d-cite> <d-cite key="grazzi2024mamba"></d-cite> <d-cite key="park2024can"></d-cite>, and formal language expressivity <d-cite key="merrill2024illusion"></d-cite><d-cite key="sarrof2024expressive"></d-cite>), and an enormous number of <a href="https://jackcook.com/2024/02/23/mamba.html">online</a> <a href="https://srush.github.io/annotated-mamba/hard.html">blogs</a>, <a href="https://www.youtube.com/watch?v=dVH1dRoMPBc">tutorials</a>, <a href="https://www.youtube.com/watch?v=8Q_tqwpTpVU">and</a> <a href="https://www.youtube.com/watch?v=N6Piou4oYx8">videos</a>. We couldn’t be more excited about the direction of this research!</p> <p>Yet despite its potential so far, we weren’t completely satisfied with the first version of Mamba…</p> <h3 id="problem-1-understanding">Problem 1 (Understanding)</h3> <p>From a conceptual standpoint, one of the reasons we found SSMs so fascinating is how they just feel <em>fundamental</em>. One way this is exemplified is how they have rich ties to many major paradigms of sequence models. As developed in our earlier works on structured SSMs <d-cite key="gu2021combining"></d-cite><d-cite key="gu2023thesis"></d-cite>, they seem to capture the essence of continuous, convolutional, and recurrent sequence models – all wrapped up in a simple and elegant model.</p> <p>But of course, aside from these, there’s another major sequence model paradigm: variants of the ubiquitous <strong>attention</strong> mechanism<d-cite key="bahdanau2015neural"></d-cite><d-cite key="vaswani2017attention"></d-cite>. SSMs always felt somewhat disjoint from attention, and we’ve tried for a while to understand their relationship better.</p> <blockquote> <p>Question 1: <strong>What are the conceptual connections between state space models and attention?</strong> Can we combine them?</p> </blockquote> <h3 id="problem-2-efficiency">Problem 2 (Efficiency)</h3> <p>From a computational standpoint, despite the work that went into making Mamba fast (in particular, its hardware-aware selective scan implementation) it’s still much less hardware-efficient than mechanisms such as attention. The missing piece is that modern accelerators such as GPUs and TPUs are <em>highly</em> specialized for matrix multiplications. While this isn’t a problem for inference, which is bottlenecked by somewhat different considerations, this can be a big deal during training time.</p> <blockquote> <p>Question 2: <strong>Can we speed up the training of Mamba models by recasting them as matrix multiplications?</strong></p> </blockquote> <p>These are the main questions that Mamba-2 – in particular, its new state space model variant – tries to address.</p> <h2 id="the-ssd-model">The SSD Model</h2> <p>The main point of the Mamba-2 paper is what we call <strong>structured state space duality</strong> (SSD), which refers to several things:</p> <ol> <li>The <strong>SSD model</strong> refers to a specific standalone layer, like attention or an SSM, that can be incorporated into deep neural networks</li> <li>The <strong>SSD framework</strong> is a general framework for reasoning about this model (and many more theoretical connections)</li> <li>The <strong>SSD algorithm</strong> is an algorithm for computing SSD layers much more efficiently than previous SSMs</li> </ol> <p>The main SSD model or “state space dual model” itself really isn’t so complicated! In this first part of a series of blog posts, we’ll provide a self-contained description of the SSD layer (and Mamba-2) in isolation and how it compares to related models, particularly Mamba-1.</p> <p>In the next parts of this series, we’ll describe the general framework and theoretical connections, which aren’t necessary to actually use Mamba-2.</p> <h3 id="the-linear-ssm-mode">The Linear (SSM) Mode</h3> <p>SSD starts from the same set of equations as Mamba:</p> \[\begin{aligned} h_{t} &amp;= A_t h_{t-1} + B_t x_t \\ y_t &amp;= C_t^{\top} h_t \end{aligned}\] <p>\begin{equation} \label{eq:ssm} (\text{Selective state space model (SSM)}) \end{equation}</p> <p>To recap, a <strong>structured state space model (SSM)</strong> <d-cite key="gu2022efficiently"></d-cite><d-cite key="gu2023thesis"></d-cite> defines a map from $x \in \mathbb{R}^\mathtt{T} \to y \in \mathbb{R}^\mathtt{T}$. Think of $x_t$ and $y_t$ as being scalars, and the hidden state $h_t$ as an $\mathtt{N}$-dimensional vector, where $\mathtt{N}$ is an independent hyperparameter called the <em>state size, state dimension, or state expansion factor</em>.</p> <p>A <em>selective</em> state space model allows the $(A, B, C)$ SSM parameters to vary across time <d-cite key="gu2023mamba"></d-cite>. We’ll think of them as tensors with shapes $A \in \mathbb{R}^\mathtt{(T, N, N)}$, $B \in \mathbb{R}^\mathtt{(T, N)}$, and $C \in \mathbb{R}^\mathtt{(T, N)}$ respectively.<d-footnote>As with Mamba-1, we take everything over the reals $\mathbb{R}$, although complex variants as with other structured SSMs like the S4 lineage <d-cite key="gu2022efficiently"></d-cite> are also possible.</d-footnote></p> <p>Structured SSMs require $A$ to have structure to be efficiently computable, such as the most commonly used diagonal structure <d-cite key="gu2022parameterization"></d-cite><d-cite key="gupta2022diagonal"></d-cite><d-cite key="smith2023s5"></d-cite><d-cite key="gupta2022simplifying"></d-cite>. In this case $A$ has shape $\mathtt{(T, N)}$ where only the diagonal elements of the $\mathtt{N} \times \mathtt{N}$ matrices are stored.</p> <h4 id="ssd-scalar-structured-ssm">SSD: Scalar Structured SSM</h4> <p>The original Mamba (or more precisely its core “S6” layer) is exactly a selective SSM with diagonal structure.</p> <p><strong>The SSD layer of Mamba-2 makes only one small modification</strong>: it restricts the diagonal $A$ even further to a <em>scalar times identity</em> structure; in other words the diagonal elements of $A$ must all be the same value. In this case $A$ can be represented with shape just $\mathtt{(T)}$ and one can also identify $A_t$ as just a scalar (and so we’ll sometimes denote it $a_t$).</p> <h4 id="multihead-ssms">Multihead SSMs</h4> <p>Equation \eqref{eq:ssm} is defined only for a single dimensional input $x \in \mathbb{R}^\mathtt{T}$. If $X \in \mathbb{R}^\mathtt{(T, P)}$ has $\mathtt{P}$ separate channels, we can use the same dynamics (i.e. the same SSM $(A, B, C)$) independently for each channel. This can be interpreted as a <em>single head</em> of the SSM model.</p> <p>Here, we think of $X$ as a tensor of shape $\mathtt{(T, P)}$ where $\mathtt{T}$ is the sequence (time) dimension and $\mathtt{P}$ is the “head dimension”.<d-footnote>Normally there's an additional batch dimension $\mathtt{B}$ when implementing these models, which we'll ignore throughout this presentation.</d-footnote></p> <p>Multiple heads can be constructed completely independently; for the remainder of this post, we assume that we’re working with a single head. Note that these heads are exactly analogous to how heads in multi-head attention models work, and in Mamba-2 we also choose similar dimensions as modern Transformers, e.g. $\mathtt{P} = 64$ or $\mathtt{P}=128$. (To scale to larger model widths $\mathtt{D} = \mathtt{d\_model}$, we keep this fixed and increase the number of independent heads.)</p> <p>We can notate the general (selective) state space model as \begin{equation} \label{eq:ssm-transformation} Y^\mathtt{(T,P)} = \mathsf{SSM}(A^\mathtt{(T,…)}, B^\mathtt{(T,N)}, C^\mathtt{(T,N)})(X^\mathtt{(T,P)}) \end{equation}</p> <p>Some axes of variation include</p> <ol> <li>The structure on $A$, which affects its parameter shape: <ul> <li><code class="language-plaintext highlighter-rouge">... = (N,N)</code> for general (unstructured) SSMs</li> <li><code class="language-plaintext highlighter-rouge">... = (N)</code> for diagonal SSMs (or other structures, such as diagonal-plus-low-rank <d-cite key="gu2022efficiently"></d-cite>)</li> <li><code class="language-plaintext highlighter-rouge">... = ()</code> for scalar SSMs (i.e. SSD)</li> </ul> </li> <li>The state dimension $\mathtt{N}$ (i.e. <code class="language-plaintext highlighter-rouge">d_state</code>)</li> <li>The head dimension $\mathtt{P}$ (i.e. <code class="language-plaintext highlighter-rouge">d_head</code>)</li> </ol> <p>There are other axes of variation of structured SSMs (e.g. time-invariance vs. selectivity, SISO vs. MIMO<d-cite key="smith2023s5"></d-cite>, real vs. complex, etc.), but we’re highlighting these so that we can contrast Mamba-2 to Mamba-1 in just a second…</p> <h3 id="the-quadratic-attention-mode">The Quadratic (Attention) Mode</h3> <p>But first, let’s switch tacks and forget about state space models for a moment. Given the same tensors above with the same shapes $(A^\mathtt{(T)}, B^\mathtt{(T, N)}, C^\mathtt{(T, N)})$, let’s define a different object.</p> <p>First, we’ll define the following matrix (don’t worry, we’ll explain more and give it a name in Part II of this series!)</p> \[L = \begin{bmatrix} 1 &amp; \\ a_1 &amp; 1 &amp; \\ a_2a_1 &amp; a_2 &amp; 1 \\ \vdots &amp; \vdots &amp; \ddots &amp; \ddots \\ a_{\mathtt{T}-1}\dots a_1 &amp; a_{\mathtt{T}-1}\dots a_2 &amp; \dots &amp; a_{\mathtt{T}-1} &amp; 1 \\ \end{bmatrix} .\] <p>Then, let’s define the following matrix</p> <p>\begin{equation} \label{eq:ssd-attention} M = L \circ C B^\top \in \mathbb{R}^{\mathtt{(T,T)}} \end{equation}</p> <p>Finally, $M$ encodes a <em>sequence transformation</em> $x \in \mathbb{R}^\mathtt{T} \to y \in \mathbb{R}^\mathtt{T}$ mapping a 1D input to a 1D output—just as in equation \eqref{eq:ssm}—through basic matrix multiplication $y = Mx$.</p> <p>What’s special about this? Well, you may notice that it looks very similar to an attention computation. In fact, if all $a_t = 1$, then $L$ is simply the lower-triangular <em>causal mask</em> and \eqref{eq:ssd-attention} is equivalent to <strong>causal linear attention</strong> <d-cite key="katharopoulos2020transformers"></d-cite>:</p> \[Y = (L \circ Q K^\top) V\] <p>This is exactly the same as equation \eqref{eq:ssd-attention} if we rename $(C, B, X) \mapsto (Q, K, V)$!</p> <h2 id="state-space-duality">State Space Duality</h2> <p>The so-called “duality” refers to the fact that the two models defined in equations \eqref{eq:ssm} (for the scalar-identity structured $A_t$ case) and \eqref{eq:ssd-attention} are actually <em>exactly the same model</em>, which we can view as a particular function</p> \[(A^\mathtt{(T)}, B^\mathtt{(T, N)}, C^\mathtt{(T, N)}, X^\mathtt{(T, P)}) \mapsto Y^\mathtt{(T, P)}\] <p>In the general <em>SSD Framework</em> (Part II of this series), we’ll show this equivalence in two completely different ways, both of which are actually much more general and each quite illuminating.</p> <p>If you take our word for it, though, then SSD is relatively simple to contrast in relation to either SSMs or attention.</p> <h3 id="ssd-vs-state-space-models">SSD vs. State Space Models</h3> <p>Compared to previous SSMs, SSD is pretty much the same as the core layer of Mamba but with even more structure on the recurrent $A$ matrices.</p> <ol> <li>Mamba-1 (S6) uses diagonal structure on $A$, while Mamba-2 (SSD) uses scalar-times-identity structure on $A$.</li> <li>Mamba-1 has a head dimension of $\mathtt{P}=1$ (i.e. all channels are completely independently controlled by separate SSMs), while Mamba-2 uses a head dimension of $\mathtt{P}&gt;1$ (something like $\mathtt{P}=64$ by default).</li> </ol> <p>In particular, this can be viewed as weight-tied in two ways:</p> <ul> <li>By restricting the diagonal structure of $A$ to scalar-times-identity, the recurrence dynamics are shared across all $\mathtt{N}$ elements of the state space.</li> <li>These dynamics are also shared across all $\mathtt{P}$ channels of a given head.</li> </ul> <p>In other words, a single SSM head has total state size $\mathtt{P} \times \mathtt{N}$, which are each governed by separate scalar recurrences in Mamba-1 but are controlled by a single shared recurrence in Mamba-2.</p> <p>Why make these restrictions? The main motivation is efficiency: these changes are necessary to be able to view the model in its [<a href="#the-quadratic-attention-mode">dual attention form</a>], which allows matrix multiplications to be used.</p> <blockquote class="block-tip"> <h4 id="the-bottom-line-mamba-1-vs-mamba-2">The Bottom Line: Mamba-1 vs. Mamba-2</h4> <p>Compared to Mamba-1, Mamba-2 allows <strong>much larger state dimensions</strong> (from <code class="language-plaintext highlighter-rouge">N=16</code> in Mamba-1 to <code class="language-plaintext highlighter-rouge">N=64</code> to <code class="language-plaintext highlighter-rouge">N=256</code> or even higher in Mamba-2) while simultaneously being <strong>much faster during training</strong>.</p> </blockquote> <p>But can this hurt us? There’s some intuition to believe that it shouldn’t. One of the main reasons for the selectivity (e.g. $A$ that depends on the input $X$) introduced in Mamba is to let the SSM be able to control whether to remember or ignore particular pieces of information; for example, if a filler “um” is encountered in a text transcript. But if such information should be ignored, then the entire state can ignore it together, and so it should be okay if the state’s dynamics are shared across all features.</p> <p>Empirically, we haven’t found evidence that the restricted expressivity of Mamba-2 might hurt, but the jury’s still out! From one perspective, Mamba-2 isn’t <em>strictly</em> better than Mamba-1: while it’s a dramatic improvement from a <em>training</em> perspective, Mamba-1 might be better from a pure <em>inference</em> perspective. Since inference speed of SSMs is entirely governed by the state dimension, if one wants to maximize performance for a target inference efficiency (i.e. for a particular state size $\mathtt{N}$), then the increased expressivity of Mamba-1 might be better. We haven’t fully analyzed the (theoretical or empirical) tradeoffs here, and think this would be a cool direction for the community to dig in more!</p> <h3 id="ssd-vs-attention">SSD vs. Attention</h3> <p>Compared, to standard (self-)attention, SSD also only has two differences:</p> <ol> <li>The softmax normalization is dropped.</li> <li>A separate elementwise mask matrix is applied multiplicatively.</li> </ol> <p>The first difference can be interpreted as what reduces the effective state size of the model from linear to constant, and improves its efficiency from quadratic to linear.</p> <p>The second difference is what distinguishes SSD from standard linear attention. One way to think of the mask is as <strong>input-dependent relative positional encodings</strong>. Because of the mask $L$ in \eqref{eq:ssd-attention}, the standard attention score $\langle Q_i, K_j \rangle$ is attenuated by a weight</p> \[a_{i:j}^\times = a_i \cdots a_{j+1}\] <p>which can be interpreted as a “discount factor” based on how far apart the positions $i$ and $j$ are. (This interpretation was concurrently espoused by Tobias Katsch’s <a href="https://arxiv.org/abs/2311.01927">GateLoop</a> paper<d-cite key="katsch2023gateloop"></d-cite>.) In its attention form, this input-dependent positional mask can be interpreted as the key factor that encodes the “selectivity” of Mamba!</p> <h2 id="best-of-both-worlds">Best of Both Worlds</h2> <p>So why do we care that there are two views of this model? Well, first of all, it’s extremely mathematically interesting, as we’ll cover in <a href="/blog/2024/mamba2-part2-theory/">Part II</a>, and we hope will inspire future directions. But there are immediate practical benefits too!</p> <h3 id="efficiency-the-ssm-and-attention-modes">Efficiency: the SSM and Attention Modes</h3> <p>The SSM \eqref{eq:ssm} and attention \eqref{eq:ssd-attention} modes represent two different ways of computing the same function, so let’s contrast them.</p> <p>First, remember that one main reason why SSMs are interesting to begin with is because computing \eqref{eq:ssm} as a recurrence requires maintaining a <em>constant-size state</em> (size $\mathtt{N}$ per channel) and scales <em>linearly in the sequence length</em> $\mathtt{T}$. The downside is that the raw FLOPs don’t reflect actual speed in practice because of hardware considerations…</p> <p>On the other hand, computing this sequence transformation $y = Mx$ through equation \eqref{eq:ssd-attention} takes quadratic time in the sequence length, because we’re materializing this $\mathtt{T} \times \mathtt{T}$ matrix. But it can be fast in practice because it only uses matrix multiplications, which are extremely optimized on GPUs and TPUs.</p> <h3 id="efficiency-the-ssd-mode">Efficiency: the SSD Mode</h3> <p>So if there are two equivalent ways of computing the same model, when should we use one mode or the other? During inference, there’s no trade-off: the SSM mode is designed for fast autoregressive inference. But what about training? Here there’s a tension between FLOPs and hardware efficiency where the attention mode uses more FLOPs, but uses them more efficiently through matrix multiplications.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-05-31-mamba-2/ssd_algorithm-480.webp 480w,/assets/img/2024-05-31-mamba-2/ssd_algorithm-800.webp 800w,/assets/img/2024-05-31-mamba-2/ssd_algorithm-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-05-31-mamba-2/ssd_algorithm.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>It turns out we can get the best of both worlds by combining the algorithms! There are two equivalent interpretations of this “state space dual” algorithm, either as</p> <ol> <li>A block decomposition of a particular structured matrix that defines the SSD “token-mixing” sequence transformation.</li> <li>A “chunkwise” algorithm that splits the sequence into segments, computes the quadratic attention form on each segment, and adjusts the result by passing the SSM states between segments.</li> </ol> <p>We’ll leave the details of this algorithm to <a href="/blog/2024/mamba2-part3-algorithm/">Part III</a> (or Section 6 of the <a href="https://arxiv.org/abs/2405.21060">full paper</a>), as it requires a bit of machinery from the theory to derive. But we do emphasize that the implementation of this algorithm isn’t too complicated – a minimal implementation that we provide is only ~30 lines of PyTorch!</p> <p>The benefits of the SSD algorithm is that it preserves the same efficient FLOP counts as SSMs (compared to quadratic attention), and also dramatically speeds up training compared to general state space models by utilizing matmuls.</p> <table> <thead> <tr> <th> </th> <th>Attention</th> <th>SSM</th> <th>SSD</th> </tr> </thead> <tbody> <tr> <td>State size</td> <td>$\mathrm{T}$</td> <td>$\mathbf{N}$</td> <td>$\mathbf{N}$</td> </tr> <tr> <td>Training FLOPs</td> <td>$\mathrm{T}^2\mathrm{N}$</td> <td>$\mathbf{TN^2}$</td> <td>$\mathbf{TN^2}$</td> </tr> <tr> <td>Inference FLOPs</td> <td>$\mathrm{T}\mathrm{N}$</td> <td>$\mathbf{N^2}$</td> <td>$\mathbf{N^2}$</td> </tr> <tr> <td>(Naive) memory</td> <td>$\mathrm{T}^2$</td> <td>$\mathrm{TN}^2$</td> <td>$\mathbf{TN}$</td> </tr> <tr> <td>Matrix multiplications?</td> <td>:heavy_check_mark:</td> <td>:x:</td> <td>:heavy_check_mark:</td> </tr> </tbody> </table> <h2 id="the-mamba-2-architecture">The Mamba-2 Architecture</h2> <p>Although the core contribution of Mamba-2 is the new SSD layer and theory, we also make some small changes to Mamba’s neural network architecture.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-05-31-mamba-2/architecture_2-480.webp 480w,/assets/img/2024-05-31-mamba-2/architecture_2-800.webp 800w,/assets/img/2024-05-31-mamba-2/architecture_2-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-05-31-mamba-2/architecture_2.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>The main change is producing the $(A, B, C)$ SSM parameters in parallel with the $X$ input, instead of sequentially. This is partly motivated by the connections to attention; but more pragmatically, it’s simpler and more amenable to scaling techniques such as tensor parallelism, which will be discussed in Part IV of this series!</p> <p>There are some other small differences which are covered in more detail in the paper. However, we do want to emphasize that these architectural changes aren’t really the main point of the model.</p> <h3 id="language-modeling">Language Modeling</h3> <p>In terms of empirical results, we didn’t test Mamba-2 as extensively as Mamba-1, but believe it should generally be on par or better across the board. Our full language model results use the same protocol as Mamba, and found slightly better scaling at Chinchilla laws <d-cite key="hoffmann2022empirical"></d-cite>.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-05-31-mamba-2/pile_8k_mamba2-480.webp 480w,/assets/img/2024-05-31-mamba-2/pile_8k_mamba2-800.webp 800w,/assets/img/2024-05-31-mamba-2/pile_8k_mamba2-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-05-31-mamba-2/pile_8k_mamba2.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>Fully trained models on the Pile dataset<d-cite key="pile"></d-cite> and the standard zero-shot downstream evaluations show similar trends. We emphasize that even when the performance is comparable, Mamba-2 is <em>much</em> faster to train than Mamba-1!</p> <h3 id="synthetic-language-modeling-mqar">Synthetic Language Modeling: MQAR</h3> <p>More interestingly, we highlight the one synthetic task we tried. Since the original Mamba paper, which investigated synthetics such as Synthetic Copying and Induction Heads, many follow-up works have begun investigating harder associative recall tasks. The <strong>multi-query associative recall (MQAR)</strong> task introduced by the Zoology and Based <d-cite key="arora2024zoology"></d-cite><d-cite key="arora2024simple"></d-cite> line of work has become a de facto standard.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-05-31-mamba-2/mqar-480.webp 480w,/assets/img/2024-05-31-mamba-2/mqar-800.webp 800w,/assets/img/2024-05-31-mamba-2/mqar-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-05-31-mamba-2/mqar.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>We ran a version of this task that’s much harder than the one usually reported in the literature, and found that Mamba-2 is substantially better than Mamba-1. One reason for the improved performance is the much larger state size (up to $16\times$ larger than Mamba-1 here), which was one of the primary motivations of Mamba-2 in the first place.</p> <p>Interestingly, Mamba-2 also appears to be noticeably better than Mamba-1 on this particular task even when the state size is controlled. We’re not quite sure why to be honest, and it would be great to ablate the other aspects of the model to investigate… for example, could it be possible that the [<a href="#ssd-vs-state-space-models">restricted structure of SSD</a>] is actually <em>helpful</em> here?</p> <h2 id="next-up">Next Up</h2> <p>In <a href="/blog/2024/mamba2-part2-theory/">the next part of this series</a>, we’ll go more into the full SSD framework, including how to prove the claimed “duality” of the SSD layer, and strong generalizations of it.</p>]]></content><author><name>Albert Gu</name></author><summary type="html"><![CDATA[]]></summary></entry><entry><title type="html">State Space Duality (Mamba-2) Part II - The Theory</title><link href="tridao.github.io/blog/2024/mamba2-part2-theory/" rel="alternate" type="text/html" title="State Space Duality (Mamba-2) Part II - The Theory"/><published>2024-05-31T00:00:00+00:00</published><updated>2024-05-31T00:00:00+00:00</updated><id>tridao.github.io/blog/2024/mamba2-part2-theory</id><content type="html" xml:base="tridao.github.io/blog/2024/mamba2-part2-theory/"><![CDATA[<ol> <li><a href="/blog/2024/mamba2-part1-model/">Part I - The Model</a></li> <li>Part II - The Theory</li> <li><a href="/blog/2024/mamba2-part3-algorithm/">Part III - The Algorithm</a></li> <li><a href="/blog/2024/mamba2-part4-systems/">Part IV - The Systems</a></li> </ol> <p>In <a href="/blog/2024/mamba2-part1-model/">Part I</a> of this series, we defined the state space dual (SSD) <em>model</em>. In isolation, this model is relatively simple to define, and we claimed that it can be computed either as an SSM recurrence or with an attention-like pattern. If you just want to use the model, feel free to skip this post!</p> <p>In this post, we’ll dive into the theory behind the model. We’ll derive the SSD “duality” in two completely separate ways, one starting from the SSM perspective and one from the attention perspective. Each method is actually much more broad than the SSD model itself, and the union of these two strong generalizations is what we call the SSD <em>framework</em>. This framework provides a rich body of connections between state space models, attention, and structured matrices. While the SSD model can be viewed as a specific instantiation of each prong of the framework, the SSD framework is much more general opens up many directions for future work.</p> <h4 id="the-state-space-duality-framework">The State Space Duality framework</h4> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-05-31-mamba-2/ssd_venn-480.webp 480w,/assets/img/2024-05-31-mamba-2/ssd_venn-800.webp 800w,/assets/img/2024-05-31-mamba-2/ssd_venn-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-05-31-mamba-2/ssd_venn.png" width="100%" height="auto" title="Structured State Space Duality" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> <figcaption class="caption">SSD Framework (red, blue): State space models (i.e. semiseparable matrices) and structured masked attention encapsulate large classes of efficient sequence models. Their intersection is the SSD model (purple).</figcaption> </figure> <p>For each of the two parts of this framework, we’ll</p> <ol> <li>Define the general concepts</li> <li>Show how the SSD model is an instantiation, and prove the duality</li> <li>Suggest future directions for how the framework can be used</li> </ol> <p>Note that this theory is <em>not necessary</em> to use the SSD model itself; this part of the series can be safely skipped for the practitioner that just wants to use SSD (Mamba-2).</p> <h2 id="recap-the-ssd-model">Recap: The SSD Model</h2> <p><a href="/blog/2024/mamba2-part1-model/">Part I</a> of this series introduced the SSD layer, which is defined as a selective SSM</p> \[\begin{aligned} h_{t} &amp;= A_t h_{t-1} + B_t x_t \\ y_t &amp;= C_t^{\top} y_t \end{aligned}\] <p>\begin{equation} \label{eq:ssm} (\text{Selective state space model (SSM)}) \end{equation}</p> <p>with scalar-identity structure on $A$.</p> <p>More formally, we view it as a <em>sequence transformation</em> $X \mapsto Y$</p> <p>\begin{equation} \label{eq:ssm-transformation} Y^\mathtt{(T,P)} = \mathsf{SSM}(A^\mathtt{(T)}, B^\mathtt{(T,N)}, C^\mathtt{(T,N)})(X^\mathtt{(T,P)}) \end{equation}</p> <p>The dual attention-like form of the SSD layer is</p> <p>\begin{equation} \label{eq:ssd-attention} M = L \circ C B^\top \in \mathbb{R}^{\mathtt{(T,T)}} \end{equation}</p> <p>Now let’s see how to prove this!</p> <h2 id="ssd-framework-1-structured-matrix-transformations">SSD Framework 1: Structured Matrix Transformations</h2> <p>The first framing of the duality will be from an SSM-centric perspective, where we’ll prove the duality through the framework of <strong>matrix sequence transformations</strong> or “matrix mixers”.</p> <h3 id="matrix-transformations">Matrix Transformations</h3> <p>The idea is that many sequence models, i.e. <em>sequence transformations</em> $X \in \mathbb{R}^\mathtt{(T,P)} \mapsto Y \in \mathbb{R}^\mathtt{(T,P)}$, can be written in the form of a single matrix multiplication $Y = M(X) \cdot X$ where $M$ is a matrix which can itself depend on $X$. We call this a <em>matrix sequence transformation</em>, or matrix transformation for short. In the literature sequence transformations have also been referred to as “sequence mixers” or “token mixers”, and matrix sequence transformations as “matrix mixers”. There are many examples of these, which are distinguished by the structure of the $M$ matrix. The de facto example is self-attention itself, where $M = \mathsf{softmax}(QK^\top)$ is the attention matrix. Other examples include MLP-Mixer<d-cite key="tolstikhin2021mlp"></d-cite>, FNet<d-cite key="lee2021fnet"></d-cite>, and Monarch Mixer<d-cite key="dao2022monarch"></d-cite><d-cite key="fu2024monarch"></d-cite>.</p> <p>Why do we care about these types of models?</p> <blockquote> <p>Writing a sequence model as a matrix transformation provides a powerful tool to understand the structure and characteristics of the model.</p> </blockquote> <p>And although general non-linear RNNs such as LSTMs <em>cannot</em> be written as matrix mixers, state space models can! In fact, this is pretty easy to see by just unrolling the definition of the SSM recurrence. The upshot is that the SSM \eqref{eq:ssm-transformation} can be written as a matrix transformation</p> \[Y = \mathsf{SSM}(A, B, C)(X) = MX\] <p>where $M_{ij} = 0$ for $i &lt; j$ (i.e. it’s lower triangular) and otherwise \begin{equation} \label{eq:semiseparable} M_{ij} = C_i^\top A_{i:j}^\times B_j := C_i^\top A_i \dots A_{j+1} B_j \end{equation}</p> <p>Drawing it out, this matrix looks like</p> \[\begin{bmatrix} C_0^\top B_0 &amp; \\ C_1^\top A_1 B_0 &amp; C_1^\top B_1 &amp; \\ C_2^\top A_2A_1 B_0 &amp; C_2^\top A_2 B_1 &amp; C_2^\top B_2 \\ \vdots &amp; \vdots &amp; \ddots &amp; \ddots \\ C_\mathtt{T}^\top A_{\mathtt{T}-1}\dots A_1 B_0 &amp; C_\mathtt{T}^\top A_{\mathtt{T}-1}\dots A_2 B_1 &amp; \dots &amp; C_\mathtt{T}^\top A_{\mathtt{T}-1} B_{\mathtt{T}-2} &amp; C_\mathtt{T}^\top B_{\mathtt{T}-1} \\ \end{bmatrix}\] <p>\begin{equation} \label{eq:ssm-matrix} (\text{Matrix Transformation Representation of State Space Models}) \end{equation}</p> <h3 id="semiseparable-matrices">Semiseparable Matrices</h3> <p>This type of matrix in fact has a name: it’s called a (triangular) <strong>semiseparable matrix</strong>, and has been studied in other fields of engineering and computational linear algebra<d-cite key="vandebril2005bibliography"></d-cite>. These matrices are (IMO) quite fundamental and beautiful, and the full paper talks about more of their properties. For example, an alternative characterization of semiseparable matrices is their <em>structured rank property</em>, which says that every submatrix contained in the lower-triangular portion is low rank.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-05-31-mamba-2/semiseparable-480.webp 480w,/assets/img/2024-05-31-mamba-2/semiseparable-800.webp 800w,/assets/img/2024-05-31-mamba-2/semiseparable-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-05-31-mamba-2/semiseparable.png" width="100%" height="auto" title="State Space Models are Semiseparable Matrices" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> <figcaption class="caption">All submatrices contained on-and-below the diagonal of a semiseparable matrix are low-rank.</figcaption> </figure> <p>For our purposes, we’ll care about this form mainly for the algorithmic considerations. One of the central messages of this SSD paper is that:</p> <blockquote class="block-tip"> <h4 id="takeaway-computing-ssms-through-matrix-multiplication">Takeaway: Computing SSMs Through Matrix Multiplication</h4> <p>All algorithms for computing state space models can be viewed as structured matrix multiplication algorithms on semiseparable matrices.</p> </blockquote> <p>Let’s see an easy instantiation of this, focusing on our main objective!</p> <h3 id="deriving-the-duality-ssm-to-attention">Deriving the Duality: SSM to Attention</h3> <p>To show that equation \eqref{eq:ssd-attention} follows from equation \eqref{eq:ssm} (in the case of the SSD model, i.e. scalar SSM), we directly use the matrix form of the state space model \eqref{eq:semiseparable}. Because the $A_t$ are all scalars in this case, they can be factored out of the entries</p> \[C_i^\top A_{i:j}^\times B_j = A_{i:j}^\times \cdot (C_i^\top B_j)\] <p>which directly implies equation \eqref{eq:ssd-attention}.</p> <p>In summary:</p> <blockquote class="block-tip"> <h4 id="duality-representation-1-ssm">Duality Representation 1 (SSM)</h4> <p>The duality for the SSD model can be seen as two <strong>different matrix multiplication algorithms</strong> on the semiseparable matrix.</p> </blockquote> <ul> <li>The linear form is a <em>structured matrix multiplication algorithm</em> that computes the outputs $Y_0, Y_1, \dots$ sequentially, leveraging the structure of the semiseparable matrix.</li> <li>The quadratic form is the <em>naive matrix multiplication algorithm</em> that materializes the full matrix.</li> </ul> <h3 id="going-beyond-the-ssd-layer-1">Going Beyond the SSD Layer 1</h3> <p>The power of the semiseparable matrix representation applies to <em>all</em> state space models, with various downstream implications.</p> <h4 id="algorithms">Algorithms</h4> <p>Algorithmically, the Mamba-2 paper explores several consequences, such as:</p> <ol> <li>The above duality result for the SSD model, i.e. a scalar-identity structured SSM.</li> <li>New asymptotic efficiency results for state space models (<a href="https://arxiv.org/abs/2405.21060">Theorem 3.7</a>), which follow from applying known results from the semiseparable matrix literature <d-cite key="pernet2016computing"></d-cite><d-cite key="pernet2018time"></d-cite><d-cite key="pernet2023exact"></d-cite>.</li> <li>A more general hybrid algorithm that can be viewed as combining both the linear and quadratic forms to get the best of both worlds. This can be derived as a new matrix multiplication algorithm utilizing <em>block decompositions</em> of the semiseparable matrix. This is the subject of Part III of this blog series!</li> </ol> <h4 id="understanding">Understanding</h4> <p>Conceptually, the matrix transformation viewpoint helps provide a unifying view of sequence models. Some example downstream ideas include</p> <ul> <li><strong>New sequence models</strong>: Restricting ourselves to matrix transformations reduces the problem of developing new sequence models to that of finding structured matrix classes with target properties. In ongoing work by my students, we study this point of view, and use it to derive the most natural bidirectional extension of Mamba (coming very soon!).</li> <li><strong>Expressivity</strong>: Looking at the matrix transformation representation can help us understand what different models can represent from a linear algebraic perspective. In another ongoing work, we use this as a tool to study which subquadratic models are the most amenable to being distilled from Transformers.</li> <li><strong>Interpretability</strong>: A concurrent work <d-cite key="ali2024hidden"></d-cite> derived the matrix formulation of SSMs and use it to probe the internal representations of Mamba models.</li> </ul> <p>We’re excited to see what algorithmic and conceptual ideas from the structured matrix literature can be applied to further improve state space models!</p> <h2 id="ssd-framework-2-structured-attention">SSD Framework 2: Structured Attention</h2> <p>The second framing of the duality is from an attention-centric perspective, where we’ll prove the duality through the framework of <strong>tensor contractions</strong>.</p> <p>Note that this is entirely independent of the previous [<a href="#ssd-framework-1-structured-matrix-transformations">matrix transformation viewpoint</a>].</p> <h3 id="warm-up-kernel-attention">Warm-up: Kernel Attention</h3> <p>For our purposes, we’ll define attention as a function</p> \[(Q^\mathtt{(T,N)}, K^\mathtt{(S,N)} , V^\mathtt{(S,P)} ) \mapsto Y^\mathtt{(T,P)}\] <p>given by the pairwise matrix multiplications</p> \[Y = (QK^\top) \cdot V\] <details><summary>On Dimensions</summary> <p>Think of $\mathtt{P} = \mathtt{N}$ as the head dimension; technically speaking, in attention the $V$ head dimension $\mathtt{P}$ can differ from the $QK$ head dimension $\mathtt{N}$. Think of $\mathtt{T}$ as the <em>target</em> sequence dimension and $\mathtt{S}$ as the <em>source</em> sequence dimension. Giving these two axes different names will make the math more clear and also covers more general forms of attention such as cross-attention, where the source and target are separate sequences with different lengths. However, for our purposes we’ll assume the self-attention setting where $\mathtt{S}=\mathtt{T}$.</p> </details> <details><summary>Why can we assume this form?</summary> <p>The usual form of attention $Y = f(QK^\top) \cdot V$ (e.g. where $f$ is the softmax function) can, for essentially all functions $f$<d-footnote>And up to some additional massaging such as row-wise normalization, which is easy to handle</d-footnote>, be written as $Y = \psi(Q)\psi(K)^\top \cdot V$ for some appropriate feature map $\psi$ (which may be infinite dimensional). In this case, we can simply redefine $Q \leftarrow \psi(Q)$ and define $\mathtt{N}$ to be the <strong>feature dimension</strong> of the attention kernel to begin with. Softmax attention, for example, can be represented with a particular infinite-dimensional feature map ($\mathtt{N}=\infty$) which represents the exponential kernel.</p> </details> <p>We’ll restrict ourselves to the case when $\psi$ is finite, which is sometimes called <strong>kernel attention</strong>. Many, many variants have been proposed before!<d-cite key="katharopoulos2020transformers"></d-cite><d-cite key="peng2021random"></d-cite><d-cite key="choromanski2021rethinking"></d-cite><d-cite key="qin2022cosformer"></d-cite><d-cite key="zheng2022linear"></d-cite><d-cite key="wang2020linformer"></d-cite><d-cite key="xiong2021nystromformer"></d-cite></p> <p>Why do we care about this formulation? When the sequence length $\mathtt{T}$ grows and the feature dimension $\mathtt{N}$ is small—commonly, in the regime when $\psi$ is simple such as an elementwise transform and so $\mathtt{N}$ is constant—then the cost of attention can be reduced from quadratic in $\mathtt{T}$ to linear. This follows from simply computing the matrix multiplications in a different order</p> \[Y = Q \cdot (K^\top V)\] <p>This is a somewhat “folklore” interpretation of linear attention.<d-footnote>At least, one lineage of efficient attention; other varieties exist, such as those based on sparsity or hashing. We reserve the term "linear attention" to those related to Katharopoulos et al.<d-cite key="katharopoulos2020transformers"></d-cite>, or more broadly low-rank attention.</d-footnote></p> <blockquote> <p>The most common way of linearizing attention is usually viewed as a consequence of the <strong>associativity of matrix multiplication</strong></p> </blockquote> <h3 id="causal-linear-attention">(Causal) Linear Attention</h3> <p>However, once the basic kernel attention is slightly modified, we can no longer use the associativity of matrix multiplication directly.</p> <p>The seminal <strong>Linear Attention (LA)</strong> framework of Katharopoulos et al. <d-cite key="katharopoulos2020transformers"></d-cite> shows that it can still be extended to the important case of incorporating causality into attention, for autoregressive settings such as language modeling.</p> <p>Let’s be a lot more explicit about how it works. The quadratic form of <strong>causal linear attention</strong> is \begin{equation} \label{eq:quadratic-kernel-attention} Y = (L \circ QK^\top) \cdot V \end{equation} where</p> \[L = \begin{bmatrix} 1 \\ \vdots &amp; \ddots \\ 1 &amp; \dots &amp; 1 \end{bmatrix}\] <p>is the <strong>causal mask</strong> matrix.</p> <p>The issue is: once the $L$ mask is incorporated into \eqref{eq:quadratic-kernel-attention}, we can no longer directly apply matrix associativity! This is the problem that the original Linear Attention paper addresses. What they show is that \eqref{eq:quadratic-kernel-attention} is equivalent to a different form which avoids materializing the quadratic $QK^\top$ attention matrix and has linear time complexity</p> \[Y = Q \cdot \mathsf{cumsum}(K^\top V)\] <p>As far as we’re aware this wasn’t explicitly proved in the paper, although it isn’t too hard to write out the summation to show it.</p> <p>What we’ll do is prove this equivalence in essentially one line, while revealing <em>exactly</em> where the “linear” part of Linear Attention comes from, and how to strongly generalize it.</p> <p>Spoiler alert:</p> <blockquote class="block-tip"> <h4 id="where-does-the-cumsum-in-linear-attention-come-from">Where does the cumsum in Linear Attention come from?</h4> <p>The appearance of the <em>cumulative sum</em> in linear attention is exactly equivalent to the fact that the causal mask $L$, as a matrix multiplication, encodes cumulative sums:</p> \[y = L \cdot x \iff y = \mathsf{cumsum}(x)\] </blockquote> <h3 id="a-tensor-contraction-proof-of-linear-attention">A Tensor Contraction Proof of Linear Attention</h3> <p>Let’s write out the quadratic form of linear attention \eqref{eq:quadratic-kernel-attention} very explicitly in <strong>tensor contraction</strong> or <a href="https://numpy.org/doc/stable/reference/generated/numpy.einsum.html">einsum</a> notation, with shape annotations:</p> \[\begin{aligned} G &amp;= \mathsf{contract}(\mathtt{TN, SN} \to \mathtt{TS})(Q, K) \\ M &amp;= \mathsf{contract}(\mathtt{TS, TS} \to \mathtt{TS})(G, L) \\ Y &amp;= \mathsf{contract}(\mathtt{TS, SP} \to \mathtt{TP})(M, V) \end{aligned}\] <p>\begin{equation} \label{eq:sma-quad} (\text{Structured Masked Attention - Quadratic Form}) \end{equation}</p> <p>With this notation, we can notice that this sequence of contractions can be written as a <em>single four-way contraction</em></p> <p>\begin{equation} \label{eq:sma} y = \mathsf{contract}(\mathtt{TN},\mathtt{SN},\mathtt{SP},\mathtt{TS} \to \mathtt{TP})(Q, K, V, L) . \end{equation}</p> <p>And finally, it can be computed with any other contraction ordering. In particular, we can perform pairwise reductions on the order $V, K, L, Q$ instead of $Q, K, L, V$</p> \[\begin{aligned} Z &amp;= \mathsf{contract}(\mathtt{SP},\mathtt{SN} \to \mathtt{SPN})(V, K) \\ H &amp;= \mathsf{contract}(\mathtt{TS},\mathtt{SPN} \to \mathtt{TPN})(L, Z) \\ Y &amp;= \mathsf{contract}(\mathtt{TN},\mathtt{TPN} \to \mathtt{TP})(Q, H) \end{aligned}\] <p>\begin{equation} \label{eq:sma-lin} (\text{Structured Masked Attention - Linear Form}) \end{equation}</p> <p>Now the key observation is that the second line of \eqref{eq:sma-lin} is simply a matrix multiplication by $L$, which can be computed with a cumulative sum.</p> <p>That’s the entire proof of linear attention! The beauty of it is that we didn’t have to write out a single summation, which was abstracted out into a tensor contraction combined with the structure of $L$.</p> <p>This immediately proves our claim about the <a href="#where-does-the-cumsum-in-linear-attention-come-from">cumsum in linear attention</a>. Moreover, this immediately reveals that the efficiency of linear attention can be made much more general…</p> <h3 id="structured-masked-attention">Structured Masked Attention</h3> <p>The critical observation is that in order for \eqref{eq:sma-lin} to be fast, all that is necessary is for $L$ to be <em>any structured matrix</em> – in other words any matrix that has subquadratic matrix-vector multiplication.</p> <p>This immediately motivates one of the main prongs of the SSD framework, which can be seen as a strong generation of LA.</p> <blockquote class="block-tip"> <h4 id="definition-structured-masked-attention">Definition: Structured Masked Attention</h4> <p><strong>Structured masked attention (SMA)</strong> is defined as the <em>four-way tensor contraction</em> \eqref{eq:sma} using an attention mask $L$ that is a structured matrix.</p> </blockquote> <blockquote class="block-tip"> <h4 id="duality-representation-2-sma">Duality Representation 2 (SMA)</h4> <p>SMA has <strong>dual quadratic and linear</strong><d-footnote>Assuming that the structured matrix $L$ has linear time matrix-vector multiplication</d-footnote> <strong>modes</strong> which are simply <em>two different pairwise reduction orders</em> \eqref{eq:sma-quad} and \eqref{eq:sma-lin}.</p> </blockquote> <p>Finally, let’s just connect this back to the commonly held view of linear attention as matrix multiplication associativity.</p> <blockquote> <p>Although it is commonly believed that incorporating attention masks $L$ prevents matrix multiplication reordering, it turns out to still be compatible. In particular, <strong>associativity of matrix multiplication</strong> is a special case of <strong>tensor contraction reduction orders</strong>; although the former no longer applies, the latter can integrate the attention mask $L$.</p> </blockquote> <p>Next, let’s look at some consequences of the structured attention framework.</p> <h3 id="deriving-the-duality-attention-to-ssm">Deriving the Duality: Attention to SSM</h3> <p>Recall that the SSD model is defined as either a scalar-identity SSM in equation \eqref{eq:ssm}, or through the attention-like form in equation \eqref{eq:ssd-attention}.</p> <p>To show the equivalence of these forms, we simply recognize that \eqref{eq:ssd-attention} is a special case of structured masked attention where the mask matrix is</p> \[L = \begin{bmatrix} 1 &amp; \\ a_1 &amp; 1 &amp; \\ a_2a_1 &amp; a_2 &amp; 1 \\ \vdots &amp; \vdots &amp; \ddots &amp; \ddots \\ a_{\mathtt{T}-1}\dots a_1 &amp; a_{\mathtt{T}-1}\dots a_2 &amp; \dots &amp; a_{\mathtt{T}-1} &amp; 1 \\ \end{bmatrix} .\] <p>\begin{equation} \label{eq:1-ss} (\text{1-semiseparable (1-SS) matrix}) \end{equation}</p> <p>We call this a <strong>1-semiseparable (1-SS) matrix</strong>, for reasons that are explained in more detail in the Mamba-2 paper.</p> <p>Thus, we can also say that the SSD model is <strong>1-semiseparable masked attention</strong> or <strong>1-SS SMA</strong>.</p> <p>To prove that this can be written as an SSM, we simply appeal to the SMA framework, which says that the dual form of this model can be computed through matrix multiplication by $L$. So how fast is that? It’s not too hard to see that multiplication $y = Lx$ can be computed in linear time through a scalar recurrence:</p> \[\begin{aligned} y_0 &amp;= x_0 \\ y_1 &amp;= a_1 x_0 + a_1 \\ y_2 &amp;= a_2a_1 x_0 + a_2 x_1 + x_2 = a_2 y_1 + x_2 \\ \vdots &amp; \qquad \vdots \end{aligned}\] <p>This corresponds exactly to the original SSM recurrence!</p> <p>(In fact, multiplication by 1-SS matrices $L$ can be computed in a <em>lot</em> more ways, which we compile in the full paper! Alternative algorithms can reveal more insights: for example, the associative scan algorithm used by S5 <d-cite key="smith2023s5"></d-cite> and Mamba can also be shown to be a structured matrix multiplication algorithm on 1-SS matrices.)</p> <h3 id="going-beyond-the-ssd-layer-2">Going Beyond the SSD Layer 2</h3> <p>Structured masked attention not only helps define the SSD model and prove its duality, but it is a much broader framework of efficient attention models.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-05-31-mamba-2/sma-480.webp 480w,/assets/img/2024-05-31-mamba-2/sma-800.webp 800w,/assets/img/2024-05-31-mamba-2/sma-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-05-31-mamba-2/sma.png" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>Prior examples include the original linear attention as well as the recent Retentive Network (RetNet) model<d-cite key="sun2023retentive"></d-cite>. These can be viewed as direct special cases of SSD. But beyond SSD, we can define classes of efficient attention by replacing the mask $L$ with <em>any structured matrix</em>. As a suggestion, we think that Toeplitz or Fourier structured attention may be interesting to consider because they might encode different forms of positional information.</p> <p>Additionally, other forms of structure can be incorporated into the $L$ mask. For example, another extension my students are developing is viewing SSD (and recurrences in general) as an algorithm operating on <em>directed line graphs</em>, and generalizing it to incorporate arbitrary graph structures.</p> <h2 id="state-space-duality">State Space Duality</h2> <p>We’ll end this post with a brief recap of what we’ve covered.</p> <p>The <strong>SSD framework</strong> consists of the two broad approaches covered in this post, which is summarized by the two areas of the [<a href="#the-state-space-duality-framework">Venn diagram</a>]:</p> <ol> <li>Viewing state space models through [<a href="#ssd-framework-1-structured-matrix-transformations">structured matrix transformations</a>]</li> <li>Generalizing linear attention through [<a href="#ssd-framework-2-structured-attention">tensor contractions</a>]</li> </ol> <p>The [<a href="#recap-the-ssd-model">SSD layer</a>] is a particular model which is the purple intersection in the figure, which can be viewed as an instance of either part of the SSD framework, and in particular has dual quadratic and linear forms that can be derived from either representation.</p> <table> <thead> <tr> <th><em>SSD Framework</em></th> <th>Structured SSMs</th> <th>Structured Attention</th> </tr> </thead> <tbody> <tr> <td>The main representation is…</td> <td>Structured matrix \eqref{eq:ssm-matrix} <br/> sequence transformations</td> <td>The 4-way \eqref{eq:sma} <br/> tensor contraction</td> </tr> <tr> <td>This generalizes…</td> <td>State space models</td> <td>Linear attention</td> </tr> <tr> <td>The SSD model is <br/> an instantiation as…</td> <td>Scalar state space model <br/> ($A_t$ is a scalar-identity matrix)</td> <td>1-semiseparable masked attention <br/> ($L$ mask is a 1-SS matrix)</td> </tr> <tr> <td>The linear-quadratic duality is <br/> revealed through…</td> <td>Structured matrix <br/> multiplication algorithms</td> <td>Tensor contraction <br/> reduction orderings</td> </tr> </tbody> </table> <h2 id="next-up">Next Up</h2> <p>In <a href="/blog/2024/mamba2-part3-algorithm/">the next part of this series</a>, we’ll see how to use some of the SSD framework (in particular, the <a href="#takeaway-computing-ssms">structured matrix algorithm</a> point of view) to derive the more efficient hybrid SSD algorithm that leverages both of the dual forms.</p>]]></content><author><name>Albert Gu</name></author><summary type="html"><![CDATA[Part I - The Model Part II - The Theory Part III - The Algorithm Part IV - The Systems]]></summary></entry><entry><title type="html">State Space Duality (Mamba-2) Part III - The Algorithm</title><link href="tridao.github.io/blog/2024/mamba2-part3-algorithm/" rel="alternate" type="text/html" title="State Space Duality (Mamba-2) Part III - The Algorithm"/><published>2024-05-31T00:00:00+00:00</published><updated>2024-05-31T00:00:00+00:00</updated><id>tridao.github.io/blog/2024/mamba2-part3-algorithm</id><content type="html" xml:base="tridao.github.io/blog/2024/mamba2-part3-algorithm/"><![CDATA[<ol> <li><a href="/blog/2024/mamba2-part1-model/">Part I - The Model</a></li> <li><a href="/blog/2024/mamba2-part2-theory/">Part II - The Theory</a></li> <li>Part III - The Algorithm</li> <li><a href="/blog/2024/mamba2-part4-systems/">Part IV - The Systems</a></li> </ol> <p>The theoretical framework of structured state space duality (see <a href="/blog/2024/mamba2-part1-model/">Part I</a> and <a href="/blog/2024/mamba2-part2-theory/">Part II</a> of this series) connects SSMs and (linear) attention through structured matrices. As mentioned in Part I, this connection allows us to derive new algorithms for selective SSMs that are faster than the parallel associative scan in Mamba-1 by leveraging matrix multiplication as a primitive. Moreover, the connection can bring system optimizations (e.g. tensor parallelism, sequence parallelism, variable sequence length) originally developed for Transformer to SSM-land.</p> <h2 id="the-ssd-algorithm">The SSD Algorithm</h2> <p>Even though we already developed optimized scans implementations for Mamba-1, we were limited to small state expansion (typically $\mathtt{N}=16$) as the algorithm and implementation did not use tensor cores (specialized hardware units that perform matrix multiplication). Typically matrix multiplication (matmul) FLOPs are much faster (up to 16x) than non-matmul FLOPs: the A100 GPU has 312 TFLOPS of BF16 matmul but only 19 TFLOPS of FP32 arithmetics, and the H100 has 989 TFLOPS of BF16 matmul but only 67 TFLOPS of FP32 arithmetics. One of our primary goals with Mamba-2 is to <strong>leverage tensor cores to speed up the SSM</strong>.</p> <p>To recap, after tying parameters and introducing the head structure, the SSM in Mamba-1 turns into SSD, a more restrictive form that has an attention-like formulation. And as SSD connects SSMs and structured matrices, we saw in Part II that efficient algorithms to compute SSMs correspond directly to different decompositions of the “token-mixing” or “sequence-mixing” matrix $M$.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-05-31-mamba-2/ssd_algorithm-480.webp 480w,/assets/img/2024-05-31-mamba-2/ssd_algorithm-800.webp 800w,/assets/img/2024-05-31-mamba-2/ssd_algorithm-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-05-31-mamba-2/ssd_algorithm.png" width="100%" height="auto" title="SSD Algorithm" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>We can therefore create new algorithms to compute SSMs simply by looking for alternative ways to multiply this matrix, for example by decomposing it in various ways. A simple block decomposition of this matrix, with carefully chosen block sizes, turns out to get all the advantages of both the linear-recurrent and quadratic-attention dual forms of SSD. This leads to the SSD algorithm, which has 4 steps. There are two completely different interpretations of this algorithm!</p> <h3 id="ssd-algorithm-block-matrix-decomposition">SSD Algorithm: Block Matrix Decomposition</h3> <p>We first partition the SSM (semiseparable) matrix into blocks of size $\mathtt{Q} \times \mathtt{Q}$. Then, we use the properties of semiseparable matrices to factorize each off-diagonal block, which is low rank.</p> <ol> <li>(<em>Orange</em>) Each diagonal block is a smaller semiseparable matrix; we can compute this multiplication however we like; in particular, using the quadratic (attention-like) form of SSD.</li> <li>(<em>Green</em>) There are only $\mathtt{T} / \mathtt{Q}$ total different green blocks because many of them are shared. These can be computed with a batched matmul.</li> <li>(<em>Yellow</em>) Notice that the yellow terms themselves form a 1-semiseparable matrix; in other words, this step is equivalently to an SSM scan (on some modified $A$ factors)!</li> <li>(<em>Blue</em>) Similar to green, these can be computed with a batched matmul.</li> </ol> <h3 id="ssd-algorithm-chunking-and-state-passing">SSD Algorithm: Chunking and State Passing</h3> <p>An alternative interpretation of the algorithm involves reasoning about how the SSM operates on the actual sequence. We first split the sequence of input into blocks (or chunks) of size $\mathtt{Q}$. The steps then have the interpretation</p> <ol> <li><strong>Intra-chunk outputs</strong>: compute the local output of each chunk (<em>what is the output per chunk supposing that the initial state (to the chunk) is 0?</em>)</li> <li><strong>Chunk states</strong>: compute the final state of each chunk (<em>what is the final state per chunk supposing that the initial state (to the chunk) is 0?</em>)</li> <li><strong>Pass states</strong>: compute a recurrence on all of the chunks’ final states – using any desired algorithm, e.g. parallel or sequential scan (<em>what is the actual final state per chunk taking into account all previous inputs?</em>)</li> <li><strong>Output states</strong>: for each chunk, given its true initial state (computed in Step 3), compute the contribution to the output just from the initial state</li> </ol> <p>Either way, we see that most of the algorithm (Step 1, 2, and 4) leverages matmuls (and hence tensor cores), and also can be computed completely in parallel! Only Step 3 requires a scan, but it operates on a much shorter sequence and usually only takes a small fraction of the time of the full algorithm.</p> <h3 id="special-cases">Special Cases</h3> <p>We note that special cases of this algorithm have been seen before. In particular RetNet<d-cite key="sun2023retentive"></d-cite>, which we showed in Part II to be a special case of SSD, mention a “chunkwise” algorithm which computes the quadratic form on a chunk of the input one-at-a-time and passes the final state to the next chunk. This turns out to be essentially equivalent to the SSD algorithm specialized to a restricted case (i.e. a decay matrix mask $L$). Our derivation comes from a different direction—the block matrix decomposition—which also makes it more obvious how to parallelize this algorithm and make it really fast in practice.</p> <p>Other forms of “chunkwise” recurrences have recently become popular, such as in <a href="https://arxiv.org/abs/2312.06635">Gated Linear Attention (GLA)</a><d-cite key="yang2024gated"></d-cite>.</p> <h2 id="the-code">The Code</h2> <p>In the “Minimal SSD” code that we provide in the paper and the <a href="https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py">code release</a>, we delineate each of these four steps. As promised, this algorithm is not only faster but also much easier to implement than the original selective scan of Mamba, coming in at just around 25 lines of code!</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">segsum</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
    <span class="sh">"""</span><span class="s">Naive segment sum calculation. exp(segsum(A)) produces a 1-SS matrix,
       which is equivalent to a scalar SSM.</span><span class="sh">"""</span>
    <span class="n">T</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="nf">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">x_cumsum</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">cumsum</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">x_segsum</span> <span class="o">=</span> <span class="n">x_cumsum</span><span class="p">[...,</span> <span class="p">:,</span> <span class="bp">None</span><span class="p">]</span> <span class="o">-</span> <span class="n">x_cumsum</span><span class="p">[...,</span> <span class="bp">None</span><span class="p">,</span> <span class="p">:]</span>
    <span class="n">mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">tril</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="nf">ones</span><span class="p">(</span><span class="n">T</span><span class="p">,</span> <span class="n">T</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="p">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">),</span> <span class="n">diagonal</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
    <span class="n">x_segsum</span> <span class="o">=</span> <span class="n">x_segsum</span><span class="p">.</span><span class="nf">masked_fill</span><span class="p">(</span><span class="o">~</span><span class="n">mask</span><span class="p">,</span> <span class="o">-</span><span class="n">torch</span><span class="p">.</span><span class="n">inf</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">x_segsum</span>

<span class="k">def</span> <span class="nf">ssd</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">block_len</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">initial_states</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
    <span class="sh">"""</span><span class="s">
    Arguments:
        X: (batch, length, n_heads, d_head)
        A: (batch, length, n_heads)
        B: (batch, length, n_heads, d_state)
        C: (batch, length, n_heads, d_state)
    Return:
        Y: (batch, length, n_heads, d_head)
    </span><span class="sh">"""</span>
    <span class="k">assert</span> <span class="n">X</span><span class="p">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">A</span><span class="p">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">B</span><span class="p">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">C</span><span class="p">.</span><span class="n">dtype</span>
    <span class="k">assert</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">%</span> <span class="n">block_len</span> <span class="o">==</span> <span class="mi">0</span>

    <span class="c1"># Rearrange into blocks/chunks
</span>    <span class="n">X</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="p">[</span><span class="nf">rearrange</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="sh">"</span><span class="s">b (c l) ... -&gt; b c l ...</span><span class="sh">"</span><span class="p">,</span> <span class="n">l</span><span class="o">=</span><span class="n">block_len</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">)]</span>

    <span class="n">A</span> <span class="o">=</span> <span class="nf">rearrange</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="sh">"</span><span class="s">b c l h -&gt; b h c l</span><span class="sh">"</span><span class="p">)</span>
    <span class="n">A_cumsum</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">cumsum</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

    <span class="c1"># 1. Compute the output for each intra-chunk (diagonal blocks)
</span>    <span class="n">L</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="nf">segsum</span><span class="p">(</span><span class="n">A</span><span class="p">))</span>
    <span class="n">Y_diag</span>  <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">einsum</span><span class="p">(</span><span class="sh">"</span><span class="s">bclhn,bcshn,bhcls,bcshp-&gt;bclhp</span><span class="sh">"</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="n">X</span><span class="p">)</span>

    <span class="c1"># 2. Compute the state for each intra-chunk
</span>    <span class="c1"># (right term of low-rank factorization of off-diagonal blocks; B terms)
</span>    <span class="n">decay_states</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">exp</span><span class="p">((</span><span class="n">A_cumsum</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">:]</span> <span class="o">-</span> <span class="n">A_cumsum</span><span class="p">))</span>
    <span class="n">states</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">einsum</span><span class="p">(</span><span class="sh">"</span><span class="s">bclhn,bhcl,bclhp-&gt;bchpn</span><span class="sh">"</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">decay_states</span><span class="p">,</span> <span class="n">X</span><span class="p">)</span>

    <span class="c1"># 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
</span>    <span class="c1"># (middle term of factorization of off-diag blocks; A terms)
</span>    <span class="k">if</span> <span class="n">initial_states</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
        <span class="n">initial_states</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">zeros_like</span><span class="p">(</span><span class="n">states</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">1</span><span class="p">])</span>
    <span class="n">states</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">cat</span><span class="p">([</span><span class="n">initial_states</span><span class="p">,</span> <span class="n">states</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">decay_chunk</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="nf">segsum</span><span class="p">(</span><span class="n">F</span><span class="p">.</span><span class="nf">pad</span><span class="p">(</span><span class="n">A_cumsum</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))))</span>
    <span class="n">new_states</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">einsum</span><span class="p">(</span><span class="sh">"</span><span class="s">bhzc,bchpn-&gt;bzhpn</span><span class="sh">"</span><span class="p">,</span> <span class="n">decay_chunk</span><span class="p">,</span> <span class="n">states</span><span class="p">)</span>
    <span class="n">states</span><span class="p">,</span> <span class="n">final_state</span> <span class="o">=</span> <span class="n">new_states</span><span class="p">[:,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">new_states</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span>

    <span class="c1"># 4. Compute state -&gt; output conversion per chunk
</span>    <span class="c1"># (left term of low-rank factorization of off-diagonal blocks; C terms)
</span>    <span class="n">state_decay_out</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="n">A_cumsum</span><span class="p">)</span>
    <span class="n">Y_off</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">einsum</span><span class="p">(</span><span class="sh">'</span><span class="s">bclhn,bchpn,bhcl-&gt;bclhp</span><span class="sh">'</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">states</span><span class="p">,</span> <span class="n">state_decay_out</span><span class="p">)</span>

    <span class="c1"># Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
</span>    <span class="n">Y</span> <span class="o">=</span> <span class="nf">rearrange</span><span class="p">(</span><span class="n">Y_diag</span><span class="o">+</span><span class="n">Y_off</span><span class="p">,</span> <span class="sh">"</span><span class="s">b c l h p -&gt; b (c l) h p</span><span class="sh">"</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">Y</span><span class="p">,</span> <span class="n">final_state</span>
</code></pre></div></div> <h2 id="the-details">The Details</h2> <p>Let’s talk about a couple of additional details in the implementation (these don’t even appear in the full paper, so pay attention!) that unpack some of the choices in this reference code.</p> <h3 id="the-ssm-scan">The SSM Scan</h3> <p>In the above code, we utilized the connection between scalar SSM recurrences</p> \[h_{t+1} = A_t h_t + B_t x_t\] <p>and matrix multiplication by 1-semiseparable matrices</p> \[L = \begin{bmatrix} 1 &amp; \\ a_1 &amp; 1 &amp; \\ a_2a_1 &amp; a_2 &amp; 1 \\ \vdots &amp; \vdots &amp; \ddots &amp; \ddots \\ a_{\mathtt{T}-1}\dots a_1 &amp; a_{\mathtt{T}-1}\dots a_2 &amp; \dots &amp; a_{\mathtt{T}-1} &amp; 1 \\ \end{bmatrix}\] <p>which we covered in Part II (and Section 3.2.2 of the paper). In this minimal implementation, we compute Step 3 of the algorithm, which is computing a scalar SSM by <em>any</em> algorithm of our choice, by explicitly materializing a 1-SS matrix and doing dense matrix multiplication.</p> <p>We use this version for several reasons:</p> <ol> <li>Code-wise, it’s simpler to materialize and multiply by this matrix than to actually implement a parallel associative scan</li> <li>Because of the block decomposition of the SSM matrix, the sequence length $\mathtt{T}$ is reduced by a factor of $\approx 100$ – so doing the scan in time $O(\mathtt{T}^2)$ instead of $O(\mathtt{T})$ isn’t too bad</li> <li>We have to materialize a 1-SS matrix anyways for Step 1 of the algorithm (the diagonal blocks), so might as well reuse the code ¯\_(ツ)_/¯</li> </ol> <p>While this example code is simpler and reasonably efficient on GPU (and probably TPU as well!), it’s no longer truly linear at long sequences. Our more optimized Triton implementation does replace the 1-SS multiplication in Step 3 with an actual associative scan.</p> <h3 id="stability">Stability</h3> <h4 id="attempt-1-ratios-of-cumprods">Attempt 1: Ratios of cumprods</h4> <p>The first naive attempt may be to notice that the entries of this matrix are cumulative products</p> \[a_{i:j}^\times = a_i \times \cdots \times a_{j-1} = \frac{a_{i:\mathtt{T}}^\times}{a_{j:\mathtt{T}}^\times}\] <p>However, this runs into severe numerical issues because these products can get really tiny (imagine $a_t \approx 0.9$ and powering it up for a sequence length $\mathtt{T}$ in the thousands!)</p> <h4 id="fix-1-the-segment-sum-segsum-operation">Fix 1: The Segment Sum (<code class="language-plaintext highlighter-rouge">segsum</code>) Operation</h4> <p>The second attempt would be to do all of this in log-space, because all the $a_t$ are positive; so the products become additions, and instead of <code class="language-plaintext highlighter-rouge">cumprod</code>s to deal with we have <code class="language-plaintext highlighter-rouge">cumsum</code>s instead. Then in order to compute the 1-SS matrix, we just have to compute the sums $\log a_i + \dots + \log a_{j-1}$ for every <em>segment</em> $[i:j]$. We call this the <strong>segment sum (segsum)</strong> primitive, analogous to cumulative sum (cumsum).</p> <h4 id="attempt-2-differences-of-cumsums">Attempt 2: Differences of cumsums</h4> <p>The obvious way to do this again is using the same idea as above, but in log space</p> \[a_{i:j}^\times = \exp\left( \log a_i + \cdots + \log a_{j-1} \right) = \left( (\log a)_{i:\mathtt{T}}^+ - (\log a)_{j:\mathtt{T}}^+ \right)\] <p>where we compute a single cumulative sum of $a$ along the time axis, and then compute all pairwise differences. In code, we can do this with</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">segsum_unstable</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
    <span class="sh">"""</span><span class="s">Naive segment sum calculation.</span><span class="sh">"""</span>
    <span class="n">T</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="nf">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">x_cumsum</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">cumsum</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">x_segsum</span> <span class="o">=</span> <span class="n">x_cumsum</span><span class="p">[...,</span> <span class="p">:,</span> <span class="bp">None</span><span class="p">]</span> <span class="o">-</span> <span class="n">x_cumsum</span><span class="p">[...,</span> <span class="bp">None</span><span class="p">,</span> <span class="p">:]</span>
    <span class="n">mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">tril</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="nf">ones</span><span class="p">(</span><span class="n">T</span><span class="p">,</span> <span class="n">T</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="p">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">),</span> <span class="n">diagonal</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
    <span class="n">x_segsum</span> <span class="o">=</span> <span class="n">x_segsum</span><span class="p">.</span><span class="nf">masked_fill</span><span class="p">(</span><span class="o">~</span><span class="n">mask</span><span class="p">,</span> <span class="o">-</span><span class="n">torch</span><span class="p">.</span><span class="n">inf</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">x_segsum</span>
</code></pre></div></div> <p>(and then the 1-semiseparable matrix is just the exponential of this output).</p> <p>Sums/differences are a lot more stable than products/quotients, so this should work – right?</p> <h4 id="fix-2-remove-all-subtractions">Fix 2: Remove All Subtractions</h4> <p>Unfortunately, it turns out this still doesn’t work. The values of this 1-SS matrix roughly represent the SSM dynamics, which are very sensitive to these values of $a_t$, so we have to be very precise. And even in log space, these cumsums can be fairly large, which runs into <a href="https://en.wikipedia.org/wiki/Catastrophic_cancellation">catastrophic cancellation</a> when subtracted. So we really have to find a way to compute this matrix with only additions, while still vectorizing everything…</p> <h4 id="attempt-3-stable-segsum">Attempt 3: Stable Segsum</h4> <p>This leads to the helper function in the reference SSD code. Instead of computing a single cumsum and then subtracting, we find a way to use a batch of independent cumsums that immediately produces the right answer without subtraction.</p> <p>These details do matter! Without the right implementation of these primitives, the basic SSD algorithm produces NaNs immediately during training (even with FP32).</p> <h3 id="discretization">Discretization</h3> <p>This lineage of structured state space models developed from <a href="https://arxiv.org/abs/2111.00396">S4</a> and <a href="https://arxiv.org/abs/2110.13985">its</a> <a href="https://arxiv.org/abs/2008.07669">predecessors</a> which were viewed as continuous-time systems.<d-cite key="gu2023thesis"></d-cite><d-cite key="gu2022efficiently"></d-cite><d-cite key="gu2021combining"></d-cite><d-cite key="gu2020hippo"></d-cite></p> <p>In Mamba, however, we don’t really view the SSM as continuous anymore. In fact, as mentioned in the Discussion (Section 5) of the <a href="https://arxiv.org/abs/2312.00752">original paper</a>, Mamba trades off with S4 on modeling different types of data:</p> <ul> <li>S4 is a continuous-time model that excels at modeling continuous data, e.g. perceptual signals such as audio waveforms and pixel-level vision.</li> <li>Mamba is a discrete-time model that excels at modeling discrete data, e.g. tokenized data such as language.</li> </ul> <p>However, the parameterization of Mamba still used the same discretization step as in prior structured SSMs, where there is another parameter $\Delta$ being modeled. We do this because the discretization step has other side effects such as properly normalizing the activations <d-cite key="gu2023train"></d-cite><d-cite key="orvieto2023resurrecting"></d-cite> which is important for performance.</p> <p>The initializations and parameterizations from the previous <a href="https://arxiv.org/abs/2206.12037">theory on structured SSMs</a> still work out-of-the-box, so why fix what’s not broken?</p> <p>Despite this, we’re pretty sure that the discretization step isn’t really necessary for Mamba. In the Mamba-2 paper, we chose to work directly with the “discrete parameters” $A$ and $B$, which in all previous structured SSM papers (including Mamba-1) were denoted $(\bar{A}, \bar{B})$ and defined through an additional transformation</p> \[\begin{align*} \bar{A} &amp;= \exp(e^{\Delta A}) \\ \bar{B} &amp;= (\exp(e^{\Delta A}) - I) A^{-1} B \end{align*}\] <p>This doesn’t pose any problems: to use the continuous SSM parameterization, simply transform the parameters through the above formulas before plugging into the SSD code above.</p> <p>In the full Mamba-2 code, we also kept the same parameterization and discretization step as in Mamba—again, why fix what’s not broken?—but hypothesize that “discrete-centric” variants (such as the <em>gamma normalization</em> of <a href="https://arxiv.org/abs/2303.06349">LRU</a><d-cite key="orvieto2023resurrecting"></d-cite> and <a href="https://arxiv.org/abs/2402.19427">Griffin</a><d-cite key="de2024griffin"></d-cite>) should work equally well.</p> <blockquote class="block-tip"> <h4 id="is-discretization-necessary">Is Discretization Necessary?</h4> <p>It’s useful for other structured SSMs, but perhaps not needed for Mamba. But it’s just a simple invertible transformation, so use either discrete or continuous parameterizations as you like!</p> </blockquote> <h2 id="whats-next">What’s Next</h2> <p>In the <a href="/blog/2024/mamba2-part4-systems/">final part of this series</a>, we’ll continue talking about the implementation of Mamba-2, but on a more macroscopic level; about the entire neural network, instead of just details of the core SSD layer.</p> <p>We’ll also talk about the actual speed of the algorithm covered in this post.</p>]]></content><author><name>Tri Dao</name></author><summary type="html"><![CDATA[Part I - The Model Part II - The Theory Part III - The Algorithm Part IV - The Systems]]></summary></entry><entry><title type="html">State Space Duality (Mamba-2) Part IV - The Systems</title><link href="tridao.github.io/blog/2024/mamba2-part4-systems/" rel="alternate" type="text/html" title="State Space Duality (Mamba-2) Part IV - The Systems"/><published>2024-05-31T00:00:00+00:00</published><updated>2024-05-31T00:00:00+00:00</updated><id>tridao.github.io/blog/2024/mamba2-part4-systems</id><content type="html" xml:base="tridao.github.io/blog/2024/mamba2-part4-systems/"><![CDATA[<ol> <li><a href="/blog/2024/mamba2-part1-model/">Part I - The Model</a></li> <li><a href="/blog/2024/mamba2-part2-theory/">Part II - The Theory</a></li> <li><a href="/blog/2024/mamba2-part3-algorithm/">Part III - The Algorithm</a></li> <li>Part IV - The Systems</li> </ol> <p>Transformers have benefited from 7 years of systems optimization from the whole research community and large companies. The SSD framework draws connections between SSMs and attention, and allows us to implement many of these optimizations for models like Mamba-2 as well. We focus on tensor parallel and sequence parallel for large-scale training, as well as variable-length sequences for efficient finetuning and inference.</p> <h2 id="systems-and-scaling-optimizations">Systems and Scaling Optimizations</h2> <h3 id="tensor-parallelism">Tensor Parallelism</h3> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-05-31-mamba-2/mamba_tp-480.webp 480w,/assets/img/2024-05-31-mamba-2/mamba_tp-800.webp 800w,/assets/img/2024-05-31-mamba-2/mamba_tp-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-05-31-mamba-2/mamba_tp.png" width="100%" height="auto" title="Mamba-2 Tensor Parallelism" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>One difficulty with large-scaling training of Mamba-1 using tensor parallelism (TP) is that it requires 2 all-reduces per layer, compared to just 1 all-reduce per attention or MLP layer in Transformer. This is because some of the SSM parameters are functions of the inner activations, not of the input to the layer. In Mamba-2, with the “parallel projection” structure, all SSM parameters are functions of the input to the layer, and we can easily apply TP to the input projection: We split the input projection and output projection matrices into 2, 4, 8 shards, depending on the TP degree. We use a grouped norm with number of groups divisible by the TP degree, so that normalization is done separately per GPU. These changes result in 1 all-reduce per layer, instead of 2.</p> <h3 id="sequence-parallelism">Sequence Parallelism</h3> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-05-31-mamba-2/mamba_cp-480.webp 480w,/assets/img/2024-05-31-mamba-2/mamba_cp-800.webp 800w,/assets/img/2024-05-31-mamba-2/mamba_cp-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-05-31-mamba-2/mamba_cp.png" width="100%" height="auto" title="Mamba-2 Sequence Parallelism" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>When training on very long sequence length, we might need to split along the sequence length and assign different parts to different devices. There are two main forms of sequence parallelism (SP): For the residual and normalization operation: this replaces the all-reduce in TP with a reduce-scatter, residual + normalization, then all-gather. Since Mamba-2 uses the same residual and normalization structure as Transformer, this form of SP applies directly with no modification. For the attention or SSM operation, aka context parallelism (CP). For attention, one could use Ring attention to split it up along the sequence dimension. For Mamba-2, the SSD framework comes to our help once again: using the same block decomposition, we can have each GPU computing its local output and its final states, then pass the states between GPUs (using send/receive communication primitives), before updating the final output of each GPU.</p> <h3 id="variable-length">Variable Length</h3> <p>For finetuning and inference, in the same batch we often have sequences of different lengths. For Transformer, one would usually pad so all sequences have the same length (wasting computation), or implement attention specifically for variable length sequences with careful load-balancing. With SSM, we can simply treat the whole batch as a long “sequence”, and avoid passing the states between different sequences in the batch by setting the state transition $A_t$ to 0 for tokens at the end of each sequence.</p> <h2 id="results">Results</h2> <p>How well do these optimizations work? The faster SSD algorithm allows us to increase the state dimension ($\mathtt{N}=64$ or $128$ compared to $\mathtt{N}=16$ in Mamba-1). Even though technically Mamba-2 is more restricted than Mamba-1 for the same $\mathtt{N}$, the larger state dimensions generally improve model quality. Here we show results for models trained on 300B tokens on the Pile, with Mamba-2 outperforming Mamba-1 and Pythia.</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-05-31-mamba-2/blog_lm_downstream-480.webp 480w,/assets/img/2024-05-31-mamba-2/blog_lm_downstream-800.webp 800w,/assets/img/2024-05-31-mamba-2/blog_lm_downstream-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-05-31-mamba-2/blog_lm_downstream.png" width="100%" height="auto" title="Downstream Evaluations" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> <figcaption class="caption">Standard downstream evaluations for open source models trained on the Pile</figcaption> </figure> <p>What about <strong>hybrid models</strong>? We have seen from recent and concurrent work (such as <a href="https://arxiv.org/abs/2403.19887">Jamba</a> and <a href="https://arxiv.org/abs/2405.16712">Zamba</a>) that combining Mamba layers with attention layers can improve over pure Transformer or Mamba. We validate at 2.7B parameters and 300B tokens scale that a hybrid model with just 6 attention blocks (and 58 SSD blocks) outperforms 64 SSD blocks, as well as our standard Transformer++ baseline (32 gated MLP and 32 attention blocks).</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-05-31-mamba-2/blog_hybrid-480.webp 480w,/assets/img/2024-05-31-mamba-2/blog_hybrid-800.webp 800w,/assets/img/2024-05-31-mamba-2/blog_hybrid-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-05-31-mamba-2/blog_hybrid.png" width="100%" height="auto" title="Downstream Evaluations for Hybrid Models" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> <figcaption class="caption">Downstream evaluations for hybrid Mamba/attention models</figcaption> </figure> <p>We also validated that the SSD algorithm is significantly faster than the selective scan algorithm from Mamba-1 for the same state dimension, and scales much better computationally to larger state dimensions. Getting those tensor cores to go brrr is the key!</p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/2024-05-31-mamba-2/ssm_ssd_dstate-480.webp 480w,/assets/img/2024-05-31-mamba-2/ssm_ssd_dstate-800.webp 800w,/assets/img/2024-05-31-mamba-2/ssm_ssd_dstate-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/2024-05-31-mamba-2/ssm_ssd_dstate.png" width="100%" height="auto" title="Mamba-2 Efficiency Benchmarks" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> <figcaption class="caption">Efficiency benchmarks on sequence length 2K</figcaption> </figure> <h2 id="future-directions">Future Directions</h2> <p>With SSD, we have connected (linear) attention and SSMs, allowing us to design faster algorithms and implement systems optimizations for SSMs. There are still tons of exciting directions that we (and hopefully the community) want to tackle:</p> <ul> <li><strong>Understanding</strong>: hybrid models with a few (4-6) attention layers perform very well, even better than pure Mamba(-2) or Transformer++. What are these attention layers doing? Can they be replaced with another mechanism?</li> <li><strong>Training optimizations</strong>: though SSD might be faster than attention, Mamba-2 as a whole might still be slower than Transformers at short (e.g. 2K) sequence length, since the MLP layers in Transformers are very hardware-friendly. Our implementation of SSD does not specifically take advantage of new features on H100 GPUs, and we look forward to future optimizations that could make SSMs faster to train than Transformers for large-scale pretraining at 2-4K sequence length.</li> <li><strong>Inference optimizations</strong>: there’s a whole suite of optimizations tailored to Transformers, in particular handling the KV cache (quantization, speculative decoding). How would the inference landscape change if model states (e.g. SSM states) no longer scale with context length, and KV cache is no longer the bottleneck?</li> </ul>]]></content><author><name>Tri Dao</name></author><summary type="html"><![CDATA[Part I - The Model Part II - The Theory Part III - The Algorithm Part IV - The Systems]]></summary></entry></feed>