[深度學習][Python]訓練MLP的GAN模型來生成圖片_訓練篇

更新於 發佈於 閱讀時間約 11 分鐘

本文主要介紹,如何利用GAN生成對抗網路來訓練生成圖片。

利用tensorflow,中的keras來建立生成器及鑑別器互相競爭訓練,最後利用訓練好的生成器來生成圖片。

GAN生成對抗網路的介紹

它由生成網路(Generator Network)和鑑別網路(Discriminator Network),二個深度類神經網路所組成,二個網路在學習過程中互相的對抗與進化,在學習過程中二個網路相互對抗調整參數,讓每一方都在不斷升級,整個學習的最終目的就是讓鑑別網路誤判生成網絡的輸出結果。

raw-image

由上述,可使用GAN來生成圖片,利用訓練好的生成網路來生成圖片,請注意,生成網路會生成什麼樣的圖片,是由訓練資料集來決定。例如說,要生成小猴子的照片,就得提供GAN一批小猴子照片當模仿參考。


資料集

採用TF影像資料庫中的fashion_mnist

https://www.tensorflow.org/datasets/catalog/fashion_mnist

https://www.tensorflow.org/datasets/catalog/fashion_mnist

在Colab上執行訓練

Python及套件版本

Python version: 3.10.12
matplotlib version: 3.7.1
numpy version: 1.25.2
tensorflow version: 2.15.0
Pandas version: 2.0.3

程式碼

1.載入套件

import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

2.載入資料集

從tensorflow.keras.datasets下載fashion_mnist資料集

from tensorflow.keras.datasets import 

(x_train_set, y_train_set), (x_test, y_test) = fashion_mnist. load_data()

3.確認資料集

Show出一張圖,確認是fashion_mnist的圖片

i = 6
print (y_train_set [i])
plt. imshow(x_train_set[i], cmap='binary')
plt.show()
raw-image

4.分割資料集

主要分出訓練跟驗證集

from sklearn.model_selection import train_test_split

x_train, x_valid, y_train, y_valid = train_test_split(x_train_set, y_train_set,
test_size=0.1, random_state=1)

5. 建立GAM模型

5.1設置了一個基本的 GAN

# 隨機種子設置​
tf.random.set_seed(1) #確保程式碼在不同次運行時能夠產生相同的隨機數據,從而提高可重現性
np.random.seed (1)
d = 30 #Codeing數量
#生成器(Generator)​
generator = keras.models.Sequential([
keras.layers.Dense (100, activation="selu", input_shape= [d]),
keras.layers.Dense (150, activation="selu"),
keras.layers.Dense (28 * 28, activation="sigmoid"),
keras.layers.Reshape ( [28, 28])
])
#判別器(Discriminator)​
discriminator = keras.models.Sequential([
keras.layers.Flatten(input_shape= [28, 28]),
keras.layers.Dense(150, activation="selu"),
keras.layers.Dense(100, activation="selu"),
keras.layers.Dense(1, activation="sigmoid")
])
# GAN 模型​
# 將生成器和判別器結合在一起,形成一個生成對抗網絡(GAN)模型​
gan = keras.models.Sequential([generator, discriminator])

5.2檢視鑑別器的模型

discriminator.summary()
raw-image

5.3檢視生成器的模型

gan.summary()
raw-image

5.4檢視結合後GAN的模型

raw-image

6.模型 Compile

設置判別器的損失函數及優化器,對整個 GAN 模型進行編譯,指定損失函數和優化器。

discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")

discriminator.trainable = False #將判別器設置為不可訓練

gan. compile(loss="binary_crossentropy", optimizer="rmsprop")

7.訓練

7.1 隨機分布訓練集

#設置了每次訓練時使用的批次大小為 32​
batch_size = 32
#使用 from_tensor_slices 方法將訓練數據 x_train 轉換為 tf.data.Dataset 對象。​
dataset = tf.data.Dataset. from_tensor_slices (x_train)#先暫存到暫存區,讀取效率較好
# 打亂數據​
dataset = dataset.shuffle(1000)#

# 預取一個批次的數據,以提高數據加載的效率​
#drop_remainder:將剩下的踢除
#prefetch(1):預先準備下一批次data
dataset = dataset. batch (batch_size, drop_remainder=True).prefetch(1)

7.2 建立訓練GAN的方法

實現 GAN 的訓練過程,通過交替訓練生成器和判別器來逐步改進生成器的性能

# 繪製生成圖像用
def plot_multiple_images (images, n=None):
#判斷最后一維度若為 1,則消去最后一維度
if images.shape [-1] == 1:
images = np. squeeze (images, axis=-1)
plt.figure(figsize=(n, 1))

