[深度學習][Python]訓練CNN的GAN模型來生成圖片_訓練篇

閱讀時間約 9 分鐘

延續上一篇訓練GAM模型,這次我們讓神經網路更多層更複雜一點,來看訓練生成的圖片是否效果會更好。

[深度學習][Python]訓練MLP的GAN模型來生成圖片_訓練篇

資料集分割處理的部分在延續上篇文章,從第五點開始後修改即可,前面都一樣

訓練過程,比較圖

是不是CNN的效果比MLP還要好,因為CNN比較好捕捉特徵去學習

CNN

CNN

CNN

MLP

raw-image

程式碼

建立建立GAN模型

跟上一篇比起來此模型使用了卷積層和批量正規化,但因為比較多層訓練時間就會比較久。

# 隨機種子設置​
tf.random.set_seed(1)
np.random. seed (1)
d = 100
#生成器(Generator)
generator = keras.models.Sequential([
keras. layers. Dense(7 * 7 * 128, input_shape= [d]),
keras. layers. Reshape ( [7, 7, 128]),
keras. layers. BatchNormalization (),
keras. layers. Conv2DTranspose(64, kernel_size=5, strides=2,
padding="SAME", activation="selu"),
keras. layers. BatchNormalization (),
keras. layers. Conv2DTranspose(1, kernel_size=5, strides=2,
padding="SAME", activation="tanh"),#輸出:-1~1
])
#判別器(Discriminator)​
discriminator = keras.models.Sequential([
keras. layers. Conv2D (64, kernel_size=5, strides=2,
padding="SAME",activation=keras.layers. LeakyReLU(0.2),
input_shape= [28, 28, 1]),
keras.layers. Dropout (0.4),
keras. layers. Conv2D(128, kernel_size=5, strides=2,
padding="SAME",activation=keras. layers. LeakyReLU(0.2)),
keras. layers. Dropout (0.4),
keras. layers. Flatten(),
keras. layers. Dense(1, activation="sigmoid")
])
# GAN 模型​
# 將生成器和判別器結合在一起,形成一個生成對抗網絡
gan = keras.models.Sequential( [generator, discriminator])

訓練

將訓練資料重新塑形並標準化,並顯示訓練過程

# 因為生成器的輸出使用了 tanh 激活函數,該函數的輸出範圍為 [-1, 1]。
# 重新改變形狀 (60000, 28, 28, 1)​ 將其值轉換到 [-1, 1] 範圍
x_train_dcgan = x_train.reshape(-1, 28, 28, 1)* 2. - 1.

batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices(x_train_dcgan)
dataset = dataset.shuffle (1000)
# batch(batch_size, drop_remainder=True) 將資料集分批,每批大小為 batch_size,並且丟棄最後不足一批的樣本。​
# prefetch(1):預先準備一批數據,以加快數據加載速度。​
dataset = dataset.batch (batch_size, drop_remainder=True) .prefetch (1)
# 調用 train_gan 函數,使用創建的資料集 dataset 訓練 GAN:

train_gan(gan, dataset, batch_size, d, n_epochs=20)
raw-image
raw-image
raw-image
raw-image

儲存生成模型

generator.save('generator_deep.h5')

模型詳細說明

生成器(Generator)

生成器的目的是從隨機噪聲生成假圖像。

generator = keras.models.Sequential([
keras.layers.Dense(7 * 7 * 128, input_shape=[d]),
keras.layers.Reshape([7, 7, 128]),
keras.layers.BatchNormalization(),
keras.layers.Conv2DTranspose(64, kernel_size=5, strides=2, padding="SAME", activation="selu"),
keras.layers.BatchNormalization(),
keras.layers.Conv2DTranspose(1, kernel_size=5, strides=2, padding="SAME", activation="tanh"), # 輸出:-11
])
  1. Dense Layer: 第一層是全連接層,將輸入的隨機向量(形狀為 [d])轉換為大小為 7 * 7 * 128 的張量。
  2. Reshape Layer: 將張量重塑為 7x7,每個位置有 128 個特徵。
  3. BatchNormalization Layer: 進行批量正規化,以加速訓練和穩定模型。
  4. Conv2DTranspose Layer: 反捲積層,將 7x7x128 的張量升尺度為 14x14x64。這裡使用了 SELU 激活函數。
  5. BatchNormalization Layer: 再次進行批量正規化。
  6. Conv2DTranspose Layer: 最後一個反捲積層,將 14x14x64 的張量升尺度為 28x28x1,並使用 tanh 激活函數。輸出圖像的值範圍為 -11

判別器(Discriminator)

判別器的目的是區分真實圖像和生成的假圖像。

