알고리즘

[빠른곱셈-1] FFT

NickTop 2023. 2. 12. 19:46

원리

1. 숫자를 n개로 끊어서 다항식으로 변환 (아래 예시)

   ${123456789 \mapsto 123x^2+456x+789 \\
x=1000}$

2. 1번에서 변환된 두 다항식을 DFT로 변환

3. 변환된 두 값을 곱해준 뒤 IDFT로 두 다항식의 곱셈을 구함

4. 다항식(x)에 정수 대입

 

Lagrange's Interpolation

m-1차 다항식은 m개의 점에서의 함수값이 결정되면 유일하게 결정됨

ex) 1차다항식에서 2개의 점을 알고있으면 다항식이 유일하게 결정된다

${f(x) = ax+b \\
f(0)=1 \\
f(1)=2 \\
=> \\
b = 1 \\
a = 1}$

 

Primitive root of unity

${w^n=1 \\
w^k!=1(1\leq k<n, k \in N)}$

를 만족하는 $w$를 Primitive root of unity 라고 한다

여기서 ${w}$를 만족하는 것 중 하나는

$w = exp(2\pi i/n)$이다

참고로 이때 $w^{n/2} = 1$ 이다

 

DFT

다항식 $A(x)$가 있다고 가정

$A(x)=a_0+a_1x^1+...+a_{n-2}x^{n-2}+a_{n-1}x^{n-1}$

$x$에 $w^i$를 대입하자 ($0\leq i< n$)

$(a_0,a_1,..,a_{n-1})x \mapsto (A(w^0),A(w^1),...,A(w^{n-1}))$

이때 얻어지는 변환을 DFT라고 한다 (이산 퓨리에 변환)

 

FFT

DFT를 naive하게 계산하려면 $O(n^2)$

FFT는 퓨리에 변환을 빨리 하는 방법

 

Cooley-tukey

FFT의 방법 중 하나

${\begin{align*}
A(x)&=a_0+a_1x^1+...+a_{n-2}x^{n-2}+a_{n-1}x^{n-1} \\
&= (a_0+a_2x^2+a_4x^4+...)+(a_1x+a_3x^3+a_5x^5+...) \\
&= (a_0+a_2x^2+a_4x^4+...)+x(a_1+a_3x^2+a_5x^4+...) \\
&= A_{even}(x^2)+xA_{odd}(x^2)
\end{align*}}$

$0\leq j<n/2$에 대해

${A(w^j) = A_{even}(w^{2j})+w^jA_{even}(w^{2j}) \\
\begin{align*}
A(w^{j+n/2}) &= A_{even}(w^{2j+n})+w^jw^{n/2}A_{even}(w^{2j+n})\\
&= A_{even}(w^{2j+n})-w^jA_{even}(w^{2j+n})
\end{align*}}$

즉, $A(w^j)$를 알고있다면 $A(w^{j+n/2})$ 는 $O(1)$에 구해진다

 

예를 들어, 3차 다항식을 생각해보자

$f(x) = 1+2x+3x^2+4x^3$

$f(1),f(-1),f(-i),f(-1)$

을 naive하게 구하고자 할 때 각각을 구하기 위해서는 4번의 덧셈을 해야하며, 총 16번의 덧셈이 소요된다

하지만, cooley-tukey의 방법으로

$f(x) = (1+3x^2)+x(2+4x^2)$

$f(1) = (4)+1*(6)$
$f(i) = (-2)+i*(0)$

$f(1)$과 $f(i)$을 구했다면 $f(-1)$과 $f(-i)$를 앞의 결과를 활용하여 구할 수 있다

${f(-1) = (4)-1*(6) \\ 
f(-i) = (-2)-i*(0) }$

3차보다 더 큰 다항식에 대해서는 더 크게 체감이 될 것 같다

 

$A_{even}(x^2)$의 DFT를 구할 때도 $A'(x)$로 생각하여

$A'_{even}(x^2), A'_{odd}(x^2)$를 구할 수 있다 (divide and conquer) 

 

$A_{even}(x^2)$와 $A_{odd}(x^2)$의 DFT를 구하고 나서

A의 DFT를 구하는데, n만큼의 시간이 소요된다 (even과 odd를 더하는 덧셈 때문에)

 

그러면 시간복잡도는

$T(n) = 2T(n/2)+O(n)$ 이므로 $O(nlogn)$이다

 

DFT의 곱

$C(x)=A(x)B(x)$ 이라고 하면

$C(w^j)=A(w^j)B(w^j)$ 이다

 

즉 A의 DFT, B의 DFT를 구하고,

각 요소를 곱해주면 C의 DFT를 구할 수 있다 (pointwise multiplication)

 

A가 n-1차 다항식, B가 n-1차 다항식이므로(두 수의 크기가 다르다면 맞춰줘야 함) 

C는 2n-2차 다항식이다.

그러면 C 다항식을 유일하게 결정하려면 최소 2n-1개의 서로 다른 점이 필요하고 편의상 2n개의 점이 필요하다고 하자

 

그러면

A,B 모두 $w^{2n}=1$인 primitive root of unity로 계산이 필요하다

 

IDFT

서로 다른 2n개의 점으로부터 다항식을 구해야 한다

C의 DFT를 역변환하면 된다(IDFT)

 

역변환을 하는 방법을 알아보자

n-1차 다항식 $A(x)$를 생각해보자

$A(x)=a_0+a_1x^1+...+a_{n-2}x^{n-2}+a_{n-1}x^{n-1}$

$y_k = A(w^k)$라고 하자

 

