FFT

快速傅里叶变换 (Fast Fourier Transform), 即利用计算机计算离散傅里叶变换(DFT)的高效、快速计算方法的统称,简称FFT

在算法竞赛中,FFT主要用于解决卷积的加速与优化,最典型的例子是求两个多项式的乘积。

譬如说

\[f(x)=a_{0}x^{0}+a_{1}x^{1}+...+a_{n-1}x^{n-1}\]

\[g(x)=b_{0}x^{0}+b_{1}x^{1}+...+b_{n-1}x^{n-1}\]

\(f(x)g(x)\)

如果用传统思路去解决这个问题,显然再怎么优化似乎只能到\(\Theta(n^{2})\)的复杂度。因为不管如何考虑,枚举偏序对\((a_{i},b_{j})\)的过程似乎都没有任何的冗余。如果我们不能跳出以上思维的局限,我们就不可能在根本上降低算法的时间复杂度。

让我们首先换一个角度来考察多项式。

多项式的点值表示法(点值)\(\Theta (nlog(n))\)

显然的是对\(\forall x\)\(\exists !y\)与之对应,那么换一个角度说我们是不是可以用若干点值对\((x_{0},y_{0}),(x_{1},y_{1})...(x_{n-1},y_{n-1})\)(其中\(y_{i}=f(x_{i}\)))去尝试表示一个多项式\(f(x)\)呢?

引理:对于任意\(n\)个点值对组成的集合\({(x_{0},y_{0}),(x_{1},y_{1})...(x_{n-1},y_{n-1})}\),其中\(x_{0}x_{1}...x_{n-1}\)互不相同,存在唯一的次数不超过\(n-1\)的多项式\(f(x)\),满足\(y_{i}=f(x_{i})(0\leq i\leq n-1)\)

由于\(x_{i}\)是我们自己决定的,所以我们可以用适当的方法,选择恰当的值,使得上述过程的复杂度降为\(\Theta (nlog(n))\)。在这里我们令\(x_{k}=\varepsilon _{n}^{k}=cos(2\pi\frac{k}{n})+sin(2\pi\frac{k}{n})i\)(其中i是虚数单位)。很容易证明\(\varepsilon_{2n}^{2k}=cos(2\pi\frac{2k}{2n})+sin(2\pi\frac{2k}{2n})i=cos(2\pi\frac{k}{n})+sin(2\pi\frac{k}{n})i=\varepsilon _{n}^{k}\)(折半定理)。那么现在的问题是如何快速求出\(f(\varepsilon _{n}^{k})\)呢?我们可以做以下的推导:

\[ f(x)=a_{0}x^{0}+a_{1}x^{1}+...+a_{n-1}x^{n-1}\\ let\ f_{even}(x)=a_{0}x^{0}+a_{2}x^{1}+...+a_{n-2}x^{\frac{n-2}{2}}\\ f_{odd}(x)=a_{1}x^{0}+a_{3}x^{1}+...+a_{n-1}x^{\frac{n-2}{2}}\\ thus\ f(x)=f_{even}(x^{2})+xf_{odd}(x^{2})\\ \therefore f(\varepsilon_{n}^{k})=f_{even}(\varepsilon_{n}^{2k})+\varepsilon_{n}^{k}f_{odd}(\varepsilon_{n}^{2k})\\ =f_{even}(\varepsilon_{\frac{n}{2}}^{k})+\varepsilon_{n}^{k}f_{odd}(\varepsilon_{\frac{n}{2}}^{k}) \]

在这里我们不妨设\(n\)是一个偶数(因为如果\(n\)是奇数,我们可以通过添加高次\(0\)使得\(n\)为偶数),于是通过上述推导我们发现问题转化为了两个子问题:求\(f_{even}(x)\)\(f_{odd}(x)\)的点值。很容易证明,在如此分治之后点值的复杂度为\(\Theta (nlog(n))\)

点值多项式的乘法\(\Theta (n)\)

假如说我们已经有了两个点值对\((x_{i},f_{i}),(x_{i},g_{i})\),分别代表两个多项式\(f(x),g(x)\)。那么\(h(x)=f(x)g(x)\)的点值表示便很容易在线性时间内计算出是\((x_{i},f_{i}g_{i})\)

