Wasserstein Distance Rivals Kullback-Leibler Divergence

for Knowledge Distillation

Jiaming Lv*, Haoyuan Yang*, Peihua Li

Dalian University of Technology

*Equal contribution. The corresponding author.

NeurIPS 2024 (Poster)

Code: https://github.com/JiamingLv/WKD Paper: arXiv Video: Youtube (in English), B Site (in Chinese)

Abstract

Since pioneering work of Hinton et al., knowledge distillation based on KullbackLeibler Divergence (KL-Div) has been predominant, and recently its variants have achieved compelling performance. However, KL-Div only compares probabilities of the corresponding category between the teacher and student while lacking a mechanism for cross-category comparison. Besides, KL-Div is problematic when applied to intermediate layers, as it cannot handle non-overlapping distributions and is unaware of geometry of the underlying manifold. To address these downsides, we propose a methodology of Wasserstein Distance (WD) based knowledge distillation. Specifically, we propose a logit distillation method called WKD-L based on discrete WD, which performs cross-category comparison of probabilities and thus can explicitly leverage rich interrelations among categories. Moreover, we introduce a feature distillation method called WKD-F, which uses a parametric method for modeling feature distributions and adopts continuous WD for transferring knowledge from intermediate layers. Comprehensive evaluations on image classification and object detection have shown (1) for logit distillation WKD-L outperforms very strong KL-Div variants; (2) for feature distillation WKD-F is superior to the KL-Div counterparts and state-of-the-art competitors.

Highlights

We propose a novel methodology of Wasserstein distance based knowledge distillation (WKD), extending beyond the classical Kullback-Leibler divergece based one pioneered by Hinton et al. Specifically,

  • We present a discrete WD based logit distillation method (WKD-L). It can leverage rich interrelations (IRs) among classes via cross-category comparisons between predicted probabilities of the teacher and student, overcoming the downside of category-to-category KL divergence.
  • We introduce continuous WD into intermediate layers for feature distillation (WKD-F). It can effectively leverage geometric structure of the Riemannian space of Gaussians, better than geometryunaware KL-divergence.
  • On both image classification and object detection tasks, WKD-L perform better than very strong KL-Div based logit distillation methods, while WKD-F is supervisor to the KL-Div counterparts and competitors of feature distillation.

Experiments

We evaluate WKD for image classification on ImageNet [41] and CIFAR-100 [42]. Also, we evaluate the effectiveness of WKD on self-knowledge distillation (Self-KD). Further, we extend WKD to object detection and conduct experiment on MS-COCO [43].

Image classification on ImageNet

Results (Acc, %) on ImageNet. In setting (a), the teacher (T) and student (S) are ResNet34 and ResNet18, respectively, while setting (b) consists of a teacher of ResNet50 and a student of MobileNetV1.

Comparison (Top-1 Acc, %) on ImageNet between WKD and the competitors with different setups. Red numbers indicate the teacher/student model has non-trivially higher performance than the commonly used ones formalized in CRD [25]. We provide the gains of the distilled student over the corresponding vanilla student.

Image classification on CIFAR-100


Following OFA [46], we evaluate WKD in the settings where the teacher is a CNN and the student is a Transformer or vice versa. We use CNN models including ResNet (RN) [9], MobileNetV2 (MNV2) [58] and ConvNeXt [59], as well as vision transformers that involve ViT [60], DeiT [61], and Swin Transformer [62].


Self-knowledge distillation on ImageNet


We implement our WKD in the framework of Born-Again Network (BAN) [66] for self-knowledge distillation (Self-KD). We conduct experiments with ResNet18 on ImageNet.

 

Object detection on MS-COCO

We extend WKD to object detection in the framework of Faster-RCNN [47]. For WKD-L, we use the classification branch in the detection head for logit distillation. For WKD-F, we transfer knowledge from features straightly fed to the classification branch, i.e., features output by the RoIAlign layer.

Citation


@inproceedings{WKD-NeurIPS2024,
title={Wasserstein Distance Rivals Kullback-Leibler Divergence for Knowledge Distillation},
author={Jiaming Lv and Haoyuan Yang and Peihua Li},
booktitle={Advances in Neural Information Processing Systems},
year={2024}
}