Skip to content

Commit

Permalink
Add camera follow entity feature
Browse files Browse the repository at this point in the history
Fixes Genesis-Embodied-AI#538

Add functionality for the camera to follow an entity.

* **Camera Class (`genesis/vis/camera.py`)**
  - Add `_follow_entity`, `_follow_height`, and `_follow_smoothing` attributes.
  - Add `follow_entity` method to set the camera to follow a specified entity.
  - Update `render` method to call `follow_entity` if enabled.

* **Scene Class (`genesis/engine/scene.py`)**
  - Add `set_camera_follow_entity` method to set the camera to follow a specified entity.
  - Update `step` method to call `set_camera_follow_entity` if enabled.

* **Drone Example (`examples/drone/interactive_drone.py`)**
  - Add `--follow` argument to enable camera follow.
  - Update `main` function to use `follow_entity` method if `--follow` argument is provided.
  • Loading branch information
Likhithsai2580 committed Jan 13, 2025
1 parent 10a1078 commit c4bd73c
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 0 deletions.
11 changes: 11 additions & 0 deletions examples/drone/interactive_drone.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("-v", "--vis", action="store_true", default=True, help="Enable visualization (default: True)")
parser.add_argument("-m", "--mac", action="store_true", default=False, help="Running on MacOS (default: False)")
parser.add_argument("-f", "--follow", action="store_true", default=False, help="Enable camera follow (default: False)")
args = parser.parse_args()

# Initialize Genesis
Expand Down Expand Up @@ -184,6 +185,16 @@ def main():
listener = keyboard.Listener(on_press=controller.on_press, on_release=controller.on_release)
listener.start()

if args.follow:
camera = scene.add_camera(
res=(640, 480),
pos=(0.0, -4.0, 2.0),
lookat=(0.0, 0.0, 0.5),
fov=45,
GUI=True,
)
scene.set_camera_follow_entity(camera, drone, height=2.0, smoothing=0.1)

if args.mac:
# Run simulation in another thread
sim_thread = threading.Thread(target=run_sim, args=(scene, drone, controller))
Expand Down
22 changes: 22 additions & 0 deletions genesis/engine/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,10 @@ def step(self, update_visualizer=True):
if self._show_FPS:
self.FPS_tracker.step()

# Update camera to follow entity if enabled
if hasattr(self, "_camera_follow_entity") and self._camera_follow_entity is not None:
self._camera_follow_entity.follow_entity()

def _step_grad(self):
self._sim.collect_output_grads()
self._sim._step_grad()
Expand Down Expand Up @@ -913,6 +917,24 @@ def _backward(self):
self._backward_ready = False
self._forward_ready = False

def set_camera_follow_entity(self, camera, entity, height=None, smoothing=None):
"""
Set the camera to follow a specified entity.
Parameters
----------
camera : genesis.Camera
The camera to follow the entity.
entity : genesis.Entity
The entity to follow.
height : float, optional
The height at which the camera should follow the entity. If None, the camera will maintain its current height.
smoothing : float, optional
The smoothing factor for the camera's movement. If None, no smoothing will be applied.
"""
camera.follow_entity(entity, height, smoothing)
self._camera_follow_entity = camera

# ------------------------------------------------------------------------------------
# ----------------------------------- properties -------------------------------------
# ------------------------------------------------------------------------------------
Expand Down
47 changes: 47 additions & 0 deletions genesis/vis/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def __init__(
self._in_recording = False
self._recorded_imgs = []

self._follow_entity = None
self._follow_height = None
self._follow_smoothing = None

if self._model not in ["pinhole", "thinlens"]:
gs.raise_exception(f"Invalid camera model: {self._model}")

Expand Down Expand Up @@ -146,6 +150,9 @@ def render(self, rgb=True, depth=False, segmentation=False, colorize_seg=False,

rgb_arr, depth_arr, seg_idxc_arr, seg_arr, normal_arr = None, None, None, None, None

if self._follow_entity is not None:
self.follow_entity()

if self._raytracer is not None:
if rgb:
self._raytracer.update_scene()
Expand Down Expand Up @@ -279,6 +286,46 @@ def set_params(self, fov=None, aperture=None, focus_dist=None):
if self._raytracer is not None:
self._raytracer.update_camera(self)

@gs.assert_built
def follow_entity(self, entity, height=None, smoothing=None):
"""
Set the camera to follow a specified entity.
Parameters
----------
entity : genesis.Entity
The entity to follow.
height : float, optional
The height at which the camera should follow the entity. If None, the camera will maintain its current height.
smoothing : float, optional
The smoothing factor for the camera's movement. If None, no smoothing will be applied.
"""
self._follow_entity = entity
self._follow_height = height
self._follow_smoothing = smoothing

def follow_entity(self):
"""
Update the camera position to follow the specified entity.
"""
if self._follow_entity is None:
return

entity_pos = self._follow_entity.get_pos()
camera_pos = np.array(self._pos)

if self._follow_height is not None:
camera_pos[2] = self._follow_height

if self._follow_smoothing is not None:
camera_pos[:2] = (
self._follow_smoothing * camera_pos[:2] + (1 - self._follow_smoothing) * entity_pos[:2]
)
else:
camera_pos[:2] = entity_pos[:2]

self.set_pose(pos=camera_pos, lookat=entity_pos)

@gs.assert_built
def start_recording(self):
"""
Expand Down

0 comments on commit c4bd73c

Please sign in to comment.