From 80adb1b4c35623e2303d45161058fed52dbb1edb Mon Sep 17 00:00:00 2001 From: Vaivaswatha N Date: Sat, 11 Jan 2025 16:05:16 +0530 Subject: [PATCH] Derive format for enums Fixes #43 --- pliron-derive/src/derive_format.rs | 216 +++++++++++++++++++++++++---- pliron-derive/src/irfmt/mod.rs | 103 +++++++++----- pliron-derive/tests/format_base.rs | 64 ++++++++- 3 files changed, 321 insertions(+), 62 deletions(-) diff --git a/pliron-derive/src/derive_format.rs b/pliron-derive/src/derive_format.rs index 2213f91..0d17422 100644 --- a/pliron-derive/src/derive_format.rs +++ b/pliron-derive/src/derive_format.rs @@ -1,11 +1,11 @@ use proc_macro2::TokenStream; use quote::{format_ident, quote, ToTokens}; use rustc_hash::FxHashMap; -use syn::{DeriveInput, LitStr, Result}; +use syn::{spanned::Spanned, DeriveInput, LitStr, Result}; use crate::irfmt::{ - canonical_format_for_structs, canonical_op_format, Directive, Elem, FieldIdent, FmtData, - Format, Lit, UnnamedVar, Var, + canonical_format_for_enums, canonical_format_for_structs, canonical_op_format, Directive, Elem, + FieldIdent, FmtData, Format, Lit, UnnamedVar, Var, }; /// Data parsed from the macro to derive formatting for Rust types @@ -35,17 +35,27 @@ pub(crate) fn derive( // Prepare the format description. let args = Into::::into(args); - let format = if !args.is_empty() { + let format_data = FmtData::try_from(input.clone())?; + + let format = if let FmtData::Enum(_) = format_data { + // Enums have only one preset format and are not customizable. + if !args.is_empty() { + return Err(syn::Error::new_spanned( + input, + "Custom format strings are not supported for Enums".to_string(), + )); + } + canonical_format_for_enums() + } else if !args.is_empty() { let format_str = syn::parse2::(args)?; Format::parse(&format_str.value()).map_err(|e| syn::Error::new_spanned(format_str, e))? } else if irobj == DeriveIRObject::Op { canonical_op_format() } else { - canonical_format_for_structs(&input)? + canonical_format_for_structs(&format_data, input.span())? }; // Prepare data for the deriver. - let format_data = FmtData::try_from(input.clone())?; let format_input = FmtInput { ident: input.ident.clone(), data: format_data, @@ -119,7 +129,65 @@ trait PrintableBuilder { // Build the body of the outer function Printable::fmt. fn build_body(input: &FmtInput, state: &mut State) -> Result { - Self::build_format(input, state) + if let FmtData::Enum(r#enum) = &input.data { + let mut output = quote! {}; + for variant in &r#enum.variants { + let variant_name = variant.name.clone(); + let fmt_data = FmtData::Struct(variant.clone()); + let format = canonical_format_for_structs(&fmt_data, input.ident.span())?; + let variant_input = FmtInput { + ident: variant.name.clone(), + data: fmt_data, + format, + }; + let mut variant_fields = quote! {}; + if !variant.fields.is_empty() { + let mut is_named = false; + let mut fields = quote! {}; + for field in &variant.fields { + let field_name = field.ident.clone(); + match field_name { + FieldIdent::Named(name) => { + is_named = true; + let field = format_ident!("{}", name); + fields.extend(quote! { + #field, + }); + } + FieldIdent::Unnamed(index) => { + let field = format_ident!("field_at_{}", index); + fields.extend(quote! { + #field, + }); + } + } + } + if is_named { + variant_fields.extend(quote! { + { #fields } + }); + } else { + variant_fields.extend(quote! { + ( #fields ) + }); + } + } + let variant_tokens = Self::build_format(&variant_input, state)?; + output.extend(quote! { + Self::#variant_name #variant_fields => { + write!(fmt, "{}", stringify!(#variant_name))?; + #variant_tokens + }, + }); + } + Ok(quote! { + match self { + #output + } + }) + } else { + Self::build_format(input, state) + } } fn build_lit(_input: &FmtInput, _state: &mut State, lit: &str) -> TokenStream { @@ -159,8 +227,13 @@ struct DeriveBasePrintable; impl PrintableBuilder<()> for DeriveBasePrintable { fn build_var(input: &FmtInput, _state: &mut (), name: &str) -> Result { - let FmtData::Struct(ref struct_fields) = input.data; - if !struct_fields + let FmtData::Struct(ref r#struct) = input.data else { + return Err(syn::Error::new_spanned( + input.ident.clone(), + "Only structs fields as variables are supported".to_string(), + )); + }; + if !r#struct .fields .iter() .map(|f| &f.ident) @@ -171,14 +244,25 @@ impl PrintableBuilder<()> for DeriveBasePrintable { format!("field name \"{}\" is invalid", name), )); } - let field = format_ident!("{}", name); - Ok(quote! { ::pliron::printable::Printable::fmt(&self.#field, ctx, state, fmt)?; }) + let field_name = format_ident!("{}", name); + // If the field is an enum variant, then we don't need to access it with `&self`. + let field = if r#struct.is_enum_variant { + quote! { #field_name } + } else { + quote! { &self.#field_name } + }; + Ok(quote! { ::pliron::printable::Printable::fmt(#field, ctx, state, fmt)?; }) } fn build_unnamed_var(input: &FmtInput, _state: &mut (), index: usize) -> Result { // This is a struct unnamed field access. - let FmtData::Struct(ref struct_fields) = input.data; - if !struct_fields + let FmtData::Struct(ref r#struct) = input.data else { + return Err(syn::Error::new_spanned( + input.ident.clone(), + "Only tuple indices in structs (tuples) are supported".to_string(), + )); + }; + if !r#struct .fields .iter() .map(|f| &f.ident) @@ -190,11 +274,20 @@ impl PrintableBuilder<()> for DeriveBasePrintable { )); } let index = syn::Index::from(index); - Ok(quote! { ::pliron::printable::Printable::fmt(&self.#index, ctx, state, fmt)?; }) + // If the field is an enum variant, then we don't need to access it with `&self`, + // but instead need to use the index field which we gave a name to, when expanding + // the match arm. + let field = if r#struct.is_enum_variant { + let field_at_index = format_ident!("field_at_{}", index); + quote! { #field_at_index } + } else { + quote! { &self.#index } + }; + Ok(quote! { ::pliron::printable::Printable::fmt(#field, ctx, state, fmt)?; }) } fn build_directive(_input: &FmtInput, _state: &mut (), _d: &Directive) -> Result { - todo!() + unimplemented!("No directives supported in DeriveBasePrintable") } } @@ -316,6 +409,10 @@ trait ParsableBuilder { ) -> ::pliron::parsable::ParseResult<'a, Self::Parsed> { use ::pliron::parsable::IntoParseResult; use ::combine::Parser; + use ::pliron::input_err; + use ::pliron::location::Located; + let cur_loc = state_stream.loc(); + #body #final_ret_value } @@ -367,8 +464,65 @@ trait ParsableBuilder { } fn build_body(input: &FmtInput, state: &mut State) -> Result { + if let FmtData::Enum(r#enum) = &input.data { + let enum_name = r#enum.name.clone(); + assert!(r#enum.variants.len() > 0, "Enum has no variants"); + let variant_name_parsed = quote! { + let variant_name_parsed = + ::pliron::identifier::Identifier::parser(()). + parse_stream(state_stream).into_result()?.0.to_string(); + }; + + let mut match_arms = quote! {}; + for variant in &r#enum.variants { + let variant_name = variant.name.clone(); + let variant_name_str = variant_name.to_string(); + let parsed_variant = if variant.fields.is_empty() { + quote! { + // Could as well use Self::#variant_name here. + #enum_name::#variant_name + } + } else { + let fmt_data = FmtData::Struct(variant.clone()); + let format = canonical_format_for_structs(&fmt_data, input.ident.span())?; + let variant_input = FmtInput { + ident: variant.name.clone(), + data: fmt_data, + format, + }; + let built_body = Self::build_body(&variant_input, state)?; + quote! { + #built_body + final_ret_value + } + }; + match_arms.extend(quote! { + #variant_name_str => { + #parsed_variant + }, + }); + } + return Ok(quote! { + #variant_name_parsed + let final_ret_value = match variant_name_parsed.as_str() { + #match_arms + _ => { + return input_err!( + cur_loc, + "Invalid variant name: {}", variant_name_parsed + )?; + } + }; + }); + } + let mut processed_fmt = Self::build_format(input, state)?; - let FmtData::Struct(r#struct) = &input.data; + let FmtData::Struct(r#struct) = &input.data else { + return Err(syn::Error::new_spanned( + input.ident.clone(), + "Only structs are supported by default impl of ParsableBuilder".to_string(), + )); + }; let name = &r#struct.name; let mut named = true; @@ -391,15 +545,20 @@ trait ParsableBuilder { } } + let name_prefix = if r#struct.is_enum_variant { + quote! { Self:: } + } else { + quote! {} + }; let final_ret_value = if named { quote! { - let final_ret_value = #name { + let final_ret_value = #name_prefix #name { #obj_builder }; } } else { quote! { - let final_ret_value = #name ( + let final_ret_value = #name_prefix #name ( #obj_builder ); } @@ -412,7 +571,12 @@ trait ParsableBuilder { } fn build_var(input: &FmtInput, _state: &mut State, name: &str) -> Result { - let FmtData::Struct(r#struct) = &input.data; + let FmtData::Struct(r#struct) = &input.data else { + return Err(syn::Error::new_spanned( + input.ident.clone(), + "Only structs fields as variables are supported".to_string(), + )); + }; let Some(crate::irfmt::Field { ty, .. }) = r#struct .fields @@ -435,7 +599,13 @@ trait ParsableBuilder { _state: &mut State, index: usize, ) -> Result { - let FmtData::Struct(r#struct) = &input.data; + let FmtData::Struct(r#struct) = &input.data else { + return Err(syn::Error::new_spanned( + input.ident.clone(), + "Only tuple indices in structs (tuples) are supported".to_string(), + )); + }; + let crate::irfmt::Field { ty, .. } = r#struct.fields.get(index).ok_or(syn::Error::new_spanned( input.ident.clone(), @@ -470,7 +640,7 @@ struct DeriveBaseParsable; impl ParsableBuilder<()> for DeriveBaseParsable { fn build_directive(_input: &FmtInput, _state: &mut (), _d: &Directive) -> Result { - todo!() + unimplemented!("No directives supported in DeriveBaseParsable") } } @@ -491,10 +661,6 @@ impl ParsableBuilder for DeriveOpParsable { use ::pliron::op::Op; use ::pliron::operation::Operation; use ::pliron::irfmt::parsers::{process_parsed_ssa_defs, ssa_opd_parser, attr_parser}; - use ::pliron::input_err; - use ::pliron::location::Located; - - let cur_loc = state_stream.loc(); }; let built_format = Self::build_format(input, state)?; diff --git a/pliron-derive/src/irfmt/mod.rs b/pliron-derive/src/irfmt/mod.rs index ff506f9..a82b7b7 100644 --- a/pliron-derive/src/irfmt/mod.rs +++ b/pliron-derive/src/irfmt/mod.rs @@ -1,14 +1,15 @@ -use proc_macro2::TokenStream; +use proc_macro2::{Span, TokenStream}; use quote::format_ident; use syn::parse::{Parse, ParseStream}; -use syn::Type; -use syn::{self, DataStruct, DeriveInput}; +use syn::{self, DataEnum, DeriveInput}; use syn::{Data, Ident}; +use syn::{Fields, Type}; mod parser; pub(crate) enum FmtData { Struct(Struct), + Enum(Enum), } impl Parse for FmtData { @@ -23,11 +24,10 @@ impl TryFrom for FmtData { fn try_from(input: DeriveInput) -> syn::Result { match input.data { - Data::Struct(ref data) => Struct::from_syn(input.ident, data).map(FmtData::Struct), - Data::Enum(_) => Err(syn::Error::new_spanned( - &input, - "Format can only be derived for structs", - )), + Data::Struct(ref data) => { + Struct::from_syn(input.ident, &data.fields, false).map(FmtData::Struct) + } + Data::Enum(ref data) => Enum::from_syn(input.ident, data).map(FmtData::Enum), Data::Union(_) => Err(syn::Error::new_spanned( &input, "Format can only be derived for structs", @@ -36,16 +36,25 @@ impl TryFrom for FmtData { } } +/// Enum format data. +#[derive(Clone)] +pub(crate) struct Enum { + pub name: Ident, + pub variants: Vec, +} + /// Struct format data. +#[derive(Clone)] pub(crate) struct Struct { pub name: Ident, pub fields: Vec, + // Whether the struct is for an enum variant + pub is_enum_variant: bool, } impl Struct { - fn from_syn(name: Ident, data: &DataStruct) -> syn::Result { - let fields = data - .fields + fn from_syn(name: Ident, fields: &Fields, is_enum_variant: bool) -> syn::Result { + let fields = fields .iter() .enumerate() .map(|(i, f)| match f.ident { @@ -60,10 +69,27 @@ impl Struct { }) .collect(); - Ok(Self { name, fields }) + Ok(Self { + name, + fields, + is_enum_variant, + }) + } +} + +impl Enum { + fn from_syn(name: Ident, data: &DataEnum) -> syn::Result { + let variants = data + .variants + .iter() + .map(|v| Struct::from_syn(v.ident.clone(), &v.fields, true)) + .collect::>()?; + + Ok(Self { name, variants }) } } +#[derive(Clone)] pub(crate) struct Field { pub(crate) ty: Type, pub(crate) ident: FieldIdent, @@ -419,45 +445,52 @@ pub(crate) fn canonical_op_format() -> Format { elems: vec![Directive::new("canonical").into()], } } +/// Enums have just one preset format, which is: +/// "Variant ". +pub(crate) fn canonical_format_for_enums() -> Format { + Format { + elems: vec![Directive::new("canonical").into()], + } +} /// Describe a canonical syntax for types / attributes defined by a struct. /// This is just "". -pub(crate) fn canonical_format_for_structs(input: &syn::DeriveInput) -> syn::Result { +pub(crate) fn canonical_format_for_structs(input: &FmtData, span: Span) -> syn::Result { // TODO: add support for per field attributes? - let data = match input.data { - Data::Struct(ref data) => data, - _ => { - return Err(syn::Error::new_spanned( - input, - "Format can only be derived for structs", - )) - } + let FmtData::Struct(data) = input else { + return Err(syn::Error::new( + span, + "Format can only be derived for structs", + )); }; - let elems = match data.fields { - syn::Fields::Named(ref fields) => { - let mut elems = vec![]; - for (i, field) in fields.named.iter().enumerate() { - let ident = field.ident.as_ref().unwrap(); + let mut elems = vec![]; + let mut is_named = false; + for (i, field) in data.fields.iter().enumerate() { + match &field.ident { + FieldIdent::Named(field) => { + is_named = true; if i > 0 { elems.push(Elem::new_lit(",")); } - elems.push(Elem::new_lit(format!("{}", ident))); + elems.push(Elem::new_lit(field)); elems.push(Elem::new_lit("=".to_string())); - elems.push(Elem::new_var(ident.to_string())); + elems.push(Elem::new_var(field)); + } + FieldIdent::Unnamed(field) => { + elems.push(Elem::new_unnamed_var(*field)); } - elems } - syn::Fields::Unnamed(ref fields) => (0..(fields.unnamed.len())) - .map(Elem::new_unnamed_var) - .collect::>(), - syn::Fields::Unit => vec![], - }; + } let mut format = Format { elems }; if !format.is_empty() { - format.enclose(Elem::Lit("<".into()), Elem::Lit(">".into())); + if is_named { + format.enclose(Elem::Lit("{".into()), Elem::Lit("}".into())); + } else { + format.enclose(Elem::Lit("(".into()), Elem::Lit(")".into())); + } } Ok(format) } diff --git a/pliron-derive/tests/format_base.rs b/pliron-derive/tests/format_base.rs index 3faca03..b2f75ba 100644 --- a/pliron-derive/tests/format_base.rs +++ b/pliron-derive/tests/format_base.rs @@ -24,7 +24,7 @@ fn int_wrapper() { let test_ty = IntWrapper { inner: int_ty }; let printed = test_ty.disp(ctx).to_string(); - assert_eq!(">", &printed); + assert_eq!("{inner=builtin.int }", &printed); let state_stream = state_stream_from_iterator( printed.chars(), @@ -79,7 +79,7 @@ fn double_wrap() { let printed = test_ty.disp(ctx).to_string(); assert_eq!( - ",two=>>", + "{one=builtin.int ,two={inner=builtin.int }}", &printed ); @@ -92,3 +92,63 @@ fn double_wrap() { .expect("DoubleWrap parser failed"); assert_eq!(res.disp(ctx).to_string(), printed); } + +#[format] +enum Enum { + A(TypePtr), + B { one: TypePtr, two: IntWrapper }, + C, +} + +#[test] +fn enum_test() { + let ctx = &mut setup_context_dialects(); + let int_ty = IntegerType::get(ctx, 64, pliron::builtin::types::Signedness::Signed); + let test_ty = Enum::B { + one: int_ty, + two: IntWrapper { inner: int_ty }, + }; + + let printed = test_ty.disp(ctx).to_string(); + assert_eq!( + "B{one=builtin.int ,two={inner=builtin.int }}", + &printed + ); + + let state_stream = state_stream_from_iterator( + printed.chars(), + parsable::State::new(ctx, location::Source::InMemory), + ); + let (res, _) = Enum::parser(()) + .parse(state_stream) + .expect("Enum parser failed"); + assert_eq!(res.disp(ctx).to_string(), printed); + + let test_ty = Enum::A(int_ty); + let printed = test_ty.disp(ctx).to_string(); + assert_eq!("A(builtin.int )", &printed); + + let state_stream = state_stream_from_iterator( + printed.chars(), + parsable::State::new(ctx, location::Source::InMemory), + ); + let (res, _) = Enum::parser(()) + .parse(state_stream) + .expect("Enum parser failed"); + + assert_eq!(res.disp(ctx).to_string(), printed); + + let test_ty = Enum::C; + let printed = test_ty.disp(ctx).to_string(); + assert_eq!("C", &printed); + + let state_stream = state_stream_from_iterator( + printed.chars(), + parsable::State::new(ctx, location::Source::InMemory), + ); + let (res, _) = Enum::parser(()) + .parse(state_stream) + .expect("Enum parser failed"); + + assert_eq!(res.disp(ctx).to_string(), printed); +}