[深度學習][Python]使用簡單的神經網路來訓練辨識fashion_mnist資料

更新於 發佈於 閱讀時間約 8 分鐘

本文主要介紹神經網路訓練辨識的過程,利用fashion_mnist及簡單的神經網路來進行分類。

使用只有兩層的神經網路來訓練辨識fashion_mnist資料。

顯示上圖fashion_mnist 圖片及類別的程式範例

import tensorflow as tf
from tensorflow.keras.datasets import fashion_mnist
import matplotlib.pyplot as plt

# 類別名稱
class_names = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]

# 載入資料集
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# 定義顯示圖片和標籤的函數
def plot_class_images(images, labels, class_names):
plt.figure(figsize=(10, 10))
for i in range(10):
ax = plt.subplot(5, 2, i + 1)
idx = labels.tolist().index(i)
plt.imshow(images[idx], cmap=plt.cm.gray)
plt.title(f"Class {i}: {class_names[i]}")
plt.axis("off")
plt.show()

# 印出每個類別的編號和對應的名稱
for i, class_name in enumerate(class_names):
print(f"Class {i}: {class_name}")

# 顯示每個類別的圖片
plot_class_images(train_images, train_labels, class_names)

建議用Colab或者是VSCode 在開啟 .ipynb 檔案去跑,可以分段的測試

程式範例

需安裝套件tensorflow scikit-learn

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Data 載入資料
from tensorflow.keras.datasets import fashion_mnist
# 內建的資料集 fashion_mnist
# 切割 訓練檔案 與測試檔案
(x_train_set, y_train_set), (x_test, y_test) = fashion_mnist.load_data()

# Split data 切割資料
from sklearn.model_selection import train_test_split
# 利用sklearn模組來切割資料 將原先訓練資料 在分割出 訓練資料 與驗證資料
x_train, x_valid, y_train, y_valid = train_test_split(x_train_set,
y_train_set,
random_state=1)

# Preprocessing 資料縮放 (0~1)
x_train = x_train / 255.0
x_valid =x_valid / 255.0
x_test = x_test / 255.0

# Build Model 建立模組
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense

model = Sequential([
#輸入層:將 28*28攤平成一維度
Flatten(input_shape=x_train.shape[1:]),
#輸出層:10類別,10個神經元
Dense (units=10, activation='softmax')
])

print(model.summary())

# Compile model.compile 方法用來配置模型的學習過程
model.compile(loss='sparse_categorical_crossentropy',
optimizer='sgd',
metrics= ['accuracy'])

# Train 分批訓練 預設每批32筆分一批次
train = model.fit(x_train, y_train,
epochs=20,
validation_data=(x_valid, y_valid))

#呈現訓練曲線
pd.DataFrame(train.history).plot()
plt.grid(True)
plt.show()

# Evaluate #顯示結果 誤差值與準確性
model.evaluate(x_test, y_test)

# Predict
y_proba = model.predict(x_test) #實際去算y
#實際上預測結果 會是 一個串列裡面包含對10種類別預測的機率值
#所以 要用​argmax挑出最大的數值
y_pred = np.argmax(y_proba, axis=1)
print(f'預測類別是: {y_pred[1]}')
plt.imshow(x_test[1], cmap=plt.cm.gray)

# Confusion matrix 利用混淆矩陣來看 驗證結果
from sklearn.metrics import confusion_matrix

print(confusion_matrix(y_test, y_pred))

呈現結果

模型介紹說明

raw-image

Train 分批訓練 預設每批32筆分一批次

raw-image

Predict

預測的類別是2,秀出檢測的圖片剛好也是類別2,還好沒漏氣

預測結果

預測結果

Confusion matrix 混淆矩陣

紅色斜線代表預測正確的,對照圖,預測標籤0但真實標籤是1的有3個,真實標籤2有誤判13個。

raw-image


分析混淆矩陣

