Skip to content

Commit

Permalink
Use named structs to type ast node keys
Browse files Browse the repository at this point in the history
  • Loading branch information
hchataing committed Apr 3, 2024
1 parent f8ca870 commit 50b210f
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 36 deletions.
40 changes: 23 additions & 17 deletions pdl-compiler/src/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ pub struct Scope<'d> {
/// Gather size information about the full AST.
#[derive(Debug)]
pub struct Schema {
size: HashMap<usize, Size>,
padded_size: HashMap<usize, Option<usize>>,
payload_size: HashMap<usize, Size>,
decl_size: HashMap<DeclKey, Size>,
field_size: HashMap<FieldKey, Size>,
padded_size: HashMap<FieldKey, Option<usize>>,
payload_size: HashMap<DeclKey, Size>,
}

impl Diagnostics {
Expand Down Expand Up @@ -328,7 +329,7 @@ impl Schema {
/// Check correct definition of packet sizes.
/// Annotate fields and declarations with the size in bits.
pub fn new(file: &File) -> Schema {
fn annotate_decl(schema: &mut Schema, scope: &HashMap<String, usize>, decl: &Decl) {
fn annotate_decl(schema: &mut Schema, scope: &HashMap<String, DeclKey>, decl: &Decl) {
// Compute the padding size for each field.
let mut padding = None;
for field in decl.fields().rev() {
Expand All @@ -342,7 +343,7 @@ impl Schema {
let mut size = decl
.parent_id()
.and_then(|parent_id| scope.get(parent_id))
.map(|key| schema.size(*key))
.map(|key| schema.decl_size(*key))
.unwrap_or(Size::Static(0));
let mut payload_size = Size::Static(0);

Expand Down Expand Up @@ -378,13 +379,13 @@ impl Schema {
DeclDesc::Test { .. } => (Size::Static(0), Size::Static(0)),
};

schema.size.insert(decl.key, size);
schema.decl_size.insert(decl.key, size);
schema.payload_size.insert(decl.key, payload_size);
}

fn annotate_field(
schema: &mut Schema,
scope: &HashMap<String, usize>,
scope: &HashMap<String, DeclKey>,
decl: &Decl,
field: &Field,
) -> Size {
Expand Down Expand Up @@ -445,7 +446,7 @@ impl Schema {
FieldDesc::Array { .. } => unreachable!(),
};

schema.size.insert(field.key, size);
schema.field_size.insert(field.key, size);
size
}

Expand All @@ -457,7 +458,8 @@ impl Schema {
}

let mut schema = Schema {
size: Default::default(),
field_size: Default::default(),
decl_size: Default::default(),
padded_size: Default::default(),
payload_size: Default::default(),
};
Expand All @@ -469,20 +471,24 @@ impl Schema {
schema
}

pub fn size(&self, key: usize) -> Size {
*self.size.get(&key).unwrap()
pub fn field_size(&self, key: FieldKey) -> Size {
*self.field_size.get(&key).unwrap()
}

pub fn padded_size(&self, key: usize) -> Option<usize> {
pub fn decl_size(&self, key: DeclKey) -> Size {
*self.decl_size.get(&key).unwrap()
}

pub fn padded_size(&self, key: FieldKey) -> Option<usize> {
*self.padded_size.get(&key).unwrap()
}

pub fn payload_size(&self, key: usize) -> Size {
pub fn payload_size(&self, key: DeclKey) -> Size {
*self.payload_size.get(&key).unwrap()
}

pub fn total_size(&self, key: usize) -> Size {
self.size(key) + self.payload_size(key)
pub fn total_size(&self, key: DeclKey) -> Size {
self.decl_size(key) + self.payload_size(key)
}
}

Expand Down Expand Up @@ -3062,9 +3068,9 @@ mod test {
file.declarations
.iter()
.map(|decl| Annotations {
size: schema.size(decl.key),
size: schema.decl_size(decl.key),
payload_size: schema.payload_size(decl.key),
fields: decl.fields().map(|field| schema.size(field.key)).collect(),
fields: decl.fields().map(|field| schema.field_size(field.key)).collect(),
})
.collect()
}
Expand Down
10 changes: 8 additions & 2 deletions pdl-compiler/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ pub struct Constraint {
pub tag_id: Option<String>,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct FieldKey(pub usize);

#[derive(Debug, Serialize, Clone, PartialEq, Eq)]
#[serde(tag = "kind")]
pub enum FieldDesc {
Expand Down Expand Up @@ -154,7 +157,7 @@ pub struct Field {
/// Unique identifier used to refer to the AST node in
/// compilation environments.
#[serde(skip_serializing)]
pub key: usize,
pub key: FieldKey,
#[serde(flatten)]
pub desc: FieldDesc,
pub cond: Option<Constraint>,
Expand All @@ -167,6 +170,9 @@ pub struct TestCase {
pub input: String,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct DeclKey(pub usize);

#[derive(Debug, Serialize, Clone, PartialEq, Eq)]
#[serde(tag = "kind")]
pub enum DeclDesc {
Expand Down Expand Up @@ -202,7 +208,7 @@ pub struct Decl {
/// Unique identifier used to refer to the AST node in
/// compilation environments.
#[serde(skip_serializing)]
pub key: usize,
pub key: DeclKey,
#[serde(flatten)]
pub desc: DeclDesc,
}
Expand Down
4 changes: 3 additions & 1 deletion pdl-compiler/src/backends/rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ fn generate_packet_size_getter<'a>(
let mut dynamic_widths = Vec::new();

for field in fields {
if let Some(width) = schema.padded_size(field.key).or(schema.size(field.key).static_()) {
if let Some(width) =
schema.padded_size(field.key).or(schema.field_size(field.key).static_())
{
constant_width += width;
continue;
}
Expand Down
8 changes: 4 additions & 4 deletions pdl-compiler/src/backends/rust/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl<'a> FieldParser<'a> {

fn add_bit_field(&mut self, field: &'a ast::Field) {
self.chunk.push(BitField { shift: self.shift, field });
self.shift += self.schema.size(field.key).static_().unwrap();
self.shift += self.schema.field_size(field.key).static_().unwrap();
if self.shift % 8 != 0 {
return;
}
Expand Down Expand Up @@ -184,7 +184,7 @@ impl<'a> FieldParser<'a> {
v = quote! { (#v >> #shift) }
}

let width = self.schema.size(field.key).static_().unwrap();
let width = self.schema.field_size(field.key).static_().unwrap();
let value_type = types::Integer::new(width);
if !single_value && width < value_type.width {
// Mask value if we grabbed more than `width` and if
Expand Down Expand Up @@ -303,7 +303,7 @@ impl<'a> FieldParser<'a> {
let mut offset = 0;
for field in fields {
if let Some(width) =
self.schema.padded_size(field.key).or(self.schema.size(field.key).static_())
self.schema.padded_size(field.key).or(self.schema.field_size(field.key).static_())
{
offset += width;
} else {
Expand Down Expand Up @@ -534,7 +534,7 @@ impl<'a> FieldParser<'a> {
let id = id.to_ident();
let type_id = type_id.to_ident();

self.code.push(match self.schema.size(decl.key) {
self.code.push(match self.schema.decl_size(decl.key) {
analyzer::Size::Unknown | analyzer::Size::Dynamic => quote! {
let #id = #type_id::parse_inner(&mut #span)?;
},
Expand Down
4 changes: 2 additions & 2 deletions pdl-compiler/src/backends/rust/serializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ impl<'a> FieldSerializer<'a> {
}

fn add_bit_field(&mut self, field: &ast::Field) {
let width = self.schema.size(field.key).static_().unwrap();
let width = self.schema.field_size(field.key).static_().unwrap();
let shift = self.shift;

match &field.desc {
Expand Down Expand Up @@ -405,7 +405,7 @@ impl<'a> FieldSerializer<'a> {
let padding_octets = padding_size / 8;
let element_width = match &width {
Some(width) => Some(*width),
None => self.schema.size(decl.unwrap().key).static_(),
None => self.schema.decl_size(decl.unwrap().key).static_(),
};

let array_size = match element_width {
Expand Down
24 changes: 14 additions & 10 deletions pdl-compiler/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,12 @@ trait Helpers<'i> {
}

impl<'a> Context<'a> {
fn key(&self) -> usize {
self.key.replace(self.key.get() + 1)
fn field_key(&self) -> ast::FieldKey {
ast::FieldKey(self.key.replace(self.key.get() + 1))
}

fn decl_key(&self) -> ast::DeclKey {
ast::DeclKey(self.key.replace(self.key.get() + 1))
}
}

Expand Down Expand Up @@ -418,7 +422,7 @@ fn parse_field(node: Node<'_>, context: &Context) -> Result<ast::Field, String>
let mut children = desc.children();
Ok(ast::Field {
loc,
key: context.key(),
key: context.field_key(),
cond: cond.map(|constraint| parse_constraint(constraint, context)).transpose()?,
desc: match rule {
Rule::checksum_field => {
Expand Down Expand Up @@ -559,7 +563,7 @@ fn parse_toplevel(root: Node<'_>, context: &Context) -> Result<ast::File, String
let function = parse_string(&mut children)?;
file.declarations.push(ast::Decl {
loc,
key: context.key(),
key: context.decl_key(),
desc: ast::DeclDesc::Checksum { id, function, width },
})
}
Expand All @@ -571,7 +575,7 @@ fn parse_toplevel(root: Node<'_>, context: &Context) -> Result<ast::File, String
let function = parse_string(&mut children)?;
file.declarations.push(ast::Decl {
loc,
key: context.key(),
key: context.decl_key(),
desc: ast::DeclDesc::CustomField { id, function, width },
})
}
Expand All @@ -583,7 +587,7 @@ fn parse_toplevel(root: Node<'_>, context: &Context) -> Result<ast::File, String
let tags = parse_enum_tag_list(&mut children, context)?;
file.declarations.push(ast::Decl {
loc,
key: context.key(),
key: context.decl_key(),
desc: ast::DeclDesc::Enum { id, width, tags },
})
}
Expand All @@ -596,7 +600,7 @@ fn parse_toplevel(root: Node<'_>, context: &Context) -> Result<ast::File, String
let fields = parse_field_list_opt(&mut children, context)?;
file.declarations.push(ast::Decl {
loc,
key: context.key(),
key: context.decl_key(),
desc: ast::DeclDesc::Packet { id, parent_id, constraints, fields },
})
}
Expand All @@ -609,7 +613,7 @@ fn parse_toplevel(root: Node<'_>, context: &Context) -> Result<ast::File, String
let fields = parse_field_list_opt(&mut children, context)?;
file.declarations.push(ast::Decl {
loc,
key: context.key(),
key: context.decl_key(),
desc: ast::DeclDesc::Struct { id, parent_id, constraints, fields },
})
}
Expand All @@ -620,7 +624,7 @@ fn parse_toplevel(root: Node<'_>, context: &Context) -> Result<ast::File, String
let fields = parse_field_list(&mut children, context)?;
file.declarations.push(ast::Decl {
loc,
key: context.key(),
key: context.decl_key(),
desc: ast::DeclDesc::Group { id, fields },
})
}
Expand All @@ -630,7 +634,7 @@ fn parse_toplevel(root: Node<'_>, context: &Context) -> Result<ast::File, String
}
}
file.comments.append(&mut toplevel_comments);
file.max_key = context.key();
file.max_key = context.key.get();
Ok(file)
}

Expand Down

0 comments on commit 50b210f

Please sign in to comment.