AI 小撇步-Distilling Knowledge

更新於 發佈於 閱讀時間約 14 分鐘

一.引言

  不知道大家會不會有這種感覺,在使用現今的一些預訓練模型時,雖然好用,但是實際在場域部屬時總感覺殺雞焉用牛刀,實際使用下去後續又沒有時間讓你去優化它,只好將錯就錯反正能用的想法持續使用,現在有個不錯的方法讓你在一開始就可以用相對低廉的成本去優化這個模型,讓後續使用不再懊悔。

二.Distilling Knowledge ?

  這個方法叫做 Distilling Knowledge ,中文可譯作知識蒸餾,這個方法的概念很簡單,我們如果將整個模型訓練過程當作是考試,這個模型就是學生,而訓練資料就是考題,模型(學生)要做的事情很簡單,就是拿到訓練資料(考題)後運算出一個結果,若與正解相似度愈高則愈高分,平常的學生依靠自己的本事答題,但若有一個學生,它有一個家教協助它統整考題,總結出所謂的必勝公式,那麼這個學生是不是會比沒有家教的學生答題準確性及達到及格標準的速度來得高?

利用這個想法,要做到知識蒸餾有三個步驟 :

  • 確定蒸餾目標(家教) : 根據你的需求找尋合適的預訓練模型,並在上面測試你的訓練資料是否結果於和預期,若不符合則需要 finetune 微調
  • 設計濃縮模型(學生) : 根據你的場景設計適當的模型,建議上可以以蒸餾目標的架構為基礎進行刪減(例如 Transformer 架構模型蒸餾後依然還是 Transformer 架構,不能指望單純的架設幾層卷積層就能模仿注意力機制的效果)
  • 設計蒸餾流程 : 選擇特徵層並定義合適的損失函數並加入到濃縮模型的訓練流程中

三.示例

  我們先定義一個場景 : 我們需要實作一個產品線上的檢測系統用來檢測產品上面的記號點,其記號點總共有10種組合,此時你選擇使用VGG19預訓練權重加上調整最後輸出層為10類來解決,但是實際部屬時遇到了效能問題,檢測效率需要再提升,於是你決定使用知識蒸餾來解決,以下為示例 :

步驟一:準備數據和Teacher模型

首先,我們需要定義 VGG19 模型並準備訓練數據。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models

# 修改 VGG19 模型輸出
class Vgg19(nn.Module):
def __init__(self):
super(Vgg19, self).__init__()
vgg19 = models.vgg19(weights=models.VGG19_Weights.DEFAULT)
self.features = vgg19.features
self.avgpool = vgg19.avgpool
self.classifier = vgg19.classifier
self.classifier[6] = nn.Linear(4096, 10) # 修改最後一層為10個類別

def forward(self, x):
features = self.features(x)
x = self.avgpool(features)
x = torch.flatten(x, 1)
out = self.classifier(x)
return features, out

teacher_model = Vgg19()

# 定義數據轉換和加載數據集
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

其中我們額外輸出 feature 結果作為蒸餾目標,該層輸出張量大小如下,後續在設計濃縮網路時需要注意需要相同的張量大小方能計算損失

raw-image

而整個模型參數量如下

raw-image

步驟二:訓練Teacher模型

我們假設VGG19的權重已經預訓練好並適應10個類別,如果需要,可以進一步微調。

# 訓練VGG19模型
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher_model.classifier.parameters(), lr=0.001)

# 訓練循環
teacher_model.train()
for epoch in range(5): # 訓練5個epoch
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad()
_,outputs = teacher_model(inputs) #暫時還不需要 feature 輸出
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

步驟三:構建Student模型

我們構建一個較小的模型,如簡化的CNN模型。

class Student(nn.Module):
def __init__(self):
super(Student, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
)
self.avgpool = nn.AvgPool2d(7)
self.flat = nn.Flatten()
self.classifier = nn.Linear(512, 10)

def forward(self, x):
features = self.features(x)
x = self.avgpool(features)
x = self.flat(x)
out = self.classifier(x)
return features,out

student_model = Student()

整個模型參數量如下,跟VGG19相比可說是把參數量打到骨折(只是個示例,真正使用時請別一次打這麼多,可以持續嘗試準確度與參數量間的甜蜜點)

