Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for generic types #574

Merged
merged 3 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion rustler_codegen/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use heck::ToSnakeCase;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{Data, Field, Fields, Ident, Lit, Meta, Variant};
use syn::{Data, Field, Fields, Ident, Lifetime, Lit, Meta, TypeParam, Variant};

use super::RustlerAttr;

Expand All @@ -14,6 +14,8 @@ pub(crate) struct Context<'a> {
pub attrs: Vec<RustlerAttr>,
pub ident: &'a proc_macro2::Ident,
pub generics: &'a syn::Generics,
pub lifetimes: Vec<Lifetime>,
pub type_parameters: Vec<TypeParam>,
pub variants: Option<Vec<&'a Variant>>,
pub struct_fields: Option<Vec<&'a Field>>,
pub is_tuple_struct: bool,
Expand Down Expand Up @@ -50,10 +52,33 @@ impl<'a> Context<'a> {
_ => false,
};

let lifetimes: Vec<_> = ast
.generics
.params
.iter()
.filter_map(|g| match g {
syn::GenericParam::Lifetime(l) => Some(l.lifetime.clone()),
_ => None,
})
.collect();

let type_parameters: Vec<_> = ast
.generics
.params
.iter()
.filter_map(|g| match g {
syn::GenericParam::Type(t) => Some(t.clone()),
// Don't keep lifetimes or generic constants
_ => None,
})
.collect();

