Skip to content

Commit

Permalink
Merge pull request #28 from iguinn/memleak
Browse files Browse the repository at this point in the history
Added constant variables to processing chain
  • Loading branch information
iguinn authored Dec 6, 2023
2 parents b9c7b29 + 688ee0b commit 5d80089
Show file tree
Hide file tree
Showing 10 changed files with 684 additions and 632 deletions.
4 changes: 3 additions & 1 deletion src/dspeed/build_dsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def build_dsp(
if lh5_tables is None:
lh5_tables = lh5.ls(f_raw)
elif isinstance(lh5_tables, str):
lh5_tables = [lh5_tables]
lh5_tables = lh5.ls(f_raw, lh5_tables)
elif isinstance(lh5_tables, list):
lh5_tables = [tab for tab_wc in lh5_tables for tab in lh5.ls(f_raw, tab_wc)]
elif not (
hasattr(lh5_tables, "__iter__")
and all(isinstance(el, str) for el in lh5_tables)
Expand Down
123 changes: 111 additions & 12 deletions src/dspeed/processing_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
grid: CoordinateGrid = auto,
unit: str | Unit = auto,
is_coord: bool = auto,
is_const: bool = False,
) -> None:
"""
Parameters
Expand All @@ -154,6 +155,10 @@ def __init__(
is_coord
If ``True``, variable represents an array index and can be converted
into a unitted number using grid.
is_const
If ``True``, variable is a constant. Variable will be set before
executing, and will not be recomputed. Does not have outer
dimension of size _block_width
"""
assert isinstance(proc_chain, ProcessingChain) and isinstance(name, str)
self.proc_chain = proc_chain
Expand All @@ -168,6 +173,7 @@ def __init__(
self.grid = grid
self.unit = unit
self.is_coord = is_coord
self.is_const = is_const

log.debug(f"added variable: {self.description()}")

Expand Down Expand Up @@ -203,18 +209,27 @@ def __setattr__(self, name: str, value: Any) -> None:

super().__setattr__(name, value)

def _make_buffer(self) -> np.ndarray:
shape = (
self.shape
if self.is_const
else (self.proc_chain._block_width,) + self.shape
)
len = np.product(shape)
# Flattened array, with padding to allow memory alignment
buf = np.zeros(len + 64 // self.dtype.itemsize, dtype=self.dtype)
# offset to ensure memory alignment
offset = (64 - buf.ctypes.data) % 64 // self.dtype.itemsize
return buf[offset : offset + len].reshape(shape)

def get_buffer(self, unit: str | Unit = None) -> np.ndarray:
# If buffer needs to be created, do so now
if self._buffer is None:
if self.shape is auto:
raise ProcessingChainError(f"cannot deduce shape of {self.name}")
if self.dtype is auto:
raise ProcessingChainError(f"cannot deduce shape of {self.name}")

# create the buffer so that the array start is aligned in memory on a multiple of 64 bytes
self._buffer = np.zeros(
shape=(self.proc_chain._block_width,) + self.shape, dtype=self.dtype
)
raise ProcessingChainError(f"cannot deduce dtype of {self.name}")
self._buffer = self._make_buffer()

if isinstance(self._buffer, np.ndarray):
if self.is_coord is True:
Expand Down Expand Up @@ -418,6 +433,47 @@ def add_variable(
self._vars_dict[name] = var
return var

def set_constant(
self,
varname: str,
val: np.ndarray | int | float | Quantity,
dtype: str | np.dtype = None,
unit: str | Unit | Quantity = None,
) -> ProcChainVar:
"""Make a variable act as a constant and set it to val.
Parameters
----------
varname
name of internal variable to set. If it does not exist, create
it; otherwise, set existing variable to be constant
val
value of constant
dtype
dtype of constant
unit
unit of constant
"""

param = self.get_variable(varname)
assert param.is_constant or param._buffer is None
param.is_constant = True

if isinstance(val, Quantity):
unit = val.unit
val = val.magnitude

val = np.array(val, dtype=dtype)

param.update_auto(
shape=val.shape,
dtype=val.dtype,
unit=unit,
)
np.copyto(param.get_buffer(), val, casting="unsafe")
log.debug(f"set constant: {self.description()} = {val}")
return param

def link_input_buffer(
self, varname: str, buff: np.ndarray | LGDO = None
) -> np.ndarray | LGDO:
Expand Down Expand Up @@ -698,7 +754,12 @@ def _parse_expr(
op, op_form = ast_ops_dict[type(node.op)]

if not (isinstance(lhs, ProcChainVar) or isinstance(rhs, ProcChainVar)):
return op(lhs, rhs)
ret = op(lhs, rhs)
if isinstance(ret, Quantity) and ureg.is_compatible_with(
ret.u, ureg.dimensionless
):
ret = ret.to(ureg.dimensionless).magnitude
return ret

name = "(" + op_form.format(str(lhs), str(rhs)) + ")"
if isinstance(lhs, ProcChainVar) and isinstance(rhs, ProcChainVar):
Expand Down Expand Up @@ -847,7 +908,10 @@ def get_index(slice_value):
return attr

# Otherwise this is probably a ProcChainVar
val = self._parse_expr(node.value, expr, dry_run, var_name_list)
# Note that we are excluding this variable from the vars list
# because it does not strictly need to be computed before as a
# prerequisite before accessing its attributes
val = self._parse_expr(node.value, expr, dry_run, [])
if val is None:
return None
return getattr(val, node.attr)
Expand Down Expand Up @@ -1319,7 +1383,7 @@ def __init__(
# Convert scalar to right type, including units
if isinstance(param, (Quantity, Unit)):
if ureg.is_compatible_with(ureg.dimensionless, param):
param = float(param)
param = param.to(ureg.dimensionless).magnitude
elif not isinstance(
grid, CoordinateGrid
) or not ureg.is_compatible_with(grid.period.u, param):
Expand All @@ -1328,7 +1392,7 @@ def __init__(
f"CoordinateGrid is {grid}"
)
else:
param = float(param / grid.period)
param = (param / grid.period).to(ureg.dimensionless).magnitude
if np.issubdtype(dtype, np.integer):
param = dtype.type(round(param))
else:
Expand Down Expand Up @@ -1977,10 +2041,10 @@ def resolve_dependencies(
module = importlib.import_module(recipe["module"])
func = getattr(module, recipe["function"])
args = recipe["args"]
new_vars = [k for k in re.split(",| ", proc_par) if k != ""]

# Initialize the new variables, if needed
if "unit" in recipe:
new_vars = [k for k in re.split(",| ", proc_par) if k != ""]
for i, name in enumerate(new_vars):
unit = recipe.get("unit", auto)
if isinstance(unit, list):
Expand Down Expand Up @@ -2038,7 +2102,42 @@ def resolve_dependencies(
except KeyError:
pass

proc_chain.add_processor(func, *args, **kwargs)
# Check if new variables should be treated as constants
if not recipe["prereqs"]:
arg_params = []
kwarg_params = {}
out_is_arg = False
for arg in args:
if isinstance(arg, str):
arg = proc_chain.get_variable(arg)
if isinstance(arg, dict):
kwarg_params.update(arg)
arg = list(arg.values())[0]
else:
arg_params.append(arg)
if isinstance(arg, ProcChainVar) and arg.name in new_vars:
out_is_arg = True
arg.is_const = True
# arg = arg.get_buffer()

if out_is_arg:
proc_man = ProcessorManager(
proc_chain,
func,
arg_params,
kwarg_params,
kwargs.get("signature", None),
kwargs.get("types", None),
)
proc_man.execute()
else:
const_val = func(*arg_params, **kwarg_params)
if len(new_vars) == 1:
const_val = const_val
for var, val in zip(new_vars, const_val):
proc_chain.set_constant(var, val)
else:
proc_chain.add_processor(func, *args, **kwargs)
except Exception as e:
raise ProcessingChainError(
"Exception raised while attempting to add processor:\n"
Expand Down
7 changes: 5 additions & 2 deletions src/dspeed/processors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,16 @@
"""

from .bl_subtract import bl_subtract
from .convolutions import cusp_filter, moving_slope, step, t0_filter, zac_filter
from .dplms import dplms
from .convolutions import convolve_wf, fft_convolve_wf
from .dwt import discrete_wavelet_transform
from .energy_kernels import cusp_filter, dplms, zac_filter
from .fftw import dft, inv_dft, psd
from .fixed_time_pickoff import fixed_time_pickoff
from .gaussian_filter1d import gaussian_filter1d
from .get_multi_local_extrema import get_multi_local_extrema
from .get_wf_centroid import get_wf_centroid
from .histogram import histogram, histogram_stats
from .kernels import moving_slope, step, t0_filter
from .linear_slope_fit import linear_slope_diff, linear_slope_fit
from .log_check import log_check
from .min_max import min_max
Expand Down Expand Up @@ -100,6 +101,8 @@

__all__ = [
"bl_subtract",
"convolve_wf",
"fft_convolve_wf",
"cusp_filter",
"t0_filter",
"zac_filter",
Expand Down
Loading

0 comments on commit 5d80089

Please sign in to comment.