${\begin{bmatrix}
y_0 \\
y_1 \\
y_2 \\
\vdots \\
y_{n-1}
\end{bmatrix}
=
\begin{bmatrix}
1 &1  &1  &...  &1 \\ 
 1&w  &w^2  &...  &w^{n-1} \\ 
1& w^2&w^4  &...  &w^{2(n-1)} \\ 
 \vdots&\vdots  &\vdots  &\ddots  & \vdots \\ 
 1& w^{n-1} & w^{2(n-1)} & ... & w^{(n-1)(n-1)}
\end{bmatrix}
\begin{bmatrix}
a_0 \\
a_1 \\
a_2 \\
\vdots \\
a_{n-1}
\end{bmatrix}
}$

 

이고 역행렬을 구하면

 

${\frac{1}{n}
\begin{bmatrix}
1 &1  &1  &...  &1 \\ 
 1&w^{-1}  &w^{-2}  &...  &w^{-(n-1)} \\ 
1& w^{-2}&w^{-4}  &...  &w^{-2(n-1)} \\ 
 \vdots&\vdots  &\vdots  &\ddots  & \vdots \\ 
 1& w^{-(n-1)} & w^{-2(n-1)} & ... & w^{-(n-1)(n-1)}
\end{bmatrix}
\begin{bmatrix}
y_0 \\
y_1 \\
y_2 \\
\vdots \\
y_{n-1}
\end{bmatrix}

=

\begin{bmatrix}
a_0 \\
a_1 \\
a_2 \\
\vdots \\
a_{n-1}
\end{bmatrix}}$

이다

 

행렬과 역행렬을 곱하면 단위행렬이 되는 것을 쉽게 볼 수 있다

위 식이 의미하는 바는 $w$ 대신 $w^{n-1}$으로 FFT를 구하고 결과에 n을 나누면 역변환이 가능하다는 뜻이다

 

코드

cooley-tukey FFT를 하려면 다항식의 차수가 2의 제곱이 되어야 함에 주의

from math import exp, pi, cos, sin
X = 0b111
p = 3
# 임의 지정, 2의 지수가 되어야 비트 연산으로 num_to_array 쉽게 계산 가능

def get_size_power_of_2(a_len,b_len):
    # cooley-tukey 계산을 위해 size는 2**n이 되어야 함
    size = 1
    thres = max(a_len, b_len)
    while size<thres:
        size = size<<1
    return size<<1 # C의 차수는 A의 차수의 2배

def num_to_array(n):
    # 숫자를 다항식으로 변환
    # A(x)=a0+a1x^1+...+a_{n-2}x^{n-2}+a_{n-1}x^{n-1} 변환
    N=[]
    while n>0:
        a = n & X
        n = n >> p
        N.append(a)
    return N
    
def fft(a,w):    
    if len(a)==1:
        return a
    a_even = a[::2]
    a_odd = a[1::2]
    A_even = fft(a_even, w*w)
    A_odd = fft(a_odd, w*w)
    A = []
    for i in range(len(A_even)):
        A.append(A_even[i]+(w**i)*A_odd[i])
        # w**i를 미리 계산하면 더 시간복잡도가 줄어듬
    for i in range(len(A_even)):
        A.append(A_even[i]-(w**i)*A_odd[i])
    return A

def pointwise_multiplication(A,B):
    C=[]
    for i in range(len(A)):
        C.append(A[i]*B[i])
    return C

def array_to_num(C):
    n = 0
    for c in reversed(C):
        n = c + (n<<p)
    return n
    
def multiplication(n1,n2):
    A = num_to_array(n1)
    B = num_to_array(n2)
    size = get_size_power_of_2(len(A),len(B))
    A = A+[0]*(size-len(A))
    B = B+[0]*(size-len(B))
    
    w = complex(cos(2*pi/size),sin(2*pi/size))
    A_dft = fft(A,w)
    B_dft = fft(B,w)

    C_dft = pointwise_multiplication(A_dft,B_dft)
    C = fft(C_dft, 1/w)
    C = list(map(lambda x:x/size,C))
    C = list(map(lambda x:round(x.real),C))
    # C중 real 값만 취함, C는 real값만 있어야 하지만 소수 계산 때문에 오차 있음
    n = array_to_num(C)
    return n

n1 = 152346781523467813246784132768914327698143279681679567812
n2 = 71856492708591278915237851293708915231785902175832715832151512531235513251323152135241352
print(n1*n2)
print(multiplication(n1,n2))
# output
# 10947105395718413493159423815491323669741350128895408553063593853966234682079364780093295071747912602980987186335604499511318226374057805670561824
# 10947105395718413493159423815491323669741350128895408553063593853966234682079364780093295071747912602980987186335604499511318226374057805670561824

 

시간복잡도 및 정확도

앞서 설명한 것 처럼 FFT의 시간복잡도가 $O(nlogn)$ 이므로 시간복잡도는 $O(nlogn)$ 이다

하지만, 삼각함수 및 복소수 연산이 있기 때문에 오차가 발생한다

이 오차를 줄이려면 삼각함수를 더 정확하게 계산해야 하는데, 그러면 시간복잡도가 늘어난다

따라서 $O(nlogn)$ 라고 할 수 없다

 

 

출처

https://speakerdeck.com/wookayin/fast-fourier-transform-algorithm?slide=30 

 

Fast Fourier Transform Algorithm

A Introduction to FFT (Fast Fourier Transform) Algorithm with its application in competitive programming. This talk was given in the 2012 SNUPS (Seoul National University Problem Solving Group) Algorithm seminar.

speakerdeck.com

https://mathoverflow.net/questions/19946/what-is-the-time-complexity-of-computing-sinx-to-t-bits-of-precision

 

What is the time complexity of computing sin(x) to t bits of precision?

Short version of the question: Presumably, it's poly$(t)$. But what polynomial, and could you provide a reference? Long version of the question: I'm sort of surprised to be asking this, becaus...

mathoverflow.net