Source code for starling.models.normalization

import torch
from torch import nn
from torch.nn import functional as F


[docs] class RMSNorm(nn.Module):
[docs] def __init__(self, dim): super().__init__() self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
[docs] def forward(self, x): return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5)