Skip to content

Commit

Permalink
Correct method get_projection_on_elements docstring under Procar
Browse files Browse the repository at this point in the history
…class (#3945)

* correct Procar docs

* more specific get_projection_on_elements return type

---------

Co-authored-by: Janosh Riebesell <[email protected]>
  • Loading branch information
DanielYang59 and janosh authored Jul 24, 2024
1 parent 44b8c6e commit 5d925fe
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.0
rev: v0.5.4
hooks:
- id: ruff
args: [ --fix, --unsafe-fixes ]
Expand All @@ -22,7 +22,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.1
rev: v1.11.0
hooks:
- id: mypy

Expand Down Expand Up @@ -65,6 +65,6 @@ repos:
args: [ --drop-empty-cells, --keep-output ]

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.369
rev: v1.1.373
hooks:
- id: pyright
12 changes: 6 additions & 6 deletions src/pymatgen/io/vasp/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3889,31 +3889,31 @@ def __init__(self, filename: PathLike) -> None:
self.data = data
self.phase_factors = phase_factors

def get_projection_on_elements(self, structure: Structure) -> dict[Spin, list]:
def get_projection_on_elements(self, structure: Structure) -> dict[Spin, list[list[dict[str, float]]]]:
"""Get a dict of projections on elements.
Args:
structure (Structure): Input structure.
Returns:
A dict as {Spin.up: [k index][b index][{Element: values}]].
A dict as {Spin: [band index][kpoint index][{Element: values}]].
"""
assert self.data is not None, "Data cannot be None."
assert self.nkpoints is not None
assert self.nbands is not None
assert self.nions is not None

dico: dict[Spin, list] = {}
elem_proj: dict[Spin, list] = {}
for spin in self.data:
dico[spin] = [[defaultdict(float) for _ in range(self.nkpoints)] for _ in range(self.nbands)]
elem_proj[spin] = [[defaultdict(float) for _ in range(self.nkpoints)] for _ in range(self.nbands)]

for iat in range(self.nions):
name = structure.species[iat].symbol
for spin, data in self.data.items():
for kpoint, band in itertools.product(range(self.nkpoints), range(self.nbands)):
dico[spin][band][kpoint][name] += np.sum(data[kpoint, band, iat, :])
elem_proj[spin][band][kpoint][name] += np.sum(data[kpoint, band, iat, :])

return dico
return elem_proj

def get_occupation(self, atom_index: int, orbital: str) -> dict:
"""Get the occupation for a particular orbital of a particular atom.
Expand Down

0 comments on commit 5d925fe

Please sign in to comment.