本文介绍了知识蒸馏(Knowledge Distillation)
技术,这是一种将大型、计算成本高昂的模型的知识转移到小型模型上的方法,从而在不损失有效性的情况下实现在计算能力较低的硬件上部署,使得评估过程更快、更高效。
https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html
在本教程中,我们将进行一系列实验,专注于通过使用更强大的网络作为教师来提高轻量级神经网络的准确性。通过本教程,你将学习到:
- 如何修改模型类以提取隐藏层表示,并将其用于进一步计算。
- 如何修改PyTorch中的常规训练循环,以包括额外的损失函数,例如在分类上的交叉熵之外。
- 如何通过使用更复杂的模型作为教师来提高轻量级模型的性能。
步骤 1:读取环境
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
步骤 2:加载数据集
CIFAR-10 是一个非常流行的图像数据集,它包含了十个类别,包含 60,000 张 32x32 像素的彩色图像,分为 50,000 张训练图像和 10,000 张测试图像,每个类别有 6,000 张图像。
# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transforms_cifar = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)
#Dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
输入图像是RGB格式的,这意味着它们具有3个通道,并且每个通道的尺寸是32x32像素。基本上,每张图像由3072个数值(3个通道乘以32像素宽乘以32像素高)组成,这些数值的范围是从0到255。
步骤3:定义Teacher 和 Student模型
我们使用两种不同的架构,为了确保实验之间的公平比较,我们在实验中保持滤波器的数量固定。这两种架构都是卷积神经网络(CNN),它们具有不同数量的卷积层作为特征提取器。
# Deeper neural network class to be used as teacher:
class DeepNN(nn.Module):
def __init__(self, num_classes=10):
super(DeepNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
# Lightweight neural network class to be used as student:
class LightNN(nn.Module):
def __init__(self, num_classes=10):
super(LightNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
步骤4:训练模型
通过正常的交叉熵来训练模型,并对比了两个模型的精度:
- Teacher parameters: 1,186,986
- Student parameters: 267,738
def train(model, train_loader, epochs, learning_rate, device):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
model.train()
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_loader:
# inputs: A collection of batch_size images
# labels: A vector of dimensionality batch_size with integers denoting class of each image
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
# outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes
# labels: The actual labels of the images. Vector of dimensionality batch_size
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
def test(model, test_loader, device):
model.to(device)
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")
return accuracy
torch.manual_seed(42)
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_loader, device)
# Instantiate the lightweight network:
torch.manual_seed(42)
nn_light = LightNN(num_classes=10).to(device)
train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_light, test_loader, device)
Epoch 1/10, Loss: 1.3248219371146863
Epoch 2/10, Loss: 0.8667521514855993
Epoch 3/10, Loss: 0.6845618891136726
Epoch 4/10, Loss: 0.5354324993117691
Epoch 5/10, Loss: 0.41447551155944007
Epoch 6/10, Loss: 0.30590455001577394
Epoch 7/10, Loss: 0.2205375938121315
Epoch 8/10, Loss: 0.17421267175918345
Epoch 9/10, Loss: 0.13511126418419353
Epoch 10/10, Loss: 0.12150332882828876
Test Accuracy: 74.84%
步骤5:蒸馏模型(交叉熵蒸馏)
知识蒸馏(Knowledge Distillation, KD)是一种提高学生网络测试准确率的技术,其核心思想是利用教师网络的输出概率分布来辅助学生网络的训练。由于两个网络都输出相同类别的概率分布,因此它们具有相同数量的输出神经元。知识蒸馏的方法是在传统的交叉熵损失中加入一个额外的损失项,这个损失项基于教师网络的softmax输出。
利用软目标中小概率的比率可以帮助实现深度神经网络的潜在目标,即在数据上创建一个相似性结构,使得相似的对象被映射得更接近。
- T(温度参数):控制输出分布的平滑度。较大的T会导致更平滑的分布,从而使得较小的概率得到更大的提升。
- soft_target_loss_weight(软目标损失权重):分配给即将包括的额外目标的权重。
- ce_loss_weight(交叉熵损失权重):分配给交叉熵的权重。调整这些权重可以使网络优化任一目标。
步骤6:蒸馏模型(余弦损失)
在进行余弦损失(Cosine Loss)最小化的过程中,我们的目标是使学生网络的隐藏层状态更接近教师网络的隐藏层状态。这与知识蒸馏的目标相似,但这次我们关注的是网络内部的表示,而不仅仅是输出层。
以下是实现这一目标的步骤和考虑因素:
温度参数(Temperature Parameter):控制softmax函数的平滑度。调整温度参数可以影响损失函数的敏感度。
损失系数(Loss Coefficients):控制不同损失项在总损失中的权重。
余弦嵌入损失(CosineEmbeddingLoss):这是一种衡量两个向量之间余弦相似度的损失函数。它的公式如下:
$$L = 1 - \text{cosine\_similarity}(u, v) $$其中,( u ) 和 ( v ) 是两个向量,余弦相似度计算为:
$$\text{cosine_similarity}(u, v) = \frac{u \cdot v}{\|u\| \|v\|} $$匹配维度(Matching Dimensions):由于教师网络和学生网络在卷积层后的维度可能不同,我们需要通过某种方式(如平均池化层)来调整教师网络的维度,使其与学生网络相匹配。
修改模型类(Modifying Model Classes):我们需要修改现有的模型类或创建新的模型类,以确保前向传播函数不仅返回网络的logits,还返回卷积层后的扁平化隐藏表示。
平均池化(Average Pooling):在教师网络的卷积层后添加平均池化层,以减少其维度,使其与学生网络匹配。
步骤6:蒸馏模型(中间回归)
在进行中间回归器(intermediate regressor)的运行时,我们的目标是通过引入一个额外的可训练网络(即回归器)来改善教师和学生网络特征图的匹配过程。这种方法比简单的余弦损失最小化更为复杂和高级,因为它允许更灵活地调整特征表示以适应两个网络之间的差异。
- 提取特征图(Extract Feature Maps):首先,从教师网络和学生网络的卷积层提取特征图。这些特征图是在分类器之前的最后一层卷积层的输出。
- 中间回归器(Intermediate Regressor):引入一个额外的网络,其目的是将学生网络的特征图转换为与教师网络的特征图具有相同维度的表示。这个回归器是可训练的,并且可以学习如何最好地转换特征表示。
- 损失函数(Loss Function):定义一个损失函数来衡量学生网络特征图和经过回归器转换后的特征表示与教师网络特征图之间的差异。这个损失函数可以是欧几里得距离、余弦相似度或其他适合度量差异的函数。