淺談Few-Shot Learning:(2): 實作CUB 資料集 Few shot learning

更新於 發佈於 閱讀時間約 26 分鐘

前言

在上一篇我們簡單介紹了 Few-shot learning 的概念與應用價值,還有他最重要的實作參數,那在這一篇我們就來使用CUB資料集,實作一下 Few-shot learning,還沒看過上篇的可以點以下連結: 淺談Few-Shot Learning:(1): 初步認識

Photo by Nubelson Fernandes on Unsplash

Photo by Nubelson Fernandes on Unsplash

資料集簡介

CUB(Caltech-UCSD Birds 200)資料集是一個常用於電腦視覺研究的影像資料集,尤其在細粒度影像辨識領域中非常受歡迎。這個資料集是由加州理工學院(Caltech)和加州大學聖地牙哥分校(UCSD)聯合創建的。

  • 類別數量:包含 200 個不同的鳥類類別。
  • 圖像數量:共有 11,788 張圖像,每個類別大約有 60 張。

資料集用途:

  • 細粒度:這個資料集專注於細粒度物種辨識,所有的圖像都是同一大類(鳥類)中的不同細分類別。
  • 標註資訊:每張圖像除了有類別標籤外,還附有精確的標註信息,包括物種名稱、圖像中鳥類的 bounding box、部分標註的身體部位等。

本次實作中,我們僅使用到類別標籤,以簡化任務並聚焦於 Few-shot learning 的核心任務上。

資料下載及目錄結構

首先資料下載下來後依照這樣的目錄結構放置

data
└── CUB
├── 001 XX
│ ├── 1.jpg
│ └── 2.jpg
└── 002 OO
├── 3.jpg
├── 4.jpg
└── 5.jpg

鳥類影像如下圖:

raw-image
raw-image

如果在 colab 上實作的話,遇到會因為資料量大不能下載的問題,解法參考:Failed to download CUB dataset in Google Colab on example Notebooks

資料讀取

下載準備好 EasyFSL 提供的 CUB 資料集後,第一個步驟就是把資料載入,透過 Pytorch 的 Dataloader 進行數據的 batch process 和 load 。

from easyfsl.datasets import CUB
from torch.utils.data import DataLoader

batch_size = 128
n_workers = 2

# 加載 CUB 資料集的訓練集
train_set = CUB(split="train", training=True)
train_loader = DataLoader(
train_set,
batch_size=batch_size,
num_workers=n_workers,
pin_memory=True,
shuffle=True,
)

以下說明設定參數:

  • batch_size:決定 batch 的數量,每次訓練處理的圖像數量為 128,若 batch_size 越大表示每次迭代處理的數量越多,更有效的使用 GPU ,同時需考量硬體資源,而大批次也可能導致模型收斂速度比較慢 ; 相反的想批次每次需要的資源小,模型收斂快,但可能小批次造成每次模型學習較不穩定,梯度波動性較大。
  • n_workers:使用 2 個工作執行緒來加速數據加載,提升效率。
  • pin_memory:將數據固定在內存中,以加快 GPU 的數據傳輸速度。
  • shuffle:在每個 epoch 隨機打亂數據,增強模型的泛化能力。

模型訓練框架

這裡我們使用了 ResNet12 作為特徵提取的 Backbone(主幹網絡)。ResNet12 是一種輕量級的卷積神經網絡架構,非常適合在 Few-shot Learning 場景下使用,因為它計算效率高且能有效捕捉圖像中的細粒度特徵。

以下為模型初始設定

from easyfsl.modules import resnet12

DEVICE = "cuda"

model = resnet12(
use_fc=True,
num_classes=len(set(train_set.get_labels())),
).to(DEVICE)

參數說明

  • use_fc=True:啟用全連接層(Fully Connected Layer),用於最後的分類任務。這使模型能夠根據提取的特徵,輸出對應於類別數的分類結果。\
  • num_classes=len(set(train_set.get_labels())):根據訓練集中存在的類別數動態設定輸出層的維度。這確保模型輸出的維度與任務的分類需求一致。

