Skip to content

Commit

Permalink
Add an optional progress bar in data collection (#14)
Browse files Browse the repository at this point in the history
tqdm-based progress bar, defaults to disabled, but can be useful for
interactive usage
  • Loading branch information
RedTachyon authored Feb 2, 2024
1 parent 2daf574 commit 74f9a03
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 13 deletions.
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.1.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

## Unreleased

### Fixed
- jinja2 is now correctly a dependency
- Added an optional progress bar to data collection

## v0.1.1 - 2024-01-22

Expand Down
11 changes: 10 additions & 1 deletion cogment_lab/process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,14 +500,23 @@ async def get_trial_data(
"next_observations",
"last_observation",
),
use_tqdm: bool = False,
tqdm_kwargs: dict[str, Any] | None = None,
) -> dict[str, TrialData]:
"""Gets trial data from the datastore, formatting it appropriately."""
if env_name is None:
env_name = self.trial_envs[trial_id]
env = self.envs[env_name]
agent_specs = env.agent_specs

data = await format_data_multiagent(self.datastore, trial_id, agent_specs, fields)
data = await format_data_multiagent(
datastore=self.datastore,
trial_id=trial_id,
actor_agent_specs=agent_specs,
fields=fields,
use_tqdm=use_tqdm,
tqdm_kwargs=tqdm_kwargs,
)

return data

Expand Down
16 changes: 8 additions & 8 deletions cogment_lab/utils/trial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np
from cogment import ActorParameters
from cogment.datastore import Datastore, DatastoreSample
from tqdm.auto import tqdm

from cogment_lab.generated import cog_settings
from cogment_lab.specs import AgentSpecs
Expand Down Expand Up @@ -302,6 +303,8 @@ async def format_data_multiagent(
"next_observations",
"last_observation",
),
use_tqdm: bool = False,
tqdm_kwargs: dict[str, Any] | None = None,
) -> dict[str, TrialData]:
"""
Formats trial data from a multiagent Cogment trial into structured formats for reinforcement learning.
Expand All @@ -311,10 +314,14 @@ async def format_data_multiagent(
trial_id (str): The identifier of the trial.
actor_agent_specs (dict[str, EnvironmentSpecs]): A dictionary mapping actor IDs to their environment specifications.
fields (List[str]): The list of fields to include in the formatted data.
tqdm_kwargs (dict[str, Any] | None): Optional keyword arguments to pass to tqdm.
Returns:
dict[str, TrialData]: A dictionary mapping actor IDs to their formatted trial data.
"""
if tqdm_kwargs is None:
tqdm_kwargs = {}

trials = []
while len(trials) == 0:
try:
Expand All @@ -327,14 +334,7 @@ async def format_data_multiagent(
actor_reward_samples = {actor_id: [] for actor_id in actor_agent_specs.keys()}

# Get all samples
all_samples = []
async for sample in datastore.all_samples(trials): # type: ignore
all_samples.append(sample)

# Sort according to tick_id -- this might not be necessary with some version of cogment
all_samples.sort(key=lambda x: x.tick_id)

for sample in all_samples:
async for sample in tqdm(datastore.all_samples(trials), disable=not use_tqdm, **tqdm_kwargs): # type: ignore
for actor_id in actor_agent_specs.keys():
# Add the sample to the list for an actor if the observation for that actor is not None
if (
Expand Down
2 changes: 0 additions & 2 deletions examples/demos/active-lunar/lunar-base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ async def main():

for episode in (pbar := trange(args.num_episodes)):
actor.set_eps(get_current_eps(episode))
if episode == args.human_episodes:
cog.stop_service("lunar")

trial_id = await cog.start_trial(
env_name="lunar",
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"numpy",
"fastapi>=0.103",
"pillow>=9.0",
"tqdm",
"jinja2>=3.1"
]

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ starlette>=0.21.0
uvicorn[standard]==0.17.6
fastapi>=0.103.2
pillow>=9.0
tqdm

# environments
Gymnasium~=0.29
Expand Down

0 comments on commit 74f9a03

Please sign in to comment.