Skip to content

Commit

Permalink
fix: baselines after big refactor
Browse files Browse the repository at this point in the history
* feat: docs

* feat: README.md

* feat: TD3 baseline

* feat: PPO baseline

* feat: new env: ant with ball in maze

* chore: remove notebooks and clean_rl

* fix: typos

* refactor: docstrings and metrics_recorder refactor

* Update docs

* chore: update contrastive loss functions with references

* chore: Add new methods and environments documentation

* chore: update docs

* chore: update README.md

* some minor changes/repairs in links (#17)

* Merge remote-tracking branch 'public/anonymize-v2' into dev

* Merge anonymize-v2 and master

* chore: update README.md

* chore: add LICENSE

* chore: update README.md

* refactor: refactor losses.py and update README.md based on Colin's pr…

* feat: (WIP) wrapper - working with ant

* refactor: remove seed/traj id logic from envs (not manipulation)

* refactor: remove seed/traj id logic from envs (manipulation)

* chore: small ant_push env creation fix

* refactor: rename seed to traj_id - so trajectories are distinguishabl…

* Add env wrapper with traj id for CRL

* UNTESTED (extremely likely to be broken, but fixable): remove most br…

* Fix many refactoring bugs but not done yet + stylistic changes + rend…

* Fix rendering + verified working on ant

* Merge with dev branch

* Clean up comments

* Move arguments to MetricsRecorder and add visualization frequency flag

* Update docstrings

* Merge pull request #26 from MichalBortkiewicz/feat/disentangle_brax_t…

* Environment interaction snippet in README

* Update README.md

* Update README.md

* Refactor utils.create_env

* Typos

* Merge branch 'dev' into env_snippet

* chore: environment interaction snippet

* fix: add sanity checks for training configuration

* chore: remove cleanJaxGCRL

* Modify env XMLs to make floor plane larger. This is helpful for large…

* More informative error message for num_envs and batch_size

* Add Dockerfile for containerization

* chore: modify env XMLs to make floor plane larger. This is helpful fo…

* Update readme for new file structure

* feat: add GitHub Actions workflow for CI deployment

* Merge remote-tracking branch 'public/master' into dev

* Merge remote-tracking branch 'origin/dev' into dev

* Merge branch 'dev' of github.com:MichalBortkiewicz/JaxGCRL into dev

* Merge remote-tracking branch 'public/master' into dev

* Merge remote-tracking branch 'origin/master' into dev

* fix: ppo baseline

* fix: sac baseline

* fix: td3 baseline
  • Loading branch information
MichalBortkiewicz authored Jan 5, 2025
1 parent 46e2452 commit b54cd4c
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 57 deletions.
38 changes: 38 additions & 0 deletions dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Use Miniconda as the base image
FROM continuumio/miniconda3:latest

# Set the working directory
WORKDIR /app

# Install dependencies for headless rendering (EGL and related libraries)
RUN apt-get update && apt-get install -y \
libegl1-mesa \
libgles2-mesa \
mesa-utils \
&& rm -rf /var/lib/apt/lists/*

# Copy the environment file
COPY environment.yml .

# Create the Conda environment
RUN conda env create -f environment.yml

# Activate the environment by default and ensure PATH includes the environment
RUN echo "conda activate my_env" >> ~/.bashrc
ENV PATH /opt/conda/envs/my_env/bin:$PATH

# Copy the rest of the application
COPY . .

# Make the training script executable
RUN chmod +x ./scripts/train.sh

# Run the training script to verify installation (optional; can be commented out after testing)
# RUN ./scripts/train.sh

# Expose a port if needed (adjust based on the application)
EXPOSE 8000

# Define the default command (adjust based on the application)
CMD ["./scripts/train.sh"]

27 changes: 14 additions & 13 deletions src/baselines/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import optax
from orbax import checkpoint as ocp

from envs.wrappers import TrajectoryIdWrapper
from src.evaluator import CrlEvaluator


Expand Down Expand Up @@ -108,6 +109,7 @@ def train(
Callable[[base.System, jnp.ndarray], Tuple[base.System, base.System]]
] = None,
restore_checkpoint_path: Optional[str] = None,
visualization_interval: int = 5,
):
"""PPO training.
Expand Down Expand Up @@ -218,12 +220,15 @@ def train(
else:
wrap_for_training = envs_v1.wrappers.wrap_for_training

env = environment
env = TrajectoryIdWrapper(env)
env = wrap_for_training(
environment,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
unwrapped_env = environment

reset_fn = jax.jit(jax.vmap(env.reset))
key_envs = jax.random.split(key_env, num_envs // process_count)
Expand Down Expand Up @@ -415,6 +420,7 @@ def training_epoch_with_timing(
v_randomization_fn = functools.partial(
randomization_fn, rng=jax.random.split(eval_key, num_eval_envs)
)
eval_env = TrajectoryIdWrapper(eval_env)
eval_env = wrap_for_training(
eval_env,
episode_length=episode_length,
Expand All @@ -433,18 +439,15 @@ def training_epoch_with_timing(
# Run initial eval
metrics = {}
if process_id == 0 and num_evals > 1:
metrics = evaluator.run_evaluation(
_unpmap(
(training_state.normalizer_params, training_state.params.policy)),
training_metrics={})
logging.info(metrics)
progress_fn(0, metrics)
metrics = evaluator.run_evaluation(_unpmap((training_state.normalizer_params, training_state.params.policy)), training_metrics={})
logging.info(metrics)
progress_fn(0, metrics, make_policy, _unpmap((training_state.normalizer_params, training_state.params.policy)), unwrapped_env)

training_metrics = {}
training_walltime = 0
current_step = 0
for it in range(num_evals_after_init):
logging.info('starting iteration %s %s', it, time.time() - xt)
for eval_epoch_num in range(num_evals_after_init):
logging.info('starting iteration %s %s', eval_epoch_num, time.time() - xt)

for _ in range(max(num_resets_per_eval, 1)):
# optimization
Expand All @@ -468,11 +471,9 @@ def training_epoch_with_timing(
(training_state.normalizer_params, training_state.params.policy)),
training_metrics)
logging.info(metrics)
progress_fn(current_step, metrics)
params = _unpmap(
(training_state.normalizer_params, training_state.params)
)
policy_params_fn(current_step, make_policy, params)
do_render = (eval_epoch_num % visualization_interval) == 0
progress_fn(current_step, metrics, make_policy, _unpmap((training_state.normalizer_params, training_state.params.policy)),
unwrapped_env, do_render=do_render)

total_steps = current_step
assert total_steps >= num_timesteps
Expand Down
12 changes: 9 additions & 3 deletions src/baselines/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import jax.numpy as jnp
import optax

from envs.wrappers import TrajectoryIdWrapper
from src.evaluator import CrlEvaluator
from src.replay_buffer import QueueBase, Sample

Expand Down Expand Up @@ -262,6 +263,7 @@ def train(
checkpoint_logdir: Optional[str] = None,
eval_env: Optional[envs.Env] = None,
randomization_fn: Optional[Callable[[base.System, jnp.ndarray], Tuple[base.System, base.System]]] = None,
visualization_interval: int = 5,
):
"""SAC training."""
process_id = jax.process_index()
Expand Down Expand Up @@ -307,12 +309,14 @@ def train(
randomization_fn,
rng=jax.random.split(key, num_envs // jax.process_count() // local_devices_to_use),
)
env = TrajectoryIdWrapper(env)
env = wrap_for_training(
env,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
unwrapped_env = environment

obs_size = env.observation_size
action_size = env.action_size
Expand Down Expand Up @@ -639,6 +643,7 @@ def training_epoch_with_timing(
eval_env = environment
if randomization_fn is not None:
v_randomization_fn = functools.partial(randomization_fn, rng=jax.random.split(eval_key, num_eval_envs))
eval_env = TrajectoryIdWrapper(eval_env)
eval_env = wrap_for_training(
eval_env,
episode_length=episode_length,
Expand All @@ -662,7 +667,7 @@ def training_epoch_with_timing(
_unpmap((training_state.normalizer_params, training_state.policy_params)), training_metrics={}
)
logging.info(metrics)
progress_fn(0, metrics)
progress_fn(0, metrics, make_policy, _unpmap((training_state.normalizer_params, training_state.policy_params)), unwrapped_env)

# Create and initialize the replay buffer.
t = time.time()
Expand All @@ -678,7 +683,7 @@ def training_epoch_with_timing(
training_walltime = time.time() - t

current_step = 0
for _ in range(num_evals_after_init):
for eval_epoch_num in range(num_evals_after_init):
logging.info("step %s", current_step)

# Optimization
Expand All @@ -702,7 +707,8 @@ def training_epoch_with_timing(
_unpmap((training_state.normalizer_params, training_state.policy_params)), training_metrics
)
logging.info(metrics)
progress_fn(current_step, metrics)
do_render = (eval_epoch_num % visualization_interval) == 0
progress_fn(current_step, metrics, make_policy, _unpmap((training_state.normalizer_params, training_state.policy_params)), unwrapped_env, do_render)

total_steps = current_step
assert total_steps >= num_timesteps
Expand Down
13 changes: 10 additions & 3 deletions src/baselines/td3/td3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import jax.numpy as jnp
import optax

from envs.wrappers import TrajectoryIdWrapper
from src.evaluator import CrlEvaluator
from src.replay_buffer import QueueBase, Sample

Expand Down Expand Up @@ -266,6 +267,7 @@ def train(
smoothing_noise: int = 0.2,
exploration_noise: float = 0.4,
randomization_fn: Optional[Callable[[base.System, jnp.ndarray], Tuple[base.System, base.System]]] = None,
visualization_interval: int = 5,
):
"""TD3 training."""
process_id = jax.process_index()
Expand Down Expand Up @@ -311,12 +313,14 @@ def train(
randomization_fn,
rng=jax.random.split(key, num_envs // jax.process_count() // local_devices_to_use),
)
env = TrajectoryIdWrapper(env)
env = wrap_for_training(
env,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
)
unwrapped_env = environment

obs_size = env.observation_size
action_size = env.action_size
Expand Down Expand Up @@ -642,6 +646,7 @@ def training_epoch_with_timing(
eval_env = environment
if randomization_fn is not None:
v_randomization_fn = functools.partial(randomization_fn, rng=jax.random.split(eval_key, num_eval_envs))
eval_env = TrajectoryIdWrapper(eval_env)
eval_env = wrap_for_training(
eval_env,
episode_length=episode_length,
Expand All @@ -665,7 +670,7 @@ def training_epoch_with_timing(
_unpmap((training_state.normalizer_params, training_state.policy_params)), training_metrics={}
)
logging.info(metrics)
progress_fn(0, metrics)
progress_fn(0, metrics, make_policy, _unpmap((training_state.normalizer_params, training_state.policy_params)), unwrapped_env)

# Create and initialize the replay buffer.
t = time.time()
Expand All @@ -681,7 +686,7 @@ def training_epoch_with_timing(
training_walltime = time.time() - t

current_step = 0
for _ in range(num_evals_after_init):
for eval_epoch_num in range(num_evals_after_init):
logging.info("step %s", current_step)

# Optimization
Expand All @@ -705,7 +710,9 @@ def training_epoch_with_timing(
_unpmap((training_state.normalizer_params, training_state.policy_params)), training_metrics
)
logging.info(metrics)
progress_fn(current_step, metrics)
do_render = (eval_epoch_num % visualization_interval) == 0
progress_fn(current_step, metrics, make_policy, _unpmap((training_state.normalizer_params, training_state.policy_params)), unwrapped_env, do_render)


total_steps = current_step
assert total_steps >= num_timesteps
Expand Down
13 changes: 5 additions & 8 deletions training_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pyinstrument import Profiler

from src.baselines.ppo import train
from utils import MetricsRecorder, create_env, create_eval_env, create_parser, render
from utils import MetricsRecorder, create_env, create_eval_env, create_parser


def main(args):
Expand Down Expand Up @@ -54,7 +54,7 @@ def main(args):
num_evals=args.num_evals,
reward_scaling=1,
episode_length=args.episode_length,
normalize_observations=args.normalize_observations,
normalize_observations=False,
action_repeat=args.action_repeat,
unroll_length=args.unroll_length,
discounting=args.discounting,
Expand Down Expand Up @@ -89,13 +89,10 @@ def main(args):
"training/alpha_loss",
"training/entropy",
]
metrics_recorder = MetricsRecorder(args.num_timesteps, metrics_to_collect)
metrics_recorder = MetricsRecorder(args.num_timesteps, metrics_to_collect, run_dir, args.exp_name)

make_inference_fn, params, _ = train_fn(environment=env, progress_fn=metrics_recorder.progress)

os.makedirs("./params", exist_ok=True)
model.save_params(f'./params/param_{args.exp_name}_s_{args.seed}', params)
render(make_inference_fn, params, env, "./renders", args.exp_name)
make_policy, params, _ = train_fn(environment=env, progress_fn=metrics_recorder.progress)
model.save_params(ckpt_dir + '/final', params)

if __name__ == "__main__":
parser = create_parser()
Expand Down
28 changes: 4 additions & 24 deletions training_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import pickle

import wandb
import math
from brax.io import model
from pyinstrument import Profiler

from src.baselines.sac import train
from utils import MetricsRecorder, get_env_config, create_env, create_eval_env, create_parser, render
from utils import MetricsRecorder, get_env_config, create_env, create_eval_env, create_parser


def main(args):
Expand Down Expand Up @@ -48,7 +47,7 @@ def main(args):
num_evals=args.num_evals,
reward_scaling=1,
episode_length=args.episode_length,
normalize_observations=args.normalize_observations,
normalize_observations=False,
action_repeat=args.action_repeat,
discounting=args.discounting,
learning_rate=args.critic_lr,
Expand All @@ -64,15 +63,6 @@ def main(args):
eval_env=eval_env
)

metrics_recorder = MetricsRecorder(args.num_timesteps)

def ensure_metric(metrics, key):
if key not in metrics:
metrics[key] = 0
else:
if math.isnan(metrics[key]):
raise Exception(f"Metric: {key} is Nan")

metrics_to_collect = [
"eval/episode_reward",
"eval/episode_success",
Expand All @@ -93,20 +83,10 @@ def ensure_metric(metrics, key):
"training/entropy",
]

def progress(num_steps, metrics):
for key in metrics_to_collect:
ensure_metric(metrics, key)
metrics_recorder.record(
num_steps,
{key: value for key, value in metrics.items() if key in metrics_to_collect},
)
metrics_recorder.log_wandb()
metrics_recorder.print_progress()

make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)
metrics_recorder = MetricsRecorder(args.num_timesteps, metrics_to_collect, run_dir, args.exp_name)

make_policy, params, _ = train_fn(environment=env, progress_fn=metrics_recorder.progress)
model.save_params(ckpt_dir + '/final', params)
render(make_inference_fn, params, env, run_dir, args.exp_name)

if __name__ == "__main__":
parser = create_parser()
Expand Down
10 changes: 4 additions & 6 deletions training_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pyinstrument import Profiler

from src.baselines.td3.td3_train import train
from utils import MetricsRecorder, get_env_config, create_env, create_eval_env, create_parser, render
from utils import MetricsRecorder, get_env_config, create_env, create_eval_env, create_parser


def main(args):
Expand Down Expand Up @@ -47,7 +47,7 @@ def main(args):
num_evals=args.num_evals,
reward_scaling=1,
episode_length=args.episode_length,
normalize_observations=args.normalize_observations,
normalize_observations=False,
action_repeat=args.action_repeat,
discounting=args.discounting,
learning_rate=args.critic_lr,
Expand Down Expand Up @@ -82,12 +82,10 @@ def main(args):
"training/entropy",
]

metrics_recorder = MetricsRecorder(args.num_timesteps, metrics_to_collect)

make_inference_fn, params, _ = train_fn(environment=env, progress_fn=metrics_recorder.progress)
metrics_recorder = MetricsRecorder(args.num_timesteps, metrics_to_collect, run_dir, args.exp_name)

make_policy, params, _ = train_fn(environment=env, progress_fn=metrics_recorder.progress)
model.save_params(ckpt_dir + '/final', params)
render(make_inference_fn, params, env, run_dir, args.exp_name)

if __name__ == "__main__":
parser = create_parser()
Expand Down
4 changes: 4 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ def get_env_config(args: argparse.Namespace):
if args.env_name not in legal_envs and "maze" not in args.env_name:
raise ValueError(f"Unknown environment: {args.env_name}")

# TODO: round num_envs to nearest valid value instead of throwing error
if ((args.episode_length - 1) * args.num_envs) % args.batch_size != 0:
raise ValueError("(episode_length - 1) * num_envs must be divisible by batch_size")

args_dict = vars(args)
Config = namedtuple("Config", [*args_dict.keys()])
config = Config(*args_dict.values())
Expand Down

0 comments on commit b54cd4c

Please sign in to comment.