Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a FFTPolynomial file #87

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
6011aa5
Added FFTMod
Chillee Apr 23, 2019
ba0cc27
Shortened a bit
Chillee Apr 23, 2019
bba7a32
Shortened FFTMod and refactored roots code into FFT itself
Chillee Apr 25, 2019
a2b9413
Updated formatting
Chillee Apr 25, 2019
e83e65b
removed another 2 lines
Chillee Apr 25, 2019
5b5a4f4
Moved FFTMod to different file and fixed -Wconversion errors
Chillee Apr 25, 2019
12f0a0a
Updated headers for FFTMod
Chillee Apr 25, 2019
75ba4da
Updated header
Chillee Apr 25, 2019
81e2ae8
Moved numerical precision commnets to description
Chillee Apr 25, 2019
77adbd4
Merge branch 'master' into fftmod
Chillee Apr 25, 2019
5243182
Fixed typo
Chillee Apr 26, 2019
ba5bc7d
Merge branch 'fftmod' of github.com:Chillee/kactl into fftmod
Chillee Apr 26, 2019
ace08ab
Made things fit within 63 columns
Chillee Apr 26, 2019
7d26dee
Fixed some formatting issues
Chillee Apr 26, 2019
7c16aeb
Fixed rep space issues
Chillee Apr 26, 2019
c77b038
Fixed spacing issues
Chillee Apr 26, 2019
fe7b393
Fixed formatting issues
Chillee Apr 26, 2019
64ef0ab
Modified header
Chillee Apr 26, 2019
f5ecd33
Switched to long double for roots calculations due to precision issues
Chillee Apr 26, 2019
b7eb861
Updated headers with correct error bounds
Chillee Apr 26, 2019
f072eb1
Removed one of the papers about accuracy
Chillee Apr 26, 2019
d5664e4
Fixed formatting
Chillee Apr 27, 2019
2ea261d
Merge branch 'master' of github.com:kth-competitive-programming/kactl…
Chillee Apr 27, 2019
cdd8367
removed extraneous spaces
Chillee Apr 27, 2019
99f761d
Fixed wconversion warning
Chillee Apr 27, 2019
b984354
Changed from vi to vl
Chillee Apr 27, 2019
b230f10
Added FFTPolynomial class
Chillee Apr 27, 2019
6f5b435
Added initial file
Chillee Apr 27, 2019
8bd2091
Fixed some minor formatting issues
Chillee Apr 27, 2019
3fe590e
Added divide, inverse, mod, derive, and integr
Chillee Apr 27, 2019
5ea7caa
Added log/exp
Chillee Apr 27, 2019
4c549b9
made a small change
Chillee Apr 27, 2019
e3ff553
Added pow
Chillee Apr 29, 2019
936726e
Added authors, a fuzz test, and interp
Chillee Apr 29, 2019
4b4f443
Updated in response to comments
Chillee Apr 29, 2019
14cb23e
updated benchmarks
Chillee Apr 29, 2019
6148b2d
Updated benchmarks
Chillee Apr 29, 2019
3ca9b81
Updated with naive mul for lower values
Chillee May 3, 2019
92095f1
Updated polynomial.cpp to import from headers
Chillee May 3, 2019
f277725
Switched to tabs
Chillee May 3, 2019
02e9270
Reorganized functions
Chillee May 4, 2019
120f099
Restructured organization a bit
Chillee May 4, 2019
094e5d5
Split into finer subsections
Chillee May 4, 2019
a4616a9
Fixed some of the includes
Chillee May 4, 2019
640a997
Split out inverse.h
Chillee May 4, 2019
ccb400b
Fixed import issues
Chillee May 4, 2019
e165e20
Merge remote-tracking branch 'true/master' into polynomial
Chillee May 4, 2019
464bfe7
Updated modulararithmetic with necessary functions
Chillee May 4, 2019
613267c
Updated to tabs
Chillee May 5, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions content/number-theory/ModularArithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,21 @@
*/
#pragma once

#include "euclid.h"

