Skip to content

Commit

Permalink
Type safe ECS queries (#32)
Browse files Browse the repository at this point in the history
* Add macro to (almost) make data querying safe

* Make `query_iter!` macro memory safe

* Add documentation to helper functions

* Optionally get handle to entity with query

* Make clippy happy

* Fix weird mutability
  • Loading branch information
foodelevator authored May 5, 2022
1 parent 0560eea commit 9e5b8ad
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 57 deletions.
5 changes: 5 additions & 0 deletions ecs/src/entity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ impl<'e> Iter<'e> {
unused_ids: entities.unused_ids.iter().copied().collect(),
}
}

/// Get the iter's entities.
pub fn entities(&self) -> &Entities {
self.entities
}
}

impl<'e> Iterator for Iter<'e> {
Expand Down
58 changes: 55 additions & 3 deletions ecs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
pub mod component;
mod entity;
mod error;
#[macro_use]
mod query;
mod world;

pub use entity::{Entities, Entity};
pub use error::BorrowMutError;
pub use query::{ComponentQuery, Query};
pub use query::{as_mut_lt, as_ref_lt, ComponentQuery, Query};
pub use world::World;

#[cfg(test)]
Expand Down Expand Up @@ -508,6 +509,57 @@ mod tests {
mem::drop(r);
}

#[test]
fn type_safe_macros() {
let mut world = World::default();
struct Name(String);
struct Speed(f32);
let sanic = world.spawn();
world.add(sanic, Name("Sanic".into()));
world.add(sanic, Speed(100.0));
let mario = world.spawn();
world.add(mario, Name("Mario".into()));
world.add(mario, Speed(200.0)); // copilot thinks mario is faster than sanic

query_iter!(world, (name: Name, speed: mut Speed) => {
match name.0.as_ref() {
"Mario" => assert_eq!(speed.0, 200.0),
"Sanic" => {
assert_eq!(speed.0, 100.0);
speed.0 = 300.0; // copilot thinks he's faster than mario
}
_ => panic!("Unexpected name"),
}
});

query_iter!(world, (entity: Entity, name: Name, speed: Speed) => {
match name.0.as_ref() {
"Mario" => {
assert_eq!(entity, mario);
assert_eq!(speed.0, 200.0)
}
"Sanic" => {
assert_eq!(entity, sanic);
assert_eq!(speed.0, 300.0);
}
_ => panic!("Unexpected name"),
}
});

let mut found_sanic = false;
let mut found_mario = false;
query_iter!(world, (entity: Entity) => {
if found_sanic {
assert_eq!(entity, mario);
found_mario = true;
} else {
assert_eq!(entity, sanic);
found_sanic = true;
}
});
assert!(found_sanic && found_mario);
}

#[test]
fn iterate_over_query() {
let mut world = World::default();
Expand Down Expand Up @@ -535,7 +587,7 @@ mod tests {
.unwrap();
let mut q = world.query(&q);
for (pos, vel) in unsafe {
q.iter().map(|comps| {
q.iter().map(|(_e, comps)| {
if let [pos, vel] = comps[..] {
(
pos.cast::<Position>().as_mut(),
Expand All @@ -558,7 +610,7 @@ mod tests {
let mut q = world.query(&q);
for (i, pos) in unsafe {
q.iter()
.map(|comps| {
.map(|(_e, comps)| {
if let [pos] = comps[..] {
pos.cast::<Position>().as_ref()
} else {
Expand Down
182 changes: 129 additions & 53 deletions ecs/src/query.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,111 @@
use std::{collections::HashSet, marker::PhantomData, ptr::NonNull};
use std::{collections::HashSet, ptr::NonNull};

use crate::{
component::{ComponentEntryRef, ComponentId, ComponentRegistry},
BorrowMutError, Entity,
component::{ComponentEntryRef, ComponentId},
entity::Iter as EntityIter,
BorrowMutError, Entity, World,
};

/// Casts `ptr` to a reference with the lifetime `'a`.
/// # Safety
/// It is the responsibility of the caller to ensure that the lifetime `'a` outlives
/// the lifetime of the data pointed to by `ptr`.
#[allow(clippy::needless_lifetimes)]
pub unsafe fn as_ref_lt<'a, T>(_lifetime: &'a (), ptr: NonNull<T>) -> &'a T {
ptr.as_ref()
}

/// Casts `ptr` to a mutable reference with the lifetime `'a`.
/// # Safety
/// It is the responsibility of the caller to ensure that the lifetime `'a` outlives
/// the lifetime of the data pointed to by `ptr`.
#[allow(clippy::mut_from_ref, clippy::needless_lifetimes)]
pub unsafe fn as_mut_lt<'a, T>(_lifetime: &'a (), mut ptr: NonNull<T>) -> &'a mut T {
ptr.as_mut()
}

#[macro_export]
macro_rules! _query_definition {
( $world:expr, $vec:expr, ($name:ident: Entity, $($tail:tt)*) ) => {{
_query_definition!($world, $vec, ($($tail)*));
}};
( $world:expr, $vec:expr, ($name:ident: $type:ty, $($tail:tt)*) ) => {{
$vec.push(ComponentQuery {
id: $world.component_id::<$type>().unwrap(),
mutable: false,
});
_query_definition!($world, $vec, ($($tail)*));
}};
( $world:expr, $vec:expr, ($name:ident: mut $type:ty, $($tail:tt)*) ) => {{
$vec.push(ComponentQuery {
id: $world.component_id::<$type>().unwrap(),
mutable: true,
});
_query_definition!($world, $vec, ($($tail)*));
}};

// Last entry
( $world:expr, $vec:expr, ($name:ident: Entity) ) => { };
( $world:expr, $vec:expr, ($name:ident: $type:ty) ) => {{
$vec.push(ComponentQuery {
id: $world.component_id::<$type>().unwrap(),
mutable: false,
});
}};
( $world:expr, $vec:expr, ($name:ident: mut $type:ty) ) => {{
$vec.push(ComponentQuery {
id: $world.component_id::<$type>().unwrap(),
mutable: true,
});
}};
}

#[macro_export]
macro_rules! _query_defvars {
( $comps:expr, $lt:expr, $entity:expr, ($name:ident: Entity, $($tail:tt)*) ) => {
let $name = $entity;
_query_defvars!($comps[..], $lt, $entity, ($($tail)*));
};
( $comps:expr, $lt:expr, $entity:expr, ($name:ident: $type:ty, $($tail:tt)*) ) => {
let $name = unsafe { $crate::query::as_ref_lt($lt, $comps[0].cast::<$type>()) };
_query_defvars!($comps[1..], $lt, $entity, ($($tail)*));
};
( $comps:expr, $lt:expr, $entity:expr, ($name:ident: mut $type:ty, $($tail:tt)*) ) => {
let $name = unsafe { $crate::query::as_mut_lt($lt, $comps[0].cast::<$type>()) };
_query_defvars!($comps[1..], $lt, $entity, ($($tail)*));
};

// Last entry
( $comps:expr, $lt:expr, $entity:expr, ($name:ident: Entity) ) => {
let $name = $entity;
};
( $comps:expr, $lt:expr, $entity:expr, ($name:ident: $type:ty) ) => {
let $name = unsafe { $crate::query::as_ref_lt($lt, $comps[0].cast::<$type>()) };
};
( $comps:expr, $lt:expr, $entity:expr, ($name:ident: mut $type:ty) ) => {
let $name = unsafe { $crate::query::as_mut_lt($lt, $comps[0].cast::<$type>()) };
};
}

#[macro_export]
macro_rules! query_iter {
( $world:expr, ($($query:tt)*) => $body:block ) => {{
#[allow(unused_mut)]
let mut v = vec![];
_query_definition!($world, v, ($($query)*));
let q = Query::new(v).expect("Query violates rusts borrow rules");

let mut res = $world.query(&q);

#[allow(unused_variables)]
for (e, comps) in unsafe { res.iter() } {
let lt = ();
$crate::_query_defvars!(comps, &lt, e, ($($query)*));
$body
}
}};
}

/// Represents a valid query for components without multiple mutable access to the same type of
/// component.
/// NOTE: there's currently no way of for example having one query for `mut A` on entities with a
Expand All @@ -19,7 +120,6 @@ pub struct Query {
pub struct ComponentQuery {
pub id: ComponentId,
pub mutable: bool,
// TODO: add optional queries: `optional: bool,`
}

impl Query {
Expand Down Expand Up @@ -50,21 +150,17 @@ impl Query {
}

#[derive(Debug)]
pub struct QueryResponse<'r, 'q> {
_world_marker: PhantomData<&'r ComponentRegistry>,
pub struct QueryResponse<'w, 'q> {
world: &'w World,
entries: Vec<ComponentEntryRef>,
query: &'q Query,
}

impl<'r, 'q> QueryResponse<'r, 'q> {
pub(crate) fn new(
_registry: &'r ComponentRegistry,
query: &'q Query,
entries: Vec<ComponentEntryRef>,
) -> Self {
impl<'w, 'q> QueryResponse<'w, 'q> {
pub(crate) fn new(world: &'w World, query: &'q Query, entries: Vec<ComponentEntryRef>) -> Self {
debug_assert!(query.components().len() == entries.len());
Self {
_world_marker: PhantomData,
world,
entries,
query,
}
Expand All @@ -91,61 +187,41 @@ impl<'r, 'q> QueryResponse<'r, 'q> {

unsafe fn try_get_by_index(&mut self, index: u32) -> Option<Vec<NonNull<u8>>> {
let mut res = Vec::with_capacity(self.entries.len());
for (e, cq) in self.entries.iter_mut().zip(self.query.components().iter()) {
res.push(if cq.mutable {
NonNull::new(e.get_mut().storage.get_mut_ptr(index as usize))?
} else {
NonNull::new(e.get().storage.get_ptr(index as usize) as *mut _)?
});
for (e, _) in self.entries.iter().zip(self.query.components().iter()) {
res.push(NonNull::new(
e.get().storage.get_ptr(index as usize) as *mut _
)?);
}
Some(res)
}

/// Returns the last index of an entity that has at least one component in the query. There
/// might not actually be a hit for this query at this index, but there is definitly no hits
/// after this index.
fn last_index_worth_checking(&self) -> Option<u32> {
self.entries
.iter()
.flat_map(|e| e.get().storage.last_set_index())
.max()
.map(|max| max as u32)
}

pub unsafe fn iter<'a>(&'a mut self) -> Iter<'a, 'r, 'q> {
Iter::new(self, self.last_index_worth_checking())
pub unsafe fn iter<'a>(&'a mut self) -> Iter<'a, 'w, 'q> {
Iter::new(self)
}
}

pub struct Iter<'a, 'r, 'q> {
index: u32,
last: Option<u32>,
res: &'a mut QueryResponse<'r, 'q>,
pub struct Iter<'a, 'w, 'q> {
res: &'a mut QueryResponse<'w, 'q>,
entity_iter: EntityIter<'w>,
}

impl<'a, 'r, 'q> Iter<'a, 'r, 'q> {
pub fn new(res: &'a mut QueryResponse<'r, 'q>, last: Option<u32>) -> Self {
Self {
index: 0,
last,
res,
}
impl<'a, 'w, 'q> Iter<'a, 'w, 'q> {
pub fn new(res: &'a mut QueryResponse<'w, 'q>) -> Self {
let mut entity_iter = res.world.entities().iter();
entity_iter.next(); // Skip the resource entity.
Self { res, entity_iter }
}
}

// TODO: for sparse components this could be optimized
impl<'a, 'r, 'q> Iterator for Iter<'a, 'r, 'q> {
type Item = Vec<NonNull<u8>>;
type Item = (Entity, Vec<NonNull<u8>>);

fn next(&mut self) -> Option<Self::Item> {
while self.index < self.last? {
let index = self.index;
self.index += 1;
let res = unsafe { self.res.try_get_by_index(index) };
if res.is_some() {
return res;
}
}
None
self.entity_iter.next().and_then(|e| unsafe {
self.res
.try_get_by_index(self.entity_iter.entities().id(e).unwrap())
.map(|comps| (e, comps))
})
}
}
7 changes: 6 additions & 1 deletion ecs/src/world.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ impl World {
None => return Err(BorrowMutError::new(c.id)),
}
}
Ok(QueryResponse::new(&self.component_registry, query, entries))
Ok(QueryResponse::new(self, query, entries))
}

/// Tries to query for a set of components. If thats not possible (see `try_query`) this
Expand Down Expand Up @@ -147,4 +147,9 @@ impl World {
.get_mut(id as usize)
})
}

/// Get a reference to the world's entities.
pub fn entities(&self) -> &Entities {
&self.entities
}
}

0 comments on commit 9e5b8ad

Please sign in to comment.