Coelacanth's Dream

RDNA 4 における WMMA 命令と改善

RDNA 3/GFX11 アーキテクチャ からサポートされている WMMA (Wave Matrix Multiply-accumulate) 命令だが、RDNA 4/GFX12 アーキテクチャ でも引き続きサポートしており、そして対応フォーマットの追加や対応する行列レイアウトの追加、必要なベクタレジスタサイズの削減などが実装されている。

使用ベクタレジスタの削減

RDNA 3/GFX11 における WMMA 命令は内部的に複数のドット積命令に分解して実行する方式となっていた。
また、性能を最大化するには積算に使われる入力行列 A, B の配置を工夫する必要があり、thread/work-item/vector-lane 0-15 のデータを 16-31 にも複製する必要があった。これは Wave32 モードの場合であり、Wave64 の場合は 0-31 のデータを 32-63 に複製する必要がある。
RDNA 3/GFX11 では 32-bit VGPR (ベクタレジスタ) を 2つの 16-bit VGPR に分割する機能 (16-bit VGPR-pairs) を持ち、また WMMA 命令は入力行列 A, B に F16/BF16/IU8 のフォーマットのみをサポートしている。
そのため、理論上は 16x16 の行列は 4 VGPR (4 [vgpr]* 2 [16-bit pair] * 32 [lane]) に収まるはずだが、先の理由のため 8 VGPR を割り当てる必要があった。
また、アキュムレータとなる行列 C も出力フォーマットが F16/BF16 の場合でも 8 VGPR を必要とする。
これは WMMA 命令では設定したフラグによって 32-bit VGPR の上半分と下半分のどちらに結果を格納するか決まるためとされている。

それが RDNA 4/GFX12 では一部データを複製する必要がなくなり、行列 A, B が 4 VGPR (Wave32) に収まるようになった。
また、結果の格納に関するフラグを設定する必要もなくなり、行列 C も 4 VGPR に収まるようになった。

    +/* This pass lowers cooperative matrix.
    + *
    + * On GFX11, the A&B matrices needs to be replicated, lanes 0..15 are replicated
    + * to 16..31 and for wave64 also into lanes 32..47 and 48..63. A&B matrices are
    + * always vectors of 16 elements.
    + *
    + * On GFX12, there is no data replication and the matrices layout is described
    + * as below:
    + *
    + * Wave32:
    + * A&B:
    + *         0..15  | 16..31 (lanes)
    + * v0 lo:  row 0  | row 4
    + * v0 hi:  row 1  | row 5
    + * v1 lo:  row 2  | row 6
    + * v1 hi:  row 3  | row 7
    + * v2 lo:  row 8  | row 12
    + * v2 hi:  row 9  | row 13
    + * v3 lo:  row 10 | row 14
    + * v3 hi:  row 11 | row 15
    + *
    + * C:
    + *         0..15  | 16..31 (lanes)
    + * v0 lo:  row 0  | row 8
    + * v0 hi:  row 1  | row 9
    + * v1 lo:  row 2  | row 10
    + * v1 hi:  row 3  | row 11
    + * v2 lo:  row 4  | row 12
    + * v2 hi:  row 5  | row 13
    + * v3 lo:  row 6  | row 14
    + * v3 hi:  row 7  | row 15
    + *
    + * Wave64:
    + * A&B:
    + *         0..15 | 16..31 | 32..47 | 48..63 (lanes)
    + * v0 lo:  row 0 | row 4  | row 8  | row 12
    + * v0 hi:  row 1 | row 5  | row 9  | row 13
    + * v1 lo:  row 2 | row 6  | row 10 | row 14
    + * v1 hi:  row 3 | row 7  | row 11 | row 15
    + *
    + * C:
    + *         0..15 | 16..31 | 32..47 | 48..63 (lanes)
    + * v0 lo:  row 0 | row 8  | row 4  | row 12
    + * v0 hi:  row 1 | row 9  | row 5  | row 13
    + * v1 lo:  row 2 | row 10 | row 6  | row 14
    + * v1 hi:  row 3 | row 11 | row 7  | row 15
    + */

対応フォーマットとレイアウト

RDNA 3/GFX11 では WMMA 命令の入力フォーマットに F16/BF16/IU8/IU4、出力フォーマットに F32/F16/BF16/I32 をサポートし、行列のレイアウトは入出力ともに 16x16 のみをサポートしていた。
RDNA 4/GFX12 では入力フォーマットに FP8 (1 bit sign, 4 bit exponent, 3 bit mantissa) と BF8 (1 bit sign, 5 bit exponent, 2 bit mantissa) のサポートが追加され、
入力フォーマットが IU4 の場合は 16x16x32 (A: 16x32, B: 32x16, C: 16x16) のレイアウトもサポートする。

RDNA 4/GFX12 がサポートする FP8/BF8 は OCP (Open Compute Project) 準拠のフォーマットであり、MI300/CDNA 3 がサポートする FP8/BF8 とは異なって無限大 (Inf) をサポートしている。

また、疎行列を入力に取る SWMMAC (Wave Matrix(sparse) Multiply-Accumulate) 命令のサポートが追加された。
SWMMAC 命令の対応フォーマットは WMMA 命令と同じであり、行列のレイアウトは CDNA 系の SMFMAC 命令と同じだとすると 16x16x32 (A: sparse 32x16, B: 16x32, C: 16x16)、入力フォーマットが IU4 の場合は 16x16x64 (A: sparse 16x64, B: 64x16, C: 16x16) をサポートする。

RDNA 4/GFX12 でも RDNA 3/GFX11 の内部的に複数のドット積命令に分解して発行する形式を継続するのか、それとも CDNA 系のように専用の行列演算ユニットを実装するのかまだ不明である。
RDNA 3/GFX11 から行列演算の IPC が向上しているのかについても、RDNA 4/GFX12 におけるサイクル数が公開されていないため不明。
しかし、モデルのサイズや VRAM の使用サイズの削減に役立つ FP8/BF8 フォーマットのサポート、疎行列のサポート、必要な VGPR の削減と RDNA 3/GFX11 から進化しているのは確かである。

参考リンク