AI 小撇步-Distilling Knowledge

更新於 發佈於 閱讀時間約 14 分鐘

一.引言

  不知道大家會不會有這種感覺,在使用現今的一些預訓練模型時,雖然好用,但是實際在場域部屬時總感覺殺雞焉用牛刀,實際使用下去後續又沒有時間讓你去優化它,只好將錯就錯反正能用的想法持續使用,現在有個不錯的方法讓你在一開始就可以用相對低廉的成本去優化這個模型,讓後續使用不再懊悔。

二.Distilling Knowledge ?

  這個方法叫做 Distilling Knowledge ,中文可譯作知識蒸餾,這個方法的概念很簡單,我們如果將整個模型訓練過程當作是考試,這個模型就是學生,而訓練資料就是考題,模型(學生)要做的事情很簡單,就是拿到訓練資料(考題)後運算出一個結果,若與正解相似度愈高則愈高分,平常的學生依靠自己的本事答題,但若有一個學生,它有一個家教協助它統整考題,總結出所謂的必勝公式,那麼這個學生是不是會比沒有家教的學生答題準確性及達到及格標準的速度來得高?

利用這個想法,要做到知識蒸餾有三個步驟 :

  • 確定蒸餾目標(家教) : 根據你的需求找尋合適的預訓練模型,並在上面測試你的訓練資料是否結果於和預期,若不符合則需要 finetune 微調
  • 設計濃縮模型(學生) : 根據你的場景設計適當的模型,建議上可以以蒸餾目標的架構為基礎進行刪減(例如 Transformer 架構模型蒸餾後依然還是 Transformer 架構,不能指望單純的架設幾層卷積層就能模仿注意力機制的效果)
  • 設計蒸餾流程 : 選擇特徵層並定義合適的損失函數並加入到濃縮模型的訓練流程中

三.示例

  我們先定義一個場景 : 我們需要實作一個產品線上的檢測系統用來檢測產品上面的記號點,其記號點總共有10種組合,此時你選擇使用VGG19預訓練權重加上調整最後輸出層為10類來解決,但是實際部屬時遇到了效能問題,檢測效率需要再提升,於是你決定使用知識蒸餾來解決,以下為示例 :

步驟一:準備數據和Teacher模型

首先,我們需要定義 VGG19 模型並準備訓練數據。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models

# 修改 VGG19 模型輸出
class Vgg19(nn.Module):
def __init__(self):
super(Vgg19, self).__init__()
vgg19 = models.vgg19(weights=models.VGG19_Weights.DEFAULT)
self.features = vgg19.features
self.avgpool = vgg19.avgpool
self.classifier = vgg19.classifier
self.classifier[6] = nn.Linear(4096, 10) # 修改最後一層為10個類別

def forward(self, x):
features = self.features(x)
x = self.avgpool(features)
x = torch.flatten(x, 1)
out = self.classifier(x)
return features, out

teacher_model = Vgg19()

# 定義數據轉換和加載數據集
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

其中我們額外輸出 feature 結果作為蒸餾目標,該層輸出張量大小如下,後續在設計濃縮網路時需要注意需要相同的張量大小方能計算損失

raw-image

而整個模型參數量如下

raw-image

步驟二:訓練Teacher模型

我們假設VGG19的權重已經預訓練好並適應10個類別,如果需要,可以進一步微調。

# 訓練VGG19模型
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher_model.classifier.parameters(), lr=0.001)

# 訓練循環
teacher_model.train()
for epoch in range(5): # 訓練5個epoch
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad()
_,outputs = teacher_model(inputs) #暫時還不需要 feature 輸出
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

步驟三:構建Student模型

我們構建一個較小的模型,如簡化的CNN模型。

class Student(nn.Module):
def __init__(self):
super(Student, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
)
self.avgpool = nn.AvgPool2d(7)
self.flat = nn.Flatten()
self.classifier = nn.Linear(512, 10)

def forward(self, x):
features = self.features(x)
x = self.avgpool(features)
x = self.flat(x)
out = self.classifier(x)
return features,out

