diff --git a/grl/numerical_methods/numerical_solvers/ode_solver.py b/grl/numerical_methods/numerical_solvers/ode_solver.py index 92b9037..7e46df4 100644 --- a/grl/numerical_methods/numerical_solvers/ode_solver.py +++ b/grl/numerical_methods/numerical_solvers/ode_solver.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn -import torchode import treetensor from tensordict import TensorDict from torch import nn @@ -32,13 +31,13 @@ def __init__( ): """ Overview: - Initialize the ODE solver using torchode or torchdyn library. + Initialize the ODE solver using torchdiffeq or torchdyn library. Arguments: ode_solver (:obj:`str`): The ODE solver to use. dt (:obj:`float`): The time step. atol (:obj:`float`): The absolute tolerance. rtol (:obj:`float`): The relative tolerance. - library (:obj:`str`): The library to use for the ODE solver. Currently, it supports 'torchdiffeq', 'torchdyn' and 'torchode'. + library (:obj:`str`): The library to use for the ODE solver. Currently, it supports 'torchdiffeq' and 'torchdyn'. **kwargs: Additional arguments for the ODE solver. """ self.ode_solver = ode_solver @@ -76,8 +75,6 @@ def integrate( return self.odeint_by_torchdyn(drift, x0, t_span) elif self.library == "torchdyn_NeuralODE": return self.odeint_by_torchdyn_NeuralODE(drift, x0, t_span) - elif self.library == "torchode": - return self.odeint_by_torchode(drift, x0, t_span) else: raise ValueError(f"library {self.library} is not supported") @@ -210,9 +207,6 @@ def forward_ode_drift_by_torchdyn_NeuralODE(t, x, args): trajectory = neural_ode(x0, t_span) return trajectory - def odeint_by_torchode(self, x0, t_span): - pass - class DictTensorConverter(nn.Module): @@ -344,13 +338,13 @@ def __init__( ): """ Overview: - Initialize the ODE solver using torchode or torchdyn library. + Initialize the ODE solver using torchdiffeq or torchdyn library. Arguments: ode_solver (:obj:`str`): The ODE solver to use. dt (:obj:`float`): The time step. atol (:obj:`float`): The absolute tolerance. rtol (:obj:`float`): The relative tolerance. - library (:obj:`str`): The library to use for the ODE solver. Currently, it supports 'torchdyn' and 'torchode'. + library (:obj:`str`): The library to use for the ODE solver. Currently, it supports 'torchdyn' and 'torchdiffeq'. **kwargs: Additional arguments for the ODE solver. """ self.ode_solver = ode_solver @@ -400,8 +394,6 @@ def integrate( return self.odeint_by_torchdyn_NeuralODE( drift, x0, t_span, batch_size, x_size ) - elif self.library == "torchode": - return self.odeint_by_torchode(drift, x0, t_span, batch_size, x_size) else: raise ValueError(f"library {self.library} is not supported") @@ -517,6 +509,3 @@ def forward_ode_drift_by_torchdyn_NeuralODE(t, x, args={}): ) return trajectory - - def odeint_by_torchode(self, x0, t_span, batch_size, x_size): - pass diff --git a/setup.py b/setup.py index 0256ff7..d6fd51e 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,6 @@ 'easydict', 'tqdm', 'torchdyn', - 'torchode', 'torchsde', 'scipy', 'POT',