Skip to content

Commit

Permalink
Slight cleanup, added docstrings, and a couple more tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelWolloch committed Mar 5, 2024
1 parent 2ce734b commit 8081576
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 56 deletions.
176 changes: 125 additions & 51 deletions src/py4vasp/calculation/_partial_charge.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,28 @@ class PartialCharge(_base.Refinery, _structure.Mixin):
Partial charges are produced by a post-processing VASP run after self-consistent
convergence is achieved. They are stored in an array of shape
(ngxf, ngyf, ngzf, ispin, nbands, nkpts). The first three dimensions are the
FFT grid dimensions and stored, the fourth dimension is the spin index, the fifth dimension
FFT grid dimensions, the fourth dimension is the spin index, the fifth dimension
is the band index, and the sixth dimension is the k-point index. Both band and
k-point arrays are also saved in the VASP output. If ispin=2, the second spin
index is the magnetization density (up-down), not the down-spin density.
k-point arrays are also saved and accessible in the .bands() and kpoints() methods.
If ispin=2, the second spin index is the magnetization density (up-down),
not the down-spin density.
Since this is postprocessing data for a fixed density, there are no ionic steps
to separate the data.
"""

@_base.data_access
def __str__(self):
"""Return a string representation of the partial charge density."""
return f"""
{"spin polarized" if self._spin_polarized() else ""} partial charge density:
on fine FFT grid: {self._raw_data.grid[:]}
{"summed over all contributing bands" if 0 in self._raw_data.bands[:] else f" separated for bands: {self._raw_data.bands[:]}"}
{"summed over all contributing k-points" if 0 in self._raw_data.kpoints[:] else f" separated for k-points: {self._raw_data.kpoints[:]}"}
{"spin polarized" if self._spin_polarized() else ""} partial charge density of {self._topology()}:
on fine FFT grid: {self.grid()}
{"summed over all contributing bands" if 0 in self.bands() else f" separated for bands: {self.bands()}"}
{"summed over all contributing k-points" if 0 in self.kpoints() else f" separated for k-points: {self.kpoints()}"}
""".strip()

@_base.data_access
def grid(self):
return self._raw_data.grid[:]

def to_dict(self, squeeze=True):
"""Store the partial charges in a dictionary.
Expand All @@ -61,29 +65,50 @@ def to_dict(self, squeeze=True):
def to_stm(
self,
mode="constant_height",
tip_height=4.0,
tip_height=2.0,
current=1e-9,
spin="both",
**kwargs,
):
"""Generate a STM image from the partial charge density.
"""Generate STM image data from the partial charge density.
Parameters
----------
mode : str
The mode in which the STM is operated. The default is "constant_height".
Other options are "constant_current".
The other option is "constant_current".
tip_height : float
The height of the STM tip above the surface in Angstrom. The default is 4.0 Angstrom.
The height of the STM tip above the surface in Angstrom.
The default is 2.0 Angstrom. Only used in "constant_height" mode.
current : float
The tunneling current in A. The default is 1e-5.
The tunneling current in A. The default is 1e-9.
Only used in "constant_current" mode.
spin : str
The spin channel to be used. The default is "both".
The other options are "up" and "down".
kwargs
Additional keyword arguments are passed to the STM calculation.
Specifically, the following parameters can be set:
sigma_z : float
The standard deviation of the Gaussian filter in the z-direction.
The default is 4.0.
sigma_xy : float
The standard deviation of the Gaussian filter in the xy-plane.
The default is 4.0.
truncate : float
The truncation of the Gaussian filter. The default is 3.0.
enhancement_factor : float
The enhancement factor for the output of the constant heigth
STM image. The default is 1000.
interpolation_factor : int
The interpolation factor for the z-direction in case of
constant current mode. The default is 10.
Returns
-------
STM
The STM object contains the STM image and the structural information.
The STM object contains the data to plot an image as well as the lattice vectors in the xy-plane and a label.
"""

default_params = {
Expand All @@ -95,8 +120,9 @@ def to_stm(
}

self._check_z_orth()
if mode in stm_modes["constant_height"]:
self._check_tip_height(tip_height)
if mode.lower() in stm_modes["constant_height"]:
self.tip_height = tip_height
self._check_tip_height()

default_params.update(kwargs)
self.sigma_z = default_params["sigma_z"]
Expand All @@ -107,31 +133,34 @@ def to_stm(

self.smoothed_charge = self._get_stm_data(spin)

if mode in stm_modes["constant_height"]:
if mode.lower() in stm_modes["constant_height"]:
self.STM = self._constant_height_stm(tip_height, spin)
return self.STM
elif mode in stm_modes["constant_current"]:
elif mode.lower() in stm_modes["constant_current"]:
self.STM = self._constant_current_stm(current, spin)
return self.STM
else:
raise ValueError(
f"STM mode '{mode}' not understood. Use 'constant_height' or 'constant_current'."
)

def plot_STM(self):
def plot_STM(self, **kwargs):
"""Plot the STM image.
If the STM is not calculated yet, a ValueError is raised.
"""

# check if STM is calculated already
if getattr(self, "STM", None) is None:
raise ValueError("STM is not calculated yet. Please calculate STM first.")
plot_scan(self.STM)
plot_scan(self.STM, **kwargs)

@_base.data_access
def _constant_current_stm(self, current, spin):
z_start = min_of_z_charge(
self._get_stm_data(spin), sigma=self.sigma_z, truncate=self.truncate
)
grid = self._raw_data.grid[:]
grid = self.grid()
cc_scan = np.zeros((grid[0], grid[1]))

# scan over the x and y grid
for i in range(grid[0]):
for j in range(grid[1]):
Expand All @@ -145,32 +174,34 @@ def _constant_current_stm(self, current, spin):
# normalize the scan
cc_scan = cc_scan - np.min(cc_scan.flatten())
spin_label = "both spin channels" if spin == "both" else f"spin {spin}"
label = f"STM for {spin_label} at constant current={current:.1e} A"
topology = self._topology()
label = (
f"STM of {topology} for {spin_label} at constant current={current:.1e} A"
)
return STM_data(data=cc_scan, lattice=self.lattice_vectors()[:2], label=label)

@_base.data_access
def _constant_height_stm(self, tip_height, spin):
grid = self._raw_data.grid[:]
z_index = self._z_index_for_height(tip_height + self._get_highest_z_coord())
grid = self.grid()
z_index = self._z_index_for_height(
self.tip_height + self._get_highest_z_coord()
)
ch_scan = np.zeros((grid[0], grid[1]))
for i in range(grid[0]):
for j in range(grid[1]):
ch_scan[i][j] = (
self.smoothed_charge[i][j][z_index] * self.enhancement_factor
)
spin_label = "both spin channels" if spin == "both" else f"spin {spin}"
label = (
f"STM for {spin_label} at constant height={float(tip_height):.2f} Angstrom"
)
topology = self._topology()
label = f"STM of {topology} for {spin_label} at constant height={float(self.tip_height):.2f} Angstrom"
return STM_data(
data=ch_scan,
lattice=self.lattice_vectors()[:2],
label=label,
)

@_base.data_access
def _z_index_for_height(self, tip_height):
return int(tip_height / self.lattice_vectors()[2][2] * self._raw_data.grid[2])
return int(tip_height / self.lattice_vectors()[2][2] * self.grid()[2])

@_base.data_access
def _get_highest_z_coord(self):
Expand All @@ -180,16 +211,20 @@ def _get_highest_z_coord(self):
def _get_lowest_z_coord(self):
return np.min(self._structure.cartesian_positions()[:, 2])

@_base.data_access
def _topology(self):
return str(self._structure._topology())

def _estimate_vacuum(self):
slab_thickness = self._get_highest_z_coord() - self._get_lowest_z_coord()
z_vector = self.lattice_vectors()[2, 2]
return z_vector - slab_thickness

def _check_tip_height(self, tip_height):
if tip_height > self._estimate_vacuum() / 2:
message = f"""The tip position at {tip_height} is above half of the
def _check_tip_height(self):
if self.tip_height > self._estimate_vacuum() / 2:
message = f"""The tip position at {self.tip_height:.2f} is above half of the
estimated vacuum thickness {self._estimate_vacuum():.2f} Angstrom.
You are probably sampling the bottom of your slab, which is not supported."""
You would be sampling the bottom of your slab, which is not supported."""
raise ValueError(message)

def _check_z_orth(self):
Expand All @@ -200,9 +235,8 @@ def _check_z_orth(self):
The STM calculation is not supported."""
raise ValueError(message)

@_base.data_access
def _get_stm_data(self, spin):
if 0 not in self._raw_data.bands[:] or 0 not in self._raw_data.kpoints[:]:
if 0 not in self.bands() or 0 not in self.kpoints():
massage = """Simulated STM images are only supported for non-separated bands and k-points.
Please set LSEPK and LSEPB to .FALSE. in the INCAR file."""
raise ValueError(massage)
Expand All @@ -211,7 +245,7 @@ def _get_stm_data(self, spin):

@_base.data_access
def _correct_units(self, charge_data):
grid_volume = np.prod(self._raw_data.grid[:])
grid_volume = np.prod(self.grid())
cell_volume = self._structure.volume()
return charge_data / (grid_volume * cell_volume)

Expand All @@ -226,23 +260,26 @@ def _smooth_stm_data(self, data):

@_base.data_access
def lattice_vectors(self):
"""Return the lattice vectors of the input structure."""
return self._structure._lattice_vectors()

def _spin_polarized(self):
return self._raw_data.partial_charge.shape[2] == 2

def _read_grid(self):
return {"grid": self._raw_data.grid[:]}
return {"grid": self.grid()}

def _read_bands(self):
return {"bands": self._raw_data.bands[:]}
return {"bands": self.bands()}

def _read_kpoints(self):
return {"kpoints": self._raw_data.kpoints[:]}
return {"kpoints": self.kpoints()}