student_model = Student()

整個模型參數量如下,跟VGG19相比可說是把參數量打到骨折(只是個示例,真正使用時請別一次打這麼多,可以持續嘗試準確度與參數量間的甜蜜點)

raw-image

步驟四:定義蒸餾損失函數

定義蒸餾損失,包括知識蒸餾損失和真實標籤損失。

def feature_distillation_loss(student_features, teacher_features, student_logits, labels, alpha=0.5):
feature_loss = nn.MSELoss()(student_features, teacher_features)
classification_loss = nn.CrossEntropyLoss()(student_logits, labels)
return feature_loss * alpha + classification_loss * (1.0 - alpha)

步驟五:訓練Student模型

使用蒸餾損失來訓練小模型。

optimizer = optim.Adam(student_model.parameters(), lr=0.001)
teacher_model.eval() # Teacher模型設定為評估模式
student_model.train()

for epoch in range(5): # 訓練5個epoch
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad()

# 獲取Teacher模型的輸出
with torch.no_grad():
teacher_features, _ = teacher_model(inputs) #只需要 feature

# 獲取Student模型的輸出
student_features, student_logits = student_model(inputs)

# 計算損失
loss = feature_distillation_loss(student_features, teacher_features, student_logits, labels)

# 反向傳播和優化
loss.backward()
optimizer.step()

running_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

步驟六:評估Student模型

評估學生模型在測試集上的表現。

student_model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
_, outputs = student_model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

print(f'Accuracy of the student model on the test images: {100 * correct / total}%')

四.結語

  看了幾周的 AWS ,這次換換口味來分享下一些實際上會使用到的小技巧,在上一份工作中,我常常遇到這種適合使用知識蒸餾的案例,實際訓練上可以撰寫如 Hyperparameter Search 的方式去找尋合適的層數減少數及卷積核減少數等超參數,並且還可以 A 蒸餾出 B , B 蒸餾出 C 的方式去迭代,直接無痛替換現場的肥大模型,可以說是我在實務上使用最多的優化方式。

avatar-img
8會員
21內容數
AI、電腦視覺、圖像處理、AWS等等持續學習時的學習筆記,也包含一些心得,主要是幫助自己學習,若能同時幫助到不小心來到這裡的人,那也是好事一件 : )
留言0
查看全部
avatar-img
發表第一個留言支持創作者!
貓貓學習筆記 的其他內容
  經過三篇的進展,我們目前實作的網路已經能做到同時訓練多種風格,且後續可以直接進行轉換,不用重新訓練,但是這種方法畢竟還是受到了預訓練的風格制約,無法跳脫出來,那麼有什麼辦法能夠讓他對於沒學過的風格也有一定的反應能力呢?
上篇我們已經把風格融入在一個網路之中,實現了訓練一次就可以轉換不同的圖片成我們訓練的風格,但是這樣還不夠,因為這樣每個風格都得訓練一個網路來轉換,太浪費了,那麼,我們有沒有辦法在同一個網路中訓練多個風格呢?
在第一篇我講到一開始的圖像風格轉換,每產生一張圖片都得重新訓練,這對於使用上難免綁手綁腳,所以理所當然的下一步就是要解決這個問題,看看能不能只要訓練一次,就可以重複使用。
  最近遇到一些人想做音訊的合成,我回答他或許可以從圖像風格轉換中找到些靈感,我才突然想起我對於這部分的認知只止於知道他能做什麼及結果大概如何,對於內部訓練邏輯及結構並沒有認真的去了解,現在剛好趁此機會好好的學習一下。
  經過三篇的進展,我們目前實作的網路已經能做到同時訓練多種風格,且後續可以直接進行轉換,不用重新訓練,但是這種方法畢竟還是受到了預訓練的風格制約,無法跳脫出來,那麼有什麼辦法能夠讓他對於沒學過的風格也有一定的反應能力呢?
