从一张查找表到 4GB/s:HiFloat8 Cast 算子的工程化之路
HiFloat8 Cast 算子在 Atlas A2/A3 昇腾 NPU 上为 PyTorch 实现 FP16/BF16 ↔ HiFloat8 双向转换,通过半空间 LUT、动态 tiling 和 DataCopy 分支优化,在大数据量下单方向吞吐稳定在 4 GB/s 量级——接近当前软件查表实现的吞吐上限。
Ascend 950系列产品 已具备原生 HiFloat8 硬件能力,本算子主要面向尚无该硬件加速的 A2/A3 平台。
一、为什么需要 HiFloat8 Cast算子
低精度是当前大模型推理落地的主线趋势。推理最大的瓶颈之一是带宽:模型权重几十上百 GB,每生成一个 token 都要把权重从显存搬一遍。把权重从 16 位压到 8 位,搬运量直接砍半,吞吐就上去了。
业界已经有 FP8 标准(E4M3 / E5M2 两个变体),它把 8 位固定切成"1 位符号 + 4/5 位指数 + 3/2 位尾数"。HiFloat8 的思路不一样:指数和尾数的位数不固定,根据数值大小动态切分。
- 数值靠近 0 时:多分一些位给尾数,精度高
- 数值远离 0 时:多分一些位给指数,动态范围大
- 8 位里还要塞下 ±0、±Inf、NaN、subnormal 等特殊值
这种"分段编码"在大模型权重分布(绝大多数值集中在 0 附近)下精度更好,但代价是编解码逻辑更复杂——不能像 FP8 那样直接位移截断,必须做格式判断和位重排。
Ascend 950系列芯片已经把 HiFloat8 的编解码做进了硬件,可以像 FP16/FP32 一样原生处理。但当前主流在用的 Atlas A2、A3 还没有这条硬件通路——要在这两代芯片上跑低精度浮点量化,就得靠软件 kernel 把转换路径补出来。这也意味着:不具备 Ascend 950系列 开发条件的开发者,完全可以基于 A2/A3 + 本算子先把 HiFloat8 数据格式吃透,把上层的量化/校准/优化算法跑通。 等 Ascend 950系列 硬件就绪时,只要把转换路径切到原生指令,上层算法可以平滑迁移,前期的开发投入不会白费。
我们要解决的问题就是:把 FP16/BF16 张量在 A2/A3 NPU 上高效地转成 HiFloat8,再转回来。一切下游的 HiFloat8 量化推理/训练特性,都依赖这个最底层的转换算子。
二、算子全景:四种模式,一个 kernel
PyTorch 用户只看到两个函数:
from amct_ops.hifloat8_cast import encode_to_hifloat8, decode_from_hifloat8
x = torch.randn(1024, 256, dtype=torch.bfloat16, device='npu')
y = encode_to_hifloat8(x) # FP16/BF16 → uint8(HiFloat8)
z = decode_from_hifloat8(y, torch.bfloat16) # uint8 → FP16/BF16
HiFloat8 不是 PyTorch 原生类型,所以编码后的张量用
uint8承载——一个字节一个值,shape 不变。
底层根据输入 dtype 自动派发到 4 种 castMode:
| castMode | 转换 | 实现要点 |
|---|---|---|
| 0 | FP16 → HiFloat8 | 半空间 LUT:32768 条 × 1 B = 32 KB UB |
| 1 | BF16 → HiFloat8 | 半空间 LUT:32768 条 × 1 B = 32 KB UB |
| 2 | HiFloat8 → FP16 | LUT:256 条 × 2 B = 512 B UB |
| 3 | HiFloat8 → BF16 | LUT:256 条 × 2 B = 512 B UB |
四种模式共用同一个 device kernel KernelHiFloat8CastLut,差异只在 LUT(查找表)内容和 input/output 的 dtype。
为什么用查找表? 因为前面说过,HiFloat8 的编解码逻辑很碎:要判断指数段、处理 subnormal、特判 ±Inf/NaN……如果在 device 上一条条算,分支多、指令多。但 FP16/BF16 一共只有 65536 种位模式,HiFloat8 只有 256 种——干脆在 host(CPU)上预先把所有结果算出来打成一张表,device 上只剩"查表 + 写回",又快又简单。
三、核心优化一:半空间 LUT —— 用对称性省下一半 UB
最初版本简单粗暴:65536 条 LUT,把 FP16/BF16 的所有位模式直接映射到 HiFloat8。
问题马上来了。65536 × 1 字节 = 64 KB UB。A2 平台 UB 标称 256 KB,CANN 运行时还要保留一部分,实际只剩 ~192 KB。这 64 KB 放进去后,剩余空间还要装 input 队列、output 队列,最后能切给一个 tile(一次处理的一小片数据)的 UB 就很捉襟见肘。tile 切得越小,启动次数越多,吞吐越低。
观察到一个关键事实:HiFloat8 的符号位与幅值正交。HiFloat8 编码把最高位作为符号、低 7 位作为幅值;而 FP16/BF16 的最高位也是符号位。于是同一个绝对值不论正负,HiFloat8 编码的低 7 位完全相同:
encode(-x) = encode(x) | 0x80 (当 encode(x) ≠ 0 时)
encode(-0) = 0x00 (0 不带符号)
这里说的"对称"不是 IEEE 浮点严格意义上的负零/NaN 对称(BF16 的 NaN 编码就不严格对称),而是HiFloat8 编码侧的对称——低 7 位幅值只取决于源浮点的低 15 位,与符号位无关。
那 LUT 完全没必要存两份。只存正半空间 32768 条"绝对值表",编码时先剥离符号、查表、再把符号补回去:
LocalTensor<uint16_t> xU16 = xLocal.ReinterpretCast<uint16_t>(); // BF16/FP16 的位视图
uint16_t v = xU16.GetValue(i); // 取一个 FP16/BF16 值
uint8_t sign = static_cast<uint8_t>(v >> 15); // 提取最高位(符号)
uint8_t mag = lut.GetValue(v & 0x7FFFu); // 用绝对值查表得到幅值
yLocal.SetValue(i, (mag == 0) ? 0x00u : (mag | (sign << 7))); // 0 编码不带符号位
收益是双重的:
- UB 占用从 64 KB 降到 32 KB,给 tile 留出更多空间
- LUT 一次搬完:
DataCopyPad的 blockLen 上限是 65535 字节,64 KB 恰好越线,原本要拆两次,现在一次完成
每个元素多花 4-5 条标量指令(AND / 移位 / 比较 / OR),但 compute 本就是标量循环,这点开销被数据搬运完全掩盖。
为什么编码方向是标量循环? Ascend C 编译器在 A2/A3 上不支持
Cast<uint32_t, uint16_t>、也不支持对uint16_t/uint32_t做ShiftLeft/Right,导致无法用Gather指令做向量化查表。每个元素只能GetValue→ 查表 →SetValue走标量路径。这也是双缓冲不带来收益的根因——下一节会讲。
四、核心优化二:动态 tiling —— 一份代码适配 A2/A3
A2 和 A3 的 UB 标称大小分别是 256 KB 和 512 KB,CANN 还会保留一部分。如果 tile 大小写死,要么浪费 UB,要么在小 UB 平台上越界崩溃。代码里同时也兼容了 Ascend 950(A5),但 A5 已有硬件原生 HiFloat8 通路,本算子主要服役在 A2/A3 上。
我们把 tile 计算彻底放到 host 运行时:
uint32_t ubBytes = platform->GetCoreMemSize(UB); // 查询当前平台实际可用 UB
uint32_t lutBytes = isEncode ? LUT16_SIZE : (LUT8_SIZE * 2);
uint32_t maxTile = (ubBytes - lutBytes) / 3u; // input(2B) + output(1B) for encode
maxTile = std::min(maxTile, 65536u); // 65536: A3 大 UB 实测可用上限
为什么是 /3? 编码方向 input 是 2 字节、output 是 1 字节,加起来每个元素占 3 字节 UB;解码方向反之,仍是 3 字节。
为什么不开双缓冲? 通常算子会用 ping-pong 双缓冲让数据搬运和计算重叠。但这里的 compute 是标量循环(上一节解释过:编译器对 uint16_t 的位操作和 Cast 限制,逼着只能逐元素查表),和 MTE(数据搬运单元)天然串不起来,重叠收益接近 0。把双缓冲省下的另一半 UB 直接给 tile,让单次处理更大块,反而更划算。
对齐策略也讲究:
if (maxTile >= 32768) return (maxTile / 32768) * 32768; // 大 tile 对齐到 32K
return std::max((maxTile / 32) * 32, 32); // 小 tile 对齐到向量粒度
32768 这个数字不是随便取的——下一节会解释。
五、核心优化三:65535 陷阱与 DataCopy 三分支
NPU 的数据搬运指令 DataCopyPad 的 DataCopyParams.blockLen 是 uint16_t 类型,最大 65535 字节。一旦单次搬运的 byteCount 超过这个值,就要拆成多个 block 调用。我们用三个分支精确分流:
if (byteCount <= 65535) {
cp = {1, byteCount, 0, 0}; // Case A: 一次搬完
} else if (byteCount % 32768 == 0) {
cp = {byteCount / 32768, 32768, 0, 0}; // Case B: n × 32K 拆分
} else {
cp = {2, byteCount / 2, 0, 0}; // Case C: encode 尾块
}
为什么这三个分支能完整覆盖所有情况?记住编码 input=2B/elem、解码 input=1B/elem,且 tile 元素数 ≤ 65536:
- Case A:byteCount ≤ 65535。覆盖编码 tile 元素数 < 32768 的情况,以及解码 tile 元素数 ≤ 65535 的情况
- Case B:byteCount 是 32768 的整数倍。覆盖完整 tile——上一节 tileLength 对齐到 32768 正是为了这一步。具体命中:编码 tileLength=32768/65536 → byteCount=65536/131072 →
{2,32768}或{4,32768};解码 tileLength=65536 → byteCount=65536 →{2,32768}。注意 32768 元素的 encode tile 和 65536 元素的 decode tile 都走这条路,不进 Case A - Case C:byteCount 为偶数但不是 32768 的整数倍。这恰好是编码尾块的形态——编码尾块元素数 ≤ 65535,乘以 2 后 byteCount/2 仍 ≤ 65535,正好放进 blockLen
三种情况是数学上的完整覆盖,没有兜底分支。这种"分支由数据形状的代数性质保证完备"的设计,比写一个 fallback 路径再祈祷它别被触发要可靠得多。
六、Host 侧 LUT 构建:精度边界的处理
HiFloat8 的解码是 256 条查表,平淡无奇。编码 LUT 才是工程量所在——32768 条 FP16/BF16 → HiFloat8 的映射要在 CPU 上一条条算出来。
工程上几个容易踩的坑:
- FP16 subnormal(非规格化数):当指数位为 0 但尾数非零时,浮点规范规定要按"无隐含 1"的方式解释,不能直接套用规格化数的公式。代码里用
__builtin_clz找尾数最高位、重新规格化为 FP32,再走 HiFloat8 编码路径。 - ±Inf / NaN:FP16 用
exp=31表示。如果不特判,把指数加上偏置(127-15=112)会变成 143,丢掉特殊值语义。必须显式映射到 FP32 的exp=255。 - HiFloat8 的特殊值:
0x80是 NaN,0x6F是 +Inf,0xEF是 -Inf。解码侧硬编码这三个特判。
LUT 在每个 device、每种 castMode 上只构建一次,加锁缓存:
static std::map<std::pair<int64_t, int64_t>, at::Tensor> lutCache; // (deviceIndex, castMode)
进程生命周期内的所有调用共享这张表,构建过程不在性能关键路径上。
七、性能数据:大数据量下吞吐稳定在 4 GB/s 量级
A2 平台实测,吞吐量按 (input bytes + output bytes) / time 计算(100 次迭代均值,10 次预热,NPU synchronize 计时)。下表为 BF16 ↔ HiF8 数据;FP16 ↔ HiF8 与 BF16 同构(仅 LUT 内容不同),实测性能与 BF16 基本一致:
| 数据量 | BF16 Encode | BF16 Decode | Roundtrip |
|---|---|---|---|
| 1 K | 14.9 MB/s | 13.5 MB/s | 10.5 MB/s |
| 64 K | 914 MB/s | 678 MB/s | 674 MB/s |
| 1 M | 3627 MB/s | 3646 MB/s | 2447 MB/s |
| 4 M | 4016.7 MB/s | 4044.3 MB/s | 2684.8 MB/s |
要点:
- 大数据(≥ 4 M)encode/decode 单方向稳定在 4 GB/s 量级,对一个标量循环 + 查表的 kernel 来说已经接近这条计算路径的天花板
- 小数据(< 64 K)受限于 kernel 启动 + LUT 加载这些固定开销,单次 ~0.2 ms
- 通过"核数优先"策略(每核最少分到 4096 / 2048 元素)避免小数据时启动过多核重复加载 LUT
精度侧覆盖了 FP16/BF16 全 65536 个位模式的 roundtrip、HiFloat8 全 256 个值的 decode,以及 ±0、±Inf、NaN、subnormal 的边界场景。
八、超出本算子的几条可迁移设计原则
-
"用空间换分支"是查表法的本质。位运算分支多、有非平凡边界(subnormal / Inf / NaN)的算子,先想想能不能 LUT 化。在 device 上写 if-else 永远比查表慢。
-
由代数性质保证分支完备,强于单元测试覆盖。单元测试只能验证你想到的输入,分支完备性能保证任意输入都有归属。本算子的 DataCopy 三分支就是按 byteCount 的可整除性穷举得到的——不需要兜底分支。
-
跨平台编译产物的关键是把硬件参数从编译期推到 host 运行时。
#ifdef PLATFORM_A2是设计失败的信号;GetCoreMemSize()+ 运行时 tiling 才是干净的解法。device 端 kernel 不应该知道自己跑在哪一代芯片上。 -
Host 慢一点没关系,device 必须快。host LUT 构建一次性预计算 32768 条边界,进程内永远不再算第二次。把所有"复杂、易错、慢"的逻辑挪到 host,device 只剩查表 + 写回——这种 host/device 分工模式可推广到任何"映射类"算子。
-
能用对称性就别用大表。半空间 LUT 多了 4-5 条标量指令,但省下的 UB 直接换成更大的 tile,整体吞吐反而上去了。优化经常是这种"局部多花、整体省回"的算账。
这个算子是 AMCT HiFloat8 量化链路的基础设施,后续的 W8A8 推理、训练前向都会复用它做格式转换。
延伸阅读
代码位置:amct_ops/hifloat8_cast/,使用方式见同目录 README.md。完整源码与社区入口:
更多推荐


所有评论(0)