Skip to content

Commit

Permalink
Polish Dataset device settings.
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Jul 23, 2024
1 parent eaea86c commit b9d9118
Show file tree
Hide file tree
Showing 73 changed files with 228 additions and 2,667 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ from grl.utils.log import log
from grl_pipelines.diffusion_model.configurations.lunarlander_continuous_qgpo import config

def qgpo_pipeline(config):
qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedDataset(numpy_data_path="./data.npz", device=config.train.device))
qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedDataset(numpy_data_path="./data.npz",))
qgpo.train()

agent = qgpo.deploy()
Expand Down
2 changes: 1 addition & 1 deletion README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ from grl.utils.log import log
from grl_pipelines.diffusion_model.configurations.lunarlander_continuous_qgpo import config

def qgpo_pipeline(config):
qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedDataset(numpy_data_path="./data.npz", device=config.train.device))
qgpo = QGPOAlgorithm(config, dataset=QGPOCustomizedDataset(numpy_data_path="./data.npz",))
qgpo.train()

agent = qgpo.deploy()
Expand Down
92 changes: 0 additions & 92 deletions grl/agents/gp.py

This file was deleted.

37 changes: 20 additions & 17 deletions grl/algorithms/gmpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,8 +878,9 @@ def policy(obs: np.ndarray) -> np.ndarray:
batch_size=config.parameter.behaviour_policy.batch_size,
shuffle=False,
sampler=sampler,
pin_memory=False,
pin_memory=True,
drop_last=True,
num_workers=8,
)

