Skip to content

[Common/PyTorch] Add MXFP8 cast-and-transpose op#2930

Open
jeweldave wants to merge 2 commits intoNVIDIA:mainfrom
jeweldave:feat/mxfp8-transpose-cast
Open

[Common/PyTorch] Add MXFP8 cast-and-transpose op#2930
jeweldave wants to merge 2 commits intoNVIDIA:mainfrom
jeweldave:feat/mxfp8-transpose-cast

Conversation

@jeweldave
Copy link
Copy Markdown

Summary

Add a fused MXFP8 cast-and-transpose op that takes a high-precision tensor
plus the source's existing compact column-wise E8M0 scales and emits row-wise
compact MXFP8 storage for the source's logical transpose. Surfaces in three
layers:

  • C API (additive, ABI-safe):
    • nvte_mxfp8_scaling_transpose_cast(input, scale_inv_colwise,
      output_rowwise, output_rowwise_scale_inv, rows, cols, stream) — minimal
      signature, E4M3 output, non-swizzled scales.
    • nvte_mxfp8_scaling_transpose_cast_v2(..., fp8_dtype,
      with_gemm_swizzled_scales, stream) — extended signature.
  • PyTorch extension: transformer_engine_torch.mxfp8_scaling_transpose_cast
    (default kwargs map to the minimal C symbol's behavior).
  • Python: MXFP8Quantizer.quantize_rowwise_transpose(tensor,
    columnwise_scale_inv, *, fake_dtype=None, with_gemm_swizzled_scales=None)
    returns a row-wise-only MXFP8Tensor whose logical shape is tensor.T.

No existing C symbol, Python signature, or default behavior is changed.

Why

The standard MXFP8Quantizer path can already produce row-wise and
column-wise MXFP8 from BF16/FP16/FP32 input. There is currently no public TE
path that, given X and its compact column-wise scales S_col(X), produces
the row-wise compact MXFP8 storage for the logical transpose X.T without
either:

  • emitting it from BF16 in a separate pass through the standard quantizer
    (re-reads BF16 source), or
  • copying the existing column-wise MXFP8 payload and column-wise scales into
    transposed row-wise storage (extra payload+scale byte traffic).

This op closes that gap. It is the building block needed to route MXFP8
backward through TN GEMMs on hardware where cuBLASLt does not currently
support MXFP8 backward NN/NT GEMM layouts (NVIDIA Spark / sm_12.1). On
hardware where backward MXFP8 NN/NT is supported (B200 sm_10.0, H100
sm_9.0) it is unused by default; downstream code can still call it for any
path that wants direct transposed-rowwise MXFP8 emission without a payload
copy.

Detailed motivation, measurements, and out-of-scope notes are in
docs/motivation.md of the proposal directory.

What's in the change

  • transformer_engine/common/include/transformer_engine/recipe.h — declare
    nvte_mxfp8_scaling_transpose_cast and _v2.
  • transformer_engine/common/recipe/mxfp8_scaling.cu — two new kernels
    (mxfp8_scaling_transpose_cast_kernel for the FP8 payload tile transpose,
    mxfp8_scaling_transpose_scales_kernel for compact-or-swizzled scale
    transpose), one new C++ entry point, two new C symbols.
  • transformer_engine/pytorch/csrc/extensions.h /
    extensions/fp8_partial_cast.cpp /
    extensions/pybind.cpp — new
    mxfp8_scaling_transpose_cast PyTorch binding routed through _v2.
  • transformer_engine/pytorch/tensor/mxfp8_tensor.py — new
    MXFP8Quantizer.quantize_rowwise_transpose helper.

Numerics

For an input tensor X quantized with the standard MXFP8 column-wise path,
the new op's output is bit-for-bit equal to taking the existing column-wise
MXFP8 payload + scales and transposing those bytes. Confirmed on GB10 in the
existing cppmega probe at (M, N, K) = (64, 96, 128) and (256, 4096, 4096):
max_payload_abs_byte_delta == 0, payload_equal == True,
scale_equal == True.

The included tests/test_mxfp8_scaling_transpose_cast.py exercises this
equivalence both via the raw extension call and via the
quantize_rowwise_transpose helper.

Tests

Drop-in pytest files added under tests/pytorch/ (this PR puts them in
tests/):

  • test_mxfp8_scaling_transpose_cast.py:
    • byte equivalence vs. column-wise-then-copy reference, multiple shapes,
      E4M3 and E5M2;
    • Python helper equivalence;
    • decoded-value reconstruction is within MXFP8 quantization tolerance of
      the native re-quantized transpose;
    • error path: high-precision input is required (FP8 input rejected);
    • error path: source dims must be MXFP8-block-aligned.
  • test_mxfp8_scaling_transpose_cast_swizzled.py:
    • with with_gemm_swizzled_scales=True, emitted scales match the bytes
      produced by the standard MXFP8Quantizer.quantize swizzled path on the
      actual transposed source.

All tests gate on CUDA being present and the new extension symbol being
built into the loaded transformer_engine_torch module.

Compatibility

  • C API: additive only. Original symbol name is reserved as the long-term
    stable signature; _v2 carries the extra knobs. No symbol's signature is
    changed.
  • Python: additive only. New method on MXFP8Quantizer; no existing method
    signature or default behavior is changed.
  • Build system: no new build flags or files; only changes existing files
    that already participate in the MXFP8 recipe build.

Out of scope (intentionally not in this PR)

  • Wiring TE Linear backward to use this op on GB10. That depends on
    cuBLASLt behavior we cannot upstream and on a downstream backward-rewrite
    shim.
  • Changing default behavior of quantize / quantize_rowwise.
  • Any change to cuBLASLt transposed operand consumption.
  • Dequantize support for with_gemm_swizzled_scales=True MXFP8 tensors. By
    design TE rejects this in cast/mxfp8/dequantize_mxfp8.cuh (and symmetric
    paths in cast/nvfp4/dequantize_nvfp4.cuh,
    cast/mxfp8/group_dequantize_mxfp8.cuh) with
    Input must have scales in compact format: swizzled scales are a one-way
    GEMM-operand layout, and the dequantize kernels don't carry an inverse
    unswizzle_scale_idx path. Our swizzled-scale test
    (test_mxfp8_scaling_transpose_cast_swizzled.py) therefore compares the
    emitted row-wise payload and scale bytes against the standard
    MXFP8Quantizer.quantize(...,
    with_gemm_swizzled_scales=True) output byte-for-byte instead of via
    decoded values, since both paths target the same GEMM-ready layout for
    the same logical row-wise tensor.

apstenku123 and others added 2 commits April 26, 2026 23:28
Add a fused MXFP8 cast-and-transpose op that takes a high-precision tensor
plus the source's existing compact column-wise E8M0 scales and emits row-wise
compact MXFP8 storage for the source's logical transpose.

The standard MXFP8Quantizer path can already produce row-wise and column-wise
MXFP8 from BF16/FP16/FP32 input. There is currently no public TE path that,
given X and its compact column-wise scales S_col(X), produces the row-wise
compact MXFP8 storage for the logical transpose X.T without either re-reading
the BF16 source or copying the existing column-wise MXFP8 payload and scales
into transposed row-wise storage. This op closes that gap. It is the building
block needed to route MXFP8 backward through TN GEMMs on hardware where
cuBLASLt does not currently support MXFP8 backward NN/NT layouts (NVIDIA Spark
sm_12.1). On B200 / H100 the new op is unused by default; downstream code can
still call it for any path that wants direct transposed-rowwise MXFP8 emission
without a payload copy.

Surfaces in three layers, all additive:

* C API (ABI-safe):
  - nvte_mxfp8_scaling_transpose_cast(input, scale_inv_colwise,
    output_rowwise, output_rowwise_scale_inv, rows, cols, stream) — minimal
    signature, E4M3 output, non-swizzled scales.
  - nvte_mxfp8_scaling_transpose_cast_v2(..., fp8_dtype,
    with_gemm_swizzled_scales, stream) — extended signature.
* PyTorch extension: transformer_engine_torch.mxfp8_scaling_transpose_cast
  (default kwargs match the minimal C symbol's behavior).
* Python: MXFP8Quantizer.quantize_rowwise_transpose(tensor,
  columnwise_scale_inv, *, fake_dtype=None, with_gemm_swizzled_scales=None)
  returns a row-wise-only MXFP8Tensor whose logical shape is tensor.T.

No existing C symbol, Python signature, or default behavior is changed.

Tests in tests/pytorch/mxfp8/:
* test_mxfp8_scaling_transpose_cast.py — byte equivalence vs. column-wise-
  then-copy reference (E4M3 + E5M2, multiple shapes), Python helper
  equivalence, decoded-value reconstruction within MXFP8 quantization
  tolerance, error paths for FP8 input and non-block-aligned dims.
* test_mxfp8_scaling_transpose_cast_swizzled.py — with
  with_gemm_swizzled_scales=True, emitted row-wise payload and scales match
  the bytes produced by the standard MXFP8Quantizer.quantize swizzled path
  on the actual transposed source. Comparison is byte-for-byte rather than
  via decoded values because TE's dequantize kernels intentionally reject
  with_gemm_swizzled_scales=True inputs (one-way GEMM-operand layout).

Tested on NVIDIA GB10 (sm_12.1) with TE rebuilt from this change: all 14
parametrized tests pass.

Signed-off-by: David Gornshtein <davidgornshtein@gmail.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 26, 2026

Greptile Summary

This PR adds a fused MXFP8 cast-and-transpose op that takes a high-precision tensor and its existing compact column-wise E8M0 scales and emits row-wise MXFP8 payload and scales for the logical transpose, avoiding an extra BF16 re-read or payload copy. The implementation is additive (new C symbols, PyTorch binding, and MXFP8Quantizer.quantize_rowwise_transpose helper), the shared-memory tiled transpose kernel handles boundary tiles correctly, swizzled-index arithmetic is consistent with the existing gemm_swizzled_scale_idx helper, and the output buffer sizes are correct in both the compact and swizzled paths.

Confidence Score: 4/5

Safe to merge; all findings are style/consistency P2s with no correctness impact.

Implementation logic is sound — kernel indexing, boundary conditions, scale buffer sizing, and swizzled-index arithmetic are all correct. The only findings are: the v1 C entry point is missing its own NVTE_API_CALL (causing profiling misattribution), the transpose cast kernel lacks launch_bounds, and there is an unused variable in one test. None affect runtime correctness.

transformer_engine/common/recipe/mxfp8_scaling.cu (missing NVTE_API_CALL on v1 wrapper, missing launch_bounds on cast kernel)

Important Files Changed

Filename Overview
transformer_engine/common/recipe/mxfp8_scaling.cu Adds two new CUDA kernels (transpose-cast payload and transpose-scale), one C++ entry point, and two C API symbols; v1 wrapper is missing its own NVTE_API_CALL; cast kernel lacks launch_bounds
transformer_engine/common/include/transformer_engine/recipe.h Adds declarations for nvte_mxfp8_scaling_transpose_cast and _v2 with clear parameter documentation; additive-only, no existing signatures changed
transformer_engine/pytorch/tensor/mxfp8_tensor.py Adds quantize_rowwise_transpose helper to MXFP8Quantizer with thorough input validation and correct scale buffer allocation matching the expected (cols×rows) rowwise layout
transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp Adds PyTorch binding for mxfp8_scaling_transpose_cast, correctly wrapping nvte_mxfp8_scaling_transpose_cast_v2 with contiguity checks and standard TensorWrapper usage
tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast.py Comprehensive byte-equivalence, numerical-reconstruction, and error-path tests; one unused variable (source) in test_transpose_cast_requires_block_aligned_dims
tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast_swizzled.py Tests swizzled scale layout against native transposed quantization with byte-level equality; clean and self-contained
transformer_engine/pytorch/csrc/extensions/pybind.cpp Exposes mxfp8_scaling_transpose_cast via pybind11 with correct default args matching the v1 C symbol's behavior (E4M3, non-swizzled)
transformer_engine/pytorch/csrc/extensions.h Adds forward declaration for mxfp8_scaling_transpose_cast; consistent with adjacent declarations

Sequence Diagram

sequenceDiagram
    participant Caller
    participant PyHelper as MXFP8Quantizer.quantize_rowwise_transpose
    participant PyBind as tex.mxfp8_scaling_transpose_cast
    participant CPP as mxfp8_scaling_transpose_cast (C++)
    participant CSymV2 as nvte_mxfp8_scaling_transpose_cast_v2
    participant ScaleKernel as mxfp8_scaling_transpose_scales_kernel
    participant CastKernel as mxfp8_scaling_transpose_cast_kernel

    Caller->>PyHelper: tensor, columnwise_scale_inv
    PyHelper->>PyHelper: validate shapes, allocate outputs
    PyHelper->>PyBind: source_2d, colwise_scale_inv, out_data, out_scale, rows, cols, fp8_dtype, swizzled
    PyBind->>CSymV2: nvte tensors + params
    CSymV2->>CPP: Tensor wrappers
    CPP->>ScaleKernel: transpose E8M0 scale bytes (compact or swizzled)
    CPP->>CastKernel: cast+transpose FP8 payload tiles
    CastKernel-->>CPP: rowwise_data (cols x rows)
    ScaleKernel-->>CPP: rowwise_scale_inv (transposed)
    CPP-->>PyHelper: filled output buffers
    PyHelper->>Caller: MXFP8Tensor(shape=(cols,rows), rowwise_only)
Loading

Comments Outside Diff (2)

  1. transformer_engine/common/recipe/mxfp8_scaling.cu, line 517-524 (link)

    P2 Missing NVTE_API_CALL in v1 entry point

    nvte_mxfp8_scaling_transpose_cast does not call NVTE_API_CALL, so when a caller uses the stable v1 symbol, API profiling/tracing tools will attribute the invocation to nvte_mxfp8_scaling_transpose_cast_v2 instead of the actual entry point. Every other public nvte_ symbol in this file (e.g. nvte_mxfp8_scaling_partial_cast, nvte_mxfp8_scaling_compute_partial_amax) has its own NVTE_API_CALL entry.

  2. tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast.py, line 203-214 (link)

    P2 Unused variable source in test

    source = _make_source(64, 128) on the first line allocates a CUDA tensor that is never referenced in the test body; only bad_source and bad_scale are used. This may confuse future readers into thinking the valid-source fixture is part of the assertion, and wastes a small amount of device memory during the test run.

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +389 to +394
*convertNVTETensorCheck(input), *convertNVTETensorCheck(scale_inv_colwise),
*convertNVTETensorCheck(output_rowwise), *convertNVTETensorCheck(output_rowwise_scale_inv),
rows, cols, static_cast<DType>(fp8_dtype), with_gemm_swizzled_scales, stream);
}

void nvte_mxfp8_scaling_transpose_cast(const NVTETensor input, const NVTETensor scale_inv_colwise,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing __launch_bounds__ on transpose cast kernel

All other __global__ kernels in this file (mxfp8_scaling_compute_partial_amax_kernel, mxfp8_scaling_partial_cast_kernel, mxfp8_scaling_transpose_scales_kernel) carry __launch_bounds__(kThreadsPerBlock). This kernel is launched with block = dim3(kTransposeTileDim, kTransposeTileDim) = 256 threads, so the appropriate hint would be __launch_bounds__(kTransposeTileDim * kTransposeTileDim). Without it the compiler cannot optimize register allocation for the stated block size.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants