Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

about cuda of dtw #11

Open
hdmjdp opened this issue Mar 8, 2023 · 3 comments
Open

about cuda of dtw #11

hdmjdp opened this issue Mar 8, 2023 · 3 comments

Comments

@hdmjdp
Copy link

hdmjdp commented Mar 8, 2023

sdtw = SoftDTW(use_cuda=False, gamma=0.01, warp=134.4)

why use cuda=False?

@heatz123
Copy link
Owner

heatz123 commented Mar 8, 2023

Hi @hdmjdp,
The current implementation of softdtw without CUDA using parallel loops using numba and I think this can be efficient enough. Additionally, due to the high memory usage of matrices for softdtw, it may be required to set a lower batch size, which may reduce the potential advantages of using CUDA. And there seems to be some numerical instability when performing backward computations in the CUDA sdtw implementation, particularly when dealing with long input sequences (refer to https://github.com/Maghoumi/pytorch-softdtw-cuda), which is in our case.

However, if you want the potential speedup of using CUDA, you can modify the models/sdtw_cuda.py file to appropriately handle the warp penalty in the forward and backward function, like as in the no-cuda implementation.

Thanks.

@hdmjdp
Copy link
Author

hdmjdp commented Mar 13, 2023

@heatz123 Thanks,I just train it to 145k without dwt loss, then finetue it on the 145k checkpoint. And there is some error, "torch.arange(0, max_len)
RuntimeError: upper bound and larger bound inconsistent with step sign
"
this error is caused by max_len=negative value.

@heatz123
Copy link
Owner

hi @hdmjdp,
I've checked the implementation, but it is not clear how max_len could be negative, as max_len is set to max value of the durations(>0) in a batch. Could you please provide more details, such as the config settings you are using and any steps for reproducing the issue? Additionally, it would be helpful if you could print the value of max_len to figure out the issue.

Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants