From f82e8be0792b90703fff9d7662ad9ff6e83adc08 Mon Sep 17 00:00:00 2001 From: Yun-Jhong Wu Date: Tue, 5 Dec 2023 19:34:13 -0600 Subject: [PATCH] Implement bool logic operations and move them to StrongNumericType (#15) --- README.md | 19 ++++- src/detail/bool_ops.rs | 154 +++++++++++++++++++++++++++++++++++++++++ src/detail/mod.rs | 4 +- src/detail/not.rs | 14 ---- src/strong_type.rs | 10 +-- tests/strong_type.rs | 30 ++++++-- 6 files changed, 204 insertions(+), 27 deletions(-) create mode 100644 src/detail/bool_ops.rs delete mode 100644 src/detail/not.rs diff --git a/README.md b/README.md index 912efbc..ee99dab 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ assert_eq!(x.type_id(), y.type_id()); // Same type: Second assert_ne!(y.type_id(), z.type_id()); // Different types: Second versus Minute ``` -#### Named type with arithmetic operations: +#### Named integer type with arithmetic operations: ```rust use strong_type::StrongNumericType; @@ -75,6 +75,23 @@ assert!(y >= x); assert_eq!(x + y, Second(5)); ``` +#### Named bool type with logical operations: + +```rust +use strong_type::StrongNumericType; + +#[derive(StrongNumericType)] +struct IsTrue(bool); + +let x = IsTrue::new(true); +let y = IsTrue::new(false); + +assert_eq!(x & y, IsTrue::new(false)); +assert_eq!(x | y, IsTrue::new(true)); +assert_eq!(x ^ y_ref, IsTrue::new(true)); +assert_eq!(!x, IsTrue::new(false)); +``` + #### Named type with `custom_display`: ```rust diff --git a/src/detail/bool_ops.rs b/src/detail/bool_ops.rs new file mode 100644 index 0000000..0809fef --- /dev/null +++ b/src/detail/bool_ops.rs @@ -0,0 +1,154 @@ +use proc_macro2::TokenStream; +use quote::quote; + +pub(crate) fn implement_bool_ops(name: &syn::Ident) -> TokenStream { + quote! { + impl std::ops::Not for #name { + type Output = Self; + + fn not(self) -> Self::Output { + Self::new(!self.value()) + } + } + + impl std::ops::Not for &#name { + type Output = #name; + + fn not(self) -> Self::Output { + #name::new(!self.value()) + } + } + + impl std::ops::BitAnd for #name { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self::Output { + Self::new(self.value() & rhs.value()) + } + } + + impl std::ops::BitAnd<&Self> for #name { + type Output = Self; + + fn bitand(self, rhs: &Self) -> Self::Output { + Self::new(self.value() & rhs.value()) + } + } + + impl std::ops::BitAnd<#name> for &#name { + type Output = #name; + + fn bitand(self, rhs: #name) -> Self::Output { + #name::new(self.value() & rhs.value()) + } + } + + impl std::ops::BitAnd<&#name> for &#name { + type Output = #name; + + fn bitand(self, rhs: &#name) -> Self::Output { + #name::new(self.value() & rhs.value()) + } + } + + impl std::ops::BitAndAssign for #name { + fn bitand_assign(&mut self, rhs: Self) { + self.0 &= rhs.value() + } + } + + impl std::ops::BitAndAssign<&Self> for #name { + fn bitand_assign(&mut self, rhs: &Self) { + self.0 &= rhs.value() + } + } + + impl std::ops::BitOr for #name { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { + Self::new(self.value() | rhs.value()) + } + } + + impl std::ops::BitOr<&Self> for #name { + type Output = Self; + + fn bitor(self, rhs: &Self) -> Self::Output { + Self::new(self.value() | rhs.value()) + } + } + + impl std::ops::BitOr<#name> for &#name { + type Output = #name; + + fn bitor(self, rhs: #name) -> Self::Output { + #name::new(self.value() | rhs.value()) + } + } + + impl std::ops::BitOr<&#name> for &#name { + type Output = #name; + + fn bitor(self, rhs: &#name) -> Self::Output { + #name::new(self.value() | rhs.value()) + } + } + + impl std::ops::BitOrAssign for #name { + fn bitor_assign(&mut self, rhs: Self) { + self.0 |= rhs.value() + } + } + + impl std::ops::BitOrAssign<&Self> for #name { + fn bitor_assign(&mut self, rhs: &Self) { + self.0 |= rhs.value() + } + } + + impl std::ops::BitXor for #name { + type Output = Self; + + fn bitxor(self, rhs: Self) -> Self::Output { + Self::new(self.value() ^ rhs.value()) + } + } + + impl std::ops::BitXor<&Self> for #name { + type Output = Self; + + fn bitxor(self, rhs: &Self) -> Self::Output { + Self::new(self.value() ^ rhs.value()) + } + } + + impl std::ops::BitXor<#name> for &#name { + type Output = #name; + + fn bitxor(self, rhs: #name) -> Self::Output { + #name::new(self.value() ^ rhs.value()) + } + } + + impl std::ops::BitXor<&#name> for &#name { + type Output = #name; + + fn bitxor(self, rhs: &#name) -> Self::Output { + #name::new(self.value() ^ rhs.value()) + } + } + + impl std::ops::BitXorAssign for #name { + fn bitxor_assign(&mut self, rhs: Self) { + self.0 ^= rhs.value() + } + } + + impl std::ops::BitXorAssign<&Self> for #name { + fn bitxor_assign(&mut self, rhs: &Self) { + self.0 ^= rhs.value() + } + } + } +} diff --git a/src/detail/mod.rs b/src/detail/mod.rs index 7646b5a..92893e2 100644 --- a/src/detail/mod.rs +++ b/src/detail/mod.rs @@ -2,22 +2,22 @@ mod arithmetic; mod basic; mod basic_primitive; mod basic_string; +mod bool_ops; mod display; mod hash; mod min_max; mod nan; mod negate; -mod not; mod underlying_type; pub(crate) use arithmetic::implement_arithmetic; pub(crate) use basic::implement_basic; pub(crate) use basic_primitive::implement_basic_primitive; pub(crate) use basic_string::implement_basic_string; +pub(crate) use bool_ops::implement_bool_ops; pub(crate) use display::{custom_display, implement_display}; pub(crate) use hash::implement_hash; pub(crate) use min_max::implement_min_max; pub(crate) use nan::implement_nan; pub(crate) use negate::implement_negate; -pub(crate) use not::implement_not; pub(crate) use underlying_type::{get_type_group, get_type_ident, UnderlyingTypeGroup}; diff --git a/src/detail/not.rs b/src/detail/not.rs deleted file mode 100644 index 484f3b6..0000000 --- a/src/detail/not.rs +++ /dev/null @@ -1,14 +0,0 @@ -use proc_macro2::TokenStream; -use quote::quote; - -pub(crate) fn implement_not(name: &syn::Ident) -> TokenStream { - quote! { - impl std::ops::Not for #name { - type Output = Self; - - fn not(self) -> Self::Output { - #name(!self.value()) - } - } - } -} diff --git a/src/strong_type.rs b/src/strong_type.rs index a202e7e..f60c65a 100644 --- a/src/strong_type.rs +++ b/src/strong_type.rs @@ -1,7 +1,7 @@ use crate::detail::{ custom_display, get_type_group, get_type_ident, implement_arithmetic, implement_basic, - implement_basic_primitive, implement_basic_string, implement_display, implement_hash, - implement_min_max, implement_nan, implement_negate, implement_not, UnderlyingTypeGroup, + implement_basic_primitive, implement_basic_string, implement_bool_ops, implement_display, + implement_hash, implement_min_max, implement_nan, implement_negate, UnderlyingTypeGroup, }; use proc_macro2::TokenStream; use quote::quote; @@ -31,7 +31,6 @@ pub(super) fn expand_strong_type(input: DeriveInput, impl_arithmetic: bool) -> T } UnderlyingTypeGroup::Bool => { ast.extend(implement_basic_primitive(name, value_type)); - ast.extend(implement_not(name)); ast.extend(implement_hash(name)); } UnderlyingTypeGroup::Char => { @@ -53,7 +52,10 @@ pub(super) fn expand_strong_type(input: DeriveInput, impl_arithmetic: bool) -> T UnderlyingTypeGroup::UInt => { ast.extend(implement_arithmetic(name)); } - _ => panic!("Non-arithmetic type {value_type}"), + UnderlyingTypeGroup::Bool => { + ast.extend(implement_bool_ops(name)); + } + _ => panic!("Non-numeric type: {value_type}"), } } diff --git a/tests/strong_type.rs b/tests/strong_type.rs index 5fe2910..f226bfb 100644 --- a/tests/strong_type.rs +++ b/tests/strong_type.rs @@ -284,14 +284,32 @@ mod tests { } #[test] - fn test_bool_negate() { - #[derive(StrongType)] + fn test_bool_ops() { + #[derive(StrongNumericType)] struct IsTrue(bool); - let is_true = IsTrue(true); - assert!(is_true.value()); - assert!(!(!is_true).value()); - assert!((!!is_true).value()); + let x = IsTrue::new(true); + let y = IsTrue::new(false); + let x_ref = &x; + let y_ref = &y; + + assert_eq!(x & y, IsTrue::new(false)); + assert_eq!(x | y, IsTrue::new(true)); + assert_eq!(x ^ y, IsTrue::new(true)); + + assert_eq!(x_ref & y, IsTrue::new(false)); + assert_eq!(x_ref | y, IsTrue::new(true)); + assert_eq!(x_ref ^ y, IsTrue::new(true)); + + assert_eq!(x & y_ref, IsTrue::new(false)); + assert_eq!(x | y_ref, IsTrue::new(true)); + assert_eq!(x ^ y_ref, IsTrue::new(true)); + + assert_eq!(x_ref & y_ref, IsTrue::new(false)); + assert_eq!(x_ref | y_ref, IsTrue::new(true)); + + assert_eq!(!x, IsTrue::new(false)); + assert_eq!(!x_ref, IsTrue::new(false)); } #[test]