Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement contour and quiver plots for densities #146

Merged
merged 46 commits into from
Apr 23, 2024
Merged
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
0dd36f7
Import Contour class from STM branch
martin-schlipf Apr 8, 2024
f78d76e
Add test for quiver plot
martin-schlipf Apr 8, 2024
af38bcc
Refactor test
martin-schlipf Apr 8, 2024
43a36c5
Fix transpose of axis
martin-schlipf Apr 8, 2024
92b2f6e
Add unit cell to quiver plot
martin-schlipf Apr 8, 2024
3dea206
Update documentation
martin-schlipf Apr 9, 2024
d17e16c
Begin implement utility to construct 2d plane from slice
martin-schlipf Apr 9, 2024
64b466d
Make slicing work for unusual order of vectors
martin-schlipf Apr 9, 2024
18f2003
Make transformation work for nearly orthorhombic cell
martin-schlipf Apr 9, 2024
4da0acd
wip: slicing
martin-schlipf Apr 10, 2024
ca51c5e
Generalize slicing algorithm so that everything uses the same code
martin-schlipf Apr 10, 2024
9e5516c
Refactor routine to rotate to largest component
martin-schlipf Apr 12, 2024
350b4bd
Raise error when axis is not obvious
martin-schlipf Apr 12, 2024
b0a5593
WIP: slicing
martin-schlipf Apr 15, 2024
43f3e70
Allow manual selection of normal direction
martin-schlipf Apr 15, 2024
fdf9a13
Add test for providing normal with orthorhombic cell
martin-schlipf Apr 15, 2024
db5e923
Allow disabling rotation altogether
martin-schlipf Apr 15, 2024
9b2d1ef
Add sanity checks and documentation
martin-schlipf Apr 15, 2024
77d5873
Update format
martin-schlipf Apr 15, 2024
ac53e52
WIP: density contour
martin-schlipf Apr 16, 2024
a021abd
Implement helper routine for slicing
martin-schlipf Apr 16, 2024
a2f5811
Expose lattice vectors and positions to users
martin-schlipf Apr 16, 2024
f08146e
Implement basic Contour plot
martin-schlipf Apr 16, 2024
24c5f8d
Implement selecting different cuts
martin-schlipf Apr 16, 2024
d674603
Implement selecting specific contour plot
martin-schlipf Apr 16, 2024
4531074
Implement noncollinear contour plots
martin-schlipf Apr 17, 2024
94914a7
Implement supercell for contour plots
martin-schlipf Apr 17, 2024
8a3a728
Allow passing normal vector
martin-schlipf Apr 17, 2024
d83b0b7
Use contour instead of heatmap for density
martin-schlipf Apr 17, 2024
5b14f25
Apply VASP color scheme
martin-schlipf Apr 18, 2024
c443520
Add test for subtraction in contour plot
martin-schlipf Apr 18, 2024
46d8805
Add annotations to identify lattice vectors
martin-schlipf Apr 18, 2024
2ff3f02
Show label at middle instead of at end
martin-schlipf Apr 18, 2024
3850fa2
Change default color for unit cell to dark
martin-schlipf Apr 18, 2024
ccc6203
Fix broken test
martin-schlipf Apr 18, 2024
7e1576e
Begin implementation of grid vectors for quiver plots
martin-schlipf Apr 18, 2024
70da109
Rename lattice to plane and add cut info
martin-schlipf Apr 19, 2024
8f8b136
Implement slicing of vectorial data for orthorhombic cell
martin-schlipf Apr 19, 2024
fbf32dd
Implement basic to_quiver for noncollinear calculations
martin-schlipf Apr 19, 2024
cd09a03
Add cell argument to Plane
martin-schlipf Apr 19, 2024
838a0c3
Implement projecting vectors onto generic plane
martin-schlipf Apr 23, 2024
3911473
Add test for different cuts of quiver plot
martin-schlipf Apr 23, 2024
f55dc47
Add normal vector for quiver plot
martin-schlipf Apr 23, 2024
9099f82
Add documentation to to_contour
martin-schlipf Apr 23, 2024
8fae0b1
Add documentation for to_quiver plot
martin-schlipf Apr 23, 2024
e6a218c
Refactor common parts of documentation
martin-schlipf Apr 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/py4vasp/_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright © VASP Software GmbH,
# Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
VASP_COLORS = ("#4C265F", "#2FB5AB", "#2C68FC", "#A82C35", "#808080")
VASP_PURPLE, VASP_CYAN, VASP_BLUE, VASP_RED, VASP_GRAY = VASP_COLORS
VASP_COLORS = ("#4C265F", "#2FB5AB", "#2C68FC", "#A82C35", "#808080", "#212529")
VASP_PURPLE, VASP_CYAN, VASP_BLUE, VASP_RED, VASP_GRAY, VASP_DARK = VASP_COLORS
11 changes: 9 additions & 2 deletions src/py4vasp/_third_party/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Copyright © VASP Software GmbH,
# Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from py4vasp._config import VASP_COLORS
import copy

