使用 Colaboratory 和 Kaggle 資料庫學習線性回歸:梯度下降法與 Scikit-learn 實作

更新於 發佈於 閱讀時間約 14 分鐘
學習理論後,重要的是如何用程式碼實現,這一篇文章我會用 Python 練習手刻梯度下降法,上傳自己的檔案練習。中間有發現數據集跑出來的結果跟我想的不一樣,我也分享如何解決(感謝AI)最後再用 Scikit-learn 加上 Matplotlib 視覺畫圖表,期望之後可以在工作使用到。

使用 Colaboratory 練習,使用 AI 工具輔助,並使用 kaggle 開源資料庫練習

Colaboratory

加上程式碼就可以簡單的練習,不需要架設其他的環境。

raw-image


1. 手刻梯度下降法 (Batch Gradient Descent)

假設我們要做 簡單線性回歸。y = ax+b ,數據我先隨機設定。

實際上應用可以用自己的 Excel 或是 CSV,

import numpy as np

# 生成範例資料 (x:輸入特徵, y:實際值)
X = np.array([1, 2, 7.2, 4.3, 6]) # 單一特徵
y = np.array([3, 7, 8, 9, 11]) # 我隨機設定數字做示範

# 初始化參數
theta_0 = 0 # 截距項 (bias)
theta_1 = 0 # 斜率 (weight)
alpha = 0.01 # 學習率
epochs = 1000 # 迭代次數
m = len(X) # 樣本數 X

# 執行梯度下降
for epoch in range(epochs):
# 計算預測值
y_pred = theta_0 + theta_1 * X

# 計算損失函數 (MSE) #**是次方,所以**2是二次方
loss = np.mean((y_pred - y) ** 2)

# 計算梯度 (偏導數)
d_theta_0 = (2/m) * np.sum(y_pred - y) # 對 theta_0 求偏導
d_theta_1 = (2/m) * np.sum((y_pred - y) * X) # 對 theta_1 求偏導

# 更新參數
theta_0 -= alpha * d_theta_0
theta_1 -= alpha * d_theta_1

# 每 100 次輸出一次 loss
if epoch % 100 == 0:
print(f"Epoch {epoch}, Loss: {loss:.5f}, theta_0: {theta_0:.5f}, theta_1: {theta_1:.5f}")

print(f"最終結果: theta_0 = {theta_0:.5f}, theta_1 = {theta_1:.5f}")

2. 這段程式碼做了什麼?

  1. 初始化參數
    • theta_0 和 theta_1 設為 0,代表模型起點隨機。
    • alpha = 0.01,控制每次參數更新的幅度。
    • epochs = 1000,設定執行 1000 次梯度下降。
  2. 梯度下降步驟
    • 計算預測值--> 計算 MSE 損失函數-->計算梯度 (偏導數)-->更新參數

輸出結果

  • 每 100 次迭代輸出一次 Loss,讓我們觀察模型收斂過程。

最後輸出最終的 theta_0theta_1,得到最佳的線性回歸參數。

raw-image

用 Colaboratory 跑出的結果

Epoch 0, Loss: 44.00000, theta_0: 0.22000, theta_1: 0.66000
Epoch 100, Loss: 0.00313, theta_0: 0.01634, theta_1: 1.98496
Epoch 200, Loss: 0.00000, theta_0: 0.00919, theta_1: 1.99563
...
Epoch 900, Loss: 0.00000, theta_0: 0.00003, theta_1: 1.99999
最終結果: theta_0 = 0.00003, theta_1 = 1.99999

生活上的應用:上傳自己的檔案

  1. 貼上以下程式碼先上傳資料
from google.colab import files
import pandas as pd

# 上傳 CSV 檔案
uploaded = files.upload()

# 讀取 CSV exel 要改成pd.read_excel
df = pd.read_csv(next(iter(uploaded))) # 取得上傳的檔案名稱
print(df.head()) # 顯示前幾筆資料


raw-image


挖 失敗! 出現 NAN

raw-image


NaN"Not a Number" 的縮寫,代表「無效數值」,通常出現在以下情況:

  • 數值運算錯誤

例如 0 、 inf 這種無法計算的數值。

你的 NaN 問題可能來自 Loss 變成 inf 之後發生數值錯誤。

  • 資料有缺失

