[PyTorch] Avoid removing usages from quantized weight tensors#2929
[PyTorch] Avoid removing usages from quantized weight tensors#2929timmoon10 wants to merge 6 commits intoNVIDIA:mainfrom
Conversation
Quantized weight tensor may be used across steps, so removing a usage is not safe. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
| # Note: We cache the quantized input for backward pass, | ||
| # but discard the quantized weights. |
There was a problem hiding this comment.
This comment was made incorrect in #1817.
This comment was marked as outdated.
This comment was marked as outdated.
Greptile SummaryThis PR fixes a correctness bug where quantized weight tensors could hold stale column-wise data after alternating between training and inference steps. The root cause was that Confidence Score: 5/5Safe to merge — targeted, well-reasoned bug fix with no API-breaking changes and direct regression tests. No P0 or P1 issues found. The monotonic-enable logic is self-consistent across both basic and grouped linear ops. The asymmetry between No files require special attention. Important Files Changed
Sequence DiagramsequenceDiagram
participant T as Training step
participant I as Inference step
participant PFF as pre_fuser_forward
participant RRS as reset_recipe_state
participant WQ as weight_quantizer
participant WT as weight tensor
Note over T,WT: Old behaviour (buggy)
T->>PFF: requires_grad=True
PFF->>WQ: set_usage(rowwise=True, columnwise=False)
WQ->>WT: column-wise NOT cached
I->>PFF: requires_grad=False
PFF->>WQ: set_usage(rowwise=True, columnwise=False)
WQ->>WT: column-wise NOT refreshed (stale)
T->>PFF: requires_grad=True
PFF->>WQ: set_usage(rowwise=True, columnwise=False)
Note over WT: backward pass uses stale column-wise data — correctness error
Note over T,WT: New behaviour (fixed)
T->>PFF: requires_grad=True
PFF->>WQ: set_usage(rowwise=True, columnwise=True)
WQ->>WT: column-wise cached
I->>PFF: requires_grad=False
PFF->>WQ: set_usage(rowwise=True, columnwise=False)
Note over RRS: reset_recipe_state preserves columnwise if previously set
RRS->>WQ: columnwise stays True (monotonic)
WQ->>WT: column-wise retained
T->>PFF: requires_grad=True
PFF->>WQ: set_usage(rowwise=True, columnwise=True)
WQ->>WT: column-wise refreshed
Reviews (2): Last reviewed commit: "Restore pre-forward quantizer config in ..." | Re-trigger Greptile |
Turns out we still need this in case the quantizer is used before the forward, e.g. in previous ops or CPU offloading. Signed-off-by: Tim Moon <tmoon@nvidia.com>
This comment was marked as outdated.
This comment was marked as outdated.
1 similar comment
|
/te-ci pytorch L1 |
| if group_idx == 0: | ||
| weight.quantizer = weight_quantizer.copy() | ||
| else: | ||
| weight.update_quantizer(weight_quantizer.copy()) |
There was a problem hiding this comment.
Shouldnt we update the quantized weight data itself as well to respect the quantizer usages here?
| # Update quantizer in quantized weight tensor | ||
| if weight_quantizer is not None and weight_is_quantized: | ||
| # Get quantizer from weight tensor | ||
| weight_tensor_quantizer = ( | ||
| weight.quantizer if self.single_grouped_weight else weight._quantizer | ||
| ) | ||
|
|
||
| # Set quantizer usages | ||
| # Note: Avoid disabling usages that are already set. The | ||
| # weight tensor may be reused across steps, so future | ||
| # steps may need usages that are currently unnecessary. | ||
| weight_quantizer.set_usage(rowwise=True) | ||
| columnwise_usage = torch.is_grad_enabled() | ||
| if weight_tensor_quantizer is not None and weight_tensor_quantizer.columnwise_usage: | ||
| columnwise_usage = True | ||
| if columnwise_usage: | ||
| weight_quantizer.set_usage(columnwise=True) | ||
|
|
||
| # Update weight tensor | ||
| if self.single_grouped_weight: | ||
| if group_idx == 0: | ||
| weight.quantizer = weight_quantizer.copy() | ||
| else: | ||
| weight.update_quantizer(weight_quantizer.copy()) |
There was a problem hiding this comment.
Do we really need this change of updating quantizer in case of quantized_model_init?
In general, for the quantized_model_init case, I am not sure if it even makes sense to modify the quantizer usages of the quantized_tensor ever during the lifecyle of a module being created.
For example lets say if we create module under torch.no_grad context manager. columnwise_usage will be set to False and if we enter a training loop, we ll be modifying the quantizer usage without modifying the parameters. Leading to quantizer and quantized tensor being in an inconsistent state.
Now in that case we can technically dequantize and quantize it back to have both usages, but we suffer dequantization errors in those case(which might be ok?).
I am wondering should we even touch the quantizer usages in case of quantized_model_init at all after the module is initialized?(for the quantized model parameters)
Description
We have experienced a correctness error in the grouped linear op when alternating between training and validation. During validation steps, we configure the weight quantizer without column-wise usage, which can cause quantized weight tensors to have stale column-wise data in the next training step. This PR makes sure to avoid disabling quantizer usages in quantized weight tensors, since a usage that was required in the past may be required in the future.
Related: #2222
Type of change
Changes
Checklist: