Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup of 0.1 bugs + tutorials #6

Merged
merged 20 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ repos:
rev: 5.12.0
hooks:
- id: isort
- repo: local
hooks:
- id: pyright
name: pyright
entry: pyright
language: node
pass_filenames: false
types: [python]
additional_dependencies: ["pyright"]
args:
- --project=pyproject.toml
# - repo: local
# hooks:
# - id: pyright
# name: pyright
# entry: pyright
# language: node
# pass_filenames: false
# types: [python]
# additional_dependencies: ["pyright"]
# args:
# - --project=pyproject.toml
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

## Unreleased

- Added guided tutorial notebooks
- Added an option to customize the orchestrator and datastore ports
- Added ParallelEnvironment as a default export from envs
- Added a placeholder image for the web UI
- Updated the uvicorn dependency to require the [standard] option
- Fixed a breaking bug in ParallelEnv
- Dropped openCV as a requirement
- Fixed some type issues, ignore some spurious warnings

## v0.1.0 - 2024-01-17

### Added
Expand Down
2 changes: 1 addition & 1 deletion cogment_lab/actors/nn_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def act(self, observation: np.ndarray, rendered_frame=None) -> int:
[act_vals] = self.network(obs)
act_probs = F.softmax(act_vals / self.temperature, dim=0)

action = torch.multinomial(act_probs, 1).item()
action = int(torch.multinomial(act_probs, 1).item())

return action

Expand Down
4 changes: 3 additions & 1 deletion cogment_lab/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,16 @@ async def impl(self, actor_session: ActorSession):

if event.observation:
observation = self.session_helper.get_observation(event)
if observation is None:
continue
logging.info(f"Got observation: {observation}")

if not observation.active:
action = None
elif not observation.alive:
action = None
else:
action = await self.act(observation.value, observation.rendered_frame)
action = await self.act(observation.value, observation.rendered_frame) # type: ignore
logging.info(f"Got action: {action} with action_space: {self.action_space.gym_space}")
cog_action = self.action_space.create_serialize(action)
actor_session.do_action(cog_action)
Expand Down
2 changes: 1 addition & 1 deletion cogment_lab/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.

from cogment_lab.envs.gymnasium import GymEnvironment
from cogment_lab.envs.pettingzoo import AECEnvironment
from cogment_lab.envs.pettingzoo import AECEnvironment, ParallelEnvironment
18 changes: 9 additions & 9 deletions cogment_lab/envs/conversions/teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def reset(self, seed: int | None = None, options: dict | None = None):
self.infos = {"gym": info, "teacher": info}
self._gym_action = None

def step(self, action):
def step(self, action: dict):
current_agent = self.agent_selection
next_agent = self._agent_selector.next()

Expand Down Expand Up @@ -110,10 +110,10 @@ def render(self):
N = 5

# Set the borders to red
img[:N, :, :] = [255, 0, 0] # Top border
img[-N:, :, :] = [255, 0, 0] # Bottom border
img[:, :N, :] = [255, 0, 0] # Left border
img[:, -N:, :] = [255, 0, 0] # Right border
img[:N, :, :] = [255, 0, 0] # type: ignore
img[-N:, :, :] = [255, 0, 0] # type: ignore
img[:, :N, :] = [255, 0, 0] # type: ignore
img[:, -N:, :] = [255, 0, 0] # type: ignore
return img

def close(self):
Expand Down Expand Up @@ -190,10 +190,10 @@ def render(self):
N = 5

# Set the borders to red
img[:N, :, :] = [255, 0, 0] # Top border
img[-N:, :, :] = [255, 0, 0] # Bottom border
img[:, :N, :] = [255, 0, 0] # Left border
img[:, -N:, :] = [255, 0, 0] # Right border
img[:N, :, :] = [255, 0, 0] # type: ignore
img[-N:, :, :] = [255, 0, 0] # type: ignore
img[:, :N, :] = [255, 0, 0] # type: ignore
img[:, -N:, :] = [255, 0, 0] # type: ignore
return img