要注意的是因为\(f(x)g(x)\)的次数为\(n-1\)\(h(x)\)的次数为\(2n-2\),因此确定\(h(x)\)需要\(2n-1\)个点值对,而现在我们只有\(n\)个点值对。我们可以通过对\(f(x)\)\(g(x)\)的点值对个数的扩充来解决这个问题,即将\(f(x)\)\(g(x)\)的点值对在一开始就取为\(2n-1\)

将点值表示法再转化为系数表示法(插值)\(\Theta (nlog(n))\)

插值过程是点值过程的逆运算。这个问题比前一个问题看起来更复杂,但事实上,通过适当的转化可以把这个问题转化为前一个问题。假设我们已经得知了某一个多项式的点值表示\({(x_{0},y_{0}),(x_{1},y_{1})...(x_{n-1},y_{n-1})}\),即:

\[ \begin{bmatrix} y_{0}\\ y_{1}\\ y_{2}\\ ...\\ y_{n-1}\end{bmatrix} = \begin{bmatrix} 1 & 1 & 1 & ... &1 \\ 1 &\varepsilon_{n}^{1}& \varepsilon_{n}^{2} & ... & \varepsilon_{n}^{n-1}\\ 1 & \varepsilon_{n}^{2} & \varepsilon_{n}^{4} & ... &\varepsilon_{n}^{2n-2} \\ ... & ... & ... & ... & ...\\ 1 & \varepsilon_{n}^{n-1} & \varepsilon_{n}^{2n-2} &... & \varepsilon_{n}^{(n-1)^{2}} \end{bmatrix} \begin{bmatrix}a_{0}\\ a_{1}\\ a_{2}\\ ...\\ a_{n-1}\end{bmatrix} \]

记为:\(Y=V_{n}A\)。巧的是\(V_{n}\)的逆矩阵\(V_{n}^{-1}\)是一个很有规律的矩阵:

\[ V_{n}^{-1} =\frac{1}{n} \begin{bmatrix}1 & 1 &1 & ... &1 \\1 & \varepsilon_{n}^{-1} & \varepsilon_{n}^{-2} & ... & \varepsilon_{n}^{-(n-1)}\\ 1 & \varepsilon_{n}^{-2} & \varepsilon_{n}^{-4} & ... &\varepsilon_{n}^{-(2n-2)} \\ ... & ... & ... & ... & ...\\1 & \varepsilon_{n}^{-(n-1)} & \varepsilon_{n}^{-(2n-2)} &... & \varepsilon_{n}^{-(n-1)^{2}} \end{bmatrix} \]

我们惊喜地发现上述矩阵与原矩阵结构类似,事实上,我们恰恰可以利用这一点,将点值部分的代码稍加变化就成了插值的代码,因为:

\[ A=V_{n}^{-1}Y =\frac{1}{n} \begin{bmatrix} 1 & 1 & 1 & ... &1 \\ 1 & \varepsilon_{n}^{-1} & \varepsilon_{n}^{-2} & ... & \varepsilon_{n}^{-(n-1)}\\ 1 & \varepsilon_{n}^{-2} & \varepsilon_{n}^{-4} & ... &\varepsilon_{n}^{-(2n-2)} \\ ... & ... & ... & ... & ...\\ 1 & \varepsilon_{n}^{-(n-1)} & \varepsilon_{n}^{-(2n-2)} &...& \varepsilon_{n}^{-(n-1)^{2}} \end{bmatrix} \begin{bmatrix}y_{0}\\ y_{1}\\ y_{2}\\ ...\\ y_{n-1}\end{bmatrix} \]

这和之前的点值的计算方法几乎是一模一样的。

【例题ex_1】hdu1402

A * B Problem Plus Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/32768 K (Java/Others) Total Submission(s): 19459    Accepted Submission(s): 4544

Problem Description Calculate A * B.

Input Each line will contain two integers A and B. Process to end of file.

Note: the length of each integer will not exceed 50000.

Output For each case, output A * B in one line.

