Skip to content

Commit

Permalink
Validate constraints in decode_partial
Browse files Browse the repository at this point in the history
  • Loading branch information
hchataing committed May 17, 2024
1 parent fe1935b commit a10c6d4
Show file tree
Hide file tree
Showing 14 changed files with 286 additions and 0 deletions.
23 changes: 23 additions & 0 deletions pdl-compiler/src/backends/rust/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,27 @@ fn generate_derived_packet_decl(
}
};

// Constraint checks are only run for constraints added to this declaration
// and not parent constraints which are expected to have been validated
// earlier.
let constraint_checks = decl.constraints().map(|c| {
let field_id = c.id.to_ident();
let field_name = &c.id;
let packet_name = id;
let value = constraint_value(&parent_data_fields, c);
let value_str = value.to_string();
quote! {
if parent.#field_id() != #value {
return Err(DecodeError::InvalidFieldValue {
packet: #packet_name,
field: #field_name,
expected: #value_str,
actual: format!("{:?}", parent.#field_id()),
})
}
}
});

let decode_partial = if parent_decl.payload().is_some() {
// Generate an implementation of decode_partial that will decode
// data fields present in the parent payload.
Expand All @@ -453,6 +474,7 @@ fn generate_derived_packet_decl(
quote! {
fn decode_partial(parent: &#parent_name) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
#( #constraint_checks )*
#field_parser
if buf.is_empty() {
Ok(Self {
Expand All @@ -472,6 +494,7 @@ fn generate_derived_packet_decl(
// return DecodeError::InvalidConstraint.
quote! {
fn decode_partial(parent: &#parent_name) -> Result<Self, DecodeError> {
#( #constraint_checks )*
Ok(Self {
#( #copied_field_ids: parent.#copied_field_ids, )*
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,14 @@ impl TryFrom<&Bar> for Vec<u8> {
impl Bar {
fn decode_partial(parent: &Foo) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.a() != 100 {
return Err(DecodeError::InvalidFieldValue {
packet: "Bar",
field: "a",
expected: "100",
actual: format!("{:?}", parent.a()),
});
}
if buf.remaining() < 1 {
return Err(DecodeError::InvalidLengthError {
obj: "Bar",
Expand Down Expand Up @@ -313,6 +321,14 @@ impl TryFrom<&Baz> for Vec<u8> {
impl Baz {
fn decode_partial(parent: &Foo) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.b() != Enum16::B {
return Err(DecodeError::InvalidFieldValue {
packet: "Baz",
field: "b",
expected: "Enum16 :: B",
actual: format!("{:?}", parent.b()),
});
}
if buf.remaining() < 2 {
return Err(DecodeError::InvalidLengthError {
obj: "Baz",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,14 @@ impl TryFrom<&Bar> for Vec<u8> {
impl Bar {
fn decode_partial(parent: &Foo) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.a() != 100 {
return Err(DecodeError::InvalidFieldValue {
packet: "Bar",
field: "a",
expected: "100",
actual: format!("{:?}", parent.a()),
});
}
if buf.remaining() < 1 {
return Err(DecodeError::InvalidLengthError {
obj: "Bar",
Expand Down Expand Up @@ -313,6 +321,14 @@ impl TryFrom<&Baz> for Vec<u8> {
impl Baz {
fn decode_partial(parent: &Foo) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.b() != Enum16::B {
return Err(DecodeError::InvalidFieldValue {
packet: "Baz",
field: "b",
expected: "Enum16 :: B",
actual: format!("{:?}", parent.b()),
});
}
if buf.remaining() < 2 {
return Err(DecodeError::InvalidLengthError {
obj: "Baz",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,14 @@ impl Child {
}
fn decode_partial(parent: &Parent) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.foo() != Enum16::A {
return Err(DecodeError::InvalidFieldValue {
packet: "Child",
field: "foo",
expected: "Enum16 :: A",
actual: format!("{:?}", parent.foo()),
});
}
if buf.remaining() < 2 {
return Err(DecodeError::InvalidLengthError {
obj: "Child",
Expand Down Expand Up @@ -395,6 +403,22 @@ impl GrandChild {
}
fn decode_partial(parent: &Child) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.bar() != Enum16::A {
return Err(DecodeError::InvalidFieldValue {
packet: "GrandChild",
field: "bar",
expected: "Enum16 :: A",
actual: format!("{:?}", parent.bar()),
});
}
if parent.quux() != Enum16::A {
return Err(DecodeError::InvalidFieldValue {
packet: "GrandChild",
field: "quux",
expected: "Enum16 :: A",
actual: format!("{:?}", parent.quux()),
});
}
let payload = buf.to_vec();
buf.advance(payload.len());
if buf.is_empty() {
Expand Down Expand Up @@ -499,6 +523,14 @@ impl TryFrom<&GrandGrandChild> for Vec<u8> {
impl GrandGrandChild {
fn decode_partial(parent: &GrandChild) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.baz() != Enum16::A {
return Err(DecodeError::InvalidFieldValue {
packet: "GrandGrandChild",
field: "baz",
expected: "Enum16 :: A",
actual: format!("{:?}", parent.baz()),
});
}
let payload = buf.to_vec();
buf.advance(payload.len());
if buf.is_empty() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,14 @@ impl Child {
}
fn decode_partial(parent: &Parent) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.foo() != Enum16::A {
return Err(DecodeError::InvalidFieldValue {
packet: "Child",
field: "foo",
expected: "Enum16 :: A",
actual: format!("{:?}", parent.foo()),
});
}
if buf.remaining() < 2 {
return Err(DecodeError::InvalidLengthError {
obj: "Child",
Expand Down Expand Up @@ -395,6 +403,22 @@ impl GrandChild {
}
fn decode_partial(parent: &Child) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.bar() != Enum16::A {
return Err(DecodeError::InvalidFieldValue {
packet: "GrandChild",
field: "bar",
expected: "Enum16 :: A",
actual: format!("{:?}", parent.bar()),
});
}
if parent.quux() != Enum16::A {
return Err(DecodeError::InvalidFieldValue {
packet: "GrandChild",
field: "quux",
expected: "Enum16 :: A",
actual: format!("{:?}", parent.quux()),
});
}
let payload = buf.to_vec();
buf.advance(payload.len());
if buf.is_empty() {
Expand Down Expand Up @@ -499,6 +523,14 @@ impl TryFrom<&GrandGrandChild> for Vec<u8> {
impl GrandGrandChild {
fn decode_partial(parent: &GrandChild) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.baz() != Enum16::A {
return Err(DecodeError::InvalidFieldValue {
packet: "GrandGrandChild",
field: "baz",
expected: "Enum16 :: A",
actual: format!("{:?}", parent.baz()),
});
}
let payload = buf.to_vec();
buf.advance(payload.len());
if buf.is_empty() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,14 @@ impl TryFrom<&NormalChild> for Vec<u8> {
impl NormalChild {
fn decode_partial(parent: &Parent) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.v() != Enum8::A {
return Err(DecodeError::InvalidFieldValue {
packet: "NormalChild",
field: "v",
expected: "Enum8 :: A",
actual: format!("{:?}", parent.v()),
});
}
if buf.is_empty() { Ok(Self {}) } else { Err(DecodeError::TrailingBytes) }
}
pub fn encode_partial(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
Expand Down Expand Up @@ -337,6 +345,14 @@ impl TryFrom<&NormalGrandChild1> for Vec<u8> {
impl NormalGrandChild1 {
fn decode_partial(parent: &AliasChild) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.v() != Enum8::B {
return Err(DecodeError::InvalidFieldValue {
packet: "NormalGrandChild1",
field: "v",
expected: "Enum8 :: B",
actual: format!("{:?}", parent.v()),
});
}
if buf.is_empty() { Ok(Self {}) } else { Err(DecodeError::TrailingBytes) }
}
pub fn encode_partial(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
Expand Down Expand Up @@ -401,6 +417,14 @@ impl TryFrom<&NormalGrandChild2> for Vec<u8> {
impl NormalGrandChild2 {
fn decode_partial(parent: &AliasChild) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.v() != Enum8::C {
return Err(DecodeError::InvalidFieldValue {
packet: "NormalGrandChild2",
field: "v",
expected: "Enum8 :: C",
actual: format!("{:?}", parent.v()),
});
}
let payload = buf.to_vec();
buf.advance(payload.len());
if buf.is_empty() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,14 @@ impl TryFrom<&NormalChild> for Vec<u8> {
impl NormalChild {
fn decode_partial(parent: &Parent) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.v() != Enum8::A {
return Err(DecodeError::InvalidFieldValue {
packet: "NormalChild",
field: "v",
expected: "Enum8 :: A",
actual: format!("{:?}", parent.v()),
});
}
if buf.is_empty() { Ok(Self {}) } else { Err(DecodeError::TrailingBytes) }
}
pub fn encode_partial(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
Expand Down Expand Up @@ -337,6 +345,14 @@ impl TryFrom<&NormalGrandChild1> for Vec<u8> {
impl NormalGrandChild1 {
fn decode_partial(parent: &AliasChild) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.v() != Enum8::B {
return Err(DecodeError::InvalidFieldValue {
packet: "NormalGrandChild1",
field: "v",
expected: "Enum8 :: B",
actual: format!("{:?}", parent.v()),
});
}
if buf.is_empty() { Ok(Self {}) } else { Err(DecodeError::TrailingBytes) }
}
pub fn encode_partial(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
Expand Down Expand Up @@ -401,6 +417,14 @@ impl TryFrom<&NormalGrandChild2> for Vec<u8> {
impl NormalGrandChild2 {
fn decode_partial(parent: &AliasChild) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.v() != Enum8::C {
return Err(DecodeError::InvalidFieldValue {
packet: "NormalGrandChild2",
field: "v",
expected: "Enum8 :: C",
actual: format!("{:?}", parent.v()),
});
}
let payload = buf.to_vec();
buf.advance(payload.len());
if buf.is_empty() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,14 @@ impl TryFrom<&Child> for Vec<u8> {
}
impl Child {
fn decode_partial(parent: &Parent) -> Result<Self, DecodeError> {
if parent.v() != Enum8::A {
return Err(DecodeError::InvalidFieldValue {
packet: "Child",
field: "v",
expected: "Enum8 :: A",
actual: format!("{:?}", parent.v()),
});
}
Ok(Self {})
}
pub fn encode_partial(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,14 @@ impl TryFrom<&Child> for Vec<u8> {
}
impl Child {
fn decode_partial(parent: &Parent) -> Result<Self, DecodeError> {
if parent.v() != Enum8::A {
return Err(DecodeError::InvalidFieldValue {
packet: "Child",
field: "v",
expected: "Enum8 :: A",
actual: format!("{:?}", parent.v()),
});
}
Ok(Self {})
}
pub fn encode_partial(&self, buf: &mut impl BufMut) -> Result<(), EncodeError> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,14 @@ impl TryFrom<&Bar> for Vec<u8> {
impl Bar {
fn decode_partial(parent: &Foo) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.a() != 100 {
return Err(DecodeError::InvalidFieldValue {
packet: "Bar",
field: "a",
expected: "100",
actual: format!("{:?}", parent.a()),
});
}
if buf.remaining() < 1 {
return Err(DecodeError::InvalidLengthError {
obj: "Bar",
Expand Down Expand Up @@ -314,6 +322,14 @@ impl TryFrom<&Baz> for Vec<u8> {
impl Baz {
fn decode_partial(parent: &Foo) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.b() != Enum16::B {
return Err(DecodeError::InvalidFieldValue {
packet: "Baz",
field: "b",
expected: "Enum16 :: B",
actual: format!("{:?}", parent.b()),
});
}
if buf.remaining() < 2 {
return Err(DecodeError::InvalidLengthError {
obj: "Baz",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,14 @@ impl TryFrom<&Bar> for Vec<u8> {
impl Bar {
fn decode_partial(parent: &Foo) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.a() != 100 {
return Err(DecodeError::InvalidFieldValue {
packet: "Bar",
field: "a",
expected: "100",
actual: format!("{:?}", parent.a()),
});
}
if buf.remaining() < 1 {
return Err(DecodeError::InvalidLengthError {
obj: "Bar",
Expand Down Expand Up @@ -314,6 +322,14 @@ impl TryFrom<&Baz> for Vec<u8> {
impl Baz {
fn decode_partial(parent: &Foo) -> Result<Self, DecodeError> {
let mut buf: &[u8] = &parent.payload;
if parent.b() != Enum16::B {
return Err(DecodeError::InvalidFieldValue {
packet: "Baz",
field: "b",
expected: "Enum16 :: B",
actual: format!("{:?}", parent.b()),
});
}
if buf.remaining() < 2 {
return Err(DecodeError::InvalidLengthError {
obj: "Baz",
Expand Down
Loading

0 comments on commit a10c6d4

Please sign in to comment.