diff --git a/content/number-theory/ModularArithmetic.h b/content/number-theory/ModularArithmetic.h index b9d01bba8..1e2f89643 100644 --- a/content/number-theory/ModularArithmetic.h +++ b/content/number-theory/ModularArithmetic.h @@ -8,25 +8,21 @@ */ #pragma once -#include "euclid.h" - const ll mod = 17; // change to something else struct Mod { - ll x; - Mod():x(0) {} - Mod(ll xx) : x(xx) {} - Mod operator+(Mod b) { return Mod((x + b.x) % mod); } - Mod operator-(Mod b) { return Mod((x - b.x + mod) % mod); } - Mod operator*(Mod b) { return Mod((x * b.x) % mod); } - Mod operator/(Mod b) { return *this * invert(b); } - Mod invert(Mod a) { - ll x, y, g = euclid(a.x, mod, x, y); - assert(g == 1); return Mod((x + mod) % mod); - } + ll v; + Mod() : v(0) {} + Mod(ll vv) : v(vv % mod) {} + Mod operator+(Mod b) { return Mod((v + b.v) % mod); } + Mod operator-(Mod b) { return Mod(v - b.v + mod); } + Mod operator*(Mod b) { return Mod(v * b.v); } + Mod operator/(Mod b) { return *this * invert(b); } + Mod invert(Mod a) { return a^(mod-2); } Mod operator^(ll e) { - if (!e) return Mod(1); - Mod r = *this ^ (e / 2); r = r * r; - return e&1 ? *this * r : r; + ll ans = 1, b = (*this).v; + for (; e; b = b * b % mod, e /= 2) + if (e & 1) ans = ans * b % mod; + return ans; } - explicit operator ll() const { return x; } -}; + explicit operator ll() const { return v; } +}; \ No newline at end of file diff --git a/content/numerical/PolyBase.h b/content/numerical/PolyBase.h index 2bedfd8e3..b70b9dff1 100644 --- a/content/numerical/PolyBase.h +++ b/content/numerical/PolyBase.h @@ -8,7 +8,7 @@ #include "../number-theory/ModularArithmetic.h" #include "FastFourierTransform.h" #include "FastFourierTransformMod.h" -// #include "NumberTheoreticTransform.h" +#include "NumberTheoreticTransform.h" typedef Mod num; typedef vector poly; @@ -30,8 +30,8 @@ poly &operator*=(poly &a, const poly &b) { res[i + j] = (res[i + j] + a[i] * b[j]); return (a = res); } - auto res = convMod(vl(all(a)), vl(all(b))); - // auto res = conv(vl(all(a)), vl(all(b))); + // auto res = convMod(vl(all(a)), vl(all(b))); + auto res = conv(vl(all(a)), vl(all(b))); return (a = poly(all(res))); } poly operator*(poly a, const num b) { diff --git a/content/numerical/PolyPow.h b/content/numerical/PolyPow.h index eb855e011..b259e7b3d 100644 --- a/content/numerical/PolyPow.h +++ b/content/numerical/PolyPow.h @@ -10,7 +10,7 @@ poly pow(poly a, ll m) { int p = 0, n = sz(a); - while (p < sz(a) && a[p].x == 0) + while (p < sz(a) && a[p].v == 0) ++p; if (ll(m)*p >= sz(a)) return poly(sz(a)); num j = a[p]; diff --git a/fuzz-tests/numerical/Polynomial.cpp b/fuzz-tests/numerical/Polynomial.cpp index 5aa9e91b8..1b53166e0 100644 --- a/fuzz-tests/numerical/Polynomial.cpp +++ b/fuzz-tests/numerical/Polynomial.cpp @@ -399,20 +399,21 @@ ll modpow(ll a, ll e) { } #include "../../content/numerical/FastFourierTransformMod.h" struct Mod { - ll x; - Mod() : x(0) {} - Mod(ll xx) : x(xx % mod) {} - Mod operator+(Mod b) { return Mod((x + b.x) % mod); } - Mod operator-(Mod b) { return Mod(x < b.x ? x - b.x + mod : x - b.x); } - Mod operator*(Mod b) { return Mod(x * b.x); } + ll v; + Mod() : v(0) {} + Mod(ll vv) : v(vv % mod) {} + Mod operator+(Mod b) { return Mod((v + b.v) % mod); } + Mod operator-(Mod b) { return Mod(v < b.v ? v - b.v + mod : v - b.v); } + Mod operator*(Mod b) { return Mod(v * b.v); } Mod operator/(Mod b) { return *this * invert(b); } Mod invert(Mod a) { return a^(mod-2); } Mod operator^(ll e) { - if (!e) return Mod(1); - Mod r = *this ^ (e / 2); r = r * r; - return e&1 ? *this * r : r; + ll ans = 1, b = (*this).v; + for (; e; b = b * b % mod, e /= 2) + if (e & 1) ans = ans * b % mod; + return ans; } - explicit operator ll() const { return x; } + explicit operator ll() const { return v; } }; typedef Mod num; @@ -442,10 +443,10 @@ bool checkEqual(mine::poly a, MIT::poly b) { return false; int ml = min(sz(a), sz(b)); for (int i = 0; i < ml; i++) - if (a[i].x != b[i].v) + if (a[i].v != b[i].v) return false; // for (int i = ml; i < sz(a); i++) - // if (a[i].x != 0) + // if (a[i].v != 0) // return false; // for (int i = ml; i < sz(b); i++) // if (b[i].v != 0) @@ -456,7 +457,7 @@ bool checkEqual(mine::poly a, MIT::poly b) { template void fail(A mine, B mit) { cout<<"mine: "; for (auto i : mine) - cout << i.x << ' '; + cout << i.v << ' '; cout << endl; cout<<"MIT: "; for (auto i : mit)