diff --git a/fairseq/modules/positional_embedding.py b/fairseq/modules/positional_embedding.py index 97cd474b51..fbc13d80ac 100644 --- a/fairseq/modules/positional_embedding.py +++ b/fairseq/modules/positional_embedding.py @@ -14,6 +14,7 @@ def PositionalEmbedding( embedding_dim: int, padding_idx: int, learned: bool = False, + auto_expand: bool = True, ): if learned: # if padding_idx is specified then offset the embedding ids by @@ -31,5 +32,6 @@ def PositionalEmbedding( embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1, + auto_expand=auto_expand, ) return m diff --git a/fairseq/modules/sinusoidal_positional_embedding.py b/fairseq/modules/sinusoidal_positional_embedding.py index e7ecd0f2c8..dd93ddc397 100644 --- a/fairseq/modules/sinusoidal_positional_embedding.py +++ b/fairseq/modules/sinusoidal_positional_embedding.py @@ -18,14 +18,19 @@ class SinusoidalPositionalEmbedding(nn.Module): Padding symbols are ignored. """ - def __init__(self, embedding_dim, padding_idx, init_size=1024): + def __init__(self, embedding_dim, padding_idx, init_size=1024, auto_expand=True): super().__init__() self.embedding_dim = embedding_dim self.padding_idx = padding_idx if padding_idx is not None else 0 - self.register_buffer("weights", SinusoidalPositionalEmbedding.get_embedding( - init_size, embedding_dim, padding_idx - ), persistent=False) + self.register_buffer( + "weights", + SinusoidalPositionalEmbedding.get_embedding( + init_size, embedding_dim, padding_idx + ), + persistent=False, + ) self.max_positions = int(1e5) + self.auto_expand = auto_expand self.onnx_trace = False def prepare_for_onnx_export_(self): @@ -75,28 +80,36 @@ def forward( bspair = torch.onnx.operators.shape_as_tensor(input) bsz, seq_len = bspair[0], bspair[1] max_pos = self.padding_idx + 1 + seq_len + weights = self.weights + if max_pos > self.weights.size(0): - # expand embeddings if needed - self.weights = SinusoidalPositionalEmbedding.get_embedding( + # If the input is longer than the number of pre-computed embeddings, + # compute the extra embeddings on the fly. + # Only store the expanded embeddings if auto_expand=True. + # In multithreading environments, mutating the weights of a module + # may cause trouble. Set auto_expand=False if this happens. + weights = SinusoidalPositionalEmbedding.get_embedding( max_pos, self.embedding_dim, self.padding_idx ).to(self.weights) + if self.auto_expand: + self.weights = weights if incremental_state is not None: # positions is the same for every token when decoding a single step pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len if self.onnx_trace: return ( - self.weights.index_select(index=self.padding_idx + pos, dim=0) + weights.index_select(index=self.padding_idx + pos, dim=0) .unsqueeze(1) .repeat(bsz, 1, 1) ) - return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) + return weights[self.padding_idx + pos, :].expand(bsz, 1, -1) positions = utils.make_positions( input, self.padding_idx, onnx_trace=self.onnx_trace ) if self.onnx_trace: - flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) + flat_embeddings = weights.detach().index_select(0, positions.view(-1)) embedding_shape = torch.cat( (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long)) ) @@ -105,7 +118,5 @@ def forward( ) return embeddings return ( - self.weights.index_select(0, positions.view(-1)) - .view(bsz, seq_len, -1) - .detach() + weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() )