AI 小撇步-Distilling Knowledge

閱讀時間約 14 分鐘

一.引言

  不知道大家會不會有這種感覺,在使用現今的一些預訓練模型時,雖然好用,但是實際在場域部屬時總感覺殺雞焉用牛刀,實際使用下去後續又沒有時間讓你去優化它,只好將錯就錯反正能用的想法持續使用,現在有個不錯的方法讓你在一開始就可以用相對低廉的成本去優化這個模型,讓後續使用不再懊悔。

二.Distilling Knowledge ?

  這個方法叫做 Distilling Knowledge ,中文可譯作知識蒸餾,這個方法的概念很簡單,我們如果將整個模型訓練過程當作是考試,這個模型就是學生,而訓練資料就是考題,模型(學生)要做的事情很簡單,就是拿到訓練資料(考題)後運算出一個結果,若與正解相似度愈高則愈高分,平常的學生依靠自己的本事答題,但若有一個學生,它有一個家教協助它統整考題,總結出所謂的必勝公式,那麼這個學生是不是會比沒有家教的學生答題準確性及達到及格標準的速度來得高?

利用這個想法,要做到知識蒸餾有三個步驟 :

  • 確定蒸餾目標(家教) : 根據你的需求找尋合適的預訓練模型,並在上面測試你的訓練資料是否結果於和預期,若不符合則需要 finetune 微調
  • 設計濃縮模型(學生) : 根據你的場景設計適當的模型,建議上可以以蒸餾目標的架構為基礎進行刪減(例如 Transformer 架構模型蒸餾後依然還是 Transformer 架構,不能指望單純的架設幾層卷積層就能模仿注意力機制的效果)
  • 設計蒸餾流程 : 選擇特徵層並定義合適的損失函數並加入到濃縮模型的訓練流程中

三.示例

  我們先定義一個場景 : 我們需要實作一個產品線上的檢測系統用來檢測產品上面的記號點,其記號點總共有10種組合,此時你選擇使用VGG19預訓練權重加上調整最後輸出層為10類來解決,但是實際部屬時遇到了效能問題,檢測效率需要再提升,於是你決定使用知識蒸餾來解決,以下為示例 :

步驟一:準備數據和Teacher模型

首先,我們需要定義 VGG19 模型並準備訓練數據。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models

# 修改 VGG19 模型輸出
class Vgg19(nn.Module):
def __init__(self):
super(Vgg19, self).__init__()
vgg19 = models.vgg19(weights=models.VGG19_Weights.DEFAULT)
self.features = vgg19.features
self.avgpool = vgg19.avgpool
self.classifier = vgg19.classifier
self.classifier[6] = nn.Linear(4096, 10) # 修改最後一層為10個類別

def forward(self, x):
features = self.features(x)
x = self.avgpool(features)
x = torch.flatten(x, 1)
out = self.classifier(x)
return features, out

teacher_model = Vgg19()

# 定義數據轉換和加載數據集
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

其中我們額外輸出 feature 結果作為蒸餾目標,該層輸出張量大小如下,後續在設計濃縮網路時需要注意需要相同的張量大小方能計算損失

raw-image

而整個模型參數量如下

raw-image

步驟二:訓練Teacher模型

我們假設VGG19的權重已經預訓練好並適應10個類別,如果需要,可以進一步微調。

# 訓練VGG19模型
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher_model.classifier.parameters(), lr=0.001)

# 訓練循環
teacher_model.train()
for epoch in range(5): # 訓練5個epoch
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad()
_,outputs = teacher_model(inputs) #暫時還不需要 feature 輸出
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

步驟三:構建Student模型

我們構建一個較小的模型,如簡化的CNN模型。

class Student(nn.Module):
def __init__(self):
super(Student, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
)
self.avgpool = nn.AvgPool2d(7)
self.flat = nn.Flatten()
self.classifier = nn.Linear(512, 10)

def forward(self, x):
features = self.features(x)
x = self.avgpool(features)
x = self.flat(x)
out = self.classifier(x)
return features,out

student_model = Student()

整個模型參數量如下,跟VGG19相比可說是把參數量打到骨折(只是個示例,真正使用時請別一次打這麼多,可以持續嘗試準確度與參數量間的甜蜜點)

raw-image

步驟四:定義蒸餾損失函數

定義蒸餾損失,包括知識蒸餾損失和真實標籤損失。

def feature_distillation_loss(student_features, teacher_features, student_logits, labels, alpha=0.5):
feature_loss = nn.MSELoss()(student_features, teacher_features)
classification_loss = nn.CrossEntropyLoss()(student_logits, labels)
return feature_loss * alpha + classification_loss * (1.0 - alpha)

步驟五:訓練Student模型

使用蒸餾損失來訓練小模型。

optimizer = optim.Adam(student_model.parameters(), lr=0.001)
teacher_model.eval() # Teacher模型設定為評估模式
student_model.train()

for epoch in range(5): # 訓練5個epoch
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad()

# 獲取Teacher模型的輸出
with torch.no_grad():
teacher_features, _ = teacher_model(inputs) #只需要 feature

# 獲取Student模型的輸出
student_features, student_logits = student_model(inputs)

# 計算損失
loss = feature_distillation_loss(student_features, teacher_features, student_logits, labels)

# 反向傳播和優化
loss.backward()
optimizer.step()

running_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

步驟六:評估Student模型

評估學生模型在測試集上的表現。

student_model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
_, outputs = student_model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

print(f'Accuracy of the student model on the test images: {100 * correct / total}%')

