diff --git a/hax-bounded-integers/src/lib.rs b/hax-bounded-integers/src/lib.rs index a556b49df..1f4bf6398 100644 --- a/hax-bounded-integers/src/lib.rs +++ b/hax-bounded-integers/src/lib.rs @@ -1,76 +1,194 @@ -use hax::IsRefinement; +use hax::Refinement; use hax_lib as hax; +mod num_traits; + macro_rules! derivate_binop_for_bounded { - ($t:ident, $bounded_t:ident, $trait:ident, $meth:ident) => { + ({$t:ident, $bounded_t:ident}; $($tt:tt)*) => { + derivate_binop_for_bounded!({$t, $bounded_t, get, Self::Output,{},{}}; $($tt)*) ; + }; + ({$t:ident, $bounded_t:ident, $get:ident, $out:ty,{$($ref:tt)?}, {$($unref:tt)?}};) => {}; + ({$t:ident, $bounded_t:ident, $get:ident, $out:ty,{$($ref:tt)?}, {$($unref:tt)?}}; ($trait:ident, $meth:ident), $($tt:tt)*) => { + derivate_binop_for_bounded!(@$t, $bounded_t, $trait, $meth, $get, $out, {$($ref)?}, {$($unref)?}); + derivate_binop_for_bounded!({$t, $bounded_t, $get, $out, {$($ref)?}, {$($unref)?}}; $($tt)*); + }; + (@$t:ident, $bounded_t:ident, $trait:ident, $meth:ident, $get:ident, $out:ty,{$($ref:tt)?}, {$($unref:tt)?}) => { // BoundedT BoundedT impl $trait<$bounded_t> for $bounded_t { type Output = $t; - fn $meth(self, other: $bounded_t) -> Self::Output { - self.value().$meth(other.value()) + #[inline(always)] + fn $meth($($ref)? self, other: $($ref)? $bounded_t) -> $out { + ($($unref)? self.$get()).$meth($($unref)? other.$get()) } } + // BoundedT T impl $trait<$t> for $bounded_t { type Output = $t; - fn $meth(self, other: $t) -> Self::Output { - self.value().$meth(other) + #[inline(always)] + fn $meth($($ref)? self, other: $($ref)? $t) -> $out { + ($($unref)? self.$get()).$meth($($unref)? other) } } // T BoundedT impl $trait<$bounded_t> for $t { type Output = $t; - fn $meth(self, other: $bounded_t) -> Self::Output { - self.$meth(other.value()) + #[inline(always)] + fn $meth($($ref)? self, other: $($ref)? $bounded_t) -> $out { + ($($unref)? self).$meth($($unref)? other.$get()) } } }; } macro_rules! mk_bounded { - ($bounded_t:ident($t: ident)$(,)?) => { + ($bounded_t:ident($t: ident $($bytes:expr)?)$(,)?) => { #[doc = concat!("Bounded ", stringify!($t)," integers. This struct enforces the invariant that values are greater or equal to `MIN` and less or equal to `MAX`.")] - #[hax::newtype_as_refinement(|x| x >= MIN && x <= MAX)] + #[hax::refinement_type(|x| x >= MIN && x <= MAX)] + #[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd)] pub struct $bounded_t($t); #[hax::exclude] const _: () = { use core::ops::*; - derivate_binop_for_bounded!($t, $bounded_t, Add, add); - derivate_binop_for_bounded!($t, $bounded_t, Sub, sub); - derivate_binop_for_bounded!($t, $bounded_t, Mul, mul); - derivate_binop_for_bounded!($t, $bounded_t, Div, div); + use num_traits::*; + + derivate_binop_for_bounded!( + {$t, $bounded_t}; + (Add, add), (Sub, sub), (Mul, mul), (Div, div), (Rem, rem), + (BitOr, bitor), (BitAnd, bitand), (BitXor, bitxor), + (Shl, shl), (Shr, shr), + ); + + derivate_binop_for_bounded!( + {$t, $bounded_t, deref, Option, {&}, {*}}; + (CheckedAdd, checked_add), (CheckedSub, checked_sub), + (CheckedMul, checked_mul), (CheckedDiv, checked_div), + ); + + impl CheckedNeg for $bounded_t { + type Output = $t; + #[inline(always)] + fn checked_neg(&self) -> Option<$t> { + self.deref().checked_neg() + } + } + + impl Not for $bounded_t { + type Output = $t; + #[inline(always)] + fn not(self) -> Self::Output { + self.deref().not() + } + } + + impl $bounded_t { + pub const MIN: $t = MIN; + pub const MAX: $t = MAX; + } + + impl Bounded for $bounded_t { + #[inline(always)] + fn min_value() -> Self { + Self::new(MIN) + } + #[inline(always)] + fn max_value() -> Self { + Self::new(MAX) + } + } + + $( + impl FromBytes for $bounded_t { + type BYTES = [u8; $bytes]; + + fn from_le_bytes(bytes: Self::BYTES) -> Self { + Self::new($t::from_le_bytes(bytes)) + } + fn from_be_bytes(bytes: Self::BYTES) -> Self { + Self::new($t::from_be_bytes(bytes)) + } + } + + impl ToBytes for $bounded_t { + fn to_le_bytes(self) -> Self::BYTES { + self.get().to_le_bytes() + } + fn to_be_bytes(self) -> Self::BYTES { + self.get().to_be_bytes() + } + } + )? + + impl MachineInt for $bounded_t { } + + impl BitOps for $bounded_t { + type Output = $t; + + fn count_ones(self) -> u32 { + self.get().count_ones() + } + fn count_zeros(self) -> u32 { + self.get().count_zeros() + } + fn leading_ones(self) -> u32 { + self.get().leading_ones() + } + fn leading_zeros(self) -> u32 { + self.get().leading_zeros() + } + fn trailing_ones(self) -> u32 { + self.get().trailing_ones() + } + fn trailing_zeros(self) -> u32 { + self.get().trailing_zeros() + } + fn rotate_left(self, n: u32) -> Self::Output { + self.get().rotate_left(n) + } + fn rotate_right(self, n: u32) -> Self::Output { + self.get().rotate_right(n) + } + fn from_be(x: Self) -> Self::Output { + Self::Output::from_be(x.get()) + } + fn from_le(x: Self) -> Self::Output { + Self::Output::from_le(x.get()) + } + fn to_be(self) -> Self::Output { + Self::Output::to_be(self.get()) + } + fn to_le(self) -> Self::Output { + Self::Output::to_le(self.get()) + } + fn pow(self, exp: u32) -> Self::Output { + Self::Output::pow(self.get(), exp) + } + } }; }; - ($bounded_t:ident($t: ident), $($tt:tt)+) => { - mk_bounded!($bounded_t($t)); + ($bounded_t:ident($t: ident $($bytes:expr)?), $($tt:tt)+) => { + mk_bounded!($bounded_t($t $($bytes)?)); mk_bounded!($($tt)+); }; } +use hax::int::Int; + mk_bounded!( - BoundedI8(i8), - BoundedI16(i16), - BoundedI32(i32), - BoundedI64(i64), - BoundedI128(i128), + BoundedI8(i8 1), + BoundedI16(i16 2), + BoundedI32(i32 4), + BoundedI64(i64 8), + BoundedI128(i128 16), BoundedIsize(isize), - BoundedU8(u8), - BoundedU16(u16), - BoundedU32(u32), - BoundedU64(u64), - BoundedU128(u128), + BoundedU8(u8 1), + BoundedU16(u16 2), + BoundedU32(u32 4), + BoundedU64(u64 8), + BoundedU128(u128 16), BoundedUsize(usize), ); - -pub fn _test( - x: BoundedU8<0, 20>, - y: BoundedU8<10, 13>, - z: BoundedU8<5, 5>, - d: BoundedU8<1, 1>, -) -> BoundedU8<5, 28> { - BoundedU8::new(x + y - z / d) -} diff --git a/hax-bounded-integers/src/num_traits.rs b/hax-bounded-integers/src/num_traits.rs new file mode 100644 index 000000000..aa63dcd67 --- /dev/null +++ b/hax-bounded-integers/src/num_traits.rs @@ -0,0 +1,101 @@ +use core::ops::*; + +pub trait Zero: Sized + Add { + fn zero() -> Self; +} + +pub trait One: Sized + Mul { + fn one() -> Self; +} + +pub trait NumOps: + Add + + Sub + + Mul + + Div + + Rem +{ +} + +pub trait Bounded { + fn min_value() -> Self; + fn max_value() -> Self; +} + +pub trait CheckedAdd { + type Output; + fn checked_add(&self, v: &Rhs) -> Option; +} + +pub trait CheckedSub { + type Output; + fn checked_sub(&self, v: &Rhs) -> Option; +} + +pub trait CheckedMul { + type Output; + fn checked_mul(&self, v: &Rhs) -> Option; +} + +pub trait CheckedDiv { + type Output; + fn checked_div(&self, v: &Rhs) -> Option; +} + +pub trait CheckedNeg { + type Output; + fn checked_neg(&self) -> Option; +} + +pub trait Num: PartialEq + Zero + NumOps {} + +pub trait FromBytes { + type BYTES; + + fn from_le_bytes(bytes: Self::BYTES) -> Self; + fn from_be_bytes(bytes: Self::BYTES) -> Self; +} + +pub trait ToBytes: FromBytes { + fn to_le_bytes(self) -> Self::BYTES; + fn to_be_bytes(self) -> Self::BYTES; +} + +pub trait MachineInt: + Sized + + Copy + + Bounded + + PartialOrd + + Ord + + Eq + + Not + + BitAnd::Output> + + BitOr::Output> + + BitXor::Output> + + Shl::Output> + + Shr::Output> + + CheckedAdd::Output> + + CheckedSub::Output> + + CheckedMul::Output> + + CheckedDiv::Output> + + BitOps::Output> +{ +} + +pub trait BitOps { + type Output; + + fn count_ones(self) -> u32; + fn count_zeros(self) -> u32; + fn leading_ones(self) -> u32; + fn leading_zeros(self) -> u32; + fn trailing_ones(self) -> u32; + fn trailing_zeros(self) -> u32; + fn rotate_left(self, n: u32) -> Self::Output; + fn rotate_right(self, n: u32) -> Self::Output; + fn from_be(x: Self) -> Self::Output; + fn from_le(x: Self) -> Self::Output; + fn to_be(self) -> Self::Output; + fn to_le(self) -> Self::Output; + fn pow(self, exp: u32) -> Self::Output; +}