Skip to content

[PyTorch] Avoid removing usages from quantized weight tensors#2929

Open
timmoon10 wants to merge 6 commits intoNVIDIA:mainfrom
timmoon10:tmoon/debug-alternate-train-infer
Open

[PyTorch] Avoid removing usages from quantized weight tensors#2929
timmoon10 wants to merge 6 commits intoNVIDIA:mainfrom
timmoon10:tmoon/debug-alternate-train-infer

Conversation

@timmoon10
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 commented Apr 25, 2026

Description

We have experienced a correctness error in the grouped linear op when alternating between training and validation. During validation steps, we configure the weight quantizer without column-wise usage, which can cause quantized weight tensors to have stale column-wise data in the next training step. This PR makes sure to avoid disabling quantizer usages in quantized weight tensors, since a usage that was required in the past may be required in the future.

Related: #2222

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Avoid removing usages from quantized weight tensors in basic linear op
  • Avoid removing usages from quantized weight tensors in grouped linear op
  • Add tests for the linear op that alternate between training and inference steps

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Quantized weight tensor may be used across steps, so removing a usage is not safe.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added bug Something isn't working 2.15.0 labels Apr 25, 2026
Comment on lines -332 to -333
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment was made incorrect in #1817.

@timmoon10

This comment was marked as outdated.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 25, 2026

Greptile Summary

This PR fixes a correctness bug where quantized weight tensors could hold stale column-wise data after alternating between training and inference steps. The root cause was that pre_fuser_forward always set columnwise=False on the weight quantizer, and reset_recipe_state could overwrite an enabled column-wise usage with a disabled one; the fix changes the usage to columnwise=requires_grad and makes usage updates monotonically increasing (never disabling what was previously set).

Confidence Score: 5/5

Safe to merge — targeted, well-reasoned bug fix with no API-breaking changes and direct regression tests.

No P0 or P1 issues found. The monotonic-enable logic is self-consistent across both basic and grouped linear ops. The asymmetry between requires_grad (weight quantizer column-wise) and weight_requires_grad (input/grad-output quantizer column-wise) is intentional and correct: column-wise weight data is needed to compute input gradients regardless of whether the weight itself has requires_grad=True. The single_grouped_weight dead-code correction is a clean refactor with no side effects.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/basic/basic_linear.py Core fix: weight quantizer now uses columnwise=requires_grad in pre_fuser_forward (was always False), and reset_recipe_state monotonically enables — never disables — column-wise usage on the quantized weight tensor.
transformer_engine/pytorch/ops/basic/grouped_linear.py Same monotonic-usage fix applied; also corrects dead code path where the single_grouped_weight branch inside is_quantized_tensor(weight{group_idx}) was never reachable because weightN attributes don't exist when single_grouped_weight is True.
tests/pytorch/test_fusible_ops.py Adds TestTrainingLoops with train/infer/train and infer/train/infer stage sequences; promotes to_cpu helper to module scope and extends it to handle QuantizedTensor dequantization.

Sequence Diagram

sequenceDiagram
    participant T as Training step
    participant I as Inference step
    participant PFF as pre_fuser_forward
    participant RRS as reset_recipe_state
    participant WQ as weight_quantizer
    participant WT as weight tensor

    Note over T,WT: Old behaviour (buggy)
    T->>PFF: requires_grad=True
    PFF->>WQ: set_usage(rowwise=True, columnwise=False)
    WQ->>WT: column-wise NOT cached
    I->>PFF: requires_grad=False
    PFF->>WQ: set_usage(rowwise=True, columnwise=False)
    WQ->>WT: column-wise NOT refreshed (stale)
    T->>PFF: requires_grad=True
    PFF->>WQ: set_usage(rowwise=True, columnwise=False)
    Note over WT: backward pass uses stale column-wise data — correctness error

    Note over T,WT: New behaviour (fixed)
    T->>PFF: requires_grad=True
    PFF->>WQ: set_usage(rowwise=True, columnwise=True)
    WQ->>WT: column-wise cached
    I->>PFF: requires_grad=False
    PFF->>WQ: set_usage(rowwise=True, columnwise=False)
    Note over RRS: reset_recipe_state preserves columnwise if previously set
    RRS->>WQ: columnwise stays True (monotonic)
    WQ->>WT: column-wise retained
    T->>PFF: requires_grad=True
    PFF->>WQ: set_usage(rowwise=True, columnwise=True)
    WQ->>WT: column-wise refreshed
Loading

Reviews (2): Last reviewed commit: "Restore pre-forward quantizer config in ..." | Re-trigger Greptile

Turns out we still need this in case the quantizer is used before the forward, e.g. in previous ops or CPU offloading.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10

This comment was marked as outdated.

1 similar comment
@timmoon10
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

if group_idx == 0:
weight.quantizer = weight_quantizer.copy()
else:
weight.update_quantizer(weight_quantizer.copy())
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldnt we update the quantized weight data itself as well to respect the quantizer usages here?

Comment on lines +678 to +701
# Update quantizer in quantized weight tensor
if weight_quantizer is not None and weight_is_quantized:
# Get quantizer from weight tensor
weight_tensor_quantizer = (
weight.quantizer if self.single_grouped_weight else weight._quantizer
)

# Set quantizer usages
# Note: Avoid disabling usages that are already set. The
# weight tensor may be reused across steps, so future
# steps may need usages that are currently unnecessary.
weight_quantizer.set_usage(rowwise=True)
columnwise_usage = torch.is_grad_enabled()
if weight_tensor_quantizer is not None and weight_tensor_quantizer.columnwise_usage:
columnwise_usage = True
if columnwise_usage:
weight_quantizer.set_usage(columnwise=True)

# Update weight tensor
if self.single_grouped_weight:
if group_idx == 0:
weight.quantizer = weight_quantizer.copy()
else:
weight.update_quantizer(weight_quantizer.copy())
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this change of updating quantizer in case of quantized_model_init?

In general, for the quantized_model_init case, I am not sure if it even makes sense to modify the quantizer usages of the quantized_tensor ever during the lifecyle of a module being created.

For example lets say if we create module under torch.no_grad context manager. columnwise_usage will be set to False and if we enter a training loop, we ll be modifying the quantizer usage without modifying the parameters. Leading to quantizer and quantized tensor being in an inconsistent state.

Now in that case we can technically dequantize and quantize it back to have both usages, but we suffer dequantization errors in those case(which might be ok?).

I am wondering should we even touch the quantizer usages in case of quantized_model_init at all after the module is initialized?(for the quantized model parameters)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.15.0 bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants