Skip to content

Commit

Permalink
match_nodes: Tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
Nadrieril committed Mar 3, 2024
1 parent ef6910e commit 8a851d0
Showing 1 changed file with 40 additions and 47 deletions.
87 changes: 40 additions & 47 deletions pest_consume_macros/src/match_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,45 +7,61 @@ use syn::{
bracketed, parenthesized, parse_quote, token, Expr, Ident, Pat, Token, Type,
};

#[derive(Clone)]
struct MatchBranchPattern {
struct Pattern {
tag: Option<String>,
rule_name: Option<Ident>,
binder: Pat,
multiple: bool,
}

#[derive(Clone)]
struct MatchBranch {
patterns: Vec<MatchBranchPattern>,
patterns: Punctuated<Pattern, Token![,]>,
body: Expr,
}

#[derive(Clone)]
struct MacroInput {
parser: Type,
input_expr: Expr,
branches: Punctuated<MatchBranch, Token![,]>,
}

impl Parse for MacroInput {
fn parse(input: ParseStream) -> Result<Self> {
let parser = if input.peek(token::Lt) {
let _: token::Lt = input.parse()?;
let parser = input.parse()?;
let _: token::Gt = input.parse()?;
let _: Token![;] = input.parse()?;
parser
} else {
parse_quote!(Self)
};
let input_expr = input.parse()?;
let _: Token![;] = input.parse()?;
let branches = Punctuated::parse_terminated(input)?;

Ok(MacroInput {
parser,
input_expr,
branches,
})
}
}

impl Parse for MatchBranch {
fn parse(input: ParseStream) -> Result<Self> {
let contents;
let _: token::Bracket = bracketed!(contents in input);

let patterns: Punctuated<MatchBranchPattern, Token![,]> =
Punctuated::parse_terminated(&contents)?;
let patterns = Punctuated::parse_terminated(&contents)?;
let _: Token![=>] = input.parse()?;
let body = input.parse()?;

Ok(MatchBranch {
patterns: patterns.into_iter().collect(),
body,
})
Ok(MatchBranch { patterns, body })
}
}

impl Parse for MatchBranchPattern {
impl Parse for Pattern {
fn parse(input: ParseStream) -> Result<Self> {
let mut tag = None;
let binder;
Expand Down Expand Up @@ -81,7 +97,7 @@ impl Parse for MatchBranchPattern {
} else {
return Err(input.error("expected `..` or nothing"));
}
Ok(MatchBranchPattern {
Ok(Pattern {
tag,
rule_name,
binder,
Expand All @@ -90,40 +106,16 @@ impl Parse for MatchBranchPattern {
}
}

impl Parse for MacroInput {
fn parse(input: ParseStream) -> Result<Self> {
let parser = if input.peek(token::Lt) {
let _: token::Lt = input.parse()?;
let parser = input.parse()?;
let _: token::Gt = input.parse()?;
let _: Token![;] = input.parse()?;
parser
} else {
parse_quote!(Self)
};
let input_expr = input.parse()?;
let _: Token![;] = input.parse()?;
let branches = Punctuated::parse_terminated(input)?;

Ok(MacroInput {
parser,
input_expr,
branches,
})
}
}

/// Takes the ident of a mutable slice. Generates code that matches on the pattern and calls
/// `process_item` for each item. Calls `error` if we can't proceed.
fn traverse_pattern(
branch: &MatchBranch,
mut patterns: &[Pattern],
i_iter: &Ident,
matches_pat: impl Fn(&MatchBranchPattern, TokenStream) -> TokenStream,
process_item: impl Fn(&MatchBranchPattern, TokenStream) -> TokenStream,
matches_pat: impl Fn(&Pattern, TokenStream) -> TokenStream,
process_item: impl Fn(&Pattern, TokenStream) -> TokenStream,
error: TokenStream,
) -> TokenStream {
let mut steps = Vec::new();
let mut patterns = branch.patterns.as_slice();

// We will match variable patterns greedily. In order for trailing single patterns like `[x..,
// y, z]` to work, we must handle them first.
Expand Down Expand Up @@ -165,16 +157,17 @@ fn traverse_pattern(
}

fn make_branch(
branch: &MatchBranch,
branch: MatchBranch,
i_nodes: &Ident,
i_node_namer: &Ident,
parser: &Type,
) -> TokenStream {
let i_nodes_iter = Ident::new("___nodes_iter", Span::call_site());
let name_enum = quote!(<#parser as ::pest_consume::NodeMatcher>::NodeName);
let node_namer_ty = quote!(<_ as ::pest_consume::NodeNamer<#parser>>);
let patterns: Vec<_> = branch.patterns.into_iter().collect();

let matches_pat = |pat: &MatchBranchPattern, x| {
let matches_pat = |pat: &Pattern, x| {
let rule_cond = match &pat.rule_name {
Some(rule_name) => {
quote!(#node_namer_ty::node_name(&#i_node_namer, &#x) == #name_enum::#rule_name)
Expand All @@ -191,7 +184,7 @@ fn make_branch(
};

// Determine if we can take this branch.
let process_item = |pat: &MatchBranchPattern, i_matched| {
let process_item = |pat: &Pattern, i_matched| {
if !pat.multiple {
let cond = matches_pat(pat, i_matched);
quote!(
Expand All @@ -205,7 +198,7 @@ fn make_branch(
}
};
let conditions = traverse_pattern(
branch,
patterns.as_slice(),
&i_nodes_iter,
matches_pat,
process_item,
Expand All @@ -217,7 +210,7 @@ fn make_branch(
Some(rule_name) => quote!(#parser::#rule_name(#node)),
None => quote!(Ok(#node)),
};
let process_item = |pat: &MatchBranchPattern, i_matched| {
let process_item = |pat: &Pattern, i_matched| {
if !pat.multiple {
let parse = parse_rule(&pat.rule_name, quote!(#i_matched));
let binder = &pat.binder;
Expand All @@ -236,7 +229,7 @@ fn make_branch(
}
};
let parses = traverse_pattern(
branch,
patterns.as_slice(),
&i_nodes_iter,
matches_pat,
process_item,
Expand Down Expand Up @@ -273,7 +266,7 @@ pub fn match_nodes(
let parser = &input.parser;
let branches = input
.branches
.iter()
.into_iter()
.map(|br| make_branch(br, &i_nodes, &i_node_namer, parser))
.collect::<Vec<_>>();

Expand Down

0 comments on commit 8a851d0

Please sign in to comment.