[深度學習][Python]訓練MLP的GAN模型來生成圖片_訓練篇

更新於 2024/07/27閱讀時間約 11 分鐘

本文主要介紹,如何利用GAN生成對抗網路來訓練生成圖片。

利用tensorflow,中的keras來建立生成器及鑑別器互相競爭訓練,最後利用訓練好的生成器來生成圖片。

GAN生成對抗網路的介紹

它由生成網路(Generator Network)和鑑別網路(Discriminator Network),二個深度類神經網路所組成,二個網路在學習過程中互相的對抗與進化,在學習過程中二個網路相互對抗調整參數,讓每一方都在不斷升級,整個學習的最終目的就是讓鑑別網路誤判生成網絡的輸出結果。

raw-image

由上述,可使用GAN來生成圖片,利用訓練好的生成網路來生成圖片,請注意,生成網路會生成什麼樣的圖片,是由訓練資料集來決定。例如說,要生成小猴子的照片,就得提供GAN一批小猴子照片當模仿參考。


資料集

採用TF影像資料庫中的fashion_mnist

https://www.tensorflow.org/datasets/catalog/fashion_mnist

https://www.tensorflow.org/datasets/catalog/fashion_mnist

在Colab上執行訓練

Python及套件版本

Python version: 3.10.12
matplotlib version: 3.7.1
numpy version: 1.25.2
tensorflow version: 2.15.0
Pandas version: 2.0.3

程式碼

1.載入套件

import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

2.載入資料集

從tensorflow.keras.datasets下載fashion_mnist資料集

from tensorflow.keras.datasets import 

(x_train_set, y_train_set), (x_test, y_test) = fashion_mnist. load_data()

3.確認資料集

Show出一張圖,確認是fashion_mnist的圖片

i = 6
print (y_train_set [i])
plt. imshow(x_train_set[i], cmap='binary')
plt.show()
raw-image

4.分割資料集

主要分出訓練跟驗證集

from sklearn.model_selection import train_test_split

x_train, x_valid, y_train, y_valid = train_test_split(x_train_set, y_train_set,
test_size=0.1, random_state=1)

5. 建立GAM模型

5.1設置了一個基本的 GAN

# 隨機種子設置​
tf.random.set_seed(1) #確保程式碼在不同次運行時能夠產生相同的隨機數據,從而提高可重現性
np.random.seed (1)
d = 30 #Codeing數量
#生成器(Generator)​
generator = keras.models.Sequential([
keras.layers.Dense (100, activation="selu", input_shape= [d]),
keras.layers.Dense (150, activation="selu"),
keras.layers.Dense (28 * 28, activation="sigmoid"),
keras.layers.Reshape ( [28, 28])
])
#判別器(Discriminator)​
discriminator = keras.models.Sequential([
keras.layers.Flatten(input_shape= [28, 28]),
keras.layers.Dense(150, activation="selu"),
keras.layers.Dense(100, activation="selu"),
keras.layers.Dense(1, activation="sigmoid")
])
# GAN 模型​
# 將生成器和判別器結合在一起,形成一個生成對抗網絡(GAN)模型​
gan = keras.models.Sequential([generator, discriminator])

5.2檢視鑑別器的模型

discriminator.summary()
raw-image

5.3檢視生成器的模型

gan.summary()
raw-image

5.4檢視結合後GAN的模型

raw-image

6.模型 Compile

設置判別器的損失函數及優化器,對整個 GAN 模型進行編譯,指定損失函數和優化器。

discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")

discriminator.trainable = False #將判別器設置為不可訓練

gan. compile(loss="binary_crossentropy", optimizer="rmsprop")

7.訓練

7.1 隨機分布訓練集

#設置了每次訓練時使用的批次大小為 32​
batch_size = 32
#使用 from_tensor_slices 方法將訓練數據 x_train 轉換為 tf.data.Dataset 對象。​
dataset = tf.data.Dataset. from_tensor_slices (x_train)#先暫存到暫存區,讀取效率較好
# 打亂數據​
dataset = dataset.shuffle(1000)#

# 預取一個批次的數據,以提高數據加載的效率​
#drop_remainder:將剩下的踢除
#prefetch(1):預先準備下一批次data
dataset = dataset. batch (batch_size, drop_remainder=True).prefetch(1)

7.2 建立訓練GAN的方法

實現 GAN 的訓練過程,通過交替訓練生成器和判別器來逐步改進生成器的性能

# 繪製生成圖像用
def plot_multiple_images (images, n=None):
#判斷最后一維度若為 1,則消去最后一維度
if images.shape [-1] == 1:
images = np. squeeze (images, axis=-1)
plt.figure(figsize=(n, 1))

for i in range(n):
plt.subplot(1, n, i + 1)
plt. imshow (images [i], cmap="binary")
plt.axis ("off")
plt.show ()