四.結語

  看了幾周的 AWS ,這次換換口味來分享下一些實際上會使用到的小技巧,在上一份工作中,我常常遇到這種適合使用知識蒸餾的案例,實際訓練上可以撰寫如 Hyperparameter Search 的方式去找尋合適的層數減少數及卷積核減少數等超參數,並且還可以 A 蒸餾出 B , B 蒸餾出 C 的方式去迭代,直接無痛替換現場的肥大模型,可以說是我在實務上使用最多的優化方式。

7會員
19內容數
AI、電腦視覺、圖像處理、AWS等等持續學習時的學習筆記,也包含一些心得,主要是幫助自己學習,若能同時幫助到不小心來到這裡的人,那也是好事一件 : )
留言0
查看全部
發表第一個留言支持創作者!
你可能也想看
【AI小說煉金術01】卡片筆記Scrintal+Claude Pro讓你輕鬆完本! 這篇文章介紹瞭如何利用AI工具和卡片盒筆記來提高小說創作效率,其中介紹了Claude Pro的強大文本創作能力,Scrintal的整理思緒和提綱挈領功能,以及如何利用小說模板和AI工具搭建故事架構。另外還提到如何選擇不同風格的故事來進行創作。最後作者分享了自己的實際寫作經驗以及對AI寫作工具的看法。
Thumbnail
avatar
萊丘 LaichuTV
2024-03-24
AI小眾市場的無人商店也開始表態了 站上10年線127元才開始規劃大方向 目前只是在110~127區間 題材都已經講過風報比機會,就是持續操作 真正規劃站穩10年線才開始 零售模式一定會慢慢改變的,前7年各研發人員和公司產品都只是在測試這個市場 經過2022~2023年後,缺工跟AI的方向更明確也更成熟未來就是會發酵這個
Thumbnail
avatar
金融狡兔
2023-12-14
AI小眾商機緩慢往前#振樺電 還沒攻擊到10年線,真正規劃站穩10年線才開始 零售模式一定會慢慢改變的,前7年各研發人員和公司產品都只是在測試這個市場 經過2022~2023年後,缺工跟AI的方向更明確也更成熟 未來就是會發酵這個題材,在市場就會有溫度可以判斷估值的共鳴
Thumbnail
avatar
金融狡兔
2023-11-30
[AI小學堂(5)]大型語言模型LLM是怎麼一回事? ChatGPT背後的技術 大型語言模型 是否與我們前面介紹的神經網路相同呢? 答案是不同的,這也是我們想要進一步探討了解的課題。今天會先解釋什麼是語言模型,想要做到的是哪些事情。
Thumbnail
avatar
技術PM路易斯
2023-08-27
[AI小學堂(四)]神經網路是什麼? 神經網路怎麼學習? 超簡化講解反向傳播演算法Back Propagation在我們的上一篇文章,我們把神經網路的架構用簡化再簡化的方式來說明,本篇文章我們會說明神經網路怎麼透過很多輸入資料來調整神經網路裡面的權重跟誤差值,藉由得到接近完美個權重跟誤差值,來做到學習的效果
Thumbnail
avatar
技術PM路易斯
2023-06-11
[AI小學堂(三)]神經網路是什麼? 淺談深度學習的神經網路Neural Network的架構-續上篇文章我們解說到了神經網路的基本架構包含了輸入層,輸出層,還有中間的隱藏層,也說明了這是一個把輸入資料拆解出特徵然後依照特徵做判斷的過程。究竟每一層的神經網路,如何影響下一層的神經網路可以辨識出特徵呢? 這些中間的線條(連結)到底是什麼意義呢? 這就是這一篇要告訴你的。
Thumbnail
avatar
技術PM路易斯
2023-05-30
[AI小學堂(二)]神經網路是什麼? 淺談深度學習的神經網路Neural Network的架構在我的上一篇文章中,我們提到了人工智慧 & 機器學習 & 深度學習跟神經網路的關係,我們也了解到了所謂的深度學習是一種基於神經網路上的機器學習方法。那麼神經網路到底是什麼呢? 我們上一篇文章裡面提到的神經網路的層Layer究竟是什麼呢? 到底為什麼神經網路需要這麼多的神經元(Neurons)跟層數呢
Thumbnail
avatar
技術PM路易斯
2023-05-30
[AI小學堂(一)]人工智慧AI vs 機器學習 vs 深度學習我們這個系列就是希望以非常科普的角度來解釋人工智慧。本篇要釐清人工智慧(AI: Artificial Intelligence),機器學習 Machine Learning, 深度學習Deep Learning,另外還有類神經網路,到底互相是什麼關係呢?
Thumbnail
avatar
技術PM路易斯
2023-05-27
AI 小白入門 (一)認識 AI前言 政府最愛說今年是 XX 元年,十年前喊出 AI 會改變世界後,2023 還真的是 AI 元年,不需贅述大家應該被 AI 各種新噱頭花樣洗了好幾輪了吧! 就像蒸氣機發明時,有人興奮覺得可以從勞役中解放人類,有人覺得會造成社會動盪,他們都對,問題是你要站在那邊?
avatar
Scott Hsiao
2023-05-14
AI小說_空中英雄_GPT-4 (下)李傑對父親說了一聲再見,然後關閉了通訊器。他先將戰機飛到了客機的上方,距離客機只有幾米,緊接著打開了戰機的雷達,掃描了客艙的內部情況。
Thumbnail
avatar
都說
2023-04-14