Skip to content

Commit

Permalink
Refactor rust parser generator to use tokenstream directly
Browse files Browse the repository at this point in the history
  • Loading branch information
hchataing committed May 22, 2024
1 parent a603e4b commit cb14a2a
Showing 1 changed file with 26 additions and 29 deletions.
55 changes: 26 additions & 29 deletions pdl-compiler/src/backends/rust/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub struct FieldParser<'a> {
packet_name: &'a str,
span: &'a proc_macro2::Ident,
chunk: Vec<BitField<'a>>,
code: Vec<proc_macro2::TokenStream>,
tokens: proc_macro2::TokenStream,
shift: usize,
offset: usize,
}
Expand All @@ -55,7 +55,7 @@ impl<'a> FieldParser<'a> {
packet_name,
span,
chunk: Vec::new(),
code: Vec::new(),
tokens: quote! {},
shift: 0,
offset: 0,
}
Expand Down Expand Up @@ -91,7 +91,7 @@ impl<'a> FieldParser<'a> {
))
.unwrap();

self.code.push(match &field.desc {
self.tokens.extend(match &field.desc {
ast::FieldDesc::Scalar { id, width } => {
let id = id.to_ident();
let value = types::get_uint(self.endianness, *width, self.span);
Expand Down Expand Up @@ -160,7 +160,7 @@ impl<'a> FieldParser<'a> {
let get = types::get_uint(self.endianness, self.shift, self.span);
if self.chunk.len() > 1 {
// Multiple values: we read into a local variable.
self.code.push(quote! {
self.tokens.extend(quote! {
let #chunk_name = #get;
});
}
Expand Down Expand Up @@ -193,7 +193,7 @@ impl<'a> FieldParser<'a> {
v = quote! { #v as #value_type };
}

self.code.push(match &field.desc {
self.tokens.extend(match &field.desc {
ast::FieldDesc::Scalar { id, .. }
| ast::FieldDesc::Flag { id, .. } => {
let id = id.to_ident();
Expand Down Expand Up @@ -327,7 +327,7 @@ impl<'a> FieldParser<'a> {

fn check_size(&mut self, span: &proc_macro2::Ident, wanted: &proc_macro2::TokenStream) {
let packet_name = &self.packet_name;
self.code.push(quote! {
self.tokens.extend(quote! {
if #span.remaining() < #wanted {
return Err(DecodeError::InvalidLengthError {
obj: #packet_name,
Expand Down Expand Up @@ -394,7 +394,7 @@ impl<'a> FieldParser<'a> {
let span = self.span;
let padding_octets = padding_size / 8;
self.check_size(span, &quote!(#padding_octets));
self.code.push(quote! {
self.tokens.extend(quote! {
let (mut head, tail) = #span.split_at(#padding_octets);
#span = tail;
});
Expand All @@ -416,7 +416,7 @@ impl<'a> FieldParser<'a> {
self.check_size(&span, &quote!(#size_field));
let parse_element =
self.parse_array_element(&format_ident!("head"), width, type_id, decl);
self.code.push(quote! {
self.tokens.extend(quote! {
let (mut head, tail) = #span.split_at(#size_field);
#span = tail;
let mut #id = Vec::new();
Expand All @@ -430,7 +430,7 @@ impl<'a> FieldParser<'a> {
// element count is known statically. Parse elements
// item by item as an array.
let count = syn::Index::from(*count);
self.code.push(quote! {
self.tokens.extend(quote! {
// TODO(mgeisler): use
// https://doc.rust-lang.org/std/array/fn.try_from_fn.html
// when stabilized.
Expand All @@ -447,7 +447,7 @@ impl<'a> FieldParser<'a> {
// The element width is not known, but the array
// element count is known by the count field. Parse
// elements item by item as a vector.
self.code.push(quote! {
self.tokens.extend(quote! {
let #id = (0..#count_field)
.map(|_| #parse_element)
.collect::<Result<Vec<_>, DecodeError>>()?;
Expand All @@ -456,7 +456,7 @@ impl<'a> FieldParser<'a> {
(ElementWidth::Unknown, ArrayShape::Unknown) => {
// Neither the count not size is known, parse elements
// until the end of the span.
self.code.push(quote! {
self.tokens.extend(quote! {
let mut #id = Vec::new();
while !#span.is_empty() {
#id.push(#parse_element?);
Expand All @@ -475,7 +475,7 @@ impl<'a> FieldParser<'a> {
quote!(#count * #element_width)
};
self.check_size(&span, &quote! { #array_size });
self.code.push(quote! {
self.tokens.extend(quote! {
// TODO(mgeisler): use
// https://doc.rust-lang.org/std/array/fn.try_from_fn.html
// when stabilized.
Expand All @@ -492,7 +492,7 @@ impl<'a> FieldParser<'a> {
// The element width is known, and the array element
// count is known dynamically by the count field.
self.check_size(&span, &quote!(#count_field * #element_width));
self.code.push(quote! {
self.tokens.extend(quote! {
let #id = (0..#count_field)
.map(|_| #parse_element)
.collect::<Result<Vec<_>, DecodeError>>()?;
Expand All @@ -512,7 +512,7 @@ impl<'a> FieldParser<'a> {
let count_field = format_ident!("{id}_count");
let array_count = if element_width != 1 {
let element_width = syn::Index::from(element_width);
self.code.push(quote! {
self.tokens.extend(quote! {
if #array_size % #element_width != 0 {
return Err(DecodeError::InvalidArraySize {
array: #array_size,
Expand All @@ -526,7 +526,7 @@ impl<'a> FieldParser<'a> {
array_size
};

self.code.push(quote! {
self.tokens.extend(quote! {
let mut #id = Vec::with_capacity(#array_count);
for _ in 0..#array_count {
#id.push(#parse_element?);
Expand All @@ -547,7 +547,7 @@ impl<'a> FieldParser<'a> {
let parse_element =
self.parse_array_element(&format_ident!("chunk"), width, type_id, decl);

self.code.push(quote! {
self.tokens.extend(quote! {
// TODO: use
// https://doc.rust-lang.org/std/array/fn.try_from_fn.html
// when stabilized.
Expand Down Expand Up @@ -578,7 +578,7 @@ impl<'a> FieldParser<'a> {
let parse_element =
self.parse_array_element(&format_ident!("chunk"), width, type_id, decl);

self.code.push(quote! {
self.tokens.extend(quote! {
let #id = #span.chunks(#element_size_field)
.take(#count_field)
.map(|mut chunk| #parse_element.and_then(|value| {
Expand Down Expand Up @@ -606,7 +606,7 @@ impl<'a> FieldParser<'a> {
} else {
quote!(#span.remaining())
};
self.code.push(quote! {
self.tokens.extend(quote! {
if #array_size % #element_size_field != 0 {
return Err(DecodeError::InvalidArraySize {
array: #array_size,
Expand All @@ -618,7 +618,7 @@ impl<'a> FieldParser<'a> {
let parse_element =
self.parse_array_element(&format_ident!("chunk"), width, type_id, decl);

self.code.push(quote! {
self.tokens.extend(quote! {
let #id = #span.chunks(#element_size_field)
.take(#array_size / #element_size_field)
.map(|mut chunk| #parse_element.and_then(|value| {
Expand Down Expand Up @@ -650,7 +650,7 @@ impl<'a> FieldParser<'a> {
let id = id.to_ident();
let type_id = type_id.to_ident();

self.code.push(match self.schema.decl_size(decl.key) {
self.tokens.extend(match self.schema.decl_size(decl.key) {
analyzer::Size::Unknown | analyzer::Size::Dynamic => quote! {
let (#id, #span) = #type_id::decode(#span)?;
},
Expand Down Expand Up @@ -706,7 +706,7 @@ impl<'a> FieldParser<'a> {
// Push code to check that the size is greater than the size
// modifier. Required to safely substract the modifier from the
// size.
self.code.push(quote! {
self.tokens.extend(quote! {
if #size_field < #size_modifier {
return Err(DecodeError::InvalidLengthError {
obj: #packet_name,
Expand All @@ -718,14 +718,14 @@ impl<'a> FieldParser<'a> {
});
}
self.check_size(self.span, &quote!(#size_field ));
self.code.push(quote! {
self.tokens.extend(quote! {
let payload = #span[..#size_field].to_vec();
#span.advance(#size_field);
});
} else if offset_from_end == Some(0) {
// The payload or body is the last field of a packet,
// consume the remaining span.
self.code.push(quote! {
self.tokens.extend(quote! {
let payload = #span.to_vec();
#span.advance(payload.len());
});
Expand All @@ -740,15 +740,15 @@ impl<'a> FieldParser<'a> {
);
let offset_from_end = syn::Index::from(offset_from_end / 8);
self.check_size(self.span, &quote!(#offset_from_end));
self.code.push(quote! {
self.tokens.extend(quote! {
let payload = #span[..#span.len() - #offset_from_end].to_vec();
#span.advance(payload.len());
});
}

let decl = self.scope.typedef[self.packet_name];
if let ast::DeclDesc::Struct { .. } = &decl.desc {
self.code.push(quote! {
self.tokens.extend(quote! {
let payload = Vec::from(payload);
});
}
Expand Down Expand Up @@ -792,10 +792,7 @@ impl<'a> FieldParser<'a> {

impl quote::ToTokens for FieldParser<'_> {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let code = &self.code;
tokens.extend(quote! {
#(#code)*
});
tokens.extend(self.tokens.clone());
}
}

Expand Down

0 comments on commit cb14a2a

Please sign in to comment.