def train_gan(gan, dataset, batch_size, d, n_epochs=10):
'''
函數參數
gan:GAN 模型,由生成器和判別器組成。
dataset:訓練數據集。
batch_size:每個訓練批次的大小。
d:生成器輸入的噪聲維度。
n_epochs:訓練的總迭代次數。
'''
##拆分 GAN 模型​
generator, discriminator = gan. layers
## 迭代訓練 Epochs​
for epoch in range(n_epochs):
print("Epoch {}/{}". format(epoch + 1, n_epochs))
# 遍歷批次數據集
for x_batch in dataset:
# 第一階段 訓練判別器
# 生成噪聲並生成圖像
noise = tf. random. normal (shape= [batch_size, d])
generated_images = generator (noise)
# 合併真實和假圖像
x_fake_and_real = tf.concat ( [generated_images, x_batch], axis=0)
# 設置判別器的訓練標籤
# 為假圖像設置標籤為 0,真實圖像設置標籤為 1。
y1 = tf. constant([[0.]] * batch_size + [[1.]] * batch_size)
# 訓練判別器
discriminator.trainable = True
discriminator.train_on_batch(x_fake_and_real, y1)#不可用fit(),並進行比較

#第二階段 訓練生成器
# 生成新的噪聲
noise = tf. random. normal (shape= [batch_size, d])
# 設置生成器的訓練標籤
# 為生成器的輸出設置標籤為 1,以欺騙判別器。
y2 = tf. constant([ [1.]] * batch_size)
# 訓練生成器
discriminator.trainable = False
gan.train_on_batch(noise, y2)
## 繪製生成圖像​ 每個 epoch 結束後,顯示生成的圖像。
plot_multiple_images (generated_images, 10)#生成 10張圖

7.3 開始訓練

train_gan(gan, dataset, batch_size, d, n_epochs=10)
raw-image
raw-image

儲存模型

generator.save('generator.h5')











