[深度學習][Python]訓練MLP的GAN模型來生成圖片_訓練篇

更新 發佈閱讀 11 分鐘

本文主要介紹,如何利用GAN生成對抗網路來訓練生成圖片。

利用tensorflow,中的keras來建立生成器及鑑別器互相競爭訓練,最後利用訓練好的生成器來生成圖片。

GAN生成對抗網路的介紹

它由生成網路(Generator Network)和鑑別網路(Discriminator Network),二個深度類神經網路所組成,二個網路在學習過程中互相的對抗與進化,在學習過程中二個網路相互對抗調整參數,讓每一方都在不斷升級,整個學習的最終目的就是讓鑑別網路誤判生成網絡的輸出結果。

raw-image

由上述,可使用GAN來生成圖片,利用訓練好的生成網路來生成圖片,請注意,生成網路會生成什麼樣的圖片,是由訓練資料集來決定。例如說,要生成小猴子的照片,就得提供GAN一批小猴子照片當模仿參考。


資料集

採用TF影像資料庫中的fashion_mnist

https://www.tensorflow.org/datasets/catalog/fashion_mnist

https://www.tensorflow.org/datasets/catalog/fashion_mnist

在Colab上執行訓練

Python及套件版本

Python version: 3.10.12
matplotlib version: 3.7.1
numpy version: 1.25.2
tensorflow version: 2.15.0
Pandas version: 2.0.3

程式碼

1.載入套件

import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

2.載入資料集

從tensorflow.keras.datasets下載fashion_mnist資料集

from tensorflow.keras.datasets import 

(x_train_set, y_train_set), (x_test, y_test) = fashion_mnist. load_data()

3.確認資料集

Show出一張圖,確認是fashion_mnist的圖片

i = 6
print (y_train_set [i])
plt. imshow(x_train_set[i], cmap='binary')
plt.show()
raw-image

4.分割資料集

主要分出訓練跟驗證集

from sklearn.model_selection import train_test_split

x_train, x_valid, y_train, y_valid = train_test_split(x_train_set, y_train_set,
test_size=0.1, random_state=1)

5. 建立GAM模型

5.1設置了一個基本的 GAN

# 隨機種子設置​
tf.random.set_seed(1) #確保程式碼在不同次運行時能夠產生相同的隨機數據,從而提高可重現性
np.random.seed (1)
d = 30 #Codeing數量
#生成器(Generator)​
generator = keras.models.Sequential([
keras.layers.Dense (100, activation="selu", input_shape= [d]),
keras.layers.Dense (150, activation="selu"),
keras.layers.Dense (28 * 28, activation="sigmoid"),
keras.layers.Reshape ( [28, 28])
])
#判別器(Discriminator)​
discriminator = keras.models.Sequential([
keras.layers.Flatten(input_shape= [28, 28]),
keras.layers.Dense(150, activation="selu"),
keras.layers.Dense(100, activation="selu"),
keras.layers.Dense(1, activation="sigmoid")
])
# GAN 模型​
# 將生成器和判別器結合在一起,形成一個生成對抗網絡(GAN)模型​
gan = keras.models.Sequential([generator, discriminator])

5.2檢視鑑別器的模型

discriminator.summary()
raw-image

5.3檢視生成器的模型

gan.summary()
raw-image

5.4檢視結合後GAN的模型

raw-image

6.模型 Compile

設置判別器的損失函數及優化器,對整個 GAN 模型進行編譯,指定損失函數和優化器。

discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")

discriminator.trainable = False #將判別器設置為不可訓練

gan. compile(loss="binary_crossentropy", optimizer="rmsprop")

7.訓練

7.1 隨機分布訓練集

#設置了每次訓練時使用的批次大小為 32​
batch_size = 32
#使用 from_tensor_slices 方法將訓練數據 x_train 轉換為 tf.data.Dataset 對象。​
dataset = tf.data.Dataset. from_tensor_slices (x_train)#先暫存到暫存區,讀取效率較好
# 打亂數據​
dataset = dataset.shuffle(1000)#

# 預取一個批次的數據,以提高數據加載的效率​
#drop_remainder:將剩下的踢除
#prefetch(1):預先準備下一批次data
dataset = dataset. batch (batch_size, drop_remainder=True).prefetch(1)

7.2 建立訓練GAN的方法

實現 GAN 的訓練過程,通過交替訓練生成器和判別器來逐步改進生成器的性能

# 繪製生成圖像用
def plot_multiple_images (images, n=None):
#判斷最后一維度若為 1,則消去最后一維度
if images.shape [-1] == 1:
images = np. squeeze (images, axis=-1)
plt.figure(figsize=(n, 1))

for i in range(n):
plt.subplot(1, n, i + 1)
plt. imshow (images [i], cmap="binary")
plt.axis ("off")
plt.show ()

def train_gan(gan, dataset, batch_size, d, n_epochs=10):
'''
函數參數
gan:GAN 模型,由生成器和判別器組成。
dataset:訓練數據集。
batch_size:每個訓練批次的大小。
d:生成器輸入的噪聲維度。
n_epochs:訓練的總迭代次數。
'''
##拆分 GAN 模型​
generator, discriminator = gan. layers
## 迭代訓練 Epochs​
for epoch in range(n_epochs):
print("Epoch {}/{}". format(epoch + 1, n_epochs))
# 遍歷批次數據集
for x_batch in dataset:
# 第一階段 訓練判別器
# 生成噪聲並生成圖像
noise = tf. random. normal (shape= [batch_size, d])
generated_images = generator (noise)
# 合併真實和假圖像
x_fake_and_real = tf.concat ( [generated_images, x_batch], axis=0)
# 設置判別器的訓練標籤
# 為假圖像設置標籤為 0,真實圖像設置標籤為 1。
y1 = tf. constant([[0.]] * batch_size + [[1.]] * batch_size)
# 訓練判別器
discriminator.trainable = True
discriminator.train_on_batch(x_fake_and_real, y1)#不可用fit(),並進行比較

#第二階段 訓練生成器
# 生成新的噪聲
noise = tf. random. normal (shape= [batch_size, d])
# 設置生成器的訓練標籤
# 為生成器的輸出設置標籤為 1,以欺騙判別器。
y2 = tf. constant([ [1.]] * batch_size)
# 訓練生成器
discriminator.trainable = False
gan.train_on_batch(noise, y2)
## 繪製生成圖像​ 每個 epoch 結束後,顯示生成的圖像。
plot_multiple_images (generated_images, 10)#生成 10張圖

7.3 開始訓練

train_gan(gan, dataset, batch_size, d, n_epochs=10)
raw-image
raw-image

儲存模型

generator.save('generator.h5')