discriminator = keras.models.Sequential([
keras.layers.Conv2D(64, kernel_size=5, strides=2, padding="SAME", activation=keras.layers.LeakyReLU(0.2), input_shape=[28, 28, 1]),
keras.layers.Dropout(0.4),
keras.layers.Conv2D(128, kernel_size=5, strides=2, padding="SAME", activation=keras.layers.LeakyReLU(0.2)),
keras.layers.Dropout(0.4),
keras.layers.Flatten(),
keras.layers.Dense(1, activation="sigmoid")
])
  1. Conv2D Layer: 第一層是卷積層,輸入形狀為 28x28x1 的圖像。使用 LeakyReLU 激活函數。
  2. Dropout Layer: 用於防止過擬合。
  3. Conv2D Layer: 第二個卷積層,將圖像進一步壓縮,使用 LeakyReLU 激活函數。
  4. Dropout Layer: 再次使用 Dropout 以防止過擬合。
  5. Flatten Layer: 將多維張量展平為一維。
  6. Dense Layer: 最後一層是全連接層,輸出一個標量值,並使用 sigmoid 激活函數以進行二分類。



avatar-img
128會員
213內容數
本業是影像辨識軟體開發,閒暇時間進修AI相關內容,將學習到的內容寫成文章分享。
留言0
查看全部
avatar-img
發表第一個留言支持創作者!
螃蟹_crab的沙龍 的其他內容
本文主要介紹,如何利用GAN生成對抗網路來訓練生成圖片。 利用tensorflow,中的keras來建立生成器及鑑別器互相競爭訓練,最後利用訓練好的生成器來生成圖片。 GAN生成對抗網路的介紹 它由生成網路(Generator Network)和鑑別網路(Discriminator Netwo
本文將延續上一篇文章,經由訓練好的VAE模型其中的解碼器,來生成圖片。 [深度學習]訓練VAE模型用於生成圖片_訓練篇 輸入產生的隨機雜訊,輸入VAE的解碼器後,生成的圖片
本文主要介紹,如何利用VAE變分自編碼器來訓練生成圖片。 訓練集資料將採用TF影像資料庫中的fashion_mnist VAE變分自編碼器簡單介紹 •VAE(Variational Auto-Encoder)中文名稱變分自編碼器,主要是一種將原始資料編碼到潛在向量空間,再編碼回來的神經網路。
本文下方連結的文章,利用Stable Diffusion生成512 * 512大小的圖片。 輸入的文字是 dog flying in space,此模型需輸入英文句子才會準確生成。 參考文獻 連結該作者在Hugging Face公開的模型去做使用。 本文是在Colab上執行。
長短期記憶(英語:Long Short-Term Memory,LSTM)是一種時間循環神經網路(RNN),論文首次發表於1997年。 LSTM(長短期記憶)是一種特定類型的遞歸神經網絡(RNN),在許多需要處理時間序列數據或順序數據的應用中非常有用。 以下是一些常見的 LSTM 應用:
先前上一篇是使用NLT內置的電影評論數據集 movie_reviews,來訓練出情感分析模型,此篇文章介紹可以導入自己的訓練資料集來建立情感分析模組。 [Python][自然語言]NLTK 實現電影評論情感分析 所需套件 pip install pandas pip install sci
本文主要介紹,如何利用GAN生成對抗網路來訓練生成圖片。 利用tensorflow,中的keras來建立生成器及鑑別器互相競爭訓練,最後利用訓練好的生成器來生成圖片。 GAN生成對抗網路的介紹 它由生成網路(Generator Network)和鑑別網路(Discriminator Netwo
本文將延續上一篇文章,經由訓練好的VAE模型其中的解碼器,來生成圖片。 [深度學習]訓練VAE模型用於生成圖片_訓練篇 輸入產生的隨機雜訊,輸入VAE的解碼器後,生成的圖片
本文主要介紹,如何利用VAE變分自編碼器來訓練生成圖片。 訓練集資料將採用TF影像資料庫中的fashion_mnist VAE變分自編碼器簡單介紹 •VAE(Variational Auto-Encoder)中文名稱變分自編碼器,主要是一種將原始資料編碼到潛在向量空間,再編碼回來的神經網路。
本文下方連結的文章,利用Stable Diffusion生成512 * 512大小的圖片。 輸入的文字是 dog flying in space,此模型需輸入英文句子才會準確生成。 參考文獻 連結該作者在Hugging Face公開的模型去做使用。 本文是在Colab上執行。
長短期記憶(英語:Long Short-Term Memory,LSTM)是一種時間循環神經網路(RNN),論文首次發表於1997年。 LSTM(長短期記憶)是一種特定類型的遞歸神經網絡(RNN),在許多需要處理時間序列數據或順序數據的應用中非常有用。 以下是一些常見的 LSTM 應用:
先前上一篇是使用NLT內置的電影評論數據集 movie_reviews,來訓練出情感分析模型,此篇文章介紹可以導入自己的訓練資料集來建立情感分析模組。 [Python][自然語言]NLTK 實現電影評論情感分析 所需套件 pip install pandas pip install sci
你可能也想看
Google News 追蹤
前言 最近在研究GAT,在網路上看到使用torch和DGL實作的GAT模型的程式碼,就想說下載下來自己跑跑看,這篇文章:Understand Graph Attention Network。途中遇到問題,把找到的解法記錄下來,給也有一樣問題的朋友參考。 正文 在Colab直接使用: !p
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 在某些情況下,別人提供的 Pretrained Transformer Model 效果不盡人意,可能會想要自己做 Pretrained Model,但是這會耗費大量運
Thumbnail
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 我們已經在 AI說書 - 從0開始 - 114 建立了 Transformer 模型。 現在我們來載入預訓練權重,預訓練的權重包含 Transformer 的智慧
Thumbnail
本篇文章介紹如何使用PyTorch構建和訓練圖神經網絡(GNN),並使用Cora資料集進行節點分類任務。通過模型架構的逐步優化,包括引入批量標準化和獨立的消息傳遞層,調整Dropout和聚合函數,顯著提高了模型的分類準確率。實驗結果表明,經過優化的GNN模型在處理圖結構數據具有強大的性能和應用潛力。
Thumbnail
透過這篇文章,我們將瞭解如何使用PyTorch實作圖神經網絡中的訊息傳遞機制,從定義消息傳遞的類別到實作消息傳遞過程。我們也探討了各種不同的消息傳遞機制,並通過對單次和多次傳遞過程的結果,可以看到節點特徵如何逐步傳遞與更新。
Thumbnail
前言 讀了許多理論,是時候實際動手做做看了,以下是我的模型訓練初體驗,有點糟就是了XD。 正文 def conv(filters, kernel_size, strides=1): return Conv2D(filters, kernel_size,
Thumbnail
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 從 AI說書 - 從0開始 - 82 到 AI說書 - 從0開始 - 85 的說明,有一個很重要的結論:最適合您的模型不一定是排行榜上最好的模型,您需要學習 NLP 評
Thumbnail
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 Transformer 可以透過繼承預訓練模型 (Pretrained Model) 來微調 (Fine-Tune) 以執行下游任務。 Pretrained Mo
Thumbnail
這篇文章探討了生成式對抗網路中機率分佈的使用與相關的訓練方式,包括Generator不同的點、Distriminator的訓練過程、生成圖片的條件設定等。此外,也提到了GAN訓練的困難與解決方式以及不同的learning方式。文章內容豐富且詳細,涵蓋了GAN的各個相關面向。
這個頻道將提供以下服務: 深入介紹各種Machine Learning技術 深入介紹各種Deep Learning技術 深入介紹各種Reinforcement Learning技術 深入介紹Probabilistic Graphical Model技術 不定時提供讀書筆記 讓我們一起在未
前言 最近在研究GAT,在網路上看到使用torch和DGL實作的GAT模型的程式碼,就想說下載下來自己跑跑看,這篇文章:Understand Graph Attention Network。途中遇到問題,把找到的解法記錄下來,給也有一樣問題的朋友參考。 正文 在Colab直接使用: !p
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 在某些情況下,別人提供的 Pretrained Transformer Model 效果不盡人意,可能會想要自己做 Pretrained Model,但是這會耗費大量運
Thumbnail
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 我們已經在 AI說書 - 從0開始 - 114 建立了 Transformer 模型。 現在我們來載入預訓練權重,預訓練的權重包含 Transformer 的智慧
Thumbnail
本篇文章介紹如何使用PyTorch構建和訓練圖神經網絡(GNN),並使用Cora資料集進行節點分類任務。通過模型架構的逐步優化,包括引入批量標準化和獨立的消息傳遞層,調整Dropout和聚合函數,顯著提高了模型的分類準確率。實驗結果表明,經過優化的GNN模型在處理圖結構數據具有強大的性能和應用潛力。
Thumbnail
透過這篇文章,我們將瞭解如何使用PyTorch實作圖神經網絡中的訊息傳遞機制,從定義消息傳遞的類別到實作消息傳遞過程。我們也探討了各種不同的消息傳遞機制,並通過對單次和多次傳遞過程的結果,可以看到節點特徵如何逐步傳遞與更新。
Thumbnail
前言 讀了許多理論,是時候實際動手做做看了,以下是我的模型訓練初體驗,有點糟就是了XD。 正文 def conv(filters, kernel_size, strides=1): return Conv2D(filters, kernel_size,
Thumbnail
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 從 AI說書 - 從0開始 - 82 到 AI說書 - 從0開始 - 85 的說明,有一個很重要的結論:最適合您的模型不一定是排行榜上最好的模型,您需要學習 NLP 評
Thumbnail
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 Transformer 可以透過繼承預訓練模型 (Pretrained Model) 來微調 (Fine-Tune) 以執行下游任務。 Pretrained Mo
Thumbnail
這篇文章探討了生成式對抗網路中機率分佈的使用與相關的訓練方式,包括Generator不同的點、Distriminator的訓練過程、生成圖片的條件設定等。此外,也提到了GAN訓練的困難與解決方式以及不同的learning方式。文章內容豐富且詳細,涵蓋了GAN的各個相關面向。
這個頻道將提供以下服務: 深入介紹各種Machine Learning技術 深入介紹各種Deep Learning技術 深入介紹各種Reinforcement Learning技術 深入介紹Probabilistic Graphical Model技術 不定時提供讀書筆記 讓我們一起在未