Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

5. 【CUDA】内存优化:矩阵转置解析

为什么朴素的矩阵转置算法性能差?

在开始之前,我们需要先理解三个关键事实。


GPU 内存访问的硬件真相

事实 1:GPU 以 128 字节为单位访问全局内存

想象 GPU 的内存系统就像超市的收银台。即使你只买一瓶水,收银员也要扫描整个购物篮(哪怕只有一件商品)。GPU 也是如此:

单个float = 4 bytes
一个cache line = 32个连续的float = 128 bytes

即使你只读取 1 个 float,GPU 也会从内存加载连续的 128 字节(32 个 float)。

事实 2:Warp 的 32 个线程同时发起内存访问

一个 warp 有 32 个线程,执行float x = input[i]时,这 32 个线程会同一时刻向内存发出请求。GPU 的内存控制器会分析这 32 个请求:

  • 最佳情况:32 个线程访问的地址刚好在同一个 128 字节的 cache line 内 → 只需 1 次内存事务
  • 最坏情况:32 个线程访问的地址分散在 32 个不同的 cache line → 需要 32 次内存事务

后者的性能只有前者的 1/32

事实 3:Shared Memory 有 32 个 Bank

Shared Memory 虽然快,但有自己的限制。它被分成 32 个 bank(就像银行有 32 个窗口),每个 bank 每个周期只能服务 1 个请求。如果多个线程同时访问同一个 bank 的不同地址,就会发生"排队"(bank conflict)。

Shared Memory布局(简化版):
地址:  0    1    2   ...  31  |  32   33   34  ...  63  | ...
Bank:  0    1    2   ...  31  |  0    1    2   ...  31  | ...
       └────────────────────┘    └────────────────────┘
          第一组32个地址           第二组32个地址

每 32 个连续的 4-byte 地址会循环映射到 32 个 bank。


问题:为什么朴素矩阵转置性能差?

让我们用一个8×8 的矩阵来具体演示(实际会更大,但原理相同)。

矩阵在内存中的存储(行主序)

输入矩阵 A (8×8):
    列0  列1  列2  列3  列4  列5  列6  列7
行0: 0    1    2    3    4    5    6    7
行1: 8    9   10   11   12   13   14   15
行2:16   17   18   19   20   21   22   23
行3:24   25   26   27   28   29   30   31
行4:32   33   34   35   36   37   38   39
行5:40   41   42   43   44   45   46   47
行6:48   49   50   51   52   53   54   55
行7:56   57   58   59   60   61   62   63

在内存中的线性存储(input数组):
索引: 0  1  2  3  4  5  6  7  8  9 10 11 ...  63
值:   0  1  2  3  4  5  6  7  8  9 10 11 ... 63

矩阵按行主序存储,意思是第 0 行的所有元素先连续存储,然后是第 1 行,以此类推。

朴素版本的代码

__global__ void transpose_naive(float* input, float* output, int N) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;

    // 从input读取,写入output的转置位置
    output[col * N + row] = input[row * N + col];
}

假设我们启动一个 8×8 的线程块(threadIdx.x 和 threadIdx.y 都是 0-7)。

第一步:读取 input(性能 OK)

考虑第 0 行的 8 个线程(threadIdx.y = 0, threadIdx.x = 0-7):

线程   threadIdx.x  threadIdx.y  计算的索引        读取的值
T0        0            0         0*8 + 0 = 0       input[0]  = 0
T1        1            0         0*8 + 1 = 1       input[1]  = 1
T2        2            0         0*8 + 2 = 2       input[2]  = 2
T3        3            0         0*8 + 3 = 3       input[3]  = 3
T4        4            0         0*8 + 4 = 4       input[4]  = 4
T5        5            0         0*8 + 5 = 5       input[5]  = 5
T6        6            0         0*8 + 6 = 6       input[6]  = 6
T7        7            0         0*8 + 7 = 7       input[7]  = 7

这 8 个线程读取的是 input[0]到 input[7],连续的 8 个地址!它们恰好在同一个 cache line 内(128 字节可以容纳 32 个 float),所以只需要1 次内存事务。✅ 这是高效的。

第二步:写入 output(性能灾难!)

现在看写入。转置意味着 input[row][col] → output[col][row]。

同样是这 8 个线程,它们写入的地址是:

线程   threadIdx.x  threadIdx.y  计算的输出索引            写入位置
T0        0            0         0*8 + 0 = 0           output[0]   ← input[0][0]
T1        1            0         1*8 + 0 = 8           output[8]   ← input[0][1]
T2        2            0         2*8 + 0 = 16          output[16]  ← input[0][2]
T3        3            0         3*8 + 0 = 24          output[24]  ← input[0][3]
T4        4            0         4*8 + 0 = 32          output[32]  ← input[0][4]
T5        5            0         5*8 + 0 = 40          output[40]  ← input[0][5]
T6        6            0         6*8 + 0 = 48          output[48]  ← input[0][6]
T7        7            0         7*8 + 0 = 56          output[56]  ← input[0][7]

