Skip to content

Commit

Permalink
remove initial_time arg
Browse files Browse the repository at this point in the history
  • Loading branch information
Sai Krishna committed Aug 22, 2020
1 parent 1af3a6d commit e13d359
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions jax_reactor/solver/bdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ def _initialize_solver_internal_state(
step_size=first_step_size)


def _solve(ode_fn, initial_time, initial_state, solution_times, jacobian_fn,
def _solve(ode_fn, initial_state, solution_times, jacobian_fn,
atol, rtol, min_step_size_factor, max_step_size_factor, max_order,
max_num_newton_iters, max_num_steps, newton_tol_factor,
newton_step_size_factor, safety_factor, bdf_coefficients):
initial_time = solution_times[0]
def advance_to_solution_time(_states):
"""Takes multiple steps to advance time to `solution_times[n]`."""
n, diagnostics, iterand, solver_internal_state, state_vec, times = _states
Expand Down Expand Up @@ -344,9 +345,8 @@ def advance_to_solution_time_cond(_states):
)


@partial(jax.jit, static_argnums=(0, 4))
@partial(jax.jit, static_argnums=(0, 3))
def bdf_solve(ode_fn,
initial_time,
initial_state,
solution_times,
jacobian_fn,
Expand All @@ -362,7 +362,7 @@ def bdf_solve(ode_fn,
safety_factor=0.9,
bdf_coefficients=[0., 0.1850, -1. / 9., -0.0823, -0.0415, 0.]):

results = _solve(ode_fn, initial_time, initial_state, solution_times,
results = _solve(ode_fn, initial_state, solution_times,
jacobian_fn, atol, rtol, min_step_size_factor,
max_step_size_factor, max_order, max_num_newton_iters,
max_num_steps, newton_tol_factor, newton_step_size_factor,
Expand Down
2 changes: 1 addition & 1 deletion jax_reactor/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def cond_fn(_state):

def step(_state):
i, current_state, current_time, _ = _state
results = bdf_solve(ode_fn, current_time, current_state,
results = bdf_solve(ode_fn, current_state,
np.array([current_time, current_time + dt]),
jacobian_fn)
next_state = results.states[-1]
Expand Down

0 comments on commit e13d359

Please sign in to comment.