在 kvoffload 这个版本的 sglang 的推理的投机采样(Speculative Sampling)阶段,引擎需要对生成的 TopK 候选 Token 索引进行去重,这就引入了本次的优化算子。
算子语义:
batch 将 mtp_step 个 top‑k index 行做去重,输出长度固定为 mtp_step * k,不足用 -1 填充场景特征:
bs=115,GPU 并行度(Occupancy)天然不足。我们经历了一个从“Naive 实现”到“指令级 Hacking”的完整过程,最终将耗时压缩了 120倍 ,从 2ms 压缩至 16us。
当我们拿到“去重”这个需求时,第一反应通常是:哈希表。这在 CPU 上是标准答案($O(N)$)。于是前人就写出了第一版 Kernel,在 Shared Memory 中维护一个哈希表。
__global__ void deduplicate_topk_kernel_v0(
const int* topk_indices,
int* topk_indices_spec,
int total_elements,
int k)
{
// ... setup shared memory ...
extern __shared__ int shared_mem[];
int* seen = shared_mem; // 哈希表
// 初始化哈希表
for (int i = tid; i < total_elements; i += blockDim.x) seen[i] = -1;
__syncthreads();
// 循环处理每个元素
for (int i = tid; i < total_elements; i += blockDim.x) {
int val = topk_indices[row * k + col];
// 简单的线性探测哈希
int hash_idx = val % total_elements;
// 【致命点1】探测循环:可能 1 次成功,也可能 10 次
// 导致 Warp 内线程严重发散 (Divergence)
for (int probe = 0; probe < total_elements; probe++) {
int idx = (hash_idx + probe) % total_elements;
// 【致命点2】原子操作竞争:所有线程抢着写
int old = atomicCAS(&seen[idx], -1, val);
if (old == -1) { // 抢到了,是新元素
int pos = atomicAdd(unique_count, 1);
unique_list[pos] = val; // 【致命点3】随机写,非合并访存
break;
} else if (old == val) {
break; // 已存在
}
}
}
}
NCU 的报告显示:
atomicCAS,硬件序列化严重。初步结论:在 GPU 上,“确定的执行路径” 往往比理论上的算法复杂度更重要。哈希表在 GPU 上通常不是好选择。同时,考虑到推理阶段的特性,输入的 2048 个索引中,很可能重复的索引并不多,导致了哈希表的冲突比理想情况糟糕得多。