Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type safe ECS queries #32

Merged
merged 7 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ecs/src/entity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ impl<'e> Iter<'e> {
unused_ids: entities.unused_ids.iter().copied().collect(),
}
}

/// Get the iter's entities.
pub fn entities(&self) -> &Entities {
self.entities
}
}

impl<'e> Iterator for Iter<'e> {
Expand Down
58 changes: 55 additions & 3 deletions ecs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
pub mod component;
mod entity;
mod error;
#[macro_use]
mod query;
mod world;

pub use entity::{Entities, Entity};
pub use error::BorrowMutError;
pub use query::{ComponentQuery, Query};
pub use query::{as_mut_lt, as_ref_lt, ComponentQuery, Query};
pub use world::World;

#[cfg(test)]
Expand Down Expand Up @@ -508,6 +509,57 @@ mod tests {
mem::drop(r);
}

#[test]
fn type_safe_macros() {
let mut world = World::default();
struct Name(String);
struct Speed(f32);
let sanic = world.spawn();
world.add(sanic, Name("Sanic".into()));
world.add(sanic, Speed(100.0));
let mario = world.spawn();
world.add(mario, Name("Mario".into()));
world.add(mario, Speed(200.0)); // copilot thinks mario is faster than sanic

query_iter!(world, (name: Name, speed: mut Speed) => {
match name.0.as_ref() {
"Mario" => assert_eq!(speed.0, 200.0),
"Sanic" => {
assert_eq!(speed.0, 100.0);
speed.0 = 300.0; // copilot thinks he's faster than mario
}
_ => panic!("Unexpected name"),
}
});

query_iter!(world, (entity: Entity, name: Name, speed: Speed) => {
match name.0.as_ref() {
"Mario" => {
assert_eq!(entity, mario);
assert_eq!(speed.0, 200.0)
}
"Sanic" => {
assert_eq!(entity, sanic);
assert_eq!(speed.0, 300.0);
}
_ => panic!("Unexpected name"),
}
});

let mut found_sanic = false;
let mut found_mario = false;
query_iter!(world, (entity: Entity) => {
if found_sanic {
assert_eq!(entity, mario);
found_mario = true;
} else {
assert_eq!(entity, sanic);
found_sanic = true;
}
});
assert!(found_sanic && found_mario);
}

#[test]
fn iterate_over_query() {
let mut world = World::default();
Expand Down Expand Up @@ -535,7 +587,7 @@ mod tests {
.unwrap();
let mut q = world.query(&q);
for (pos, vel) in unsafe {
q.iter().map(|comps| {
q.iter().map(|(_e, comps)| {
if let [pos, vel] = comps[..] {
(
pos.cast::<Position>().as_mut(),
Expand All @@ -558,7 +610,7 @@ mod tests {
let mut q = world.query(&q);
for (i, pos) in unsafe {
q.iter()
.map(|comps| {
.map(|(_e, comps)| {
if let [pos] = comps[..] {
pos.cast::<Position>().as_ref()
} else {
Expand Down
182 changes: 129 additions & 53 deletions ecs/src/query.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,111 @@
use std::{collections::HashSet, marker::PhantomData, ptr::NonNull};
use std::{collections::HashSet, ptr::NonNull};

use crate::{
component::{ComponentEntryRef, ComponentId, ComponentRegistry},
BorrowMutError, Entity,
component::{ComponentEntryRef, ComponentId},
entity::Iter as EntityIter,
BorrowMutError, Entity, World,
};

/// Casts `ptr` to a reference with the lifetime `'a`.
/// # Safety
/// It is the responsibility of the caller to ensure that the lifetime `'a` outlives
/// the lifetime of the data pointed to by `ptr`.
#[allow(clippy::needless_lifetimes)]
pub unsafe fn as_ref_lt<'a, T>(_lifetime: &'a (), ptr: NonNull<T>) -> &'a T {
ptr.as_ref()
}

/// Casts `ptr` to a mutable reference with the lifetime `'a`.
/// # Safety
/// It is the responsibility of the caller to ensure that the lifetime `'a` outlives
/// the lifetime of the data pointed to by `ptr`.
#[allow(clippy::mut_from_ref, clippy::needless_lifetimes)]
pub unsafe fn as_mut_lt<'a, T>(_lifetime: &'a (), mut ptr: NonNull<T>) -> &'a mut T {
ptr.as_mut()
}

#[macro_export]
macro_rules! _query_definition {
( $world:expr, $vec:expr, ($name:ident: Entity, $($tail:tt)*) ) => {{
_query_definition!($world, $vec, ($($tail)*));
}};
( $world:expr, $vec:expr, ($name:ident: $type:ty, $($tail:tt)*) ) => {{
$vec.push(ComponentQuery {
id: $world.component_id::<$type>().unwrap(),
mutable: false,
});
_query_definition!($world, $vec, ($($tail)*));
}};
( $world:expr, $vec:expr, ($name:ident: mut $type:ty, $($tail:tt)*) ) => {{
$vec.push(ComponentQuery {
id: $world.component_id::<$type>().unwrap(),
mutable: true,
});
_query_definition!($world, $vec, ($($tail)*));
}};

// Last entry
( $world:expr, $vec:expr, ($name:ident: Entity) ) => { };
( $world:expr, $vec:expr, ($name:ident: $type:ty) ) => {{
$vec.push(ComponentQuery {
id: $world.component_id::<$type>().unwrap(),
mutable: false,
});
}};
( $world:expr, $vec:expr, ($name:ident: mut $type:ty) ) => {{
$vec.push(ComponentQuery {
id: $world.component_id::<$type>().unwrap(),
mutable: true,
});
}};
}

#[macro_export]
macro_rules! _query_defvars {
( $comps:expr, $lt:expr, $entity:expr, ($name:ident: Entity, $($tail:tt)*) ) => {
let $name = $entity;
_query_defvars!($comps[..], $lt, $entity, ($($tail)*));
};
( $comps:expr, $lt:expr, $entity:expr, ($name:ident: $type:ty, $($tail:tt)*) ) => {
let $name = unsafe { $crate::query::as_ref_lt($lt, $comps[0].cast::<$type>()) };
_query_defvars!($comps[1..], $lt, $entity, ($($tail)*));
};
( $comps:expr, $lt:expr, $entity:expr, ($name:ident: mut $type:ty, $($tail:tt)*) ) => {
let $name = unsafe { $crate::query::as_mut_lt($lt, $comps[0].cast::<$type>()) };
_query_defvars!($comps[1..], $lt, $entity, ($($tail)*));
};

// Last entry
( $comps:expr, $lt:expr, $entity:expr, ($name:ident: Entity) ) => {
let $name = $entity;
};
( $comps:expr, $lt:expr, $entity:expr, ($name:ident: $type:ty) ) => {
let $name = unsafe { $crate::query::as_ref_lt($lt, $comps[0].cast::<$type>()) };
};
( $comps:expr, $lt:expr, $entity:expr, ($name:ident: mut $type:ty) ) => {
let $name = unsafe { $crate::query::as_mut_lt($lt, $comps[0].cast::<$type>()) };
};
}

#[macro_export]
macro_rules! query_iter {
( $world:expr, ($($query:tt)*) => $body:block ) => {{
#[allow(unused_mut)]
let mut v = vec![];
_query_definition!($world, v, ($($query)*));
let q = Query::new(v).expect("Query violates rusts borrow rules");

let mut res = $world.query(&q);

#[allow(unused_variables)]
for (e, comps) in unsafe { res.iter() } {
let lt = ();
$crate::_query_defvars!(comps, &lt, e, ($($query)*));
$body
}
}};
}

/// Represents a valid query for components without multiple mutable access to the same type of
/// component.
/// NOTE: there's currently no way of for example having one query for `mut A` on entities with a
Expand All @@ -19,7 +120,6 @@ pub struct Query {
pub struct ComponentQuery {
pub id: ComponentId,
pub mutable: bool,
// TODO: add optional queries: `optional: bool,`
}

impl Query {
Expand Down Expand Up @@ -50,21 +150,17 @@ impl Query {
}

#[derive(Debug)]
pub struct QueryResponse<'r, 'q> {
_world_marker: PhantomData<&'r ComponentRegistry>,
pub struct QueryResponse<'w, 'q> {
world: &'w World,
entries: Vec<ComponentEntryRef>,
query: &'q Query,
}

impl<'r, 'q> QueryResponse<'r, 'q> {
pub(crate) fn new(
_registry: &'r ComponentRegistry,
query: &'q Query,
entries: Vec<ComponentEntryRef>,
) -> Self {
impl<'w, 'q> QueryResponse<'w, 'q> {
pub(crate) fn new(world: &'w World, query: &'q Query, entries: Vec<ComponentEntryRef>) -> Self {
debug_assert!(query.components().len() == entries.len());
Self {
_world_marker: PhantomData,
world,
entries,
query,
}
Expand All @@ -91,61 +187,41 @@ impl<'r, 'q> QueryResponse<'r, 'q> {

unsafe fn try_get_by_index(&mut self, index: u32) -> Option<Vec<NonNull<u8>>> {
let mut res = Vec::with_capacity(self.entries.len());
for (e, cq) in self.entries.iter_mut().zip(self.query.components().iter()) {
res.push(if cq.mutable {
NonNull::new(e.get_mut().storage.get_mut_ptr(index as usize))?
} else {
NonNull::new(e.get().storage.get_ptr(index as usize) as *mut _)?
});
for (e, _) in self.entries.iter().zip(self.query.components().iter()) {
res.push(NonNull::new(
e.get().storage.get_ptr(index as usize) as *mut _
)?);
}
Some(res)
}

/// Returns the last index of an entity that has at least one component in the query. There
/// might not actually be a hit for this query at this index, but there is definitly no hits
/// after this index.
fn last_index_worth_checking(&self) -> Option<u32> {
self.entries
.iter()
.flat_map(|e| e.get().storage.last_set_index())
.max()
.map(|max| max as u32)
}

pub unsafe fn iter<'a>(&'a mut self) -> Iter<'a, 'r, 'q> {
Iter::new(self, self.last_index_worth_checking())
pub unsafe fn iter<'a>(&'a mut self) -> Iter<'a, 'w, 'q> {
Iter::new(self)
}
}

pub struct Iter<'a, 'r, 'q> {
index: u32,
last: Option<u32>,
res: &'a mut QueryResponse<'r, 'q>,
pub struct Iter<'a, 'w, 'q> {
res: &'a mut QueryResponse<'w, 'q>,
entity_iter: EntityIter<'w>,
}

impl<'a, 'r, 'q> Iter<'a, 'r, 'q> {
pub fn new(res: &'a mut QueryResponse<'r, 'q>, last: Option<u32>) -> Self {
Self {
index: 0,
last,
res,
}
impl<'a, 'w, 'q> Iter<'a, 'w, 'q> {
pub fn new(res: &'a mut QueryResponse<'w, 'q>) -> Self {
let mut entity_iter = res.world.entities().iter();
entity_iter.next(); // Skip the resource entity.
Self { res, entity_iter }
}
}

// TODO: for sparse components this could be optimized
impl<'a, 'r, 'q> Iterator for Iter<'a, 'r, 'q> {
type Item = Vec<NonNull<u8>>;
type Item = (Entity, Vec<NonNull<u8>>);

fn next(&mut self) -> Option<Self::Item> {
while self.index < self.last? {
let index = self.index;
self.index += 1;
let res = unsafe { self.res.try_get_by_index(index) };
if res.is_some() {
return res;
}
}
None
self.entity_iter.next().and_then(|e| unsafe {
self.res
.try_get_by_index(self.entity_iter.entities().id(e).unwrap())
.map(|comps| (e, comps))
})
}
}
7 changes: 6 additions & 1 deletion ecs/src/world.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ impl World {
None => return Err(BorrowMutError::new(c.id)),
}
}
Ok(QueryResponse::new(&self.component_registry, query, entries))
Ok(QueryResponse::new(self, query, entries))
}

/// Tries to query for a set of components. If thats not possible (see `try_query`) this
Expand Down Expand Up @@ -147,4 +147,9 @@ impl World {
.get_mut(id as usize)
})
}

/// Get a reference to the world's entities.
pub fn entities(&self) -> &Entities {
&self.entities
}
}