如果 X 或 y 有 NaN,運算時會導致 NaN 傳播。

但你的 df.dropna() 已經移除 NaN,所以這應該不是主要原因。

  • 數值過大導致溢出 (Overflow)

你的 theta_1 在 Epoch 100 之後變得超大 (10^308 這種等級),導致 Loss 變成 inf,進一步影響計算。

這通常是 學習率 (alpha) 過大,讓梯度下降失控,導致 theta_0 和 theta_1 在幾次迭代後爆炸。

問 ChatGPT 原因:

你的問題應該是 學習率過大,導致 Loss 變成 inf,進而產生 NaN。透過:

  1. 降低學習率 (alpha = 0.0001)
  2. 標準化 Xy
  3. 限制 theta_0theta_1 的範圍

我先說我的嘗試,因為0.0001,Theta_0 結果出現 0 ,有可能是均值接近0,所以可以先嘗試

👉 解決方法

試著先檢查 Xy 的均值:

如果 y_predy 的均值相等,那麼梯度接近 0,theta_0 幾乎不變。

print("X mean:", np.mean(X))
print("y mean:", np.mean(y))

如果 Xy 的均值接近 0,可以嘗試:

  • X 進行標準化,但 不要對 y 標準化
    X = (X - np.mean(X)) / np.std(X)

測試後發現 X meany mean 都非常接近 0 (在 e-17e-16 級別),這表示 X 和 y 已經標準化了

raw-image


因為標準化後的 X 是均值為 0 的對稱分佈,梯度下降在更新 theta_0 時可能變得極小,導致 theta_0 始終維持在 0


為什麼 theta_0 斜率= 0?

如果 X 已標準化 (mean ≈ 0),則:

  • 目標函數的最佳擬合線應該通過 (0, mean(y)),但因為 mean(y) ≈ 0,所以 theta_0 = 0
  • 在梯度更新時,theta_0 的梯度 (d_theta_0) 會接近 0,因此 theta_0 幾乎沒有變化。

這是數學上的結果,不是程式錯誤。


raw-image

在我把 y 標準化的程式碼註解掉就可以看到變化

# 標準化數據
X_mean, X_std = np.mean(X), np.std(X)
y_mean, y_std = np.mean(y), np.std(y)
X = (X - X_mean) / X_std
# 因為我的數據標準化會使截距貼近零,所以暫時不標準化 y = (y - y_mean) / y_std
print("X mean:", np.mean(X))
print("y mean:", np.mean(y))

在公司很常會希望用圖表報告,用 Matplotib 就可以繪製

可視化梯度下降 Matplotlib

用 Matplotlib 繪製梯度下降過程:

我先把這段程式碼貼到我前面的數據庫中

import matplotlib.pyplot as plt

# 繪製數據點 #一般的二維數據就可以跑
plt.scatter(X, y, color='blue', label='Actual data')

# 繪製最終擬合線
y_pred_final = theta_0 + theta_1 * X
plt.plot(X, y_pred_final, color='red', label='Fitted line')

plt.xlabel('X')
plt.ylabel('y')
plt.legend()
plt.title('Linear Regression using Gradient Descent')
plt.show()
剛好我選的數據集成正比

剛好我選的數據集成正比

試著亂改資料點xD

試著亂改資料點xD

改進:使用 Scikit-Learn

如果不想手刻梯度下降,可以使用 scikit-learn 的其中一個功能 LinearRegression 直接做回歸:

糖尿病資料庫

我用別的數據庫測試 https://scikit-learn.org/stable/datasets/toy_dataset.html

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split

# 1. 載入糖尿病數據集
diabetes = datasets.load_diabetes()

# 2. 取出其中一個特徵 (只用 BMI 來預測疾病指數)
X = diabetes.data[:, np.newaxis, 2] # 選擇第3個特徵
y = diabetes.target # 目標變數

# 3. 分割數據集為訓練集和測試集 (80% 訓練, 20% 測試)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 4. 建立並訓練線性回歸模型
model = LinearRegression()
model.fit(X_train, y_train)

# 5. 預測測試數據
y_pred = model.predict(X_test)

# 6. 視覺化結果
plt.scatter(X_test, y_test, color="black", label="Actual Data") # 真實值
plt.plot(X_test, y_pred, color="red", linewidth=2, label="Predicted Line") # 預測直線
plt.xlabel("BMI")
plt.ylabel("Disease Progression")
plt.legend()
plt.show()

# 7. 輸出模型參數
print(f"Intercept (theta_0): {model.intercept_}")
print(f"Slope (theta_1): {model.coef_[0]}")
raw-image

補充說明語法:

X = diabetes.data[:, np.newaxis, 2] # 選擇第3個特徵

[:] 是 Python 中的「切片語法(slicing)」,在這行程式碼裡的作用是選擇所有的列(rows),而 2 則是選擇第 3 個特徵(索引從 0 開始算)。因為我們只做二維的回歸,所以只能其中一個特徵(BMI) 做疾病指數的預測。

  • diabetes.data
    • 這是 scikit-learn 的糖尿病數據集,裡面有年齡、BMI、血壓,格式是 numpy.ndarray,維度為 (442, 10),代表 442 筆資料,每筆資料有 10 個特徵
  • [:, np.newaxis, 2]
    • : → 選取 所有的列(rows),也就是 442 筆資料 全部選取。
    • np.newaxis → 這個會增加一個維度,讓 X 變成 (442, 1) 的 2D 矩陣,而不是原本的 (442,) 1D 陣列。
    • 2 → 取出 索引 2 的特徵(第 3 個特徵)
