diff --git a/README.md b/README.md index 8b079051..69d32d40 100644 --- a/README.md +++ b/README.md @@ -216,7 +216,7 @@ The derive implementation supports the following attributes: - `codec(encoded_as = "OtherType")`: Needs to be placed above a field and makes the field being encoded by using `OtherType`. - `codec(index = 0)`: Needs to be placed above an enum variant to make the variant use the given - index when encoded. By default the index is determined by counting from `0` beginning wth the + index when encoded. By default the index is determined by counting from `0` beginning with the first variant. - `codec(encode_bound)`, `codec(decode_bound)` and `codec(mel_bound)`: All 3 attributes take in a `where` clause for the `Encode`, `Decode` and `MaxEncodedLen` trait implementation for diff --git a/derive/src/decode.rs b/derive/src/decode.rs index 7f2d08b2..59a42ce0 100644 --- a/derive/src/decode.rs +++ b/derive/src/decode.rs @@ -15,7 +15,7 @@ use proc_macro2::{Ident, Span, TokenStream}; use syn::{spanned::Spanned, Data, Error, Field, Fields}; -use crate::utils; +use crate::utils::{self, UsedIndexes}; /// Generate function block for function `Decode::decode`. /// @@ -57,9 +57,17 @@ pub fn quote( .to_compile_error(); } - let recurse = data_variants().enumerate().map(|(i, v)| { + let mut used_indexes = match UsedIndexes::from_iter(data_variants()) { + Ok(index) => index, + Err(e) => return e.into_compile_error(), + }; + let mut items = vec![]; + for v in data_variants() { let name = &v.ident; - let index = utils::variant_index(v, i); + let index = match used_indexes.variant_index(v) { + Ok(index) => index, + Err(e) => return e.into_compile_error(), + }; let create = create_instance( quote! { #type_name #type_generics :: #name }, @@ -69,7 +77,7 @@ pub fn quote( crate_path, ); - quote_spanned! { v.span() => + let item = quote_spanned! { v.span() => #[allow(clippy::unnecessary_cast)] __codec_x_edqy if __codec_x_edqy == #index as ::core::primitive::u8 => { // NOTE: This lambda is necessary to work around an upstream bug @@ -80,8 +88,9 @@ pub fn quote( #create })(); }, - } - }); + }; + items.push(item); + } let read_byte_err_msg = format!("Could not decode `{type_name}`, failed to read variant byte"); @@ -91,7 +100,7 @@ pub fn quote( match #input.read_byte() .map_err(|e| e.chain(#read_byte_err_msg))? { - #( #recurse )* + #( #items )* _ => { #[allow(clippy::redundant_closure_call)] return (move || { diff --git a/derive/src/encode.rs b/derive/src/encode.rs index 142bb439..2c3600ea 100644 --- a/derive/src/encode.rs +++ b/derive/src/encode.rs @@ -17,7 +17,7 @@ use std::str::from_utf8; use proc_macro2::{Ident, Span, TokenStream}; use syn::{punctuated::Punctuated, spanned::Spanned, token::Comma, Data, Error, Field, Fields}; -use crate::utils; +use crate::{utils, utils::UsedIndexes}; type FieldsList = Punctuated; @@ -313,12 +313,18 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS if data_variants().count() == 0 { return quote!(); } - - let recurse = data_variants().enumerate().map(|(i, f)| { + let mut used_indexes = match UsedIndexes::from_iter(data_variants()) { + Ok(index) => index, + Err(e) => return e.into_compile_error(), + }; + let mut items = vec![]; + for f in data_variants() { let name = &f.ident; - let index = utils::variant_index(f, i); - - match f.fields { + let index = match used_indexes.variant_index(f) { + Ok(index) => index, + Err(e) => return e.into_compile_error(), + }; + let item = match f.fields { Fields::Named(ref fields) => { let fields = &fields.named; let field_name = |_, ident: &Option| quote!(#ident); @@ -396,11 +402,12 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS [hinting, encoding] }, - } - }); + }; + items.push(item) + } - let recurse_hinting = recurse.clone().map(|[hinting, _]| hinting); - let recurse_encoding = recurse.clone().map(|[_, encoding]| encoding); + let recurse_hinting = items.iter().map(|[hinting, _]| hinting); + let recurse_encoding = items.iter().map(|[_, encoding]| encoding); let hinting = quote! { // The variant index uses 1 byte. diff --git a/derive/src/utils.rs b/derive/src/utils.rs index 091a45ee..07f89d16 100644 --- a/derive/src/utils.rs +++ b/derive/src/utils.rs @@ -17,13 +17,14 @@ //! NOTE: attributes finder must be checked using check_attribute first, //! otherwise the macro can panic. -use std::str::FromStr; +use std::{collections::HashSet, str::FromStr}; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use syn::{ parse::Parse, punctuated::Punctuated, spanned::Spanned, token, Attribute, Data, DeriveInput, - Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Path, Variant, + ExprLit, Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Path, + Variant, }; fn find_meta_item<'a, F, R, I, M>(mut itr: I, mut pred: F) -> Option @@ -37,32 +38,96 @@ where }) } -/// Look for a `#[scale(index = $int)]` attribute on a variant. If no attribute -/// is found, fall back to the discriminant or just the variant index. -pub fn variant_index(v: &Variant, i: usize) -> TokenStream { - // first look for an attribute - let index = find_meta_item(v.attrs.iter(), |meta| { - if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta { - if nv.path.is_ident("index") { - if let Lit::Int(ref v) = nv.lit { - let byte = v - .base10_parse::() - .expect("Internal error, index attribute must have been checked"); - return Some(byte); +pub struct UsedIndexes { + used_set: HashSet, + current: u8, +} + +impl UsedIndexes { + /// Build a Set of used indexes for use with #[scale(index = $int)] attribute on variant + pub fn from_iter<'a, I: Iterator>(values: I) -> syn::Result { + let mut set = HashSet::new(); + for (i, v) in values.enumerate() { + if let Some((index, nv)) = find_meta_item(v.attrs.iter(), |meta| { + if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta { + if nv.path.is_ident("index") { + if let Lit::Int(ref v) = nv.lit { + let byte = v + .base10_parse::() + .expect("Internal error, index attribute must have been checked"); + return Some((byte, nv.span())); + } + } + } + None + }) { + if !set.insert(index) { + return Err(syn::Error::new(nv.span(), "Duplicate variant index. qed")) + } + set.insert(i.try_into().expect("Will never happen. qed")); + } else { + match v.discriminant.as_ref() { + Some(( + _, + expr @ syn::Expr::Lit(ExprLit { lit: syn::Lit::Int(lit_int), .. }), + )) => { + let index = lit_int + .base10_parse::() + .expect("Internal error, index attribute must have been checked"); + if !set.insert(index) { + return Err(syn::Error::new(expr.span(), "Duplicate variant index. qed")) + } + set.insert(i.try_into().expect("Will never happen. qed")); + }, + _ => (), } } } + Ok(Self { current: 0, used_set: set }) + } - None - }); - - // then fallback to discriminant or just index - index.map(|i| quote! { #i }).unwrap_or_else(|| { - v.discriminant - .as_ref() - .map(|(_, expr)| quote! { #expr }) - .unwrap_or_else(|| quote! { #i }) - }) + /// Look for a `#[scale(index = $int)]` attribute on a variant. If no attribute + /// is found, fall back to the discriminant or just the variant index. + pub fn variant_index(&mut self, v: &Variant) -> syn::Result { + // first look for an attribute + let index = find_meta_item(v.attrs.iter(), |meta| { + if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta { + if nv.path.is_ident("index") { + if let Lit::Int(ref v) = nv.lit { + let byte = v + .base10_parse::() + .expect("Internal error, index attribute must have been checked"); + return Some(byte); + } + } + } + + None + }); + + index.map_or_else( + || match v.discriminant.as_ref() { + Some((_, expr)) => return Ok(quote! { #expr }), + None => { + let idx = self.next_index(); + return Ok(quote! { #idx }) + }, + }, + |i| Ok(quote! { #i }), + ) + } + + fn next_index(&mut self) -> u8 { + loop { + if self.used_set.contains(&self.current) { + self.current += 1; + } else { + let index = self.current; + self.current += 1; + return index + } + } + } } /// Look for a `#[codec(encoded_as = "SomeType")]` outer attribute on the given diff --git a/tests/variant_number.rs b/tests/variant_number.rs index 9bdaba0a..c38fd748 100644 --- a/tests/variant_number.rs +++ b/tests/variant_number.rs @@ -10,7 +10,7 @@ fn discriminant_variant_counted_in_default_index() { } assert_eq!(T::A.encode(), vec![1]); - assert_eq!(T::B.encode(), vec![1]); + assert_eq!(T::B.encode(), vec![2]); } #[test] @@ -36,5 +36,5 @@ fn index_attr_variant_counted_and_reused_in_default_index() { } assert_eq!(T::A.encode(), vec![1]); - assert_eq!(T::B.encode(), vec![1]); + assert_eq!(T::B.encode(), vec![2]); }