counter = 1
Expand All @@ -889,8 +890,8 @@ def policy(obs: np.ndarray) -> np.ndarray:
behaviour_policy_loss = self.model[
"GPPolicy"
].behaviour_policy_loss(
action=data["a"],
state=data["s"],
action=data["a"].to(config.model.GPPolicy.device),
state=data["s"].to(config.model.GPPolicy.device),
maximum_likelihood=(
config.parameter.behaviour_policy.maximum_likelihood
if hasattr(
Expand Down Expand Up @@ -960,8 +961,9 @@ def policy(obs: np.ndarray) -> np.ndarray:
batch_size=config.parameter.critic.batch_size,
shuffle=False,
sampler=sampler,
pin_memory=False,
pin_memory=True,
drop_last=True,
num_workers=8,
)

counter = 1
Expand All @@ -974,19 +976,19 @@ def policy(obs: np.ndarray) -> np.ndarray:
for data in data_loader:

v_loss, next_v = self.model["GPPolicy"].critic.v_loss(
state=data["s"],
action=data["a"],
next_state=data["s_"],
state=data["s"].to(config.model.GPPolicy.device),
action=data["a"].to(config.model.GPPolicy.device),
next_state=data["s_"].to(config.model.GPPolicy.device),
tau=config.parameter.critic.tau,
)
v_optimizer.zero_grad(set_to_none=True)
v_loss.backward()
v_optimizer.step()
q_loss, q, q_target = self.model["GPPolicy"].critic.iql_q_loss(
state=data["s"],
action=data["a"],
reward=data["r"],
done=data["d"],
state=data["s"].to(config.model.GPPolicy.device),
action=data["a"].to(config.model.GPPolicy.device),
reward=data["r"].to(config.model.GPPolicy.device),
done=data["d"].to(config.model.GPPolicy.device),
next_v=next_v,
discount=config.parameter.critic.discount_factor,
)
Expand Down Expand Up @@ -1078,8 +1080,9 @@ def policy(obs: np.ndarray) -> np.ndarray:
batch_size=config.parameter.guided_policy.batch_size,
shuffle=False,
sampler=sampler,
pin_memory=False,
pin_memory=True,
drop_last=True,
num_workers=8,
)

counter = 1
Expand All @@ -1092,7 +1095,7 @@ def policy(obs: np.ndarray) -> np.ndarray:
log_p_loss,
log_u_loss,
) = self.model["GPPolicy"].policy_gradient_loss(
data["s"],
data["s"].to(config.model.GPPolicy.device),
gradtime_step=config.parameter.guided_policy.gradtime_step,
beta=beta,
repeats=(
Expand All @@ -1108,7 +1111,7 @@ def policy(obs: np.ndarray) -> np.ndarray:
log_p_loss,
log_u_loss,
) = self.model["GPPolicy"].policy_gradient_loss_by_REINFORCE(
data["s"],
data["s"].to(config.model.GPPolicy.device),
gradtime_step=config.parameter.guided_policy.gradtime_step,
beta=beta,
repeats=(
Expand All @@ -1133,7 +1136,7 @@ def policy(obs: np.ndarray) -> np.ndarray:
) = self.model[
"GPPolicy"
].policy_gradient_loss_by_REINFORCE_softmax(
data["s"],
data["s"].to(config.model.GPPolicy.device),
gradtime_step=config.parameter.guided_policy.gradtime_step,
beta=beta,
repeats=(
Expand All @@ -1146,8 +1149,8 @@ def policy(obs: np.ndarray) -> np.ndarray:
guided_policy_loss = self.model[
"GPPolicy"
].policy_gradient_loss_add_matching_loss(
data["a"],
data["s"],
data["a"].to(config.model.GPPolicy.device),
data["s"].to(config.model.GPPolicy.device),
maximum_likelihood=(
config.parameter.guided_policy.maximum_likelihood
if hasattr(
Expand Down
49 changes: 26 additions & 23 deletions grl/algorithms/gmpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,8 +789,9 @@ def policy(obs: np.ndarray) -> np.ndarray:
batch_size=config.parameter.behaviour_policy.batch_size,
shuffle=False,
sampler=sampler,
pin_memory=False,
pin_memory=True,
drop_last=True,
num_workers=8,
)

counter = 1
Expand All @@ -800,8 +801,8 @@ def policy(obs: np.ndarray) -> np.ndarray:
behaviour_policy_loss = self.model[
"GPPolicy"
].behaviour_policy_loss(
action=data["a"],
state=data["s"],
action=data["a"].to(config.model.GPPolicy.device),
state=data["s"].to(config.model.GPPolicy.device),
maximum_likelihood=(
config.parameter.behaviour_policy.maximum_likelihood
if hasattr(
Expand Down Expand Up @@ -857,17 +858,17 @@ def policy(obs: np.ndarray) -> np.ndarray:

fake_actions = generate_fake_action(
self.model["GPPolicy"],
self.dataset.states[:],
self.dataset.states[:].to(config.model.GPPolicy.device),
config.parameter.sample_per_state,
)
fake_next_actions = generate_fake_action(
self.model["GPPolicy"],
self.dataset.next_states[:],
self.dataset.next_states[:].to(config.model.GPPolicy.device),
config.parameter.sample_per_state,
)

self.dataset.fake_actions = fake_actions
self.dataset.fake_next_actions = fake_next_actions
self.dataset.fake_actions = fake_actions.to("cpu")
self.dataset.fake_next_actions = fake_next_actions.to("cpu")

# ---------------------------------------
# make fake action ↑
Expand Down Expand Up @@ -901,8 +902,9 @@ def policy(obs: np.ndarray) -> np.ndarray:
batch_size=config.parameter.critic.batch_size,
shuffle=False,
sampler=sampler,
pin_memory=False,
pin_memory=True,
drop_last=True,
num_workers=8,
)

counter = 1
Expand All @@ -915,19 +917,19 @@ def policy(obs: np.ndarray) -> np.ndarray:
for data in data_loader:

v_loss, next_v = self.model["GPPolicy"].critic.v_loss(
state=data["s"],
action=data["a"],
next_state=data["s_"],
state=data["s"].to(config.model.GPPolicy.device),
action=data["a"].to(config.model.GPPolicy.device),
next_state=data["s_"].to(config.model.GPPolicy.device),
tau=config.parameter.critic.tau,
)
v_optimizer.zero_grad(set_to_none=True)
v_loss.backward()
v_optimizer.step()
q_loss, q, q_target = self.model["GPPolicy"].critic.iql_q_loss(
state=data["s"],
action=data["a"],
reward=data["r"],
done=data["d"],
state=data["s"].to(config.model.GPPolicy.device),
action=data["a"].to(config.model.GPPolicy.device),
reward=data["r"].to(config.model.GPPolicy.device),
done=data["d"].to(config.model.GPPolicy.device),
next_v=next_v,
discount=config.parameter.critic.discount_factor,
)
Expand Down Expand Up @@ -1022,8 +1024,9 @@ def policy(obs: np.ndarray) -> np.ndarray:
batch_size=config.parameter.guided_policy.batch_size,
shuffle=False,
sampler=sampler,
pin_memory=False,
pin_memory=True,
drop_last=True,
num_workers=8,
)

counter = 1
Expand All @@ -1049,8 +1052,8 @@ def policy(obs: np.ndarray) -> np.ndarray:
) = self.model[
"GPPolicy"
].policy_optimization_loss_by_advantage_weighted_regression(
data["a"],
data["s"],
data["a"].to(config.model.GPPolicy.device),
data["s"].to(config.model.GPPolicy.device),
maximum_likelihood=(
config.parameter.guided_policy.maximum_likelihood
if hasattr(
Expand Down Expand Up @@ -1079,8 +1082,8 @@ def policy(obs: np.ndarray) -> np.ndarray:
) = self.model[
"GPPolicy"
].policy_optimization_loss_by_advantage_weighted_regression_softmax(
data["s"],
data["fake_a"],
data["s"].to(config.model.GPPolicy.device),
data["fake_a"].to(config.model.GPPolicy.device),
maximum_likelihood=(
config.parameter.guided_policy.maximum_likelihood
if hasattr(
Expand All @@ -1095,10 +1098,10 @@ def policy(obs: np.ndarray) -> np.ndarray:
matching_loss_sum += matching_loss
elif config.parameter.algorithm_type == "GMPO_softmax_sample":
fake_actions_ = self.model["GPPolicy"].behaviour_policy_sample(
state=data["s"],
state=data["s"].to(config.model.GPPolicy.device),
t_span=(
torch.linspace(0.0, 1.0, config.parameter.t_span).to(
data["s"].device
config.model.GPPolicy.device
)
if hasattr(config.parameter, "t_span")
and config.parameter.t_span is not None
Expand All @@ -1115,7 +1118,7 @@ def policy(obs: np.ndarray) -> np.ndarray:
) = self.model[
"GPPolicy"
].policy_optimization_loss_by_advantage_weighted_regression_softmax(
data["s"],
data["s"].to(config.model.GPPolicy.device),
fake_actions_,
maximum_likelihood=(
config.parameter.guided_policy.maximum_likelihood
Expand Down
Loading

0 comments on commit b9d9118

Please sign in to comment.