注意到问题了吗?这 8 个线程写入的地址是:0, 8, 16, 24, 32, 40, 48, 56

每两个地址之间相隔 8 个 float(32 字节)。而一个 cache line 只有 128 字节(32 个 float)。

可视化这些地址在 cache line 中的分布

Cache Line 0 (地址 0-31):
[0] [1] [2] [3] [4] [5] [6] [7] [8] [9] ... [31]
 ↑写                           ↑写
T0写这里                       T1写这里

Cache Line 1 (地址 32-63):
[32][33][34][35][36][37][38][39][40][41]...[56]...
 ↑写                           ↑写        ↑写
T4写                           T5写       T7写

这 8 个线程的写入分散在多个不同的 cache line中!GPU 无法合并这些写入,需要发起多次内存事务。这被称为非合并访问(Uncoalesced Access)


优化方案:使用 Shared Memory 作为"中转站"

优化的核心思想:把转置操作分成两步

  1. 第一步:从全局内存合并读取到 Shared Memory(按行读)
  2. 第二步:从 Shared Memory 读取,合并写入到全局内存(按列读 Shared Memory,但按行写全局内存)

优化代码

#define TILE_SIZE 32

__global__ void transpose_optimized(float* input, float* output, int N) {
    // 分配Shared Memory:注意+1!
    __shared__ float tile[TILE_SIZE][TILE_SIZE + 1];

    int x = blockIdx.x * TILE_SIZE + threadIdx.x;
    int y = blockIdx.y * TILE_SIZE + threadIdx.y;

    // 步骤1:合并读取input到Shared Memory
    tile[threadIdx.y][threadIdx.x] = input[y * N + x];

    __syncthreads();  // 确保所有线程都写完

    // 步骤2:从Shared Memory读取(转置),合并写入output
    int x_out = blockIdx.y * TILE_SIZE + threadIdx.x;  // 注意:x和y交换了
    int y_out = blockIdx.x * TILE_SIZE + threadIdx.y;

    output[y_out * N + x_out] = tile[threadIdx.x][threadIdx.y];  // 注意索引交换
}

用 8×8 的 tile 具体演示

假设我们处理输入矩阵的左上角 8×8 块(简化为 8 而非 32,原理相同)。

步骤 1:合并读取 input → Shared Memory

线程块中第 0 行的 8 个线程(threadIdx.y=0, threadIdx.x=0-7):

读取 input:
线程   读取地址      读取值    写入Shared Memory位置
T0    input[0]  =  0    →    tile[0][0]
T1    input[1]  =  1    →    tile[0][1]
T2    input[2]  =  2    →    tile[0][2]
T3    input[3]  =  3    →    tile[0][3]
T4    input[4]  =  4    →    tile[0][4]
T5    input[5]  =  5    →    tile[0][5]
T6    input[6]  =  6    →    tile[0][6]
T7    input[7]  =  7    →    tile[0][7]

读取地址连续(0,1,2,3,4,5,6,7),完美合并!✅

所有线程执行后,Shared Memory 的内容:

tile数组(在Shared Memory中):
       [0] [1] [2] [3] [4] [5] [6] [7]
[0]     0   1   2   3   4   5   6   7
[1]     8   9  10  11  12  13  14  15
[2]    16  17  18  19  20  21  22  23
[3]    24  25  26  27  28  29  30  31
[4]    32  33  34  35  36  37  38  39
[5]    40  41  42  43  44  45  46  47
[6]    48  49  50  51  52  53  54  55
[7]    56  57  58  59  60  61  62  63

步骤 2:转置访问 Shared Memory,合并写入 output

关键在这里!注意代码中:

output[y_out * N + x_out] = tile[threadIdx.x][threadIdx.y];
                                    ↑            ↑
                                      交换了索引!

同样是第 0 行的 8 个线程(threadIdx.y=0, threadIdx.x=0-7):

读取 Shared Memory:
线程   读取位置              读取值    写入output地址
T0    tile[0][0]    =  0    →    output[0]
T1    tile[1][0]    =  8    →    output[1]
T2    tile[2][0]    = 16    →    output[2]
T3    tile[3][0]    = 24    →    output[3]
T4    tile[4][0]    = 32    →    output[4]
T5    tile[5][0]    = 40    →    output[5]
T6    tile[6][0]    = 48    →    output[6]
T7    tile[7][0]    = 56    →    output[7]

写入的 output 地址是连续的(0,1,2,3,4,5,6,7),完美合并!✅

转置发生在哪里?

