Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/uint384 hint #503

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions src/hint_processor/builtin_hint_codes.zig
Original file line number Diff line number Diff line change
Expand Up @@ -664,3 +664,189 @@ pub const NONDET_BIGINT3_V2 =
\\from starkware.cairo.common.cairo_secp.secp_utils import split
\\segments.write_arg(ids.res.address_, split(value))
;


// The following hints support the lib https://github.com/NethermindEth/research-basic-Cairo-operations-big-integers/blob/main/lib
pub const UINT384_UNSIGNED_DIV_REM =
\\def split(num: int, num_bits_shift: int, length: int):
\\ a = []
\\ for _ in range(length):
\\ a.append( num & ((1 << num_bits_shift) - 1) )
\\ num = num >> num_bits_shift
\\ return tuple(a)
\\
\\def pack(z, num_bits_shift: int) -> int:
\\ limbs = (z.d0, z.d1, z.d2)
\\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
\\
\\a = pack(ids.a, num_bits_shift = 128)
\\div = pack(ids.div, num_bits_shift = 128)
\\quotient, remainder = divmod(a, div)
\\
\\quotient_split = split(quotient, num_bits_shift=128, length=3)
\\assert len(quotient_split) == 3
\\
\\ids.quotient.d0 = quotient_split[0]
\\ids.quotient.d1 = quotient_split[1]
\\ids.quotient.d2 = quotient_split[2]
\\
\\remainder_split = split(remainder, num_bits_shift=128, length=3)
\\ids.remainder.d0 = remainder_split[0]
\\ids.remainder.d1 = remainder_split[1]
\\ids.remainder.d2 = remainder_split[2]
;

pub const UINT384_SPLIT_128 =
\\ids.low = ids.a & ((1<<128) - 1)
\\ids.high = ids.a >> 128
;

pub const ADD_NO_UINT384_CHECK =
\\sum_d0 = ids.a.d0 + ids.b.d0
\\ids.carry_d0 = 1 if sum_d0 >= ids.SHIFT else 0
\\sum_d1 = ids.a.d1 + ids.b.d1 + ids.carry_d0
\\ids.carry_d1 = 1 if sum_d1 >= ids.SHIFT else 0
\\sum_d2 = ids.a.d2 + ids.b.d2 + ids.carry_d1
\\ids.carry_d2 = 1 if sum_d2 >= ids.SHIFT else 0
;

pub const UINT384_SQRT =
\\from starkware.python.math_utils import isqrt
\\
\\def split(num: int, num_bits_shift: int, length: int):
\\ a = []
\\ for _ in range(length):
\\ a.append( num & ((1 << num_bits_shift) - 1) )
\\ num = num >> num_bits_shift
\\ return tuple(a)
\\
\\def pack(z, num_bits_shift: int) -> int:
\\ limbs = (z.d0, z.d1, z.d2)
\\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
\\
\\a = pack(ids.a, num_bits_shift=128)
\\root = isqrt(a)
\\assert 0 <= root < 2 ** 192
\\root_split = split(root, num_bits_shift=128, length=3)
\\ids.root.d0 = root_split[0]
\\ids.root.d1 = root_split[1]
\\ids.root.d2 = root_split[2]
;

pub const SUB_REDUCED_A_AND_REDUCED_B =
\\def split(num: int, num_bits_shift: int, length: int):
\\ a = []
\\ for _ in range(length):
\\ a.append( num & ((1 << num_bits_shift) - 1) )
\\ num = num >> num_bits_shift
\\ return tuple(a)
\\
\\def pack(z, num_bits_shift: int) -> int:
\\ limbs = (z.d0, z.d1, z.d2)
\\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
\\
\\a = pack(ids.a, num_bits_shift = 128)
\\b = pack(ids.b, num_bits_shift = 128)
\\p = pack(ids.p, num_bits_shift = 128)
\\
\\res = (a - b) % p
\\
\\
\\res_split = split(res, num_bits_shift=128, length=3)
\\
\\ids.res.d0 = res_split[0]
\\ids.res.d1 = res_split[1]
\\ids.res.d2 = res_split[2]
;

pub const UINT384_SIGNED_NN = "memory[ap] = 1 if 0 <= (ids.a.d2 % PRIME) < 2 ** 127 else 0";