留言
avatar-img
螃蟹_crab的沙龍
159會員
311內容數
本業是影像辨識軟體開發,閒暇時間進修AI相關內容,將學習到的內容寫成文章分享。 興趣是攝影,踏青,探索未知領域。 人生就是不斷的挑戰及自我認清,希望老了躺在床上不會後悔自己什麼都沒做。
螃蟹_crab的沙龍的其他內容
2024/07/27
呈上篇介紹如何訓練模型,此篇就主要介紹如何利用訓練好的模型來生成圖片 [深度學習][Python]DCGAN訓練生成手寫阿拉伯數字_生成篇 生成的結果 生成的圖片大小會根據,當初設置的生成器輸出大小來決定,當你使用生成對抗網絡(GAN)生成圖像時,生成器模型的最後一層通常會決定生成圖
Thumbnail
2024/07/27
呈上篇介紹如何訓練模型,此篇就主要介紹如何利用訓練好的模型來生成圖片 [深度學習][Python]DCGAN訓練生成手寫阿拉伯數字_生成篇 生成的結果 生成的圖片大小會根據,當初設置的生成器輸出大小來決定,當你使用生成對抗網絡(GAN)生成圖像時,生成器模型的最後一層通常會決定生成圖
Thumbnail
2024/07/27
本文參考TensorFlow官網Deep Convolutional Generative Adversarial Network的程式碼來加以實作說明。 示範如何使用深度卷積生成對抗網路(DCGAN) 生成手寫數位影像。
Thumbnail
2024/07/27
本文參考TensorFlow官網Deep Convolutional Generative Adversarial Network的程式碼來加以實作說明。 示範如何使用深度卷積生成對抗網路(DCGAN) 生成手寫數位影像。
Thumbnail
2024/07/26
本文將延續上一篇文章,經由訓練好的GAN模型中的生成器來生成圖片 [深度學習][Python]訓練MLP的GAN模型來生成圖片_訓練篇 [深度學習][Python]訓練CNN的GAN模型來生成圖片_訓練篇 相較之下CNN的GAN生成的效果比較好,但模型也相對比較複雜,訓練時間花的也比較
Thumbnail
2024/07/26
本文將延續上一篇文章,經由訓練好的GAN模型中的生成器來生成圖片 [深度學習][Python]訓練MLP的GAN模型來生成圖片_訓練篇 [深度學習][Python]訓練CNN的GAN模型來生成圖片_訓練篇 相較之下CNN的GAN生成的效果比較好,但模型也相對比較複雜,訓練時間花的也比較
Thumbnail
看更多
你可能也想看
Thumbnail
在 vocus 與你一起探索內容、發掘靈感的路上,我們又將啟動新的冒險——vocus App 正式推出! 現在起,你可以在 iOS App Store 下載全新上架的 vocus App。 無論是在通勤路上、日常空檔,或一天結束後的放鬆時刻,都能自在沈浸在內容宇宙中。
Thumbnail
在 vocus 與你一起探索內容、發掘靈感的路上,我們又將啟動新的冒險——vocus App 正式推出! 現在起,你可以在 iOS App Store 下載全新上架的 vocus App。 無論是在通勤路上、日常空檔,或一天結束後的放鬆時刻,都能自在沈浸在內容宇宙中。
Thumbnail
vocus 慶祝推出 App,舉辦 2026 全站慶。推出精選內容與數位商品折扣,訂單免費與紅包抽獎、新註冊會員專屬活動、Boba Boost 贊助抽紅包,以及全站徵文,並邀請你一起來回顧過去的一年, vocus 與創作者共同留下了哪些精彩創作。
Thumbnail
vocus 慶祝推出 App,舉辦 2026 全站慶。推出精選內容與數位商品折扣,訂單免費與紅包抽獎、新註冊會員專屬活動、Boba Boost 贊助抽紅包,以及全站徵文,並邀請你一起來回顧過去的一年, vocus 與創作者共同留下了哪些精彩創作。
Thumbnail
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 回顧 AI說書 - 從0開始 - 129 中說,Bidirectional Encoder Representations from Transformers (BER
Thumbnail
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 回顧 AI說書 - 從0開始 - 129 中說,Bidirectional Encoder Representations from Transformers (BER
Thumbnail
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 我們已經在 AI說書 - 從0開始 - 114 建立了 Transformer 模型。 現在我們來載入預訓練權重,預訓練的權重包含 Transformer 的智慧
Thumbnail
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 我們已經在 AI說書 - 從0開始 - 114 建立了 Transformer 模型。 現在我們來載入預訓練權重,預訓練的權重包含 Transformer 的智慧
Thumbnail
本篇文章介紹如何使用PyTorch構建和訓練圖神經網絡(GNN),並使用Cora資料集進行節點分類任務。通過模型架構的逐步優化,包括引入批量標準化和獨立的消息傳遞層,調整Dropout和聚合函數,顯著提高了模型的分類準確率。實驗結果表明,經過優化的GNN模型在處理圖結構數據具有強大的性能和應用潛力。
Thumbnail
本篇文章介紹如何使用PyTorch構建和訓練圖神經網絡(GNN),並使用Cora資料集進行節點分類任務。通過模型架構的逐步優化,包括引入批量標準化和獨立的消息傳遞層,調整Dropout和聚合函數,顯著提高了模型的分類準確率。實驗結果表明,經過優化的GNN模型在處理圖結構數據具有強大的性能和應用潛力。
Thumbnail
本文參考TensorFlow官網Deep Convolutional Generative Adversarial Network的程式碼來加以實作說明。 示範如何使用深度卷積生成對抗網路(DCGAN) 生成手寫數位影像。
Thumbnail
本文參考TensorFlow官網Deep Convolutional Generative Adversarial Network的程式碼來加以實作說明。 示範如何使用深度卷積生成對抗網路(DCGAN) 生成手寫數位影像。
Thumbnail
延續上一篇訓練GAM模型,這次我們讓神經網路更多層更複雜一點,來看訓練生成的圖片是否效果會更好。 [深度學習][Python]訓練MLP的GAN模型來生成圖片_訓練篇 資料集分割處理的部分在延續上篇文章,從第五點開始後修改即可,前面都一樣 訓練過程,比較圖 是不是CNN的效果比MLP還要好,
Thumbnail
延續上一篇訓練GAM模型,這次我們讓神經網路更多層更複雜一點,來看訓練生成的圖片是否效果會更好。 [深度學習][Python]訓練MLP的GAN模型來生成圖片_訓練篇 資料集分割處理的部分在延續上篇文章,從第五點開始後修改即可,前面都一樣 訓練過程,比較圖 是不是CNN的效果比MLP還要好,
Thumbnail
本文主要介紹,如何利用GAN生成對抗網路來訓練生成圖片。 利用tensorflow,中的keras來建立生成器及鑑別器互相競爭訓練,最後利用訓練好的生成器來生成圖片。 GAN生成對抗網路的介紹 它由生成網路(Generator Network)和鑑別網路(Discriminator Netwo
Thumbnail
本文主要介紹,如何利用GAN生成對抗網路來訓練生成圖片。 利用tensorflow,中的keras來建立生成器及鑑別器互相競爭訓練,最後利用訓練好的生成器來生成圖片。 GAN生成對抗網路的介紹 它由生成網路(Generator Network)和鑑別網路(Discriminator Netwo
Thumbnail
透過這篇文章,我們將瞭解如何使用PyTorch實作圖神經網絡中的訊息傳遞機制,從定義消息傳遞的類別到實作消息傳遞過程。我們也探討了各種不同的消息傳遞機制,並通過對單次和多次傳遞過程的結果,可以看到節點特徵如何逐步傳遞與更新。
Thumbnail
透過這篇文章,我們將瞭解如何使用PyTorch實作圖神經網絡中的訊息傳遞機制,從定義消息傳遞的類別到實作消息傳遞過程。我們也探討了各種不同的消息傳遞機制,並通過對單次和多次傳遞過程的結果,可以看到節點特徵如何逐步傳遞與更新。
Thumbnail
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 Transformer 可以透過繼承預訓練模型 (Pretrained Model) 來微調 (Fine-Tune) 以執行下游任務。 Pretrained Mo
Thumbnail
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 Transformer 可以透過繼承預訓練模型 (Pretrained Model) 來微調 (Fine-Tune) 以執行下游任務。 Pretrained Mo
Thumbnail
本文主要介紹神經網路訓練辨識的過程,利用fashion_mnist及簡單的神經網路來進行分類。 使用只有兩層的神經網路來訓練辨識fashion_mnist資料。
Thumbnail
本文主要介紹神經網路訓練辨識的過程,利用fashion_mnist及簡單的神經網路來進行分類。 使用只有兩層的神經網路來訓練辨識fashion_mnist資料。
Thumbnail
這篇文章探討了生成式對抗網路中機率分佈的使用與相關的訓練方式,包括Generator不同的點、Distriminator的訓練過程、生成圖片的條件設定等。此外,也提到了GAN訓練的困難與解決方式以及不同的learning方式。文章內容豐富且詳細,涵蓋了GAN的各個相關面向。
Thumbnail
這篇文章探討了生成式對抗網路中機率分佈的使用與相關的訓練方式,包括Generator不同的點、Distriminator的訓練過程、生成圖片的條件設定等。此外,也提到了GAN訓練的困難與解決方式以及不同的learning方式。文章內容豐富且詳細,涵蓋了GAN的各個相關面向。
追蹤感興趣的內容從 Google News 追蹤更多 vocus 的最新精選內容追蹤 Google News