for i in range(n):
plt.subplot(1, n, i + 1)
plt. imshow (images [i], cmap="binary")
plt.axis ("off")
plt.show ()

def train_gan(gan, dataset, batch_size, d, n_epochs=10):
'''
函數參數
gan:GAN 模型,由生成器和判別器組成。
dataset:訓練數據集。
batch_size:每個訓練批次的大小。
d:生成器輸入的噪聲維度。
n_epochs:訓練的總迭代次數。
'''
##拆分 GAN 模型​
generator, discriminator = gan. layers
## 迭代訓練 Epochs​
for epoch in range(n_epochs):
print("Epoch {}/{}". format(epoch + 1, n_epochs))
# 遍歷批次數據集
for x_batch in dataset:
# 第一階段 訓練判別器
# 生成噪聲並生成圖像
noise = tf. random. normal (shape= [batch_size, d])
generated_images = generator (noise)
# 合併真實和假圖像
x_fake_and_real = tf.concat ( [generated_images, x_batch], axis=0)
# 設置判別器的訓練標籤
# 為假圖像設置標籤為 0,真實圖像設置標籤為 1。
y1 = tf. constant([[0.]] * batch_size + [[1.]] * batch_size)
# 訓練判別器
discriminator.trainable = True
discriminator.train_on_batch(x_fake_and_real, y1)#不可用fit(),並進行比較

#第二階段 訓練生成器
# 生成新的噪聲
noise = tf. random. normal (shape= [batch_size, d])
# 設置生成器的訓練標籤
# 為生成器的輸出設置標籤為 1,以欺騙判別器。
y2 = tf. constant([ [1.]] * batch_size)
# 訓練生成器
discriminator.trainable = False
gan.train_on_batch(noise, y2)
## 繪製生成圖像​ 每個 epoch 結束後,顯示生成的圖像。
plot_multiple_images (generated_images, 10)#生成 10張圖

7.3 開始訓練

train_gan(gan, dataset, batch_size, d, n_epochs=10)
raw-image
raw-image

儲存模型

generator.save('generator.h5')











