[深度學習]訓練VAE模型用於生成圖片_訓練篇

更新於 2024/07/25閱讀時間約 11 分鐘

本文主要介紹,如何利用VAE變分自編碼器來訓練生成圖片。

訓練集資料將採用TF影像資料庫中的fashion_mnist

VAE變分自編碼器簡單介紹

•VAE(Variational Auto-Encoder)中文名稱變分自編碼器,主要是一種將原始資料編碼到潛在向量空間,再編碼回來的神經網路。

raw-image

在Colab上執行訓練

Python及套件版本

Python version: 3.10.12
NumPy version: 1.25.2
Pandas version: 2.0.3
Matplotlib version: 3.7.1
TensorFlow version: 2.15.0

程式碼

1.載入套件

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

2.載入資料集

from tensorflow.keras.datasets import fashion_mnist
(x_train_set, y_train_set), (x_test, y_test) = fashion_mnist. load_data()
x_train_set = x_train_set / 255.0x_test = x_test / 255.0

3.確認資料集

顯示隨便一張圖

i = 5
print (y_train_set [i])
plt. imshow(x_train_set[i], cmap='binary')
plt.show()
raw-image


4.分割資料集

主要分出訓練跟驗證集

from sklearn.model_selection import train_test_split

x_train, x_valid, y_train, y_valid = train_test_split(x_train_set, y_train_set,

5.建立VAE模型

5.1 建立編碼器模型

class Sampling (keras.layers.Layer): #用於從潛在空間中進行采樣
def call(self, inputs):
mean, log_var = inputs
#ε 是從標準常態分配中抽出的,其個數必須和(log σ**2)一致
e = tf. random. normal(tf.shape (log_var))
# Z = μ + ε * exp(log σ**2/2)
# 最終的潛在變量,這樣生成的潛在變量具有正確的均值和方差。
return mean + e * tf.math.exp (log_var / 2)
#清除背景舊模型
keras.backend.clear_session() #清除背景舊模型
# 設置隨機種子
# 設置 TensorFlow 和 NumPy 的隨機種子,以確保實驗的可重現性​
tf.random.set_seed(1)
np.random.seed (1)

#定義編碼器模型​
d = 10#代表最后輸出時的數量
in_en = keras. layers. Input (shape= [28, 28])
#因結構特殊,不能使用Sequential,改用Functional方式
c = keras.layers.Flatten() (in_en)#(in_en)代表輸入層,要傳到Flatten層
c = keras.layers.Dense (150, activation="selu")(c)
c = keras.layers.Dense(100, activation="selu")(c)

c_mean = keras.layers.Dense(d)(c)
c_log_var = keras.layers.Dense(d)(c)

out_en = Sampling()([c_mean, c_log_var])

var_encoder = keras.models.Model(inputs= [in_en], outputs= [out_en])

var_encoder.summary()
編碼器模型架構

編碼器模型架構

5.2 建立解碼器模型

#in_de:定義了解碼器的輸入層,形狀為 [d],其中 d 是潛在空間的維度(與編碼器中的輸出一致)
in_de = keras.layers.Input (shape= [d])

x = keras. layers.Dense(100, activation="selu")(in_de)
x = keras. layers.Dense(150, activation="selu")(x)
x = keras. layers.Dense(28 * 28, activation="sigmoid")(x)
# Reshape 層:將展平的輸出轉換回 [28, 28] 的形狀,即恢復為原始圖像的形狀。
out_de = keras.layers.Reshape( [28, 28])(x)

var_decoder = keras.models.Model(inputs=[in_de], outputs=[out_de])

codings = var_encoder(in_en) #通過編碼器 var_encoder 將輸入 in_en 轉換為潛在變量。
rec = var_decoder (codings) #通過解碼器 var_decoder 將潛在變量 codings 重建回圖像空間。

5.3 組合成完整的VAE模型

# 建立一個變分自編碼器(Variational Autoencoder, VAE)的完整模型
# 其輸入為 in_en(原始圖像),輸出為 rec(重建的圖像)
var_ae = keras.models.Model(inputs=[in_en], outputs= [rec])

var_ae.summary()
解碼器模型

解碼器模型

6.模型Compile

損失函數、優化器和評估指標的設置後,模型需要被編譯,才能進行訓練。

6.1 設置損失函數

D_KL = -0.5 * tf. math. reduce_sum (
#公式:∑(1+log σ**2 - σ**2 - μ**2) ; σ**2 = exp(log σ**2)
1 + c_log_var - tf. math.exp (c_log_var) - tf.math. square (c_mean),
axis=1)
#所有累加值求平均
latent_loss = tf.math.reduce_mean(D_KL) / 784.0 #784個神經元
var_ae.add_loss(latent_loss)

6.2 設置評估指標

def rounded_accuracy (y_true, y_pred):
# binary_accuracy用來計算二類分類問題的準確率。準確率是指預測值與真實標籤相匹配的比例。​
return keras.metrics.binary_accuracy(tf.round (y_true), tf. round (y_pred)

6.3 Compile

  • loss='binary_crossentropy':設置模型的損失函數為二元交叉熵,適合二類分類問題。
  • optimizer='rmsprop':使用 RMSprop 優化器來更新模型的權重。
  • metrics=[rounded_accuracy]:設置自定義的準確率指標 rounded_accuracy 來監控模型在訓練過程中的性能。
var_ae.compile(loss='binary_crossentropy',
optimizer="rmsprop",
metrics= [rounded_accuracy])
  • 損失函數:決定了模型在訓練過程中如何計算誤差並調整參數。
  • 優化器:決定了模型權重的更新方式,影響訓練過程的效率和效果。
  • 評估指標:提供了在訓練和評估過程中監控模型性能的方法,使你可以追蹤模型的學習情況。

7.訓練

x_train:這是模型的訓練數據,通常是原始圖像或數據集。

x_train :作為第二個參數:在自編碼器(包括變分自編碼器)中,訓練數據的輸入和目標是相同的,因為目標是將輸入數據重建回來。因此,這裡的 x_train 既是輸入數據也是目標數據。

epochs:訓練過程中完整遍歷訓練數據集的次數。這裡設定為 20,表示將訓練 20 個循環(每個循環都遍歷一次完整的訓練數據集)。

validation_data:這是用於模型驗證的數據集。與訓練數據類似,驗證數據的輸入和目標也是相同的。

train = var_ae.fit(x_train, x_train,
epochs=20, batch_size=256,
validation_data=(x_valid, x_valid))
raw-image

8.查看訓練歷史紀錄

#將這些歷史數據轉換為 Pandas DataFrame,以便於進行進一步的分析和可視化。​
pd.DataFrame(train.history).plot()
plt.grid(True)
plt.show()
raw-image

9.評估模型性能

返回損失值及準確率

var_ae.evaluate(x_test, x_test)

10.比較模型的重建效果

將前 5 個測試樣本的原始圖像和相應的重建圖像並排顯示,以比較模型的重建效果。

這是評估變分自編碼器性能的一個直觀方法,幫助你檢查模型是否能夠有效地重建圖像,並理解其在數據生成方面的能力。

plt.figure(figsize=(10, 4))
j = 0
for i in range (5,10): #印出 第六筆到第十筆
plt.subplot(2, 5, 1 + j)#編號 1開始
plt.imshow(x_test[i], cmap='binary')
plt.title('original')
plt.axis('off')

plt.subplot(2, 5, 1 + 5 + j)#編號 6開始
plt.imshow (x_test_decoded [i], cmap='binary')
plt.title('reconstructed')
plt.axis('off')
j += 1
print(j)
plt.show()
第六到第十

第六到第十

第一筆到第五筆

第一筆到第五筆

11.儲存模型

var_decoder.save('var_decoder.h5')


如何使用儲存好的模型,在下一篇文章展示

[深度學習]訓練VAE模型用於生成圖片_生成篇





avatar-img
128會員
209內容數
本業是影像辨識軟體開發,閒暇時間進修AI相關內容,將學習到的內容寫成文章分享。
留言0
查看全部
avatar-img
發表第一個留言支持創作者!
螃蟹_crab的沙龍 的其他內容
本文下方連結的文章,利用Stable Diffusion生成512 * 512大小的圖片。 輸入的文字是 dog flying in space,此模型需輸入英文句子才會準確生成。 參考文獻 連結該作者在Hugging Face公開的模型去做使用。 本文是在Colab上執行。
長短期記憶(英語:Long Short-Term Memory,LSTM)是一種時間循環神經網路(RNN),論文首次發表於1997年。 LSTM(長短期記憶)是一種特定類型的遞歸神經網絡(RNN),在許多需要處理時間序列數據或順序數據的應用中非常有用。 以下是一些常見的 LSTM 應用:
先前上一篇是使用NLT內置的電影評論數據集 movie_reviews,來訓練出情感分析模型,此篇文章介紹可以導入自己的訓練資料集來建立情感分析模組。 [Python][自然語言]NLTK 實現電影評論情感分析 所需套件 pip install pandas pip install sci
情感分析是一種自然語言處理技術,用於自動識別和分析文本中的情感傾向,通常是正向、負向或中性。 我們可以使用 NLTK 來實現一個基於單純貝斯分類器的情感分析模型。
本文介紹了流行的Python套件NLTK(Natural Language Toolkit)的主要特點、功能和在中文和英文語料上的應用。從安裝到實際應用,深入介紹了分詞、停用詞去除、詞性標註、命名實體識別等NLP任務的具體實現和步驟,幫助讀者理解和應用NLTK。
本文利用pyqt5,使用pyttsx3將QLineEdit(單行輸入框)的字串,轉成語音呈現出來。
本文下方連結的文章,利用Stable Diffusion生成512 * 512大小的圖片。 輸入的文字是 dog flying in space,此模型需輸入英文句子才會準確生成。 參考文獻 連結該作者在Hugging Face公開的模型去做使用。 本文是在Colab上執行。
長短期記憶(英語:Long Short-Term Memory,LSTM)是一種時間循環神經網路(RNN),論文首次發表於1997年。 LSTM(長短期記憶)是一種特定類型的遞歸神經網絡(RNN),在許多需要處理時間序列數據或順序數據的應用中非常有用。 以下是一些常見的 LSTM 應用:
先前上一篇是使用NLT內置的電影評論數據集 movie_reviews,來訓練出情感分析模型,此篇文章介紹可以導入自己的訓練資料集來建立情感分析模組。 [Python][自然語言]NLTK 實現電影評論情感分析 所需套件 pip install pandas pip install sci
情感分析是一種自然語言處理技術,用於自動識別和分析文本中的情感傾向,通常是正向、負向或中性。 我們可以使用 NLTK 來實現一個基於單純貝斯分類器的情感分析模型。
本文介紹了流行的Python套件NLTK(Natural Language Toolkit)的主要特點、功能和在中文和英文語料上的應用。從安裝到實際應用,深入介紹了分詞、停用詞去除、詞性標註、命名實體識別等NLP任務的具體實現和步驟,幫助讀者理解和應用NLTK。
本文利用pyqt5,使用pyttsx3將QLineEdit(單行輸入框)的字串,轉成語音呈現出來。
你可能也想看
Google News 追蹤
Thumbnail
本文探討了複利效應的重要性,並藉由巴菲特的投資理念,說明如何選擇穩定產生正報酬的資產及長期持有的核心理念。透過定期定額的投資方式,不僅能減少情緒影響,還能持續參與全球股市的發展。此外,文中介紹了使用國泰 Cube App 的便利性及低手續費,幫助投資者簡化投資流程,達成長期穩定增長的財務目標。
Thumbnail
前言 讀了許多理論,是時候實際動手做做看了,以下是我的模型訓練初體驗,有點糟就是了XD。 正文 def conv(filters, kernel_size, strides=1): return Conv2D(filters, kernel_size,
三年前,我開始鑽研卡片盒筆記法,逐漸體會到做筆記最困難的部分,其實是看見資訊的內部連結。這種筆記方法不僅能幫助我們更好地組織資料,還能提升研究的效率和質量。 ▋理解資訊連結 卡片盒筆記法的核心在於理解和整理資訊之間的關聯性。這並非僅僅是將資訊羅列起來,而是需要將零散的資訊點連結成一個有機
Thumbnail
在現今少子化的時代,提升學習效率至關重要。卡爾·紐波特的書《DEEP WORK深度工作力》提供了有效的時間管理和學習策略,能夠幫助我們在競爭激烈的社會中脫穎而出。書中介紹的學習方法和策略,不僅適用於大學生,也可應用在日常生活中,幫助我們擁有良好的學習力,增進生活效率。
Thumbnail
這本書訪談了大學學生,並且歸納出幾點建議,書中也提到不必每條條都嚴格遵守,而是選擇一組吸引你的規則,並在大學生活中履行。 我自己在看這本書的時候,結合自己的大學經歷,選取幾點我比較有感觸的部分,分為以下幾點,後面則會提到一些關於書中內容反思
Thumbnail
透過麗鳳督導在心理諮商上的應用,能夠讓我們看待個案問題時有了全新的視角。學理論要浸泡到自動化思考,分析個案時需要考慮家庭結構、互動關係和人際界線等重要元素。此外,心理諮商師需用關係去理解表徵問題,並運用大量的探問與對話,從而從症狀到系統的探索。
Thumbnail
不是只是硬記硬背,而是要用對方法。學習,不分年紀,不分時候,我們隨時都在學習,但有良好的學習技能,像故事/小說書中,電影裡那些擁有超能力的人一樣,可以在自己想學的技能中,一眼就記住,過目不忘的技能,如果擁有或許也是一件不錯的事,但切換到現實,我們認真學習,雖然也能記住,但所要花費的時間成本...
深度學習是機器學習的一個分支,它使用多層神經網絡來模擬和解決複雜的問題。有許多不同的深度學習框架可供選擇,這些框架提供了用於訓練神經網絡的工具和函數。以下是一些常用的深度學習框架的簡介: TensorFlow: TensorFlow由Google開發,是最流行的深度學習框架之一。它具有靈活的計算
Thumbnail
在這資訊爆炸的時代,「知識淺薄化」已成為許多學者和觀察家所擔憂的現象。但究竟這是否真的是個問題,還是只是我們社會進步的副產品?我們需要更深入地去探索。 瞬息萬變的資訊環境確實改變了我們消化資訊的方式。從報紙、雜誌到網絡新聞,再到社群媒體的推送資訊,我們每天都在瀏覽大量的信息,但往往只能浮於表面。隨
Thumbnail
如何與錯誤打交道,就是對於自身的錯誤的察覺,又或者是對於所學的知識正確性如何思辨。 大家好,今天我們來談談「第二層思考」。這是一個相當重要的概念,尤其在現代社會中,我們需要面對各種各樣的資訊和知識,但有時候這些資訊和知識並不是那麼正確。所以,我們必須學會用第二層思考去判斷和分析這些資訊和知識。 首先
Thumbnail
在資訊繁複的環境中,你是否經常感到學習過於分散、無法深入,或是難以理解自己的內心世界? 在這個瞬息萬變的世界裡,我們都渴望抓住並善用所獲得的知識,並深入理解自己的內在世界。你是否曾經苦惱過如何有效地管理學習,讓思考更有深度,或者如何更好地理解自己的情感和想法?這篇文章會為你揭示三個能有效解答這些問題
Thumbnail
本文探討了複利效應的重要性,並藉由巴菲特的投資理念,說明如何選擇穩定產生正報酬的資產及長期持有的核心理念。透過定期定額的投資方式,不僅能減少情緒影響,還能持續參與全球股市的發展。此外,文中介紹了使用國泰 Cube App 的便利性及低手續費,幫助投資者簡化投資流程,達成長期穩定增長的財務目標。
Thumbnail
前言 讀了許多理論,是時候實際動手做做看了,以下是我的模型訓練初體驗,有點糟就是了XD。 正文 def conv(filters, kernel_size, strides=1): return Conv2D(filters, kernel_size,
三年前,我開始鑽研卡片盒筆記法,逐漸體會到做筆記最困難的部分,其實是看見資訊的內部連結。這種筆記方法不僅能幫助我們更好地組織資料,還能提升研究的效率和質量。 ▋理解資訊連結 卡片盒筆記法的核心在於理解和整理資訊之間的關聯性。這並非僅僅是將資訊羅列起來,而是需要將零散的資訊點連結成一個有機
Thumbnail
在現今少子化的時代,提升學習效率至關重要。卡爾·紐波特的書《DEEP WORK深度工作力》提供了有效的時間管理和學習策略,能夠幫助我們在競爭激烈的社會中脫穎而出。書中介紹的學習方法和策略,不僅適用於大學生,也可應用在日常生活中,幫助我們擁有良好的學習力,增進生活效率。
Thumbnail
這本書訪談了大學學生,並且歸納出幾點建議,書中也提到不必每條條都嚴格遵守,而是選擇一組吸引你的規則,並在大學生活中履行。 我自己在看這本書的時候,結合自己的大學經歷,選取幾點我比較有感觸的部分,分為以下幾點,後面則會提到一些關於書中內容反思
Thumbnail
透過麗鳳督導在心理諮商上的應用,能夠讓我們看待個案問題時有了全新的視角。學理論要浸泡到自動化思考,分析個案時需要考慮家庭結構、互動關係和人際界線等重要元素。此外,心理諮商師需用關係去理解表徵問題,並運用大量的探問與對話,從而從症狀到系統的探索。
Thumbnail
不是只是硬記硬背,而是要用對方法。學習,不分年紀,不分時候,我們隨時都在學習,但有良好的學習技能,像故事/小說書中,電影裡那些擁有超能力的人一樣,可以在自己想學的技能中,一眼就記住,過目不忘的技能,如果擁有或許也是一件不錯的事,但切換到現實,我們認真學習,雖然也能記住,但所要花費的時間成本...
深度學習是機器學習的一個分支,它使用多層神經網絡來模擬和解決複雜的問題。有許多不同的深度學習框架可供選擇,這些框架提供了用於訓練神經網絡的工具和函數。以下是一些常用的深度學習框架的簡介: TensorFlow: TensorFlow由Google開發,是最流行的深度學習框架之一。它具有靈活的計算
Thumbnail
在這資訊爆炸的時代,「知識淺薄化」已成為許多學者和觀察家所擔憂的現象。但究竟這是否真的是個問題,還是只是我們社會進步的副產品?我們需要更深入地去探索。 瞬息萬變的資訊環境確實改變了我們消化資訊的方式。從報紙、雜誌到網絡新聞,再到社群媒體的推送資訊,我們每天都在瀏覽大量的信息,但往往只能浮於表面。隨
Thumbnail
如何與錯誤打交道,就是對於自身的錯誤的察覺,又或者是對於所學的知識正確性如何思辨。 大家好,今天我們來談談「第二層思考」。這是一個相當重要的概念,尤其在現代社會中,我們需要面對各種各樣的資訊和知識,但有時候這些資訊和知識並不是那麼正確。所以,我們必須學會用第二層思考去判斷和分析這些資訊和知識。 首先
Thumbnail
在資訊繁複的環境中,你是否經常感到學習過於分散、無法深入,或是難以理解自己的內心世界? 在這個瞬息萬變的世界裡,我們都渴望抓住並善用所獲得的知識,並深入理解自己的內在世界。你是否曾經苦惱過如何有效地管理學習,讓思考更有深度,或者如何更好地理解自己的情感和想法?這篇文章會為你揭示三個能有效解答這些問題