avatar-img
128會員
209內容數
本業是影像辨識軟體開發,閒暇時間進修AI相關內容,將學習到的內容寫成文章分享。
留言0
查看全部
avatar-img
發表第一個留言支持創作者!
螃蟹_crab的沙龍 的其他內容
本文將延續上一篇文章,經由訓練好的VAE模型其中的解碼器,來生成圖片。 [深度學習]訓練VAE模型用於生成圖片_訓練篇 輸入產生的隨機雜訊,輸入VAE的解碼器後,生成的圖片
本文主要介紹,如何利用VAE變分自編碼器來訓練生成圖片。 訓練集資料將採用TF影像資料庫中的fashion_mnist VAE變分自編碼器簡單介紹 •VAE(Variational Auto-Encoder)中文名稱變分自編碼器,主要是一種將原始資料編碼到潛在向量空間,再編碼回來的神經網路。
本文下方連結的文章,利用Stable Diffusion生成512 * 512大小的圖片。 輸入的文字是 dog flying in space,此模型需輸入英文句子才會準確生成。 參考文獻 連結該作者在Hugging Face公開的模型去做使用。 本文是在Colab上執行。
長短期記憶(英語:Long Short-Term Memory,LSTM)是一種時間循環神經網路(RNN),論文首次發表於1997年。 LSTM(長短期記憶)是一種特定類型的遞歸神經網絡(RNN),在許多需要處理時間序列數據或順序數據的應用中非常有用。 以下是一些常見的 LSTM 應用:
先前上一篇是使用NLT內置的電影評論數據集 movie_reviews,來訓練出情感分析模型,此篇文章介紹可以導入自己的訓練資料集來建立情感分析模組。 [Python][自然語言]NLTK 實現電影評論情感分析 所需套件 pip install pandas pip install sci
情感分析是一種自然語言處理技術,用於自動識別和分析文本中的情感傾向,通常是正向、負向或中性。 我們可以使用 NLTK 來實現一個基於單純貝斯分類器的情感分析模型。
本文將延續上一篇文章,經由訓練好的VAE模型其中的解碼器,來生成圖片。 [深度學習]訓練VAE模型用於生成圖片_訓練篇 輸入產生的隨機雜訊,輸入VAE的解碼器後,生成的圖片
本文主要介紹,如何利用VAE變分自編碼器來訓練生成圖片。 訓練集資料將採用TF影像資料庫中的fashion_mnist VAE變分自編碼器簡單介紹 •VAE(Variational Auto-Encoder)中文名稱變分自編碼器,主要是一種將原始資料編碼到潛在向量空間,再編碼回來的神經網路。
本文下方連結的文章,利用Stable Diffusion生成512 * 512大小的圖片。 輸入的文字是 dog flying in space,此模型需輸入英文句子才會準確生成。 參考文獻 連結該作者在Hugging Face公開的模型去做使用。 本文是在Colab上執行。
長短期記憶(英語:Long Short-Term Memory,LSTM)是一種時間循環神經網路(RNN),論文首次發表於1997年。 LSTM(長短期記憶)是一種特定類型的遞歸神經網絡(RNN),在許多需要處理時間序列數據或順序數據的應用中非常有用。 以下是一些常見的 LSTM 應用:
先前上一篇是使用NLT內置的電影評論數據集 movie_reviews,來訓練出情感分析模型,此篇文章介紹可以導入自己的訓練資料集來建立情感分析模組。 [Python][自然語言]NLTK 實現電影評論情感分析 所需套件 pip install pandas pip install sci
情感分析是一種自然語言處理技術,用於自動識別和分析文本中的情感傾向,通常是正向、負向或中性。 我們可以使用 NLTK 來實現一個基於單純貝斯分類器的情感分析模型。
你可能也想看
Google News 追蹤
Thumbnail
徵的就是你 🫵 超ㄅㄧㄤˋ 獎品搭配超瞎趴的四大主題,等你踹共啦!還有機會獲得經典的「偉士牌樂高」喔!馬上來參加本次的活動吧!
Thumbnail
隨著理財資訊的普及,越來越多台灣人不再將資產侷限於台股,而是將視野拓展到國際市場。特別是美國市場,其豐富的理財選擇,讓不少人開始思考將資金配置於海外市場的可能性。 然而,要參與美國市場並不只是盲目跟隨標的這麼簡單,而是需要策略和方式,尤其對新手而言,除了選股以外還會遇到語言、開戶流程、Ap
Thumbnail
在現今少子化的時代,提升學習效率至關重要。卡爾·紐波特的書《DEEP WORK深度工作力》提供了有效的時間管理和學習策略,能夠幫助我們在競爭激烈的社會中脫穎而出。書中介紹的學習方法和策略,不僅適用於大學生,也可應用在日常生活中,幫助我們擁有良好的學習力,增進生活效率。
Thumbnail
本文介紹了self-attention在處理不固定大小輸入值時的應用,並討論瞭如何計算self-attention以及transformer中的multi-head self-attention。此外,文章還探討了在語音辨識和圖片處理中使用self-attention的方法,以及與CNN的比較。
Thumbnail
這本書訪談了大學學生,並且歸納出幾點建議,書中也提到不必每條條都嚴格遵守,而是選擇一組吸引你的規則,並在大學生活中履行。 我自己在看這本書的時候,結合自己的大學經歷,選取幾點我比較有感觸的部分,分為以下幾點,後面則會提到一些關於書中內容反思
Thumbnail
透過麗鳳督導在心理諮商上的應用,能夠讓我們看待個案問題時有了全新的視角。學理論要浸泡到自動化思考,分析個案時需要考慮家庭結構、互動關係和人際界線等重要元素。此外,心理諮商師需用關係去理解表徵問題,並運用大量的探問與對話,從而從症狀到系統的探索。
Thumbnail
不是只是硬記硬背,而是要用對方法。學習,不分年紀,不分時候,我們隨時都在學習,但有良好的學習技能,像故事/小說書中,電影裡那些擁有超能力的人一樣,可以在自己想學的技能中,一眼就記住,過目不忘的技能,如果擁有或許也是一件不錯的事,但切換到現實,我們認真學習,雖然也能記住,但所要花費的時間成本...
深度學習是機器學習的一個分支,它使用多層神經網絡來模擬和解決複雜的問題。有許多不同的深度學習框架可供選擇,這些框架提供了用於訓練神經網絡的工具和函數。以下是一些常用的深度學習框架的簡介: TensorFlow: TensorFlow由Google開發,是最流行的深度學習框架之一。它具有靈活的計算
Thumbnail
如何與錯誤打交道,就是對於自身的錯誤的察覺,又或者是對於所學的知識正確性如何思辨。 大家好,今天我們來談談「第二層思考」。這是一個相當重要的概念,尤其在現代社會中,我們需要面對各種各樣的資訊和知識,但有時候這些資訊和知識並不是那麼正確。所以,我們必須學會用第二層思考去判斷和分析這些資訊和知識。 首先
Thumbnail
「品酒」已經不再是有錢人的權利,在這個美酒當道的年代,我們要如何像 Somm 電影的品酒師,一口就能辨別出「口感」、「年份」、「產地」,甚至預測下一季爆款的酒呢? 情境: 這時候,機器學習與深度學習都是相當好的辦法,但我們要成為好的品酒工程師之前,我們必須學會理解「數據來源」、「產業知識」、「演算法
Thumbnail
師範大學的 陳佩英 教授來訪均一! 讓我們有機會向教授請益有關個人化學習的前瞻發展可能性。 教授很親切給予我們許多建言與引導,聽完教授的回饋,有三個小心得: 真的覺得自己懂得不過廣泛,也不夠深啊! 2. 也很喜歡教授提醒我們要注意工具背後的教育理念。 a. 特別是對非認知能力的評量,不能用行為主義來
Thumbnail
是什麼讓一個人的成長速度比另一個人更快呢? 《深度學習的技術》的作者楊大輝的答案是:學習的深淺。
Thumbnail
徵的就是你 🫵 超ㄅㄧㄤˋ 獎品搭配超瞎趴的四大主題,等你踹共啦!還有機會獲得經典的「偉士牌樂高」喔!馬上來參加本次的活動吧!
Thumbnail
隨著理財資訊的普及,越來越多台灣人不再將資產侷限於台股,而是將視野拓展到國際市場。特別是美國市場,其豐富的理財選擇,讓不少人開始思考將資金配置於海外市場的可能性。 然而,要參與美國市場並不只是盲目跟隨標的這麼簡單,而是需要策略和方式,尤其對新手而言,除了選股以外還會遇到語言、開戶流程、Ap
Thumbnail
在現今少子化的時代,提升學習效率至關重要。卡爾·紐波特的書《DEEP WORK深度工作力》提供了有效的時間管理和學習策略,能夠幫助我們在競爭激烈的社會中脫穎而出。書中介紹的學習方法和策略,不僅適用於大學生,也可應用在日常生活中,幫助我們擁有良好的學習力,增進生活效率。
Thumbnail
本文介紹了self-attention在處理不固定大小輸入值時的應用,並討論瞭如何計算self-attention以及transformer中的multi-head self-attention。此外,文章還探討了在語音辨識和圖片處理中使用self-attention的方法,以及與CNN的比較。
Thumbnail
這本書訪談了大學學生,並且歸納出幾點建議,書中也提到不必每條條都嚴格遵守,而是選擇一組吸引你的規則,並在大學生活中履行。 我自己在看這本書的時候,結合自己的大學經歷,選取幾點我比較有感觸的部分,分為以下幾點,後面則會提到一些關於書中內容反思
Thumbnail
透過麗鳳督導在心理諮商上的應用,能夠讓我們看待個案問題時有了全新的視角。學理論要浸泡到自動化思考,分析個案時需要考慮家庭結構、互動關係和人際界線等重要元素。此外,心理諮商師需用關係去理解表徵問題,並運用大量的探問與對話,從而從症狀到系統的探索。
Thumbnail
不是只是硬記硬背,而是要用對方法。學習,不分年紀,不分時候,我們隨時都在學習,但有良好的學習技能,像故事/小說書中,電影裡那些擁有超能力的人一樣,可以在自己想學的技能中,一眼就記住,過目不忘的技能,如果擁有或許也是一件不錯的事,但切換到現實,我們認真學習,雖然也能記住,但所要花費的時間成本...
深度學習是機器學習的一個分支,它使用多層神經網絡來模擬和解決複雜的問題。有許多不同的深度學習框架可供選擇,這些框架提供了用於訓練神經網絡的工具和函數。以下是一些常用的深度學習框架的簡介: TensorFlow: TensorFlow由Google開發,是最流行的深度學習框架之一。它具有靈活的計算
Thumbnail
如何與錯誤打交道,就是對於自身的錯誤的察覺,又或者是對於所學的知識正確性如何思辨。 大家好,今天我們來談談「第二層思考」。這是一個相當重要的概念,尤其在現代社會中,我們需要面對各種各樣的資訊和知識,但有時候這些資訊和知識並不是那麼正確。所以,我們必須學會用第二層思考去判斷和分析這些資訊和知識。 首先
Thumbnail
「品酒」已經不再是有錢人的權利,在這個美酒當道的年代,我們要如何像 Somm 電影的品酒師,一口就能辨別出「口感」、「年份」、「產地」,甚至預測下一季爆款的酒呢? 情境: 這時候,機器學習與深度學習都是相當好的辦法,但我們要成為好的品酒工程師之前,我們必須學會理解「數據來源」、「產業知識」、「演算法
Thumbnail
師範大學的 陳佩英 教授來訪均一! 讓我們有機會向教授請益有關個人化學習的前瞻發展可能性。 教授很親切給予我們許多建言與引導,聽完教授的回饋,有三個小心得: 真的覺得自己懂得不過廣泛,也不夠深啊! 2. 也很喜歡教授提醒我們要注意工具背後的教育理念。 a. 特別是對非認知能力的評量,不能用行為主義來
Thumbnail
是什麼讓一個人的成長速度比另一個人更快呢? 《深度學習的技術》的作者楊大輝的答案是:學習的深淺。