#include <vector>
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <string>
#define pi acos(-1)
#define eps 1e-6
using namespace std;
class comp
{
public:
    double r, i;
    comp()
    {
        r = i = 0;
    }
    comp(double real, double image)
    {
        r = real; i = image;
    }
    comp operator+(comp x)
    {
        return comp(r + x.r, i + x.i);
    }
    comp operator-(comp x)
    {
        return comp(r - x.r, i - x.i);
    }
    comp operator*(comp x)
    {
        return comp(r*x.r - i*x.i, r*x.i + x.r*i);
    }
};
void trans(vector<comp> &x)
{
    int l = x.size();
    int h = round(log(l) / log(2));
    vector<comp> y;
    y.resize(l);
    for (int i = 0; i < l; i++)
    {
        int L = 0, R = l;
        for (int j = 0; j < h; j++)
        {
            int mid = (L + R) >> 1;
            if (i&(1 << j)) L = mid;
            else R = mid;
        }
        y[L] = x[i];
    }
    x = y;
}
void fft(vector<comp> &x, int tag)
{
    trans(x);
    int l = x.size();
    for (int h = 2; h <= l; h <<= 1)
    {
        comp wn(cos(tag * 2 * pi / h), sin(tag * 2 * pi / h));
        for (int i = 0; i < l; i += h)
        {
            comp w(1, 0);
            for (int j = i; j < i + h / 2; j++)
            {
                comp u = x[j];
                comp v = w*x[j + h / 2];
                x[j] = u + v;
                x[j + h / 2] = u - v;
                w = w*wn;
            }
        }
    }
}
int main()
{
    string a, b;
    while (cin >> a >> b)
    {
        vector<comp> x, y;
        for (int i = a.length() - 1; i >= 0; i--) x.push_back(comp(a[i] - '0', 0));
        for (int i = b.length() - 1; i >= 0; i--) y.push_back(comp(b[i] - '0', 0));
        int l = 1;
        while (l < 2 * x.size() || l < 2 * y.size()) l <<= 1;
        x.resize(l);
        y.resize(l);
        fft(x, 1);
        fft(y, 1);
        for (int i = 0; i < l; i++) x[i] = x[i] * y[i];
        fft(x, -1);
        vector<int> ans;
        ans.resize(l);
        for (int i = 0; i < l; i++) ans[i] = round(x[i].r / l);
        int i = 0;
        while (i < ans.size())
        {
            if (ans[i] >= 10)
            {
                if (i == ans.size() - 1)
                {
                    ans.push_back(ans[i] / 10);
                    ans[i] %= 10;
                }
                else {
                    ans[i + 1] += ans[i] / 10;
                    ans[i] %= 10;
                }
            }
            i++;
        }
        for (i = ans.size() - 1; i >= 0; i--) if (ans[i]) break;
        for (int j = i; j >= 0; j--) printf("%d", ans[j]);
        if (i < 0) puts("0"); else puts("");
    }
    return 0;
}

【例题ex_2】Lonlife-ACM1092

Fate Dog Time Limit:5s Memory Limit:128MByte

Submissions:93Solved:16

DESCRIPTION Mr.Ang was addicted to spend lots of money on mobile gaming Fate/Grand Order recently. As a well-heeled man of the European royal descent, Mr.Ang has a collection of nn servants already. It seems to him the current combat system is a bit complicated, thus he has come up with a new combat system—select three different servants for each combat, and let them take turns to attack, where each servant is able to attack exactly once in each round.

For the ii-th servant in his collection:

It is able to cause aiai base damages and pipi percents attack bonus if it is selected as the first attacker. It is able to cause bibi base damages and get the first one’s attack bonus if it is selected as the second attacker. It is able to cause cici base damages and get the first one’s attack bonus if it is selected as the third attacker. In summary, if the selected servants in attack order is the ii-th one, the jj-th one and the kk-th one in his collection, they are able to cause (ai+bj+ck)(1+pi100)(ai+bj+ck)(1+pi100) damages in total.

Assuming that there is a BOSS with HH hit points, Mr.Ang intends to know how many ways of the selected servants that are able to kill the BOSS in one round, in other words, they are able to cause no less than HH damages in one round.

INPUT

The first line contains a positive integer TT, which represents there are TT test cases. The following is test cases. For each test case: The first line contains two positive integers nn and HH, which represent the number of servants and the hit points of the BOSS. In the next nn lines, the ii-th line contains four non-negative integers ai,bi,ciai,bi,ci and pipi. It is guaranteed that no more than 10 test cases do not satisfy n,H≤103n,H≤103. 1≤T≤100,3≤n≤105,1≤H≤3⋅105,0≤ai,bi,ci≤105,0≤pi≤100(i=1,2,⋯,n)1≤T≤100,3≤n≤105,1≤H≤3⋅105,0≤ai,bi,ci≤105,0≤pi≤100(i=1,2,⋯,n)

