Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model inclusion of track graph in Decoding notebooks #1157

Open
CBroz1 opened this issue Oct 8, 2024 · 3 comments
Open

Model inclusion of track graph in Decoding notebooks #1157

CBroz1 opened this issue Oct 8, 2024 · 3 comments
Labels
decoding documentation Improvements or additions to documentation enhancement New feature or request

Comments

@CBroz1
Copy link
Member

CBroz1 commented Oct 8, 2024

This line mentions the possibility of including a track graph in decoding

# ### 1D Decoding
#
# If you want to do 1D decoding, you will need to specify the `track_graph`, `edge_order`, and `edge_spacing` in the `environments` parameter. You can read more about these parameters in the [linearization notebook](./24_Linearization.ipynb). You can retrieve these parameters from the `TrackGraph` table if you have stored them there. These will then go into the `environments` parameter of the `ContFragClusterlessClassifier` model.

In my brief attempt, I did not find the process of including this graph straight forward, needing to remap kwargs across a fetched value and classes from non_local_detector. I hit errors with the embedded ObservationModel having an empty name

Partial error stack
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/spyglass/utils/dj_mixin.py:612: in populate
    return super().populate(*restrictions, **kwargs)
../datajoint-python/datajoint/autopopulate.py:254: in populate
    status = self._populate1(key, jobs, **populate_kwargs)
../datajoint-python/datajoint/autopopulate.py:322: in _populate1
    make(dict(key), **(make_kwargs or {}))
src/spyglass/decoding/v1/clusterless.py:212: in make
    classifier.fit(
../../miniconda3/envs/spy/lib/python3.9/site-packages/non_local_detector/models/base.py:960: in fit
    self._fit(
../../miniconda3/envs/spy/lib/python3.9/site-packages/non_local_detector/models/base.py:493: in _fit
    self.initialize_state_index()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = ClusterlessDetector(clusterless_algorithm='clusterless_kde',
                    clusterless_algorithm_params={'block_...ke=False)],
                    sampling_frequency=500.0,
                    state_names=['Continuous', 'Fragmented'])

    def initialize_state_index(self):
        self.n_discrete_states_ = len(self.state_names)
        bin_sizes = []
        state_ind = []
        is_track_interior = []
        for ind, obs in enumerate(self.observation_models):
            if obs.is_local or obs.is_no_spike:
                bin_sizes.append(1)
                state_ind.append(ind * np.ones((1,), dtype=int))
                is_track_interior.append(np.ones((1,), dtype=bool))
            else:
                environment = self.environments[
>                   self.environments.index(obs.environment_name)
                ]
E               ValueError: '' is not in list

../../miniconda3/envs/spy/lib/python3.9/site-packages/
@CBroz1 CBroz1 added documentation Improvements or additions to documentation enhancement New feature or request decoding labels Oct 8, 2024
@edeno
Copy link
Collaborator

edeno commented Oct 8, 2024

Not quite sure what you're running here. It seems like the environment name given for the observation model is not one specified in the list of Environments given to the model.

@CBroz1
Copy link
Member Author

CBroz1 commented Oct 8, 2024

Run snippet
@pytest.fixture
def some_name(track_graph):
    from non_local_detector.environment import Environment
    from non_local_detector.models import ContFragClusterlessClassifier
    from spyglass.decode import v1 as decode_v1

    graph_entry = track_graph.fetch1() # Restricted table
    class_kwargs = dict(
        clusterless_algorithm_params={
            "block_size": 10000,
            "position_std": 12.0,
            "waveform_std": 24.0,
        },
        environments=[
            Environment(
                environment_name=graph_entry["track_graph_name"],
                track_graph=track_graph.get_networkx_track_graph(),
                edge_order=graph_entry["linear_edge_order"],
                edge_spacing=graph_entry["linear_edge_spacing"],
            )
        ],
    )
    params_pk = {"decoding_param_name": "contfrag_clusterless"}
    decode_v1.core.DecodingParameters.insert_default()
    decode_v1.core.DecodingParameters.insert1(
        {
            **params_pk,
            "decoding_params": ContFragClusterlessClassifier(**class_kwargs),
            "decoding_kwargs": dict(),
        },
        skip_duplicates=True,
    )

@edeno
Copy link
Collaborator

edeno commented Oct 8, 2024

If you omit: environment_name=graph_entry["track_graph_name"], then it should work, otherwise you have to specify the same name in the ObservationModel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
decoding documentation Improvements or additions to documentation enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants