인공지능

모델 경량화 프루닝 (Pruning) - Unstructured

NickTop 2025. 3. 23. 16:53

프루닝이란

프루닝 이미지

프루닝(pruning)은 딥러닝 모델에서 모델의 크기를 줄이고 효율성을 높이기 위해 불필요한 부분을 제거하는 과정입니다. 필요없는 부분을 "가지치기" 해서 모델의 덜 중요한 부분을 걷어냅니다

가지치기하는 가장 쉬운 방법은 모델의 weight를 0으로 만드는 것(masking)입니다

학습방법은 아래와 같습니다

 

1. 모델이 적당히 converge 할때까지 학습함

2. 중요도가 낮은 weight를 0으로 만듬 (마스킹)

3. 마스킹된 weight를 제외하고 모델 재학습

4. 반복

 

그러면 어떤 값을 0으로 만들어야 할까요?

(논문에 Saliency Score 라는 표현이 자주 나오는데 weight의 중요도를 결정하는 값입니다, 중요도가 낮은 것을 없애야 함)

 

Method1 : Hessian

https://proceedings.neurips.cc/paper_files/paper/1989/file/6c9882bbac1c7093bd25041881277658-Paper.pdf#:~:text=analytically%20predict%20the%20effect%20of,of%20the%20parameter

 

첫번째 방법은 second derivative를 기준으로 weight를 0으로 만드는 방법입니다

그 아이디어는 다음과 같습니다

 

작은 weight의 변화량에 Loss가 어떻게 변하는지 봅시다

loss function은 weight에 종속된 함수입니다

$\Delta E = E(w+\Delta w) - E$

$E(w+\Delta w)$를 taylor series로 근사합시다 (2차 도함수까지)

$E(w+\Delta w) \simeq  E(w) + \gradient E^T  \Delta w + \frac{1}{2} \Delta w^T H \Delta w$

따라서, $\Delta E \approx \sum_{i} \frac{\partial E}{\partial w_i} \Delta w_i 
+ \frac{1}{2} \sum_{i,j} \frac{\partial^2 E}{\partial w_i \partial w_j} \Delta w_i \Delta w_j$

(참고로, $w_i$는 스칼라입니다)

 

첫번째 term을 봅시다

모델이 적당히 converge되면 gradient는 0에 가까워집니다.

그러므로 첫번째 term을 0으로 간주할 수 있습니다.

 

두번째 term을 봅시다

인공지능 모델에서 weight element간의 유사성은 거의 없습니다 (i,j 둘 다 고려하면 계산복잡도가 너무 커져서 이를 무시하기 위함임)

따라서 Hessian function에서 diagonol을 제외한 값은 0으로 간주할 수 있습니다

이를 정리하면

 $\Delta E \approx  \frac{1}{2} \sum_{i} \frac{\partial^2 E}{\partial^2 w_i} \Delta w_i ^2$

$\frac{\partial^2 E}{\partial^2 w_i} \Delta w_i$가 작을수록 모델에 영향을 적게 주는 weight입니다.

따라서 $\frac{\partial^2 E}{\partial^2 w_i} \Delta w_i$를 기준으로 정렬한 뒤 작은 값부터 0으로 만들어주면 됩니다

Second order Derivative는 backpropagation과 비슷한 방법으로 구합니다 (정확한 방법은 생략)

 

https://proceedings.neurips.cc/paper/1992/file/303ed4c69846ab36c2904d3ba8573050-Paper.pdf

위 논문에서는 Hessian의 역행렬을 구해서 second order derivative를 구하는 방법에 대해 설명합니다

 

Method2 : Absolute Value

https://arxiv.org/pdf/1506.02626

https://arxiv.org/pdf/1510.00149

가장 원초적인 프루닝방법입니다.

Second order derivative 방법이 더 좋다고 얘기하지만 계산복잡도 때문에 단순 weight의 절댓값이 threshold보다 작은 것을 0으로 만들었습니다

 

 

[참고 : Lottery Ticket Hypothesis - network와 동일한 성능을 가진 sub-network가 존재할것임]

아래부터는 training하기 전에 프루닝 대상을 찾는 방법입니다

Method3 : Gradient (SNIP)

https://arxiv.org/pdf/1810.02340

 

모든 Loss은 Loss를 최종적으로 minimize하는 것이 목적입니다. 프루닝으로 마스킹된 값에 따라 Loss function에도 영향이 있습니다. 그럼에도 불구하고 여전히 loss를 minimize 해야 합니다. c를 프루닝에 의해 마스킹된 값이라고 합시다

 

그렇다면 목적식은 다음과 같습니다

$\underset{\textbf c,\textbf w}{min} L(\textbf c \odot \textbf w ; D)$

$\textbf c$는 마스킹 여부이며 $\textbf c$의 element는 0과 1만 가질 수 있습니다

 

c의 값들 중 element j 하나만 집중해서 변화량을 보겠습니다

talyor series 1계함수로 근사합니다

 

