Skip to content

Commit

Permalink
Derive format for enums
Browse files Browse the repository at this point in the history
Fixes #43
  • Loading branch information
vaivaswatha committed Jan 11, 2025
1 parent 8ae78de commit 80adb1b
Show file tree
Hide file tree
Showing 3 changed files with 321 additions and 62 deletions.
216 changes: 191 additions & 25 deletions pliron-derive/src/derive_format.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use proc_macro2::TokenStream;
use quote::{format_ident, quote, ToTokens};
use rustc_hash::FxHashMap;
use syn::{DeriveInput, LitStr, Result};
use syn::{spanned::Spanned, DeriveInput, LitStr, Result};

use crate::irfmt::{
canonical_format_for_structs, canonical_op_format, Directive, Elem, FieldIdent, FmtData,
Format, Lit, UnnamedVar, Var,
canonical_format_for_enums, canonical_format_for_structs, canonical_op_format, Directive, Elem,
FieldIdent, FmtData, Format, Lit, UnnamedVar, Var,
};

/// Data parsed from the macro to derive formatting for Rust types
Expand Down Expand Up @@ -35,17 +35,27 @@ pub(crate) fn derive(

// Prepare the format description.
let args = Into::<TokenStream>::into(args);
let format = if !args.is_empty() {
let format_data = FmtData::try_from(input.clone())?;

let format = if let FmtData::Enum(_) = format_data {
// Enums have only one preset format and are not customizable.
if !args.is_empty() {
return Err(syn::Error::new_spanned(
input,
"Custom format strings are not supported for Enums".to_string(),
));
}
canonical_format_for_enums()
} else if !args.is_empty() {
let format_str = syn::parse2::<LitStr>(args)?;
Format::parse(&format_str.value()).map_err(|e| syn::Error::new_spanned(format_str, e))?
} else if irobj == DeriveIRObject::Op {
canonical_op_format()
} else {
canonical_format_for_structs(&input)?
canonical_format_for_structs(&format_data, input.span())?
};

// Prepare data for the deriver.
let format_data = FmtData::try_from(input.clone())?;
let format_input = FmtInput {
ident: input.ident.clone(),
data: format_data,
Expand Down Expand Up @@ -119,7 +129,65 @@ trait PrintableBuilder<State: Default> {

// Build the body of the outer function Printable::fmt.
fn build_body(input: &FmtInput, state: &mut State) -> Result<TokenStream> {
Self::build_format(input, state)
if let FmtData::Enum(r#enum) = &input.data {
let mut output = quote! {};
for variant in &r#enum.variants {
let variant_name = variant.name.clone();
let fmt_data = FmtData::Struct(variant.clone());
let format = canonical_format_for_structs(&fmt_data, input.ident.span())?;
let variant_input = FmtInput {
ident: variant.name.clone(),
data: fmt_data,
format,
};
let mut variant_fields = quote! {};
if !variant.fields.is_empty() {
let mut is_named = false;
let mut fields = quote! {};
for field in &variant.fields {
let field_name = field.ident.clone();
match field_name {
FieldIdent::Named(name) => {
is_named = true;
let field = format_ident!("{}", name);
fields.extend(quote! {
#field,
});
}
FieldIdent::Unnamed(index) => {
let field = format_ident!("field_at_{}", index);
fields.extend(quote! {
#field,
});
}
}
}
if is_named {
variant_fields.extend(quote! {
{ #fields }
});
} else {
variant_fields.extend(quote! {
( #fields )
});
}
}
let variant_tokens = Self::build_format(&variant_input, state)?;
output.extend(quote! {
Self::#variant_name #variant_fields => {
write!(fmt, "{}", stringify!(#variant_name))?;
#variant_tokens
},
});
}
Ok(quote! {
match self {
#output
}
})
} else {
Self::build_format(input, state)
}
}

fn build_lit(_input: &FmtInput, _state: &mut State, lit: &str) -> TokenStream {
Expand Down Expand Up @@ -159,8 +227,13 @@ struct DeriveBasePrintable;

impl PrintableBuilder<()> for DeriveBasePrintable {
fn build_var(input: &FmtInput, _state: &mut (), name: &str) -> Result<TokenStream> {
let FmtData::Struct(ref struct_fields) = input.data;
if !struct_fields
let FmtData::Struct(ref r#struct) = input.data else {
return Err(syn::Error::new_spanned(
input.ident.clone(),
"Only structs fields as variables are supported".to_string(),
));
};
if !r#struct
.fields
.iter()
.map(|f| &f.ident)
Expand All @@ -171,14 +244,25 @@ impl PrintableBuilder<()> for DeriveBasePrintable {
format!("field name \"{}\" is invalid", name),
));
}
let field = format_ident!("{}", name);
Ok(quote! { ::pliron::printable::Printable::fmt(&self.#field, ctx, state, fmt)?; })
let field_name = format_ident!("{}", name);
// If the field is an enum variant, then we don't need to access it with `&self`.
let field = if r#struct.is_enum_variant {
quote! { #field_name }
} else {
quote! { &self.#field_name }
};
Ok(quote! { ::pliron::printable::Printable::fmt(#field, ctx, state, fmt)?; })
}

fn build_unnamed_var(input: &FmtInput, _state: &mut (), index: usize) -> Result<TokenStream> {
// This is a struct unnamed field access.
let FmtData::Struct(ref struct_fields) = input.data;
if !struct_fields
let FmtData::Struct(ref r#struct) = input.data else {
return Err(syn::Error::new_spanned(
input.ident.clone(),
"Only tuple indices in structs (tuples) are supported".to_string(),
));
};
if !r#struct
.fields
.iter()
.map(|f| &f.ident)
Expand All @@ -190,11 +274,20 @@ impl PrintableBuilder<()> for DeriveBasePrintable {
));
}
let index = syn::Index::from(index);
Ok(quote! { ::pliron::printable::Printable::fmt(&self.#index, ctx, state, fmt)?; })
// If the field is an enum variant, then we don't need to access it with `&self`,
// but instead need to use the index field which we gave a name to, when expanding
// the match arm.
let field = if r#struct.is_enum_variant {
let field_at_index = format_ident!("field_at_{}", index);
quote! { #field_at_index }
} else {
quote! { &self.#index }
};
Ok(quote! { ::pliron::printable::Printable::fmt(#field, ctx, state, fmt)?; })
}

fn build_directive(_input: &FmtInput, _state: &mut (), _d: &Directive) -> Result<TokenStream> {
todo!()
unimplemented!("No directives supported in DeriveBasePrintable")
}
}

Expand Down Expand Up @@ -316,6 +409,10 @@ trait ParsableBuilder<State: Default> {
) -> ::pliron::parsable::ParseResult<'a, Self::Parsed> {
use ::pliron::parsable::IntoParseResult;
use ::combine::Parser;
use ::pliron::input_err;
use ::pliron::location::Located;
let cur_loc = state_stream.loc();

#body
#final_ret_value
}
Expand Down Expand Up @@ -367,8 +464,65 @@ trait ParsableBuilder<State: Default> {
}

fn build_body(input: &FmtInput, state: &mut State) -> Result<TokenStream> {
if let FmtData::Enum(r#enum) = &input.data {
let enum_name = r#enum.name.clone();
assert!(r#enum.variants.len() > 0, "Enum has no variants");
let variant_name_parsed = quote! {
let variant_name_parsed =
::pliron::identifier::Identifier::parser(()).
parse_stream(state_stream).into_result()?.0.to_string();
};

let mut match_arms = quote! {};
for variant in &r#enum.variants {
let variant_name = variant.name.clone();
let variant_name_str = variant_name.to_string();
let parsed_variant = if variant.fields.is_empty() {
quote! {
// Could as well use Self::#variant_name here.
#enum_name::#variant_name
}
} else {
let fmt_data = FmtData::Struct(variant.clone());
let format = canonical_format_for_structs(&fmt_data, input.ident.span())?;
let variant_input = FmtInput {
ident: variant.name.clone(),
data: fmt_data,
format,
};
let built_body = Self::build_body(&variant_input, state)?;
quote! {
#built_body
final_ret_value
}
};
match_arms.extend(quote! {
#variant_name_str => {
#parsed_variant
},
});
}
return Ok(quote! {
#variant_name_parsed
let final_ret_value = match variant_name_parsed.as_str() {
#match_arms
_ => {
return input_err!(
cur_loc,
"Invalid variant name: {}", variant_name_parsed
)?;
}
};
});
}

let mut processed_fmt = Self::build_format(input, state)?;
let FmtData::Struct(r#struct) = &input.data;
let FmtData::Struct(r#struct) = &input.data else {
return Err(syn::Error::new_spanned(
input.ident.clone(),
"Only structs are supported by default impl of ParsableBuilder".to_string(),
));
};
let name = &r#struct.name;
let mut named = true;

Expand All @@ -391,15 +545,20 @@ trait ParsableBuilder<State: Default> {
}
}

let name_prefix = if r#struct.is_enum_variant {
quote! { Self:: }
} else {
quote! {}
};
let final_ret_value = if named {
quote! {
let final_ret_value = #name {
let final_ret_value = #name_prefix #name {
#obj_builder
};
}
} else {
quote! {
let final_ret_value = #name (
let final_ret_value = #name_prefix #name (
#obj_builder
);
}
Expand All @@ -412,7 +571,12 @@ trait ParsableBuilder<State: Default> {
}

fn build_var(input: &FmtInput, _state: &mut State, name: &str) -> Result<TokenStream> {
let FmtData::Struct(r#struct) = &input.data;
let FmtData::Struct(r#struct) = &input.data else {
return Err(syn::Error::new_spanned(
input.ident.clone(),
"Only structs fields as variables are supported".to_string(),
));
};

let Some(crate::irfmt::Field { ty, .. }) = r#struct
.fields
Expand All @@ -435,7 +599,13 @@ trait ParsableBuilder<State: Default> {
_state: &mut State,
index: usize,
) -> Result<TokenStream> {
let FmtData::Struct(r#struct) = &input.data;
let FmtData::Struct(r#struct) = &input.data else {
return Err(syn::Error::new_spanned(
input.ident.clone(),
"Only tuple indices in structs (tuples) are supported".to_string(),
));
};

let crate::irfmt::Field { ty, .. } =
r#struct.fields.get(index).ok_or(syn::Error::new_spanned(
input.ident.clone(),
Expand Down Expand Up @@ -470,7 +640,7 @@ struct DeriveBaseParsable;

impl ParsableBuilder<()> for DeriveBaseParsable {
fn build_directive(_input: &FmtInput, _state: &mut (), _d: &Directive) -> Result<TokenStream> {
todo!()
unimplemented!("No directives supported in DeriveBaseParsable")
}
}

Expand All @@ -491,10 +661,6 @@ impl ParsableBuilder<OpParserState> for DeriveOpParsable {
use ::pliron::op::Op;
use ::pliron::operation::Operation;
use ::pliron::irfmt::parsers::{process_parsed_ssa_defs, ssa_opd_parser, attr_parser};
use ::pliron::input_err;
use ::pliron::location::Located;

let cur_loc = state_stream.loc();
};

let built_format = Self::build_format(input, state)?;
Expand Down
Loading

0 comments on commit 80adb1b

Please sign in to comment.