Transformers have benefited from 7 years of systems optimization from the whole research community and large companies. The SSD framework draws connections between SSMs and attention, and allows us to implement many of these optimizations for models like Mamba-2 as well. We focus on tensor parallel and sequence parallel for large-scale training, as well as variable-length sequences for efficient finetuning and inference.
One difficulty with large-scaling training of Mamba-1 using tensor parallelism (TP) is that it requires 2 all-reduces per layer, compared to just 1 all-reduce per attention or MLP layer in Transformer. This is because some of the SSM parameters are functions of the inner activations, not of the input to the layer. In Mamba-2, with the “parallel projection” structure, all SSM parameters are functions of the input to the layer, and we can easily apply TP to the input projection: We split the input projection and output projection matrices into 2, 4, 8 shards, depending on the TP degree. We use a grouped norm with number of groups divisible by the TP degree, so that normalization is done separately per GPU. These changes result in 1 all-reduce per layer, instead of 2.
When training on very long sequence length, we might need to split along the sequence length and assign different parts to different devices. There are two main forms of sequence parallelism (SP): For the residual and normalization operation: this replaces the all-reduce in TP with a reduce-scatter, residual + normalization, then all-gather. Since Mamba-2 uses the same residual and normalization structure as Transformer, this form of SP applies directly with no modification. For the attention or SSM operation, aka context parallelism (CP). For attention, one could use Ring attention to split it up along the sequence dimension. For Mamba-2, the SSD framework comes to our help once again: using the same block decomposition, we can have each GPU computing its local output and its final states, then pass the states between GPUs (using send/receive communication primitives), before updating the final output of each GPU.
For finetuning and inference, in the same batch we often have sequences of different lengths. For Transformer, one would usually pad so all sequences have the same length (wasting computation), or implement attention specifically for variable length sequences with careful load-balancing. With SSM, we can simply treat the whole batch as a long “sequence”, and avoid passing the states between different sequences in the batch by setting the state transition $A_t$ to 0 for tokens at the end of each sequence.
How well do these optimizations work? The faster SSD algorithm allows us to increase the state dimension ($\mathtt{N}=64$ or $128$ compared to $\mathtt{N}=16$ in Mamba-1). Even though technically Mamba-2 is more restricted than Mamba-1 for the same $\mathtt{N}$, the larger state dimensions generally improve model quality. Here we show results for models trained on 300B tokens on the Pile, with Mamba-2 outperforming Mamba-1 and Pythia.
What about hybrid models? We have seen from recent and concurrent work (such as Jamba and Zamba) that combining Mamba layers with attention layers can improve over pure Transformer or Mamba. We validate at 2.7B parameters and 300B tokens scale that a hybrid model with just 6 attention blocks (and 58 SSD blocks) outperforms 64 SSD blocks, as well as our standard Transformer++ baseline (32 gated MLP and 32 attention blocks).
We also validated that the SSD algorithm is significantly faster than the selective scan algorithm from Mamba-1 for the same state dimension, and scales much better computationally to larger state dimensions. Getting those tensor cores to go brrr is the key!
With SSD, we have connected (linear) attention and SSMs, allowing us to design faster algorithms and implement systems optimizations for SSMs. There are still tons of exciting directions that we (and hopefully the community) want to tackle: