Two acquaintances independently asked about this today, so it seems worth a write-up: recently (as of this writing), DeepMind published a new paper about a new practical fast matrix multiplication algorithm, along with a press release that is a bit misleading and seems to have led to some confusion already.

In particular, while the paper does introduce many new “practical” (not scare quotes, I’m using the word in a specific sense here that I’ll make more precise in a minute) fast small-matrix multiplication algorithms, that does not mean that you should replace your small-matrix library routines for 3×3 or 4×4 matrix multiplication with them. That’s not actually what they’re meant for, and they wouldn’t be good at it.

If you just want the executive summary, here it is: these are definitely interesting algorithms from an arithmetic complexity theory standpoint – especially for the case of 4×4 matrices over finite fields, where (to the best of my knowledge) Strassen’s algorithm from 1969 was still the reigning champion. These algorithms are also practically relevant, meaning that not only do they have better asymptotic lower bounds than Strassen’s algorithm, they are still algorithms you might actually use in practice, unlike essentially everything else that has been written on the topic in the last 50 years: these algorithms are correct, and will in principle win over Strassen’s algorithm with large enough matrices, but that cut-off is well beyond the sizes that anyone is actually doing computations with.

And if you know what Strassen’s algorithm is, you might be in the market for the results from this paper, In fact, I’ll go one further and say that if you are currently using Strassen’s algorithm somewhere, you should definitely check the paper out. For the rest of you, I’ll try to give a very short summary of the topic and explain why actual small matrices are not really the intended use case.

Strassen’s algorithm

Regular matrix multiplication of 2×2 matrices uses the standard “dot products of rows with columns” algorithm:

$\begin{pmatrix} a_{11} & a_{12} \\ a_{21} & a_{22} \end{pmatrix} \begin{pmatrix} b_{11} & b_{12} \\ b_{21} & b_{22} \end{pmatrix} = \begin{pmatrix} a_{11} b_{11} + a_{12} b_{21} & a_{11} b_{12} + a_{12} b_{22} \\ a_{21} b_{11} + a_{22} b_{21} & a_{21} b_{12} + a_{22} b_{22} \end{pmatrix}$

As written, this does 8 multiplications and 4 additions. In 1969, Volker Strassen discovered an algorithm for this that only uses 7 multiplications but 18 additions (follow the link for more details, I won’t go into it here); Winograd later presented a variant that only uses 15 additions. This is interesting in principle if multiplications are much more expensive than additions, something which is true in some settings but does not commonly apply to hardware floating-point math these days. Hardware floating-point now commonly implements fused multiply-add (FMA) instructions, and the 2×2 matrix multiplication above can be implemented using 4 regular multiplications, 4 fused multiply-adds, and no separate additions at all. Moreover, Strassen’s algorithm has some numerical stability issues when used with floating-point (these concerns do not exist when it’s used for finite field arithmetic, though!) that means it also must be used carefully. What this means is that you would never actually consider using Strassen’s algorithm on 2×2 matrices, and that is in fact not how it’s normally presented.

The form of Strassen’s algorithm that is of practical interest works not with 2×2 matrices, but with 2×2 block matrices, that is, 2×2 matrices whose elements are themselves matrices. Matrix multiplication has a very regular structure that looks exactly the same when multiplying block matrices as it does when multiplying matrices of scalars, you just need to be careful about operand order for multiplications because matrix multiplications, unlike multiplications in the scalar ring or field, are not commutative:

