HiFloat8 Cast 算子在 Atlas A2/A3 昇腾 NPU 上为 PyTorch 实现 FP16/BF16 ↔ HiFloat8 双向转换,通过半空间 LUT、动态 tiling 和 DataCopy 分支优化,在大数据量下单方向吞吐稳定在 4 GB/s 量级——接近当前软件查表实现的吞吐上限。
Ascend 950系列产品 已具备原生 HiFloat8 硬件能力,本算子主要面向尚无该硬件加速的 A2/A3 平台。

一、为什么需要 HiFloat8 Cast算子

低精度是当前大模型推理落地的主线趋势。推理最大的瓶颈之一是带宽:模型权重几十上百 GB,每生成一个 token 都要把权重从显存搬一遍。把权重从 16 位压到 8 位,搬运量直接砍半,吞吐就上去了。

业界已经有 FP8 标准(E4M3 / E5M2 两个变体),它把 8 位固定切成"1 位符号 + 4/5 位指数 + 3/2 位尾数"。HiFloat8 的思路不一样:指数和尾数的位数不固定,根据数值大小动态切分。

  • 数值靠近 0 时:多分一些位给尾数,精度高
  • 数值远离 0 时:多分一些位给指数,动态范围大
  • 8 位里还要塞下 ±0、±Inf、NaN、subnormal 等特殊值

这种"分段编码"在大模型权重分布(绝大多数值集中在 0 附近)下精度更好,但代价是编解码逻辑更复杂——不能像 FP8 那样直接位移截断,必须做格式判断和位重排。

Ascend 950系列芯片已经把 HiFloat8 的编解码做进了硬件,可以像 FP16/FP32 一样原生处理。但当前主流在用的 Atlas A2、A3 还没有这条硬件通路——要在这两代芯片上跑低精度浮点量化,就得靠软件 kernel 把转换路径补出来。这也意味着:不具备 Ascend 950系列 开发条件的开发者,完全可以基于 A2/A3 + 本算子先把 HiFloat8 数据格式吃透,把上层的量化/校准/优化算法跑通。 等 Ascend 950系列 硬件就绪时,只要把转换路径切到原生指令,上层算法可以平滑迁移,前期的开发投入不会白费。

我们要解决的问题就是:把 FP16/BF16 张量在 A2/A3 NPU 上高效地转成 HiFloat8,再转回来。一切下游的 HiFloat8 量化推理/训练特性,都依赖这个最底层的转换算子。

二、算子全景:四种模式,一个 kernel

PyTorch 用户只看到两个函数:

from amct_ops.hifloat8_cast import encode_to_hifloat8, decode_from_hifloat8

x = torch.randn(1024, 256, dtype=torch.bfloat16, device='npu')
y = encode_to_hifloat8(x)                       # FP16/BF16 → uint8(HiFloat8)
z = decode_from_hifloat8(y, torch.bfloat16)     # uint8     → FP16/BF16

HiFloat8 不是 PyTorch 原生类型,所以编码后的张量用 uint8 承载——一个字节一个值,shape 不变。

底层根据输入 dtype 自动派发到 4 种 castMode

castMode 转换 实现要点
0 FP16 → HiFloat8 半空间 LUT:32768 条 × 1 B = 32 KB UB
1 BF16 → HiFloat8 半空间 LUT:32768 条 × 1 B = 32 KB UB
2 HiFloat8 → FP16 LUT:256 条 × 2 B = 512 B UB
3 HiFloat8 → BF16 LUT:256 条 × 2 B = 512 B UB

四种模式共用同一个 device kernel KernelHiFloat8CastLut,差异只在 LUT(查找表)内容和 input/output 的 dtype。

为什么用查找表? 因为前面说过,HiFloat8 的编解码逻辑很碎:要判断指数段、处理 subnormal、特判 ±Inf/NaN……如果在 device 上一条条算,分支多、指令多。但 FP16/BF16 一共只有 65536 种位模式,HiFloat8 只有 256 种——干脆在 host(CPU)上预先把所有结果算出来打成一张表,device 上只剩"查表 + 写回",又快又简单。

三、核心优化一:半空间 LUT —— 用对称性省下一半 UB

最初版本简单粗暴:65536 条 LUT,把 FP16/BF16 的所有位模式直接映射到 HiFloat8。

问题马上来了。65536 × 1 字节 = 64 KB UB。A2 平台 UB 标称 256 KB,CANN 运行时还要保留一部分,实际只剩 ~192 KB。这 64 KB 放进去后,剩余空间还要装 input 队列、output 队列,最后能切给一个 tile(一次处理的一小片数据)的 UB 就很捉襟见肘。tile 切得越小,启动次数越多,吞吐越低。

观察到一个关键事实:HiFloat8 的符号位与幅值正交。HiFloat8 编码把最高位作为符号、低 7 位作为幅值;而 FP16/BF16 的最高位也是符号位。于是同一个绝对值不论正负,HiFloat8 编码的低 7 位完全相同:

