Skip to content

Commit

Permalink
Merge pull request #12 from VisualComputingInstitute/fix-sqeuclid
Browse files Browse the repository at this point in the history
Fix important bug
  • Loading branch information
lucasb-eyer authored Nov 26, 2017
2 parents 23d314a + 31a3b08 commit 0e30b89
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def cdist(a, b, metric='euclidean'):
"""
with tf.name_scope("cdist"):
diffs = all_diffs(a, b)
if metric == 'euclidean':
if metric == 'sqeuclidean':
return tf.reduce_sum(tf.square(diffs), axis=-1)
elif metric == 'sqeuclidean':
elif metric == 'euclidean':
return tf.sqrt(tf.reduce_sum(tf.square(diffs), axis=-1) + 1e-12)
elif metric == 'cityblock':
return tf.reduce_sum(tf.abs(diffs), axis=-1)
Expand Down Expand Up @@ -82,10 +82,10 @@ def batch_hard(dists, pids, margin, batch_precision_at_k=None):
"""
with tf.name_scope("batch_hard"):
same_identity_mask = tf.equal(tf.expand_dims(pids, axis=1),
tf.expand_dims(pids, axis=0))
tf.expand_dims(pids, axis=0))
negative_mask = tf.logical_not(same_identity_mask)
positive_mask = tf.logical_xor(same_identity_mask,
tf.eye(tf.shape(pids)[0], dtype=tf.bool))
tf.eye(tf.shape(pids)[0], dtype=tf.bool))

furthest_positive = tf.reduce_max(dists*tf.cast(positive_mask, tf.float32), axis=1)
closest_negative = tf.map_fn(lambda x: tf.reduce_min(tf.boolean_mask(x[0], x[1])),
Expand Down

0 comments on commit 0e30b89

Please sign in to comment.