const ll mod = 17; // change to something else
struct Mod {
ll x;
Mod(ll xx) : x(xx) {}
Mod operator+(Mod b) { return Mod((x + b.x) % mod); }
Mod operator-(Mod b) { return Mod((x - b.x + mod) % mod); }
Mod operator*(Mod b) { return Mod((x * b.x) % mod); }
ll v;
Mod() : v(0) {}
Mod(ll vv) : v(vv % mod) {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ll vv = 0 to get rid of the default constructor

Mod operator+(Mod b) { return Mod((v + b.v) % mod); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

% mod is unnecessary

Mod operator-(Mod b) { return Mod(v - b.v + mod); }
Mod operator*(Mod b) { return Mod(v * b.v); }
Mod operator/(Mod b) { return *this * invert(b); }
Mod invert(Mod a) {
ll x, y, g = euclid(a.x, mod, x, y);
assert(g == 1); return Mod((x + mod) % mod);
}
Mod invert(Mod a) { return a^(mod-2); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assumes the modulo is prime, which is probably worth a comment. Also, inline this into operator/. Or maybe keep this method with a comment about it being only for composite moduli.

Mod operator^(ll e) {
if (!e) return Mod(1);
Mod r = *this ^ (e / 2); r = r * r;
return e&1 ? *this * r : r;
ll ans = 1, b = (*this).v;
for (; e; b = b * b % mod, e /= 2)
if (e & 1) ans = ans * b % mod;
return ans;
}
};
explicit operator ll() const { return v; }
};
37 changes: 20 additions & 17 deletions content/numerical/FastFourierTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,31 @@
* Date: 2019-01-09
* License: CC0
* Source: http://neerc.ifmo.ru/trains/toulouse/2017/fft2.pdf (do read, it's excellent)
Papers about accuracy: http://www.daemonology.net/papers/fft.pdf, http://www.cs.berkeley.edu/~fateman/papers/fftvsothers.pdf
For integers rounding works if $(|a| + |b|)\max(a, b) < \mathtt{\sim} 10^9$, or in theory maybe $10^6$.
Accuracy bound from http://www.daemonology.net/papers/fft.pdf
* Description: fft(a, ...) computes $\hat f(k) = \sum_x a[x] \exp(2\pi i \cdot k x / N)$ for all $k$. Useful for convolution:
\texttt{conv(a, b) = c}, where $c[x] = \sum a[i]b[x-i]$.
For convolution of complex numbers or more than two vectors: FFT, multiply
pointwise, divide by n, reverse(start+1, end), FFT back.
For integers, consider using a number-theoretic transform instead, to avoid rounding issues.
Let N be $\max(|a|,|b|)$. Is guaranteed safe as long as $N\log_2{N}\max(a)\max(b) < \mathtt{\sim} 10^{16}$ .
Consider using number-theoretic transform or FFTMod instead if precision is an issue.
* Time: O(N \log N), where $N = |A|+|B|-1$ ($\tilde 1s$ for $N=2^{22}$)
* Status: somewhat tested
*/
#pragma once

typedef complex<double> C;
typedef complex<long double> Cd;
typedef vector<double> vd;

void fft(vector<C> &a, vector<C> &rt, vi& rev, int n) {
void fft(vector<C> &a, int n, int L, vector<C> &rt) {
vi rev(n);
rep(i,0,n) rev[i] = (rev[i / 2] | (i & 1) << L) / 2;
if (rt.empty()) {
rt.assign(n, 1);
for (int k = 2; k < n; k *= 2) {
Cd z[] = {1, polar(1.0, M_PI / k)};
rep(i, k, 2 * k) rt[i] = Cd(rt[i / 2]) * z[i & 1];
}
}
rep(i,0,n) if (i < rev[i]) swap(a[i], a[rev[i]]);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can merge this loop with the generating loop if you want, would keep things more compact and (marginally) improve cache locality.

for (int k = 1; k < n; k *= 2)
for (int i = 0; i < n; i += 2 * k) rep(j,0,k) {
Expand All @@ -27,25 +36,19 @@ void fft(vector<C> &a, vector<C> &rt, vi& rev, int n) {
C z(x[0]*y[0] - x[1]*y[1], x[0]*y[1] + x[1]*y[0]); /// exclude-line
a[i + j + k] = a[i + j] - z;
a[i + j] += z;
}
}
}

vd conv(const vd& a, const vd& b) {
vd conv(const vd &a, const vd &b) {
if (a.empty() || b.empty()) return {};
vd res(sz(a) + sz(b) - 1);
int L = 32 - __builtin_clz(sz(res)), n = 1 << L;
vector<C> in(n), out(n), rt(n, 1); vi rev(n);
rep(i,0,n) rev[i] = (rev[i/2] | (i&1) << L) / 2;
for (int k = 2; k < n; k *= 2) {
C z[] = {1, polar(1.0, M_PI / k)};
rep(i,k,2*k) rt[i] = rt[i/2] * z[i&1];
}
vector<C> in(n), out(n), rt;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make these static if you want to reuse the memory?

copy(all(a), begin(in));
rep(i,0,sz(b)) in[i].imag(b[i]);
fft(in, rt, rev, n);
fft(in, n, L, rt);
trav(x, in) x *= x;
rep(i,0,n) out[i] = in[-i & (n - 1)] - conj(in[i]);
fft(out, rt, rev, n);
rep(i,0,sz(res)) res[i] = imag(out[i]) / (4*n);
fft(out, n, L, rt);
rep(i,0,sz(res)) res[i] = imag(out[i]) / (4 * n);
return res;
}
36 changes: 36 additions & 0 deletions content/numerical/FastFourierTransformMod.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/**
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(need to merge with master to get rid of these fft changes from the diff)

* Author: chilli
* Date: 2019-04-25
* License: CC0
* Source: http://neerc.ifmo.ru/trains/toulouse/2017/fft2.pdf
* Description: Higher precision FFT, can be used for convolutions modulo arbitrary integers.
* Let N be $\max(|a|,|b|)$. Is guaranteed safe as long as $N\log_2{N}\sqrt{\max(a)\max(b)} < \mathtt{\sim} 10^{16}$ .
* Time: O(N \log N), where $N = |A|+|B|-1$ (twice as slow as NTT or FFT)
* Status: somewhat tested
*/
#pragma once

#include "FastFourierTransform.h"

typedef vector<ll> vl;
template <int M> vl convMod(const vl &a, const vl &b) {
if (a.empty() || b.empty()) return {};
vl res(sz(a) + sz(b) - 1);
int B=32-__builtin_clz(sz(res)), n = 1<<B, cut=int(sqrt(M));
vector<C> L(n), R(n), outs(n), outl(n), rt;
rep(i,0,sz(a)) L[i] = Cd(a[i] / cut, a[i] % cut);
rep(i,0,sz(b)) R[i] = Cd(b[i] / cut, b[i] % cut);
fft(L, n, B, rt), fft(R, n, B, rt);
rep(i,0,n) {
int j = -i & (n - 1);
outl[j] = (L[i] + conj(L[j])) * R[i] / (2.0 * n);
outs[j] = (L[i] - conj(L[j])) * R[i] / (2.0 * n) / 1i;
}
fft(outl, n, B, rt), fft(outs, n, B, rt);
rep(i,0,sz(res)) {
ll av = ll(outl[i].real()+.5), cv = ll(outs[i].imag()+.5);
ll bv = ll(outl[i].imag()+.5) + ll(outs[i].real()+.5);
res[i] = ((av % M * cut + bv % M) * cut + cv % M) % M;
}
return res;
}
47 changes: 47 additions & 0 deletions content/numerical/PolyBase.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/**
* Author: chilli, Andrew He, Adamant
* Date: 2019-04-27
* Description: A FFT based Polynomial class.
*/
#pragma once

#include "../number-theory/ModularArithmetic.h"
#include "FastFourierTransform.h"
#include "FastFourierTransformMod.h"
#include "NumberTheoreticTransform.h"

typedef Mod num;
typedef vector<num> poly;
poly &operator+=(poly &a, const poly &b) {
a.resize(max(sz(a), sz(b)));
rep(i, 0, sz(b)) a[i] = a[i] + b[i];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rep(i, 0, sz(b)) a[i] = a[i] + b[i];
rep(i, 0, sz(b)) a[i] = a[i] + b[i];
rep(i,0,sz(b)) a[i] = a[i] + b[i];

(same in other places)

return a;
}
poly &operator-=(poly &a, const poly &b) {
a.resize(max(sz(a), sz(b)));
rep(i, 0, sz(b)) a[i] = a[i] - b[i];
return a;
}

poly &operator*=(poly &a, const poly &b) {
if (sz(a) + sz(b) < 100){
poly res(sz(a) + sz(b) - 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max(..., 0)? (does this special case ever come up? should we just require size >= 1 for everything? I don't have the experience to say)

rep(i,0,sz(a)) rep(j,0,sz(b))
res[i + j] = (res[i + j] + a[i] * b[j]);
return (a = res);
}
// auto res = convMod<mod>(vl(all(a)), vl(all(b)));
auto res = conv(vl(all(a)), vl(all(b)));
return (a = poly(all(res)));
}
poly operator*(poly a, const num b) {
poly c = a;
trav(i, c) i = i * b;
return c;
}
#define OP(o, oe) \
poly operator o(poly a, poly b) { \
poly c = a; \
return c o##= b; \
}
OP(*, *=) OP(+, +=) OP(-, -=);
25 changes: 25 additions & 0 deletions content/numerical/PolyEvaluate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/**
* Author: chilli, Andrew He, Adamant
* Date: 2019-04-27
* Description: Multi-point evaluation. Evaluates a given polynomial A at $A(x_0), ... A(x_n)$.
* Time: O(n \log^2 n)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where n is the max of the two sizes involved, I guess. not sure if worth mentioning

*/
#pragma once

#include "PolyBase.h"
#include "PolyMod.h"

vector<num> eval(const poly &a, const vector<num> &x) {
int n = sz(x);
if (!n) return {};
vector<poly> up(2 * n);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combine up and down into a single declaration (same elsewhere)

rep(i, 0, n) up[i + n] = poly({num(0) - x[i], 1});
for (int i = n - 1; i > 0; i--)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for (int i = n; i--;), or even while (i--) ...; if i's declaration is moved to that of n

up[i] = up[2 * i] * up[2 * i + 1];
vector<poly> down(2 * n);
down[1] = a % up[1];
rep(i, 2, 2 * n) down[i] = down[i / 2] % up[i];
vector<num> y(n);
rep(i, 0, n) y[i] = down[i + n][0];
return y;
}
22 changes: 22 additions & 0 deletions content/numerical/PolyIntegDeriv.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/**
* Author: chilli, Andrew He, Adamant
* Date: 2019-04-27
* Description: A FFT based Polynomial class.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the description (I don't understand what integr does at all). Same for other algorithms

*/
#pragma once
#include "PolyBase.h"

poly deriv(poly a) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const & for the parameters, or modify a in-place

if (a.empty()) return {};
poly b(sz(a) - 1);
rep(i, 1, sz(a)) b[i - 1] = a[i] * num(i);
return b;
}
poly integr(poly a) {
if (a.empty()) return {0};
poly b(sz(a) + 1);
b[1] = num(1);
rep(i, 2, sz(b)) b[i] = b[mod%i]*Mod(-mod/i+mod);
rep(i, 1 ,sz(b)) b[i] = a[i-1] * b[i];
return b;
}
35 changes: 17 additions & 18 deletions content/numerical/PolyInterpolate.h
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
/**
* Author: Simon Lindholm
* Date: 2017-05-10
* License: CC0
* Source: Wikipedia
* Author: chilli, Andrew He, Adamant

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much of my work comes from THU, probably should give them credit too.

* Date: 2019-04-27
* Description: Given $n$ points (x[i], y[i]), computes an n-1-degree polynomial $p$ that
* passes through them: $p(x) = a[0]*x^0 + ... + a[n-1]*x^{n-1}$.
* For numerical precision, pick $x[k] = c*\cos(k/(n-1)*\pi), k=0 \dots n-1$.
* Time: O(n^2)
* Time: O(n \log^2 n)
*/
#pragma once

typedef vector<double> vd;
vd interpolate(vd x, vd y, int n) {
vd res(n), temp(n);
rep(k,0,n-1) rep(i,k+1,n)
y[i] = (y[i] - y[k]) / (x[i] - x[k]);
double last = 0; temp[0] = 1;
rep(k,0,n) rep(i,0,n) {
res[i] += y[k] * temp[i];
swap(last, temp[i]);
temp[i] -= last * x[k];
}
return res;
#include "PolyBase.h"
#include "PolyIntegDeriv.h"
#include "PolyEvaluate.h"

poly interp(vector<num> x, vector<num> y) {
int n=sz(x);
vector<poly> up(n*2);
rep(i,0,n) up[i+n] = poly({num(0)-x[i], num(1)});
for(int i=n-1; i>0;i--) up[i] = up[2*i]*up[2*i+1];
vector<num> a = eval(deriv(up[1]), x);
vector<poly> down(2*n);
rep(i,0,n) down[i+n] = poly({y[i]*(num(1)/a[i])});
for(int i=n-1;i>0;i--) down[i] = down[i*2] * up[i*2+1] + down[i*2+1] * up[i*2];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for (int i=n; i-->1;) seems harder to typo

return down[1];
}
25 changes: 25 additions & 0 deletions content/numerical/PolyInterpolateSlow.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/**
* Author: Simon Lindholm
* Date: 2017-05-10
* License: CC0
* Source: Wikipedia
* Description: Given $n$ points (x[i], y[i]), computes an n-1-degree polynomial $p$ that
* passes through them: $p(x) = a[0]*x^0 + ... + a[n-1]*x^{n-1}$.
* For numerical precision, pick $x[k] = c*\cos(k/(n-1)*\pi), k=0 \dots n-1$.
* Time: O(n^2)
*/
#pragma once

typedef vector<double> vd;
vd interpolate(vd x, vd y, int n) {
vd res(n), temp(n);
rep(k,0,n-1) rep(i,k+1,n)
y[i] = (y[i] - y[k]) / (x[i] - x[k]);
double last = 0; temp[0] = 1;
rep(k,0,n) rep(i,0,n) {
res[i] += y[k] * temp[i];
swap(last, temp[i]);
temp[i] -= last * x[k];
}
return res;
}
16 changes: 16 additions & 0 deletions content/numerical/PolyInverse.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/**
* Author: chilli, Andrew He, Adamant
* Date: 2019-04-27
* Description: A FFT based Polynomial class.
*/
#pragma once

#include "PolyBase.h"

poly modK(poly a, int k) { return {a.begin(), a.begin() + min(k, sz(a))}; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return {all(a) - max(0, sz(a) - k)} maybe

poly inverse(poly A) {
poly B = poly({num(1) / A[0]});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

poly B{num(1) / A[0]};

while (sz(B) < sz(A))
B = modK(B * (poly({num(2)}) - modK(A, 2*sz(B)) * B), 2 * sz(B));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

poly{num(2)} should work as well (or just poly{2}? not sure what nums we are expecting)

return modK(B, sz(A));
}
25 changes: 25 additions & 0 deletions content/numerical/PolyLogExp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/**
* Author: chilli, Andrew He, Adamant
* Date: 2019-04-27
* Description: A FFT based Polynomial class.
*/
#pragma once

#include "PolyBase.h"
#include "PolyInverse.h"
#include "PolyIntegDeriv.h"

poly log(poly a) {
return modK(integr(deriv(a) * inverse(a)), sz(a));
}
poly exp(poly a) {
poly b(1, num(1));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

poly b{num(1)}

if (a.empty())
return b;
while (sz(b) < sz(a)) {
b.resize(sz(b) * 2);
b *= (poly({num(1)}) + modK(a, sz(b)) - log(b));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

poly{num(1)}

b.resize(sz(b) / 2 + 1);
}
return modK(b, sz(a));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit: missing newline at end of file in a few places)

30 changes: 30 additions & 0 deletions content/numerical/PolyMod.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/**
* Author: chilli, Andrew He, Adamant
* Date: 2019-04-27
* Description: A FFT based Polynomial class.
*/
#pragma once

#include "PolyBase.h"
#include "PolyInverse.h"

poly &operator/=(poly &a, poly b) {
if (sz(a) < sz(b))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same line works

return a = {};
int s = sz(a) - sz(b) + 1;
reverse(all(a)), reverse(all(b));
a.resize(s), b.resize(s);
a = a * inverse(b);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we define operator*=?

a.resize(s), reverse(all(a));
return a;
}
OP(/, /=)
poly &operator%=(poly &a, poly &b) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const poly& b, same above

if (sz(a) < sz(b))
return a;
poly c = (a / b) * b;
a.resize(sz(b) - 1);
rep(i, 0, sz(a)) a[i] = a[i] - c[i];
return a;
}
OP(%, %=)
Loading