Skip to content

Commit

Permalink
remove dependency on torchode.
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Jul 31, 2024
1 parent 3d42e7e commit 69dba2b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 16 deletions.
19 changes: 4 additions & 15 deletions grl/numerical_methods/numerical_solvers/ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
import torch.nn as nn
import torchode
import treetensor
from tensordict import TensorDict
from torch import nn
Expand Down Expand Up @@ -32,13 +31,13 @@ def __init__(
):
"""
Overview:
Initialize the ODE solver using torchode or torchdyn library.
Initialize the ODE solver using torchdiffeq or torchdyn library.
Arguments:
ode_solver (:obj:`str`): The ODE solver to use.
dt (:obj:`float`): The time step.
atol (:obj:`float`): The absolute tolerance.
rtol (:obj:`float`): The relative tolerance.
library (:obj:`str`): The library to use for the ODE solver. Currently, it supports 'torchdiffeq', 'torchdyn' and 'torchode'.
library (:obj:`str`): The library to use for the ODE solver. Currently, it supports 'torchdiffeq' and 'torchdyn'.
**kwargs: Additional arguments for the ODE solver.
"""
self.ode_solver = ode_solver
Expand Down Expand Up @@ -76,8 +75,6 @@ def integrate(
return self.odeint_by_torchdyn(drift, x0, t_span)
elif self.library == "torchdyn_NeuralODE":
return self.odeint_by_torchdyn_NeuralODE(drift, x0, t_span)
elif self.library == "torchode":
return self.odeint_by_torchode(drift, x0, t_span)
else:
raise ValueError(f"library {self.library} is not supported")

Expand Down Expand Up @@ -210,9 +207,6 @@ def forward_ode_drift_by_torchdyn_NeuralODE(t, x, args):
trajectory = neural_ode(x0, t_span)
return trajectory

def odeint_by_torchode(self, x0, t_span):
pass


class DictTensorConverter(nn.Module):

Expand Down Expand Up @@ -344,13 +338,13 @@ def __init__(
):
"""
Overview:
Initialize the ODE solver using torchode or torchdyn library.
Initialize the ODE solver using torchdiffeq or torchdyn library.
Arguments:
ode_solver (:obj:`str`): The ODE solver to use.
dt (:obj:`float`): The time step.
atol (:obj:`float`): The absolute tolerance.
rtol (:obj:`float`): The relative tolerance.
library (:obj:`str`): The library to use for the ODE solver. Currently, it supports 'torchdyn' and 'torchode'.
library (:obj:`str`): The library to use for the ODE solver. Currently, it supports 'torchdyn' and 'torchdiffeq'.
**kwargs: Additional arguments for the ODE solver.
"""
self.ode_solver = ode_solver
Expand Down Expand Up @@ -400,8 +394,6 @@ def integrate(
return self.odeint_by_torchdyn_NeuralODE(
drift, x0, t_span, batch_size, x_size
)
elif self.library == "torchode":
return self.odeint_by_torchode(drift, x0, t_span, batch_size, x_size)
else:
raise ValueError(f"library {self.library} is not supported")

Expand Down Expand Up @@ -517,6 +509,3 @@ def forward_ode_drift_by_torchdyn_NeuralODE(t, x, args={}):
)

return trajectory

def odeint_by_torchode(self, x0, t_span, batch_size, x_size):
pass
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
'easydict',
'tqdm',
'torchdyn',
'torchode',
'torchsde',
'scipy',
'POT',
Expand Down

0 comments on commit 69dba2b

Please sign in to comment.