from py4vasp._config import VASP_BLUE, VASP_COLORS, VASP_GRAY, VASP_RED
from py4vasp._util import import_

from .contour import Contour
from .graph import Graph
from .mixin import Mixin
from .plot import plot
@@ -13,6 +16,10 @@

if import_.is_imported(go) and import_.is_imported(pio):
axis_format = {"showexponent": "all", "exponentformat": "power"}
contour = copy.copy(pio.templates["ggplot2"].data.contour[0])
contour.colorscale = [[0, VASP_BLUE], [0.5, VASP_GRAY], [1, VASP_RED]]
data = {"contour": (contour,)}
layout = {"colorway": VASP_COLORS, "xaxis": axis_format, "yaxis": axis_format}
pio.templates["vasp"] = go.layout.Template(layout=layout)
pio.templates["vasp"] = go.layout.Template(data=data, layout=layout)
pio.templates["ggplot2"].layout.shapedefaults = {}
pio.templates.default = "ggplot2+vasp"
178 changes: 178 additions & 0 deletions src/py4vasp/_third_party/graph/contour.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright © VASP Software GmbH,
# Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from __future__ import annotations

import dataclasses
import itertools

import numpy as np

from py4vasp import _config
from py4vasp._third_party.graph import trace
from py4vasp._util import import_

ff = import_.optional("plotly.figure_factory")
go = import_.optional("plotly.graph_objects")
interpolate = import_.optional("scipy.interpolate")


@dataclasses.dataclass
class Contour(trace.Trace):
"""Represents data on a 2d slice through the unit cell.

This class creates a visualization of the data within the unit cell based on its
configuration. Currently it supports the creation of heatmaps and quiver plots.
For heatmaps each data point corresponds to one point on the grid. For quiver plots
each data point should be a 2d vector within the plane.
"""

_interpolation_factor = 2
"""If the lattice does not align with the cartesian axes, the data is interpolated
to by approximately this factor along each line."""
_shift_label_pixels = 10
"Shift the labels by this many pixels to avoid overlap."

data: np.array
"""2d or 3d grid data in the plane spanned by the lattice vectors. If the data is
the dimensions should be the ones of the grid, if the data is 3d the first dimension
should be a 2 for a vector in the plane of the grid and the other two dimensions
should be the grid."""
lattice: Lattice
"""2 vectors spanning the plane in which the data is represented. Each vector should
have two components, so remove any element normal to the plane."""
label: str
"Assign a label to the visualization that may be used to identify one among multiple plots."
isolevels: bool = False
"Defines whether isolevels should be added or a heatmap is used."
supercell: np.array = (1, 1)
"Multiple of each lattice vector to be drawn."
show_cell: bool = True
"Show the unit cell in the resulting visualization."

def to_plotly(self):
lattice_supercell = np.diag(self.supercell) @ self.lattice.vectors
# swap a and b axes because that is the way plotly expects the data
data = np.tile(self.data, self.supercell).T
if self._is_contour():
yield self._make_contour(lattice_supercell, data), self._options()
elif self._is_heatmap():
yield self._make_heatmap(lattice_supercell, data), self._options()
else:
yield self._make_quiver(lattice_supercell, data), self._options()

def _is_contour(self):
return self.data.ndim == 2 and self.isolevels

def _is_heatmap(self):
return self.data.ndim == 2 and not self.isolevels

def _make_contour(self, lattice, data):
x, y, z = self._interpolate_data_if_necessary(lattice, data)
return go.Contour(x=x, y=y, z=z, name=self.label, autocontour=True)

def _make_heatmap(self, lattice, data):
x, y, z = self._interpolate_data_if_necessary(lattice, data)
return go.Heatmap(x=x, y=y, z=z, name=self.label, colorscale="turbid_r")

def _make_quiver(self, lattice, data):
u = data[:, :, 0].flatten()
v = data[:, :, 1].flatten()
meshes = [
np.linspace(np.zeros(2), vector, num_points, endpoint=False)
for vector, num_points in zip(reversed(lattice), data.shape)
# remember that b and a axis are swapped
]
x, y = np.array([sum(points) for points in itertools.product(*meshes)]).T
fig = ff.create_quiver(x, y, u, v, scale=1)
return fig.data[0]

def _interpolate_data_if_necessary(self, lattice, data):
if self._interpolation_required():
x, y, z = self._interpolate_data(lattice, data)
else:
x, y, z = self._use_data_without_interpolation(lattice, data)
return x, y, z

def _interpolation_required(self):
y_position_first_vector = self.lattice.vectors[0, 1]
x_position_second_vector = self.lattice.vectors[1, 0]
return not np.allclose((y_position_first_vector, x_position_second_vector), 0)

def _interpolate_data(self, lattice, data):
area_cell = abs(np.cross(lattice[0], lattice[1]))
points_per_area = data.size / area_cell
points_per_line = np.sqrt(points_per_area) * self._interpolation_factor
lengths = np.sum(np.abs(lattice), axis=0)
shape = np.ceil(points_per_line * lengths).astype(int)
line_mesh_a = self._make_mesh(lattice, data.shape[1], 0)
line_mesh_b = self._make_mesh(lattice, data.shape[0], 1)
x_in, y_in = (line_mesh_a[:, np.newaxis] + line_mesh_b[np.newaxis, :]).T
x_in = x_in.flatten()
y_in = y_in.flatten()
z_in = data.flatten()
x_out, y_out = np.meshgrid(
np.linspace(x_in.min(), x_in.max(), shape[0]),
np.linspace(y_in.min(), y_in.max(), shape[1]),
)
z_out = interpolate.griddata((x_in, y_in), z_in, (x_out, y_out), method="cubic")
return x_out[0], y_out[:, 0], z_out

def _use_data_without_interpolation(self, lattice, data):
x = self._make_mesh(lattice, data.shape[1], 0)
y = self._make_mesh(lattice, data.shape[0], 1)
return x, y, data

def _make_mesh(self, lattice, num_point, index):
vector = index if self._interpolation_required() else (index, index)
return (
np.linspace(0, lattice[vector], num_point, endpoint=False)
+ 0.5 * lattice[vector] / num_point
)

def _options(self):
return {
"shapes": self._create_unit_cell(),
"annotations": self._label_unit_cell_vectors(),
}

def _create_unit_cell(self):
if not self.show_cell:
return ()
pos_to_str = lambda pos: f"{pos[0]} {pos[1]}"
vectors = self.lattice.vectors
corners = (vectors[0], vectors[0] + vectors[1], vectors[1])
to_corners = (f"L {pos_to_str(corner)}" for corner in corners)
path = f"M 0 0 {' '.join(to_corners)} Z"
unit_cell = {"type": "path", "line": {"color": _config.VASP_DARK}, "path": path}
return (unit_cell,)

def _label_unit_cell_vectors(self):
if self.lattice.cut is None:
return []
vectors = self.lattice.vectors
labels = tuple("abc".replace(self.lattice.cut, ""))
return [
{
"text": label,
"showarrow": False,
"x": 0.5 * vectors[i, 0],
"y": 0.5 * vectors[i, 1],
**self._shift_label(vectors[i], vectors[1 - i]),
}
for i, label in enumerate(labels)
]

def _shift_label(self, current_vector, other_vector):
invert = np.cross(current_vector, other_vector) < 0
norm = np.linalg.norm(current_vector)
shifts = self._shift_label_pixels * current_vector[::-1] / norm
if invert:
return {
"xshift": -shifts[0],
"yshift": shifts[1],
}
else:
return {
"xshift": shifts[0],
"yshift": -shifts[1],
}
43 changes: 35 additions & 8 deletions src/py4vasp/_third_party/graph/graph.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,9 @@

from py4vasp import exception
from py4vasp._config import VASP_COLORS
from py4vasp._third_party.graph.contour import Contour
from py4vasp._third_party.graph.series import Series
from py4vasp._third_party.graph.trace import Trace
from py4vasp._util import import_

go = import_.optional("plotly.graph_objects")
@@ -25,7 +27,7 @@ class Graph(Sequence):
parameters set in this class.
"""

series: Series or Sequence[Series]
series: Trace or Sequence[Trace]
"One or more series shown in the graph."
xlabel: str = None
"Label for the x axis."
@@ -71,11 +73,14 @@ def to_plotly(self):
"Convert the graph to a plotly figure."
figure = self._make_plotly_figure()
for trace, options in self._generate_plotly_traces():
if options["row"] is None:
if options.get("row") is None:
figure.add_trace(trace)
else:
figure.add_trace(trace, row=options["row"], col=1)

for shape in options.get("shapes", ()):
figure.add_shape(**shape)
for annotation in options.get("annotations", ()):
figure.add_annotation(**annotation)
return figure

def show(self):
@@ -107,9 +112,8 @@ def _ipython_display_(self):
def _generate_plotly_traces(self):
colors = itertools.cycle(VASP_COLORS)
for series in self:
if not series.color:
series = replace(series, color=next(colors))
yield from series._generate_traces()
series = _set_color_if_not_present(series, colors)
yield from series.to_plotly()

def _make_plotly_figure(self):
figure = self._figure_with_one_or_two_y_axes()
@@ -120,12 +124,13 @@ def _make_plotly_figure(self):
return figure

def _figure_with_one_or_two_y_axes(self):
has_secondary_y_axis = lambda series: isinstance(series, Series) and series.y2
if self._subplot_on:
max_row = max(series.subplot for series in self)
figure = subplots.make_subplots(rows=max_row, cols=1)
figure.update_layout(showlegend=False)
return figure
elif any(series.y2 for series in self):
elif any(has_secondary_y_axis(series) for series in self):
return subplots.make_subplots(specs=[[{"secondary_y": True}]])
else:
return go.Figure()
@@ -142,6 +147,8 @@ def _set_xaxis_options(self, figure):
figure.layout.xaxis.tickmode = "array"
figure.layout.xaxis.tickvals = tuple(self.xticks.keys())
figure.layout.xaxis.ticktext = self._xtick_labels()
if self._all_are_contour():
figure.layout.xaxis.visible = False

def _xtick_labels(self):
# empty labels will be overwritten by plotly so we put a single space in them
@@ -156,6 +163,17 @@ def _set_yaxis_options(self, figure):
figure.layout.yaxis.title.text = self.ylabel
if self.y2label:
figure.layout.yaxis2.title.text = self.y2label
if self._all_are_contour():
figure.layout.yaxis.visible = False
if self._any_are_contour():
figure.layout.yaxis.scaleanchor = "x"
figure.layout.height = 500

def _all_are_contour(self):
return all(isinstance(series, Contour) for series in self)

def _any_are_contour(self):
return any(isinstance(series, Contour) for series in self)

def to_frame(self):
"""Convert graph to a pandas dataframe.
@@ -213,7 +231,16 @@ def _name_column(self, series, suffix, idx=None):

@property
def _subplot_on(self):
return any(series.subplot for series in self)
has_subplot = lambda series: isinstance(series, Series) and series.subplot
return any(has_subplot(series) for series in self)


def _set_color_if_not_present(series, color_iterator):
if isinstance(series, Contour):
return series
if not series.color:
series = replace(series, color=next(color_iterator))
return series


Graph._fields = tuple(field.name for field in fields(Graph))
8 changes: 6 additions & 2 deletions src/py4vasp/_third_party/graph/series.py
Original file line number Diff line number Diff line change
@@ -5,13 +5,14 @@
import numpy as np

from py4vasp import exception
from py4vasp._third_party.graph import trace
from py4vasp._util import import_

go = import_.optional("plotly.graph_objects")


@dataclass
class Series:
class Series(trace.Trace):
"""Represents a single series in a graph.

Typically this corresponds to a single line of x-y data with an optional name used
@@ -55,7 +56,7 @@ def __setattr__(self, key, value):
assert not self._frozen or hasattr(self, key)
super().__setattr__(key, value)

def _generate_traces(self):
def to_plotly(self):
first_trace = True
for item in enumerate(np.atleast_2d(np.array(self.y))):
yield self._make_trace(*item, first_trace), {"row": self.subplot}
@@ -123,5 +124,8 @@ def _common_options(self, first_trace):
"yaxis": "y2" if self.y2 else "y",
}

def _generate_shapes(self):
return ()


Series._fields = tuple(field.name for field in fields(Series))
14 changes: 14 additions & 0 deletions src/py4vasp/_third_party/graph/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright © VASP Software GmbH,
# Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from abc import ABC, abstractmethod


class Trace(ABC):
"""Defines a base class with all methods that need to be implemented for Graph to
work as intended"""

@abstractmethod
def to_plotly(self):
"""Use yield to generate one or more plotly traces. Each returned element should
be a tuple (trace, dict) where the trace can be used as data for plotly and the
options modify the generation of the figure."""
Loading