$\Delta L_j (\textbf w;D) \simeq g_j(\textbf w;D) = \frac{\partial L (\textbf c \odot \textbf  w;D)}{\partial c_j}$

$\textbf c$는 모두 1로 초기화되어있습니다

 

c는 0과 1만 있어 미분가능한 함수가 아닙니다. 하지만 미분가능하다고 가정합니다.

 

sensitivity를 구하는 방법은 배치를 하나만 돌려서 auto diff로 구합니다

z = w @ x 이렇게 들어가야 할 함수에

z = (c * w) @ x 이렇게 넣어줍니다

 

$\tilde{w}_j = c_j \cdot w_j  \quad \Rightarrow \quad  \frac{\partial L}{\partial c_j}  = \frac{\partial L}{\partial \tilde{w}_j} \cdot \frac{\partial \tilde{w}_j}{\partial c_j}  = \frac{\partial L}{\partial \tilde{w}_j} \cdot w_j$

 

$ \frac{\partial L}{\partial \tilde{w}_j}$는 auto diff 계산 중에 나오는 값입니다.

따라서 SNIP는 간단한 계산으로 쉽게 sensitivity를 계산할 수 있는 방법입니다

또한, 대부분의 pruning이 fine tuning처럼 이루어지는데 비해 학습 초기 단계에서 마스킹을 결정할 수 있다는 큰 장점이 있습니다

 

 

Method4 : Derivative of Gradient (GraSP)

SNIP는 gradient를 중요시 합니다.

하지만 gradient는 현재 얼마나 중요한지만 파악 가능하고 미래에 어떻게 될지 반영하지 못한다고 합니다.

미래에 어떻게 될지 반영하려면 gradient의 곡률을 따져야 하며, gradient의 미분값, 즉 Hessian을 구해야한다고 합니다.

1번 방법과 비슷해 보이지만 SNIP에서 파생된 방법입니다

 

Hessian vector product는 auto diff를 통해 Full Hessian matrix (공간복잡도 n*n임)를 구하지 않아도 됩니다

수학적으로 왜 가능한지 정확히 이해를 못해서 다음에 자세히 다루겠습니다

 

논문에서 소개하는 방법은 다음과 같습니다

computing Hessian-gradient product

 

Method5 : SynFlow

https://arxiv.org/pdf/2006.05467

Layer collapse를 막는 것에 중점을 둔 방법입니다

하나의 Layer의 모든 weight가 없어지면 학습이 제대로 이루어지지 않기 때문에 이를 막기 위한 Saliency Score를 만듭니다

방법은 아래와 같습니다1. input x에 1을 넣습니다 [1,1,1,1,1,...1]

2. weight에 절대값을 취합니다

3. R = sum of logits. (참고로 스칼라임)

4. $S(w_i) = \frac{\partial R}{\partial w_i} \odot w_i$

 

왜 이렇게 했는지 살펴봅시다. 목표는 모든 layer 골고루 프루닝하는 것입니다. 그러기 위해서는 $S(w_i)$가 아래를 만족해야 합니다

1. Neuron-wise Conservation: activation 함수 영향 없어야 함 (ex. relu)

activation 영향이 없다는 것은 activation 기준으로 Sin = Sout

모든 weight에 절대값을 취하고 input이 1이므로 만족

2. Network-wise Conservation : 모든 파라미터가 동일하게 score에 반영되어야 함

아래 equation을 만족한다

$\left\langle \frac{\partial \mathcal{R}}{\partial W^{[l]}}, W^{[l]} \right\rangle + \sum_{i=l}^{L} \left\langle \frac{\partial \mathcal{R}}{\partial b^{[i]}}, b^{[i]} \right\rangle = \left\langle \frac{\partial \mathcal{R}}{\partial \mathbf{y}}, \mathbf{y} \right\rangle$

 

SynFlow의 장점 중 하나는 프루닝할때 data 대신 unit input을 넣어 train data가 필요없다는 것입니다

 

 

Pruning 시 중요한 점

https://arxiv.org/pdf/2009.08576

Early pruning method가 얼마나 효용성이 있을까

프루닝 방법 비교

Ramdom 하게 아무거나 프루닝하는 것보다는 좋은 성능을 보였지만, Magnitude after training보다 좋은 성능을 보이지는 못함

 

 

Experiment with initialization

Reinit : 마스킹을 하고 나서 마스킹한 위치를 유지하고 나머지 값들을 초기화함

Shuffled Layerwise : 마스킹 제외 layer내에서 값을 섞음

위 두개를 하더라도 unmodified와 성능차이가 크게 나지 않는다

 

즉, [이 위치에 이 값이 있어야한다]가 중요한 것이 아닌 [이 위치에 마스킹은 들어가야한다]가 더 중요한 요소이다

 

실험에서 한가지 특이한 점있다

Inverted : Saliency Score를 반대로 정렬

GraSP는 정렬을 반대로 하든, 어떤 변형을 가하든 결과가 똑같다, 즉 weight의 중요도를 제대로 측정하지 못한다는 것이다