pub const UNSIGNED_DIV_REM_UINT768_BY_UINT384 =
\\def split(num: int, num_bits_shift: int, length: int):
\\ a = []
\\ for _ in range(length):
\\ a.append( num & ((1 << num_bits_shift) - 1) )
\\ num = num >> num_bits_shift
\\ return tuple(a)
\\
\\def pack(z, num_bits_shift: int) -> int:
\\ limbs = (z.d0, z.d1, z.d2)
\\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
\\
\\def pack_extended(z, num_bits_shift: int) -> int:
\\ limbs = (z.d0, z.d1, z.d2, z.d3, z.d4, z.d5)
\\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
\\
\\a = pack_extended(ids.a, num_bits_shift = 128)
\\div = pack(ids.div, num_bits_shift = 128)
\\
\\quotient, remainder = divmod(a, div)
\\
\\quotient_split = split(quotient, num_bits_shift=128, length=6)
\\
\\ids.quotient.d0 = quotient_split[0]
\\ids.quotient.d1 = quotient_split[1]
\\ids.quotient.d2 = quotient_split[2]
\\ids.quotient.d3 = quotient_split[3]
\\ids.quotient.d4 = quotient_split[4]
\\ids.quotient.d5 = quotient_split[5]
\\
\\remainder_split = split(remainder, num_bits_shift=128, length=3)
\\ids.remainder.d0 = remainder_split[0]
\\ids.remainder.d1 = remainder_split[1]
\\ids.remainder.d2 = remainder_split[2]
;

// equal to UNSIGNED_DIV_REM_UINT768_BY_UINT384 but with some whitespace removed
// in the `num = num >> num_bits_shift` and between `pack` and `pack_extended`
pub const UNSIGNED_DIV_REM_UINT768_BY_UINT384_STRIPPED =
\\def split(num: int, num_bits_shift: int, length: int):
\\ a = []
\\ for _ in range(length):
\\ a.append( num & ((1 << num_bits_shift) - 1) )
\\ num = num >> num_bits_shift
\\ return tuple(a)
\\
\\def pack(z, num_bits_shift: int) -> int:
\\ limbs = (z.d0, z.d1, z.d2)
\\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
\\
\\def pack_extended(z, num_bits_shift: int) -> int:
\\ limbs = (z.d0, z.d1, z.d2, z.d3, z.d4, z.d5)
\\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
\\
\\a = pack_extended(ids.a, num_bits_shift = 128)
\\div = pack(ids.div, num_bits_shift = 128)
\\
\\quotient, remainder = divmod(a, div)
\\
\\quotient_split = split(quotient, num_bits_shift=128, length=6)
\\
\\ids.quotient.d0 = quotient_split[0]
\\ids.quotient.d1 = quotient_split[1]
\\ids.quotient.d2 = quotient_split[2]
\\ids.quotient.d3 = quotient_split[3]
\\ids.quotient.d4 = quotient_split[4]
\\ids.quotient.d5 = quotient_split[5]
\\
\\remainder_split = split(remainder, num_bits_shift=128, length=3)
\\ids.remainder.d0 = remainder_split[0]
\\ids.remainder.d1 = remainder_split[1]
\\ids.remainder.d2 = remainder_split[2]
;

pub const INV_MOD_P_UINT512 =
\\def pack_512(u, num_bits_shift: int) -> int:
\\ limbs = (u.d0, u.d1, u.d2, u.d3)
\\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
\\
\\x = pack_512(ids.x, num_bits_shift = 128)
\\p = ids.p.low + (ids.p.high << 128)
\\x_inverse_mod_p = pow(x,-1, p)
\\
\\x_inverse_mod_p_split = (x_inverse_mod_p & ((1 << 128) - 1), x_inverse_mod_p >> 128)
\\
\\ids.x_inverse_mod_p.low = x_inverse_mod_p_split[0]
\\ids.x_inverse_mod_p.high = x_inverse_mod_p_split[1]
;

20 changes: 13 additions & 7 deletions src/hint_processor/builtin_hint_processor/secp/bigint_utils.zig
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,22 @@ pub fn BigIntN(comptime NUM_LIMBS: usize) type {
return .{ .limbs = limbs };
}

