在上一篇我們簡單介紹了 Few-shot learning 的概念與應用價值,還有他最重要的實作參數,那在這一篇我們就來使用CUB資料集,實作一下 Few-shot learning,還沒看過上篇的可以點以下連結: 淺談Few-Shot Learning:(1): 初步認識。
Photo by Nubelson Fernandes on Unsplash
CUB(Caltech-UCSD Birds 200)資料集是一個常用於電腦視覺研究的影像資料集,尤其在細粒度影像辨識領域中非常受歡迎。這個資料集是由加州理工學院(Caltech)和加州大學聖地牙哥分校(UCSD)聯合創建的。
資料集用途:
本次實作中,我們僅使用到類別標籤,以簡化任務並聚焦於 Few-shot learning 的核心任務上。
首先資料下載下來後依照這樣的目錄結構放置
data
└── CUB
├── 001 XX 鳥
│ ├── 1.jpg
│ └── 2.jpg
└── 002 OO 鳥
├── 3.jpg
├── 4.jpg
└── 5.jpg
鳥類影像如下圖:
如果在 colab 上實作的話,遇到會因為資料量大不能下載的問題,解法參考:Failed to download CUB dataset in Google Colab on example Notebooks
下載準備好 EasyFSL 提供的 CUB 資料集後,第一個步驟就是把資料載入,透過 Pytorch 的 Dataloader 進行數據的 batch process 和 load 。
from easyfsl.datasets import CUB
from torch.utils.data import DataLoader
batch_size = 128
n_workers = 2
# 加載 CUB 資料集的訓練集
train_set = CUB(split="train", training=True)
train_loader = DataLoader(
train_set,
batch_size=batch_size,
num_workers=n_workers,
pin_memory=True,
shuffle=True,
)
以下說明設定參數:
batch_size
:決定 batch 的數量,每次訓練處理的圖像數量為 128,若 batch_size 越大表示每次迭代處理的數量越多,更有效的使用 GPU ,同時需考量硬體資源,而大批次也可能導致模型收斂速度比較慢 ; 相反的想批次每次需要的資源小,模型收斂快,但可能小批次造成每次模型學習較不穩定,梯度波動性較大。n_workers
:使用 2 個工作執行緒來加速數據加載,提升效率。pin_memory
:將數據固定在內存中,以加快 GPU 的數據傳輸速度。shuffle
:在每個 epoch 隨機打亂數據,增強模型的泛化能力。這裡我們使用了 ResNet12 作為特徵提取的 Backbone(主幹網絡)。ResNet12 是一種輕量級的卷積神經網絡架構,非常適合在 Few-shot Learning 場景下使用,因為它計算效率高且能有效捕捉圖像中的細粒度特徵。
以下為模型初始設定
from easyfsl.modules import resnet12
DEVICE = "cuda"
model = resnet12(
use_fc=True,
num_classes=len(set(train_set.get_labels())),
).to(DEVICE)
use_fc=True
:啟用全連接層(Fully Connected Layer),用於最後的分類任務。這使模型能夠根據提取的特徵,輸出對應於類別數的分類結果。\num_classes=len(set(train_set.get_labels()))
:根據訓練集中存在的類別數動態設定輸出層的維度。這確保模型輸出的維度與任務的分類需求一致。.to(DEVICE)
:將模型移動到 GPU(如 cuda
)以加速訓練。如果使用 CPU,則可以改為 DEVICE = "cpu"
。
還記得我們在上一篇說明 Few-shot learning 訓練過程的核心環節嗎?
這裡的任務(Task)基於支持集(Support Set)與查詢集(Query Set)進行學息的,因此在數據建立上需要進行合理分配,以因應 Few-shot learning 特性的訓練和驗證流程。
from easyfsl.methods import PrototypicalNetworks
from easyfsl.samplers import TaskSampler
n_way = 3
n_shot = 5
n_query = 10
n_validation_tasks = 50
val_set = CUB(split="val", training=False)
val_sampler = TaskSampler(
val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)
val_loader = DataLoader(
val_set,
batch_sampler=val_sampler,
num_workers=n_workers,
pin_memory=True,
collate_fn=val_sampler.episodic_collate_fn,
)
few_shot_classifier = PrototypicalNetworks(model).to(DEVICE)
n_way
個類別,並從每個類別中取出 n_shot
張支撐集圖像和 n_query
張查詢集圖像。這種設計模擬了 Few-shot Learning 的場景,幫助模型學習更好的泛化能力。num_workers
)和固定內存(pin_memory
)加速數據加載。n_way
= 3:每個任務包含 3 個類別(Way)。這表示模型每次訓練都需要對 3 個不同的類別進行分類。n_shot
= 5:每個類別的支撐集(Support Set)包含 5 張圖像,作為模型學習該類別特徵的基礎。n_query
= 10:每個類別的查詢集(Query Set)包含 10 張圖像,用於驗證模型是否正確學習到類別的特徵。n_validation_tasks
= 50:在驗證過程中,每次將生成 50 個不同的任務,模擬模型在不同情境下的學習與預測能力。跟多數的 DL model 類似,要設定 loss function 、 leraning rate、 epoch 、optimizer … 參數,這裡就不多贅述。
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter
LOSS_FUNCTION = nn.CrossEntropyLoss()
n_epochs = 10
scheduler_milestones = [150, 180]
scheduler_gamma = 0.1
learning_rate = 1e-01
tb_logs_dir = Path(".")
train_optimizer = SGD(
model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4
)
train_scheduler = MultiStepLR(
train_optimizer,
milestones=scheduler_milestones,
gamma=scheduler_gamma,
)
tb_writer = SummaryWriter(log_dir=str(tb_logs_dir))
LOSS_FUNCTION
): 使用交叉熵損失函數(CrossEntropyLoss),適合於多分類任務。train_optimizer
): 選擇 SGD(隨機梯度下降)作為優化器,並配置了以下參數:lr=1e-01
):控制參數更新的步長。momentum=0.9
):幫助加速梯度下降並減少波動性。weight_decay=5e-4
):用於防止過擬合,起到正則化的效果。3. 學習率調度器(train_scheduler
): 可以理解成動態調節學習率
MultiStepLR
調度器在指定的里程碑(150 和 180 epoch)降低學習率。gamma=0.1
。這種方式有助於模型在後期更穩定地收斂。4. 訓練監控工具(tb_writer
)使用 TensorBoard 作為可視化工具,將訓練過程中的損失、準確率等指標記錄下來,便於監控與分析。
在這段訓練流程中,模型主要通過多次迭代來學習數據的特徵,每一步都執行特定的任務以確保模型能夠準確、高效地進行學習和驗證。
def training_epoch(model_: nn.Module, data_loader: DataLoader, optimizer: Optimizer):
all_loss = []
model_.train()
with tqdm(data_loader, total=len(data_loader), desc="Training") as tqdm_train:
for images, labels in tqdm_train:
optimizer.zero_grad()
loss = LOSS_FUNCTION(model_(images.to(DEVICE)), labels.to(DEVICE))
loss.backward()
optimizer.step()
all_loss.append(loss.item())
tqdm_train.set_postfix(loss=mean(all_loss))
return mean(all_loss)
在每一回和的訓練流程中:
model_.train()
:讓模型進入訓練模式,啟用 dropout 和 batch normalization 的特性,為訓練過程準備好。DataLoader
提供的數據批次進行迭代,每個批次包含一部分訓練數據(例如 batch_size = 128
),這樣可以分批處理,減少對內存的壓力。model_(images.to(DEVICE))
:將當前批次的圖像數據傳入模型,計算出對應的預測結果。LOSS_FUNCTION(predictions, labels)
:比較模型的預測結果與真實標籤,計算損失值,用於指導模型如何更新參數。loss.backward()
:根據損失值,計算每個參數的梯度,這些梯度表示每個參數對損失值的影響。optimizer.step()
:利用計算出的梯度,通過優化器(如 SGD)更新模型的參數,讓模型在下一次迭代時更接近正確的預測。all_loss.append(loss.item())
:將當前批次的損失記錄下來,用於後續的平均損失計算。流程主要在做參數更新、驗證、保存最佳模型狀態
from easyfsl.utils import evaluate
# 保存最優模型狀態
best_state = None
best_validation_accuracy = 0.0
validation_frequency = 10
for epoch in range(n_epochs):
print(f"Epoch {epoch}")
average_loss = training_epoch(model, train_loader, train_optimizer)
if epoch % validation_frequency == validation_frequency - 1:
model.set_use_fc(False)
validation_accuracy = evaluate(
few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
)
model.set_use_fc(True)
# 更新最優模型狀態
if validation_accuracy > best_validation_accuracy:
best_validation_accuracy = validation_accuracy
best_state = copy.deepcopy(few_shot_classifier.state_dict())
print("Ding ding ding! We found a new best model!")
tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)
tb_writer.add_scalar("Train/loss", average_loss, epoch)
train_scheduler.step()
# 加載最優模型狀態
if best_state is not None:
few_shot_classifier.load_state_dict(best_state)
print("Loaded the best model state!")
else:
print("No best state was saved.")
training_epoch
):train_loader
將數據分批送入模型,通過優化器(train_optimizer
)進行參數更新。loss
),通過反向傳播(backward()
)更新參數,最後記錄平均損失作為當前 epoch 的結果。2. 定期驗證(validation_frequency
):
val_loader
)評估模型在 Few-shot 任務下的準確率,檢查模型的泛化能力。3. 保存最佳模型:
best_validation_accuracy
),保存當前模型狀態(state_dict
),確保最終導出的是性能最優的模型。4. 記錄訓練與驗證結果(TensorBoard):
Train/loss
)與驗證準確率(Val/acc
),便於後續分析。5. 學習率調整: 使用學習率調度器(train_scheduler
)在訓練過程中降低學習率(根據設置的里程碑),幫助模型在後期穩定收斂。
每個 epoch 的 loss log
經過訓練與驗證後,最後一步是測試模型的分類能力,評估它在未見過的數據上是否能夠有效運作。
n_test_tasks = 10
test_set = CUB(split="test", training=False)
test_sampler = TaskSampler(
test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)
test_loader = DataLoader(
test_set,
batch_sampler=test_sampler,
num_workers=n_workers,
pin_memory=True,
collate_fn=test_sampler.episodic_collate_fn,
)
model.set_use_fc(False)
accuracy = evaluate(few_shot_classifier, test_loader, device=DEVICE)
print(f"Average accuracy : {(100 * accuracy):.2f} %")
n_test_tasks = 10
)n_way
(類別數)、n_shot
(支撐集數量)、n_query
(查詢集數量)進行組織,模擬真實的小樣本學習場景。2. 數據準備
test_set = CUB(split="test", training=False)
加載 CUB 資料集的測試集,確保使用未見過的數據來檢驗模型的泛化能力。TaskSampler
負責組織測試任務,確保每次測試都符合 Few-shot Learning 的典型設計(支撐集 + 查詢集)。DataLoader
用於批次處理測試數據,提高數據加載效率。3. 模型前向傳播與準確率計算
model.set_use_fc(False)
:測試時禁用全連接層(Fully Connected Layer),直接使用 Prototypical Networks 或 Few-shot 方法的特徵嵌入進行分類。evaluate(few_shot_classifier, test_loader, device=DEVICE)
:evaluate
函數負責執行模型的測試流程,計算每個查詢樣本與支撐集樣本的距離,並進行分類決策。補充說明model.set_use_fc(False)
:
當我們執行 set_use_fc(False)
,實際上就是 讓模型不再輸出傳統分類的機率分佈,而是只提取特徵向量,並透過 Few-shot 方法(如 Prototypical Networks)進行距離度量分類。
對比全連接層分類與 Few-shot Learning
n_way = 3
,則每個測試任務的查詢集包含 3 個類別,模型需要判斷這些圖片應屬於哪個類別。我們印出其中一個 batch 的測試結果,幫助更好理解模型真正預測情形。
在這篇文章中,我們使用 Easy Few-shot 框架,結合 CUB 鳥類分類資料集,實作了一個 Few-Shot Learning(FSL)模型,並拆解中間過程細節,從資料讀取、Dataloader 設計、ResNet12 模型設定、Few-Shot 訓練與測試評估 的整個流程。透過這次實作,我們驗證了即使在極少樣本(Few-Shot)場景下,模型仍然能夠學習特徵並執行分類任務。
在測試階段,我們進行了視覺化分析,顯示 支撐集(Support Set)與查詢集(Query Set) 的影像與預測結果,並發現模型有一定程度的準確,需要評估真實應用場景是否可以接受這樣的準確與誤判結果。
雖然這次初步實作的部分已經完成,但 Few-Shot Learning 仍然是一個值得深入探討的領域,未來可以進一步研究 不同 Few-Shot 方法的比較(ProtoNet、Matching Network、MAML)、還有找時間讀一下paper原文,甚至探索如何應用於醫學影像的領域。
這篇文章至此先告一段落,希望這次的實作能夠幫助理解 Few-Shot Learning 的核心概念與實踐流程,也為後續更深入的研究打下基礎。
我們下次見~
附上在 Heptabase 白板整理的 snapshot (白板累積越來越多的卡片)
有需要的話~ 🎁 Heptabase 折扣碼:https://join.heptabase.com?invite-acc-id=889dbd70-632c-446d-b7f8-ffa41b27e716
Heptabase 白板 snapshot