$\begin{pmatrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{pmatrix} \begin{pmatrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{pmatrix} = \begin{pmatrix} A_{11} B_{11} + A_{12} B_{21} & A_{11} B_{12} + A_{12} B_{22} \\ A_{21} B_{11} + A_{22} B_{21} & A_{21} B_{12} + A_{22} B_{22} \end{pmatrix}$

This looks identical to what we had before (except it’s now all upper case), but the operations mean something different: before we were multiplying scalars with each other, so we had something like individual real number multiplications (or floating-point multiplications in actual numerical code), now the products are matrix-matrix products (which are O(N³) operations when multiplying two square NxN matrices using the standard algorithm, a mix of multiplications and additions or possibly fused-multiply-adds) and the sums are matrix-matrix sums (which are O(N²) additions). And because what we’re describing here is a matrix multiplication algorithm to begin with, the smaller sub-matrix multiplications can also use Strassen’s algorithm if they want to.

In short, not only does each multiplication and addition in this block matrix represent many underlying scalar operations, but the relative cost of multiplications compared to additions keeps growing as N (the size of the blocks in our block matrix) increases. And that’s where Strassen’s algorithm is actually used in practice: eventually, once N becomes large enough, the many extra additions and irregular structure become worthwhile. It’s important to note that arithmetic operation count is not the only factor here: the extremely regular structure of conventional matrix multiplication, and the ease with which it can use FMAs, means that the arithmetic operations in a tuned matrix multiplication kernel are a lot cheaper than you might expect, because these kernels tend to do an extremely good job of keeping the machine busy. With small matrices (and actual 4×4 matrices definitely fit that description), other overheads dominate. Somewhat larger blocks are mostly limited by memory subsystem effects and carefully manage their footprint in the various cache levels and TLBs, which tends to include techniques such as extra memory copying and packing stages that might seem like a waste because they’re not spamming floating-point math, but are worth the cost because they make everything else go faster. With much larger blocks, eventually Strassen can become attractive, although the actual cut-off varies wildly between architectures. I’ve seen reports of Strassen multiplication becoming useful for blocks as small as 128×128, but more usually, it’s used for blocks that are at least 512×512 elements, if not much more. All this assuming that its less-than-ideal numerical properties are not a show-stopper for the given application.

AlphaTensor

That brings us back to AlphaTensor. What DeepMind has implemented is, in essence, a variant of the neural-net-guided Monte Carlo Tree Search they’ve been using with great success to build Chess and Go AIs. This time, they use it to search for small-matrix multiplication algorithms. I’m not the right person to go into the details here, and it’s not important for the purposes of this post. We can just treat this procedure as a black-box optimizer that can do a guided search for matrix-multiplication algorithms meeting a user-specified set of criteria.

One obvious criterion would be optimizing for minimum multiplication count, and in fact one of the discoveries reported is a “Strassen-like” factorization that uses 47 multiplications to multiply two 4×4 matrices (compared to 7*7=49 multiplications for two nested applications of Strassen’s 2×2 algorithm, or 64 multiplications for the direct algorithm). And separate from the more theoretical concern of minimum operation count for a “practical” algorithm, the same optimizer can also incorporate measured runtimes on actual devices into its operation, and thus be used to perform a guided search for algorithms that are fast on certain devices.

That’s the process used to yield figure 5 in the paper, which reports speed-ups of optimized 4×4 matrix multiplication factorizations against the baseline (which is the regular algorithm). Note that what this does is one high-level 4×4 block matrix multiply using the algorithm in question at the top level; all the smaller lower-level matrix multiplies use regular matrix multiplication. Also note that the reported matrix sizes start at 8192×8192, i.e. in this case, the 2048×2048-element blocks are multiplied regularly.

Furthermore, note that the reported speed-ups that the press release refers to as “10-20%” are compared to the baseline of using regular matrix multiplication, not against Strassen (in this case, “Strassen-square”, the aforementioned 4×4 factorization obtained by applying Strassen’s algorithm twice to a 4×4 matrix). Strassen results are reported in the figure as well, they’re the red bars.

In short, the new algorithms are a sizeable improvement, especially on the TPU, but the perspective here is important: this type of algorithm is interesting when individual multiplications are quite expensive, in this case because they are actually multiplications of fairly large matrices themselves (2048×2048 blocks are nothing to sneeze at).

If you’re actually multiplying 4×4 matrices of scalars, the standard algorithm remains the way to go, and is likely to stay that way for the foreseeable future.

From → Maths