인공지능

Dataset Distillation

NickTop 2025. 5. 1. 15:01

https://arxiv.org/pdf/1811.10959

 

Introduction

MNIST와 cifar10 distillation 결과

dataset으로부터 가짜이미지를 만들어 적은 이미지로도 모델을 학습합니다

model poisioning

특정 label(plane)을 attack하는 이미지를 만들어 gradient step하나만 돌려 모델의 정확도를 낮출 수도 있습니다

 

Approach

$\theta_1 = \theta_0 - \tilde{\eta} \nabla_{\theta_0} \ell(\tilde{\mathbf{x}}, \theta_0) $ : weight는 가짜 데이터 (synthetic data)로만 학습됩니다

$\tilde{\mathbf{x}}^*, \tilde{\eta}^* = \arg\min_{\tilde{\mathbf{x}}, \tilde{\eta}} \mathcal{L}(\tilde{\mathbf{x}}, \tilde{\eta}; \theta_0) = \arg\min_{\tilde{\mathbf{x}}, \tilde{\eta}} \ell (\mathbf{x}, \theta_1) = \arg\min_{\tilde{\mathbf{x}}, \tilde{\eta}} \ell(\mathbf{x}, \theta_0 - \tilde{\eta} \nabla_{\theta_0} \ell(\tilde{\mathbf{x}}, \theta_0)) $

 

$\arg\min_{\tilde{\mathbf{x}}, \tilde{\eta}} \ell (\mathbf{x}, \theta_1)$ : 실제 데이터 x의 loss가 줄어드는 방향으로 synthetic data와 learning rate를 선택합니다

 

계산

gradient를 계산하는 방법을 살펴봅시다

 

${\begin{align} 
\frac {\partial L}{\partial \tilde{x}} &=  \frac{\partial L}{\partial \theta_1} \frac{\partial \theta_1}{\partial \tilde{x}} \\
&= \frac{\partial L}{\partial \theta_1} (\frac {\partial}{\partial \tilde x}(\theta_0 - \tilde \eta \nabla \ell (\tilde x, \theta_0))) \\
&= \frac{\partial L}{\partial \theta_1}  (\frac{\partial \theta_0}{\partial \tilde x} - \tilde \eta \frac{\partial^2 \ell (\tilde x, \theta_0))}{\partial \tilde x \partial \theta_0}) \\
&=-\tilde \eta \frac{\partial L}{\partial \theta_1} \frac{\partial^2 \ell (\tilde x, \theta_0))}{\partial \tilde x \partial \theta_0}
\end{align} }$

 

approach에서 weight가 한번만 업데이트 되었지만, 실제로는 여러번 업데이트 됩니다

$\theta_t$ 상태에서 $\theta_T$가 되었다고 합시다 (T-t번 업데이트됨)

그리고 한번 업데이트할때 s만큼의 데이터를 쓴다고 합시다

 

${\theta_{s+1} = \theta_s - \tilde{\eta}_s \nabla_{\theta_s} \ell(\tilde{\mathbf{x}}, \theta_s)}$

미분하면

${\frac {\partial \theta_{s+1}}{\partial \theta_s} = I - \tilde \eta_s H_s}$

 

이를 반영하면

${\begin{align}
\frac {\partial L}{\partial \tilde x} &= \frac {\partial L}{\partial  \theta_T}\frac {\partial \theta_T}{\partial  \theta_{T-1}}\frac {\partial \theta_{T-1}}{\partial  \theta_{T-2}}...\frac {\partial \theta_{t+2}}{\partial  \theta_{t+1}} \frac {\partial \theta_{t+1}}{\partial  \tilde x} \\
&=\frac {\partial L}{\partial  \theta_T} \prod_{s=t+1}^{T}(I-\tilde\eta_sH_s) (-\tilde\eta_t \frac{\partial^2 \ell (\tilde x, \theta_t))}{\partial \tilde x \partial \theta_t})
\end{align}}$

 

하지만 실험과정에서 위와 같이 얻은 synthetic data를 다른 weight initialization에 사용했을때 문제가 있었다고 합니다.

그래서 여러개의 weight initialization에 대해서 synthetic data를 업데이트 했다고 합니다

 

N개의 weight initialization에 대해 synthetic data를 업데이트 했다고 합시다

논문에서 최종 loss는 아래와 같이 계산됩니다

 

${\begin{align}
\frac {\partial L}{\partial \tilde x} 
&=\sum^N_{i=1} \frac {\partial L}{\partial  \theta_T^{(i)}} \prod_{s=t+1}^{T}(I-\tilde\eta_s^{(i)}H_s^{(i)}) (-\tilde\eta_t^{(i)} \frac{\partial^2 \ell (\tilde x, \theta_t^{(i)}))}{\partial \tilde x \partial \theta_t^{(i)}}) \\
\end{align}}$

 

$\tilde \eta$도 학습되는 파라미터입니다

비슷한 방법으로 계산하면

${\frac {\partial L}{\partial \tilde \eta_t^*}=\frac {\partial L}{\partial  \theta_T} \prod_{s=t+1}^{T}(I-\tilde\eta_sH_s) (- \frac{\partial \ell (\tilde x, \theta_t))}{ \partial \theta_t})}$