Skip to content

Commit

Permalink
feat: implement contour and quiver plots for densities (#146)
Browse files Browse the repository at this point in the history
* import Contour class from STM branch
* extend Contour class to allow for contour plots, heatmaps, and quiver plots
* add annotations to identify lattice vectors

* implement utility class to take a slice of a 3d cell
* add functionality to slice grid scalars and grid vectors

* implement contour and quiver plots for densities
* implement selecting different densities (only contour)
* implement selecting different cuts
* implement supercell
* allow passing normal vector
* add documentation to to_contour and to_quiver
  • Loading branch information
martin-schlipf authored Apr 23, 2024
1 parent 351133b commit 77fe2ae
Show file tree
Hide file tree
Showing 15 changed files with 1,310 additions and 34 deletions.
4 changes: 2 additions & 2 deletions src/py4vasp/_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright © VASP Software GmbH,
# Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
VASP_COLORS = ("#4C265F", "#2FB5AB", "#2C68FC", "#A82C35", "#808080")
VASP_PURPLE, VASP_CYAN, VASP_BLUE, VASP_RED, VASP_GRAY = VASP_COLORS
VASP_COLORS = ("#4C265F", "#2FB5AB", "#2C68FC", "#A82C35", "#808080", "#212529")
VASP_PURPLE, VASP_CYAN, VASP_BLUE, VASP_RED, VASP_GRAY, VASP_DARK = VASP_COLORS
11 changes: 9 additions & 2 deletions src/py4vasp/_third_party/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Copyright © VASP Software GmbH,
# Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from py4vasp._config import VASP_COLORS
import copy

from py4vasp._config import VASP_BLUE, VASP_COLORS, VASP_GRAY, VASP_RED
from py4vasp._util import import_

from .contour import Contour
from .graph import Graph
from .mixin import Mixin
from .plot import plot
Expand All @@ -13,6 +16,10 @@

if import_.is_imported(go) and import_.is_imported(pio):
axis_format = {"showexponent": "all", "exponentformat": "power"}
contour = copy.copy(pio.templates["ggplot2"].data.contour[0])
contour.colorscale = [[0, VASP_BLUE], [0.5, VASP_GRAY], [1, VASP_RED]]
data = {"contour": (contour,)}
layout = {"colorway": VASP_COLORS, "xaxis": axis_format, "yaxis": axis_format}
pio.templates["vasp"] = go.layout.Template(layout=layout)
pio.templates["vasp"] = go.layout.Template(data=data, layout=layout)
pio.templates["ggplot2"].layout.shapedefaults = {}
pio.templates.default = "ggplot2+vasp"
178 changes: 178 additions & 0 deletions src/py4vasp/_third_party/graph/contour.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright © VASP Software GmbH,
# Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from __future__ import annotations

import dataclasses
import itertools

import numpy as np

from py4vasp import _config
from py4vasp._third_party.graph import trace
from py4vasp._util import import_

ff = import_.optional("plotly.figure_factory")
go = import_.optional("plotly.graph_objects")
interpolate = import_.optional("scipy.interpolate")


@dataclasses.dataclass
class Contour(trace.Trace):
"""Represents data on a 2d slice through the unit cell.
This class creates a visualization of the data within the unit cell based on its
configuration. Currently it supports the creation of heatmaps and quiver plots.
For heatmaps each data point corresponds to one point on the grid. For quiver plots
each data point should be a 2d vector within the plane.
"""

_interpolation_factor = 2
"""If the lattice does not align with the cartesian axes, the data is interpolated
to by approximately this factor along each line."""
_shift_label_pixels = 10
"Shift the labels by this many pixels to avoid overlap."

data: np.array
"""2d or 3d grid data in the plane spanned by the lattice vectors. If the data is
the dimensions should be the ones of the grid, if the data is 3d the first dimension
should be a 2 for a vector in the plane of the grid and the other two dimensions
should be the grid."""
lattice: Lattice
"""2 vectors spanning the plane in which the data is represented. Each vector should
have two components, so remove any element normal to the plane."""
label: str
"Assign a label to the visualization that may be used to identify one among multiple plots."
isolevels: bool = False
"Defines whether isolevels should be added or a heatmap is used."
supercell: np.array = (1, 1)
"Multiple of each lattice vector to be drawn."
show_cell: bool = True
"Show the unit cell in the resulting visualization."

def to_plotly(self):
lattice_supercell = np.diag(self.supercell) @ self.lattice.vectors
# swap a and b axes because that is the way plotly expects the data
data = np.tile(self.data, self.supercell).T
if self._is_contour():
yield self._make_contour(lattice_supercell, data), self._options()
elif self._is_heatmap():
yield self._make_heatmap(lattice_supercell, data), self._options()
else:
yield self._make_quiver(lattice_supercell, data), self._options()

def _is_contour(self):
return self.data.ndim == 2 and self.isolevels

def _is_heatmap(self):
return self.data.ndim == 2 and not self.isolevels

def _make_contour(self, lattice, data):
x, y, z = self._interpolate_data_if_necessary(lattice, data)
return go.Contour(x=x, y=y, z=z, name=self.label, autocontour=True)

def _make_heatmap(self, lattice, data):
x, y, z = self._interpolate_data_if_necessary(lattice, data)
return go.Heatmap(x=x, y=y, z=z, name=self.label, colorscale="turbid_r")

def _make_quiver(self, lattice, data):
u = data[:, :, 0].flatten()
v = data[:, :, 1].flatten()
meshes = [
np.linspace(np.zeros(2), vector, num_points, endpoint=False)
for vector, num_points in zip(reversed(lattice), data.shape)
# remember that b and a axis are swapped
]
x, y = np.array([sum(points) for points in itertools.product(*meshes)]).T
fig = ff.create_quiver(x, y, u, v, scale=1)
return fig.data[0]

def _interpolate_data_if_necessary(self, lattice, data):
if self._interpolation_required():
x, y, z = self._interpolate_data(lattice, data)
else:
x, y, z = self._use_data_without_interpolation(lattice, data)
return x, y, z

def _interpolation_required(self):
y_position_first_vector = self.lattice.vectors[0, 1]
x_position_second_vector = self.lattice.vectors[1, 0]
return not np.allclose((y_position_first_vector, x_position_second_vector), 0)

def _interpolate_data(self, lattice, data):
area_cell = abs(np.cross(lattice[0], lattice[1]))
points_per_area = data.size / area_cell
points_per_line = np.sqrt(points_per_area) * self._interpolation_factor
lengths = np.sum(np.abs(lattice), axis=0)
shape = np.ceil(points_per_line * lengths).astype(int)
line_mesh_a = self._make_mesh(lattice, data.shape[1], 0)
line_mesh_b = self._make_mesh(lattice, data.shape[0], 1)
x_in, y_in = (line_mesh_a[:, np.newaxis] + line_mesh_b[np.newaxis, :]).T
x_in = x_in.flatten()
y_in = y_in.flatten()
z_in = data.flatten()
x_out, y_out = np.meshgrid(
np.linspace(x_in.min(), x_in.max(), shape[0]),
np.linspace(y_in.min(), y_in.max(), shape[1]),
)
z_out = interpolate.griddata((x_in, y_in), z_in, (x_out, y_out), method="cubic")
return x_out[0], y_out[:, 0], z_out

def _use_data_without_interpolation(self, lattice, data):
x = self._make_mesh(lattice, data.shape[1], 0)
y = self._make_mesh(lattice, data.shape[0], 1)
return x, y, data

def _make_mesh(self, lattice, num_point, index):
vector = index if self._interpolation_required() else (index, index)
return (
np.linspace(0, lattice[vector], num_point, endpoint=False)
+ 0.5 * lattice[vector] / num_point
)

def _options(self):
return {
"shapes": self._create_unit_cell(),
"annotations": self._label_unit_cell_vectors(),
}

def _create_unit_cell(self):
if not self.show_cell:
return ()
pos_to_str = lambda pos: f"{pos[0]} {pos[1]}"
vectors = self.lattice.vectors
corners = (vectors[0], vectors[0] + vectors[1], vectors[1])
to_corners = (f"L {pos_to_str(corner)}" for corner in corners)
path = f"M 0 0 {' '.join(to_corners)} Z"
unit_cell = {"type": "path", "line": {"color": _config.VASP_DARK}, "path": path}
return (unit_cell,)

def _label_unit_cell_vectors(self):
if self.lattice.cut is None:
return []
vectors = self.lattice.vectors
labels = tuple("abc".replace(self.lattice.cut, ""))
return [
{
"text": label,
"showarrow": False,
"x": 0.5 * vectors[i, 0],
"y": 0.5 * vectors[i, 1],
**self._shift_label(vectors[i], vectors[1 - i]),
}
for i, label in enumerate(labels)
]

def _shift_label(self, current_vector, other_vector):
invert = np.cross(current_vector, other_vector) < 0
norm = np.linalg.norm(current_vector)
shifts = self._shift_label_pixels * current_vector[::-1] / norm
if invert:
return {
"xshift": -shifts[0],
"yshift": shifts[1],
}
else:
return {
"xshift": shifts[0],
"yshift": -shifts[1],
}
43 changes: 35 additions & 8 deletions src/py4vasp/_third_party/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

from py4vasp import exception
from py4vasp._config import VASP_COLORS
from py4vasp._third_party.graph.contour import Contour
from py4vasp._third_party.graph.series import Series
from py4vasp._third_party.graph.trace import Trace
from py4vasp._util import import_

go = import_.optional("plotly.graph_objects")
Expand All @@ -25,7 +27,7 @@ class Graph(Sequence):
parameters set in this class.
"""

series: Series or Sequence[Series]
series: Trace or Sequence[Trace]
"One or more series shown in the graph."
xlabel: str = None
"Label for the x axis."
Expand Down Expand Up @@ -71,11 +73,14 @@ def to_plotly(self):
"Convert the graph to a plotly figure."
figure = self._make_plotly_figure()
for trace, options in self._generate_plotly_traces():
if options["row"] is None:
if options.get("row") is None:
figure.add_trace(trace)
else:
figure.add_trace(trace, row=options["row"], col=1)

for shape in options.get("shapes", ()):
figure.add_shape(**shape)
for annotation in options.get("annotations", ()):
figure.add_annotation(**annotation)
return figure

def show(self):
Expand Down Expand Up @@ -107,9 +112,8 @@ def _ipython_display_(self):
def _generate_plotly_traces(self):
colors = itertools.cycle(VASP_COLORS)
for series in self:
if not series.color:
series = replace(series, color=next(colors))
yield from series._generate_traces()
series = _set_color_if_not_present(series, colors)
yield from series.to_plotly()

def _make_plotly_figure(self):
figure = self._figure_with_one_or_two_y_axes()
Expand All @@ -120,12 +124,13 @@ def _make_plotly_figure(self):
return figure

def _figure_with_one_or_two_y_axes(self):
has_secondary_y_axis = lambda series: isinstance(series, Series) and series.y2
if self._subplot_on:
max_row = max(series.subplot for series in self)
figure = subplots.make_subplots(rows=max_row, cols=1)
figure.update_layout(showlegend=False)
return figure
elif any(series.y2 for series in self):
elif any(has_secondary_y_axis(series) for series in self):
return subplots.make_subplots(specs=[[{"secondary_y": True}]])
else:
return go.Figure()
Expand All @@ -142,6 +147,8 @@ def _set_xaxis_options(self, figure):
figure.layout.xaxis.tickmode = "array"
figure.layout.xaxis.tickvals = tuple(self.xticks.keys())
figure.layout.xaxis.ticktext = self._xtick_labels()
if self._all_are_contour():
figure.layout.xaxis.visible = False

def _xtick_labels(self):
# empty labels will be overwritten by plotly so we put a single space in them
Expand All @@ -156,6 +163,17 @@ def _set_yaxis_options(self, figure):
figure.layout.yaxis.title.text = self.ylabel
if self.y2label:
figure.layout.yaxis2.title.text = self.y2label
if self._all_are_contour():
figure.layout.yaxis.visible = False
if self._any_are_contour():
figure.layout.yaxis.scaleanchor = "x"
figure.layout.height = 500

def _all_are_contour(self):
return all(isinstance(series, Contour) for series in self)

def _any_are_contour(self):
return any(isinstance(series, Contour) for series in self)

def to_frame(self):
"""Convert graph to a pandas dataframe.
Expand Down Expand Up @@ -213,7 +231,16 @@ def _name_column(self, series, suffix, idx=None):

@property
def _subplot_on(self):
return any(series.subplot for series in self)
has_subplot = lambda series: isinstance(series, Series) and series.subplot
return any(has_subplot(series) for series in self)


def _set_color_if_not_present(series, color_iterator):
if isinstance(series, Contour):
return series
if not series.color:
series = replace(series, color=next(color_iterator))
return series


Graph._fields = tuple(field.name for field in fields(Graph))
Expand Down
8 changes: 6 additions & 2 deletions src/py4vasp/_third_party/graph/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import numpy as np

from py4vasp import exception
from py4vasp._third_party.graph import trace
from py4vasp._util import import_

go = import_.optional("plotly.graph_objects")


@dataclass
class Series:
class Series(trace.Trace):
"""Represents a single series in a graph.
Typically this corresponds to a single line of x-y data with an optional name used
Expand Down Expand Up @@ -55,7 +56,7 @@ def __setattr__(self, key, value):
assert not self._frozen or hasattr(self, key)
super().__setattr__(key, value)

def _generate_traces(self):
def to_plotly(self):
first_trace = True
for item in enumerate(np.atleast_2d(np.array(self.y))):
yield self._make_trace(*item, first_trace), {"row": self.subplot}
Expand Down Expand Up @@ -123,5 +124,8 @@ def _common_options(self, first_trace):
"yaxis": "y2" if self.y2 else "y",
}

def _generate_shapes(self):
return ()


Series._fields = tuple(field.name for field in fields(Series))
14 changes: 14 additions & 0 deletions src/py4vasp/_third_party/graph/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright © VASP Software GmbH,
# Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from abc import ABC, abstractmethod


class Trace(ABC):
"""Defines a base class with all methods that need to be implemented for Graph to
work as intended"""

@abstractmethod
def to_plotly(self):
"""Use yield to generate one or more plotly traces. Each returned element should
be a tuple (trace, dict) where the trace can be used as data for plotly and the
options modify the generation of the figure."""
Loading

0 comments on commit 77fe2ae

Please sign in to comment.