import math
import torch
from torch import nn
# Non-learnable position encodings
[docs]
class PositionalEncoding1D(nn.Module):
[docs]
def __init__(self, embedding_size):
"""
Positional encoding for 1D data. The positional encoding is added to the input tensor
to provide information about the position of the elements in the input data. The positional
encoding is computed using sine and cosine functions.
Parameters
----------
embedding_size : int
The number of features of the input data.
"""
super(PositionalEncoding1D, self).__init__()
self.embedding_size = embedding_size
self.cached_encodings = {} # Cache for previously computed encodings
def _generate_positional_encoding(self, seq_len, device):
"""
Generate positional encodings dynamically based on sequence length.
Parameters
----------
seq_len : int
The length of the sequence for which to generate positional encodings.
device : torch.device
The device on which to create the encodings.
Returns
-------
torch.Tensor
Positional encodings tensor of shape (1, seq_len, embedding_size)
"""
# Initialize the positional encoding tensor with 0s
pe = torch.zeros(seq_len, self.embedding_size, device=device)
# Get the position tensor (0, 1, 2, ..., seq_len - 1)
position = torch.arange(
0, seq_len, dtype=torch.float32, device=device
).unsqueeze(1)
# Compute divisor term for the positional encodings
div_term = torch.exp(
torch.arange(0, self.embedding_size, 2, dtype=torch.float32, device=device)
* (-torch.log(torch.tensor(10000.0, device=device)) / self.embedding_size)
)
# Assigns sine values to even indices in the last dimension
pe[:, 0::2] = torch.sin(position * div_term)
# Assigns cosine values to odd indices in the last dimension
pe[:, 1::2] = torch.cos(position * div_term)
# Add batch dimension
pe = pe.unsqueeze(0)
return pe
[docs]
def forward(self, x):
"""
Add positional encodings to the input tensor.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, seq_len, embedding_size)
Returns
-------
torch.Tensor
Input tensor with positional encodings added
"""
seq_len = x.size(1)
# Check if we have cached this sequence length
cache_key = f"{seq_len}_{x.device}"
if cache_key not in self.cached_encodings:
# Generate and cache the positional encoding for this sequence length
self.cached_encodings[cache_key] = self._generate_positional_encoding(
seq_len, x.device
)
# Limit cache size to prevent memory issues
if len(self.cached_encodings) > 10: # Arbitrary limit, adjust as needed
# Remove a random key (simple approach)
remove_key = next(iter(self.cached_encodings))
if remove_key != cache_key: # Don't remove what we just added
del self.cached_encodings[remove_key]
# Get the positional encoding from cache
pe = self.cached_encodings[cache_key]
# Add positional encoding to the input tensor
return x + pe[:, :seq_len, :]
[docs]
class PositionalEncoding2D(nn.Module):
[docs]
def __init__(self, embed_dim: int):
"""
Positional encoding for 2D data using alternating sine and cosine.
Parameters
----------
embed_dim : int
The number of embedding dimensions (channels)
"""
super(PositionalEncoding2D, self).__init__()
self.embed_dim = embed_dim
self.cached_encodings = {} # Cache for previously computed encodings
[docs]
def forward(self, x):
b, c, h, w = x.shape
# Check cache for this resolution
cache_key = f"{h}_{w}_{x.device}"
if cache_key not in self.cached_encodings:
self.cached_encodings[cache_key] = self.generate_pe(h, w, x.device)
# Limit cache size
if len(self.cached_encodings) > 10:
remove_key = next(iter(self.cached_encodings))
if remove_key != cache_key:
del self.cached_encodings[remove_key]
pe = self.cached_encodings[cache_key]
return x + pe
[docs]
def generate_pe(self, height, width, device):
"""
Generate 2D positional encodings with both sine and cosine functions.
"""
# Make sure embed_dim is divisible by 4
if self.embed_dim % 4 != 0:
raise ValueError(
f"Embedding dimension must be divisible by 4, got {self.embed_dim}"
)
# Each dimension gets 1/4 of channels for sin and 1/4 for cos
dim_t = self.embed_dim // 4
# Position tensors
# [height, 1]
y_pos = torch.arange(height, device=device).float().view(height, 1)
# [1, width]
x_pos = torch.arange(width, device=device).float().view(1, width)
# Frequencies for different dimensions
freq = torch.exp(
torch.arange(0, dim_t, dtype=torch.float32, device=device)
* (-math.log(10000.0) / dim_t)
).view(dim_t, 1, 1) # [dim_t, 1, 1]
# Calculate encodings
pos_x_enc = x_pos.expand(height, -1) # [height, width]
pos_y_enc = y_pos.expand(-1, width) # [height, width]
# Apply frequency bands to positions
pos_x_enc = pos_x_enc.unsqueeze(0) * freq # [dim_t, height, width]
pos_y_enc = pos_y_enc.unsqueeze(0) * freq # [dim_t, height, width]
# Initialize positional encoding
pe = torch.zeros(1, self.embed_dim, height, width, device=device)
# X dimension - sin and cos
pe[0, :dim_t] = torch.sin(pos_x_enc)
pe[0, dim_t : 2 * dim_t] = torch.cos(pos_x_enc)
# Y dimension - sin and cos
pe[0, 2 * dim_t : 3 * dim_t] = torch.sin(pos_y_enc)
pe[0, 3 * dim_t :] = torch.cos(pos_y_enc)
return pe
# Learnable positional encodings
[docs]
class LearnablePositionalEncoding1D(nn.Module):
[docs]
def __init__(self, sequence_length, embed_dim):
super(LearnablePositionalEncoding1D, self).__init__()
self.sequence_length = sequence_length
self.embed_dim = embed_dim
self.positional_encoding = nn.Parameter(
torch.randn(1, sequence_length, embed_dim)
)
[docs]
def forward(self, x):
if self.positional_encoding.device != x.device:
self.positional_encoding = self.positional_encoding.to(x.device)
return x + self.positional_encoding
[docs]
class LearnablePositionalEncoding2D(nn.Module):
[docs]
def __init__(self, height, width, embed_dim):
super(LearnablePositionalEncoding2D, self).__init__()
self.height = height
self.width = width
self.embed_dim = embed_dim
self.positional_encoding = nn.Parameter(
torch.randn(1, embed_dim, height, width)
)
[docs]
def forward(self, x):
if self.positional_encoding.device != x.device:
self.positional_encoding = self.positional_encoding.to(x.device)
return x + self.positional_encoding