pub fn insertFromVarName(self: *Self, allocator: std.mem.allocator, var_name: []const u8, vm: *CairoVM, ids_data: std.StringHashMap(HintReference), ap_tracking: ApTracking) !void {
pub fn insertFromVarName(
self: *const Self,
allocator: std.mem.Allocator,
var_name: []const u8,
vm: *CairoVM,
ids_data: std.StringHashMap(HintReference),
ap_tracking: ApTracking,
) !void {
const addr = try hint_utils.getRelocatableFromVarName(var_name, vm, ids_data, ap_tracking);
inline for (0..NUM_LIMBS) |i| {
try vm.insertInMemory(allocator, addr + i, self.limbs[i]);
try vm.insertInMemory(allocator, try addr.addUint(i), MaybeRelocatable.fromFelt(self.limbs[i]));
}
}

pub fn pack(self: *const Self, allocator: std.mem.Allocator) !Int {
const result = packBigInt(allocator, NUM_LIMBS, self.limbs, 128);
return result;
return packBigInt(allocator, NUM_LIMBS, self.limbs, 128);
}

pub fn pack86(self: *const Self, allocator: std.mem.Allocator) !Int {
Expand All @@ -80,9 +86,9 @@ pub fn BigIntN(comptime NUM_LIMBS: usize) type {
return result;
}

pub fn split(self: *Self, num: Int) Self {
const limbs = splitBigInt(std.mem.Allocator, num, self.limbs.len, 128);
return self.fromValues(limbs);

pub fn split(allocator: std.mem.Allocator, num: Int) !Self {
return Self.fromValues(try splitBigInt(allocator, num, NUM_LIMBS, 128));
}

// @TODO: implement from. It is dependent on split function.
Expand Down
21 changes: 21 additions & 0 deletions src/hint_processor/hint_processor_def.zig
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ const segments = @import("segments.zig");

const bigint_utils = @import("../hint_processor/builtin_hint_processor/secp/bigint_utils.zig");
const bigint = @import("bigint.zig");
const uint384 = @import("uint384.zig");
const inv_mod_p_uint512 = @import("vrf/inv_mod_p_uint512.zig");


const deserialize_utils = @import("../parser/deserialize_utils.zig");

Expand Down Expand Up @@ -384,6 +387,24 @@ pub const CairoVMHintProcessor = struct {
try bigint.bigintPackDivModHint(allocator, vm, exec_scopes, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.BIGINT_SAFE_DIV, hint_data.code)) {
try bigint.bigIntSafeDivHint(allocator, vm, exec_scopes, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UINT384_UNSIGNED_DIV_REM, hint_data.code)) {
try uint384.uint384UnsignedDivRem(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UINT384_SPLIT_128, hint_data.code)) {
try uint384.uint384Split128(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.ADD_NO_UINT384_CHECK, hint_data.code)) {
try uint384.addNoUint384Check(allocator, vm, hint_data.ids_data, hint_data.ap_tracking, constants);
} else if (std.mem.eql(u8, hint_codes.UINT384_SQRT, hint_data.code)) {
try uint384.uint384Sqrt(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UINT384_SIGNED_NN, hint_data.code)) {
try uint384.uint384SignedNn(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.SUB_REDUCED_A_AND_REDUCED_B, hint_data.code)) {
try uint384.subReducedAAndReducedB(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UNSIGNED_DIV_REM_UINT768_BY_UINT384, hint_data.code)) {
try uint384.unsignedDivRemUint768ByUint384(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UNSIGNED_DIV_REM_UINT768_BY_UINT384_STRIPPED, hint_data.code)) {
try uint384.unsignedDivRemUint768ByUint384(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.INV_MOD_P_UINT512, hint_data.code)) {
try inv_mod_p_uint512.invModPUint512(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else {
std.log.err("not implemented: {s}\n", .{hint_data.code});
return HintError.HintNotImplemented;
Expand Down
4 changes: 4 additions & 0 deletions src/hint_processor/testing_utils.zig
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ pub fn checkMemory(mem: *Memory, comptime rows: anytype) !void {
pub fn checkMemoryAddress(mem: *Memory, data: anytype) !void {
const expected = if (data[1].len == 2) MaybeRelocatable.fromRelocatable(Relocatable.init(data[1][0], data[1][1])) else MaybeRelocatable.fromInt(u256, data[1][0]);

errdefer {
std.log.err("failed expect: {any}, got: {any}\n", .{ expected, mem.get(Relocatable.init(data[0][0], data[0][1])) });
}

try std.testing.expectEqual(expected, mem.get(Relocatable.init(data[0][0], data[0][1])));
}

Expand Down
10 changes: 10 additions & 0 deletions src/hint_processor/uint256_utils.zig
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const ApTracking = @import("../vm/types/programjson.zig").ApTracking;
const HintData = @import("hint_processor_def.zig").HintData;
const ExecutionScopes = @import("../vm/types/execution_scopes.zig").ExecutionScopes;

const Int = @import("std").math.big.int.Managed;
const helper = @import("../math/fields/helper.zig");
const MathError = @import("../vm/error.zig").MathError;
const HintError = @import("../vm/error.zig").HintError;
Expand Down Expand Up @@ -55,6 +56,15 @@ pub const Uint256 = struct {
pub fn split(comptime T: type, num: T) Self {
return Self.init(Felt252.fromInt(T, num & std.math.maxInt(u128)), Felt252.fromInt(T, num >> 128));
}

pub fn pack(self: Self, allocator: std.mem.Allocator) !Int {
var result = try Int.initSet(allocator, self.high.toInteger());
errdefer result.deinit();

try result.shiftLeft(&result, 128);
try result.addScalar(&result, self.low.toInteger());
return result;
}
// converting self to biguint value
// optimize by using biguint
// right now using u512, so to not use allocator with big int
Expand Down
Loading
Loading