Skip to content

Commit

Permalink
Polish documents.
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Jun 21, 2024
1 parent 533b3e7 commit 6940902
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 46 deletions.
2 changes: 1 addition & 1 deletion grl/algorithms/gmpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from rich.progress import track
from tensordict import TensorDict

import d4rl
import wandb
from grl.agents.gm import GPAgent

Expand Down Expand Up @@ -830,6 +829,7 @@ def policy(obs: np.ndarray) -> np.ndarray:
evaluation_results[f"evaluation/return_min"] = return_min

if isinstance(self.dataset, GPD4RLDataset):
import d4rl
env_id = config.dataset.args.env_id
evaluation_results[f"evaluation/return_mean_normalized"] = (
d4rl.get_normalized_score(env_id, return_mean)
Expand Down
2 changes: 1 addition & 1 deletion grl/algorithms/gmpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from rich.progress import track
from tensordict import TensorDict

import d4rl
import wandb
from grl.agents.gm import GPAgent

Expand Down Expand Up @@ -741,6 +740,7 @@ def policy(obs: np.ndarray) -> np.ndarray:
evaluation_results[f"evaluation/return_min"] = return_min

if isinstance(self.dataset, GPD4RLDataset):
import d4rl
env_id = config.dataset.args.env_id
evaluation_results[f"evaluation/return_mean_normalized"] = (
d4rl.get_normalized_score(env_id, return_mean)
Expand Down
2 changes: 1 addition & 1 deletion grl/algorithms/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

import d4rl
import wandb
from grl.agents.gp import GPAgent

Expand Down Expand Up @@ -1439,6 +1438,7 @@ def policy(obs: np.ndarray) -> np.ndarray:
] = return_min

if isinstance(self.dataset, GPD4RLDataset):
import d4rl
env_id = config.dataset.args.env_id
evaluation_results[
f"evaluation/guidance_scale:[{guidance_scale}]/return_mean_normalized"
Expand Down
45 changes: 2 additions & 43 deletions grl/algorithms/srpo.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import copy
import os
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Dict, Union

import d4rl
import gym
import numpy as np
import torch
import torch.nn as nn
from easydict import EasyDict
from rich.progress import Progress, track
from rich.progress import track
from tensordict import TensorDict
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -388,45 +386,6 @@ def get_train_data(dataloader):
while True:
yield from dataloader

def pallaral_simple_eval_policy(
policy_fn, env_name, seed, eval_episodes=20
):
eval_envs = []
for i in range(eval_episodes):
env = gym.make(env_name)
eval_envs.append(env)
env.seed(seed + 1001 + i)
env.buffer_state = env.reset()
env.buffer_return = 0.0
ori_eval_envs = [env for env in eval_envs]
import time

t = time.time()
while len(eval_envs) > 0:
new_eval_envs = []
states = np.stack([env.buffer_state for env in eval_envs])
states = torch.Tensor(states).to("cuda")
with torch.no_grad():
actions = policy_fn(states).detach().cpu().numpy()
for i, env in enumerate(eval_envs):
state, reward, done, info = env.step(actions[i])
env.buffer_return += reward
env.buffer_state = state
if not done:
new_eval_envs.append(env)
eval_envs = new_eval_envs
for i in range(eval_episodes):
ori_eval_envs[i].buffer_return = d4rl.get_normalized_score(
env_name, ori_eval_envs[i].buffer_return
)
mean = np.mean(
[ori_eval_envs[i].buffer_return for i in range(eval_episodes)]
)
std = np.std(
[ori_eval_envs[i].buffer_return for i in range(eval_episodes)]
)
return mean, std

def evaluate(policy_fn, train_iter):
evaluation_results = dict()

Expand Down

0 comments on commit 6940902

Please sign in to comment.