Skip to content

Commit

Permalink
revert RoPE
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxMax2016 authored Sep 30, 2023
1 parent 5fd03fd commit f472b23
Showing 1 changed file with 74 additions and 81 deletions.
155 changes: 74 additions & 81 deletions grad/encoder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import math
import torch

from einops import rearrange
from grad.base import BaseModule
from grad.reversal import SpeakerClassifier
from grad.utils import sequence_mask
from grad.utils import sequence_mask, convert_pad_shape


class LayerNorm(BaseModule):
Expand Down Expand Up @@ -77,67 +76,8 @@ def calc_mean_std(self, x, mask=None):
return mn, sd


class RotaryPositionalEmbeddings(BaseModule):
"""
## RoPE module
https://github.com/labmlai/annotated_deep_learning_paper_implementations
Rotary encoding transforms pairs of features by rotating in the 2D plane.
That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
by an angle depending on the position of the token.
"""
def __init__(self, d: int, base: int = 10_000):
r"""
* `d` is the number of features $d$
* `base` is the constant used for calculating $\Theta$
"""
super().__init__()
self.base = base
self.d = int(d)
self.cos_cached = None
self.sin_cached = None

def _build_cache(self, x: torch.Tensor):
r"""
Cache $\cos$ and $\sin$ values
"""
# Return if cache is already built
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
return
# Get sequence length
seq_len = x.shape[0]
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
# Calculate the product of position index and $\theta_i$
idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
# Concatenate so that for row $m$ we have
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
# Cache them
self.cos_cached = idx_theta2.cos()[:, None, None, :]
self.sin_cached = idx_theta2.sin()[:, None, None, :]

def _neg_half(self, x: torch.Tensor):
d_2 = self.d // 2
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)

def forward(self, x: torch.Tensor):
"""
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
"""
x = rearrange(x, "b h t d -> t b h d")
self._build_cache(x)
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
x_rope, x_pass = x[..., : self.d], x[..., self.d :]
# Calculate
neg_half_x = self._neg_half(x_rope)
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")


class MultiHeadAttention(BaseModule):
def __init__(self, channels, out_channels, n_heads,
def __init__(self, channels, out_channels, n_heads, window_size=None,
heads_share=True, p_dropout=0.0, proximal_bias=False,
proximal_init=False):
super(MultiHeadAttention, self).__init__()
Expand All @@ -146,6 +86,7 @@ def __init__(self, channels, out_channels, n_heads,
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.window_size = window_size
self.heads_share = heads_share
self.proximal_bias = proximal_bias
self.p_dropout = p_dropout
Expand All @@ -155,11 +96,13 @@ def __init__(self, channels, out_channels, n_heads,
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
self.conv_v = torch.nn.Conv1d(channels, channels, 1)

# from https://nn.labml.ai/transformers/rope/index.html
self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)

if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel,
window_size * 2 + 1, self.k_channels) * rel_stddev)
self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel,
window_size * 2 + 1, self.k_channels) * rel_stddev)
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
self.drop = torch.nn.Dropout(p_dropout)

Expand All @@ -169,28 +112,31 @@ def __init__(self, channels, out_channels, n_heads,
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
torch.nn.init.xavier_uniform_(self.conv_v.weight)

def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)

x, self.attn = self.attention(q, k, v, mask=attn_mask)

x = self.conv_o(x)
return x

def attention(self, query, key, value, mask=None):
b, d, t_s, t_t = (*key.size(), query.size(2))
query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)

query = self.query_rotary_pe(query)
key = self.key_rotary_pe(key)
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)

scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)

if self.window_size is not None:
assert t_s == t_t, "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
rel_logits = self._relative_position_to_absolute_position(rel_logits)
scores_local = rel_logits / math.sqrt(self.k_channels)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device,
Expand All @@ -200,9 +146,52 @@ def attention(self, query, key, value, mask=None):
p_attn = torch.nn.functional.softmax(scores, dim=-1)
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights,
value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(b, d, t_t)
return output, p_attn

def _matmul_with_relative_values(self, x, y):
ret = torch.matmul(x, y.unsqueeze(0))
return ret

def _matmul_with_relative_keys(self, x, y):
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret

def _get_relative_embeddings(self, relative_embeddings, length):
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = torch.nn.functional.pad(
relative_embeddings, convert_pad_shape([[0, 0],
[pad_length, pad_length], [0, 0]]))
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[:,
slice_start_position:slice_end_position]
return used_relative_embeddings

def _relative_position_to_absolute_position(self, x):
batch, heads, length, _ = x.size()
x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]]))
x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
return x_final

def _absolute_position_to_relative_position(self, x):
batch, heads, length, _ = x.size()
x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
x_flat = x.view([batch, heads, length**2 + length*(length - 1)])
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
return x_final

def _attention_bias_proximal(self, length):
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
Expand Down Expand Up @@ -235,23 +224,24 @@ def forward(self, x, x_mask):

class Encoder(BaseModule):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers,
kernel_size=1, p_dropout=0.0, **kwargs):
kernel_size=1, p_dropout=0.0, window_size=None, **kwargs):
super(Encoder, self).__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size

self.drop = torch.nn.Dropout(p_dropout)
self.attn_layers = torch.nn.ModuleList()
self.norm_layers_1 = torch.nn.ModuleList()
self.ffn_layers = torch.nn.ModuleList()
self.norm_layers_2 = torch.nn.ModuleList()
for _ in range(self.n_layers):
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels,
n_heads, p_dropout=p_dropout))
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels,
n_heads, window_size=window_size, p_dropout=p_dropout))
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(FFN(hidden_channels, hidden_channels,
filter_channels, kernel_size, p_dropout=p_dropout))
Expand All @@ -278,7 +268,8 @@ def __init__(self, n_vecs, n_mels, n_embs,
n_heads=2,
n_layers=6,
kernel_size=3,
p_dropout=0.1):
p_dropout=0.1,
window_size=4):
super(TextEncoder, self).__init__()
self.n_vecs = n_vecs
self.n_mels = n_mels
Expand All @@ -289,6 +280,7 @@ def __init__(self, n_vecs, n_mels, n_embs,
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size

self.prenet = ConvReluNorm(n_vecs,
n_channels,
Expand All @@ -307,7 +299,8 @@ def __init__(self, n_vecs, n_mels, n_embs,
n_heads,
n_layers,
kernel_size,
p_dropout)
p_dropout,
window_size=window_size)

self.proj_m = torch.nn.Conv1d(n_channels + n_embs + n_embs, n_mels, 1)

Expand Down

0 comments on commit f472b23

Please sign in to comment.