AI時代系列(3) 機器學習三部曲: 🔹 第一部:《機器學習 —— AI 智慧的啟航》
46/100 第五週:非監督學習
46. 生成對抗網路(GAN)基礎 🎨 讓 AI 產生逼真的圖片、音樂和語音!
🎨 生成對抗網路(GAN)基礎
讓 AI 產生逼真的圖片、音樂和語音!
________________________________________
🔎 一、什麼是 GAN(Generative Adversarial Network)?
• Ian Goodfellow 於 2014 年提出
• 是一種強大的生成模型(Generative Model)
• 兩個神經網路互相對抗、相互提升
o 生成器(Generator, G):負責「騙人」產生逼真數據
o 判別器(Discriminator, D):負責「抓假」分辨真假
• 最終目標:讓生成器產生的數據逼真到連判別器都分不出真假!
________________________________________
🌟 二、GAN 運作原理(核心概念)
組件 功能
Generator (G): 從隨機噪聲產生類似真實的數據(假圖片、假音樂)
Discriminator (D): 判斷輸入是真實資料還是 G 生成的假資料
對抗學習:G 持續學習如何「騙」D,D 不斷強化辨識能力
損失函數:雙方互相博弈,最終達到均衡
________________________________________
🛠 三、簡化數學公式(直觀理解)
G: 生成器 D: 判別器
-------------------- ----------------------------------
| 亂數 z ~ p(z) | | 真實樣本 x ~ p_data(x) |
| → G(z) → 假樣本 | | |
-------------------- ----------------------------------
↓ ↓
嘗試騙過 D(生成真) 嘗試分辨真假(真假概率越準越好)
整體損失函數:
min_G max_D V(D, G)
V(D, G) = E_x[log D(x)] + E_z[log(1 - D(G(z)))]
這段圖文內容清楚地呈現了 生成對抗網路(GAN) 的核心對抗機制:
生成器(G)從隨機噪聲 z∼p(z)產生假樣本 G(z),目的是讓這些假樣本看起來像真實資料,試圖「騙過」判別器(D);
而判別器的任務則是區分真實樣本 x∼pdata(x)與生成樣本 G(z),提升辨別正確率。兩者透過對抗性的學習方式進行「博弈」,
損失函數 V(D,G) 同時包含 D 判別真實樣本的準確性 logD(x),以及分辨生成樣本的能力 log(1−D(G(z)))。整體學習目標就是 G 嘗試最小化此函數,而 D 嘗試最大化它,最終達成一個動態平衡,生成器能產出幾可亂真的樣本,判別器則無法再準確分辨真假。
🎯 直觀理解:
D 就像警察:越準越好。
G 就像造假者:越逼真越好。
整體是一場博弈:G 和 D 相互進化,直到 D 無法分辨真假為止(也就是達成納什均衡的狀態)。
________________________________________
📈 四、GAN 經典應用場景
✅ 圖片生成(如:DeepFake、人臉生成)
✅ AI 作畫(AI 藝術生成)
✅ 語音合成(Text-to-Speech,TTS)
✅ 音樂創作(AI 自動作曲)
✅ 數據增強(Data Augmentation)
✅ 遊戲場景生成、模擬器生成虛擬環境
________________________________________
🎨 五、AI 生圖範例 - DCGAN(圖片生成)
python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
# 設定裝置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 生成器模型
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(100, 128, 7, 1, 0, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
# 判別器模型
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
nn.Conv2d(1, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Flatten(),
nn.Linear(128 * 7 * 7, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x)
# 超參數設定
batch_size = 64
lr = 0.0002
num_epochs = 10
latent_dim = 100
# 數據處理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
dataset = datasets.MNIST(root='data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 模型與優化器
G = Generator().to(device)
D = Discriminator().to(device)
loss_fn = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
# 訓練
os.makedirs('dcgan_output', exist_ok=True)
for epoch in range(num_epochs):
for i, (real_imgs, _) in enumerate(dataloader):
real_imgs = real_imgs.to(device)
batch_size = real_imgs.size(0)
# 訓練判別器 D
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
fake_imgs = G(z)
real_loss = loss_fn(D(real_imgs), real_labels)
fake_loss = loss_fn(D(fake_imgs.detach()), fake_labels)
d_loss = real_loss + fake_loss
optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()
# 訓練生成器 G
g_loss = loss_fn(D(fake_imgs), real_labels)
optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()
if i % 100 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] Step [{i}/{len(dataloader)}] "
f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
# 每一輪儲存生成圖片
with torch.no_grad():
fake_imgs = G(torch.randn(64, latent_dim, 1, 1).to(device))
save_image(fake_imgs[:25], f"dcgan_output/epoch_{epoch+1}.png", nrow=5, normalize=True)
print("訓練完成,生成圖片已儲存至 dcgan_output/")
🖼 執行後效果:
程式會在每個訓練 epoch 產生 25 張 28x28 的手寫數字圖片,並儲存在 dcgan_output 資料夾中。
________________________________________
📊 六、GAN 進階變種
模型名稱 功能特色
DCGAN 深度卷積 GAN,生成更精緻圖片
CycleGAN 可實現風格轉換(馬→斑馬,照片→油畫)
StyleGAN 高品質人臉生成,超擬真
Pix2Pix 影像到影像的轉換
WGAN 解決傳統 GAN 訓練不穩定問題
________________________________________
📉 七、GAN 優缺點總結
優點 缺點
✅ 生成結果逼真度高 ❌ 訓練極不穩定(模式崩潰)
✅ 應用範圍廣(圖片、音樂、語音) ❌ 需要大量數據與運算資源
✅ 不需標籤資料(Unsupervised) ❌ 難以衡量生成效果好壞(評估困難)
________________________________________
🎯 八、總結與亮點
✔ GAN 是目前最受矚目的 AI 生成技術
✔ 不只讓機器「看」得懂,還能讓機器「創造」內容
✔ 是**生成式 AI(Generative AI)**的核心技術基礎
________________________________________
📌 一句話精華
🎨 GAN = 讓 AI 學會「以假亂真」,創作出人類也分不清的影像、音樂與語音!
________________________________________