歡迎來到Scikit-learn教學系列的第三篇文章!在前兩篇中,我們介紹了Scikit-learn的基礎、環境設置以及資料前處理。這一篇將進入監督式學習的核心,聚焦於分類模型。我們將學習分類問題的基本概念、使用Scikit-learn實現常見分類演算法,並通過實作範例與評估方法來訓練和評估模型。準備好探索機器學習的預測能力吧!
分類問題概述
分類問題是監督式學習的一種,目標是根據特徵預測資料的類別標籤。常見的分類任務包括:- 二元分類:預測兩個類別,例如判斷電子郵件是否為垃圾郵件(是/否)。
- 多類別分類:預測多於兩個類別,例如辨識手寫數字(0-9)。
在Scikit-learn中,分類模型的訓練與預測流程簡單且一致:使用fit()
訓練模型,然後用predict()
進行預測。
常用分類演算法
Scikit-learn提供了多種分類演算法,以下是三種入門級演算法:
- 邏輯回歸(Logistic Regression):基於概率的線性模型,適合二元分類。
- K近鄰(K-Nearest Neighbors, KNN):基於距離的非參數方法,適合小型資料集。
- 決策樹(Decision Tree):基於樹狀結構的模型,易於解釋。
我們將通過實作展示這些演算法的應用。
模型訓練與預測
Scikit-learn的分類器遵循統一的API:
- fit(X, y):使用特徵X和標籤y訓練模型。
- predict(X):對新資料預測類別。
- score(X, y):計算模型在資料上的準確率。
此外,資料通常需要前處理(參考前文),並分割為訓練集與測試集以評估模型表現。
模型評估
為了判斷模型的好壞,我們使用以下指標:
- 準確率(Accuracy):預測正確的比例。
- 混淆矩陣(Confusion Matrix):展示預測類別與真實類別的對應關係。
- 分類報告(Classification Report):包含精確率(Precision)、召回率(Recall)和F1分數。
Scikit-learn的metrics
模組提供了這些評估工具。

實作:使用Iris資料集訓練分類模型
讓我們使用Scikit-learn的內建Iris資料集,訓練並比較三種分類模型(邏輯回歸、KNN、決策樹)。我們將進行資料前處理、模型訓練、預測與評估。
程式碼範例
以下程式碼展示完整的分類流程:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
# 載入Iris資料集
iris = load_iris()
X = iris.data
y = iris.target
feature_names = iris.feature_names
target_names = iris.target_names
# 資料前處理
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 分割訓練集與測試集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.3, random_state=42)
# 定義模型
models = {
'Logistic Regression': LogisticRegression(random_state=42),
'K-Nearest Neighbors': KNeighborsClassifier(n_neighbors=5),
'Decision Tree': DecisionTreeClassifier(random_state=42)
}
# 訓練與評估
results = {}
for name, model in models.items():
# 訓練模型
model.fit(X_train, y_train)
# 預測
y_pred = model.predict(X_test)
# 計算準確率
accuracy = accuracy_score(y_test, y_pred)
results[name] = accuracy
# 輸出結果
print(f"\n{name} Results:")
print(f"Accuracy: {accuracy:.4f}")
print("Classification Report:")
print(classification_report(y_test, y_pred, target_names=target_names))
# 混淆矩陣
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=target_names, yticklabels=target_names)
plt.title(f'Confusion Matrix - {name}')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.savefig(f'iris_cm_{name.lower().replace(" ", "_")}.png')
plt.show()
# 比較模型表現
plt.figure(figsize=(8, 6))
sns.barplot(x=list(results.values()), y=list(results.keys()))
plt.title('Model Accuracy Comparison')
plt.xlabel('Accuracy')
plt.ylabel('Model')
plt.savefig('iris_model_comparison.png')
plt.show()
程式碼解釋
- 資料載入與前處理: 載入Iris資料集,包含4個特徵與3個類別。 使用StandardScaler標準化特徵,確保模型不受不同尺度影響。
- 資料分割:使用train_test_split將資料分為70%訓練集與30%測試集。
- 模型訓練與預測: 訓練三種模型:邏輯回歸、KNN、決策樹。 使用predict對測試集進行預測。
- 模型評估: 計算準確率並輸出分類報告。 繪製混淆矩陣,展示每個類別的預測表現。
- 視覺化:繪製柱狀圖比較不同模型的準確率。
運行程式碼後,你會看到每種模型的準確率、分類報告與混淆矩陣,並生成比較圖表。







練習:比較Wine資料集的分類表現
請完成以下練習:
- 使用Scikit-learn的load_wine()載入Wine資料集。
- 對特徵進行標準化,並將資料分為訓練集與測試集(比例自選)。
- 訓練以下模型:邏輯回歸、KNN(n_neighbors=3)、決策樹。
- 計算每種模型的準確率,繪製混淆矩陣,並保存為wine_cm_模型名稱.png。
- 繪製柱狀圖比較模型準確率,保存為wine_model_comparison.png。
以下是起點程式碼:
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
# 載入Wine資料集
wine = load_wine()
X = wine.data
y = wine.target
target_names = wine.target_names
# 你的程式碼從這裡開始...
總結
恭喜你學會了分類模型的基礎!本篇文章介紹了分類問題、常見演算法(邏輯回歸、KNN、決策樹)、模型訓練與評估方法,並通過Iris資料集展示了完整流程。你現在能夠訓練分類模型並使用混淆矩陣與分類報告評估其表現。
在下一篇文章中,我們將探索監督學習 - 回歸模型入門,學習如何預測連續變量,例如房價或溫度。
資源與進階學習
- Scikit-learn分類文件:https://scikit-learn.org/stable/supervised_learning.html
- 分類實務:《Introduction to Machine Learning with Python》
- 練習平台:Kaggle分類競賽(https://www.kaggle.com/competitions)