Skip to content

Commit

Permalink
Updated modulararithmetic with necessary functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Chillee committed May 4, 2019
1 parent e165e20 commit 464bfe7
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 35 deletions.
32 changes: 14 additions & 18 deletions content/number-theory/ModularArithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,21 @@
*/
#pragma once

#include "euclid.h"

const ll mod = 17; // change to something else
struct Mod {
ll x;
Mod():x(0) {}
Mod(ll xx) : x(xx) {}
Mod operator+(Mod b) { return Mod((x + b.x) % mod); }
Mod operator-(Mod b) { return Mod((x - b.x + mod) % mod); }
Mod operator*(Mod b) { return Mod((x * b.x) % mod); }
Mod operator/(Mod b) { return *this * invert(b); }
Mod invert(Mod a) {
ll x, y, g = euclid(a.x, mod, x, y);
assert(g == 1); return Mod((x + mod) % mod);
}
ll v;
Mod() : v(0) {}
Mod(ll vv) : v(vv % mod) {}
Mod operator+(Mod b) { return Mod((v + b.v) % mod); }
Mod operator-(Mod b) { return Mod(v - b.v + mod); }
Mod operator*(Mod b) { return Mod(v * b.v); }
Mod operator/(Mod b) { return *this * invert(b); }
Mod invert(Mod a) { return a^(mod-2); }
Mod operator^(ll e) {
if (!e) return Mod(1);
Mod r = *this ^ (e / 2); r = r * r;
return e&1 ? *this * r : r;
ll ans = 1, b = (*this).v;
for (; e; b = b * b % mod, e /= 2)
if (e & 1) ans = ans * b % mod;
return ans;
}
explicit operator ll() const { return x; }
};
explicit operator ll() const { return v; }
};
6 changes: 3 additions & 3 deletions content/numerical/PolyBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "../number-theory/ModularArithmetic.h"
#include "FastFourierTransform.h"
#include "FastFourierTransformMod.h"
// #include "NumberTheoreticTransform.h"
#include "NumberTheoreticTransform.h"

typedef Mod num;
typedef vector<num> poly;
Expand All @@ -30,8 +30,8 @@ poly &operator*=(poly &a, const poly &b) {
res[i + j] = (res[i + j] + a[i] * b[j]);
return (a = res);
}
auto res = convMod<mod>(vl(all(a)), vl(all(b)));
// auto res = conv(vl(all(a)), vl(all(b)));
// auto res = convMod<mod>(vl(all(a)), vl(all(b)));
auto res = conv(vl(all(a)), vl(all(b)));
return (a = poly(all(res)));
}
poly operator*(poly a, const num b) {
Expand Down
2 changes: 1 addition & 1 deletion content/numerical/PolyPow.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

poly pow(poly a, ll m) {
int p = 0, n = sz(a);
while (p < sz(a) && a[p].x == 0)
while (p < sz(a) && a[p].v == 0)
++p;
if (ll(m)*p >= sz(a)) return poly(sz(a));
num j = a[p];
Expand Down
27 changes: 14 additions & 13 deletions fuzz-tests/numerical/Polynomial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,20 +399,21 @@ ll modpow(ll a, ll e) {
}
#include "../../content/numerical/FastFourierTransformMod.h"
struct Mod {
ll x;
Mod() : x(0) {}
Mod(ll xx) : x(xx % mod) {}
Mod operator+(Mod b) { return Mod((x + b.x) % mod); }
Mod operator-(Mod b) { return Mod(x < b.x ? x - b.x + mod : x - b.x); }
Mod operator*(Mod b) { return Mod(x * b.x); }
ll v;
Mod() : v(0) {}
Mod(ll vv) : v(vv % mod) {}
Mod operator+(Mod b) { return Mod((v + b.v) % mod); }
Mod operator-(Mod b) { return Mod(v < b.v ? v - b.v + mod : v - b.v); }
Mod operator*(Mod b) { return Mod(v * b.v); }
Mod operator/(Mod b) { return *this * invert(b); }
Mod invert(Mod a) { return a^(mod-2); }
Mod operator^(ll e) {
if (!e) return Mod(1);
Mod r = *this ^ (e / 2); r = r * r;
return e&1 ? *this * r : r;
ll ans = 1, b = (*this).v;
for (; e; b = b * b % mod, e /= 2)
if (e & 1) ans = ans * b % mod;
return ans;
}
explicit operator ll() const { return x; }
explicit operator ll() const { return v; }
};

typedef Mod num;
Expand Down Expand Up @@ -442,10 +443,10 @@ bool checkEqual(mine::poly a, MIT::poly b) {
return false;
int ml = min(sz(a), sz(b));
for (int i = 0; i < ml; i++)
if (a[i].x != b[i].v)
if (a[i].v != b[i].v)
return false;
// for (int i = ml; i < sz(a); i++)
// if (a[i].x != 0)
// if (a[i].v != 0)
// return false;
// for (int i = ml; i < sz(b); i++)
// if (b[i].v != 0)
Expand All @@ -456,7 +457,7 @@ bool checkEqual(mine::poly a, MIT::poly b) {
template <class A, class B> void fail(A mine, B mit) {
cout<<"mine: ";
for (auto i : mine)
cout << i.x << ' ';
cout << i.v << ' ';
cout << endl;
cout<<"MIT: ";
for (auto i : mit)
Expand Down

0 comments on commit 464bfe7

Please sign in to comment.