Skip to content

Commit

Permalink
completed get_data_1d for ScanGroup class
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Li committed Oct 21, 2024
1 parent e0b3193 commit ca020d5
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 196 deletions.
87 changes: 37 additions & 50 deletions src/tavi/data/plotter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# import matplotlib.colors as colors
from typing import Optional

import numpy as np
from tavi.data.scan_data import ScanData1D


class Plot1D(object):
Expand All @@ -18,60 +18,36 @@ class Plot1D(object):
"""

def __init__(
self,
x: np.ndarray,
y: np.ndarray,
xerr: Optional[np.ndarray] = None,
yerr: Optional[np.ndarray] = None,
) -> None:
def __init__(self) -> None:
# self.ax = None
self.x = x
self.y = y
self.xerr = xerr
self.yerr = yerr

self.title: str = ""
self.xlabel: Optional[str] = None
self.ylabel: Optional[str] = None
self.label: Optional[str] = None
self.data_list: list[ScanData1D] = []
self.title = ""
self.xlabel = None
self.ylabel = None

# plot specifications
self.xlim: Optional[tuple[float, float]] = None
self.ylim: Optional[tuple[float, float]] = None
self.color = "C0"
self.fmt = "o"
self.LOG_X = False
self.LOG_Y = False

def make_labels(
self,
x_str: str,
y_str: str,
norm_to: tuple[float, str],
label: Optional[str] = None,
title: Optional[str] = None,
):
"""Create axes labels, plot title and curve label"""

norm_val, norm_channel = norm_to
if norm_channel == "time":
norm_channel_str = "seconds"
else:
norm_channel_str = norm_channel
if norm_val == 1:
self.ylabel = y_str + "/ " + norm_channel_str
else:
self.ylabel = y_str + f" / {norm_val} " + norm_channel_str

self.xlabel = x_str
self.label = label
self.title = title

def plot_curve(self, ax):
if self.yerr is None:
ax.plot(self.x, self.y, label=self.label)
else:
ax.errorbar(x=self.x, y=self.y, yerr=self.yerr, fmt=self.fmt, label=self.label)
def add_scan(self, scan_data: ScanData1D, **kwargs):
self.data_list.append(scan_data)
for key, val in kwargs.items():
scan_data.fmt.update({key: val})

def plot(self, ax):
for data in self.data_list:
if data.err is None:
if not data.label:
ax.plot(data.x, data.y, **data.fmt)
else:
ax.plot(data.x, data.y, label=data.label, **data.fmt)
else:
if not data.label:
ax.errorbar(x=data.x, y=data.y, yerr=data.err, **data.fmt)
else:
ax.errorbar(x=data.x, y=data.y, yerr=data.err, label=data.label, **data.fmt)

if self.xlim is not None:
ax.set_xlim(left=self.xlim[0], right=self.xlim[1])
Expand All @@ -80,8 +56,19 @@ def plot_curve(self, ax):

if self.title is not None:
ax.set_title(self.title)
ax.set_xlabel(self.xlabel)
ax.set_ylabel(self.ylabel)

if self.xlabel is None:
xlabels = []
for data in self.data_list:
xlabels.append(data.xlabel)
ax.set_xlabel(",".join(xlabels))

if self.ylabel is None:
ylabels = []
for data in self.data_list:
ylabels.append(data.ylabel)
ax.set_ylabel(",".join(ylabels))

ax.grid(alpha=0.6)
ax.legend()

Expand Down
48 changes: 24 additions & 24 deletions src/tavi/data/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, name: str, nexus_dict: NexusEntry) -> None:

self.name: str = name
self._nexus_dict: NexusEntry = nexus_dict
self.data: dict = self.get_data()
self.data: dict = self.get_data_columns()

@classmethod
def from_spice(
Expand Down Expand Up @@ -161,7 +161,7 @@ def instrument_info(self):
)
return instru_info

def get_data(self) -> dict:
def get_data_columns(self) -> dict[str, np.ndarray]:
"""Get scan data as a dictionary"""
data_dict = {}
names = self._nexus_dict.get_dataset_names()
Expand All @@ -187,13 +187,12 @@ def validate_rebin_params(rebin_params: float | int | tuple) -> tuple:
raise ValueError(f"Unrecogonized rebin parameters {rebin_params}")
return rebin_params

def get_plot_data(
def get_data(
self,
axes: tuple[Optional[str], Optional[str]] = (None, None),
norm_to: Optional[tuple[float, Literal["time", "monitor", "mcu"]]] = None,
rebin_type: Literal["tol", "grid", None] = None,
rebin_params: Union[float, tuple] = 0.0,
) -> Plot1D:
**rebin_params_dict: Union[float, tuple],
) -> ScanData1D:
"""Generate a curve from a single scan to plot, with the options
to normalize the y-axis and rebin x-axis.
Expand All @@ -214,17 +213,20 @@ def get_plot_data(
label = "scan " + str(self.scan_info.scan_num)
title = f"{label}: {self.scan_info.scan_title}"

if rebin_type is None: # no rebin
for rebin_type in ["grid", "tol"]:
rebin_params = rebin_params_dict.get(rebin_type)
if rebin_params is not None:
break

if not rebin_params: # no rebin
if norm_to is not None: # normalize y-axis without rebining along x-axis
norm_val, norm_channel = norm_to
scan_data_1d.renorm(norm_col=self.data[norm_channel] / norm_val)
else:
else: # equivalent to normalizing to preset
norm_to = (self.scan_info.preset_value, self.scan_info.preset_channel)

plot1d = Plot1D(x=scan_data_1d.x, y=scan_data_1d.y, yerr=scan_data_1d.err)
plot1d.make_labels(x_str, y_str, norm_to, label, title)

return plot1d
scan_data_1d.make_labels((x_str, y_str), norm_to, label, title)
return scan_data_1d

# Rebin, first validate rebin params
rebin_params_tuple = Scan.validate_rebin_params(rebin_params)
Expand Down Expand Up @@ -254,25 +256,23 @@ def get_plot_data(
norm_val=norm_val,
)
case _:
raise ValueError('Unrecogonized rebin type. Needs to be "tol" or "grid".')
raise ValueError('Unrecogonized rebin_params_dict keyword. Needs to be "tol" or "grid".')

plot1d = Plot1D(x=scan_data_1d.x, y=scan_data_1d.y, yerr=scan_data_1d.err)
plot1d.make_labels(x_str, y_str, norm_to, label, title)
return plot1d
scan_data_1d.make_labels((x_str, y_str), norm_to, label, title)
return scan_data_1d

def plot(
self,
x_str: Optional[str] = None,
y_str: Optional[str] = None,
norm_channel: Literal["time", "monitor", "mcu", None] = None,
norm_val: float = 1.0,
rebin_type: Literal["tol", "grid", None] = None,
rebin_step: float = 0.0,
axes: tuple[Optional[str], Optional[str]] = (None, None),
norm_to: Optional[tuple[float, Literal["time", "monitor", "mcu"]]] = None,
**rebin_params_dict: Union[float, tuple],
):
"""Plot a 1D curve gnerated from a singal scan in a new window"""

plot1d = self.get_plot_data(x_str, y_str, norm_channel, norm_val, rebin_type, rebin_step)
scan_data_1d = self.get_data(axes, norm_to, **rebin_params_dict)

fig, ax = plt.subplots()
plot1d.plot_curve(ax)
plot1d = Plot1D()
plot1d.add_scan(scan_data_1d, c="C0", fmt="o")
plot1d.plot(ax)
fig.show()
54 changes: 37 additions & 17 deletions src/tavi/data/scan_data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-


import numpy as np


class ScanData1D(object):
"""1D scan data ready to be plot, with ooptions to renormalize or rebin"""

ZERO = 1e-6

Expand All @@ -17,6 +17,26 @@ def __init__(self, x: np.ndarray, y: np.ndarray) -> None:
self.y = y
self.err = np.sqrt(y)
# self._ind = ind
self.label = ""
self.title = ""
self.fmt: dict = {}

def make_labels(self, axes: tuple[str, str], norm_to: tuple[float, str], label: str = "", title: str = "") -> None:
"""Create axes labels, plot title and curve label"""
x_str, y_str = axes
norm_val, norm_channel = norm_to
if norm_channel == "time":
norm_channel_str = "seconds"
else:
norm_channel_str = norm_channel
if norm_val == 1:
self.ylabel = y_str + "/ " + norm_channel_str
else:
self.ylabel = y_str + f" / {norm_val} " + norm_channel_str

self.xlabel = x_str
self.label = label
self.title = title

def __add__(self, other):
# check x length, rebin other if do not match
Expand Down Expand Up @@ -85,7 +105,7 @@ def rebin_tol(self, rebin_params: tuple, weight_col: np.ndarray):
rebin_min = np.min(self.x) if rebin_min is None else rebin_min
rebin_max = np.max(self.x) if rebin_max is None else rebin_max

x_grid = np.arange(rebin_min + rebin_step / 2, rebin_max + rebin_step / 2, rebin_step)
x_grid = np.arange(rebin_min - rebin_step / 2, rebin_max + rebin_step * 3 / 2, rebin_step)
x = np.zeros_like(x_grid)
y = np.zeros_like(x_grid)
counts = np.zeros_like(x_grid)
Expand All @@ -98,17 +118,17 @@ def rebin_tol(self, rebin_params: tuple, weight_col: np.ndarray):
weights[idx] += weight_col[i]
counts[idx] += 1

self.err = np.sqrt(y) / counts
self.y = y / counts
self.x = x / weights
self.err = np.sqrt(y[1:-2]) / counts[1:-2]
self.y = y[1:-2] / counts[1:-2]
self.x = x[1:-2] / weights[1:-2]

def rebin_tol_renorm(self, rebin_params: tuple, norm_col: np.ndarray, norm_val: float = 1.0):
"""Rebin with tolerance and renormalize"""
rebin_min, rebin_max, rebin_step = rebin_params
rebin_min = np.min(self.x) if rebin_min is None else rebin_min
rebin_max = np.max(self.x) if rebin_max is None else rebin_max

x_grid = np.arange(rebin_min + rebin_step / 2, rebin_max + rebin_step / 2, rebin_step)
x_grid = np.arange(rebin_min - rebin_step / 2, rebin_max + rebin_step * 3 / 2, rebin_step)
x = np.zeros_like(x_grid)
y = np.zeros_like(x_grid)
counts = np.zeros_like(x_grid)
Expand All @@ -121,17 +141,17 @@ def rebin_tol_renorm(self, rebin_params: tuple, norm_col: np.ndarray, norm_val:
x[idx] += self.x[i] * norm_col[i]
counts[idx] += norm_col[i]

self.err = np.sqrt(y) / counts * norm_val
self.y = y / counts * norm_val
self.x = x / counts
self.err = np.sqrt(y[1:-2]) / counts[1:-2] * norm_val
self.y = y[1:-2] / counts[1:-2] * norm_val
self.x = x[1:-2] / counts

def rebin_grid(self, rebin_params: tuple):
"""Rebin with a regular grid"""
rebin_min, rebin_max, rebin_step = rebin_params
rebin_min = np.min(self.x) if rebin_min is None else rebin_min
rebin_max = np.max(self.x) if rebin_max is None else rebin_max

x = np.arange(rebin_min + rebin_step / 2, rebin_max + rebin_step / 2, rebin_step)
x = np.arange(rebin_min - rebin_step / 2, rebin_max + rebin_step * 3 / 2, rebin_step)
y = np.zeros_like(x)
counts = np.zeros_like(x)

Expand All @@ -140,9 +160,9 @@ def rebin_grid(self, rebin_params: tuple):
y[idx] += self.y[i]
counts[idx] += 1

self.x = x
self.err = np.sqrt(y) / counts
self.y = y / counts
self.x = x[1:-2]
self.err = np.sqrt(y[1:-2]) / counts[1:-2]
self.y = y[1:-2] / counts[1:-2]

def rebin_grid_renorm(self, rebin_params: tuple, norm_col: np.ndarray, norm_val: float = 1.0):
"""Rebin with a regular grid and renormalize"""
Expand All @@ -151,7 +171,7 @@ def rebin_grid_renorm(self, rebin_params: tuple, norm_col: np.ndarray, norm_val:
rebin_min = np.min(self.x) if rebin_min is None else rebin_min
rebin_max = np.max(self.x) if rebin_max is None else rebin_max

x = np.arange(rebin_min + rebin_step / 2, rebin_max + rebin_step / 2, rebin_step)
x = np.arange(rebin_min - rebin_step / 2, rebin_max + rebin_step * 3 / 2, rebin_step)
y = np.zeros_like(x)
counts = np.zeros_like(x)

Expand All @@ -162,9 +182,9 @@ def rebin_grid_renorm(self, rebin_params: tuple, norm_col: np.ndarray, norm_val:
y[idx] += self.y[i]
counts[idx] += norm_col[i]

self.x = x
self.err = np.sqrt(y) / counts * norm_val
self.y = y / counts * norm_val
self.x = x[1:-2]
self.err = np.sqrt(y[1:-2]) / counts[1:-2] * norm_val
self.y = y[1:-2] / counts[1:-2] * norm_val


class ScanData2D(object):
Expand Down
Loading

0 comments on commit ca020d5

Please sign in to comment.