Skip to content

Commit

Permalink
Add torchrl tensordict dataset and replay buffer.
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Jul 30, 2024
1 parent b5f28c7 commit 086006d
Show file tree
Hide file tree
Showing 7 changed files with 590 additions and 124 deletions.
75 changes: 31 additions & 44 deletions grl/algorithms/gmpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from easydict import EasyDict
from rich.progress import track
from tensordict import TensorDict
from torchrl.data import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

import wandb
from grl.agents.gm import GPAgent
Expand Down Expand Up @@ -759,7 +761,7 @@ def save_checkpoint(model, iteration=None, model_type=False):
else:
raise NotImplementedError

def generate_fake_action(model, states, sample_per_state):
def generate_fake_action(model, states, action_augment_num):

fake_actions_sampled = []
for states in track(
Expand All @@ -769,7 +771,7 @@ def generate_fake_action(model, states, sample_per_state):

fake_actions_ = model.behaviour_policy_sample(
state=states,
batch_size=sample_per_state,
batch_size=action_augment_num,
t_span=(
torch.linspace(0.0, 1.0, config.parameter.t_span).to(
states.device
Expand Down Expand Up @@ -862,6 +864,14 @@ def policy(obs: np.ndarray) -> np.ndarray:
lr=config.parameter.behaviour_policy.learning_rate,
)

replay_buffer=TensorDictReplayBuffer(
storage=self.dataset.storage,
batch_size=config.parameter.behaviour_policy.batch_size,
sampler=SamplerWithoutReplacement(),
prefetch=10,
pin_memory=True,
)

behaviour_policy_train_iter = 0
for epoch in track(
range(config.parameter.behaviour_policy.epochs),
Expand All @@ -870,22 +880,9 @@ def policy(obs: np.ndarray) -> np.ndarray:
if self.behaviour_policy_train_epoch >= epoch:
continue

sampler = torch.utils.data.RandomSampler(
self.dataset, replacement=False
)
data_loader = torch.utils.data.DataLoader(
self.dataset,
batch_size=config.parameter.behaviour_policy.batch_size,
shuffle=False,
sampler=sampler,
pin_memory=True,
drop_last=True,
num_workers=8,
)

counter = 1
behaviour_policy_loss_sum = 0
for data in data_loader:
for index, data in enumerate(replay_buffer):

behaviour_policy_loss = self.model[
"GPPolicy"
Expand Down Expand Up @@ -946,34 +943,29 @@ def policy(obs: np.ndarray) -> np.ndarray:
lr=config.parameter.critic.learning_rate,
)

replay_buffer=TensorDictReplayBuffer(
storage=self.dataset.storage,
batch_size=config.parameter.critic.batch_size,
sampler=SamplerWithoutReplacement(),
prefetch=10,
pin_memory=True,
)

critic_train_iter = 0
for epoch in track(
range(config.parameter.critic.epochs), description="Critic training"
):
if self.critic_train_epoch >= epoch:
continue

sampler = torch.utils.data.RandomSampler(
self.dataset, replacement=False
)
data_loader = torch.utils.data.DataLoader(
self.dataset,
batch_size=config.parameter.critic.batch_size,
shuffle=False,
sampler=sampler,
pin_memory=True,
drop_last=True,
num_workers=8,
)

counter = 1

v_loss_sum = 0.0
v_sum = 0.0
q_loss_sum = 0.0
q_sum = 0.0
q_target_sum = 0.0
for data in data_loader:
for index, data in enumerate(replay_buffer):

v_loss, next_v = self.model["GPPolicy"].critic.v_loss(
state=data["s"].to(config.model.GPPolicy.device),
Expand Down Expand Up @@ -1062,6 +1054,14 @@ def policy(obs: np.ndarray) -> np.ndarray:
lr=config.parameter.guided_policy.learning_rate,
)

replay_buffer=TensorDictReplayBuffer(
storage=self.dataset.storage,
batch_size=config.parameter.guided_policy.batch_size,
sampler=SamplerWithoutReplacement(),
prefetch=10,
pin_memory=True,
)

guided_policy_train_iter = 0
beta = config.parameter.guided_policy.beta
for epoch in track(
Expand All @@ -1072,22 +1072,9 @@ def policy(obs: np.ndarray) -> np.ndarray:
if self.guided_policy_train_epoch >= epoch:
continue

sampler = torch.utils.data.RandomSampler(
self.dataset, replacement=False
)
data_loader = torch.utils.data.DataLoader(
self.dataset,
batch_size=config.parameter.guided_policy.batch_size,
shuffle=False,
sampler=sampler,
pin_memory=True,
drop_last=True,
num_workers=8,
)

counter = 1
guided_policy_loss_sum = 0.0
for data in data_loader:
for index, data in enumerate(replay_buffer):
if config.parameter.algorithm_type == "GMPG":
(
guided_policy_loss,
Expand Down
93 changes: 42 additions & 51 deletions grl/algorithms/gmpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
from easydict import EasyDict
from rich.progress import track
from tensordict import TensorDict
from torchrl.data import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

import wandb
from grl.agents.gm import GPAgent

from grl.datasets import create_dataset
from grl.datasets.gp import GPDataset, GPD4RLDataset
from grl.datasets.gp import GPDataset, GPD4RLDataset, GPD4RLTensorDictDataset
from grl.generative_models.diffusion_model import DiffusionModel
from grl.generative_models.conditional_flow_model.optimal_transport_conditional_flow_model import (
OptimalTransportConditionalFlowModel,
Expand Down Expand Up @@ -670,7 +672,7 @@ def save_checkpoint(model, iteration=None, model_type=False):
else:
raise NotImplementedError

def generate_fake_action(model, states, sample_per_state):
def generate_fake_action(model, states, action_augment_num):

fake_actions_sampled = []
for states in track(
Expand All @@ -680,7 +682,7 @@ def generate_fake_action(model, states, sample_per_state):

fake_actions_ = model.behaviour_policy_sample(
state=states,
batch_size=sample_per_state,
batch_size=action_augment_num,
t_span=(
torch.linspace(0.0, 1.0, config.parameter.t_span).to(
states.device
Expand Down Expand Up @@ -739,7 +741,7 @@ def policy(obs: np.ndarray) -> np.ndarray:
evaluation_results[f"evaluation/return_max"] = return_max
evaluation_results[f"evaluation/return_min"] = return_min

if isinstance(self.dataset, GPD4RLDataset):
if isinstance(self.dataset, GPD4RLDataset) or isinstance(self.dataset, GPD4RLTensorDictDataset):
import d4rl
env_id = config.dataset.args.env_id
evaluation_results[f"evaluation/return_mean_normalized"] = (
Expand Down Expand Up @@ -773,6 +775,14 @@ def policy(obs: np.ndarray) -> np.ndarray:
lr=config.parameter.behaviour_policy.learning_rate,
)

replay_buffer=TensorDictReplayBuffer(
storage=self.dataset.storage,
batch_size=config.parameter.behaviour_policy.batch_size,
sampler=SamplerWithoutReplacement(),
prefetch=10,
pin_memory=True,
)

behaviour_policy_train_iter = 0
for epoch in track(
range(config.parameter.behaviour_policy.epochs),
Expand All @@ -781,22 +791,9 @@ def policy(obs: np.ndarray) -> np.ndarray:
if self.behaviour_policy_train_epoch >= epoch:
continue

sampler = torch.utils.data.RandomSampler(
self.dataset, replacement=False
)
data_loader = torch.utils.data.DataLoader(
self.dataset,
batch_size=config.parameter.behaviour_policy.batch_size,
shuffle=False,
sampler=sampler,
pin_memory=True,
drop_last=True,
num_workers=8,
)

counter = 1
behaviour_policy_loss_sum = 0
for data in data_loader:
for index, data in enumerate(replay_buffer):

behaviour_policy_loss = self.model[
"GPPolicy"
Expand Down Expand Up @@ -859,16 +856,18 @@ def policy(obs: np.ndarray) -> np.ndarray:
fake_actions = generate_fake_action(
self.model["GPPolicy"],
self.dataset.states[:].to(config.model.GPPolicy.device),
config.parameter.sample_per_state,
config.parameter.action_augment_num,
)
fake_next_actions = generate_fake_action(
self.model["GPPolicy"],
self.dataset.next_states[:].to(config.model.GPPolicy.device),
config.parameter.sample_per_state,
config.parameter.action_augment_num,
)

self.dataset.fake_actions = fake_actions.to("cpu")
self.dataset.fake_next_actions = fake_next_actions.to("cpu")
self.dataset.load_fake_actions(
fake_actions=fake_actions.to("cpu"),
fake_next_actions=fake_next_actions.to("cpu"),
)

# ---------------------------------------
# make fake action ↑
Expand All @@ -887,34 +886,29 @@ def policy(obs: np.ndarray) -> np.ndarray:
lr=config.parameter.critic.learning_rate,
)

replay_buffer=TensorDictReplayBuffer(
storage=self.dataset.storage,
batch_size=config.parameter.critic.batch_size,
sampler=SamplerWithoutReplacement(),
prefetch=10,
pin_memory=True,
)

critic_train_iter = 0
for epoch in track(
range(config.parameter.critic.epochs), description="Critic training"
):
if self.critic_train_epoch >= epoch:
continue

sampler = torch.utils.data.RandomSampler(
self.dataset, replacement=False
)
data_loader = torch.utils.data.DataLoader(
self.dataset,
batch_size=config.parameter.critic.batch_size,
shuffle=False,
sampler=sampler,
pin_memory=True,
drop_last=True,
num_workers=8,
)

counter = 1

v_loss_sum = 0.0
v_sum = 0.0
q_loss_sum = 0.0
q_sum = 0.0
q_target_sum = 0.0
for data in data_loader:
for index, data in enumerate(replay_buffer):

v_loss, next_v = self.model["GPPolicy"].critic.v_loss(
state=data["s"].to(config.model.GPPolicy.device),
Expand Down Expand Up @@ -1008,6 +1002,15 @@ def policy(obs: np.ndarray) -> np.ndarray:
)
guided_policy_train_iter = 0
beta = config.parameter.guided_policy.beta

replay_buffer=TensorDictReplayBuffer(
storage=self.dataset.storage,
batch_size=config.parameter.guided_policy.batch_size,
sampler=SamplerWithoutReplacement(),
prefetch=10,
pin_memory=True,
)

for epoch in track(
range(config.parameter.guided_policy.epochs),
description="Guided policy training",
Expand All @@ -1016,19 +1019,6 @@ def policy(obs: np.ndarray) -> np.ndarray:
if self.guided_policy_train_epoch >= epoch:
continue

sampler = torch.utils.data.RandomSampler(
self.dataset, replacement=False
)
data_loader = torch.utils.data.DataLoader(
self.dataset,
batch_size=config.parameter.guided_policy.batch_size,
shuffle=False,
sampler=sampler,
pin_memory=True,
drop_last=True,
num_workers=8,
)

counter = 1
guided_policy_loss_sum = 0.0
if config.parameter.algorithm_type == "GMPO":
Expand All @@ -1042,7 +1032,8 @@ def policy(obs: np.ndarray) -> np.ndarray:
energy_sum = 0.0
relative_energy_sum = 0.0
matching_loss_sum = 0.0
for data in data_loader:

for index, data in enumerate(replay_buffer):
if config.parameter.algorithm_type == "GMPO":
(
guided_policy_loss,
Expand Down Expand Up @@ -1107,7 +1098,7 @@ def policy(obs: np.ndarray) -> np.ndarray:
and config.parameter.t_span is not None
else None
),
batch_size=config.parameter.sample_per_state,
batch_size=config.parameter.action_augment_num,
)
fake_actions_ = torch.einsum("nbd->bnd", fake_actions_)
(
Expand Down
Loading

0 comments on commit 086006d

Please sign in to comment.