留言
avatar-img
留言分享你的想法!
avatar-img
越南放大鏡 X 下班資工系
13會員
60內容數
雙重身份:越南放大鏡 X 下班資工系 政大東南亞語言學系是我接觸越南語的起點,畢業後找越南外派工作的生活跟資訊時,發現幾乎都是清單式的分享,很難身歷其境。所以我希望「越南放大鏡」可以帶讀者看到更多細節和深入的觀察。 - 下班資工系則是自學資工系的課程內容,記錄實際操作的過程,學習理論的過程。希望可以跟讀者一起成長。
2025/04/24
本系列文章將循序漸進地介紹 JavaScript 的核心概念,從基礎語法到進階應用,例如非同步程式設計和 React 基礎。內容淺顯易懂,並使用生活化的比喻幫助讀者理解,搭配程式碼範例,適合 JavaScript 初學者學習。
Thumbnail
2025/04/24
本系列文章將循序漸進地介紹 JavaScript 的核心概念,從基礎語法到進階應用,例如非同步程式設計和 React 基礎。內容淺顯易懂,並使用生活化的比喻幫助讀者理解,搭配程式碼範例,適合 JavaScript 初學者學習。
Thumbnail
2025/04/21
本文介紹行動通訊網路的演進歷史,從1G到5G,並說明ITU與3GPP在制定通訊規格上的重要角色,以及5G的三大關鍵應用場景:URLLC、eMBB和mMTC。
Thumbnail
2025/04/21
本文介紹行動通訊網路的演進歷史,從1G到5G,並說明ITU與3GPP在制定通訊規格上的重要角色,以及5G的三大關鍵應用場景:URLLC、eMBB和mMTC。
Thumbnail
2025/04/11
這篇文章說明網路的七層模型、IP 位址、通訊埠、TCP/UDP 協定、HTTP 協定、HTTP 狀態碼以及 WebSocket,並解釋它們之間的關係與互動方式。文中包含許多圖表和範例,幫助讀者理解這些網路概念。
Thumbnail
2025/04/11
這篇文章說明網路的七層模型、IP 位址、通訊埠、TCP/UDP 協定、HTTP 協定、HTTP 狀態碼以及 WebSocket,並解釋它們之間的關係與互動方式。文中包含許多圖表和範例,幫助讀者理解這些網路概念。
Thumbnail
看更多
你可能也想看
Thumbnail
「欸!這是在哪裡買的?求連結 🥺」 誰叫你太有品味,一發就讓大家跟著剁手手? 讓你回購再回購的生活好物,是時候該介紹出場了吧! 「開箱你的美好生活」現正召喚各路好物的開箱使者 🤩
Thumbnail
「欸!這是在哪裡買的?求連結 🥺」 誰叫你太有品味,一發就讓大家跟著剁手手? 讓你回購再回購的生活好物,是時候該介紹出場了吧! 「開箱你的美好生活」現正召喚各路好物的開箱使者 🤩
Thumbnail
我的「媽」呀! 母親節即將到來,vocus 邀請你寫下屬於你的「媽」故事——不管是紀錄爆笑的日常,或是一直想對她表達的感謝,又或者,是你這輩子最想聽她說出的一句話。 也歡迎你曬出合照,分享照片背後的點點滴滴 ♥️ 透過創作,將這份情感表達出來吧!🥹
Thumbnail
我的「媽」呀! 母親節即將到來,vocus 邀請你寫下屬於你的「媽」故事——不管是紀錄爆笑的日常,或是一直想對她表達的感謝,又或者,是你這輩子最想聽她說出的一句話。 也歡迎你曬出合照,分享照片背後的點點滴滴 ♥️ 透過創作,將這份情感表達出來吧!🥹
Thumbnail
這篇是給初學技術分析者的建議,覺得去蕪存菁,最簡潔有效的東西。 1.認識什麼是K線,開高低收,成交量。 2.知道均線與均量的數學意義。 3.學習簡單的走勢型態,比如W底M頭,切線,跳空缺口。 以上3點就足夠了,不管基於什麼說法想法理由,都不要花時間去學任何其他指標。
Thumbnail
這篇是給初學技術分析者的建議,覺得去蕪存菁,最簡潔有效的東西。 1.認識什麼是K線,開高低收,成交量。 2.知道均線與均量的數學意義。 3.學習簡單的走勢型態,比如W底M頭,切線,跳空缺口。 以上3點就足夠了,不管基於什麼說法想法理由,都不要花時間去學任何其他指標。
Thumbnail
前言 這篇會拿Finlab上的策略與機器學習預測線圖的因子進行結合。由於模型是透過2007-2011年的線圖作為訓練資料,回測的時候會從2012年開始以示公平。 還沒看過前面兩篇的可以點下面連結,會比較看得懂接下來的內容。 第一篇: 什麼?!AI也看得懂k線圖?利用機器學習來判斷股票漲
Thumbnail
前言 這篇會拿Finlab上的策略與機器學習預測線圖的因子進行結合。由於模型是透過2007-2011年的線圖作為訓練資料,回測的時候會從2012年開始以示公平。 還沒看過前面兩篇的可以點下面連結,會比較看得懂接下來的內容。 第一篇: 什麼?!AI也看得懂k線圖?利用機器學習來判斷股票漲
Thumbnail
還沒有看過上一篇的可以點擊下面連結 什麼?!AI也看得懂k線圖?利用機器學習來判斷股票漲跌。(1)論文解析。 這一篇會把注意力放在論文提到的技術並套用在台股市場,也會使用論文中的方法進行驗證,看看是否在台股也有一樣的超額報酬。 資料生成 第一步也是最難的一步-資料生成。 這裡
Thumbnail
還沒有看過上一篇的可以點擊下面連結 什麼?!AI也看得懂k線圖?利用機器學習來判斷股票漲跌。(1)論文解析。 這一篇會把注意力放在論文提到的技術並套用在台股市場,也會使用論文中的方法進行驗證,看看是否在台股也有一樣的超額報酬。 資料生成 第一步也是最難的一步-資料生成。 這裡
Thumbnail
在上一期文章中。我們使用TA_lib套件。來協助我們尋找隱藏在股票價格當中的特殊K線形態並把它尋找到的結果輸出到一個表格當中。雖然結果是以100。以及-100的簡明方式來呈現;
Thumbnail
在上一期文章中。我們使用TA_lib套件。來協助我們尋找隱藏在股票價格當中的特殊K線形態並把它尋找到的結果輸出到一個表格當中。雖然結果是以100。以及-100的簡明方式來呈現;
Thumbnail
在AI浪潮下的訊號開發 提到可將AI訓練好的模型產生之訊號當成一個商品來匯入,今天將手把把示範如何把這訊號進行匯入成商品,並在策略撰寫時,可引用至此訊號,當為輔助資訊。 此表格為筆者使用的CNN模型訊號,在此利用開盤價、最高價、最低價、收盤價的技巧,讓多方趨勢的日期呈現紅K、空方趨勢的日期呈現黑K
Thumbnail
在AI浪潮下的訊號開發 提到可將AI訓練好的模型產生之訊號當成一個商品來匯入,今天將手把把示範如何把這訊號進行匯入成商品,並在策略撰寫時,可引用至此訊號,當為輔助資訊。 此表格為筆者使用的CNN模型訊號,在此利用開盤價、最高價、最低價、收盤價的技巧,讓多方趨勢的日期呈現紅K、空方趨勢的日期呈現黑K
Thumbnail
常常我們在財經節目聽到一堆技術指標都可以成功獲利,但真的如此嗎? 這麼簡單的技術指標操作就能獲利,早就人人變成有錢人了! 相信數據會說話,身為軟體工程師就最喜歡用數字來解讀一切了,因此這個篇章將會手把手教你如何使用Python語言來回測你的股票及交易策略。 剛接觸股市時最常聽到的就是KD、RSI、
Thumbnail
常常我們在財經節目聽到一堆技術指標都可以成功獲利,但真的如此嗎? 這麼簡單的技術指標操作就能獲利,早就人人變成有錢人了! 相信數據會說話,身為軟體工程師就最喜歡用數字來解讀一切了,因此這個篇章將會手把手教你如何使用Python語言來回測你的股票及交易策略。 剛接觸股市時最常聽到的就是KD、RSI、
Thumbnail
上一篇我們有介紹如何爬取Goodinfo的資訊並統計分析,還沒閱讀的朋友建議先行閱讀,再進入此篇章會比較容易上手唷,傳送門如下: 🚪【Google Colab系列】以Goodinfo!為例,統計一段時間內的最高、最低殖利率 為什麼要做資料視覺化? 相信圖文甚至影音箱對於文字來說更為吸引我們進
Thumbnail
上一篇我們有介紹如何爬取Goodinfo的資訊並統計分析,還沒閱讀的朋友建議先行閱讀,再進入此篇章會比較容易上手唷,傳送門如下: 🚪【Google Colab系列】以Goodinfo!為例,統計一段時間內的最高、最低殖利率 為什麼要做資料視覺化? 相信圖文甚至影音箱對於文字來說更為吸引我們進
Thumbnail
我們將會對動態設定學習率(learning rate)作為最陡梯度下降法的變異演算法做介紹。內容包括了解釋什麼事循環式的學習率調整排程法和何謂使用指數衰退權重來計算移動平均值,同時也介紹如何對大量參數的變數進行最佳化和目前活躍的演算法變異。如 adagrad, adadelta 和 RMSprop
Thumbnail
我們將會對動態設定學習率(learning rate)作為最陡梯度下降法的變異演算法做介紹。內容包括了解釋什麼事循環式的學習率調整排程法和何謂使用指數衰退權重來計算移動平均值,同時也介紹如何對大量參數的變數進行最佳化和目前活躍的演算法變異。如 adagrad, adadelta 和 RMSprop
Thumbnail
A.i人工智慧真的能預測股市嗎 ? 我們不免俗再提到機器學習,前幾年機器學習,人工智慧這些名詞非常的夯,引領風潮,全世界都在瘋狂,因為AlphaGo 打敗了無數個圍棋高手,開始炒熱機器學習。有人也許好奇,AlphaGo的技術不就是人工神經網路嗎,他的概念由來已久......
Thumbnail
A.i人工智慧真的能預測股市嗎 ? 我們不免俗再提到機器學習,前幾年機器學習,人工智慧這些名詞非常的夯,引領風潮,全世界都在瘋狂,因為AlphaGo 打敗了無數個圍棋高手,開始炒熱機器學習。有人也許好奇,AlphaGo的技術不就是人工神經網路嗎,他的概念由來已久......
Thumbnail
這篇文章的標題有「預測」二字,但看完之後請大家思考一下,這種基於「統計學」、「機器學習」的預測方法,是否跟你心中的「預測」相差甚遠呢?
Thumbnail
這篇文章的標題有「預測」二字,但看完之後請大家思考一下,這種基於「統計學」、「機器學習」的預測方法,是否跟你心中的「預測」相差甚遠呢?
追蹤感興趣的內容從 Google News 追蹤更多 vocus 的最新精選內容追蹤 Google News