raw-image

步驟四:定義蒸餾損失函數

定義蒸餾損失,包括知識蒸餾損失和真實標籤損失。

def feature_distillation_loss(student_features, teacher_features, student_logits, labels, alpha=0.5):
feature_loss = nn.MSELoss()(student_features, teacher_features)
classification_loss = nn.CrossEntropyLoss()(student_logits, labels)
return feature_loss * alpha + classification_loss * (1.0 - alpha)

步驟五:訓練Student模型

使用蒸餾損失來訓練小模型。

optimizer = optim.Adam(student_model.parameters(), lr=0.001)
teacher_model.eval() # Teacher模型設定為評估模式
student_model.train()

for epoch in range(5): # 訓練5個epoch
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad()

# 獲取Teacher模型的輸出
with torch.no_grad():
teacher_features, _ = teacher_model(inputs) #只需要 feature

# 獲取Student模型的輸出
student_features, student_logits = student_model(inputs)

# 計算損失
loss = feature_distillation_loss(student_features, teacher_features, student_logits, labels)

# 反向傳播和優化
loss.backward()
optimizer.step()

running_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

步驟六:評估Student模型

評估學生模型在測試集上的表現。

student_model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
_, outputs = student_model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

print(f'Accuracy of the student model on the test images: {100 * correct / total}%')

四.結語

  看了幾周的 AWS ,這次換換口味來分享下一些實際上會使用到的小技巧,在上一份工作中,我常常遇到這種適合使用知識蒸餾的案例,實際訓練上可以撰寫如 Hyperparameter Search 的方式去找尋合適的層數減少數及卷積核減少數等超參數,並且還可以 A 蒸餾出 B , B 蒸餾出 C 的方式去迭代,直接無痛替換現場的肥大模型,可以說是我在實務上使用最多的優化方式。

