diff --git a/include/mcl/vint.hpp b/include/mcl/vint.hpp index d29cb4e8..25a80fb2 100644 --- a/include/mcl/vint.hpp +++ b/include/mcl/vint.hpp @@ -22,6 +22,17 @@ namespace mcl { +// return the max size in x[0..n) +inline int getRealSize(const Unit *x, int n) +{ + assert(n > 0); + while (n > 0) { + if (x[n - 1]) return n; + n--; + } + return 1; +} + /** signed integer with variable length */ @@ -31,13 +42,12 @@ class Vint { static const int invalidVar = -2147483647 - 1; // abs(invalidVar) is not defined static const size_t N = maxUnitSize * 2 + 1; private: - Unit buf_[N]; + Unit buf_[N]; // assume buf_[size_ - 1] != 0 unless the value is zero size_t size_; bool isNeg_; - void trim(size_t n) + void trim() { - assert(n > 0); - int i = (int)n - 1; + int i = (int)size_ - 1; for (; i > 0; i--) { if (buf_[i]) { size_ = i + 1; @@ -75,7 +85,7 @@ class Vint { c = bint::addUnit(dst + yn, n, c); } dst[xn] = c; - z.trim(xn + 1); + z.trim(); } static void uadd1(Vint& z, const Unit *x, size_t xn, Unit y) { @@ -86,7 +96,7 @@ class Vint { } if (z.buf_ != x) bint::copyN(z.buf_, x, xn); z.buf_[zn - 1] = bint::addUnit(z.buf_, xn, y); - z.trim(zn); + z.trim(); } static void usub1(Vint& z, const Unit *x, size_t xn, Unit y) { @@ -99,7 +109,7 @@ class Vint { Unit c = bint::subUnit(dst, xn, y); (void)c; assert(!c); - z.trim(zn); + z.trim(); } static void usub(Vint& z, const Unit *x, size_t xn, const Unit *y, size_t yn) { @@ -117,7 +127,7 @@ class Vint { c = bint::subUnit(dst, n, c); } assert(!c); - z.trim(xn); + z.trim(); } static void _add(Vint& z, const Vint& x, bool xNeg, const Vint& y, bool yNeg) { @@ -187,13 +197,13 @@ class Vint { } Unit *xx = (Unit*)CYBOZU_ALLOCA(sizeof(Unit) * xn); bint::copyN(xx, x, xn); - Unit *qq = q ? &q->buf_[0] : 0; - size_t rn = bint::div(qq, qn, xx, xn, &y[0], yn); + Unit *qq = q ? q->buf_ : 0; + size_t rn = bint::div(qq, qn, xx, xn, y, yn); r.copy(xx, rn); + r.trim(); if (q) { - q->trim(qn); + q->trim(); } - r.trim(rn); } /* @param x [inout] x <- d @@ -362,7 +372,7 @@ class Vint { if (!*pb) { return; } - trim(unitSize); + trim(); } /* set [0, max) randomly @@ -375,7 +385,7 @@ class Vint { size_ = n; rg.read(pb, buf_, n * sizeof(Unit)); if (!*pb) return; - trim(n); + trim(); *this %= max; } /* @@ -459,7 +469,7 @@ class Vint { buf_[q] |= mask; } else { buf_[q] &= ~mask; - trim(q + 1); + trim(); } } /* @@ -470,17 +480,10 @@ class Vint { */ void setStr(bool *pb, const char *str, int base = 0) { - // allow twice size of MCL_MAX_BIT_SIZE because of multiplication - const size_t maxN = (MCL_MAX_BIT_SIZE * 2 + UnitBitSize - 1) / UnitBitSize; - if (!setSize(maxN)) { - *pb = false; - return; - } - isNeg_ = false; size_t len = strlen(str); - size_t n = fp::strToArray(&isNeg_, &buf_[0], maxN, str, len, base); - if (n == 0) return; - trim(n); + size_t n = fp::strToArray(&isNeg_, buf_, N, str, len, base); + if (n == 0 || !setSize(n)) return; + trim(); *pb = true; } static int compare(const Vint& x, const Vint& y) @@ -543,7 +546,7 @@ class Vint { size_t zn = xn + yn; if (!z.setSize(zn)) return; bint::mulNM(z.buf_, x.buf_, xn, y.buf_, yn); - z.trim(zn); + z.trim(); z.isNeg_ = x.isNeg_ ^ y.isNeg_; } static void sqr(Vint& y, const Vint& x) @@ -565,7 +568,7 @@ class Vint { if (!z.setSize(zn)) return; z.buf_[zn - 1] = bint::mulUnitN(z.buf_, x.buf_, y, xn); z.isNeg_ = x.isNeg_; - z.trim(zn); + z.trim(); } static void divu1(Vint& q, const Vint& x, Unit y) { @@ -609,9 +612,9 @@ class Vint { int r; if (q) { q->isNeg_ = xNeg ^ yNeg; - if (!q->setSize(xn)) return 0; r = (int)bint::divUnit(q->buf_, x.buf_, xn, absY); - q->trim(xn); + q->setSize(xn); + q->trim(); } else { r = (int)bint::modUnit(x.buf_, xn, absY); } @@ -653,12 +656,12 @@ class Vint { } static Unit udivModu1(Vint *q, const Vint& x, Unit y) { - assert(!x.isNeg_); + assert(!x.isNeg_ && y != 0); size_t xn = x.size(); - if (q && !q->setSize(N)) return 0; + if (q) q->setSize(xn); Unit r = bint::divUnit(q ? q->buf_ : 0, x.buf_, xn, y); if (q) { - q->trim(xn); + q->trim(); q->isNeg_ = false; } return r; @@ -693,15 +696,11 @@ class Vint { char buf[1024]; size_t n = fp::local::loadWord(buf, sizeof(buf), is); if (n == 0) return; - const size_t maxN = 384 / (sizeof(MCL_SIZEOF_UNIT) * 8); - if (!setSize(maxN)) { - *pb = false; - return; - } isNeg_ = false; - n = fp::strToArray(&isNeg_, buf_, maxN, buf, n, ioMode); + n = fp::strToArray(&isNeg_, buf_, N, buf, n, ioMode); if (n == 0) return; - trim(n); + size_ = n; + trim(); *pb = true; } // logical left shift (copy sign) @@ -710,10 +709,10 @@ class Vint { assert(shiftBit <= MCL_MAX_BIT_SIZE * 2); // many be too big size_t xn = x.size(); size_t yn = xn + (shiftBit + UnitBitSize - 1) / UnitBitSize; - if (!y.setSize(yn)) return; bint::shiftLeft(y.buf_, x.buf_, shiftBit, xn); y.isNeg_ = x.isNeg_; - y.trim(yn); + y.size_ = yn; + y.trim(); } // logical right shift (copy sign) static void shr(Vint& y, const Vint& x, size_t shiftBit) @@ -725,10 +724,10 @@ class Vint { return; } size_t yn = xn - shiftBit / UnitBitSize; - y.setSize(yn); - bint::shiftRight(&y.buf_[0], &x.buf_[0], shiftBit, xn); + bint::shiftRight(y.buf_, x.buf_, shiftBit, xn); y.isNeg_ = x.isNeg_; - y.trim(yn); + y.size_ = yn; + y.trim(); } static void neg(Vint& y, const Vint& x) { @@ -750,34 +749,29 @@ class Vint { static void orBit(Vint& z, const Vint& x, const Vint& y) { assert(!x.isNeg_ && !y.isNeg_); - const Vint *px = &x, *py = &y; - if (x.size() < y.size()) { - fp::swap_(px, py); + size_t min = x.size_, max = y.size_; + const Unit *src = y.buf_; + if (x.size_ > y.size_) { + min = y.size_; + max = x.size_; + src = x.buf_; } - size_t xn = px->size(); - size_t yn = py->size(); - assert(xn >= yn); - z.setSize(xn); - for (size_t i = 0; i < yn; i++) { + for (size_t i = 0; i < min; i++) { z.buf_[i] = x.buf_[i] | y.buf_[i]; } - bint::copyN(z.buf_ + yn, px->buf_ + yn, xn - yn); - z.trim(xn); + bint::copyN(z.buf_ + min, src + min, max - min); + z.size_ = max; + z.isNeg_ = false; } static void andBit(Vint& z, const Vint& x, const Vint& y) { assert(!x.isNeg_ && !y.isNeg_); - const Vint *px = &x, *py = &y; - if (x.size() < y.size()) { - fp::swap_(px, py); - } - size_t yn = py->size(); - assert(px->size() >= yn); - z.setSize(yn); - for (size_t i = 0; i < yn; i++) { + size_t zn = fp::min_(x.size_, y.size_); + for (size_t i = 0; i < zn; i++) { z.buf_[i] = x.buf_[i] & y.buf_[i]; } - z.trim(yn); + z.size_ = zn; + z.trim(); } static void orBitu1(Vint& z, const Vint& x, Unit y) { diff --git a/test/vint_test.cpp b/test/vint_test.cpp index 998fdccc..fdf07c8a 100644 --- a/test/vint_test.cpp +++ b/test/vint_test.cpp @@ -813,6 +813,7 @@ CYBOZU_TEST_AUTO(shift) y >>= i; CYBOZU_TEST_EQUAL(y, z); } + for (int i = 0; i < 3; i++) { Vint::shr(y, x, i * UnitBitSize); Vint::pow(s, Vint(2), i * UnitBitSize); @@ -1291,8 +1292,15 @@ CYBOZU_TEST_AUTO(andOr) Vint z; z = x & y; CYBOZU_TEST_EQUAL(z, Vint("1209221003550923564822922")); + z.clear(); + z = y & x; + CYBOZU_TEST_EQUAL(z, Vint("1209221003550923564822922")); + z.clear(); z = x | y; CYBOZU_TEST_EQUAL(z, Vint("29348220482094820948208435244134352108849315802")); + z.clear(); + z = y | x; + CYBOZU_TEST_EQUAL(z, Vint("29348220482094820948208435244134352108849315802")); #ifndef MCL_AVOID_EXCEPTION_TEST // CYBOZU_TEST_EXCEPTION(Vint("-2") | Vint("5"), cybozu::Exception); // CYBOZU_TEST_EXCEPTION(Vint("-2") & Vint("5"), cybozu::Exception);