转置发生在访问 Shared Memory 时:

  • 读取 input 时按tile[row][col]存储
  • 写入 output 时按tile[col][row]读取

这样,对 Shared Memory 的列访问(非连续)变成了对 output 的行访问(连续)!


神秘的 +1:为什么 [TILE_SIZE + 1] 能避免 Bank Conflict?

Bank Conflict 的本质

还记得 Shared Memory 有 32 个 bank 吗?地址到 bank 的映射规则是:

bank_id = (地址 / 4) % 32

也就是说,每隔 32 个 4 字节的元素,bank ID 就会循环一次。

不加 +1 时的灾难

假设 TILE_SIZE = 32,没有 +1:

__shared__ float tile[32][32];  // 每行32个元素

内存布局

地址(以4-byte为单位):
行0: 0   1   2   ... 31   | Bank: 0  1  2  ... 31
行1: 32  33  34  ... 63   | Bank: 0  1  2  ... 31
行2: 64  65  66  ... 95   | Bank: 0  1  2  ... 31
...

注意:每行的起始元素(0, 32, 64...)都映射到 Bank 0!

问题出现在转置访问时

当我们执行tile[threadIdx.x][threadIdx.y],比如读取第 0 列:

32个线程访问第0列:
Thread 0:  tile[0][0]   → 地址 0   → Bank 0
Thread 1:  tile[1][0]   → 地址 32  → Bank 0  ❌
Thread 2:  tile[2][0]   → 地址 64  → Bank 0  ❌
Thread 3:  tile[3][0]   → 地址 96  → Bank 0  ❌
...
Thread 31: tile[31][0]  → 地址 992 → Bank 0  ❌

灾难!所有 32 个线程都在访问 Bank 0 的不同地址!

由于一个 bank 每次只能服务 1 个请求,这 32 个访问必须串行执行,需要 32 个周期。这就是32-way bank conflict,性能下降 32 倍!

加上 +1 的魔法

__shared__ float tile[32][33];  // 每行33个元素(多1个)

内存布局

地址(以4-byte为单位):
行0: 0   1   2   ... 31  32 |  Bank: 0  1  2  ... 31  0
行1: 33  34  35  ... 63  64 65|  Bank: 1  2  3  ... 31  0  1
行2: 66  67  68  ... 97  98 99|  Bank: 2  3  4  ... 31  0  1  2
...

注意:每行有 33 个元素,所以下一行的起始地址不再对齐到 Bank 0!

现在访问第 0 列

32个线程访问第0列:
Thread 0:  tile[0][0]   → 地址 0   → Bank 0  ✅
Thread 1:  tile[1][0]   → 地址 33  → Bank 1  ✅
Thread 2:  tile[2][0]   → 地址 66  → Bank 2  ✅
Thread 3:  tile[3][0]   → 地址 99  → Bank 3  ✅
...
Thread 31: tile[31][0]  → 地址 1023→ Bank 31 ✅

完美!所有 32 个线程访问不同的 bank,可以并行执行,只需 1 个周期!

为什么是 +1 而不是 +2 或其他数?

+1 是最小的能打破对齐的数字:

  • +1:下一行偏移 33 个元素 = 32 + 1,bank ID 偏移 1
  • +2:偏移 34 = 32 + 2,bank ID 偏移 2,也能避免冲突,但浪费更多空间

实际上,任何不是 32 的倍数的偏移都可以,但+1 最节省 Shared Memory。


性能对比

在 H100 上测试 32×32 的 tile,转置一个 4096×4096 的矩阵:

版本全局内存带宽利用率Shared Memory 冲突率性能(GB/s)
朴素版本~20%N/A600
优化版本(无+1)~90%96.9% (32-way)1200
优化版本(有+1)~90%0%2800

朴素版本由于非合并访问,带宽只有理论峰值的 20%。

优化版本即使有严重的 bank conflict,也比朴素版本快 2 倍,因为合并访问的收益太大了。

加上+1 消除 bank conflict 后,性能再提升 2.3 倍,达到接近 HBM3 理论带宽(3 TB/s)的 93%!


核心要点总结

  1. 合并访问是 GPU 性能的生命线

    • 同一 warp 的线程应该访问连续的内存地址
    • 非合并访问会导致 30 倍以上的性能损失
  2. Shared Memory 是优化的关键中转站

    • 允许我们将不规则的全局内存访问模式转换为规则的
    • 即使有额外的读写步骤,总体性能仍然大幅提升
  3. Bank Conflict 不可忽视

    • 看似微小的内存布局调整(+1)可以带来 2 倍以上性能差异
    • 转置、矩阵运算等操作特别容易触发 bank conflict
  4. 内存访问模式决定一切

    • GPU 硬件设计高度优化了连续访问模式
    • 算法设计时必须考虑数据的物理布局