Self {
attrs,
ident: &ast.ident,
generics: &ast.generics,
lifetimes,
type_parameters,
variants,
struct_fields,
is_tuple_struct,
Expand Down
145 changes: 111 additions & 34 deletions rustler_codegen/src/encode_decode_templates.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{GenericArgument, PathSegment, TraitBound};

use super::context::Context;

Expand All @@ -11,48 +12,85 @@ pub(crate) fn decoder(ctx: &Context, inner: TokenStream) -> TokenStream {
// The Decoder uses a special lifetime '__rustler_decode_lifetime. We need to ensure that all
// other lifetimes are bound to this lifetime: As we decode from a term (which has a lifetime),
// references to that term may not outlive the term itself.
let lifetimes: Vec<_> = generics
.params
.iter()
.filter_map(|g| match g {
syn::GenericParam::Lifetime(l) => Some(l.lifetime.clone()),
_ => None,
})
.collect();

let mut impl_generics = generics.clone();
let decode_lifetime = syn::Lifetime::new("'__rustler_decode_lifetime", Span::call_site());
let lifetime_def = syn::LifetimeParam::new(decode_lifetime.clone());
impl_generics
.params
.push(syn::GenericParam::Lifetime(lifetime_def));

if !lifetimes.is_empty() {
let where_clause = impl_generics.make_where_clause();
let where_clause = impl_generics.make_where_clause();

for lifetime in lifetimes {
let mut puncated = syn::punctuated::Punctuated::new();
puncated.push(lifetime.clone());
let predicate = syn::PredicateLifetime {
lifetime: decode_lifetime.clone(),
colon_token: syn::token::Colon {
spans: [Span::call_site()],
},
bounds: puncated,
};
where_clause.predicates.push(predicate.into());
for lifetime in ctx.lifetimes.iter() {
let mut punctuated = syn::punctuated::Punctuated::new();
punctuated.push(lifetime.clone());
let predicate = syn::PredicateLifetime {
lifetime: decode_lifetime.clone(),
colon_token: syn::token::Colon {
spans: [Span::call_site()],
},
bounds: punctuated,
};
where_clause.predicates.push(predicate.into());

let mut puncated = syn::punctuated::Punctuated::new();
puncated.push(decode_lifetime.clone());
let predicate = syn::PredicateLifetime {
lifetime: lifetime.clone(),
colon_token: syn::token::Colon {
spans: [Span::call_site()],
},
bounds: puncated,
};
where_clause.predicates.push(predicate.into());
}
let mut punctuated = syn::punctuated::Punctuated::new();
punctuated.push(decode_lifetime.clone());
let predicate = syn::PredicateLifetime {
lifetime: lifetime.clone(),
colon_token: syn::token::Colon {
spans: [Span::call_site()],
},
bounds: punctuated,
};
where_clause.predicates.push(predicate.into());
}

for type_parameter in ctx.type_parameters.iter() {
let mut punctuated = syn::punctuated::Punctuated::new();
punctuated.push(decode_lifetime.clone().into());
punctuated.push(syn::TypeParamBound::Trait(TraitBound {
paren_token: None,
modifier: syn::TraitBoundModifier::None,
lifetimes: None,
path: syn::Path {
leading_colon: Some(syn::token::PathSep::default()),
segments: [
PathSegment {
ident: syn::Ident::new("rustler", Span::call_site()),
arguments: syn::PathArguments::None,
},
PathSegment {
ident: syn::Ident::new("Decoder", Span::call_site()),
arguments: syn::PathArguments::AngleBracketed(
syn::AngleBracketedGenericArguments {
colon2_token: None,
lt_token: Default::default(),
args: std::iter::once(GenericArgument::Lifetime(
decode_lifetime.clone(),
))
.collect(),
gt_token: Default::default(),
},
),
},
]
.iter()
.cloned()
.collect(),
},
}));
let predicate = syn::PredicateType {
lifetimes: None,
bounded_ty: syn::Type::Path(syn::TypePath {
qself: None,
path: type_parameter.clone().ident.into(),
}),
colon_token: syn::token::Colon {
spans: [Span::call_site()],
},
bounds: punctuated,
};
where_clause.predicates.push(predicate.into());
}

let (impl_generics, _, where_clause) = impl_generics.split_for_impl();
Expand All @@ -69,7 +107,46 @@ pub(crate) fn decoder(ctx: &Context, inner: TokenStream) -> TokenStream {

pub(crate) fn encoder(ctx: &Context, inner: TokenStream) -> TokenStream {
let ident = ctx.ident;
let generics = ctx.generics;
let mut generics = ctx.generics.clone();

let where_clause = generics.make_where_clause();

for type_parameter in ctx.type_parameters.iter() {
let mut punctuated = syn::punctuated::Punctuated::new();
punctuated.push(syn::TypeParamBound::Trait(TraitBound {
paren_token: None,
modifier: syn::TraitBoundModifier::None,
lifetimes: None,
path: syn::Path {
leading_colon: Some(syn::token::PathSep::default()),
segments: [
PathSegment {
ident: syn::Ident::new("rustler", Span::call_site()),
arguments: syn::PathArguments::None,
},
PathSegment {
ident: syn::Ident::new("Encoder", Span::call_site()),
arguments: syn::PathArguments::None,
},
]
.iter()
.cloned()
.collect(),
},
}));
let predicate = syn::PredicateType {
lifetimes: None,
bounded_ty: syn::Type::Path(syn::TypePath {
qself: None,
path: type_parameter.ident.clone().into(),
}),
colon_token: syn::token::Colon {
spans: [Span::call_site()],
},
bounds: punctuated,
};
where_clause.predicates.push(predicate.into());
}
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

quote! {
Expand Down
2 changes: 2 additions & 0 deletions rustler_tests/lib/rustler_test.ex
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ defmodule RustlerTest do
def newtype_record_echo(_), do: err()
def tuplestruct_record_echo(_), do: err()
def reserved_keywords_type_echo(_), do: err()
def generic_struct_echo(_), do: err()
def mk_generic_map(_), do: err()

def dirty_io(), do: err()
def dirty_cpu(), do: err()
Expand Down
4 changes: 3 additions & 1 deletion rustler_tests/native/rustler_test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ rustler::init!(
test_tuple::maybe_add_one_to_tuple,
test_tuple::add_i32_from_tuple,
test_tuple::greeting_person_from_tuple,
test_codegen::reserved_keywords::reserved_keywords_type_echo
test_codegen::reserved_keywords::reserved_keywords_type_echo,
test_codegen::generic_types::generic_struct_echo,
test_codegen::generic_types::mk_generic_map,
],
load = load
);
Expand Down
25 changes: 25 additions & 0 deletions rustler_tests/native/rustler_test/src/test_codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,28 @@ pub mod reserved_keywords {
reserved
}
}

pub mod generic_types {
use rustler::{NifMap, NifStruct};
#[derive(NifStruct)]
#[module = "GenericStruct"]
pub struct GenericStruct<T> {
t: T,
}

#[rustler::nif]
pub fn generic_struct_echo(value: GenericStruct<i32>) -> GenericStruct<i32> {
value
}

#[derive(NifMap)]
pub struct GenericMap<T> {
a: T,
b: T,
}

#[rustler::nif]
pub fn mk_generic_map(value: &str) -> GenericMap<&str> {
GenericMap { a: value, b: value }
}
}
12 changes: 12 additions & 0 deletions rustler_tests/test/codegen_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -434,4 +434,16 @@ defmodule RustlerTest.CodegenTest do
assert {1} == RustlerTest.reserved_keywords_type_echo({1})
assert {:record, 1} == RustlerTest.reserved_keywords_type_echo({:record, 1})
end

describe "generic types" do
test "generic struct" do
assert %{__struct__: GenericStruct, t: 1} ==
RustlerTest.generic_struct_echo(%{__struct__: GenericStruct, t: 1})
end

test "generic map" do
assert %{a: "hello", b: "hello"} ==
RustlerTest.mk_generic_map("hello")
end
end
end