Skip to content

Commit

Permalink
implementing ScanData2D class
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Li committed Oct 22, 2024
1 parent 5e0275e commit 5ea407b
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 49 deletions.
67 changes: 60 additions & 7 deletions src/tavi/data/plotter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# import matplotlib.colors as colors
from typing import Optional

from tavi.data.scan_data import ScanData1D
from tavi.data.scan_data import ScanData1D, ScanData2D


class Plot1D(object):
Expand All @@ -20,7 +20,7 @@ class Plot1D(object):

def __init__(self) -> None:
# self.ax = None
self.data_list: list[ScanData1D] = []
self.scan_data: list[ScanData1D] = []
self.title = ""
self.xlabel = None
self.ylabel = None
Expand All @@ -32,12 +32,12 @@ def __init__(self) -> None:
self.LOG_Y = False

def add_scan(self, scan_data: ScanData1D, **kwargs):
self.data_list.append(scan_data)
self.scan_data.append(scan_data)
for key, val in kwargs.items():
scan_data.fmt.update({key: val})

def plot(self, ax):
for data in self.data_list:
for data in self.scan_data:
if data.err is None:
if not data.label:
ax.plot(data.x, data.y, **data.fmt)
Expand All @@ -59,13 +59,13 @@ def plot(self, ax):

if self.xlabel is None:
xlabels = []
for data in self.data_list:
for data in self.scan_data:
xlabels.append(data.xlabel)
ax.set_xlabel(",".join(xlabels))

if self.ylabel is None:
ylabels = []
for data in self.data_list:
for data in self.scan_data:
ylabels.append(data.ylabel)
ax.set_ylabel(",".join(ylabels))

Expand All @@ -74,4 +74,57 @@ def plot(self, ax):


class Plot2D(object):
pass

def __init__(self) -> None:
# self.ax = None
self.contour_data: list[ScanData2D] = []
self.curve_data: list[ScanData1D] = []
self.title = ""
self.xlabel = None
self.ylabel = None

# plot specifications
self.xlim: Optional[tuple[float, float]] = None
self.ylim: Optional[tuple[float, float]] = None
self.LOG_X = False
self.LOG_Y = False

def add_contour(self, contour_data: ScanData2D, **kwargs):
self.contour_data.append(contour_data)
for key, val in kwargs.items():
contour_data.fmt.update({key: val})

# TODO
def add_curve(self, curve_data: ScanData1D, **kwargs):
self.curve_data.append(curve_data)
for key, val in kwargs.items():
curve_data.fmt.update({key: val})

def plot(self, ax):
for contour in self.contour_data:
ax.pcolormesh(contour.x, contour.y, contour.z, **contour.fmt)
for curve in self.curve_data:
ax.errorbar(x=curve.x, y=curve.y, yerr=curve.err)

if self.xlim is not None:
ax.set_xlim(left=self.xlim[0], right=self.xlim[1])
if self.ylim is not None:
ax.set_ylim(bottom=self.ylim[0], top=self.ylim[1])

if self.title is not None:
ax.set_title(self.title)

if self.xlabel is None:
xlabels = []
for contour in self.contour_data:
xlabels.append(contour.xlabel)
ax.set_xlabel(",".join(xlabels))

if self.ylabel is None:
ylabels = []
for contour in self.contour_data:
ylabels.append(contour.ylabel)
ax.set_ylabel(",".join(ylabels))

ax.grid(alpha=0.6)
ax.legend()
32 changes: 30 additions & 2 deletions src/tavi/data/scan_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,40 @@ def __init__(self, x: np.ndarray, y: np.ndarray, z: np.ndarray) -> None:

self.err = np.sqrt(z)
self.title = ""
self.fmt: dict = {}

def make_labels(
self,
axes: tuple[str, str, str],
norm_to: tuple[float, str],
label: str = "",
title: str = "",
) -> None:
"""Create axes labels, plot title and curve label"""
x_str, y_str, z_str = axes
norm_val, norm_channel = norm_to
if norm_channel == "time":
norm_channel_str = "seconds"
else:
norm_channel_str = norm_channel
if norm_val == 1:
self.zlabel = z_str + "/ " + norm_channel_str
else:
self.zlabel = z_str + f" / {norm_val} " + norm_channel_str

self.xlabel = x_str
self.ylabel = y_str
self.label = label
self.title = title + self.zlabel

def __sub__(self, other):
pass

def renorm(self):
def renorm(self, norm_col: np.ndarray, norm_val: float = 1.0):
pass

def rebin_grid(self):
def rebin_grid(self, rebin_params: tuple):
pass