encode(-x) = encode(x) | 0x80   (当 encode(x) ≠ 0 时)
encode(-0) = 0x00                (0 不带符号)

这里说的"对称"不是 IEEE 浮点严格意义上的负零/NaN 对称(BF16 的 NaN 编码就不严格对称),而是HiFloat8 编码侧的对称——低 7 位幅值只取决于源浮点的低 15 位,与符号位无关。

那 LUT 完全没必要存两份。只存正半空间 32768 条"绝对值表",编码时先剥离符号、查表、再把符号补回去:

LocalTensor<uint16_t> xU16 = xLocal.ReinterpretCast<uint16_t>();  // BF16/FP16 的位视图
uint16_t v    = xU16.GetValue(i);                          // 取一个 FP16/BF16 值
uint8_t  sign = static_cast<uint8_t>(v >> 15);             // 提取最高位(符号)
uint8_t  mag  = lut.GetValue(v & 0x7FFFu);                 // 用绝对值查表得到幅值
yLocal.SetValue(i, (mag == 0) ? 0x00u : (mag | (sign << 7))); // 0 编码不带符号位

收益是双重的:

  • UB 占用从 64 KB 降到 32 KB,给 tile 留出更多空间
  • LUT 一次搬完DataCopyPad 的 blockLen 上限是 65535 字节,64 KB 恰好越线,原本要拆两次,现在一次完成

每个元素多花 4-5 条标量指令(AND / 移位 / 比较 / OR),但 compute 本就是标量循环,这点开销被数据搬运完全掩盖。

为什么编码方向是标量循环? Ascend C 编译器在 A2/A3 上不支持 Cast<uint32_t, uint16_t>、也不支持对 uint16_t/uint32_tShiftLeft/Right,导致无法用 Gather 指令做向量化查表。每个元素只能 GetValue → 查表 → SetValue 走标量路径。这也是双缓冲不带来收益的根因——下一节会讲。

四、核心优化二:动态 tiling —— 一份代码适配 A2/A3

A2 和 A3 的 UB 标称大小分别是 256 KB 和 512 KB,CANN 还会保留一部分。如果 tile 大小写死,要么浪费 UB,要么在小 UB 平台上越界崩溃。代码里同时也兼容了 Ascend 950(A5),但 A5 已有硬件原生 HiFloat8 通路,本算子主要服役在 A2/A3 上。

我们把 tile 计算彻底放到 host 运行时:

uint32_t ubBytes  = platform->GetCoreMemSize(UB);        // 查询当前平台实际可用 UB
uint32_t lutBytes = isEncode ? LUT16_SIZE : (LUT8_SIZE * 2);
uint32_t maxTile  = (ubBytes - lutBytes) / 3u;           // input(2B) + output(1B) for encode
maxTile = std::min(maxTile, 65536u);                     // 65536: A3 大 UB 实测可用上限

为什么是 /3 编码方向 input 是 2 字节、output 是 1 字节,加起来每个元素占 3 字节 UB;解码方向反之,仍是 3 字节。

为什么不开双缓冲? 通常算子会用 ping-pong 双缓冲让数据搬运和计算重叠。但这里的 compute 是标量循环(上一节解释过:编译器对 uint16_t 的位操作和 Cast 限制,逼着只能逐元素查表),和 MTE(数据搬运单元)天然串不起来,重叠收益接近 0。把双缓冲省下的另一半 UB 直接给 tile,让单次处理更大块,反而更划算。

对齐策略也讲究:

if (maxTile >= 32768) return (maxTile / 32768) * 32768;  // 大 tile 对齐到 32K
return std::max((maxTile / 32) * 32, 32);                // 小 tile 对齐到向量粒度

32768 这个数字不是随便取的——下一节会解释。

五、核心优化三:65535 陷阱与 DataCopy 三分支

NPU 的数据搬运指令 DataCopyPadDataCopyParams.blockLenuint16_t 类型,最大 65535 字节。一旦单次搬运的 byteCount 超过这个值,就要拆成多个 block 调用。我们用三个分支精确分流:

if (byteCount <= 65535) {
    cp = {1, byteCount, 0, 0};                  // Case A: 一次搬完
} else if (byteCount % 32768 == 0) {
    cp = {byteCount / 32768, 32768, 0, 0};      // Case B: n × 32K 拆分
} else {
    cp = {2, byteCount / 2, 0, 0};              // Case C: encode 尾块
}

为什么这三个分支能完整覆盖所有情况?记住编码 input=2B/elem、解码 input=1B/elem,且 tile 元素数 ≤ 65536:

  • Case A:byteCount ≤ 65535。覆盖编码 tile 元素数 < 32768 的情况,以及解码 tile 元素数 ≤ 65535 的情况
  • Case B:byteCount 是 32768 的整数倍。覆盖完整 tile——上一节 tileLength 对齐到 32768 正是为了这一步。具体命中:编码 tileLength=32768/65536 → byteCount=65536/131072 → {2,32768}{4,32768};解码 tileLength=65536 → byteCount=65536 → {2,32768}注意 32768 元素的 encode tile 和 65536 元素的 decode tile 都走这条路,不进 Case A
  • Case C:byteCount 为偶数但不是 32768 的整数倍。这恰好是编码尾块的形态——编码尾块元素数 ≤ 65535,乘以 2 后 byteCount/2 仍 ≤ 65535,正好放进 blockLen

