인공지능/Knowledge Distillation

Knowledge Distillation from A Stronger Teacher

NickTop 2024. 7. 14. 17:52

Strong teacher?

- 성능이 좋은 모델 (파라미터가 큰 모델일수도 있고, data augmentation이나 epoch을 늘린 모델일수도있음)

 

 

이전에 진행되었던 연구 방향

Intermediate representation으로 성능향상

OFD, CRD

student모델과 teacher 모델의 차이가 크면 Student가 잘 학습이 이루어지 않는다

- TAKD, DGKD : teaching assistant가 있다

 

 

Resnet 크기 별 student 성능
Strong 모델과 KL Divergence 비교



모델의 크기가 커지면 오히려 student의 성능이 떨어지는 경우도 있다

왜냐하면, teacher 모델이 복잡해질수록 student가 teacher를 재현하기 힘들다

따라서, teacher의 정확한 값을 복원하기보다 예측값의 상대적인 순위를 student에 distillation 해주는 것만으로도 충분하다

 

DIST : Distillation from A Stronger Teacher

 

이러한 이유로, 본 논문에서는 KL Divergence를 쓰는 대신 Pearson correlation을 썼다

$d_p(u, v) := 1 − ρ_p(u, v)$

distance는 다음과 같이 정의한다

$ρ_p(u, v)$는 피어슨 상관계수이다.

피어슨 상관 계수는 두 Random Variable이 서로 비례하는 관계일 때 1이다

즉, teacher와 student의 Rank가 동일하다면 loss가 거의 없다

 

Batch한번 돌 때 inter-class loss는 다음과 같이 계산된다

$L_{inter} := \frac{1}{B} \sum^B_id_p(Y_{i,:}(s),Y_{i,:}(t))$

 

DIST 방법과 이전의 방법의 비교

 

또한, inter-class만 사용하는 기존의 KD와 달리, intra-class relation도 distillation 시켜준다

 

배치 한번 돌때마다 분포를 구한다

예를들어 cat, dog, plane이라는 class가 있다고 하자

cat클래스에는 cat의 score가 가장 높을 것이다. 그리고 plane의 score가 가장 낮을 것이다

cat > dog > plane 이라는 정보를 student에게 전달해준다

(intra-class를 쓰면 이 정보가 왜 전달된다는 건진 잘 모르겠음)

$L_{intra} := \frac{1}{B} \sum^C_jd_p(Y_{:,j}(s),Y_{:,j}(t))$

 

이를 종합하여 Loss는 다음과 같이 계산된다

L_{tr} = αL_{cls} + βL_{inter} + γL_{intra}

cls는 ground truth이다

 

Experiment

performance on DIST

 

Strong teacher 모델에서 KD를 어떻게 적용해야하는지 시사한다

이를바탕으로 적용가능한 다른 Loss 함수가 있는지 살펴볼수있을것같다.