import torch
from torch import nn
from torch.nn import functional as F
[docs]
class RMSNorm(nn.Module):
[docs]
def __init__(self, dim):
super().__init__()
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
[docs]
def forward(self, x):
return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5)