2024-07-25|閱讀時間 ‧ 約 35 分鐘

[深度學習][Python]訓練MLP的GAN模型來生成圖片_訓練篇

本文主要介紹,如何利用GAN生成對抗網路來訓練生成圖片。

利用tensorflow,中的keras來建立生成器及鑑別器互相競爭訓練,最後利用訓練好的生成器來生成圖片。

GAN生成對抗網路的介紹

它由生成網路(Generator Network)和鑑別網路(Discriminator Network),二個深度類神經網路所組成,二個網路在學習過程中互相的對抗與進化,在學習過程中二個網路相互對抗調整參數,讓每一方都在不斷升級,整個學習的最終目的就是讓鑑別網路誤判生成網絡的輸出結果。

由上述,可使用GAN來生成圖片,利用訓練好的生成網路來生成圖片,請注意,生成網路會生成什麼樣的圖片,是由訓練資料集來決定。例如說,要生成小猴子的照片,就得提供GAN一批小猴子照片當模仿參考。


資料集

採用TF影像資料庫中的fashion_mnist

https://www.tensorflow.org/datasets/catalog/fashion_mnist

在Colab上執行訓練

Python及套件版本

Python version: 3.10.12
matplotlib version: 3.7.1
numpy version: 1.25.2
tensorflow version: 2.15.0
Pandas version: 2.0.3

程式碼

1.載入套件

import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

2.載入資料集

從tensorflow.keras.datasets下載fashion_mnist資料集

from tensorflow.keras.datasets import 

(x_train_set, y_train_set), (x_test, y_test) = fashion_mnist. load_data()

3.確認資料集

Show出一張圖,確認是fashion_mnist的圖片

i = 6
print (y_train_set [i])
plt. imshow(x_train_set[i], cmap='binary')
plt.show()

4.分割資料集

主要分出訓練跟驗證集

from sklearn.model_selection import train_test_split

x_train, x_valid, y_train, y_valid = train_test_split(x_train_set, y_train_set,
test_size=0.1, random_state=1)

5. 建立GAM模型

5.1設置了一個基本的 GAN

# 隨機種子設置​
tf.random.set_seed(1) #確保程式碼在不同次運行時能夠產生相同的隨機數據,從而提高可重現性
np.random.seed (1)
d = 30 #Codeing數量
#生成器(Generator)​
generator = keras.models.Sequential([
keras.layers.Dense (100, activation="selu", input_shape= [d]),
keras.layers.Dense (150, activation="selu"),
keras.layers.Dense (28 * 28, activation="sigmoid"),
keras.layers.Reshape ( [28, 28])
])
#判別器(Discriminator)​
discriminator = keras.models.Sequential([
keras.layers.Flatten(input_shape= [28, 28]),
keras.layers.Dense(150, activation="selu"),
keras.layers.Dense(100, activation="selu"),
keras.layers.Dense(1, activation="sigmoid")
])
# GAN 模型​
# 將生成器和判別器結合在一起,形成一個生成對抗網絡(GAN)模型​
gan = keras.models.Sequential([generator, discriminator])

5.2檢視鑑別器的模型

discriminator.summary()

5.3檢視生成器的模型

gan.summary()

5.4檢視結合後GAN的模型

6.模型 Compile

設置判別器的損失函數及優化器,對整個 GAN 模型進行編譯,指定損失函數和優化器。

discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")

discriminator.trainable = False #將判別器設置為不可訓練

gan. compile(loss="binary_crossentropy", optimizer="rmsprop")

7.訓練

7.1 隨機分布訓練集

#設置了每次訓練時使用的批次大小為 32​
batch_size = 32
#使用 from_tensor_slices 方法將訓練數據 x_train 轉換為 tf.data.Dataset 對象。​
dataset = tf.data.Dataset. from_tensor_slices (x_train)#先暫存到暫存區,讀取效率較好
# 打亂數據​
dataset = dataset.shuffle(1000)#

# 預取一個批次的數據,以提高數據加載的效率​
#drop_remainder:將剩下的踢除
#prefetch(1):預先準備下一批次data
dataset = dataset. batch (batch_size, drop_remainder=True).prefetch(1)

7.2 建立訓練GAN的方法

實現 GAN 的訓練過程,通過交替訓練生成器和判別器來逐步改進生成器的性能

# 繪製生成圖像用
def plot_multiple_images (images, n=None):
#判斷最后一維度若為 1,則消去最后一維度
if images.shape [-1] == 1:
images = np. squeeze (images, axis=-1)
plt.figure(figsize=(n, 1))

for i in range(n):
plt.subplot(1, n, i + 1)
plt. imshow (images [i], cmap="binary")
plt.axis ("off")
plt.show ()

def train_gan(gan, dataset, batch_size, d, n_epochs=10):
'''
函數參數
gan:GAN 模型,由生成器和判別器組成。
dataset:訓練數據集。
batch_size:每個訓練批次的大小。
d:生成器輸入的噪聲維度。
n_epochs:訓練的總迭代次數。
'''
##拆分 GAN 模型​
generator, discriminator = gan. layers
## 迭代訓練 Epochs​
for epoch in range(n_epochs):
print("Epoch {}/{}". format(epoch + 1, n_epochs))
# 遍歷批次數據集
for x_batch in dataset:
# 第一階段 訓練判別器
# 生成噪聲並生成圖像
noise = tf. random. normal (shape= [batch_size, d])
generated_images = generator (noise)
# 合併真實和假圖像
x_fake_and_real = tf.concat ( [generated_images, x_batch], axis=0)
# 設置判別器的訓練標籤
# 為假圖像設置標籤為 0,真實圖像設置標籤為 1。
y1 = tf. constant([[0.]] * batch_size + [[1.]] * batch_size)
# 訓練判別器
discriminator.trainable = True
discriminator.train_on_batch(x_fake_and_real, y1)#不可用fit(),並進行比較

#第二階段 訓練生成器
# 生成新的噪聲
noise = tf. random. normal (shape= [batch_size, d])
# 設置生成器的訓練標籤
# 為生成器的輸出設置標籤為 1,以欺騙判別器。
y2 = tf. constant([ [1.]] * batch_size)
# 訓練生成器
discriminator.trainable = False
gan.train_on_batch(noise, y2)
## 繪製生成圖像​ 每個 epoch 結束後,顯示生成的圖像。
plot_multiple_images (generated_images, 10)#生成 10張圖

7.3 開始訓練

train_gan(gan, dataset, batch_size, d, n_epochs=10)

儲存模型

generator.save('generator.h5')











分享至
成為作者繼續創作的動力吧!
© 2024 vocus All rights reserved.