Skip to content

Commit

Permalink
many changes for zebrafish data checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Jun 29, 2024
1 parent b2b5b28 commit 542028b
Show file tree
Hide file tree
Showing 5 changed files with 515 additions and 8 deletions.
51 changes: 49 additions & 2 deletions face_rhythm/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,10 @@ def fit(
self._model = tl.decomposition.__dict__[method](**self.params_method)

self._cleanup() ## Clean up any previous runs
cp_all = {key: self._model.fit_transform(torch.as_tensor(d, device=self._DEVICE)) for key,d in self.data.items()}
cp_all = {key: self._model.fit_transform(torch.as_tensor(d, device=self._DEVICE)) for key, d in self.data.items()}
self.factors_raw = cp_all
self.factors = {key_factor: {key: cp.factors[ii].cpu().numpy() for ii, key in enumerate(self.names_dims_array_preDecomp)} for key_factor,cp in cp_all.items()}
self.factor_weights = {key: cp.weights.cpu().numpy() for key, cp in cp_all.items()}

## Clean up
self._cleanup()
Expand All @@ -425,6 +427,51 @@ def fit(
self.config['device'] = DEVICE


def order_factors_by_EVR(self, data: dict=None, factors: dict=None, weights: dict=None, overwrite_factors: bool=True):
"""
Order the factors by how much variance each factor explains in the data.
Args:
data (dict of np.ndarray):
Dictionary of data arrays.
Each array should have the same shape.
The arrays should be numpy arrays.
factors (dict of np.ndarray):
Dictionary of factors of the TCA model.
weights (dict of np.ndarray):
Dictionary of weights of the TCA model.
Returns:
factors_ordered (dict of np.ndarray):
Dictionary of factors ordered by how much variance they explain.
"""
factors = factors if factors is not None else self.factors
weights = weights if weights is not None else self.factor_weights
data = data if data is not None else self.data

assert [key_d == key_f for key_d, key_f in zip(data.keys(), factors.keys())], "Data keys and factors keys must match."

outs = {key_d: helpers.order_cp_factors_by_EVR(
tensor_dense=torch.as_tensor(d),
cp_factors=[torch.as_tensor(f) for f in f.values()],
cp_weights=torch.as_tensor(weights[key_d]),
orthogonalizable_EVR=True,
) for (key_d, d), (key_f, f) in zip(data.items(), factors.items())}
orders = {key: out[0] for key, out in outs.items()}
evrs = {key: out[1] for key, out in outs.items()}

factors_ordered = {key: {key_factor: factors[key][key_factor][:, orders[key]] for key_factor in factors[key].keys()} for key in factors.keys()}
weights_ordered = {key: weights[key][orders[key]] for key in weights.keys()}
evrs_ordered = {key: evrs[key] for key in evrs.keys()}

if overwrite_factors:
self.factors = factors_ordered
self.factor_weights = weights_ordered
self.evrs_ordered = evrs_ordered

return orders, factors_ordered, weights_ordered, evrs_ordered


def rearrange_factors(
self,
factors: dict=None,
Expand Down Expand Up @@ -565,7 +612,7 @@ def plot_factors(
title_figure = f"{name_factors}_[{name_factor}]"
## Plot the factor
fig, ax = plt.subplots()
ax.plot(val)
ax.plot(val / val.max(axis=0))
ax.set_title(title_figure)
ax.set_xlabel(f'{name_factor}_bin')
ax.legend(np.arange(val.shape[1]))
Expand Down
Loading

0 comments on commit 542028b

Please sign in to comment.