OUTPUT

For each test case, output in one line, contains one integer, which represents the number of the ways.

官方题解
官方题解
#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>
#define N 100010
#define pi acos(-1)
using namespace std;
typedef long long ll;
int a[N], b[N], c[N], p[N], H, l;
ll higher[6 * N], xhigher[3 * N], yhigher[3 * N], xyHigher[3 * N],cE,bE,n,bcE;
struct comp
{
    double r, i;
    comp() {}
    comp(double real, double image) { r = real; i = image; }
    comp operator+(comp x) { return comp(x.r + r, x.i + i); }
    comp operator-(comp x) { return comp(r - x.r, i - x.i); }
    comp operator*(comp x) { return comp(x.r*r - x.i*i, x.r*i + x.i*r); }
};
vector<comp> x, y;
void init()
{
    for (int i = 0; i <= 2 * H; i++) higher[i] = 0;
    for (int i = 0; i <= H; i++) xhigher[i] = yhigher[i] = xyHigher[i] = 0;
    cE = bE = bcE = 0;
    for (l = 1; l <= 2 * H; l <<= 1);
    x.resize(l); y.resize(l);
    for (int i = 0; i < l; i++) x[i] = comp(0, 0), y[i] = comp(0, 0);
    for (int i = 0; i < n; i++)
    {
        if (b[i] <= H) xhigher[b[i]] += 1; else bE++;
        if (c[i] <= H) yhigher[c[i]] += 1; else cE++;
        if (b[i] + c[i] > H) bcE++; else xyHigher[b[i] + c[i]]++;
    }
    for (int i = 0; i <= H; i++) x[i].r = xhigher[i], y[i].r = yhigher[i];
    xhigher[H] += bE; yhigher[H] += cE; xyHigher[H] += bcE;
    for (int i = H - 1; i >= 0; i--)
        xhigher[i] += xhigher[i + 1], yhigher[i] += yhigher[i + 1], xyHigher[i] += xyHigher[i + 1];
}
void trans(vector<comp> &x)
{
    vector<comp> y;
    y.resize(l);
    int h = round(log(l) / log(2));
    for (int i = 0; i < l; i++)
    {
        int L = 0, R = l;
        for (int j = 0; j < h; j++)
        {
            int mid = (L + R) >> 1;
            if (i&(1 << j)) L = mid;
            else R = mid;
        }
        y[L] = x[i];
    }
    x = y;
}
void fft(vector<comp> &x, int tag)
{
    trans(x);
    for (int h = 2; h <= l; h <<= 1)
    {
        comp w(cos(tag * 2 * pi / h), sin(tag * 2 * pi / h));
        for (int i = 0; i < l; i += h)
        {
            comp wn(1, 0);
            for (int j = i; j < i + h / 2; j++)
            {
                comp u = x[j];
                comp v = wn*x[j + h / 2];
                x[j] = u + v;
                x[j + h / 2] = u - v;
                wn = w*wn;
            }
        }
    }
}
void work()
{
    ll ans = 0;
    for (int i = 0; i < n; i++)
    {
        int bpc = ceil((1.0*(100 * (H - a[i]) - a[i] * p[i])) / (100 + p[i]));
        if (bpc <= 0)
        {
            ans += (n-1)*(n - 2);
        } else {
            ans += higher[bpc] + (bE + cE)*n - bE*cE;
            if (b[i] >= bpc) ans -= n; else ans -= yhigher[bpc - b[i]];
            if (c[i] >= bpc) ans -= n; else ans -= xhigher[bpc - c[i]];
            ans -= xyHigher[bpc];
            if (b[i] + c[i] >= bpc) ans += 2;
        }
    }
    printf("%lld\n", ans);
}
int main()
{
    int t;
    cin >> t;
    while (t--)
    {
        cin >> n >> H;
        for (int i = 0; i < n; i++) scanf("%d%d%d%d", a + i, b + i, c + i, p + i);
        init();
        fft(x, 1); fft(y, 1);
        for (int i = 0; i < l; i++) x[i] = x[i] * y[i];
        fft(x, -1);
        for (int i = 0; i <= 2 * H; i++) higher[i] = round(x[i].r / l);
        for (int i = 2 * H - 1; i >= 0; i--) higher[i] += higher[i + 1];
        work();
    }
    return 0;
}