留言
avatar-img
留言分享你的想法!
avatar-img
貓貓學習筆記
10會員
21內容數
AI、電腦視覺、圖像處理、AWS等等持續學習時的學習筆記,也包含一些心得,主要是幫助自己學習,若能同時幫助到不小心來到這裡的人,那也是好事一件 : )
貓貓學習筆記的其他內容
2024/07/08
我們前面幾篇已經講完TTS技術的一大半架構了,知道了如何將聲學特徵重建回音訊波形,也從中可以知道要是聲學特徵不完善,最終取得的結果也會不自然,剩下要探討該如何將文字轉換成聲學特徵,且能夠自然地表現停頓及細節變化,讓我們開始吧。
Thumbnail
2024/07/08
我們前面幾篇已經講完TTS技術的一大半架構了,知道了如何將聲學特徵重建回音訊波形,也從中可以知道要是聲學特徵不完善,最終取得的結果也會不自然,剩下要探討該如何將文字轉換成聲學特徵,且能夠自然地表現停頓及細節變化,讓我們開始吧。
Thumbnail
2024/06/26
距離上篇已經快過一個月了,這個月我也沒閒著,我FF14生產職拉了不少等級進行了上篇 WaveNet 的後續調試,也比較與其他人實現的效果,又發現了幾個實作上可能造成困難的點,現在就跟各位分享一下~
Thumbnail
2024/06/26
距離上篇已經快過一個月了,這個月我也沒閒著,我FF14生產職拉了不少等級進行了上篇 WaveNet 的後續調試,也比較與其他人實現的效果,又發現了幾個實作上可能造成困難的點,現在就跟各位分享一下~
Thumbnail
2024/06/01
WaveNet 提供了一個先進的架構用於音訊重建,但是,有必要嗎? Mel 頻譜本身就是經過數學轉換而獲得的結果,不能反運算嗎 ? 到底 WaveNet 在其中扮演了甚麼腳色 ?它是如何運作的 ? 讓我們在這篇好好探討下去。
Thumbnail
2024/06/01
WaveNet 提供了一個先進的架構用於音訊重建,但是,有必要嗎? Mel 頻譜本身就是經過數學轉換而獲得的結果,不能反運算嗎 ? 到底 WaveNet 在其中扮演了甚麼腳色 ?它是如何運作的 ? 讓我們在這篇好好探討下去。
Thumbnail
看更多
你可能也想看
Thumbnail
每年4月、5月都是最多稅要繳的月份,當然大部份的人都是有機會繳到「綜合所得稅」,只是相當相當多人還不知道,原來繳給政府的稅!可以透過一些有活動的銀行信用卡或電子支付來繳,從繳費中賺一點點小確幸!就是賺個1%~2%大家也是很開心的,因為你們把沒回饋變成有回饋,就是用卡的最高境界 所得稅線上申報
Thumbnail
每年4月、5月都是最多稅要繳的月份,當然大部份的人都是有機會繳到「綜合所得稅」,只是相當相當多人還不知道,原來繳給政府的稅!可以透過一些有活動的銀行信用卡或電子支付來繳,從繳費中賺一點點小確幸!就是賺個1%~2%大家也是很開心的,因為你們把沒回饋變成有回饋,就是用卡的最高境界 所得稅線上申報
Thumbnail
全球科技產業的焦點,AKA 全村的希望 NVIDIA,於五月底正式發布了他們在今年 2025 第一季的財報 (輝達內部財務年度為 2026 Q1,實際日曆期間為今年二到四月),交出了打敗了市場預期的成績單。然而,在銷售持續高速成長的同時,川普政府加大對於中國的晶片管制......
Thumbnail
全球科技產業的焦點,AKA 全村的希望 NVIDIA,於五月底正式發布了他們在今年 2025 第一季的財報 (輝達內部財務年度為 2026 Q1,實際日曆期間為今年二到四月),交出了打敗了市場預期的成績單。然而,在銷售持續高速成長的同時,川普政府加大對於中國的晶片管制......
Thumbnail
重點摘要: 6 月繼續維持基準利率不變,強調維持高利率主因為關稅 點陣圖表現略為鷹派,收斂 2026、2027 年降息預期 SEP 連續 2 季下修 GDP、上修通膨預測值 --- 1.繼續維持利率不變,強調需要維持高利率是因為關稅: 聯準會 (Fed) 召開 6 月利率會議
Thumbnail
重點摘要: 6 月繼續維持基準利率不變,強調維持高利率主因為關稅 點陣圖表現略為鷹派,收斂 2026、2027 年降息預期 SEP 連續 2 季下修 GDP、上修通膨預測值 --- 1.繼續維持利率不變,強調需要維持高利率是因為關稅: 聯準會 (Fed) 召開 6 月利率會議
Thumbnail
我最近在網上學到了一個非常實用的方法,可以快速了解一個行業。這個方法來自麥肯錫的工作方法,搭配ChatGPT使用非常高效。只要你學會了,就能輕鬆掌握任何行業的基礎知識。 麥肯錫的方法論 第一步:總結關鍵詞
Thumbnail
我最近在網上學到了一個非常實用的方法,可以快速了解一個行業。這個方法來自麥肯錫的工作方法,搭配ChatGPT使用非常高效。只要你學會了,就能輕鬆掌握任何行業的基礎知識。 麥肯錫的方法論 第一步:總結關鍵詞
Thumbnail
因為 AI 領域的技術不斷地迭代更新,無法避免的是需要一直去追新的技術 並且需要在一個有限的時間學會,或是實作應用導入到專案之中。 那我覺得在學習新技術可通過以下步驟: 1. 找一個讀得懂的教學資源 現在網路上的教學資源很多,或是書籍的資源也很豐富, 同時也有像是 chatgpt 的 AI
Thumbnail
因為 AI 領域的技術不斷地迭代更新,無法避免的是需要一直去追新的技術 並且需要在一個有限的時間學會,或是實作應用導入到專案之中。 那我覺得在學習新技術可通過以下步驟: 1. 找一個讀得懂的教學資源 現在網路上的教學資源很多,或是書籍的資源也很豐富, 同時也有像是 chatgpt 的 AI
Thumbnail
AI繪圖要廣泛用於商用還有一大段路,還需要依賴人類的經驗判斷、調整,為什麼呢?
Thumbnail
AI繪圖要廣泛用於商用還有一大段路,還需要依賴人類的經驗判斷、調整,為什麼呢?
Thumbnail
AI 工具雖能在短時間生成內容,但它不瞭解你的客戶,也無法取代你做現場互動交流。在合適的時機選擇使用適合的 AI 工具,幫助我們專注於最重要的人事物上。
Thumbnail
AI 工具雖能在短時間生成內容,但它不瞭解你的客戶,也無法取代你做現場互動交流。在合適的時機選擇使用適合的 AI 工具,幫助我們專注於最重要的人事物上。
Thumbnail
最新的AI趨勢讓人眼花撩亂,不知要如何開始學習?本文介紹了作者對AI的使用和體驗,以及各類AI工具以及推薦的選擇。最後強調了AI是一個很好用的工具,可以幫助人們節省時間並提高效率。鼓勵人們保持好奇心,不停止學習,並提出了對健康生活和開心生活的祝福。
Thumbnail
最新的AI趨勢讓人眼花撩亂,不知要如何開始學習?本文介紹了作者對AI的使用和體驗,以及各類AI工具以及推薦的選擇。最後強調了AI是一個很好用的工具,可以幫助人們節省時間並提高效率。鼓勵人們保持好奇心,不停止學習,並提出了對健康生活和開心生活的祝福。
Thumbnail
為了充分發揮AI的潛力,我們必須深入瞭解其運作模式和思考邏輯,並學會與AI對話的技巧。《ChatGPT提問課,做個懂AI的高效工作者》這本書提供了豐富的實例,讓讀者更容易學會如何提出精準的問題,並享有提問課程的閱讀回饋。這對於想成為懂AI的高效工作者的人來說,是一本值得一看的書。
Thumbnail
為了充分發揮AI的潛力,我們必須深入瞭解其運作模式和思考邏輯,並學會與AI對話的技巧。《ChatGPT提問課,做個懂AI的高效工作者》這本書提供了豐富的實例,讓讀者更容易學會如何提出精準的問題,並享有提問課程的閱讀回饋。這對於想成為懂AI的高效工作者的人來說,是一本值得一看的書。
Thumbnail
不知道大家會不會有這種感覺,在使用現今的一些預訓練模型時,雖然好用,但是實際在場域部屬時總感覺殺雞焉用牛刀,實際使用下去後續又沒有時間讓你去優化它,只好將錯就錯反正能用的想法持續使用,現在有個不錯的方法讓你在一開始就可以用相對低廉的成本去優化這個模型,讓後續使用不再懊悔。
Thumbnail
不知道大家會不會有這種感覺,在使用現今的一些預訓練模型時,雖然好用,但是實際在場域部屬時總感覺殺雞焉用牛刀,實際使用下去後續又沒有時間讓你去優化它,只好將錯就錯反正能用的想法持續使用,現在有個不錯的方法讓你在一開始就可以用相對低廉的成本去優化這個模型,讓後續使用不再懊悔。
Thumbnail
筆記-曲博談AI模型.群聯-24.05.05 https://www.youtube.com/watch?v=JHE88hwx4b0&t=2034s *大型語言模型 三個步驟: 1.預訓練,訓練一次要用幾萬顆處理器、訓練時間要1個月,ChatGPT訓練一次的成本為1000萬美金。 2.微調(
Thumbnail
筆記-曲博談AI模型.群聯-24.05.05 https://www.youtube.com/watch?v=JHE88hwx4b0&t=2034s *大型語言模型 三個步驟: 1.預訓練,訓練一次要用幾萬顆處理器、訓練時間要1個月,ChatGPT訓練一次的成本為1000萬美金。 2.微調(
Thumbnail
這篇文章介紹瞭如何利用生成式AI(GenAI)來提高學習效率,包括文章重點整理、完善知識體系、客製化學習回饋、提供多元觀點等方法。同時提醒使用者應注意內容的信效度,保持學術誠信,適當運用GenAI能大幅提升工作效率。
Thumbnail
這篇文章介紹瞭如何利用生成式AI(GenAI)來提高學習效率,包括文章重點整理、完善知識體系、客製化學習回饋、提供多元觀點等方法。同時提醒使用者應注意內容的信效度,保持學術誠信,適當運用GenAI能大幅提升工作效率。
追蹤感興趣的內容從 Google News 追蹤更多 vocus 的最新精選內容追蹤 Google News