.to(DEVICE):將模型移動到 GPU(如 cuda)以加速訓練。如果使用 CPU,則可以改為 DEVICE = "cpu"

為什麼選擇 ResNet12?

  1. 輕量化架構:ResNet12 是 ResNet 系列中較輕量的版本,參數量少,能有效降低計算資源的需求,適合 Few-shot Learning 的小樣本場景。
  2. 良好的特徵提取能力:ResNet12 的殘差結構(Residual Block)能捕捉多層次的特徵,對於細粒度辨識任務(如 CUB 資料集)表現出色。
  3. 適合 Few-shot 任務:ResNet12 經常用於 Few-shot Learning 框架,因其能平衡訓練效率與模型性能,適應各種細粒度數據集的需求。

若不使用 ResNet12 可以考慮的模型

  1. MobileNet 是一種輕量化模型,通過深度可分離卷積(Depthwise Separable Convolution)顯著減少參數量和計算量,非常適合硬體資源受限的情況。
  2. 它對小樣本數據的特徵提取能力穩定,能很好地適應 Few-shot Learning 的場景。

Few-shot 訓練前設定

還記得我們在上一篇說明 Few-shot learning 訓練過程的核心環節嗎?

這裡的任務(Task)基於支持集(Support Set)與查詢集(Query Set)進行學息的,因此在數據建立上需要進行合理分配,以因應 Few-shot learning 特性的訓練和驗證流程。

from easyfsl.methods import PrototypicalNetworks
from easyfsl.samplers import TaskSampler

n_way = 3
n_shot = 5
n_query = 10
n_validation_tasks = 50

val_set = CUB(split="val", training=False)
val_sampler = TaskSampler(
val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)
val_loader = DataLoader(
val_set,
batch_sampler=val_sampler,
num_workers=n_workers,
pin_memory=True,
collate_fn=val_sampler.episodic_collate_fn,
)

few_shot_classifier = PrototypicalNetworks(model).to(DEVICE)

核心邏輯解釋

  • TaskSampler: EasyFSL 框架中的一個模組,用於生成 Episodic Learning 任務。每個任務會隨機選擇 n_way 個類別,並從每個類別中取出 n_shot 張支撐集圖像和 n_query 張查詢集圖像。這種設計模擬了 Few-shot Learning 的場景,幫助模型學習更好的泛化能力。
  • DataLoader:使用 PyTorch 的 DataLoader 將生成的任務組織成批次,並通過多線程(num_workers)和固定內存(pin_memory)加速數據加載。
  • Prototypical Networks:是一種經典的 Few-shot Learning 方法,其核心思想是將支撐集的圖像嵌入到特徵空間中,計算各類別的「原型」(Prototype),並基於查詢集的特徵與原型之間的距離進行分類。

參數設定

  1. n_way = 3:每個任務包含 3 個類別(Way)。這表示模型每次訓練都需要對 3 個不同的類別進行分類。
  2. n_shot = 5:每個類別的支撐集(Support Set)包含 5 張圖像,作為模型學習該類別特徵的基礎。
  3. n_query = 10:每個類別的查詢集(Query Set)包含 10 張圖像,用於驗證模型是否正確學習到類別的特徵。
  4. n_validation_tasks = 50:在驗證過程中,每次將生成 50 個不同的任務,模擬模型在不同情境下的學習與預測能力。

模型學習設定

跟多數的 DL model 類似,要設定 loss function 、 leraning rate、 epoch 、optimizer … 參數,這裡就不多贅述。

from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter

LOSS_FUNCTION = nn.CrossEntropyLoss()

n_epochs = 10
scheduler_milestones = [150, 180]
scheduler_gamma = 0.1
learning_rate = 1e-01
tb_logs_dir = Path(".")