def rebin_grid_renorm(self, rebin_params: tuple, norm_col: np.ndarray, norm_val: float = 1.0):
pass
45 changes: 14 additions & 31 deletions src/tavi/data/scan_group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np

from tavi.data.scan import Scan
Expand Down Expand Up @@ -134,31 +133,35 @@ def _get_data_2d(
"""

x_axis, y_axis, z_axis = axes
x_array = np.array([])
y_array = np.array([])
z_array = np.array([])
x_array = []
y_array = []
z_array = []

title = "Combined scans: "

for scan in self.scans:
x_array = np.append(x_array, scan.data[x_axis])
y_array = np.append(y_array, scan.data[y_axis])
z_array = np.append(y_array, scan.data[z_axis])
x_array.append(scan.data[x_axis])
y_array.append(scan.data[y_axis])
z_array.append(scan.data[z_axis])
title += f"{scan.scan_info.scan_num} "

scan_data_2d = ScanData2D(x=x_array, y=y_array, z=z_array)
scan_data_2d = ScanData2D(
x=np.vstack(x_array),
y=np.vstack(y_array),
z=np.vstack(z_array),
)
rebin_params = rebin_params_dict.get("grid")

if not rebin_params: # no rebin,
if norm_to is not None: # renorm
norm_val, norm_channel = norm_to
norm_list = self._get_norm_list(norm_channel)
scan_data_1d.renorm(norm_col=norm_list, norm_val=norm_val)
scan_data_2d.renorm(norm_col=norm_list, norm_val=norm_val)
else: # no renorm, check if all presets are the same
norm_to = self._get_default_renorm_params()

scan_data_1d.make_labels(axes, norm_to, title=title)
return scan_data_1d
scan_data_2d.make_labels(axes, norm_to, title=title)
return scan_data_2d

# if not isinstance(rebin_params, tuple):
# raise ValueError(f"rebin parameters ={rebin_params} needs to be a tuple.")
Expand Down Expand Up @@ -207,23 +210,3 @@ def get_data(
raise ValueError(f"x axes={x_axis} or y axes={y_axis} are not identical.")
axes = (*x_axis, *y_axis)
return self._get_data_1d(axes, norm_to, **rebin_params_dict)

def plot(self, contour_plot, cmap="turbo", vmax=100, vmin=0, ylim=None, xlim=None):
"""Plot contour"""

x, y, z, _, _, xlabel, ylabel, zlabel, title = contour_plot

fig, ax = plt.subplots()
p = ax.pcolormesh(x, y, z, shading="auto", cmap=cmap, vmax=vmax, vmin=vmin)
fig.colorbar(p, ax=ax)
ax.set_title(title)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.grid(alpha=0.6)

if xlim is not None:
ax.set_xlim(left=xlim[0], right=xlim[1])
if ylim is not None:
ax.set_ylim(bottom=ylim[0], top=ylim[1])

fig.show()
17 changes: 8 additions & 9 deletions tests/test_scan_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,14 @@ def test_scan_group_1d_rebin():
scan_list = list(range(42, 49, 1)) + list(range(70, 76, 1))

sg = tavi.combine_scans(scan_list, name="dispH")
scan_data_1 = sg.get_data(tol=(0.5, 4, 0.2))
scan_data_2 = sg.get_data(grid=(0.5, 4, 0.2))

fig, ax = plt.subplots()
plot1d = Plot1D()

scan_data_1 = sg.get_data(tol=(0.5, 4, 0.2))
plot1d.add_scan(scan_data_1, c="C0", fmt="o")

scan_data_2 = sg.get_data(grid=(0.5, 4, 0.2))
plot1d.add_scan(scan_data_2, c="C1", fmt="o")

fig, ax = plt.subplots()
plot1d.plot(ax)
plt.show()

Expand All @@ -47,11 +45,12 @@ def test_scan_group_2d():
sg = tavi.combine_scans(scan_list, name="dispH")
scan_data_2d = sg.get_data(
axes=("qh", "en", "detector"),
norm_to=(1, "mcu"),
grid=(0.025, 0.1),
# norm_to=(1, "mcu"),
# grid=(0.025, 0.1),
)

fig, ax = plt.subplots()
plot2d = Plot2D()
plot2d.plot(ax, scan_data_2d, cmap="turbo", vmax=80)
plot2d.add_contour(scan_data_2d, cmap="turbo", vmax=80)
fig, ax = plt.subplots()
plot2d.plot(ax)
plt.show()

0 comments on commit 5ea407b

Please sign in to comment.