Skip to content

Commit

Permalink
Merge pull request theochem#146 from FanwangM/fix_failed_refactor
Browse files Browse the repository at this point in the history
Rename `total_diversity_volume` --> `hypersphere_overlap_of_subset`
  • Loading branch information
FanwangM authored Jul 9, 2023
2 parents 9c0523e + 9f746c2 commit 4d1a080
Showing 1 changed file with 25 additions and 14 deletions.
39 changes: 25 additions & 14 deletions DiverseSelector/diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@

"""Molecule dataset diversity calculation module."""

import warnings
from typing import List

import numpy as np
from scipy.spatial.distance import euclidean
import warnings

__all__ = [
"compute_diversity",
Expand All @@ -42,7 +42,7 @@

def compute_diversity(
features: np.array,
div_type: str = "total_diversity_volume",
div_type: str = "hypersphere_overlap_of_subset",
) -> float:
"""Compute diversity metrics.
Expand All @@ -53,7 +53,8 @@ def compute_diversity(
div_type : str, optional
Method of calculation diversity for a given molecule set, which
includes "entropy", "logdet", "shannon_entropy", "wdud",
gini_coefficient" and "total_diversity_volume". Default is "total_diversity_volume".
gini_coefficient" and "hypersphere_overlap_of_subset".
Default is "hypersphere_overlap_of_subset".
mols : List[rdkit.Chem.rdchem.Mol], optional
List of RDKit molecule objects. This is only needed when using the
"explicit_diversity_index" method. Default=None.
Expand All @@ -68,7 +69,7 @@ def compute_diversity(
"logdet": logdet,
"shannon_entropy": shannon_entropy,
"wdud": wdud,
"total_diversity_volume": total_diversity_volume,
"hypersphere_overlap_of_subset": hypersphere_overlap_of_subset,
"gini_coefficient": gini_coefficient,
}

Expand Down Expand Up @@ -201,7 +202,10 @@ def shannon_entropy(x: np.ndarray) -> float:
p_i = np.count_nonzero(x[:, i]) / num_mols
# sum all non-zero terms
if p_i == 0:
raise ValueError(f"Feature {i} has value 0 for all molecules. Remove extraneous feature from data set.")
raise ValueError(
f"Feature {i} has value 0 for all molecules."
"Remove extraneous feature from data set."
)
h_x += (-1 * p_i) * np.log10(p_i)
return h_x

Expand Down Expand Up @@ -246,15 +250,17 @@ def wdud(x: np.ndarray) -> float:
min_x = np.min(x, axis=0)
# Normalization of each feature to [0, 1]
if np.any(np.abs(max_x - min_x) < 1e-30):
raise ValueError(f"One of the features is redundant and causes normalization to fail.")
raise ValueError(
f"One of the features is redundant and causes normalization to fail."
)
x_norm = (x - min_x) / (max_x - min_x)
ans = [] # store the Wasserstein distance for each feature
for i in range(0, num_features):
wdu = 0.0
y = np.sort(x_norm[:, i])
# Round to the sixth decimal place and count number of unique elements
# to construct an accurate cumulative discrete distribution func \sum_{x <= y_{i + 1}} 1/k
y, counts = np.unique(np.round(x_norm[:,i], decimals=6), return_counts=True)
y, counts = np.unique(np.round(x_norm[:, i], decimals=6), return_counts=True)
p = 0
# Ignore 0 and because v_min= 0
for j in range(1, len(counts)):
Expand All @@ -264,7 +270,7 @@ def wdud(x: np.ndarray) -> float:
# Make a grid from yi1 to yi
grid = np.linspace(yi1, yi, num=1000, endpoint=True)
# Evaluate the integrand |x - \sum_{x <= y_{i + 1}} 1/k|
p += counts[j-1]
p += counts[j - 1]
integrand = np.abs(grid - p / num_mols)
# Integrate using np.trapz
wdu += np.trapz(y=integrand, x=grid)
Expand Down Expand Up @@ -311,23 +317,28 @@ def hypersphere_overlap_of_subset(lib: np.ndarray, x: np.array) -> float:
min_x = np.min(lib, axis=0)
# Normalization of each feature to [0, 1]
if np.any(np.abs(max_x - min_x) < 1e-30):
raise ValueError(f"One of the features is redundant and causes normalization to fail.")
raise ValueError(
f"One of the features is redundant and causes normalization to fail."
)
x_norm = (x - min_x) / (max_x - min_x)
# r_o = hypersphere radius
r_o = d * np.sqrt(1 / k)
if r_o > 0.5:
warnings.warn(f"The number of molecules should be much larger"
" than the number of features.")
warnings.warn(
f"The number of molecules should be much larger"
" than the number of features."
)
g_s = 0
edge = 0
lam = (d - 1.0) / d # Lambda parameter controls edge penalty
# lambda parameter controls edge penalty
lam = (d - 1.0) / d
# calculate overlap volume
for i in range(0, (k - 1)):
for j in range((i + 1), k):
dist = np.linalg.norm(x_norm[i] - x_norm[j])
# Overlap penalty
if dist <= (2 * r_o):
with np.errstate(divide='ignore'):
with np.errstate(divide="ignore"):
# min(100) ignores the inf case with divide by zero
g_s += min(100, 2 * (r_o / dist) - 1)
# Edge penalty: lambda (1 - \sum^d_j e_{ij} / (dr_0)
Expand All @@ -342,7 +353,7 @@ def hypersphere_overlap_of_subset(lib: np.ndarray, x: np.array) -> float:
if dist > r_o:
dist = r_o
edge_pen += dist
edge_pen /= (d * r_o)
edge_pen /= d * r_o
# print("Should be positive value only", (1.0 - edge_pen))
edge_pen = lam * (1.0 - edge_pen)
edge += edge_pen
Expand Down

0 comments on commit 4d1a080

Please sign in to comment.