上篇我們已經把風格融入在一個網路之中,實現了訓練一次就可以轉換不同的圖片成我們訓練的風格,但是這樣還不夠,因為這樣每個風格都得訓練一個網路來轉換,太浪費了,那麼,我們有沒有辦法在同一個網路中訓練多個風格呢?
在第一篇我講到一開始的圖像風格轉換,每產生一張圖片都得重新訓練,這對於使用上難免綁手綁腳,所以理所當然的下一步就是要解決這個問題,看看能不能只要訓練一次,就可以重複使用。
  最近遇到一些人想做音訊的合成,我回答他或許可以從圖像風格轉換中找到些靈感,我才突然想起我對於這部分的認知只止於知道他能做什麼及結果大概如何,對於內部訓練邏輯及結構並沒有認真的去了解,現在剛好趁此機會好好的學習一下。
你可能也想看
Google News 追蹤
Thumbnail
隨著理財資訊的普及,越來越多台灣人不再將資產侷限於台股,而是將視野拓展到國際市場。特別是美國市場,其豐富的理財選擇,讓不少人開始思考將資金配置於海外市場的可能性。 然而,要參與美國市場並不只是盲目跟隨標的這麼簡單,而是需要策略和方式,尤其對新手而言,除了選股以外還會遇到語言、開戶流程、Ap
Thumbnail
嘿,大家新年快樂~ 新年大家都在做什麼呢? 跨年夜的我趕工製作某個外包設計案,在工作告一段落時趕上倒數。 然後和兩個小孩過了一個忙亂的元旦。在深夜時刻,看到朋友傳來的解籤網站,興致勃勃熬夜體驗了一下,覺得非常好玩,或許有人玩過了,但還是想寫上來分享紀錄一下~
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 在某些情況下,別人提供的 Pretrained Transformer Model 效果不盡人意,可能會想要自己做 Pretrained Model,但是這會耗費大量運
Thumbnail
我最近在網上學到了一個非常實用的方法,可以快速了解一個行業。這個方法來自麥肯錫的工作方法,搭配ChatGPT使用非常高效。只要你學會了,就能輕鬆掌握任何行業的基礎知識。 麥肯錫的方法論 第一步:總結關鍵詞
Thumbnail
前言 讀了許多理論,是時候實際動手做做看了,以下是我的模型訓練初體驗,有點糟就是了XD。 正文 def conv(filters, kernel_size, strides=1): return Conv2D(filters, kernel_size,
1. 在任何事情裡嘗試使用AI來幫忙: - 「你應該在你做的任何事情裡嘗試使用AI來幫忙。」隨著你的實驗,你會發現AI的幫忙可能是滿意,可能是很鳥,可能很垃圾,也可能令你很不安。這樣的過程不僅是利用AI來幫自己的忙,更是讓自己熟悉AI的能力,讓你自己更加瞭解AI能如何協助你,或者威脅你,或者取代你
Thumbnail
最近在嘗試使用不同的AI生圖方式混合出圖的方式,採用A平台的優點,並用B平台後製的手法截長補短,創造出自己更想要的小說場景,效果不錯,現在以這張圖為例,來講一下我的製作步驟。
「你應該在你做的任何事情裡嘗試使用AI來幫忙。」 「隨著你的實驗,你會發現AI的幫忙可能是滿意,可能是很鳥, 可能很垃圾,也可能令你很不安。」「由於AI是“通用科技 (General Purpose Technology)”, 並不會有一本書能幫助你了解它全部的價值,以及他全部的限制。」
Thumbnail
最新的AI趨勢讓人眼花撩亂,不知要如何開始學習?本文介紹了作者對AI的使用和體驗,以及各類AI工具以及推薦的選擇。最後強調了AI是一個很好用的工具,可以幫助人們節省時間並提高效率。鼓勵人們保持好奇心,不停止學習,並提出了對健康生活和開心生活的祝福。
Thumbnail
作者用常見的生活模式為底,分享所見所聞,提供自己對於這些情況的建議。詞彙的使用跟實力培養絕對是需要時間的,也非常需要靠別人的作品來當作自己的養分來源,多閱讀別人的書籍、文字、而且是有意識的學習,然後再加上實踐,相信慢慢就能寫出吸引人的文字、甚至發展出自己的風格。
Thumbnail
延續上週提到的,「有哪些不訓練模型的情況下,能夠強化語言模型的能力」,這堂課接續介紹其中第 3、4 個方法
Thumbnail
隨著理財資訊的普及,越來越多台灣人不再將資產侷限於台股,而是將視野拓展到國際市場。特別是美國市場,其豐富的理財選擇,讓不少人開始思考將資金配置於海外市場的可能性。 然而,要參與美國市場並不只是盲目跟隨標的這麼簡單,而是需要策略和方式,尤其對新手而言,除了選股以外還會遇到語言、開戶流程、Ap
Thumbnail
嘿,大家新年快樂~ 新年大家都在做什麼呢? 跨年夜的我趕工製作某個外包設計案,在工作告一段落時趕上倒數。 然後和兩個小孩過了一個忙亂的元旦。在深夜時刻,看到朋友傳來的解籤網站,興致勃勃熬夜體驗了一下,覺得非常好玩,或許有人玩過了,但還是想寫上來分享紀錄一下~
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 在某些情況下,別人提供的 Pretrained Transformer Model 效果不盡人意,可能會想要自己做 Pretrained Model,但是這會耗費大量運
Thumbnail
我最近在網上學到了一個非常實用的方法,可以快速了解一個行業。這個方法來自麥肯錫的工作方法,搭配ChatGPT使用非常高效。只要你學會了,就能輕鬆掌握任何行業的基礎知識。 麥肯錫的方法論 第一步:總結關鍵詞
Thumbnail
前言 讀了許多理論,是時候實際動手做做看了,以下是我的模型訓練初體驗,有點糟就是了XD。 正文 def conv(filters, kernel_size, strides=1): return Conv2D(filters, kernel_size,
1. 在任何事情裡嘗試使用AI來幫忙: - 「你應該在你做的任何事情裡嘗試使用AI來幫忙。」隨著你的實驗,你會發現AI的幫忙可能是滿意,可能是很鳥,可能很垃圾,也可能令你很不安。這樣的過程不僅是利用AI來幫自己的忙,更是讓自己熟悉AI的能力,讓你自己更加瞭解AI能如何協助你,或者威脅你,或者取代你
Thumbnail
最近在嘗試使用不同的AI生圖方式混合出圖的方式,採用A平台的優點,並用B平台後製的手法截長補短,創造出自己更想要的小說場景,效果不錯,現在以這張圖為例,來講一下我的製作步驟。
「你應該在你做的任何事情裡嘗試使用AI來幫忙。」 「隨著你的實驗,你會發現AI的幫忙可能是滿意,可能是很鳥, 可能很垃圾,也可能令你很不安。」「由於AI是“通用科技 (General Purpose Technology)”, 並不會有一本書能幫助你了解它全部的價值,以及他全部的限制。」
Thumbnail
最新的AI趨勢讓人眼花撩亂,不知要如何開始學習?本文介紹了作者對AI的使用和體驗,以及各類AI工具以及推薦的選擇。最後強調了AI是一個很好用的工具,可以幫助人們節省時間並提高效率。鼓勵人們保持好奇心,不停止學習,並提出了對健康生活和開心生活的祝福。
Thumbnail
作者用常見的生活模式為底,分享所見所聞,提供自己對於這些情況的建議。詞彙的使用跟實力培養絕對是需要時間的,也非常需要靠別人的作品來當作自己的養分來源,多閱讀別人的書籍、文字、而且是有意識的學習,然後再加上實踐,相信慢慢就能寫出吸引人的文字、甚至發展出自己的風格。
Thumbnail
延續上週提到的,「有哪些不訓練模型的情況下,能夠強化語言模型的能力」,這堂課接續介紹其中第 3、4 個方法