알고리즘

[빠른곱셈-2][선행지식] 중국인의 나머지 정리(chinese remainder theorem)

NickTop 2023. 2. 21. 23:12

배경

실질적으로 곱셈에 쓰고 있는 쇤하게-슈트라센 알고리즘을 정리하다 보니, 몇 가지 선행지식이 있습니다.

먼저 짚고 넘어가는 게 좋을 것 같아서 정리합니다

 

합동식

$x\equiv y (mod \,p)$

=> p divides (x-y)

=> x-y = np

 

$p_0, p_1,...,p_{n-1}$가 있다고 하자 (mutually coprime)

For any integer $u$,

$u(mod \,p_i) = u_i$ 라고 하고

$r_u$를 다음과 같이 정의한다

$r_u =(u_0,u_1,...,u_{n-1})$

 

Chinese Remainder Theorem

Let. $p = \sum_{i=0}^{n-1} p_i$

$u$가 [0,p-1] 범위의 정수일때 $u\rightarrow r_u$는 일대일 대응이다 ($r_u$는 서로 겹치지 않음)

 

증명

a,b가 u의 원소라고 하고, $r_a=r_b$라고 하자

$(a_0,a_1,...,a_{n-1}) = (b_0,b_1,...,b_{n-1})$

$a_i = b_i$ (for every i)

$a_i= a(mod \,p_i)=b_i= b(mod \,p_i)$

$p_i$ divides $a-a_i$ and $b-b_i$

그리고 $a_i=b_i$ 이기 때문에 $p_i$ divides $a-b_i$

 

어떤 두 수 x,y가 p로 나누어 떨어진다면,

x-y도 p로 나누어 떨어진다

따라서, $p_i$ divides $a-b$ (for every i)

$p_i$는 서로소이기 때문에

$\prod_{i=0}^{n-1}p_i=p$ divides $a-b$

$a-b = np$

a,b는 [0,p-1]의 범위에 있으므로,

$a-b=0$ 따라서 모순이고 chinese remainder theorem는 성립

 

q

Residue Based Representation of Integer를 효율적으로 계산하기 위해 q를 정의합시다

q는 그 자체로는 큰 의미가 없지만, q를 미리 구해놓으면 나중에 계산할 때 유리한 점이 있습니다

$q(j,i)=q(j-1,i)*q(j-1,i+2^{j-1})$ 라고 정의합시다

(참고로, $q(j,i)=p_ip_{i+1}...p_{i-1+2^j}$)입니다

n은 2^k입니다

j=0,1,...,k

i=0,1*2^j,2*2^j,... 가 가능합니다

 

from functools import reduce
from typing import List

def multiplication(a,b):
    return a*b

def efficiently_calculate_q(p:List[int]):
    n = len(p)
    k = n.bit_length()
    print(n,k)
    q = [[0]*n for _ in range(k)]
    for i in range(n):
        q[0][i]=p[i]
    for j in range(1,k): # 시간복잡도1 : k
        for i in range(0,n,2**j): # 시간복잡도2 : n/(2^j)
            # p[x][y]의 최대 크기를 b 라고 하면
            # q[j][i]의 크기는 최대 (2^j)*b
            q[j][i] = q[j-1][i]*q[j-1][i+2**(j-1)]  # 시간복잡도3 : M((2^j)*b)
    
    assert(q[-1][0]==reduce(multiplication,p))

efficiently_calculate_q([43,346,213476,432,43,3,3,1])

시간복잡도1*시간복잡도2

 $\frac{n}{2^j}M(2^{j-1}b) \leq  \frac{1}{2}M(2^{j-1}b\frac{n}{2^{j-1}})=\frac{1}{2}M(nb)$

따라서 총 시간복잡도(시간복잡도1*시간복잡도2*시간복잡도3)

$O(k*M(nb)) =O(logn*M(nb))$

 

u로부터 $r_u$ 계산

u(j,i)를 정의하자

$u(j,i) = u(mod\,q(j,i))$

그러면

${\begin{align}
u(j-1,i)&=u(mod\,q(j-1,i)) \\
&=u(mod\,q(j-1,i)*q(j-1,i+2^{j-1})) (mod\,q(j-1,i)) \,\,\,\,\,<<note>>\\
&=u(mod\,q(j,i))(mod\,q(j-1,i))\\
&=u(j,i)(mod\,q(j-1,i))\\
\end{align}}$

 

note that,  x (mod y) = x (mod z*y)(mod y) 

 

비슷한 방법으로

${\begin{align}
u(j-1,i+2^{j-1})&=u(mod\,q(j-1,i+2^{j-1})) \\
&=u(mod\,q(j-1,i+2^{j-1})*q(j-1,i)) (mod\,q(j-1,i+2^{j-1})) \\
&=u(mod\,q(j,i))(mod\,q(j-1,i+2^{j-1}))\\
&=u(j,i)(mod\,q(j-1,i+2^{j-1}))\\
\end{align}}$

 

$(u_0,u_1,...,u_{n-1})$은 $(u(0,0),u(0,1),...,u(0,n-1))$과 같습니다

def efficiently_calculate_u(p:List[int],U:int):
    n = len(p)
    k = n.bit_length()-1
    q = efficiently_calculate_q(p)
    u = [[0]*n for _ in range(k+1)]
    u[k][0]=U
    for j in range(k,0,-1): # 시간복잡도1 : O(k)
        for i in range(0,n,2**j): # 시간복잡도2 : O(n/(2^j))
            # u[j-1][i]의 크기는 q[j-1][i]를 넘지 못함
            # u[j-1][i]의 최대 크기는 (2^j)*b
            u[j-1][i] = u[j][i]%q[j-1][i] # 시간복잡도3-1 : M((2^j)*b)
            u[j-1][i+2**(j-1)] = u[j][i]%q[j-1][i+2**(j-1)] # 시간복잡도3-2 : M((2^j)*b)
    return u[0]

def mod(a,b):
    return a*b

def directly_calculate_u(p:List[int],U:int):
    u = [0]*len(p)
    for i in range(len(p)):
        u[i]=U%p[i]
    return u

p = [2,3,5,7,13,17,19,713]
U = 165423
assert(directly_calculate_u(p,U)==efficiently_calculate_u(p,U))

총 시간복잡도는 q와 동일하게 $O(k*M(nb)) =O(logn*M(nb))$입니다

 

directly로 계산했을 경우 시간복잡도는 $O(b^2*M(b))$입니다 (이 부분은 이해하지 못했습니다)

$M(x)=x^{1+a}$로 계산했을 경우 $\frac{n^{1-a}}{logn}$만큼 개선이 이루어집니다

 

$r_u$로 부터 u 구하기

$c_i=\frac{p}{p_i}$라고 정의하자

$c_id_i\equiv1(mod\,p_i)$인 $d_i$가 존재한다

 

증명

$c_i$와 $p_i$는 mutually coprime.

$ac_i + bp_i=1$을 만족하는 정수 a,b 존재(확장된 유클리드 알고리즘으로 a,b 구할 수 있다)

$ac_i =1(mod\,p_i)$

a가 $d_i$

$d_i$를 [0,$p_i$)의 범위로 둔다면 $d_i$는 유일하다

def extended_euclidean_algo(a,b):
    prev_x,prev_y,prev_r = 0,1,b
    x,y,r = 1,0,a
    while prev_r%r!=0:
        t = prev_r//r
        new_x = prev_x - x*t
        new_y = prev_y - y*t
        new_r = prev_r - r*t
        prev_x,prev_y,prev_r = x,y,r
        x,y,r = new_x,new_y,new_r
    return x,y,r
print(extended_euclidean_algo(11,123))
# (56, -5, 1)
# 56*11-5*123=1

시간복잡도는 $O(log(max(A,B)))$입니다

확장된 유클리드 알고리즘 설명은 건너뛰겠습니다. 저는 아래 영상을 보고 코드를 짰습니다

https://www.youtube.com/watch?v=PmwLXveLtqc&list=PLdEdazAwz5Q884ImnFH_5yEne0qzGHNhS&index=7 

 

$u(mod\,p)$

$u(mod\,p)=\sum_{j=0}^{n-1}u_jc_jd_j(mod\,p)$ 가 성립함

 

왜냐하면

왼쪽을 $mod \, p_i$를 취해주면

$u(mod\,p)(mod \, p_i)=u_i$

 

오른쪽을 $mod \, p_i$를 취해주면

