2024-01-06|閱讀時間 ‧ 約 31 分鐘

[機器學習]CNN學習MNIST 手寫英文字母資料,用網頁展現成果_模型訓練篇

streamlit與github連動程式庫,呈現即時預測手寫英文字母

前言

此篇主要談論如何用CNN學習MNIST訓練出模型

若要看如何Streamlit程式範例 解說如下方連結

[機器學習]CNN學習MNIST 手寫英文字母資料,用網頁展現成果_Streamlit Web應用程式篇

整理了一下,先前學的機器學習利用Colab來訓練出能辨識手寫A~Z英文字母的模型,使用的模型是CNN(Convolutional Neural Network,CNN)模型

訓練好的模型,當然是要拿來應用,成果呈現的方式由streamlit網頁呈現手寫英文字母按辨識,即可跑出辨識結果。

結果圖

程式碼Github連結

如何連動github與stramlit可以參考一下這個文章解釋的蠻清楚的,我就不要班門弄斧了


Github上檔案說明

streamlit 設定中會指定main file path,就是要連動github開啟哪一個檔案


訓練模型的程式碼說明

載入資料

!pip install emnist
下載emnist手寫資料,詳細說明如下網址
https://pypi.org/project/emnist/
# pip install emnist
# Import Dataset(s)
from emnist import list_datasets
list_datasets()

會顯示出emnist有哪些類別的資料庫可以做使用

  1. balanced: 這個資料集是 EMNIST 的均衡版本,意味著每個字母類別的樣本數量相對平均。
  2. byclass: 按字母類別組織的資料集。每個字母類別都有自己的資料子集。
  3. bymerge: 將一些形狀相似的字母合併到一個類別。這是為了處理某些字母形狀相似度較高的情況。
  4. digits: 只包含數字的資料集,沒有字母。
  5. letters: 包含所有字母的資料集,但沒有數字。
  6. mnist: 是 MNIST 資料集的一個子集,只包含手寫數字。
#導入使用EMNIST Letters(包含A~Z)26類別的資料
from emnist import extract_training_samples
x_train, y_train = extract_training_samples('letters')
from emnist import extract_test_samples
x_test, y_test = extract_test_samples('letters')

載入EMNIST Letters(包含A~Z)26類別的資料 在拆分成測試與驗證集

import numpy as np
class_names = [chr(ord('A')+i) for i in range(26)]
''.join(class_names)
np.array(class_names)[y_train[0:26]]

將模型訓練數據集中前 26 個樣本的類別標籤轉換為對應的字母

  • chr(ord('A')+i) 會將 ASCII 碼中的 'A' 起始值加上索引 i,得到相應的字母。
  • for i in range(26) 用於迭代 i 從 0 到 25,以包含 A 到 Z 的所有字母。
  • class_names 這個列表最終包含了所有英文字母。
  • ''.join(class_names) 將列表中的字母用空字符串連接在一起,形成一個完整的字串,即 A 到 Z 的英文字母序列。
  • np.array(class_names) 將 class_names 轉換為 NumPy 陣列。
  • y_train[0:26] 取得 y_train 陣列的前 26 個元素。
  • np.array(class_names)[y_train[0:26]] 根據這些標籤從 class_names 中取得相應的字母。
# 顯示第1張圖片圖像
import matplotlib.pyplot as plt

# 第一筆資料
X2 = x_train[47,:,:]

# 繪製點陣圖,cmap='gray':灰階
plt.imshow(X2.reshape(28,28), cmap='gray')

# 隱藏刻度
plt.axis('off')

# 顯示圖形
plt.show()

進行特徵工程

# 特徵縮放,使用常態化(Normalization),公式 = (x - min) / (max - min)
# 顏色範圍:0~255,所以,公式簡化為 x / 255
# 注意,顏色0為白色,與RGB顏色不同,(0,0,0) 為黑色。
x_train_norm, x_test_norm = x_train / 255.0, x_test / 255.0
x_train_norm[0]

這段程式碼進行了特徵縮放,使用的方法是歸一化(Normalization)。歸一化的目的是將特徵的值縮放到一個標準範圍,這裡是將顏色值(0~255)縮放到 0 到 1 之間。

建立模型結構

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout

# 建立模型
model = Sequential([
Conv2D(filters=32, kernel_size=3, activation='relu', input_shape=(28, 28, 1)),
MaxPooling2D(pool_size=2),
Conv2D(filters=64, kernel_size=3, activation='relu'),
MaxPooling2D(pool_size=2),
Flatten(),
Dense(units=64, activation='relu'),
Dropout(0.5),
Dense(units=26, activation='softmax')
])
# 設定優化器(optimizer)、損失函數(loss)、效能衡量指標(metrics)的類別
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

Conv2D 層: 這是卷積層,用於從圖像中學習特徵。filters=32 表示使用 32 個卷積核,kernel_size=3 表示卷積核的大小是 3x3,activation='relu' 使用 ReLU 激活函數。

MaxPooling2D 層: 這是池化層,用於減小圖像的空間尺寸。pool_size=2 表示使用 2x2 的最大池化。

Flatten 層: 將卷積層和池化層的輸出攤平成一維數組,以便與全連接層相連接。

Dense 層 (全連接層): units=64 表示該層有 64 個神經元,activation='relu' 使用 ReLU 激活函數。

Dropout 層: 這是為了防止過擬合,丟棄 50% 的神經元。

Dense 層 (輸出層): units=26 表示輸出層有 26 個神經元activation='softmax' 使用 softmax 激活函數,這對應於多類別分類問題,每個神經元代表一個字母。

最使用 model.compile 配置模型的優化器(optimizer)、損失函數(loss)和效能衡量指標(metrics)。這裡使用了 Adam 優化器,稀疏分類交叉熵作為損失函數,衡量指標是準確度(accuracy)。模型已經建立完成,可以進行訓練了

history = model.fit(x_train_norm, y_train, epochs=100, batch_size=1000, validation_split=0.2)
  • x_train_norm 是經過歸一化的訓練資料,y_train 是相對應的訓練標籤。
  • epochs=100 指定了訓練的輪數,即模型將對整個訓練數據集進行 100 輪訓練。
  • batch_size=1000 定義了每次訓練更新的樣本數,即每次更新模型權重時,使用的樣本數量。
  • validation_split=0.2 表示將訓練數據的 20% 用於驗證,這有助於監控模型的性能。

評分

# 評分(Score Model)
score=model.evaluate(x_test_norm, y_test, verbose=0)

for i, x in enumerate(score):
print(f'{model.metrics_names[i]}: {score[i]:.4f}')

model.evaluate 方法返回模型的損失值和指定的評估指標值

預測

# 顯示第 9 筆的機率
import numpy as np

predictions = model.predict(x_test_norm[8:9])
print(f'0~9預測機率: {np.around(predictions, 2)}')

這段程式碼用於顯示模型對第9筆測試數據的預測機率。model.predict 方法用於獲得模型對輸入數據的預測結果,這裡你取第9筆數據

模型儲存

# 模型存檔
model.save('model.h5')




希望對大家有所幫助,動動小手按下愛心的圖案給點鼓勵



[機器學習]CNN學習MNIST 手寫英文字母資料,用網頁展現成果_Streamlit Web應用程式篇

分享至
成為作者繼續創作的動力吧!
© 2024 vocus All rights reserved.