avatar-img
131會員
217內容數
本業是影像辨識軟體開發,閒暇時間進修AI相關內容,將學習到的內容寫成文章分享。
留言0
查看全部
avatar-img
發表第一個留言支持創作者!
螃蟹_crab的沙龍 的其他內容
本文將延續上一篇文章,經由訓練好的VAE模型其中的解碼器,來生成圖片。 [深度學習]訓練VAE模型用於生成圖片_訓練篇 輸入產生的隨機雜訊,輸入VAE的解碼器後,生成的圖片
本文主要介紹,如何利用VAE變分自編碼器來訓練生成圖片。 訓練集資料將採用TF影像資料庫中的fashion_mnist VAE變分自編碼器簡單介紹 •VAE(Variational Auto-Encoder)中文名稱變分自編碼器,主要是一種將原始資料編碼到潛在向量空間,再編碼回來的神經網路。
本文下方連結的文章,利用Stable Diffusion生成512 * 512大小的圖片。 輸入的文字是 dog flying in space,此模型需輸入英文句子才會準確生成。 參考文獻 連結該作者在Hugging Face公開的模型去做使用。 本文是在Colab上執行。
長短期記憶(英語:Long Short-Term Memory,LSTM)是一種時間循環神經網路(RNN),論文首次發表於1997年。 LSTM(長短期記憶)是一種特定類型的遞歸神經網絡(RNN),在許多需要處理時間序列數據或順序數據的應用中非常有用。 以下是一些常見的 LSTM 應用:
先前上一篇是使用NLT內置的電影評論數據集 movie_reviews,來訓練出情感分析模型,此篇文章介紹可以導入自己的訓練資料集來建立情感分析模組。 [Python][自然語言]NLTK 實現電影評論情感分析 所需套件 pip install pandas pip install sci
情感分析是一種自然語言處理技術,用於自動識別和分析文本中的情感傾向,通常是正向、負向或中性。 我們可以使用 NLTK 來實現一個基於單純貝斯分類器的情感分析模型。
本文將延續上一篇文章,經由訓練好的VAE模型其中的解碼器,來生成圖片。 [深度學習]訓練VAE模型用於生成圖片_訓練篇 輸入產生的隨機雜訊,輸入VAE的解碼器後,生成的圖片
本文主要介紹,如何利用VAE變分自編碼器來訓練生成圖片。 訓練集資料將採用TF影像資料庫中的fashion_mnist VAE變分自編碼器簡單介紹 •VAE(Variational Auto-Encoder)中文名稱變分自編碼器,主要是一種將原始資料編碼到潛在向量空間,再編碼回來的神經網路。
本文下方連結的文章,利用Stable Diffusion生成512 * 512大小的圖片。 輸入的文字是 dog flying in space,此模型需輸入英文句子才會準確生成。 參考文獻 連結該作者在Hugging Face公開的模型去做使用。 本文是在Colab上執行。
長短期記憶(英語:Long Short-Term Memory,LSTM)是一種時間循環神經網路(RNN),論文首次發表於1997年。 LSTM(長短期記憶)是一種特定類型的遞歸神經網絡(RNN),在許多需要處理時間序列數據或順序數據的應用中非常有用。 以下是一些常見的 LSTM 應用:
先前上一篇是使用NLT內置的電影評論數據集 movie_reviews,來訓練出情感分析模型,此篇文章介紹可以導入自己的訓練資料集來建立情感分析模組。 [Python][自然語言]NLTK 實現電影評論情感分析 所需套件 pip install pandas pip install sci
情感分析是一種自然語言處理技術,用於自動識別和分析文本中的情感傾向,通常是正向、負向或中性。 我們可以使用 NLTK 來實現一個基於單純貝斯分類器的情感分析模型。
你可能也想看
Google News 追蹤
Thumbnail
現代社會跟以前不同了,人人都有一支手機,只要打開就可以獲得各種資訊。過去想要辦卡或是開戶就要跑一趟銀行,然而如今科技快速發展之下,金融App無聲無息地進到你生活中。但同樣的,每一家銀行都有自己的App時,我們又該如何選擇呢?(本文係由國泰世華銀行邀約) 今天我會用不同角度帶大家看這款國泰世華CUB
Thumbnail
嘿,大家新年快樂~ 新年大家都在做什麼呢? 跨年夜的我趕工製作某個外包設計案,在工作告一段落時趕上倒數。 然後和兩個小孩過了一個忙亂的元旦。在深夜時刻,看到朋友傳來的解籤網站,興致勃勃熬夜體驗了一下,覺得非常好玩,或許有人玩過了,但還是想寫上來分享紀錄一下~
Thumbnail
本文探討了影像生成模型的多種應用,包括文字、圖像和聲音到影片的生成,涵蓋了GAN、Transformer和Diffusion等技術。透過回顧相關研究,分析影像生成技術的未來趨勢與挑戰,為讀者提供全面的理解與啟示。
前言 最近在研究GAT,在網路上看到使用torch和DGL實作的GAT模型的程式碼,就想說下載下來自己跑跑看,這篇文章:Understand Graph Attention Network。途中遇到問題,把找到的解法記錄下來,給也有一樣問題的朋友參考。 正文 在Colab直接使用: !p
Thumbnail
本篇文章介紹如何使用PyTorch構建和訓練圖神經網絡(GNN),並使用Cora資料集進行節點分類任務。通過模型架構的逐步優化,包括引入批量標準化和獨立的消息傳遞層,調整Dropout和聚合函數,顯著提高了模型的分類準確率。實驗結果表明,經過優化的GNN模型在處理圖結構數據具有強大的性能和應用潛力。
Thumbnail
前言 讀了許多理論,是時候實際動手做做看了,以下是我的模型訓練初體驗,有點糟就是了XD。 正文 def conv(filters, kernel_size, strides=1): return Conv2D(filters, kernel_size,
Thumbnail
VQGAN是一種基於GAN(生成對抗式網路)的生成式模型,可以創造新的、逼真的圖像或修改已有圖像。本論文介紹了改進VQGAN用於StableDiffusion中的新方法架構,並提出了一種新的非對稱式VQGAN,具有更強的解碼器和兩個設計條件解碼器。論文下方另附相關資料連結。
Thumbnail
這篇文章探討了生成式對抗網路中機率分佈的使用與相關的訓練方式,包括Generator不同的點、Distriminator的訓練過程、生成圖片的條件設定等。此外,也提到了GAN訓練的困難與解決方式以及不同的learning方式。文章內容豐富且詳細,涵蓋了GAN的各個相關面向。
這個頻道將提供以下服務: 深入介紹各種Machine Learning技術 深入介紹各種Deep Learning技術 深入介紹各種Reinforcement Learning技術 深入介紹Probabilistic Graphical Model技術 不定時提供讀書筆記 讓我們一起在未
Thumbnail
這篇文章介紹瞭如何利用生成式AI(GenAI)來提高學習效率,包括文章重點整理、完善知識體系、客製化學習回饋、提供多元觀點等方法。同時提醒使用者應注意內容的信效度,保持學術誠信,適當運用GenAI能大幅提升工作效率。
Thumbnail
生成式AI(Generative AI)是近年來人工智慧領域中備受矚目的技術之一。它以機器學習為基礎,通過學習大量數據中的模式和關係,能夠生成各種新的內容,涵蓋文字、圖像、音訊等多個領域。本文將深入探討生成式AI的原理、優缺點以及應用範疇。
Thumbnail
可能包含敏感內容
這邊紀錄使用Bing images create 生成原圖並利用Tensor art
Thumbnail
現代社會跟以前不同了,人人都有一支手機,只要打開就可以獲得各種資訊。過去想要辦卡或是開戶就要跑一趟銀行,然而如今科技快速發展之下,金融App無聲無息地進到你生活中。但同樣的,每一家銀行都有自己的App時,我們又該如何選擇呢?(本文係由國泰世華銀行邀約) 今天我會用不同角度帶大家看這款國泰世華CUB
Thumbnail
嘿,大家新年快樂~ 新年大家都在做什麼呢? 跨年夜的我趕工製作某個外包設計案,在工作告一段落時趕上倒數。 然後和兩個小孩過了一個忙亂的元旦。在深夜時刻,看到朋友傳來的解籤網站,興致勃勃熬夜體驗了一下,覺得非常好玩,或許有人玩過了,但還是想寫上來分享紀錄一下~
Thumbnail
本文探討了影像生成模型的多種應用,包括文字、圖像和聲音到影片的生成,涵蓋了GAN、Transformer和Diffusion等技術。透過回顧相關研究,分析影像生成技術的未來趨勢與挑戰,為讀者提供全面的理解與啟示。
前言 最近在研究GAT,在網路上看到使用torch和DGL實作的GAT模型的程式碼,就想說下載下來自己跑跑看,這篇文章:Understand Graph Attention Network。途中遇到問題,把找到的解法記錄下來,給也有一樣問題的朋友參考。 正文 在Colab直接使用: !p
Thumbnail
本篇文章介紹如何使用PyTorch構建和訓練圖神經網絡(GNN),並使用Cora資料集進行節點分類任務。通過模型架構的逐步優化,包括引入批量標準化和獨立的消息傳遞層,調整Dropout和聚合函數,顯著提高了模型的分類準確率。實驗結果表明,經過優化的GNN模型在處理圖結構數據具有強大的性能和應用潛力。
Thumbnail
前言 讀了許多理論,是時候實際動手做做看了,以下是我的模型訓練初體驗,有點糟就是了XD。 正文 def conv(filters, kernel_size, strides=1): return Conv2D(filters, kernel_size,
Thumbnail
VQGAN是一種基於GAN(生成對抗式網路)的生成式模型,可以創造新的、逼真的圖像或修改已有圖像。本論文介紹了改進VQGAN用於StableDiffusion中的新方法架構,並提出了一種新的非對稱式VQGAN,具有更強的解碼器和兩個設計條件解碼器。論文下方另附相關資料連結。
Thumbnail
這篇文章探討了生成式對抗網路中機率分佈的使用與相關的訓練方式,包括Generator不同的點、Distriminator的訓練過程、生成圖片的條件設定等。此外,也提到了GAN訓練的困難與解決方式以及不同的learning方式。文章內容豐富且詳細,涵蓋了GAN的各個相關面向。
這個頻道將提供以下服務: 深入介紹各種Machine Learning技術 深入介紹各種Deep Learning技術 深入介紹各種Reinforcement Learning技術 深入介紹Probabilistic Graphical Model技術 不定時提供讀書筆記 讓我們一起在未
Thumbnail
這篇文章介紹瞭如何利用生成式AI(GenAI)來提高學習效率,包括文章重點整理、完善知識體系、客製化學習回饋、提供多元觀點等方法。同時提醒使用者應注意內容的信效度,保持學術誠信,適當運用GenAI能大幅提升工作效率。
Thumbnail
生成式AI(Generative AI)是近年來人工智慧領域中備受矚目的技術之一。它以機器學習為基礎,通過學習大量數據中的模式和關係,能夠生成各種新的內容,涵蓋文字、圖像、音訊等多個領域。本文將深入探討生成式AI的原理、優缺點以及應用範疇。
Thumbnail
可能包含敏感內容
這邊紀錄使用Bing images create 生成原圖並利用Tensor art