[논문리뷰] Distilling the Knowledge in a Neural Network
https://arxiv.org/abs/1503.02531
Distilling the Knowledge in a Neural Network
A very simple way to improve the performance of almost any machine learning algorithm is to train many different models on the same data and then to average their predictions. Unfortunately, making predictions using a whole ensemble of models is cumbersome
arxiv.org
Abstract
모델의 성능을 향상시키는 가장 간단한 방법은 여러 모델을 학습한 후 예측을 평균하여 성능을 향상시키는 앙상블 기법이다.
하지만, 앙상블 모델을 실제로 배포하기에는 과도한 계산 비용이 발생하는데 특히 대규모 신경망을 여러 개 사용할 경우는 더욱 심하다.
본 논문은 앙상블 모델을 single model로 지식을 압축시켜 배포 가능한 효율적인 모델 제작을 제안한다.
Distillation을 통해서 더 작은 모델이 앙상블 모델의 성능을 그대로 유지하도록 한다.
Introduction
머신러닝에서는 학습 단계와 배포 단계의 요구 사항이 다르다.
- 학습 단계 : 큰 데이터셋에서 복잡한 구조를 학습하는 데 많은 계산 자원을 사용할 수 있음
- 배포 단계 : 실시간 처리와 계산 자원 최적화가 중요함
대규모 데이터에서 더 나은 결과를 얻기 위해 복잡한 모델이나 앙상블 모델을 학습할 수 있다. 이후에는 이 복잡한 모델을 작고 배포에 적합한 모델로 변환해야 하는데 이를 지식 증류(Knowledge Distillation)이라고 부른다.
복잡한 모델의 학습 & 일반화
이때 큰 모델에서 학습된 지식이 단순히 매개변수(파라미터) 값에만 국한되지 않고, 더 추상적인 개념으로 봐야한다.
즉, 모델이 학습한 내용을 입력 데이터와 출력 데이터 간의 매핑 관계로 이해해야 한다.
학습된 지식 = 매개변수 값이라고 생각하면, 모델의 형태를 바꾸더라도 지식을 유지하는 것이 어렵다고 느껴진다.
그러나 지식을 더 추상적인 개념(입력과 출력 간의 관계)로 이해하면 된다.
복잡한 모델은 단순히 정답만을 맞히는 것이 아니라 잘못된 답변에도 어느 정도의 확률을 할당한다.
-> BMW를 보고 쓰레기 트럭으로 잘못 인식할 확률은 아주 작음. 하지만 당근으로 잘못 인식할 확률보다는 훠씬 높음.
이런 오답 간의 상대적 확률 차이가 모델이 데이터를 어떻게 일반화하는지 보여준다.
즉 복잡한 모델은 단순히 정답을 맞추는 것 외에도, 잘못된 답변에도 확률을 부여하여 잘못된 답변들 간의 확률 차이를 통해 데이터를 일반화한다.
작은 모델로 지식 전이
그래서 복잡한 모델이 잘 일반화하는 방식을 작은 모델로 전이하는 것이 핵심이다.
소프트 타겟은 복잡한 모델이 각 클래스에 대해 예측한 확률 분포를 의미한다. 예를 들어, 이미지가 특정 클래스로 분류될 확률을 모델이 예측하면, 이 확률 값을 사용하여 작은 모델을 학습시킬 수 있다.
복잡한 모델에서 얻은 이 확률 분포(소프트 타겟)를 사용해 작은 모델을 학습한다. 복잡한 모델이 여러 개의 더 간단한 모델로 구성된 앙상블(ensemble)일 경우, 각 모델이 예측한 확률을 산술 평균 또는 기하 평균으로 계산하여 소프트 타겟을 만든다.
엔트로피가 높은 소프트 타겟은 즉, 각 클래스에 대한 확률 분포가 넓게 퍼져 있을 때, 학습에 더 많은 정보를 제공한다.
-> 작은 모델이 정답 뿐 아니라 잘못된 클래스에 대한 확률 차이도 배우도록 하여, 데이터를 더 잘 일반화하는 데 도움
소프트 타겟을 사용하면 작은 모델은 복잡한 모델보다 훨씬 적은 데이터로도 학습할 수 있으며, 높은 학습률을 사용할 수 있다.
로짓(logit)을 이용한 학습
MNIST와 같은 데이터셋에서는 복잡한 모델이 매우 높은 신뢰도로 정답을 예측한다.
로짓(logit), 즉 소프트맥스 함수 입력 값을 이용해 작은 모델이 복잡한 모델의 예측을 모방하도록 학습할 수 있다.
예를 들어 '2'라는 숫자를 매우 정확하게 '2'로 예측하는데 이 모델이 예측하는 다른 클래스의 확률('2'가 '3'일 가능성)도 아주 작은 값으로 나오며, 이 값들도 데이터 간의 유사성 구조를 담고 있다.
문제는 매우 작은 확률값들이 학습 과정에서 거의 영향을 미치지 않는다는 것이다. 작은 모델이 소프트 타겟을 학습할 때, 작은 확률의 차이는 교차 엔트로피 비용 함수에 크게 반영되지 않아, 학습이 제대로 이루어지지 않을 수 있다.
따라서 softmax로 계산된 확률값 대신, softmax에 들어가기 전의 로짓(logit) 값을 타겟으로 사용하는 방법을 제안했다.
로짓은 softmax 변환을 거치기 전의 모델 출력 값으로, 이는 원래 확률보다 더 큰 값들을 가진다.
Knowledge Distiilation은 소프트 타겟을 생성할 때 softmax의 온도 파라미터를 높이는 방법이다. 온도 조정을 통해 복잡한 모데르이 예측값을 더 소프트하게 만들고, 이 값을 사용해 작은 모델을 학습한다.
온도를 높이게 되면 예측하는 확률 분포가 더 부드럽게 퍼지게 된다.
전이 세트
작은 모델을 학습하는 데 사용하는 데이터는 레이블이 없는 데이터일 수도 있고, 원래의 훈련 데이터일 수도 있다.
연구에 따르면, 원래의 훈련 세트를 사용하는 것이 더 효과적이며, 특히 정답을 맞추는 항목을 추가하면 작은 모델이 더 좋은 성능을 낼 수 있다.
Distillation

신경망은 일방적으로 softmax 출력층을 사용해 각 클래스에 대해 계산된 로짓 값을 확률로 변환한다. 이는 각 클래스에 대해 계산된 로짓 $z_i$를 다른 로짓과 비교하여 클래스 확률 $q_i$로 변환하는 방식이다.
여기서 $T$는 온도(temerature)이며, 크게 설정하면 클래스들 간의 확률 분포가 더 부드럽게 변한다.
증류에서 가장 간단한 형태는 큰 모델이 예측한 소프트 타겟을 사용해서 작은 모델을 훈련시키는 것이다.
전이 세트에 대해 높은 온도에서 생성된 소프트 타겟을 사용하여 작은 모델을 훈련한다.
정답 레이블이 이미 알려져 있다면, 작은 모델을 훈련할 때 정답 레이블을 함께 사용하여 성능을 향상 수 있다.
두 개의 목표 함수를 사용하여 가중 평균을 계산하면 된다.
1. 소프트 타겟에 대한 교차 엔트로피 (큰 모델에서 생성된 소프트 타겟과 동일한 높은 온도를 사용)
2. 정답 레이블에 대한 교차 엔트로피
실험 결과, 두 번째 목표 함수의 가중치를 상대적으로 낮게 설정하는 것이 가장 좋은 결과를 도출한다는 점을 발견
Matching logits is a special case of distillation

로짓 : 신경망의 마지막 레이어에서 나온 값으로, softmax 함수를 거쳐 각 클래스에 대한 확률로 변환되는데 이 확률을 만들기 전에 나온 값
소프트 타겟: 복잡한 모델이 출력한 확률 분포로, 각 클래스에 대해 모델이 얼마나 확신하고 있는지를 보여줌. 소프트 타겟은 일반적으로 "딱 맞는" 정답 (하드 타겟)보다 더 많은 정보를 제공함
복잡한 모델이 로짓 $v_i$를 생성하여 소프트 타겟 확률 $p_i$를 만든다고 할 때, 전이 훈련이 온도 $T$에서 이루어진다.
온도가 로짓의 크기에 비해 높다면 다음과 같이 근사할 수 있다.

각 전이 사례에 대해 로짓이 0을 중심으로 맞춰졌다고 가정한다면 다음과 같이 간단해진다.
$$\sum_{j} z_j = \sum_{j} v_j = 0$$

온도가 높으면, 로짓의 차이를 최소화하는 방식으로 학습하게 되며, 이는 증류 과정에서 로짓을 직접 맞추는 것과 같다.
온도가 낮을 때는 작은 음수 로짓에 대한 중요도가 낮아져서 모델이 이러한 작은 값들을 크게 신경 쓰지 않는다. 그러나, 작은 음수 로짓들도 유용한 정보를 포함할 수 있기에 중간 정도의 온도가 최적의 성능을 제공한다.
즉, 높은 온도는 로짓 간의 차이를 최소화하여 모델을 학습시키는 데 집중하게 하고, 낮은 온도는 노이즈가 많은 음수 로짓에 덜 집중하게 하여 더 유용한 정보를 포착하도록 한다.
Preliminary experiments on MNIST
두 개의 숨겨진 레이어와 각 레이어에 1200개의 Rectified Linear 유닛을 가진 대형 신경망을 훈련한다.
드롭아웃과 가중치 제약을 사용해 강력한 정규화를 적용한다.
성능 비교
- 대형 네트워크는 67개의 테스트 오류를 달성
- 작은 네트워크(800개의 유닛)는 정규화 없이 146개의 테스트 오류를 달성
- 작은 네트워크가 대형 네트워크의 소프트 타겟을 온도 20에서 일치시키는 추가 작업을 통해 74개의 테스트 오류를 달성
- 이는 소프트 타겟이 일반화에 대한 지식을 포함하여, 모델에 많은 정보를 전이할 수 있음을 보여줌
온도 설정
- 증류된 네트워크의 유닛 수가 300 이상일 때, 8 이상의 온도에서 결과가 유사
- 유닛 수를 30으로 줄였을 때, 2.5에서 4 범위의 온도가 더 좋은 성능을 발휘