三种情况是数学上的完整覆盖,没有兜底分支。这种"分支由数据形状的代数性质保证完备"的设计,比写一个 fallback 路径再祈祷它别被触发要可靠得多。

六、Host 侧 LUT 构建:精度边界的处理

HiFloat8 的解码是 256 条查表,平淡无奇。编码 LUT 才是工程量所在——32768 条 FP16/BF16 → HiFloat8 的映射要在 CPU 上一条条算出来。

工程上几个容易踩的坑:

  • FP16 subnormal(非规格化数):当指数位为 0 但尾数非零时,浮点规范规定要按"无隐含 1"的方式解释,不能直接套用规格化数的公式。代码里用 __builtin_clz 找尾数最高位、重新规格化为 FP32,再走 HiFloat8 编码路径。
  • ±Inf / NaN:FP16 用 exp=31 表示。如果不特判,把指数加上偏置(127-15=112)会变成 143,丢掉特殊值语义。必须显式映射到 FP32 的 exp=255
  • HiFloat8 的特殊值0x80 是 NaN,0x6F 是 +Inf,0xEF 是 -Inf。解码侧硬编码这三个特判。

LUT 在每个 device、每种 castMode 上只构建一次,加锁缓存:

static std::map<std::pair<int64_t, int64_t>, at::Tensor> lutCache;  // (deviceIndex, castMode)

进程生命周期内的所有调用共享这张表,构建过程不在性能关键路径上。

七、性能数据:大数据量下吞吐稳定在 4 GB/s 量级

A2 平台实测,吞吐量按 (input bytes + output bytes) / time 计算(100 次迭代均值,10 次预热,NPU synchronize 计时)。下表为 BF16 ↔ HiF8 数据;FP16 ↔ HiF8 与 BF16 同构(仅 LUT 内容不同),实测性能与 BF16 基本一致:

数据量 BF16 Encode BF16 Decode Roundtrip
1 K 14.9 MB/s 13.5 MB/s 10.5 MB/s
64 K 914 MB/s 678 MB/s 674 MB/s
1 M 3627 MB/s 3646 MB/s 2447 MB/s
4 M 4016.7 MB/s 4044.3 MB/s 2684.8 MB/s

要点:

  • 大数据(≥ 4 M)encode/decode 单方向稳定在 4 GB/s 量级,对一个标量循环 + 查表的 kernel 来说已经接近这条计算路径的天花板
  • 小数据(< 64 K)受限于 kernel 启动 + LUT 加载这些固定开销,单次 ~0.2 ms
  • 通过"核数优先"策略(每核最少分到 4096 / 2048 元素)避免小数据时启动过多核重复加载 LUT

精度侧覆盖了 FP16/BF16 全 65536 个位模式的 roundtrip、HiFloat8 全 256 个值的 decode,以及 ±0、±Inf、NaN、subnormal 的边界场景。

八、超出本算子的几条可迁移设计原则

  1. "用空间换分支"是查表法的本质。位运算分支多、有非平凡边界(subnormal / Inf / NaN)的算子,先想想能不能 LUT 化。在 device 上写 if-else 永远比查表慢。

  2. 由代数性质保证分支完备,强于单元测试覆盖。单元测试只能验证你想到的输入,分支完备性能保证任意输入都有归属。本算子的 DataCopy 三分支就是按 byteCount 的可整除性穷举得到的——不需要兜底分支。

  3. 跨平台编译产物的关键是把硬件参数从编译期推到 host 运行时#ifdef PLATFORM_A2 是设计失败的信号;GetCoreMemSize() + 运行时 tiling 才是干净的解法。device 端 kernel 不应该知道自己跑在哪一代芯片上。

  4. Host 慢一点没关系,device 必须快。host LUT 构建一次性预计算 32768 条边界,进程内永远不再算第二次。把所有"复杂、易错、慢"的逻辑挪到 host,device 只剩查表 + 写回——这种 host/device 分工模式可推广到任何"映射类"算子。

  5. 能用对称性就别用大表。半空间 LUT 多了 4-5 条标量指令,但省下的 UB 直接换成更大的 tile,整体吞吐反而上去了。优化经常是这种"局部多花、整体省回"的算账。


这个算子是 AMCT HiFloat8 量化链路的基础设施,后续的 W8A8 推理、训练前向都会复用它做格式转换。

延伸阅读

代码位置:amct_ops/hifloat8_cast/,使用方式见同目录 README.md。完整源码与社区入口:

Logo

CANN开发者社区旨在汇聚广大开发者,围绕CANN架构重构、算子开发、部署应用优化等核心方向,展开深度交流与思想碰撞,携手共同促进CANN开放生态突破!

更多推荐