Skip to content

Commit

Permalink
Merge pull request #570 from KevinMusgrave/dev
Browse files Browse the repository at this point in the history
v1.7.0
  • Loading branch information
Kevin Musgrave authored Jan 16, 2023
2 parents ad3f3c9 + 774df78 commit ef755ab
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 11 deletions.
18 changes: 12 additions & 6 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
[flake8]

extend-ignore =
E266 # too many leading '#' for block comment
E203 # whitespace before ':'
E402 # module level import not at top of file
E501 # line too long
E741 # ambiguous variable names
E265 # block comment should start with #
# too many leading '#' for block comment
E266
# whitespace before ':'
E203
# module level import not at top of file
E402
# line too long
E501
# ambiguous variable names
E741
# block comment should start with #
E265

per-file-ignores =
__init__.py:F401
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.6.3"
__version__ = "1.7.0"
14 changes: 13 additions & 1 deletion src/pytorch_metric_learning/losses/arcface_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,19 @@ def cast_types(self, dtype, device):

def modify_cosine_of_target_classes(self, cosine_of_target_classes):
angles = self.get_angles(cosine_of_target_classes)
return torch.cos(angles + self.margin)

# Compute cos of (theta + margin) and cos of theta
cos_theta_plus_margin = torch.cos(angles + self.margin)
cos_theta = torch.cos(angles)

# Keep the cost function monotonically decreasing
unscaled_logits = torch.where(
angles <= np.deg2rad(180) - self.margin,
cos_theta_plus_margin,
cos_theta - self.margin * np.sin(self.margin),
)

return unscaled_logits

def scale_logits(self, logits, *_):
return logits * self.scale
13 changes: 10 additions & 3 deletions tests/losses/test_arcface_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,16 @@ def test_arcface_loss(self):

for i, c in enumerate(labels):
acos = torch.acos(torch.clamp(logits[i, c], -1, 1))
logits[i, c] = torch.cos(
acos + torch.tensor(np.radians(margin), dtype=dtype).to(TEST_DEVICE)
)
if acos <= (np.pi - np.radians(margin)):
logits[i, c] = torch.cos(
acos
+ torch.tensor(np.radians(margin), dtype=dtype).to(TEST_DEVICE)
)
else:
mg = np.radians(margin)
logits[i, c] -= torch.tensor(mg * np.sin(mg), dtype=dtype).to(
TEST_DEVICE
)

correct_loss = F.cross_entropy(logits * scale, labels.to(TEST_DEVICE))

Expand Down

0 comments on commit ef755ab

Please sign in to comment.