-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss_function.py
163 lines (137 loc) · 6 KB
/
loss_function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import torch
from torch import nn
from utils import make_non_pad_mask
class Tacotron2Loss(nn.Module):
def __init__(self):
super(Tacotron2Loss, self).__init__()
def forward(self, model_output, targets):
mel_target, gate_target = targets[0], targets[1]
mel_target.requires_grad = False
gate_target.requires_grad = False
gate_target = gate_target.view(-1, 1)
_, mel_out, mel_out_postnet, gate_out, _ = model_output
gate_out = gate_out.view(-1, 1)
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
nn.MSELoss()(mel_out_postnet, mel_target)
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
return mel_loss + gate_loss
class GuidedAttentionLoss(nn.Module):
"""Guided attention loss function module.
See https://github.com/espnet/espnet/blob/e962a3c609ad535cd7fb9649f9f9e9e0a2a27291/espnet/nets/pytorch_backend/e2e_tts_tacotron2.py#L25
This module calculates the guided attention loss described
in `Efficiently Trainable Text-to-Speech System Based
on Deep Convolutional Networks with Guided Attention`_,
which forces the attention to be diagonal.
.. _`Efficiently Trainable Text-to-Speech System
Based on Deep Convolutional Networks with Guided Attention`:
https://arxiv.org/abs/1710.08969
"""
def __init__(self, sigma=0.4, alpha=1.0, reset_always=True):
"""Initialize guided attention loss module.
Args:
sigma (float, optional): Standard deviation to control
how close attention to a diagonal.
alpha (float, optional): Scaling coefficient (lambda).
reset_always (bool, optional): Whether to always reset masks.
"""
super(GuidedAttentionLoss, self).__init__()
self.sigma = sigma
self.alpha = alpha
self.reset_always = reset_always
self.guided_attn_masks = None
self.masks = None
def _reset_masks(self):
self.guided_attn_masks = None
self.masks = None
def forward(self, att_ws, ilens, olens):
"""Calculate forward propagation.
Args:
att_ws (Tensor): Batch of attention weights (B, T_max_out, T_max_in).
ilens (LongTensor): Batch of input lenghts (B,).
olens (LongTensor): Batch of output lenghts (B,).
Returns:
Tensor: Guided attention loss value.
"""
if self.guided_attn_masks is None:
self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens).to(
att_ws.device
)
if self.masks is None:
self.masks = self._make_masks(ilens, olens).to(att_ws.device)
losses = self.guided_attn_masks * att_ws
loss = torch.mean(losses.masked_select(self.masks))
if self.reset_always:
self._reset_masks()
return self.alpha * loss
def _make_guided_attention_masks(self, ilens, olens):
n_batches = len(ilens)
max_ilen = max(ilens)
max_olen = max(olens)
guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen))
for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(
ilen, olen, self.sigma
)
return guided_attn_masks
@staticmethod
def _make_guided_attention_mask(ilen, olen, sigma):
"""Make guided attention mask.
Examples:
>>> guided_attn_mask =_make_guided_attention(5, 5, 0.4)
>>> guided_attn_mask.shape
torch.Size([5, 5])
>>> guided_attn_mask
tensor([[0.0000, 0.1175, 0.3935, 0.6753, 0.8647],
[0.1175, 0.0000, 0.1175, 0.3935, 0.6753],
[0.3935, 0.1175, 0.0000, 0.1175, 0.3935],
[0.6753, 0.3935, 0.1175, 0.0000, 0.1175],
[0.8647, 0.6753, 0.3935, 0.1175, 0.0000]])
>>> guided_attn_mask =_make_guided_attention(3, 6, 0.4)
>>> guided_attn_mask.shape
torch.Size([6, 3])
>>> guided_attn_mask
tensor([[0.0000, 0.2934, 0.7506],
[0.0831, 0.0831, 0.5422],
[0.2934, 0.0000, 0.2934],
[0.5422, 0.0831, 0.0831],
[0.7506, 0.2934, 0.0000],
[0.8858, 0.5422, 0.0831]])
"""
grid_x, grid_y = torch.meshgrid(torch.arange(olen), torch.arange(ilen))
grid_x, grid_y = grid_x.float().to(olen.device), grid_y.float().to(ilen.device)
return 1.0 - torch.exp(
-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2))
)
@staticmethod
def _make_masks(ilens, olens):
"""Make masks indicating non-padded part.
Args:
ilens (LongTensor or List): Batch of lengths (B,).
olens (LongTensor or List): Batch of lengths (B,).
Returns:
Tensor: Mask tensor indicating non-padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
>>> ilens, olens = [5, 2], [8, 5]
>>> _make_mask(ilens, olens)
tensor([[[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]],
[[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
in_masks = make_non_pad_mask(ilens) # (B, T_in)
out_masks = make_non_pad_mask(olens) # (B, T_out)
return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in)