def close(self):
Expand Down
4 changes: 3 additions & 1 deletion cogment_lab/envs/gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,10 @@ def __init__(

if isinstance(self.env_id, Callable):
self.env_maker = self.env_id
elif isinstance(self.env_id, str):
self.env_maker = lambda **kwargs: gym.make(self.env_id, **kwargs) # type: ignore
else:
self.env_maker = lambda **kwargs: gym.make(self.env_id, **kwargs)
raise ValueError(f"env_id must be a string or a callable, got {self.env_id}")

if self.registration:
importlib.import_module(self.registration)
Expand Down
18 changes: 9 additions & 9 deletions cogment_lab/envs/pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pettingzoo import AECEnv, ParallelEnv

from cogment_lab.core import CogmentEnv, State
from cogment_lab.generated.data_pb2 import Observation as PbObservation
from cogment_lab.generated.data_pb2 import Observation as PbObservation # type: ignore
from cogment_lab.session_helpers import EnvironmentSessionHelper
from cogment_lab.specs import AgentSpecs
from cogment_lab.utils import import_object
Expand All @@ -42,7 +42,7 @@ def __init__(
self,
env_path: str,
make_kwargs: dict | None = None,
reset_options: dict = None,
reset_options: dict | None = None,
render: bool = False,
reinitialize: bool = False,
dry: bool = False,
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(
self.env: AECEnv = self.env_maker(**self.make_args)
self.agent_specs = self.create_agent_specs(self.env)
else:
self.env = None
self.env = AECEnv()
self.agent_specs = {}

self.initialized = False
Expand Down Expand Up @@ -221,14 +221,15 @@ async def end(self, state: State):

@staticmethod
def fill_observations_(
state: State, observations: dict[str, PbObservation], frame: np.ndarray
state: State, observations: dict[str, PbObservation], frame: np.ndarray | None
) -> dict[str, PbObservation]:
"""
Fill in any missing observations with the default observation. Mutates the observations dict.

Args:
state: The Cogment state.
observations: The observations dict.
frame: The rendered frame.

Returns:
The filled observations dict.
Expand Down Expand Up @@ -292,7 +293,7 @@ def __init__(
self,
env_path: str,
make_kwargs: dict | None = None,
reset_options: dict = None,
reset_options: dict | None = None,
render: bool = False,
reinitialize: bool = False,
dry: bool = False,
Expand Down Expand Up @@ -341,7 +342,7 @@ def __init__(
self.env: ParallelEnv = self.env_maker(**self.make_args)
self.agent_specs = self.create_agent_specs(self.env)
else:
self.env = None
self.env = ParallelEnv()
self.agent_specs = {}

self.initialized = False
Expand All @@ -365,10 +366,10 @@ async def initialize(self, state: State, environment_session: EnvironmentSession
state.env = self.env
state.agent_specs = self.agent_specs
elif self.initialized and not self.reinitialize:
state.env: ParallelEnv = self.env
state.env = self.env # type: ignore
state.agent_specs = self.agent_specs
elif self.reinitialize:
state.env: ParallelEnv = self.env_maker(**self.make_args)
state.env = self.env_maker(**self.make_args) # type: ignore
state.agent_specs = self.create_agent_specs(state.env)

self.initialized = True
Expand Down Expand Up @@ -429,7 +430,6 @@ async def step(self, state: State, action: dict[str, Any]):
"""
logging.info("Stepping environment")

state.env.step(action)
obs, rewards, terminated, truncated, info = state.env.step(action)

frame = state.env.render() if state.session_cfg.render else None
Expand Down
41 changes: 7 additions & 34 deletions cogment_lab/humans/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
def image_to_msg(img: np.ndarray | None) -> str | None:
if img is None:
return None
img = Image.fromarray(img)
image = Image.fromarray(img)
img_byte_array = io.BytesIO()
img.save(img_byte_array, format="PNG") # type: ignore
image.save(img_byte_array, format="PNG") # type: ignore
base64_encoded_result_bytes = base64.b64encode(img_byte_array.getvalue())
base64_encoded_result_str = base64_encoded_result_bytes.decode("ascii")
return f"data:image/png;base64,{base64_encoded_result_str}"
Expand Down Expand Up @@ -157,11 +157,11 @@ def __init__(self, send_queue: asyncio.Queue, recv_queue: asyncio.Queue):
self.recv_queue = recv_queue

async def act(self, observation: Any, rendered_frame: np.ndarray | None = None) -> int:
logging.info(
f"Getting an action with {observation=}" + f" and {rendered_frame.shape=}"
if rendered_frame is not None
else "no frame"
)
# logging.info(
# f"Getting an action with {observation=}" + f" and {rendered_frame.shape=}"
# if rendered_frame is not None
# else "no frame"
# )
await self.send_queue.put(rendered_frame)
action = await self.recv_queue.get()
return action
Expand All @@ -187,30 +187,3 @@ async def run_cogment_actor(
signal_queue.put(True)

await serve


async def shutdown():
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
asyncio.get_event_loop().stop()


def signal_handler(sig, frame):
asyncio.create_task(shutdown())


async def main(app_port: int = 8000, cogment_port: int = 8999):
app_to_actor = asyncio.Queue()
actor_to_app = asyncio.Queue()
fastapi_task = asyncio.create_task(start_fastapi(port=app_port, send_queue=app_to_actor, recv_queue=actor_to_app))
cogment_task = asyncio.create_task(
run_cogment_actor(port=cogment_port, send_queue=actor_to_app, recv_queue=app_to_actor)
)

await asyncio.gather(fastapi_task, cogment_task)


if __name__ == "__main__":
asyncio.run(main())
3 changes: 2 additions & 1 deletion cogment_lab/humans/static/index.html

Large diffs are not rendered by default.

45 changes: 28 additions & 17 deletions cogment_lab/process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
import multiprocessing as mp
import os
from asyncio import Task
from collections.abc import Sequence
from collections.abc import Callable, Coroutine, Sequence
from multiprocessing import Process, Queue
from typing import Any, Callable, Coroutine
from typing import Any

import cogment
from cogment.control import Controller
from cogment.datastore import Datastore

from cogment_lab.actors.runner import actor_runner
from cogment_lab.core import BaseActor, BaseEnv
Expand All @@ -39,13 +41,6 @@
)


ORCHESTRATOR_ENDPOINT = "grpc://localhost:9000"
ENVIRONMENT_ENDPOINT = "grpc://localhost:9001"
RANDOM_AGENT_ENDPOINT = "grpc://localhost:9002"
HUMAN_AGENT_ENDPOINT = "grpc://localhost:8999"
DATASTORE_ENDPOINT = "grpc://localhost:9003"


AgentName = str
ImplName = str
TrialName = str
Expand All @@ -54,12 +49,17 @@
class Cogment:
"""Main Cogment class for managing experiments"""

controller: Controller
datastore: Datastore

def __init__(
self,
user_id: str = "cogment_lab",
torch_mode: bool = False,
log_dir: str | None = None,
mp_method: str | None = None,
orchestrator_port: int = 9000,
datastore_port: int = 9003,
):
"""Initializes the Cogment instance
Expand All @@ -79,9 +79,20 @@ def __init__(
self.envs: dict[ImplName, BaseEnv] = {}
self.actors: dict[ImplName, BaseActor] = {}

self.orchestrator_endpoint = f"grpc://localhost:{orchestrator_port}"
self.datastore_endpoint = f"grpc://localhost:{datastore_port}"

self.context = cogment.Context(cog_settings=cog_settings, user_id=user_id)
self.controller = self.context.get_controller(endpoint=cogment.Endpoint(ORCHESTRATOR_ENDPOINT))
self.datastore = self.context.get_datastore(endpoint=cogment.Endpoint(DATASTORE_ENDPOINT))
controller = self.context.get_controller(endpoint=cogment.Endpoint(self.orchestrator_endpoint))
datastore = self.context.get_datastore(endpoint=cogment.Endpoint(self.datastore_endpoint))

assert isinstance(
controller, Controller
), "self.controller is not an instance of Controller. Please report this."
assert isinstance(datastore, Datastore), "self.datastore is not an instance of Datastore. Please report this."

self.controller = controller
self.datastore = datastore

self.env_ports: dict[ImplName, int] = {}
self.actor_ports: dict[ImplName, int] = {}
Expand Down Expand Up @@ -121,7 +132,7 @@ def _add_process(

p = TorchProcess(target=target, args=args)
else:
p = self.mp_ctx.Process(target=target, args=args)
p = self.mp_ctx.Process(target=target, args=args) # type: ignore
p.start()
self.processes[name] = p

Expand All @@ -148,7 +159,7 @@ def run_env(
env_name: ImplName,
port: int = 9001,
log_file: str | None = None,
) -> Coroutine[bool]:
) -> Coroutine[None, None, bool]:
"""Given an environment, runs it in a subprocess
Args:
Expand Down Expand Up @@ -195,7 +206,7 @@ def run_actor(
actor_name: ImplName,
port: int = 9002,
log_file: str | None = None,
) -> Coroutine[bool]:
) -> Coroutine[None, None, bool]:
"""Given an actor, runs it
Args:
Expand Down Expand Up @@ -282,7 +293,7 @@ def run_web_ui(
html_override: str | None = None,
file_override: str | None = None,
jinja_parameters: dict[str, Any] | None = None,
) -> Coroutine[bool]:
) -> Coroutine[None, None, bool]:
"""Runs the human actor in a separate process
Args:
Expand Down Expand Up @@ -457,7 +468,7 @@ async def start_trial(
for agent_name, actor_impl in actor_impls.items()
]

env_config = data_pb2.EnvironmentConfig(**session_config)
env_config = data_pb2.EnvironmentConfig(**session_config) # type: ignore

trial_params = cogment.TrialParameters(
cog_settings,
Expand All @@ -466,7 +477,7 @@ async def start_trial(
environment_config=env_config,
actors=actor_params,
environment_implementation=env_name,
datalog_endpoint=DATASTORE_ENDPOINT,
datalog_endpoint=self.datastore_endpoint,
)

trial_id = await self.controller.start_trial(trial_id_requested=trial_name, trial_params=trial_params)
Expand Down
Loading
Loading