[Common/PyTorch] Add MXFP8 cast-and-transpose op#2930
[Common/PyTorch] Add MXFP8 cast-and-transpose op#2930jeweldave wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Add a fused MXFP8 cast-and-transpose op that takes a high-precision tensor
plus the source's existing compact column-wise E8M0 scales and emits row-wise
compact MXFP8 storage for the source's logical transpose.
The standard MXFP8Quantizer path can already produce row-wise and column-wise
MXFP8 from BF16/FP16/FP32 input. There is currently no public TE path that,
given X and its compact column-wise scales S_col(X), produces the row-wise
compact MXFP8 storage for the logical transpose X.T without either re-reading
the BF16 source or copying the existing column-wise MXFP8 payload and scales
into transposed row-wise storage. This op closes that gap. It is the building
block needed to route MXFP8 backward through TN GEMMs on hardware where
cuBLASLt does not currently support MXFP8 backward NN/NT layouts (NVIDIA Spark
sm_12.1). On B200 / H100 the new op is unused by default; downstream code can
still call it for any path that wants direct transposed-rowwise MXFP8 emission
without a payload copy.
Surfaces in three layers, all additive:
* C API (ABI-safe):
- nvte_mxfp8_scaling_transpose_cast(input, scale_inv_colwise,
output_rowwise, output_rowwise_scale_inv, rows, cols, stream) — minimal
signature, E4M3 output, non-swizzled scales.
- nvte_mxfp8_scaling_transpose_cast_v2(..., fp8_dtype,
with_gemm_swizzled_scales, stream) — extended signature.
* PyTorch extension: transformer_engine_torch.mxfp8_scaling_transpose_cast
(default kwargs match the minimal C symbol's behavior).
* Python: MXFP8Quantizer.quantize_rowwise_transpose(tensor,
columnwise_scale_inv, *, fake_dtype=None, with_gemm_swizzled_scales=None)
returns a row-wise-only MXFP8Tensor whose logical shape is tensor.T.
No existing C symbol, Python signature, or default behavior is changed.
Tests in tests/pytorch/mxfp8/:
* test_mxfp8_scaling_transpose_cast.py — byte equivalence vs. column-wise-
then-copy reference (E4M3 + E5M2, multiple shapes), Python helper
equivalence, decoded-value reconstruction within MXFP8 quantization
tolerance, error paths for FP8 input and non-block-aligned dims.
* test_mxfp8_scaling_transpose_cast_swizzled.py — with
with_gemm_swizzled_scales=True, emitted row-wise payload and scales match
the bytes produced by the standard MXFP8Quantizer.quantize swizzled path
on the actual transposed source. Comparison is byte-for-byte rather than
via decoded values because TE's dequantize kernels intentionally reject
with_gemm_swizzled_scales=True inputs (one-way GEMM-operand layout).
Tested on NVIDIA GB10 (sm_12.1) with TE rebuilt from this change: all 14
parametrized tests pass.
Signed-off-by: David Gornshtein <davidgornshtein@gmail.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a fused MXFP8 cast-and-transpose op that takes a high-precision tensor and its existing compact column-wise E8M0 scales and emits row-wise MXFP8 payload and scales for the logical transpose, avoiding an extra BF16 re-read or payload copy. The implementation is additive (new C symbols, PyTorch binding, and Confidence Score: 4/5Safe to merge; all findings are style/consistency P2s with no correctness impact. Implementation logic is sound — kernel indexing, boundary conditions, scale buffer sizing, and swizzled-index arithmetic are all correct. The only findings are: the v1 C entry point is missing its own NVTE_API_CALL (causing profiling misattribution), the transpose cast kernel lacks launch_bounds, and there is an unused variable in one test. None affect runtime correctness. transformer_engine/common/recipe/mxfp8_scaling.cu (missing NVTE_API_CALL on v1 wrapper, missing launch_bounds on cast kernel) Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant PyHelper as MXFP8Quantizer.quantize_rowwise_transpose
participant PyBind as tex.mxfp8_scaling_transpose_cast
participant CPP as mxfp8_scaling_transpose_cast (C++)
participant CSymV2 as nvte_mxfp8_scaling_transpose_cast_v2
participant ScaleKernel as mxfp8_scaling_transpose_scales_kernel
participant CastKernel as mxfp8_scaling_transpose_cast_kernel
Caller->>PyHelper: tensor, columnwise_scale_inv
PyHelper->>PyHelper: validate shapes, allocate outputs
PyHelper->>PyBind: source_2d, colwise_scale_inv, out_data, out_scale, rows, cols, fp8_dtype, swizzled
PyBind->>CSymV2: nvte tensors + params
CSymV2->>CPP: Tensor wrappers
CPP->>ScaleKernel: transpose E8M0 scale bytes (compact or swizzled)
CPP->>CastKernel: cast+transpose FP8 payload tiles
CastKernel-->>CPP: rowwise_data (cols x rows)
ScaleKernel-->>CPP: rowwise_scale_inv (transposed)
CPP-->>PyHelper: filled output buffers
PyHelper->>Caller: MXFP8Tensor(shape=(cols,rows), rowwise_only)
|
| *convertNVTETensorCheck(input), *convertNVTETensorCheck(scale_inv_colwise), | ||
| *convertNVTETensorCheck(output_rowwise), *convertNVTETensorCheck(output_rowwise_scale_inv), | ||
| rows, cols, static_cast<DType>(fp8_dtype), with_gemm_swizzled_scales, stream); | ||
| } | ||
|
|
||
| void nvte_mxfp8_scaling_transpose_cast(const NVTETensor input, const NVTETensor scale_inv_colwise, |
There was a problem hiding this comment.
Missing
__launch_bounds__ on transpose cast kernel
All other __global__ kernels in this file (mxfp8_scaling_compute_partial_amax_kernel, mxfp8_scaling_partial_cast_kernel, mxfp8_scaling_transpose_scales_kernel) carry __launch_bounds__(kThreadsPerBlock). This kernel is launched with block = dim3(kTransposeTileDim, kTransposeTileDim) = 256 threads, so the appropriate hint would be __launch_bounds__(kTransposeTileDim * kTransposeTileDim). Without it the compiler cannot optimize register allocation for the stated block size.
Summary
Add a fused MXFP8 cast-and-transpose op that takes a high-precision tensor
plus the source's existing compact column-wise E8M0 scales and emits row-wise
compact MXFP8 storage for the source's logical transpose. Surfaces in three
layers:
output_rowwise, output_rowwise_scale_inv, rows, cols, stream) — minimal
signature, E4M3 output, non-swizzled scales.
with_gemm_swizzled_scales, stream) — extended signature.
(default kwargs map to the minimal C symbol's behavior).
columnwise_scale_inv, *, fake_dtype=None, with_gemm_swizzled_scales=None)
returns a row-wise-only MXFP8Tensor whose logical shape is tensor.T.
No existing C symbol, Python signature, or default behavior is changed.
Why
The standard MXFP8Quantizer path can already produce row-wise and
column-wise MXFP8 from BF16/FP16/FP32 input. There is currently no public TE
path that, given X and its compact column-wise scales S_col(X), produces
the row-wise compact MXFP8 storage for the logical transpose X.T without
either:
(re-reads BF16 source), or
transposed row-wise storage (extra payload+scale byte traffic).
This op closes that gap. It is the building block needed to route MXFP8
backward through TN GEMMs on hardware where cuBLASLt does not currently
support MXFP8 backward NN/NT GEMM layouts (NVIDIA Spark / sm_12.1). On
hardware where backward MXFP8 NN/NT is supported (B200 sm_10.0, H100
sm_9.0) it is unused by default; downstream code can still call it for any
path that wants direct transposed-rowwise MXFP8 emission without a payload
copy.
Detailed motivation, measurements, and out-of-scope notes are in
docs/motivation.md of the proposal directory.
What's in the change
nvte_mxfp8_scaling_transpose_cast and _v2.
(mxfp8_scaling_transpose_cast_kernel for the FP8 payload tile transpose,
mxfp8_scaling_transpose_scales_kernel for compact-or-swizzled scale
transpose), one new C++ entry point, two new C symbols.
extensions/fp8_partial_cast.cpp /
extensions/pybind.cpp — new
mxfp8_scaling_transpose_cast PyTorch binding routed through _v2.
MXFP8Quantizer.quantize_rowwise_transpose helper.
Numerics
For an input tensor X quantized with the standard MXFP8 column-wise path,
the new op's output is bit-for-bit equal to taking the existing column-wise
MXFP8 payload + scales and transposing those bytes. Confirmed on GB10 in the
existing cppmega probe at (M, N, K) = (64, 96, 128) and (256, 4096, 4096):
max_payload_abs_byte_delta == 0, payload_equal == True,
scale_equal == True.
The included tests/test_mxfp8_scaling_transpose_cast.py exercises this
equivalence both via the raw extension call and via the
quantize_rowwise_transpose helper.
Tests
Drop-in pytest files added under tests/pytorch/ (this PR puts them in
tests/):
E4M3 and E5M2;
the native re-quantized transpose;
produced by the standard MXFP8Quantizer.quantize swizzled path on the
actual transposed source.
All tests gate on CUDA being present and the new extension symbol being
built into the loaded transformer_engine_torch module.
Compatibility
stable signature; _v2 carries the extra knobs. No symbol's signature is
changed.
signature or default behavior is changed.
that already participate in the MXFP8 recipe build.
Out of scope (intentionally not in this PR)
cuBLASLt behavior we cannot upstream and on a downstream backward-rewrite
shim.
design TE rejects this in cast/mxfp8/dequantize_mxfp8.cuh (and symmetric
paths in cast/nvfp4/dequantize_nvfp4.cuh,
cast/mxfp8/group_dequantize_mxfp8.cuh) with
Input must have scales in compact format: swizzled scales are a one-way
GEMM-operand layout, and the dequantize kernels don't carry an inverse
unswizzle_scale_idx path. Our swizzled-scale test
(test_mxfp8_scaling_transpose_cast_swizzled.py) therefore compares the
emitted row-wise payload and scale bytes against the standard
MXFP8Quantizer.quantize(...,
with_gemm_swizzled_scales=True) output byte-for-byte instead of via
decoded values, since both paths target the same GEMM-ready layout for
the same logical row-wise tensor.