train_optimizer = SGD(
model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4
)
train_scheduler = MultiStepLR(
train_optimizer,
milestones=scheduler_milestones,
gamma=scheduler_gamma,
)

tb_writer = SummaryWriter(log_dir=str(tb_logs_dir))

設定說明

  1. 損失函數(LOSS_FUNCTION): 使用交叉熵損失函數(CrossEntropyLoss),適合於多分類任務。
  2. 優化器(train_optimizer): 選擇 SGD(隨機梯度下降)作為優化器,並配置了以下參數:
  • 學習率(lr=1e-01:控制參數更新的步長。
  • 動量(momentum=0.9:幫助加速梯度下降並減少波動性。
  • 權重衰減(weight_decay=5e-4:用於防止過擬合,起到正則化的效果。

3. 學習率調度器(train_scheduler): 可以理解成動態調節學習率

  • 使用 MultiStepLR 調度器在指定的里程碑(150 和 180 epoch)降低學習率。
  • 每次降低學習率時,會將當前學習率乘以 gamma=0.1。這種方式有助於模型在後期更穩定地收斂。

4. 訓練監控工具(tb_writer使用 TensorBoard 作為可視化工具,將訓練過程中的損失、準確率等指標記錄下來,便於監控與分析。

訓練和驗證流程

在這段訓練流程中,模型主要通過多次迭代來學習數據的特徵,每一步都執行特定的任務以確保模型能夠準確、高效地進行學習和驗證。

定義 training_epoch function

def training_epoch(model_: nn.Module, data_loader: DataLoader, optimizer: Optimizer):
all_loss = []
model_.train()
with tqdm(data_loader, total=len(data_loader), desc="Training") as tqdm_train:
for images, labels in tqdm_train:
optimizer.zero_grad()
loss = LOSS_FUNCTION(model_(images.to(DEVICE)), labels.to(DEVICE))
loss.backward()
optimizer.step()
all_loss.append(loss.item())
tqdm_train.set_postfix(loss=mean(all_loss))
return mean(all_loss)

在每一回和的訓練流程中:

  1. 切換模型為訓練模式model_.train():讓模型進入訓練模式,啟用 dropout 和 batch normalization 的特性,為訓練過程準備好。
  2. 迭代數據批次(Batch):使用 DataLoader 提供的數據批次進行迭代,每個批次包含一部分訓練數據(例如 batch_size = 128),這樣可以分批處理,減少對內存的壓力。
  3. 前向傳播(Forward Pass)model_(images.to(DEVICE)):將當前批次的圖像數據傳入模型,計算出對應的預測結果。
  4. 計算損失(Loss Calculation)LOSS_FUNCTION(predictions, labels):比較模型的預測結果與真實標籤,計算損失值,用於指導模型如何更新參數。
  5. 反向傳播(Backward Pass)loss.backward():根據損失值,計算每個參數的梯度,這些梯度表示每個參數對損失值的影響。
  6. 更新參數(Parameter Update)optimizer.step():利用計算出的梯度,通過優化器(如 SGD)更新模型的參數,讓模型在下一次迭代時更接近正確的預測。
  7. 記錄損失all_loss.append(loss.item()):將當前批次的損失記錄下來,用於後續的平均損失計算。

執行訓練和驗證流程

流程主要在做參數更新、驗證、保存最佳模型狀態

from easyfsl.utils import evaluate
# 保存最優模型狀態
best_state = None
best_validation_accuracy = 0.0
validation_frequency = 10

for epoch in range(n_epochs):
print(f"Epoch {epoch}")
average_loss = training_epoch(model, train_loader, train_optimizer)

if epoch % validation_frequency == validation_frequency - 1:
model.set_use_fc(False)
validation_accuracy = evaluate(
few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
)
model.set_use_fc(True)

# 更新最優模型狀態
if validation_accuracy > best_validation_accuracy:
best_validation_accuracy = validation_accuracy
best_state = copy.deepcopy(few_shot_classifier.state_dict())
print("Ding ding ding! We found a new best model!")

tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

tb_writer.add_scalar("Train/loss", average_loss, epoch)

train_scheduler.step()

# 加載最優模型狀態
if best_state is not None:
few_shot_classifier.load_state_dict(best_state)
print("Loaded the best model state!")
else:
print("No best state was saved.")
  1. 訓練模型參數(training_epoch
  • 使用 train_loader 將數據分批送入模型,通過優化器(train_optimizer)進行參數更新。
  • 每個批次計算損失(loss),通過反向傳播(backward())更新參數,最後記錄平均損失作為當前 epoch 的結果。

2. 定期驗證(validation_frequency

  • 每隔 10 個 epoch 進行一次驗證。
  • 使用驗證集(val_loader)評估模型在 Few-shot 任務下的準確率,檢查模型的泛化能力。

3. 保存最佳模型

  • 如果驗證準確率超過歷史最佳(best_validation_accuracy),保存當前模型狀態(state_dict),確保最終導出的是性能最優的模型。

4. 記錄訓練與驗證結果(TensorBoard)

  • 使用 TensorBoard 記錄每個 epoch 的訓練損失(Train/loss)與驗證準確率(Val/acc),便於後續分析。

5. 學習率調整: 使用學習率調度器(train_scheduler)在訓練過程中降低學習率(根據設置的里程碑),幫助模型在後期穩定收斂。

每個 epoch 的 loss log

每個 epoch 的 loss log

測試評估

經過訓練與驗證後,最後一步是測試模型的分類能力,評估它在未見過的數據上是否能夠有效運作。

n_test_tasks = 10

test_set = CUB(split="test", training=False)
test_sampler = TaskSampler(
test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)
test_loader = DataLoader(
test_set,
batch_sampler=test_sampler,
num_workers=n_workers,
pin_memory=True,
collate_fn=test_sampler.episodic_collate_fn,
)

model.set_use_fc(False)

accuracy = evaluate(few_shot_classifier, test_loader, device=DEVICE)
print(f"Average accuracy : {(100 * accuracy):.2f} %")

測試過程說明

  1. 設置測試任務數量 (n_test_tasks = 10)
  • 測試階段採用 10 組不同的 Few-shot 任務,確保模型的泛化能力得到充分測試。
  • 每個任務都會根據 n_way(類別數)、n_shot(支撐集數量)、n_query(查詢集數量)進行組織,模擬真實的小樣本學習場景。

2. 數據準備

  • test_set = CUB(split="test", training=False)加載 CUB 資料集的測試集,確保使用未見過的數據來檢驗模型的泛化能力。
  • TaskSampler 負責組織測試任務,確保每次測試都符合 Few-shot Learning 的典型設計(支撐集 + 查詢集)。
  • DataLoader 用於批次處理測試數據,提高數據加載效率。

3. 模型前向傳播與準確率計算

  • model.set_use_fc(False):測試時禁用全連接層(Fully Connected Layer),直接使用 Prototypical Networks 或 Few-shot 方法的特徵嵌入進行分類。
  • 這樣的設計能更貼近 Few-shot Learning 的核心思想,即透過支撐集學習特徵,並利用距離度量進行分類,而不是傳統的大規模訓練分類器。(以下補充說明為何要這麼做)
  • evaluate(few_shot_classifier, test_loader, device=DEVICE)
  • evaluate 函數負責執行模型的測試流程,計算每個查詢樣本與支撐集樣本的距離,並進行分類決策。
  • 最終輸出測試準確率,衡量模型在新類別上的學習能力。


補充說明model.set_use_fc(False)

當我們執行 set_use_fc(False),實際上就是 讓模型不再輸出傳統分類的機率分佈,而是只提取特徵向量,並透過 Few-shot 方法(如 Prototypical Networks)進行距離度量分類。

對比全連接層分類與 Few-shot Learning

raw-image

測試結果

raw-image
  • 這表示 在所有 10 個測試任務的查詢圖片中,模型的平均分類準確率為 81%
  • 若 n_way = 3,則每個測試任務的查詢集包含 3 個類別,模型需要判斷這些圖片應屬於哪個類別。
  • 這個結果表示,在 10 個不同的 Few-shot 測試場景中,模型大部分時間能夠準確地辨識新類別。

顯示測試的結果

我們印出其中一個 batch 的測試結果,幫助更好理解模型真正預測情形。

  • 上排表示抽樣的 support set
  • 下排表示 query set ,並且包含真實標籤(GT)和預測標籤(Pred),綠色字 表示預測正確、紅色字 表示預測錯誤
raw-image

小結

在這篇文章中,我們使用 Easy Few-shot 框架,結合 CUB 鳥類分類資料集,實作了一個 Few-Shot Learning(FSL)模型,並拆解中間過程細節,從資料讀取、Dataloader 設計、ResNet12 模型設定、Few-Shot 訓練與測試評估 的整個流程。透過這次實作,我們驗證了即使在極少樣本(Few-Shot)場景下,模型仍然能夠學習特徵並執行分類任務。

在測試階段,我們進行了視覺化分析,顯示 支撐集(Support Set)與查詢集(Query Set) 的影像與預測結果,並發現模型有一定程度的準確,需要評估真實應用場景是否可以接受這樣的準確與誤判結果。

雖然這次初步實作的部分已經完成,但 Few-Shot Learning 仍然是一個值得深入探討的領域,未來可以進一步研究 不同 Few-Shot 方法的比較(ProtoNet、Matching Network、MAML)、還有找時間讀一下paper原文,甚至探索如何應用於醫學影像的領域。

這篇文章至此先告一段落,希望這次的實作能夠幫助理解 Few-Shot Learning 的核心概念與實踐流程,也為後續更深入的研究打下基礎。

我們下次見~

附上在 Heptabase 白板整理的 snapshot (白板累積越來越多的卡片)

有需要的話~ 🎁 Heptabase 折扣碼:https://join.heptabase.com?invite-acc-id=889dbd70-632c-446d-b7f8-ffa41b27e716

Heptabase 白板 snapshot

Heptabase 白板 snapshot

相關連結

avatar-img
33會員
44內容數
歡迎來到《桃花源記》專欄。這裡不僅是一個文字的集合,更是一個探索、夢想和自我發現的空間。在這個專欄中,我們將一同走進那些隱藏在日常生活中的"桃花源"——那些讓我們心動、讓我們反思、讓我們找到內心平靜的時刻和地方
留言0
查看全部
avatar-img
發表第一個留言支持創作者!
Karen的沙龍 的其他內容
探索Few-Shot Learning如何在數據稀缺的情況下使機器學習模型迅速學習並做出精確預測。本文將介紹Few-Shot Learning的基本原理、核心策略,以及在實際應用。
探索Few-Shot Learning如何在數據稀缺的情況下使機器學習模型迅速學習並做出精確預測。本文將介紹Few-Shot Learning的基本原理、核心策略,以及在實際應用。
你可能也想看
Google News 追蹤
Thumbnail
本文將延續上一篇文章,經由訓練好的VAE模型其中的解碼器,來生成圖片。 [深度學習]訓練VAE模型用於生成圖片_訓練篇 輸入產生的隨機雜訊,輸入VAE的解碼器後,生成的圖片
Thumbnail
前言 讀了許多理論,是時候實際動手做做看了,以下是我的模型訓練初體驗,有點糟就是了XD。 正文 def conv(filters, kernel_size, strides=1): return Conv2D(filters, kernel_size,
Thumbnail
最近在嘗試使用不同的AI生圖方式混合出圖的方式,採用A平台的優點,並用B平台後製的手法截長補短,創造出自己更想要的小說場景,效果不錯,現在以這張圖為例,來講一下我的製作步驟。
Thumbnail
此篇調查論文探討了Diffusion模型在文字、圖片和聲音轉換為影片,以及影片衍生和編輯的應用類型。作者也介紹了U-Net架構和Vision Transformer等生成圖像架構,並詳細探討了訓練模型的方法以及不同的影像資料集來源。
Thumbnail
Ae 小技巧:宣紙噪點+抽幀效果 動態後記系列會記錄一些我在製作中的記錄,可能是分解動畫、小技巧、發想、腳本......等等。 每篇都是小短篇,就是補充用的小筆記,沒有前後順序,可跳著閱讀。
Thumbnail
寬景Wide view 鳥瞰Bird view 前景Foreground 背景Background 正面Front View 側面Side View 俯視Top View 景深Depth of field 微距鏡頭Macro Shot 超特寫Extreme Close up
Thumbnail
我一看,這沒什麼難啊,但是我知道你們上學期有學過「美圖秀秀」的方法,這次呢,我們就教「Moonshot」的指令。
Thumbnail
這一篇要測試一下Video Linear CFG Guidance這個節點,在網路上很多的教學影片跟網友分享的工作流中會看到這個節點,據說這個節點不只可以用在生成影片的工作流中,也可以使用在一般的生成圖片工作流中。
Thumbnail
本篇文章參考 Youtube 影片(...真實模型推薦...)內容,為大家找出影片中的模型,直接作圖測試,您直接連結過去,就可以在 TensorArt 內直接使用囉!
Thumbnail
本文將延續上一篇文章,經由訓練好的VAE模型其中的解碼器,來生成圖片。 [深度學習]訓練VAE模型用於生成圖片_訓練篇 輸入產生的隨機雜訊,輸入VAE的解碼器後,生成的圖片
Thumbnail
前言 讀了許多理論,是時候實際動手做做看了,以下是我的模型訓練初體驗,有點糟就是了XD。 正文 def conv(filters, kernel_size, strides=1): return Conv2D(filters, kernel_size,
Thumbnail
最近在嘗試使用不同的AI生圖方式混合出圖的方式,採用A平台的優點,並用B平台後製的手法截長補短,創造出自己更想要的小說場景,效果不錯,現在以這張圖為例,來講一下我的製作步驟。
Thumbnail
此篇調查論文探討了Diffusion模型在文字、圖片和聲音轉換為影片,以及影片衍生和編輯的應用類型。作者也介紹了U-Net架構和Vision Transformer等生成圖像架構,並詳細探討了訓練模型的方法以及不同的影像資料集來源。
Thumbnail
Ae 小技巧:宣紙噪點+抽幀效果 動態後記系列會記錄一些我在製作中的記錄,可能是分解動畫、小技巧、發想、腳本......等等。 每篇都是小短篇,就是補充用的小筆記,沒有前後順序,可跳著閱讀。
Thumbnail
寬景Wide view 鳥瞰Bird view 前景Foreground 背景Background 正面Front View 側面Side View 俯視Top View 景深Depth of field 微距鏡頭Macro Shot 超特寫Extreme Close up
Thumbnail
我一看,這沒什麼難啊,但是我知道你們上學期有學過「美圖秀秀」的方法,這次呢,我們就教「Moonshot」的指令。
Thumbnail
這一篇要測試一下Video Linear CFG Guidance這個節點,在網路上很多的教學影片跟網友分享的工作流中會看到這個節點,據說這個節點不只可以用在生成影片的工作流中,也可以使用在一般的生成圖片工作流中。
Thumbnail
本篇文章參考 Youtube 影片(...真實模型推薦...)內容,為大家找出影片中的模型,直接作圖測試,您直接連結過去,就可以在 TensorArt 內直接使用囉!