[PyTorch] Fix stale columnwise data usage#2925
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Greptile SummaryThis PR fixes stale
Confidence Score: 3/5Mostly safe bug fix, but one module-level regression in linear.py drops a defensive assignment that guards against an AttributeError in quantize_weight. The core fix (elif→if and FSDP2 tensor changes) is correct and well-tested for the primary code paths. A P1 regression in linear.py removes a guard for the weight_quantizer=None + QuantizedTensor scenario, turning a silent fallback into a potential crash. While this path may be rarely hit in practice, the asymmetry with how layernorm_mlp.py and layernorm_linear.py handle the same pattern makes linear.py's divergence worth addressing. transformer_engine/pytorch/module/linear.py — dropped defensive quantizer assignment before quantize_weight call Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Forward pass enters module] --> B{weight is QuantizedTensorStorage AND not debug?}
B -- Yes --> C[weight_quantizer = weight._quantizer]
B -- No --> D{weight_quantizer is not None?}
C --> D
D -- Yes --> E[set_usage rowwise=True columnwise=is_grad_enabled ...]
D -- No --> F[skip set_usage]
E --> G[quantize_weight]
F --> G
G --> H{FSDP2 reshard_after_forward?}
H -- Yes --> I[columnwise = is_backward_pass]
H -- No --> J[columnwise = is_backward_pass OR grad_enabled]
J --> K{columnwise=True but _columnwise_data is None?}
K -- Yes --> L[RuntimeError raised - mxfp8_tensor only]
K -- No --> M[all-gather sharded tensors]
I --> M
M --> N[GEMM forward]
|
| if weight_quantizer is not None: | ||
| if isinstance(weight, QuantizedTensor) and not debug: | ||
| weight_quantizer = weight._quantizer |
There was a problem hiding this comment.
Dropped defensive
weight_quantizer assignment loses the quantize_weight call
The original elif isinstance(weight, QuantizedTensor): weight_quantizer = weight._quantizer handled the case where weight_quantizer arrives as None while weight is already a QuantizedTensor. In that path, quantize_weight immediately dereferences quantizer.rowwise_usage (line 710 of base.py) and will raise AttributeError: 'NoneType' object has no attribute 'rowwise_usage'.
The new code only re-assigns weight_quantizer when it is already non-None, so the previously guarded scenario now crashes instead of falling back to the weight's own quantizer. The missing assignment should be:
if weight_quantizer is not None:
if isinstance(weight, QuantizedTensor) and not debug:
weight_quantizer = weight._quantizer
columnwise_usage = ...
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
elif isinstance(weight, QuantizedTensor):
# weight_quantizer is None but weight is pre-quantized — pick up its quantizer
weight_quantizer = weight._quantizer| columnwise_usage = is_backward_pass or torch.is_grad_enabled() | ||
| sharded_tensors = (self._rowwise_data, rowwise_scale_inv) | ||
| columnwise_usage = self._quantizer.columnwise_usage | ||
| if columnwise_usage: | ||
| if self._columnwise_data is None or columnwise_scale_inv is None: | ||
| raise RuntimeError( | ||
| "FSDP2 (reshard_after_forward=False) needs columnwise MXFP8 data " | ||
| "for the upcoming backward pass, but the local shard has none. " | ||
| "Ensure the weight is quantized with columnwise_usage=True before " | ||
| "this all-gather." | ||
| ) |
There was a problem hiding this comment.
torch.is_grad_enabled() can be True during eval without torch.no_grad()
model.eval() alone does not disable the gradient tape — torch.is_grad_enabled() stays True unless the caller wraps the eval loop with torch.no_grad(). In that situation columnwise_usage becomes True, but the local shard may not have _columnwise_data (it was never quantized with columnwise support during eval), so the new RuntimeError fires.
Users who ran eval with grads enabled previously got silently incorrect (stale) data; they now get a hard crash. While the crash is more correct, the error message could guide them:
| columnwise_usage = is_backward_pass or torch.is_grad_enabled() | |
| sharded_tensors = (self._rowwise_data, rowwise_scale_inv) | |
| columnwise_usage = self._quantizer.columnwise_usage | |
| if columnwise_usage: | |
| if self._columnwise_data is None or columnwise_scale_inv is None: | |
| raise RuntimeError( | |
| "FSDP2 (reshard_after_forward=False) needs columnwise MXFP8 data " | |
| "for the upcoming backward pass, but the local shard has none. " | |
| "Ensure the weight is quantized with columnwise_usage=True before " | |
| "this all-gather." | |
| ) | |
| if self._columnwise_data is None or columnwise_scale_inv is None: | |
| raise RuntimeError( | |
| "FSDP2 (reshard_after_forward=False) needs columnwise MXFP8 data " | |
| "for the upcoming backward pass, but the local shard has none. " | |
| "Ensure the weight is quantized with columnwise_usage=True before " | |
| "this all-gather. If you are running evaluation without requiring " | |
| "gradients, wrap the eval loop with torch.no_grad()." | |
| ) |
| # 2. Eval forward -- should set columnwise=False on primary FP8 weight | ||
| # quantizers, simulating the start of an evaluation loop. | ||
| run_forward(is_eval=True) | ||
| for q in get_weight_quantizers(): |
There was a problem hiding this comment.
The test doesnt make sense to me. I dont think we should be toggling the quantizer usages in case of quantized_model_init at all.
This breaks the very principle that quantized tensor and its internal quantizer shouldnt be in conflict with each other. And in here, the columnwise_data is present for the quantized_tensor even though the columnwise_usage is set to False.
| if isinstance(weights[0], QuantizedTensorStorage) and not debug: | ||
| weight_quantizers = [weight._quantizer for weight in weights] | ||
| for weight_quantizer in weight_quantizers: |
There was a problem hiding this comment.
I dont think these changes are needed for this file and any of the other files in the PR. If weight is already quantized, it doesnt make sense to change its internal quantizer and have the quantized weight and its internal quantizer in a state of conflict with each other.
In general in case of quantized_model_init, if we are changing quantized_tensor's internal quantizer, quantized_tensor should also be updated to have that appropriate usages.
| else: | ||
| rowwise_usage = True | ||
| columnwise_usage = self._quantizer.columnwise_usage | ||
| columnwise_usage = is_backward_pass or torch.is_grad_enabled() |
There was a problem hiding this comment.
I think it still makes sense to use self._quantize.columwise_usage as the real truth of what data is
"really avaliable" in the sharded quantized tensor and throw an error if that usage doesnt match
is_backward_pass or torch.is_grad_enabled(Similar to mxfp8 tensor)
What we are doing here is that we are silently creating columnwise data after allgather for allgathered tensor, even though original sharded data tensor didnt have that data.
In my opinion, I am against any change here since even doing such a validation and throwing error is going to incur CPU overheads when using
torch.is_grad_enabled
There was a problem hiding this comment.
Same comment in every other FSDP2 related changes
Description
This PR sets columnwise usage correctly for all quantizers instead of retaining the value in the quantizer, which may be incorrect after resuming training post validation steps as the columnwise usage is set to
Falsefor eval mode.Type of change
Changes
Checklist: