不知道大家會不會有這種感覺,在使用現今的一些預訓練模型時,雖然好用,但是實際在場域部屬時總感覺殺雞焉用牛刀,實際使用下去後續又沒有時間讓你去優化它,只好將錯就錯反正能用的想法持續使用,現在有個不錯的方法讓你在一開始就可以用相對低廉的成本去優化這個模型,讓後續使用不再懊悔。
這個方法叫做 Distilling Knowledge ,中文可譯作知識蒸餾,這個方法的概念很簡單,我們如果將整個模型訓練過程當作是考試,這個模型就是學生,而訓練資料就是考題,模型(學生)要做的事情很簡單,就是拿到訓練資料(考題)後運算出一個結果,若與正解相似度愈高則愈高分,平常的學生依靠自己的本事答題,但若有一個學生,它有一個家教協助它統整考題,總結出所謂的必勝公式,那麼這個學生是不是會比沒有家教的學生答題準確性及達到及格標準的速度來得高?
利用這個想法,要做到知識蒸餾有三個步驟 :
我們先定義一個場景 : 我們需要實作一個產品線上的檢測系統用來檢測產品上面的記號點,其記號點總共有10種組合,此時你選擇使用VGG19預訓練權重加上調整最後輸出層為10類來解決,但是實際部屬時遇到了效能問題,檢測效率需要再提升,於是你決定使用知識蒸餾來解決,以下為示例 :
首先,我們需要定義 VGG19 模型並準備訓練數據。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
# 修改 VGG19 模型輸出
class Vgg19(nn.Module):
def __init__(self):
super(Vgg19, self).__init__()
vgg19 = models.vgg19(weights=models.VGG19_Weights.DEFAULT)
self.features = vgg19.features
self.avgpool = vgg19.avgpool
self.classifier = vgg19.classifier
self.classifier[6] = nn.Linear(4096, 10) # 修改最後一層為10個類別
def forward(self, x):
features = self.features(x)
x = self.avgpool(features)
x = torch.flatten(x, 1)
out = self.classifier(x)
return features, out
teacher_model = Vgg19()
# 定義數據轉換和加載數據集
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
其中我們額外輸出 feature 結果作為蒸餾目標,該層輸出張量大小如下,後續在設計濃縮網路時需要注意需要相同的張量大小方能計算損失
而整個模型參數量如下
我們假設VGG19的權重已經預訓練好並適應10個類別,如果需要,可以進一步微調。
# 訓練VGG19模型
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher_model.classifier.parameters(), lr=0.001)
# 訓練循環
teacher_model.train()
for epoch in range(5): # 訓練5個epoch
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad()
_,outputs = teacher_model(inputs) #暫時還不需要 feature 輸出
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")
我們構建一個較小的模型,如簡化的CNN模型。
class Student(nn.Module):
def __init__(self):
super(Student, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
)
self.avgpool = nn.AvgPool2d(7)
self.flat = nn.Flatten()
self.classifier = nn.Linear(512, 10)
def forward(self, x):
features = self.features(x)
x = self.avgpool(features)
x = self.flat(x)
out = self.classifier(x)
return features,out
student_model = Student()
整個模型參數量如下,跟VGG19相比可說是把參數量打到骨折(只是個示例,真正使用時請別一次打這麼多,可以持續嘗試準確度與參數量間的甜蜜點)
定義蒸餾損失,包括知識蒸餾損失和真實標籤損失。
def feature_distillation_loss(student_features, teacher_features, student_logits, labels, alpha=0.5):
feature_loss = nn.MSELoss()(student_features, teacher_features)
classification_loss = nn.CrossEntropyLoss()(student_logits, labels)
return feature_loss * alpha + classification_loss * (1.0 - alpha)
使用蒸餾損失來訓練小模型。
optimizer = optim.Adam(student_model.parameters(), lr=0.001)
teacher_model.eval() # Teacher模型設定為評估模式
student_model.train()
for epoch in range(5): # 訓練5個epoch
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad()
# 獲取Teacher模型的輸出
with torch.no_grad():
teacher_features, _ = teacher_model(inputs) #只需要 feature
# 獲取Student模型的輸出
student_features, student_logits = student_model(inputs)
# 計算損失
loss = feature_distillation_loss(student_features, teacher_features, student_logits, labels)
# 反向傳播和優化
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")
評估學生模型在測試集上的表現。
student_model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
_, outputs = student_model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the student model on the test images: {100 * correct / total}%')
看了幾周的 AWS ,這次換換口味來分享下一些實際上會使用到的小技巧,在上一份工作中,我常常遇到這種適合使用知識蒸餾的案例,實際訓練上可以撰寫如 Hyperparameter Search 的方式去找尋合適的層數減少數及卷積核減少數等超參數,並且還可以 A 蒸餾出 B , B 蒸餾出 C 的方式去迭代,直接無痛替換現場的肥大模型,可以說是我在實務上使用最多的優化方式。