-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update doper.py #351
Update doper.py #351
Conversation
WalkthroughThe pull request modifies the Changes
Possibly Related PRs
Poem
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (8)
smact/dopant_prediction/doper.py (8)
90-90
: Avoid relying on assert in production code.Asserts can be disabled with certain Python optimisations, so consider using an explicit check and raising a custom exception or warning if this condition is breached.
- assert len(dopants) == 5 + if len(dopants) != 5: + raise ValueError("Unexpected dopants length. Expected 5.")
Line range hint
177-201
: Consolidate repetitive doping checks.The logic for cation doping is repeated in loops with similar thresholds and sub-prob checks. Try refactoring into a helper function to reduce duplication and improve clarity.
Line range hint
215-229
: Eliminate duplication in anion doping.As with the cation code, much of this block mirrors the same pattern. Refactoring into a shared helper function could significantly streamline future changes.
238-240
: Clarify dynamic length adjustment.Line 238 sets ‘self.len_list’ to 4, then line 240 changes it to 6. Consider using a single assignment or a more descriptive mechanism to indicate why these values differ when selectivity is enabled.
247-257
: Evaluate weighting in combined score ordering.You compute a combined score, then reorder dopants. If the user expects to emphasise different factors (similarity vs. selectivity), allow custom weighting to avoid hard-coded distribution.
305-312
: Simplify parameter-based plotting logic.This chain of if-elif statements might be more succinctly expressed with a dictionary mapping from ‘plot_value’ to the proper index. It will be easier to extend with new plot modes in the future.
325-327
: Generalise the weighting parameter.The weighting factor of 0.25 is currently fixed. You might allow users to configure it via a method parameter or class attribute for greater flexibility.
341-347
: Remove or justify commented-out code.Lines 341 to 345 are commented out without explanation. If unused, consider removing them to improve clarity and reduce confusion about which code paths are relevant.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
smact/dopant_prediction/doper.py
(12 hunks)
🔇 Additional comments (1)
smact/dopant_prediction/doper.py (1)
85-85
: Verify that each ‘cation’ is valid in the mutation table.
Although this update is straightforward, consider confirming that ‘cation’ exists in the mutator’s specification to prevent potential KeyErrors.
✅ Verification successful
Let me gather more information about how the cations are validated and where they come from.
The cations are properly validated before being used with sub_prob.
After examining the code, I can confirm that cations are validated through multiple safeguards:
- Cations are first validated during initialisation via
parse_spec()
which checks their charges - The
_get_dopants
method filters species based on valid charge states - Before calling
sub_prob
, there's an explicit charge validation check - The cation must exist in
self.cation_mutator.specs
(loaded from a JSON file) to be used - The probability threshold check (
> self.threshold
) ensures only valid substitutions are considered
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# This script searches for sub_prob calls to ensure that all references use valid cation species
rg -A3 'sub_prob\(cation,'
Length of output: 1624
Script:
#!/bin/bash
# Check the CationMutator class implementation and initialization
ast-grep --pattern 'class CationMutator {
$$$
}'
# Also check where cations are defined or validated
rg -l "cation.*mutator" --type py
Length of output: 121
Script:
#!/bin/bash
# Let's examine the CationMutator class implementation in doper.py
cat smact/dopant_prediction/doper.py
Length of output: 14051
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #351 +/- ##
===========================================
+ Coverage 77.56% 78.04% +0.48%
===========================================
Files 31 31
Lines 2599 2624 +25
===========================================
+ Hits 2016 2048 +32
+ Misses 583 576 -7 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
smact/dopant_prediction/doper.py (3)
62-73
: LGTM! Consider updating docstring for new embeddings.The embedding validation and initialization logic is well-structured. However, the docstring still mentions that only 'skipspecies' is supported.
Update the docstring to reflect all supported embeddings:
- embedding (str): Name of the species embedding to use. Currently only 'skipspecies' is supported. + embedding (str): Name of the species embedding to use. Supported options are: + - 'skipspecies' + - 'M3GNet-MP-2023.11.1-oxi-Eform' + - 'M3GNet-MP-2023.11.1-oxi-band_gap'
Line range hint
192-216
: Consider using list comprehension for cleaner code.The probability calculation blocks contain repeated patterns that could be simplified.
- for _i, n_specie in enumerate(poss_n_type_cat): - n_specie_charge = utilities.parse_spec(n_specie)[1] - if cation_charge >= n_specie_charge: - continue - if CM.sub_prob(cation, n_specie) > self.threshold: - n_type_cat.append([ - n_specie, - cation, - self.cation_mutator.sub_prob(cation, n_specie), - self.cation_mutator.get_lambda(cation, n_specie), - ]) + n_type_cat.extend([ + [n_specie, cation, CM.sub_prob(cation, n_specie), CM.get_lambda(cation, n_specie)] + for n_specie in poss_n_type_cat + if utilities.parse_spec(n_specie)[1] < cation_charge + and CM.sub_prob(cation, n_specie) > self.threshold + ])Also applies to: 230-244
262-271
: Verify the combined score weighting.The combined scoring system uses a fixed weight of 0.75 for similarity and 0.25 for selectivity. Consider making these weights configurable.
+ def __init__( + self, + original_species: tuple[str, ...], + filepath: str | None = None, + embedding: str | None = None, + use_probability: bool = True, + similarity_weight: float = 0.75, # Add weight parameters + ): + self.similarity_weight = similarity_weight + self.selectivity_weight = 1 - similarity_weight def _calculate_combined_score(self, similarity: float, selectivity: float) -> float: - return (1 - 0.25) * similarity + 0.25 * selectivity + return self.similarity_weight * similarity + self.selectivity_weight * selectivity
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
requirements.txt
(17 hunks)smact/dopant_prediction/doper.py
(14 hunks)
✅ Files skipped from review due to trivial changes (1)
- requirements.txt
🔇 Additional comments (2)
smact/dopant_prediction/doper.py (2)
Line range hint 300-327
: LGTM! Enhanced plotting functionality with flexible value selection.
The plot_dopants method now supports multiple visualization options through the plot_value parameter.
23-29
: LGTM! Verify file existence in package data directory.
The new constants for M3GNet embeddings are well-defined and follow the project's naming conventions.
✅ Verification successful
Let me check if these files are included in the PR or if they are expected to be downloaded/generated during installation.
Let me try a different approach to find any information about these data files.
Let me check the README and the data directory contents for any information about these files.
The required M3GNet embedding files are present in the repository
I can see that both files are actually present in the repository at the correct location:
smact/data/species_rep/ion_embedding_M3GNet-MP-2023.11.1-oxi-Eform_cosine_similarity.json
smact/data/species_rep/ion_embedding_SpeciesM3GNet-MP-2023.11.1-oxi-band_gap_SCRATCH.json
The paths defined in the constants are correct and the files exist in the expected location within the package's data directory.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify the existence of embedding files in the data directory
for file in "species_rep/ion_embedding_M3GNet-MP-2023.11.1-oxi-Eform_cosine_similarity.json" "species_rep/ion_embedding_SpeciesM3GNet-MP-2023.11.1-oxi-band_gap_SCRATCH.json"; do
if [ -f "${SMACT_DATA_DIR:-data_directory}/$file" ]; then
echo "✓ Found $file"
else
echo "✗ Missing $file"
fi
done
Length of output: 939
Script:
#!/bin/bash
# Check if files are part of the PR
git diff --name-only origin/main... | grep -E "species_rep/.*json"
# Check for any setup/download scripts
fd -e py -e sh setup download install | xargs grep -l "species_rep.*json"
# Look for documentation about data files
fd -e md -e rst | xargs grep -l "species_rep.*json"
Length of output: 543
Script:
#!/bin/bash
# Search for any mentions of these files in the codebase
rg -l "M3GNet-MP-2023"
# Check if there's a data directory in the repository
fd -t d "data"
# Look for any documentation about data files or setup
fd "README" -e md
Length of output: 110
Script:
#!/bin/bash
# Check README content for data file information
cat README.md
# List contents of the data directory
ls -R smact/data/
# Check if there's any documentation about data files
cat smact/data/README.md 2>/dev/null || echo "No README in data directory"
Length of output: 13399
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🔭 Outside diff range comments (1)
smact/dopant_prediction/doper.py (1)
Line range hint
300-327
: Improve robustness of plotting functionalityThe current implementation uses magic strings and lacks validation. Consider:
- Using an enum for plot value options
- Adding input validation
- Improving error messages
+from enum import Enum + +class PlotValueType(Enum): + PROBABILITY = "probability" + SIMILARITY = "similarity" + SELECTIVITY = "selectivity" + COMBINED = "combined" - def plot_dopants(self, cmap: str = "YlOrRd", plot_value: str = "probability") -> None: + def plot_dopants( + self, + cmap: str = "YlOrRd", + plot_value: PlotValueType = PlotValueType.PROBABILITY + ) -> None: """Plot the dopant suggestions using the periodic table heatmap. Args: cmap: The colormap to use for the heatmap - plot_value: The value to plot on the heatmap + plot_value: The value to plot (probability, similarity, selectivity, or combined) """ assert self.results, "Dopants are not calculated. Run get_dopants first." + if not isinstance(plot_value, PlotValueType): + raise ValueError(f"Invalid plot_value. Expected one of {[e.value for e in PlotValueType]}") for dopants in self.results.values(): if self.len_list == 3: dict_results = {utilities.parse_spec(x)[0]: y for x, _, y in dopants.get("sorted")} - elif plot_value == "probability": + elif plot_value == PlotValueType.PROBABILITY: dict_results = {utilities.parse_spec(x)[0]: y for x, _, y, _, _, _ in dopants.get("sorted")} - elif plot_value == "similarity": + elif plot_value == PlotValueType.SIMILARITY: dict_results = {utilities.parse_spec(x)[0]: y for x, _, _, y, _, _ in dopants.get("sorted")} - elif plot_value == "selectivity": + elif plot_value == PlotValueType.SELECTIVITY: dict_results = {utilities.parse_spec(x)[0]: y for x, _, _, _, y, _ in dopants.get("sorted")} else: dict_results = {utilities.parse_spec(x)[0]: y for x, _, _, _, _, y in dopants.get("sorted")}
🧹 Nitpick comments (4)
smact/dopant_prediction/doper.py (4)
23-29
: Consider improving embedding configuration managementThe current implementation hard-codes embedding paths and options. Consider:
- Moving paths to a configuration file for easier maintenance
- Using an Enum for supported embeddings to prevent typos and improve maintainability
Here's a suggested implementation:
from enum import Enum class EmbeddingType(Enum): SKIPSPECIES = "skipspecies" M3GNET_EFORM = "M3GNet-MP-2023.11.1-oxi-Eform" M3GNET_GAP = "M3GNet-MP-2023.11.1-oxi-band_gap" EMBEDDING_PATHS = { EmbeddingType.SKIPSPECIES: SKIPSSPECIES_COSINE_SIM_PATH, EmbeddingType.M3GNET_EFORM: SPECIES_M3GNET_MP2023_EFORM_COSINE_PATH, EmbeddingType.M3GNET_GAP: SPECIES_M3GNET_MP2023_GAP_COSINE_PATH, }Then update the validation:
-if embedding and embedding not in [ - "skipspecies", - "M3GNet-MP-2023.11.1-oxi-Eform", - "M3GNet-MP-2023.11.1-oxi-band_gap", -]: +if embedding and embedding not in [e.value for e in EmbeddingType]: raise ValueError(f"Embedding {embedding} is not supported")Also applies to: 62-73
100-105
: Improve robustness of selectivity calculationThe hard-coded assertion might break if the data structure changes. Consider:
- Using a constant for the expected list length
- Adding more descriptive error messages
+ DOPANT_LIST_LENGTH = 5 # [selected_site, original_specie, sub_prob, selectivity] + def _get_selectivity(self, data_list, cations, sub): data = data_list.copy() for dopants in data: if sub == "anion": dopants.append(1.0) continue selected_site, original_specie, sub_prob = dopants[:3] sum_prob = sub_prob for cation in cations: if cation != original_specie: sum_prob += self.cation_mutator.sub_prob(cation, selected_site) selectivity = sub_prob / sum_prob selectivity = round(selectivity, 2) dopants.append(selectivity) - assert len(dopants) == 5 + if len(dopants) != self.DOPANT_LIST_LENGTH: + raise ValueError(f"Expected dopant list length {self.DOPANT_LIST_LENGTH}, got {len(dopants)}") return data
340-342
: Enhance flexibility of the scoring mechanismThe current implementation uses hard-coded weights. Consider:
- Making weights configurable via constructor parameters
- Adding documentation about the scoring methodology
- Validating input ranges
+ def __init__( + self, + original_species: tuple[str, ...], + filepath: str | None = None, + embedding: str | None = None, + use_probability: bool = True, + similarity_weight: float = 0.75, # Add weight parameters + ): + self.similarity_weight = similarity_weight + self.selectivity_weight = 1 - similarity_weight def _calculate_combined_score(self, similarity: float, selectivity: float) -> float: + """Calculate the combined score using weighted sum of similarity and selectivity. + + Args: + similarity: The similarity score (0 to 1) + selectivity: The selectivity score (0 to 1) + + Returns: + float: The combined score (0 to 1) + """ + if not 0 <= similarity <= 1 or not 0 <= selectivity <= 1: + raise ValueError("Similarity and selectivity must be between 0 and 1") - return (1 - 0.25) * similarity + 0.25 * selectivity + return self.similarity_weight * similarity + self.selectivity_weight * selectivity
356-362
: Clean up table formatting codeRemove the commented code and consider defining headers as a constant for better maintainability.
- # if self.use_probability: - # headers = ["Rank", "Dopant", "Host", "Probability", "Selectivity", "Combined"] - # else: - # headers = ["Rank", "Dopant", "Host", "Similarity", "Selectivity", "Combined"] - - headers = ["Rank", "Dopant", "Host", "Probability", "Similarity", "Selectivity", "Combined"] + HEADERS = [ + "Rank", + "Dopant", + "Host", + "Probability", + "Similarity", + "Selectivity", + "Combined" + ]
def test_alternative_representations(self): | ||
test_specie = ("Cu+", "Ga3+", "S2-") | ||
test_gap = doper.Doper(test_specie, embedding="M3GNet-MP-2023.11.1-oxi-band_gap") | ||
test_eform = doper.Doper(test_specie, embedding="M3GNet-MP-2023.11.1-oxi-Eform") | ||
test_lamba = doper.Doper(test_specie, filepath=TEST_LAMBDA_JSON) | ||
for test in [test_gap, test_eform, test_lamba]: | ||
self.assertIsInstance(test, doper.Doper) | ||
result = test.get_dopants() | ||
self.assertIsInstance(result, dict) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Enhance test coverage for alternative representations
The current test method only verifies basic instantiation and return types. Consider adding the following test cases:
- Validate the structure and content of the returned dictionary
- Assert the expected number of dopants in the results
- Test edge cases with empty or invalid species
- Verify that the combined scores are calculated correctly
Here's a suggested enhancement:
def test_alternative_representations(self):
test_specie = ("Cu+", "Ga3+", "S2-")
test_gap = doper.Doper(test_specie, embedding="M3GNet-MP-2023.11.1-oxi-band_gap")
test_eform = doper.Doper(test_specie, embedding="M3GNet-MP-2023.11.1-oxi-Eform")
test_lamba = doper.Doper(test_specie, filepath=TEST_LAMBDA_JSON)
+ expected_keys = {
+ "n-type cation substitutions",
+ "p-type cation substitutions",
+ "n-type anion substitutions",
+ "p-type anion substitutions"
+ }
for test in [test_gap, test_eform, test_lamba]:
self.assertIsInstance(test, doper.Doper)
result = test.get_dopants()
self.assertIsInstance(result, dict)
+ self.assertEqual(set(result.keys()), expected_keys)
+ for key in result:
+ self.assertIn("sorted", result[key])
+ self.assertIsInstance(result[key]["sorted"], list)
+ # Verify each dopant entry has the expected structure
+ for dopant in result[key]["sorted"]:
+ self.assertEqual(len(dopant), 6) # [species, host, prob, sim, select, combined]
+
+ # Test edge cases
+ with self.assertRaises(ValueError):
+ doper.Doper((), embedding="M3GNet-MP-2023.11.1-oxi-band_gap")
Committable suggestion skipped: line range outside the PR's diff.
Pull Request Template
Description
Type of change
How Has This Been Tested?
Test Configuration:
Reviewers
N/A
Checklist
Summary by CodeRabbit
New Features
Bug Fixes
Documentation