Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
W95Psp committed Apr 30, 2024
1 parent 490d94b commit 93e4b1a
Show file tree
Hide file tree
Showing 2 changed files with 254 additions and 35 deletions.
188 changes: 153 additions & 35 deletions hax-bounded-integers/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,76 +1,194 @@
use hax::IsRefinement;
use hax::Refinement;
use hax_lib as hax;

mod num_traits;

macro_rules! derivate_binop_for_bounded {
($t:ident, $bounded_t:ident, $trait:ident, $meth:ident) => {
({$t:ident, $bounded_t:ident}; $($tt:tt)*) => {
derivate_binop_for_bounded!({$t, $bounded_t, get, Self::Output,{},{}}; $($tt)*) ;
};
({$t:ident, $bounded_t:ident, $get:ident, $out:ty,{$($ref:tt)?}, {$($unref:tt)?}};) => {};
({$t:ident, $bounded_t:ident, $get:ident, $out:ty,{$($ref:tt)?}, {$($unref:tt)?}}; ($trait:ident, $meth:ident), $($tt:tt)*) => {
derivate_binop_for_bounded!(@$t, $bounded_t, $trait, $meth, $get, $out, {$($ref)?}, {$($unref)?});
derivate_binop_for_bounded!({$t, $bounded_t, $get, $out, {$($ref)?}, {$($unref)?}}; $($tt)*);
};
(@$t:ident, $bounded_t:ident, $trait:ident, $meth:ident, $get:ident, $out:ty,{$($ref:tt)?}, {$($unref:tt)?}) => {
// BoundedT<A, B> <OP> BoundedT<C, D>
impl<const MIN_LHS: $t, const MAX_LHS: $t, const MIN_RHS: $t, const MAX_RHS: $t>
$trait<$bounded_t<MIN_RHS, MAX_RHS>> for $bounded_t<MIN_LHS, MAX_LHS>
{
type Output = $t;
fn $meth(self, other: $bounded_t<MIN_RHS, MAX_RHS>) -> Self::Output {
self.value().$meth(other.value())
#[inline(always)]
fn $meth($($ref)? self, other: $($ref)? $bounded_t<MIN_RHS, MAX_RHS>) -> $out {
($($unref)? self.$get()).$meth($($unref)? other.$get())
}
}

// BoundedT<A, B> <OP> T
impl<const MIN: $t, const MAX: $t> $trait<$t> for $bounded_t<MIN, MAX> {
type Output = $t;
fn $meth(self, other: $t) -> Self::Output {
self.value().$meth(other)
#[inline(always)]
fn $meth($($ref)? self, other: $($ref)? $t) -> $out {
($($unref)? self.$get()).$meth($($unref)? other)
}
}

// T <OP> BoundedT<A, B>
impl<const MIN: $t, const MAX: $t> $trait<$bounded_t<MIN, MAX>> for $t {
type Output = $t;
fn $meth(self, other: $bounded_t<MIN, MAX>) -> Self::Output {
self.$meth(other.value())
#[inline(always)]
fn $meth($($ref)? self, other: $($ref)? $bounded_t<MIN, MAX>) -> $out {
($($unref)? self).$meth($($unref)? other.$get())
}
}
};
}

macro_rules! mk_bounded {
($bounded_t:ident($t: ident)$(,)?) => {
($bounded_t:ident($t: ident $($bytes:expr)?)$(,)?) => {
#[doc = concat!("Bounded ", stringify!($t)," integers. This struct enforces the invariant that values are greater or equal to `MIN` and less or equal to `MAX`.")]
#[hax::newtype_as_refinement(|x| x >= MIN && x <= MAX)]
#[hax::refinement_type(|x| x >= MIN && x <= MAX)]
#[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
pub struct $bounded_t<const MIN: $t, const MAX: $t>($t);

#[hax::exclude]
const _: () = {
use core::ops::*;
derivate_binop_for_bounded!($t, $bounded_t, Add, add);
derivate_binop_for_bounded!($t, $bounded_t, Sub, sub);
derivate_binop_for_bounded!($t, $bounded_t, Mul, mul);
derivate_binop_for_bounded!($t, $bounded_t, Div, div);
use num_traits::*;

derivate_binop_for_bounded!(
{$t, $bounded_t};
(Add, add), (Sub, sub), (Mul, mul), (Div, div), (Rem, rem),
(BitOr, bitor), (BitAnd, bitand), (BitXor, bitxor),
(Shl, shl), (Shr, shr),
);

derivate_binop_for_bounded!(
{$t, $bounded_t, deref, Option<Self::Output>, {&}, {*}};
(CheckedAdd, checked_add), (CheckedSub, checked_sub),
(CheckedMul, checked_mul), (CheckedDiv, checked_div),
);

impl<const MIN: $t, const MAX: $t> CheckedNeg for $bounded_t<MIN, MAX> {
type Output = $t;
#[inline(always)]
fn checked_neg(&self) -> Option<$t> {
self.deref().checked_neg()
}
}

impl<const MIN: $t, const MAX: $t> Not for $bounded_t<MIN, MAX> {
type Output = $t;
#[inline(always)]
fn not(self) -> Self::Output {
self.deref().not()
}
}

impl<const MIN: $t, const MAX: $t> $bounded_t<MIN, MAX> {
pub const MIN: $t = MIN;
pub const MAX: $t = MAX;
}

impl<const MIN: $t, const MAX: $t> Bounded for $bounded_t<MIN, MAX> {
#[inline(always)]
fn min_value() -> Self {
Self::new(MIN)
}
#[inline(always)]
fn max_value() -> Self {
Self::new(MAX)
}
}

$(
impl<const MIN: $t, const MAX: $t> FromBytes for $bounded_t<MIN, MAX> {
type BYTES = [u8; $bytes];

fn from_le_bytes(bytes: Self::BYTES) -> Self {
Self::new($t::from_le_bytes(bytes))
}
fn from_be_bytes(bytes: Self::BYTES) -> Self {
Self::new($t::from_be_bytes(bytes))
}
}

impl<const MIN: $t, const MAX: $t> ToBytes for $bounded_t<MIN, MAX> {
fn to_le_bytes(self) -> Self::BYTES {
self.get().to_le_bytes()
}
fn to_be_bytes(self) -> Self::BYTES {
self.get().to_be_bytes()
}
}
)?

impl<const MIN: $t, const MAX: $t> MachineInt for $bounded_t<MIN, MAX> { }

impl<const MIN: $t, const MAX: $t> BitOps for $bounded_t<MIN, MAX> {
type Output = $t;

fn count_ones(self) -> u32 {
self.get().count_ones()
}
fn count_zeros(self) -> u32 {
self.get().count_zeros()
}
fn leading_ones(self) -> u32 {
self.get().leading_ones()
}
fn leading_zeros(self) -> u32 {
self.get().leading_zeros()
}
fn trailing_ones(self) -> u32 {
self.get().trailing_ones()
}
fn trailing_zeros(self) -> u32 {
self.get().trailing_zeros()
}
fn rotate_left(self, n: u32) -> Self::Output {
self.get().rotate_left(n)
}
fn rotate_right(self, n: u32) -> Self::Output {
self.get().rotate_right(n)
}
fn from_be(x: Self) -> Self::Output {
Self::Output::from_be(x.get())
}
fn from_le(x: Self) -> Self::Output {
Self::Output::from_le(x.get())
}
fn to_be(self) -> Self::Output {
Self::Output::to_be(self.get())
}
fn to_le(self) -> Self::Output {
Self::Output::to_le(self.get())
}
fn pow(self, exp: u32) -> Self::Output {
Self::Output::pow(self.get(), exp)
}
}
};
};
($bounded_t:ident($t: ident), $($tt:tt)+) => {
mk_bounded!($bounded_t($t));
($bounded_t:ident($t: ident $($bytes:expr)?), $($tt:tt)+) => {
mk_bounded!($bounded_t($t $($bytes)?));
mk_bounded!($($tt)+);
};
}

use hax::int::Int;

mk_bounded!(
BoundedI8(i8),
BoundedI16(i16),
BoundedI32(i32),
BoundedI64(i64),
BoundedI128(i128),
BoundedI8(i8 1),
BoundedI16(i16 2),
BoundedI32(i32 4),
BoundedI64(i64 8),
BoundedI128(i128 16),
BoundedIsize(isize),
BoundedU8(u8),
BoundedU16(u16),
BoundedU32(u32),
BoundedU64(u64),
BoundedU128(u128),
BoundedU8(u8 1),
BoundedU16(u16 2),
BoundedU32(u32 4),
BoundedU64(u64 8),
BoundedU128(u128 16),
BoundedUsize(usize),
);

pub fn _test(
x: BoundedU8<0, 20>,
y: BoundedU8<10, 13>,
z: BoundedU8<5, 5>,
d: BoundedU8<1, 1>,
) -> BoundedU8<5, 28> {
BoundedU8::new(x + y - z / d)
}
101 changes: 101 additions & 0 deletions hax-bounded-integers/src/num_traits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use core::ops::*;

pub trait Zero: Sized + Add<Self, Output = Self> {
fn zero() -> Self;
}

pub trait One: Sized + Mul<Self, Output = Self> {
fn one() -> Self;
}

pub trait NumOps<Rhs = Self, Output = Self>:
Add<Rhs, Output = Output>
+ Sub<Rhs, Output = Output>
+ Mul<Rhs, Output = Output>
+ Div<Rhs, Output = Output>
+ Rem<Rhs, Output = Output>
{
}

pub trait Bounded {
fn min_value() -> Self;
fn max_value() -> Self;
}

pub trait CheckedAdd<Rhs = Self> {
type Output;
fn checked_add(&self, v: &Rhs) -> Option<Self::Output>;
}

pub trait CheckedSub<Rhs = Self> {
type Output;
fn checked_sub(&self, v: &Rhs) -> Option<Self::Output>;
}

pub trait CheckedMul<Rhs = Self> {
type Output;
fn checked_mul(&self, v: &Rhs) -> Option<Self::Output>;
}

pub trait CheckedDiv<Rhs = Self> {
type Output;
fn checked_div(&self, v: &Rhs) -> Option<Self::Output>;
}

pub trait CheckedNeg {
type Output;
fn checked_neg(&self) -> Option<Self::Output>;
}

pub trait Num: PartialEq + Zero + NumOps {}

pub trait FromBytes {
type BYTES;

fn from_le_bytes(bytes: Self::BYTES) -> Self;
fn from_be_bytes(bytes: Self::BYTES) -> Self;
}

pub trait ToBytes: FromBytes {
fn to_le_bytes(self) -> Self::BYTES;
fn to_be_bytes(self) -> Self::BYTES;
}

pub trait MachineInt:
Sized
+ Copy
+ Bounded
+ PartialOrd
+ Ord
+ Eq
+ Not
+ BitAnd<Output = <Self as Not>::Output>
+ BitOr<Output = <Self as Not>::Output>
+ BitXor<Output = <Self as Not>::Output>
+ Shl<Self, Output = <Self as Not>::Output>
+ Shr<Self, Output = <Self as Not>::Output>
+ CheckedAdd<Output = <Self as Not>::Output>
+ CheckedSub<Output = <Self as Not>::Output>
+ CheckedMul<Output = <Self as Not>::Output>
+ CheckedDiv<Output = <Self as Not>::Output>
+ BitOps<Output = <Self as Not>::Output>
{
}

pub trait BitOps {
type Output;

fn count_ones(self) -> u32;
fn count_zeros(self) -> u32;
fn leading_ones(self) -> u32;
fn leading_zeros(self) -> u32;
fn trailing_ones(self) -> u32;
fn trailing_zeros(self) -> u32;
fn rotate_left(self, n: u32) -> Self::Output;
fn rotate_right(self, n: u32) -> Self::Output;
fn from_be(x: Self) -> Self::Output;
fn from_le(x: Self) -> Self::Output;
fn to_be(self) -> Self::Output;
fn to_le(self) -> Self::Output;
fn pow(self, exp: u32) -> Self::Output;
}

0 comments on commit 93e4b1a

Please sign in to comment.