5. 【CUDA】内存优化:矩阵转置解析
为什么朴素的矩阵转置算法性能差?
在开始之前,我们需要先理解三个关键事实。
GPU 内存访问的硬件真相
事实 1:GPU 以 128 字节为单位访问全局内存
想象 GPU 的内存系统就像超市的收银台。即使你只买一瓶水,收银员也要扫描整个购物篮(哪怕只有一件商品)。GPU 也是如此:
单个float = 4 bytes
一个cache line = 32个连续的float = 128 bytes
即使你只读取 1 个 float,GPU 也会从内存加载连续的 128 字节(32 个 float)。
事实 2:Warp 的 32 个线程同时发起内存访问
一个 warp 有 32 个线程,执行float x = input[i]时,这 32 个线程会同一时刻向内存发出请求。GPU 的内存控制器会分析这 32 个请求:
- 最佳情况:32 个线程访问的地址刚好在同一个 128 字节的 cache line 内 → 只需 1 次内存事务
- 最坏情况:32 个线程访问的地址分散在 32 个不同的 cache line → 需要 32 次内存事务
后者的性能只有前者的 1/32!
事实 3:Shared Memory 有 32 个 Bank
Shared Memory 虽然快,但有自己的限制。它被分成 32 个 bank(就像银行有 32 个窗口),每个 bank 每个周期只能服务 1 个请求。如果多个线程同时访问同一个 bank 的不同地址,就会发生"排队"(bank conflict)。
Shared Memory布局(简化版):
地址: 0 1 2 ... 31 | 32 33 34 ... 63 | ...
Bank: 0 1 2 ... 31 | 0 1 2 ... 31 | ...
└────────────────────┘ └────────────────────┘
第一组32个地址 第二组32个地址
每 32 个连续的 4-byte 地址会循环映射到 32 个 bank。
问题:为什么朴素矩阵转置性能差?
让我们用一个8×8 的矩阵来具体演示(实际会更大,但原理相同)。
矩阵在内存中的存储(行主序)
输入矩阵 A (8×8):
列0 列1 列2 列3 列4 列5 列6 列7
行0: 0 1 2 3 4 5 6 7
行1: 8 9 10 11 12 13 14 15
行2:16 17 18 19 20 21 22 23
行3:24 25 26 27 28 29 30 31
行4:32 33 34 35 36 37 38 39
行5:40 41 42 43 44 45 46 47
行6:48 49 50 51 52 53 54 55
行7:56 57 58 59 60 61 62 63
在内存中的线性存储(input数组):
索引: 0 1 2 3 4 5 6 7 8 9 10 11 ... 63
值: 0 1 2 3 4 5 6 7 8 9 10 11 ... 63
矩阵按行主序存储,意思是第 0 行的所有元素先连续存储,然后是第 1 行,以此类推。
朴素版本的代码
__global__ void transpose_naive(float* input, float* output, int N) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
// 从input读取,写入output的转置位置
output[col * N + row] = input[row * N + col];
}
假设我们启动一个 8×8 的线程块(threadIdx.x 和 threadIdx.y 都是 0-7)。
第一步:读取 input(性能 OK)
考虑第 0 行的 8 个线程(threadIdx.y = 0, threadIdx.x = 0-7):
线程 threadIdx.x threadIdx.y 计算的索引 读取的值
T0 0 0 0*8 + 0 = 0 input[0] = 0
T1 1 0 0*8 + 1 = 1 input[1] = 1
T2 2 0 0*8 + 2 = 2 input[2] = 2
T3 3 0 0*8 + 3 = 3 input[3] = 3
T4 4 0 0*8 + 4 = 4 input[4] = 4
T5 5 0 0*8 + 5 = 5 input[5] = 5
T6 6 0 0*8 + 6 = 6 input[6] = 6
T7 7 0 0*8 + 7 = 7 input[7] = 7
这 8 个线程读取的是 input[0]到 input[7],连续的 8 个地址!它们恰好在同一个 cache line 内(128 字节可以容纳 32 个 float),所以只需要1 次内存事务。✅ 这是高效的。
第二步:写入 output(性能灾难!)
现在看写入。转置意味着 input[row][col] → output[col][row]。
同样是这 8 个线程,它们写入的地址是:
线程 threadIdx.x threadIdx.y 计算的输出索引 写入位置
T0 0 0 0*8 + 0 = 0 output[0] ← input[0][0]
T1 1 0 1*8 + 0 = 8 output[8] ← input[0][1]
T2 2 0 2*8 + 0 = 16 output[16] ← input[0][2]
T3 3 0 3*8 + 0 = 24 output[24] ← input[0][3]
T4 4 0 4*8 + 0 = 32 output[32] ← input[0][4]
T5 5 0 5*8 + 0 = 40 output[40] ← input[0][5]
T6 6 0 6*8 + 0 = 48 output[48] ← input[0][6]
T7 7 0 7*8 + 0 = 56 output[56] ← input[0][7]
注意到问题了吗?这 8 个线程写入的地址是:0, 8, 16, 24, 32, 40, 48, 56
每两个地址之间相隔 8 个 float(32 字节)。而一个 cache line 只有 128 字节(32 个 float)。
可视化这些地址在 cache line 中的分布:
Cache Line 0 (地址 0-31):
[0] [1] [2] [3] [4] [5] [6] [7] [8] [9] ... [31]
↑写 ↑写
T0写这里 T1写这里
Cache Line 1 (地址 32-63):
[32][33][34][35][36][37][38][39][40][41]...[56]...
↑写 ↑写 ↑写
T4写 T5写 T7写
这 8 个线程的写入分散在多个不同的 cache line中!GPU 无法合并这些写入,需要发起多次内存事务。这被称为非合并访问(Uncoalesced Access)。
优化方案:使用 Shared Memory 作为"中转站"
优化的核心思想:把转置操作分成两步
- 第一步:从全局内存合并读取到 Shared Memory(按行读)
- 第二步:从 Shared Memory 读取,合并写入到全局内存(按列读 Shared Memory,但按行写全局内存)
优化代码
#define TILE_SIZE 32
__global__ void transpose_optimized(float* input, float* output, int N) {
// 分配Shared Memory:注意+1!
__shared__ float tile[TILE_SIZE][TILE_SIZE + 1];
int x = blockIdx.x * TILE_SIZE + threadIdx.x;
int y = blockIdx.y * TILE_SIZE + threadIdx.y;
// 步骤1:合并读取input到Shared Memory
tile[threadIdx.y][threadIdx.x] = input[y * N + x];
__syncthreads(); // 确保所有线程都写完
// 步骤2:从Shared Memory读取(转置),合并写入output
int x_out = blockIdx.y * TILE_SIZE + threadIdx.x; // 注意:x和y交换了
int y_out = blockIdx.x * TILE_SIZE + threadIdx.y;
output[y_out * N + x_out] = tile[threadIdx.x][threadIdx.y]; // 注意索引交换
}
用 8×8 的 tile 具体演示
假设我们处理输入矩阵的左上角 8×8 块(简化为 8 而非 32,原理相同)。
步骤 1:合并读取 input → Shared Memory
线程块中第 0 行的 8 个线程(threadIdx.y=0, threadIdx.x=0-7):
读取 input:
线程 读取地址 读取值 写入Shared Memory位置
T0 input[0] = 0 → tile[0][0]
T1 input[1] = 1 → tile[0][1]
T2 input[2] = 2 → tile[0][2]
T3 input[3] = 3 → tile[0][3]
T4 input[4] = 4 → tile[0][4]
T5 input[5] = 5 → tile[0][5]
T6 input[6] = 6 → tile[0][6]
T7 input[7] = 7 → tile[0][7]
读取地址连续(0,1,2,3,4,5,6,7),完美合并!✅
所有线程执行后,Shared Memory 的内容:
tile数组(在Shared Memory中):
[0] [1] [2] [3] [4] [5] [6] [7]
[0] 0 1 2 3 4 5 6 7
[1] 8 9 10 11 12 13 14 15
[2] 16 17 18 19 20 21 22 23
[3] 24 25 26 27 28 29 30 31
[4] 32 33 34 35 36 37 38 39
[5] 40 41 42 43 44 45 46 47
[6] 48 49 50 51 52 53 54 55
[7] 56 57 58 59 60 61 62 63
步骤 2:转置访问 Shared Memory,合并写入 output
关键在这里!注意代码中:
output[y_out * N + x_out] = tile[threadIdx.x][threadIdx.y];
↑ ↑
交换了索引!
同样是第 0 行的 8 个线程(threadIdx.y=0, threadIdx.x=0-7):
读取 Shared Memory:
线程 读取位置 读取值 写入output地址
T0 tile[0][0] = 0 → output[0]
T1 tile[1][0] = 8 → output[1]
T2 tile[2][0] = 16 → output[2]
T3 tile[3][0] = 24 → output[3]
T4 tile[4][0] = 32 → output[4]
T5 tile[5][0] = 40 → output[5]
T6 tile[6][0] = 48 → output[6]
T7 tile[7][0] = 56 → output[7]
写入的 output 地址是连续的(0,1,2,3,4,5,6,7),完美合并!✅
转置发生在哪里?
转置发生在访问 Shared Memory 时:
- 读取 input 时按
tile[row][col]存储 - 写入 output 时按
tile[col][row]读取
这样,对 Shared Memory 的列访问(非连续)变成了对 output 的行访问(连续)!
神秘的 +1:为什么 [TILE_SIZE + 1] 能避免 Bank Conflict?
Bank Conflict 的本质
还记得 Shared Memory 有 32 个 bank 吗?地址到 bank 的映射规则是:
bank_id = (地址 / 4) % 32
也就是说,每隔 32 个 4 字节的元素,bank ID 就会循环一次。
不加 +1 时的灾难
假设 TILE_SIZE = 32,没有 +1:
__shared__ float tile[32][32]; // 每行32个元素
内存布局:
地址(以4-byte为单位):
行0: 0 1 2 ... 31 | Bank: 0 1 2 ... 31
行1: 32 33 34 ... 63 | Bank: 0 1 2 ... 31
行2: 64 65 66 ... 95 | Bank: 0 1 2 ... 31
...
注意:每行的起始元素(0, 32, 64...)都映射到 Bank 0!
问题出现在转置访问时:
当我们执行tile[threadIdx.x][threadIdx.y],比如读取第 0 列:
32个线程访问第0列:
Thread 0: tile[0][0] → 地址 0 → Bank 0
Thread 1: tile[1][0] → 地址 32 → Bank 0 ❌
Thread 2: tile[2][0] → 地址 64 → Bank 0 ❌
Thread 3: tile[3][0] → 地址 96 → Bank 0 ❌
...
Thread 31: tile[31][0] → 地址 992 → Bank 0 ❌
灾难!所有 32 个线程都在访问 Bank 0 的不同地址!
由于一个 bank 每次只能服务 1 个请求,这 32 个访问必须串行执行,需要 32 个周期。这就是32-way bank conflict,性能下降 32 倍!
加上 +1 的魔法
__shared__ float tile[32][33]; // 每行33个元素(多1个)
内存布局:
地址(以4-byte为单位):
行0: 0 1 2 ... 31 32 | Bank: 0 1 2 ... 31 0
行1: 33 34 35 ... 63 64 65| Bank: 1 2 3 ... 31 0 1
行2: 66 67 68 ... 97 98 99| Bank: 2 3 4 ... 31 0 1 2
...
注意:每行有 33 个元素,所以下一行的起始地址不再对齐到 Bank 0!
现在访问第 0 列:
32个线程访问第0列:
Thread 0: tile[0][0] → 地址 0 → Bank 0 ✅
Thread 1: tile[1][0] → 地址 33 → Bank 1 ✅
Thread 2: tile[2][0] → 地址 66 → Bank 2 ✅
Thread 3: tile[3][0] → 地址 99 → Bank 3 ✅
...
Thread 31: tile[31][0] → 地址 1023→ Bank 31 ✅
完美!所有 32 个线程访问不同的 bank,可以并行执行,只需 1 个周期!
为什么是 +1 而不是 +2 或其他数?
+1 是最小的能打破对齐的数字:
- +1:下一行偏移 33 个元素 = 32 + 1,bank ID 偏移 1
- +2:偏移 34 = 32 + 2,bank ID 偏移 2,也能避免冲突,但浪费更多空间
实际上,任何不是 32 的倍数的偏移都可以,但+1 最节省 Shared Memory。
性能对比
在 H100 上测试 32×32 的 tile,转置一个 4096×4096 的矩阵:
| 版本 | 全局内存带宽利用率 | Shared Memory 冲突率 | 性能(GB/s) |
|---|---|---|---|
| 朴素版本 | ~20% | N/A | 600 |
| 优化版本(无+1) | ~90% | 96.9% (32-way) | 1200 |
| 优化版本(有+1) | ~90% | 0% | 2800 |
朴素版本由于非合并访问,带宽只有理论峰值的 20%。
优化版本即使有严重的 bank conflict,也比朴素版本快 2 倍,因为合并访问的收益太大了。
加上+1 消除 bank conflict 后,性能再提升 2.3 倍,达到接近 HBM3 理论带宽(3 TB/s)的 93%!
核心要点总结
-
合并访问是 GPU 性能的生命线
- 同一 warp 的线程应该访问连续的内存地址
- 非合并访问会导致 30 倍以上的性能损失
-
Shared Memory 是优化的关键中转站
- 允许我们将不规则的全局内存访问模式转换为规则的
- 即使有额外的读写步骤,总体性能仍然大幅提升
-
Bank Conflict 不可忽视
- 看似微小的内存布局调整(+1)可以带来 2 倍以上性能差异
- 转置、矩阵运算等操作特别容易触发 bank conflict
-
内存访问模式决定一切
- GPU 硬件设计高度优化了连续访问模式
- 算法设计时必须考虑数据的物理布局