diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py index c8c75af435..4f35a14c5b 100644 --- a/ldm/modules/ema.py +++ b/ldm/modules/ema.py @@ -3,14 +3,14 @@ class LitEma(nn.Module): - def __init__(self, model, decay=0.9999, use_num_upates=True): + def __init__(self, model, decay=0.9999, use_num_updates=True): super().__init__() if decay < 0.0 or decay > 1.0: raise ValueError('Decay must be between 0 and 1') self.m_name2s_name = {} self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) - self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_updates else torch.tensor(-1,dtype=torch.int)) for name, p in model.named_parameters():