Source code for starling.models.ema

import torch
import torch.nn as nn


[docs] class EMA:
[docs] def __init__(self, beta: float = 0.999): self.beta = beta self.step = 0
[docs] def update_model_average(self, ma_model, current_model): for current_params, ma_params in zip( current_model.parameters(), ma_model.parameters() ): ema_model_weights, current_weighs = ma_params.data, current_params.data ma_params.data = self.update_average(ema_model_weights, current_weighs)
[docs] def update_average(self, ema_model_params, current_model_params): if ema_model_params is None: return current_model_params return ema_model_params * self.beta + (1 - self.beta) * current_model_params
[docs] def step_ema(self, ema_model, model, step_start_ema=2000): if self.step < step_start_ema: self.reset_parameters(ema_model, model) self.step += 1 return self.update_model_average(ema_model, model) self.step += 1
[docs] def reset_parameters(self, ema_model, model): ema_model.load_state_dict(model.state_dict())