${\begin{align}
\sum_{j=0}^{n-1}u_jc_jd_j(mod\,p)(mod\,p_i)&=\sum_{j=0}^{n-1}u_jc_jd_j(mod\,p_i) \\
&=u_ic_id_i(mod\,p_i) \,\,\,\,<<note>>\\
&=u_i(mod\,p_i) \\
&=u_i
\end{align}}$

note that) $c_i$를 제외한 $c_j$은 $p_i$를 약수로 가지고 있다

 

따라서 앞서 설명했듯이 1-1대응이므로 등식은 성립한다

$u$가 [0,p-1] 범위의 정수일때 $u\rightarrow r_u$는 일대일 대응이다 ($r_u$는 서로 겹치지 않음)

 

$S(j,i)$

S를 다음과 같이 정의합니다

$S(j,i)=\sum_{m=i}^{i-1+2^{j}}\frac{p_ip_{i+1}...p_{i-1+2^{j}}}{p_m}u_md_m$

그러면,

$S(k,0)(mod\,p)=u(mod\,p)$입니다. (앞에서 증명했듯이)

또한, $S(0,i)=u_id_i$입니다

 

그리고 아래식이 성립합니다

\begin{align} 
S(j,i)&=p_ip_{i+1}...p_{(i-1+2^{j})}\sum_{m=i}^{i-1+2^{j}}\frac{u_md_m}{p_m}\\
&=p_i...p_{(i-1+2^{j-1})}p_{(i+2^{j-1})}...p_{(i-1+2^{j})}\sum_{m=i}^{i-1+2^{j-1}}\frac{u_md_m}{p_m}
+p_i...p_{(i-1+2^{j-1})}p_{(i+2^{j-1})}...p_{(i-1+2^{j})}\sum_{m=i+2^{j-1}}^{i-1+2^{j}}\frac{u_md_m}{p_m} \\
&=p_{(i+2^{j-1})}...p_{(i-1+2^{j})}S(j-1,i)+p_i...p_{(i-1+2^{j-1})}S(j-1,i+2^{j-1})\\
&=q(j-1,i+2^{j-1})S(j-1,i)+q(j-1,i)S(j-1,i+2^{j-1})\\
\end{align}

 

def prod_c(p,i):
    output = 1
    for j in range(len(p)):
        if j==i:
            continue
        output*=p[j]
    return output

def remainders_u(p:List[int],U:int):
    # p와 u로부터 U를 구하는 걸 하는 것임
    n = len(p)
    k = n.bit_length()-1
    q = efficiently_calculate_q(p)
    u = efficiently_calculate_u(p,U)
    
    S = [[0]*n for _ in range(k+1)]
    C = [0]*n
    # todo) how to compute C efficiently
    for i in range(n):
        C[i] = prod_c(p,i)
    for i in range(n):
        d = extended_euclidean_algo(C[i],p[i])
        S[0][i]=u[i]*(d[0]%p[i]) # C[i]*d[i]
    for j in range(1,k+1): # 시간복잡도1 : O(k)
        for i in range(0,n,2**j): # 시간복잡도2 : O(n/(2^j))
            # u[j-1][i]의 최대 크기는 (2^j)*b
            # S[j][i]의 크기 : (2^j)*b의 크기의 수를 2^j 번 더함 -> (2^j)*b+2^j-1 이고
            # |S[j][i]| <= (2^(j+1))*b
            S[j][i] += q[j-1][i]*S[j-1][i+2**(j-1)] + q[j-1][i+2**(j-1)]*S[j-1][i]
            # 시간복잡도3 : 2M((2^j)*b)
    P = C[0]*p[0]
    assert(S[k][0]%P==U)
remainders_u([2,3,5,7],500)

총 시간복잡도는 $O(logn*M(nb))$입니다. 이유는 계속 앞에서 같은 것을 계산했기 때문에 넘어가겠습니다

(C를 계산하는 것이 시간복잡도가 많아보이는데 이는 아직 잘 모르겠습니다)

 

결론

$p_i$(mutually coprime)가 주어지고,

어떤수(u)를 $p_i$로 나눈 나머지쌍 $u_i$가 주어진다면

[0,p-1] 범위내의 어떤수(u)를 찾을 수 있고, ($p=\prod p_i$) 

그 시간복잡도는 $O(logn*M(nb))$이다