개요
Knowledge distillation에서 logit을 Standardization하는 것이 더 효과가 좋다는 것을 보여준다
위 그림을 보았을 때, 실제로는 오른쪽 student가 더 예측을 잘했음에도 불구하고 오른쪽 student의 KL-divergence가 더 높습니다. standardization을 시키면 그 문제가 해결됩니다
저자는 probability function로 Softmax를 써야한다고 말한 후 수식을 통해 기존의 문제점을 지적하고 standardization을 쓰면 개선된다고 얘기합니다
Derivation of softmax in Classification
softmax는 엔트로피를 높이는 probability function 중 최적의 해이다
probability function을 q라고 하고 logit을 v라고 하자
다음과 같은 목적식이 있습니다 (왜 엔트로피가 최대가 되는 것이 좋은지는 잘 모르겠습니다)
$\underset{q}{max}L_1 = - \sum_{n=1}^{N} \sum_{k=1}^{K} q(v_{n})^{(k)} \log q(v_{n})^{(k)}$
(N은 배치의 크기, K는 레이블의 개수)
두 가지의 constraint를 가집니다
1. 확률을 다 더하면 1
$\sum_{k=1}^{K}q(v_{n})^{(k)} = 1 \quad \forall n$
2. Target class(y_n)로 분포를 맞춘다
$\sum_{k=1}^K v_{n}^{(k)} \, q(v_{n})^{(k)} = v_{n}^{(k)} \quad \forall n$
라그랑주 함수로 최적의 해를 찾아봅시다
Lagrangian multipliers $\alpha$를 적용합시다
$L = - \sum_{n=1}^{N} \sum_{k=1}^{K} q(v_{n})^{(k)} \log q(v_{n})^{(k)} + \sum_{n=1}^N \alpha_{1,n}(\sum_{k=1}^{K}q(v_{n})^{(k)} -1)+\sum_{n=1}^{N} \alpha_{2,n}(\sum_{k=1}^K v_{n}^{(k)} \, q(v_{n})^{(k)} - v_{n}^{(y_n)})$
$L = \sum_{n=1}^{N} \sum_{k=1}^{K}(-q(v_{n})^{(k)} \log q(v_{n})^{(k)} + \alpha_{1,n}q(v_{n})^{(k)} + \alpha_{2,n} v_{n}^{(k)} q(v_{n})^{(k)} ) +C_{constant}$
$ q(v_{n})^{(k)} $로 미분했을 때 0이어야 합니다
($q(v_{n,k})$로 미분한다는 개념이 완전히 이해되지는 않네요)
$-1-\log q(v_{n})^{(k)} + \alpha_{1,n} + \alpha_{2,n} v_{n}^{(k)} = 0$
$\Rightarrow q(v_{n})^{(k)} = exp(A*v_{n}^{(k)})/Z_T$
위 식과 제약식의 조건으로 인해 q는 softmax가 됩니다
Derivation of softmax in KD
v를 teacher의 logit, z를 student의 logit이라고 합시다
$\underset{q}{max}L_2 = - \sum_{n=1}^{N} \sum_{k=1}^{K} q(z_{n})^{(k)} \log q(z_{n})^{(k)}$
다음과 같은 constraints를 가집니다
1. 확률을 다 더하면 1
$\sum_{k=1}^{K}q(z_{n})^{(k)} = 1 \quad \forall n$
2. Target class(y_n)로 분포를 맞춘다
$\sum_{k=1}^K z_{n}^{(k)} \, q(z_{n})^{(k)} = z_{n}^{(k)} \quad \forall n$
3. Student가 teacher의 분포를 따르도록 한다 (식 자체가 잘 와닿지는 않지만 분포를 따른다고 단순히 q(z)=q(v)를 성립하도록 하는 것이 아닌 느슨하게 따르도록 하는 것 같습니다)
$\sum_{k=1}^K z_{n}^{(k)} \, q(z_{n})^{(k)} =\sum_{k=1}^K z_{n}^{(k)} \, q(v_{n})^{(k)}$
동일하게 ragrange 식을 적용합니다
$L_S = L_2 + \sum_{n=1}^{N} \beta_{1,n} \left( \sum_{k=1}^{K} q(z_n)^{(k)} - 1 \right) + \sum_{n=1}^{N} \beta_{2,n} \left( \sum_{k=1}^{K} z_n^{(k)} q(Z_n)^{(k)} - z_{(y_n)} \right) \sum_{n=1}^{N} \beta_{3,n} z_n^{(k)} \left( q(z_n)^{(k)} - q(v_n)^{(k)} \right).$
$\frac{\partial L_S}{\partial q(z_n)^{(k)}} = -1 - \log q(z_n)^{(k)} + \beta_{1,n} + \beta_{2,n} z_n^{(k)} + \beta_{3,n} z_n^{(k)}$
정리하면
$q(z_n)^{(k)} = \exp\left(C * z_n^{(k)}\right)/Z_s$
Classification / KD 모두 softmax를 사용할 수 있다는 결론입니다
두 가지 더 의미가 있습니다
1. Teacher와 student의 temperature를 다르게 해도 된다 (상수 A,C를 다르게 해도 됨)
2. Sample(batch) 마다 다르게 temperature를 정해도 된다
Drawbacks of Shared Temperatures
Shared Temperature의 문제점을 알기 위해 softmax를 다음과 같이 정의합시다. 기존의 softmax에 shift가 추가되었습니다
$q\left(z_n; a_S, b_S\right)^{(k)} = \frac{\exp\left[\left(z_n^{(k)} - a_S\right)/b_S\right]}{\sum_{m=1}^{K} \exp\left[\left(z_n^{(m)} - a_S\right)/b_S\right]}$
잘 학습된 student는 teacher의 분포와 일치해야 합니다
$q\left(z_n; a_S, b_S\right)^{(k)} = q\left(v_n; a_T, b_T\right)^{(k)}$
이 식으로 부터 임의의 label i,j에 대해 다음과 같은 등식이 성립합니다
$\frac{\exp\left[\left(z_n^{(i)} - a_S\right)/b_S\right]}{\exp\left[\left(z_n^{(j)} - a_S\right)/b_S\right]} = \frac{\exp\left[\left(v_n^{(i)} - a_S\right)/b_S\right]}{\exp\left[\left(v_n^{(j)} - a_S\right)/b_S\right]}$
$\Rightarrow \frac{z_n^{(i)} - z_n^{(j)}}{b_S} = \frac{v_n^{(i)} - v_n^{(j)}}{b_T}$
j를 1부터 K까지 더해주면
$\frac{z_n^{(i)} - \bar{z}_n}{b_S} = \frac{v_n^{(i)} - \bar{v}_n}{b_T}$
이로부터 두가지 사실을 얻을 수 있습니다
1. Logit shift를 고려해야함
temperature가 teacher/student 모두 1이라면 $z^{(i)}_n = v^{(i)}_n + ∆n$
(logit shift는 충분히 있을 수 있다고 보는데 식 $q\left(z_n; a_S, b_S\right)^{(k)}$가 유도된 식이 아니라 임의로 정한 식같은데 이로부터 결론을 짓는 것이 맞는지 모르겠습니다, 논문원본: From Eq.9, it can be found that a constant shift exists)
2. Variance가 일치해야함
$\frac{\sigma(z_n)^2}{\sigma(v_n)^2} = \frac{\frac{1}{K} \sum_{i=1}^{K} \left( z_n^{(i)} - \bar{z}_n \right)^2}{\frac{1}{K} \sum_{i=1}^{K} \left( v_n^{(i)} - \bar{v}_n \right)^2} = \frac{b_S^2}{b_T^2}$
$\sigma(z_n) =\sigma(v_n)$가 되도록 temperature를 조절해야 한다
4.3. Logit Standardization
위 수식에 대한 결론으로,
teacher와 student를 standardization 시켜서 문제를 해결합니다
logit standardization외에 적용한 다른 건 없습니다
Experiment
다른 기법들에 logit standardization을 해줬을 때 성능에 개선을 보였습니다
'인공지능 > Knowledge Distillation' 카테고리의 다른 글
Improving Knowledge Distillation via Regularizing Feature Norm and Direction (0) | 2024.08.16 |
---|---|
Revisit the Power of Vanilla Knowledge Distillation: from Small Scale to Large Scale (0) | 2024.07.31 |
Knowledge Distillation from A Stronger Teacher (0) | 2024.07.14 |
Self Distillation - Be Your Own Teacher (0) | 2024.04.12 |