由混淆矩陣可以快速發現,哪裡誤判比較多,讓我們看預測標籤6 誤判 真實標籤0,2,4特別多,往回看一下圖檔,也難怪外輪廓都長得很像,特徵太接近,圖檔又小才28*28,用的神經網路太簡單,難怪學得不好。


raw-image

用更多層一點的神經網路

model = Sequential([
#第一層:將 28*28攤平成一維度
Flatten(input_shape=x_train.shape[1:]), #x_train.shape[1:] 28,28
#第二層
Dense (units=300, activation='relu'), #
Dense (units=200, activation='relu'),
Dense (units=100, activation='relu'),
#輸出層:10類別,10個神經元
Dense (units=10, activation='softmax')
])

混淆矩陣結果

對照一下2層與5層結果,預測6,誤判0 次數從136降到108了,也與理論實際符合在一定情況下神經網路層越多預測結果越好

raw-image













留言
avatar-img
留言分享你的想法!
avatar-img
螃蟹_crab的沙龍
147會員
261內容數
本業是影像辨識軟體開發,閒暇時間進修AI相關內容,將學習到的內容寫成文章分享。
螃蟹_crab的沙龍的其他內容
2024/06/21
微調(Fine tune)是深度學習中遷移學習的一種方法,其中預訓練模型的權重會在新數據上進行訓練。 本文主要介紹如何使用新的訓練圖檔在tesseract 辨識模型進行Fine tune 有關於安裝的部分可以參考友人的其他文章 Tesseract OCR - 繁體中文【安裝篇】 將所有資料
Thumbnail
2024/06/21
微調(Fine tune)是深度學習中遷移學習的一種方法,其中預訓練模型的權重會在新數據上進行訓練。 本文主要介紹如何使用新的訓練圖檔在tesseract 辨識模型進行Fine tune 有關於安裝的部分可以參考友人的其他文章 Tesseract OCR - 繁體中文【安裝篇】 將所有資料
Thumbnail
2024/06/01
平時都在用tesseract來辨識OCR的部分,在網路上也常常聽說easyOCR比tesseract好用,就拿之前測試的OCR素材來比較看看囉。 以下輸入同樣圖片直接測試,並非絕對誰就比較準,只單純測試數字含英文的部分。 圖片素材就是15碼(英文加數字),檔名為OCR正確結果
Thumbnail
2024/06/01
平時都在用tesseract來辨識OCR的部分,在網路上也常常聽說easyOCR比tesseract好用,就拿之前測試的OCR素材來比較看看囉。 以下輸入同樣圖片直接測試,並非絕對誰就比較準,只單純測試數字含英文的部分。 圖片素材就是15碼(英文加數字),檔名為OCR正確結果
Thumbnail
2024/05/26
本文將展示使用不同激活函數(ReLU 和 Sigmoid)的效果。 一個簡單的多層感知器(MLP)模型來對 Fashion-MNIST 資料集進行分類。 函數定義 Sigmoid 函數 Sigmoid 函數將輸入壓縮到 0到 1 之間: 特性: 輸出範圍是 (0,1)(0, 1)(0,1
Thumbnail
2024/05/26
本文將展示使用不同激活函數(ReLU 和 Sigmoid)的效果。 一個簡單的多層感知器(MLP)模型來對 Fashion-MNIST 資料集進行分類。 函數定義 Sigmoid 函數 Sigmoid 函數將輸入壓縮到 0到 1 之間: 特性: 輸出範圍是 (0,1)(0, 1)(0,1
Thumbnail
看更多
你可能也想看
Thumbnail
沙龍一直是創作與交流的重要空間,這次 vocus 全面改版了沙龍介面,就是為了讓好內容被好好看見! 你可以自由編排你的沙龍首頁版位,新版手機介面也讓每位訪客都能更快找到感興趣的內容、成為你的支持者。 改版完成後可以在社群媒體分享新版面,並標記 @vocus.official⁠ ♥️ ⁠
Thumbnail
沙龍一直是創作與交流的重要空間,這次 vocus 全面改版了沙龍介面,就是為了讓好內容被好好看見! 你可以自由編排你的沙龍首頁版位,新版手機介面也讓每位訪客都能更快找到感興趣的內容、成為你的支持者。 改版完成後可以在社群媒體分享新版面,並標記 @vocus.official⁠ ♥️ ⁠
Thumbnail
每年4月、5月都是最多稅要繳的月份,當然大部份的人都是有機會繳到「綜合所得稅」,只是相當相當多人還不知道,原來繳給政府的稅!可以透過一些有活動的銀行信用卡或電子支付來繳,從繳費中賺一點點小確幸!就是賺個1%~2%大家也是很開心的,因為你們把沒回饋變成有回饋,就是用卡的最高境界 所得稅線上申報
Thumbnail
每年4月、5月都是最多稅要繳的月份,當然大部份的人都是有機會繳到「綜合所得稅」,只是相當相當多人還不知道,原來繳給政府的稅!可以透過一些有活動的銀行信用卡或電子支付來繳,從繳費中賺一點點小確幸!就是賺個1%~2%大家也是很開心的,因為你們把沒回饋變成有回饋,就是用卡的最高境界 所得稅線上申報
Thumbnail
本篇文章介紹如何使用PyTorch構建和訓練圖神經網絡(GNN),並使用Cora資料集進行節點分類任務。通過模型架構的逐步優化,包括引入批量標準化和獨立的消息傳遞層,調整Dropout和聚合函數,顯著提高了模型的分類準確率。實驗結果表明,經過優化的GNN模型在處理圖結構數據具有強大的性能和應用潛力。
Thumbnail
本篇文章介紹如何使用PyTorch構建和訓練圖神經網絡(GNN),並使用Cora資料集進行節點分類任務。通過模型架構的逐步優化,包括引入批量標準化和獨立的消息傳遞層,調整Dropout和聚合函數,顯著提高了模型的分類準確率。實驗結果表明,經過優化的GNN模型在處理圖結構數據具有強大的性能和應用潛力。
Thumbnail
延續上一篇訓練GAM模型,這次我們讓神經網路更多層更複雜一點,來看訓練生成的圖片是否效果會更好。 [深度學習][Python]訓練MLP的GAN模型來生成圖片_訓練篇 資料集分割處理的部分在延續上篇文章,從第五點開始後修改即可,前面都一樣 訓練過程,比較圖 是不是CNN的效果比MLP還要好,
Thumbnail
延續上一篇訓練GAM模型,這次我們讓神經網路更多層更複雜一點,來看訓練生成的圖片是否效果會更好。 [深度學習][Python]訓練MLP的GAN模型來生成圖片_訓練篇 資料集分割處理的部分在延續上篇文章,從第五點開始後修改即可,前面都一樣 訓練過程,比較圖 是不是CNN的效果比MLP還要好,
Thumbnail
本文主要介紹,如何利用GAN生成對抗網路來訓練生成圖片。 利用tensorflow,中的keras來建立生成器及鑑別器互相競爭訓練,最後利用訓練好的生成器來生成圖片。 GAN生成對抗網路的介紹 它由生成網路(Generator Network)和鑑別網路(Discriminator Netwo
Thumbnail
本文主要介紹,如何利用GAN生成對抗網路來訓練生成圖片。 利用tensorflow,中的keras來建立生成器及鑑別器互相競爭訓練,最後利用訓練好的生成器來生成圖片。 GAN生成對抗網路的介紹 它由生成網路(Generator Network)和鑑別網路(Discriminator Netwo
Thumbnail
透過這篇文章,我們將瞭解如何使用PyTorch實作圖神經網絡中的訊息傳遞機制,從定義消息傳遞的類別到實作消息傳遞過程。我們也探討了各種不同的消息傳遞機制,並通過對單次和多次傳遞過程的結果,可以看到節點特徵如何逐步傳遞與更新。
Thumbnail
透過這篇文章,我們將瞭解如何使用PyTorch實作圖神經網絡中的訊息傳遞機制,從定義消息傳遞的類別到實作消息傳遞過程。我們也探討了各種不同的消息傳遞機制,並通過對單次和多次傳遞過程的結果,可以看到節點特徵如何逐步傳遞與更新。
Thumbnail
本篇文章專注於消息傳遞(message passing)在圖神經網絡(GNN)中的應用,並以簡單的例子解釋了消息傳遞的過程和機制。
Thumbnail
本篇文章專注於消息傳遞(message passing)在圖神經網絡(GNN)中的應用,並以簡單的例子解釋了消息傳遞的過程和機制。
Thumbnail
本文主要筆記使用pytorch建立graph的幾個概念與實作。在傳統的神經網路模型中,數據點之間往往是互相連接和影響的,使用GNN,我們不僅處理單獨的數據點或Xb,而是處理一個包含多個數據點和它們之間連結的特徵。GNN的優勢在於其能夠將這些連結關係納入模型中,將關係本身作為特徵進行學習。
Thumbnail
本文主要筆記使用pytorch建立graph的幾個概念與實作。在傳統的神經網路模型中,數據點之間往往是互相連接和影響的,使用GNN,我們不僅處理單獨的數據點或Xb,而是處理一個包含多個數據點和它們之間連結的特徵。GNN的優勢在於其能夠將這些連結關係納入模型中,將關係本身作為特徵進行學習。
Thumbnail
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 接著來談 Transformer 架構中的 Feedforward Network (FFN): 其為全連接的神經網路架構 回顧 AI說書 - 從0開始 - 64
Thumbnail
我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。 接著來談 Transformer 架構中的 Feedforward Network (FFN): 其為全連接的神經網路架構 回顧 AI說書 - 從0開始 - 64
Thumbnail
上篇我們簡單的了解了 TTS 想要達到的目標,但是對於訓練資料的處理、網路架構、損失函數、輸出分析等考慮到篇幅尚未解釋清楚,這篇將針對訓練資料處理中的文字部分進行詳細說明,讓我們開始吧。
Thumbnail
上篇我們簡單的了解了 TTS 想要達到的目標,但是對於訓練資料的處理、網路架構、損失函數、輸出分析等考慮到篇幅尚未解釋清楚,這篇將針對訓練資料處理中的文字部分進行詳細說明,讓我們開始吧。
Thumbnail
本文將展示使用不同激活函數(ReLU 和 Sigmoid)的效果。 一個簡單的多層感知器(MLP)模型來對 Fashion-MNIST 資料集進行分類。 函數定義 Sigmoid 函數 Sigmoid 函數將輸入壓縮到 0到 1 之間: 特性: 輸出範圍是 (0,1)(0, 1)(0,1
Thumbnail
本文將展示使用不同激活函數(ReLU 和 Sigmoid)的效果。 一個簡單的多層感知器(MLP)模型來對 Fashion-MNIST 資料集進行分類。 函數定義 Sigmoid 函數 Sigmoid 函數將輸入壓縮到 0到 1 之間: 特性: 輸出範圍是 (0,1)(0, 1)(0,1
Thumbnail
本文主要介紹神經網路訓練辨識的過程,利用fashion_mnist及簡單的神經網路來進行分類。 使用只有兩層的神經網路來訓練辨識fashion_mnist資料。
Thumbnail
本文主要介紹神經網路訓練辨識的過程,利用fashion_mnist及簡單的神經網路來進行分類。 使用只有兩層的神經網路來訓練辨識fashion_mnist資料。
追蹤感興趣的內容從 Google News 追蹤更多 vocus 的最新精選內容追蹤 Google News