Skip to content

Commit

Permalink
refactoring vint
Browse files Browse the repository at this point in the history
  • Loading branch information
herumi committed Jan 19, 2024
1 parent 9a30b1a commit 76de55d
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 63 deletions.
120 changes: 57 additions & 63 deletions include/mcl/vint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@

namespace mcl {

// return the max size in x[0..n)
inline int getRealSize(const Unit *x, int n)
{
assert(n > 0);
while (n > 0) {
if (x[n - 1]) return n;
n--;
}
return 1;
}

/**
signed integer with variable length
*/
Expand All @@ -31,13 +42,12 @@ class Vint {
static const int invalidVar = -2147483647 - 1; // abs(invalidVar) is not defined
static const size_t N = maxUnitSize * 2 + 1;
private:
Unit buf_[N];
Unit buf_[N]; // assume buf_[size_ - 1] != 0 unless the value is zero
size_t size_;
bool isNeg_;
void trim(size_t n)
void trim()
{
assert(n > 0);
int i = (int)n - 1;
int i = (int)size_ - 1;
for (; i > 0; i--) {
if (buf_[i]) {
size_ = i + 1;
Expand Down Expand Up @@ -75,7 +85,7 @@ class Vint {
c = bint::addUnit(dst + yn, n, c);
}
dst[xn] = c;
z.trim(xn + 1);
z.trim();
}
static void uadd1(Vint& z, const Unit *x, size_t xn, Unit y)
{
Expand All @@ -86,7 +96,7 @@ class Vint {
}
if (z.buf_ != x) bint::copyN(z.buf_, x, xn);
z.buf_[zn - 1] = bint::addUnit(z.buf_, xn, y);
z.trim(zn);
z.trim();
}
static void usub1(Vint& z, const Unit *x, size_t xn, Unit y)
{
Expand All @@ -99,7 +109,7 @@ class Vint {
Unit c = bint::subUnit(dst, xn, y);
(void)c;
assert(!c);
z.trim(zn);
z.trim();
}
static void usub(Vint& z, const Unit *x, size_t xn, const Unit *y, size_t yn)
{
Expand All @@ -117,7 +127,7 @@ class Vint {
c = bint::subUnit(dst, n, c);
}
assert(!c);
z.trim(xn);
z.trim();
}
static void _add(Vint& z, const Vint& x, bool xNeg, const Vint& y, bool yNeg)
{
Expand Down Expand Up @@ -187,13 +197,13 @@ class Vint {
}
Unit *xx = (Unit*)CYBOZU_ALLOCA(sizeof(Unit) * xn);
bint::copyN(xx, x, xn);
Unit *qq = q ? &q->buf_[0] : 0;
size_t rn = bint::div(qq, qn, xx, xn, &y[0], yn);
Unit *qq = q ? q->buf_ : 0;
size_t rn = bint::div(qq, qn, xx, xn, y, yn);
r.copy(xx, rn);
r.trim();
if (q) {
q->trim(qn);
q->trim();
}
r.trim(rn);
}
/*
@param x [inout] x <- d
Expand Down Expand Up @@ -362,7 +372,7 @@ class Vint {
if (!*pb) {
return;
}
trim(unitSize);
trim();
}
/*
set [0, max) randomly
Expand All @@ -375,7 +385,7 @@ class Vint {
size_ = n;
rg.read(pb, buf_, n * sizeof(Unit));
if (!*pb) return;
trim(n);
trim();
*this %= max;
}
/*
Expand Down Expand Up @@ -459,7 +469,7 @@ class Vint {
buf_[q] |= mask;
} else {
buf_[q] &= ~mask;
trim(q + 1);
trim();
}
}
/*
Expand All @@ -470,17 +480,10 @@ class Vint {
*/
void setStr(bool *pb, const char *str, int base = 0)
{
// allow twice size of MCL_MAX_BIT_SIZE because of multiplication
const size_t maxN = (MCL_MAX_BIT_SIZE * 2 + UnitBitSize - 1) / UnitBitSize;
if (!setSize(maxN)) {
*pb = false;
return;
}
isNeg_ = false;
size_t len = strlen(str);
size_t n = fp::strToArray(&isNeg_, &buf_[0], maxN, str, len, base);
if (n == 0) return;
trim(n);
size_t n = fp::strToArray(&isNeg_, buf_, N, str, len, base);
if (n == 0 || !setSize(n)) return;
trim();
*pb = true;
}
static int compare(const Vint& x, const Vint& y)
Expand Down Expand Up @@ -543,7 +546,7 @@ class Vint {
size_t zn = xn + yn;
if (!z.setSize(zn)) return;
bint::mulNM(z.buf_, x.buf_, xn, y.buf_, yn);
z.trim(zn);
z.trim();
z.isNeg_ = x.isNeg_ ^ y.isNeg_;
}
static void sqr(Vint& y, const Vint& x)
Expand All @@ -565,7 +568,7 @@ class Vint {
if (!z.setSize(zn)) return;
z.buf_[zn - 1] = bint::mulUnitN(z.buf_, x.buf_, y, xn);
z.isNeg_ = x.isNeg_;
z.trim(zn);
z.trim();
}
static void divu1(Vint& q, const Vint& x, Unit y)
{
Expand Down Expand Up @@ -609,9 +612,9 @@ class Vint {
int r;
if (q) {
q->isNeg_ = xNeg ^ yNeg;
if (!q->setSize(xn)) return 0;
r = (int)bint::divUnit(q->buf_, x.buf_, xn, absY);
q->trim(xn);
q->setSize(xn);
q->trim();
} else {
r = (int)bint::modUnit(x.buf_, xn, absY);
}
Expand Down Expand Up @@ -653,12 +656,12 @@ class Vint {
}
static Unit udivModu1(Vint *q, const Vint& x, Unit y)
{
assert(!x.isNeg_);
assert(!x.isNeg_ && y != 0);
size_t xn = x.size();
if (q && !q->setSize(N)) return 0;
if (q) q->setSize(xn);
Unit r = bint::divUnit(q ? q->buf_ : 0, x.buf_, xn, y);
if (q) {
q->trim(xn);
q->trim();
q->isNeg_ = false;
}
return r;
Expand Down Expand Up @@ -693,15 +696,11 @@ class Vint {
char buf[1024];
size_t n = fp::local::loadWord(buf, sizeof(buf), is);
if (n == 0) return;
const size_t maxN = 384 / (sizeof(MCL_SIZEOF_UNIT) * 8);
if (!setSize(maxN)) {
*pb = false;
return;
}
isNeg_ = false;
n = fp::strToArray(&isNeg_, buf_, maxN, buf, n, ioMode);
n = fp::strToArray(&isNeg_, buf_, N, buf, n, ioMode);
if (n == 0) return;
trim(n);
size_ = n;
trim();
*pb = true;
}
// logical left shift (copy sign)
Expand All @@ -710,10 +709,10 @@ class Vint {
assert(shiftBit <= MCL_MAX_BIT_SIZE * 2); // many be too big
size_t xn = x.size();
size_t yn = xn + (shiftBit + UnitBitSize - 1) / UnitBitSize;
if (!y.setSize(yn)) return;
bint::shiftLeft(y.buf_, x.buf_, shiftBit, xn);
y.isNeg_ = x.isNeg_;
y.trim(yn);
y.size_ = yn;
y.trim();
}
// logical right shift (copy sign)
static void shr(Vint& y, const Vint& x, size_t shiftBit)
Expand All @@ -725,10 +724,10 @@ class Vint {
return;
}
size_t yn = xn - shiftBit / UnitBitSize;
y.setSize(yn);
bint::shiftRight(&y.buf_[0], &x.buf_[0], shiftBit, xn);
bint::shiftRight(y.buf_, x.buf_, shiftBit, xn);
y.isNeg_ = x.isNeg_;
y.trim(yn);
y.size_ = yn;
y.trim();
}
static void neg(Vint& y, const Vint& x)
{
Expand All @@ -750,34 +749,29 @@ class Vint {
static void orBit(Vint& z, const Vint& x, const Vint& y)
{
assert(!x.isNeg_ && !y.isNeg_);
const Vint *px = &x, *py = &y;
if (x.size() < y.size()) {
fp::swap_(px, py);
size_t min = x.size_, max = y.size_;
const Unit *src = y.buf_;
if (x.size_ > y.size_) {
min = y.size_;
max = x.size_;
src = x.buf_;
}
size_t xn = px->size();
size_t yn = py->size();
assert(xn >= yn);
z.setSize(xn);
for (size_t i = 0; i < yn; i++) {
for (size_t i = 0; i < min; i++) {
z.buf_[i] = x.buf_[i] | y.buf_[i];
}
bint::copyN(z.buf_ + yn, px->buf_ + yn, xn - yn);
z.trim(xn);
bint::copyN(z.buf_ + min, src + min, max - min);
z.size_ = max;
z.isNeg_ = false;
}
static void andBit(Vint& z, const Vint& x, const Vint& y)
{
assert(!x.isNeg_ && !y.isNeg_);
const Vint *px = &x, *py = &y;
if (x.size() < y.size()) {
fp::swap_(px, py);
}
size_t yn = py->size();
assert(px->size() >= yn);
z.setSize(yn);
for (size_t i = 0; i < yn; i++) {
size_t zn = fp::min_(x.size_, y.size_);
for (size_t i = 0; i < zn; i++) {
z.buf_[i] = x.buf_[i] & y.buf_[i];
}
z.trim(yn);
z.size_ = zn;
z.trim();
}
static void orBitu1(Vint& z, const Vint& x, Unit y)
{
Expand Down
8 changes: 8 additions & 0 deletions test/vint_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,7 @@ CYBOZU_TEST_AUTO(shift)
y >>= i;
CYBOZU_TEST_EQUAL(y, z);
}

for (int i = 0; i < 3; i++) {
Vint::shr(y, x, i * UnitBitSize);
Vint::pow(s, Vint(2), i * UnitBitSize);
Expand Down Expand Up @@ -1291,8 +1292,15 @@ CYBOZU_TEST_AUTO(andOr)
Vint z;
z = x & y;
CYBOZU_TEST_EQUAL(z, Vint("1209221003550923564822922"));
z.clear();
z = y & x;
CYBOZU_TEST_EQUAL(z, Vint("1209221003550923564822922"));
z.clear();
z = x | y;
CYBOZU_TEST_EQUAL(z, Vint("29348220482094820948208435244134352108849315802"));
z.clear();
z = y | x;
CYBOZU_TEST_EQUAL(z, Vint("29348220482094820948208435244134352108849315802"));
#ifndef MCL_AVOID_EXCEPTION_TEST
// CYBOZU_TEST_EXCEPTION(Vint("-2") | Vint("5"), cybozu::Exception);
// CYBOZU_TEST_EXCEPTION(Vint("-2") & Vint("5"), cybozu::Exception);
Expand Down

0 comments on commit 76de55d

Please sign in to comment.