@_base.data_access
def _read_structure(self):
return {"structure": self._structure.read()}

@_base.data_access
def _read_partial_charge(self, squeeze=True):
if squeeze:
return {"partial_charge": np.squeeze(self._raw_data.partial_charge[:].T)}
Expand All @@ -251,6 +288,24 @@ def _read_partial_charge(self, squeeze=True):

@_base.data_access
def to_array(self, band=0, kpoint=0, spin="both"):
"""Return the partial charge density as a 3D array.
Parameters
----------
band : int
The band index. The default is 0, which means that all bands are summed.
kpoint : int
The k-point index. The default is 0, which means that all k-points are summed.
spin : str
The spin channel to be used. The default is "both".
The other options are "up" and "down".
Returns
-------
np.array
The partial charge density as a 3D array.
"""

parchg = self._raw_data.partial_charge[:].T

band = self._check_band_index(band)
Expand All @@ -277,10 +332,20 @@ def to_array(self, band=0, kpoint=0, spin="both"):
return parchg

@_base.data_access
def bands(self):
"""Return the band array listing the contributing bands.
[2,4,5] means that the 2nd, 4th, and 5th bands are contributing while
[0] means that all bands are contributing.
"""

return self._raw_data.bands[:]

def _check_band_index(self, band):
if band in self._raw_data.bands[:]:
return np.where(self._raw_data.bands[:] == band)[0][0]
elif 0 in self._raw_data.bands[:]:
bands = self.bands()
if band in bands:
return np.where(bands == band)[0][0]
elif 0 in bands:
message = f"""The band index {band} is not available.
The summed partial charge density is returned instead."""
warnings.warn(message, UserWarning)
Expand All @@ -289,10 +354,19 @@ def _check_band_index(self, band):
raise ValueError(f"Band {band} not found in the bands array.")

@_base.data_access
def kpoints(self):
"""Return the k-points array listing the contributing k-points.
[2,4,5] means that the 2nd, 4th, and 5th k-points are contributing with
all weights = 1. [0] means that all k-points are contributing.
"""
return self._raw_data.kpoints[:]

def _check_kpoint_index(self, kpoint):
if kpoint in self._raw_data.kpoints[:]:
return np.where(self._raw_data.kpoints[:] == kpoint)[0][0]
elif 0 in self._raw_data.kpoints[:]:
kpoints = self.kpoints()
if kpoint in kpoints:
return np.where(kpoints == kpoint)[0][0]
elif 0 in kpoints:
message = f"""The k-point index {kpoint} is not available.
The summed partial charge density is returned instead."""
warnings.warn(message, UserWarning)
Expand Down Expand Up @@ -331,6 +405,7 @@ def min_of_z_charge(
def plot_scan(
stm_data,
mult_xy=[2, 2],
levels=40,
cmap="copper",
name="STM",
):
Expand All @@ -347,11 +422,11 @@ def plot_scan(
# make the xy-grid in cartesian coordinates
XX, YY = make_cart_grid(grid, stm_data.lattice, mult_xy)
# multiply the image in the x and y directions
scan = multiply_image(stm_data.data, mult_xy)
scan = multiply_image(stm_data.data, [mult_xy[1], mult_xy[0]])
# plot the STM image
import matplotlib.pyplot as plt

plt.contourf(XX, YY, scan, 40, cmap=cmap)
plt.contourf(XX, YY, scan.T, levels, cmap=cmap)
plt.colorbar()
# use the 2D lattice vectors to plot the xy-unit cell
lattice = stm_data.lattice
Expand Down Expand Up @@ -396,7 +471,6 @@ def make_cart_grid(grid, lattice, mult):
# reshape the coordinates to the shape of the meshgrid
XX = np.reshape(coordinates[:, 0], XX.shape)
YY = np.reshape(coordinates[:, 1], YY.shape)

return XX, YY
elif len(mult) == 3:
grid = (grid[0] * mult[0], grid[1] * mult[1], grid[2] * mult[2])
Expand Down
Loading

0 comments on commit 8081576

Please sign in to comment.