2024-04-30|閱讀時間 ‧ 約 24 分鐘

類神經網路訓練 局部最小值 (local minima) 與鞍點 (saddle point)

    之前有提到有時我們在微分之後會得到gradient = 0的值,就以為我們已經找到最小值,但其實它只是local minima。

    那這一節主要想跟大家分享我們要怎麼區分是不是Local Minima。

    首先,如果在我們取微分之後,得到gradient = 0的情況,我們統稱為critical point

    那critical point根據圖形我們可以分成兩種類型:

    1. local minima:圖形最低點。
    2. saddle point:只是某個方向的最低點,在不同方向上其實還有路可以走。


    那要如何區分?

    1. 如上圖,可以透過圖形區分
    2. 透過Hessian區分

    Hessian

    計算方式:

    1. 透過泰勒展開式,我們可以將L(θ)近似於:

    𝐿(θ)≈ 𝐿(θ)′+(θ−θ′)𝑇 *g+1/2 * (θ−θ′)𝑇 𝐻(θ−θ′)

    此時處於critical point -> (θθ′)𝑇 *g的值為0

    *g: gradient

    *H: Hessian -> Hij =  (∂2/∂θiθj )*L(θ'): 對Loss函數的二次微分

    1. 計算vTHv,以及所有的eigen value λ:

    a. λ> 0 -> local minima

    b. λ< 0 -> local maxmum

    c. λ有正有負 -> saddle point

    Q: 如果算出來是saddle point呢?

    A: 那我們就能透過eigen value 與vector得出可以更新的方向

    假設我們得到的eigen value λ = 2, -2

    我們就計算出eigen vector

    接著照著eigen vector u的方向更新我們的θ,即能更新我們的參數


    推導過程:

    H可以替換成eigen value λ,v替換成u=[1 1]T

    => uTHu = uTλu = λ|u|2

    再帶回泰勒展開式我們可以得到

    𝐿(θ)= 𝐿(θ)′+1/2 * (θ−θ′)𝑇 𝐻(θ−θ′)

    (θ−θ′) = u帶入

    -> 𝐿(θ)= 𝐿(θ')+1/2 * uTλu

    如果λ < 0 我們能得知 𝐿(θ) < 𝐿(θ')

    => θ−θ′ = u -> θ = θ′ + u

    => 推得我們可以透過+u更新θ值


    Example:

    假設我們設計一個function為 y = w1w2x,目標是找到最接近y= 1的答案:

    我們可以得到Loss function:

    L = (ŷ - ​w1w2x)2 = (1-w1w2x)2 <- 只有一筆y= 1的訓練資料

    接著,我們透過微分得到gradient decent:

    ∂/∂w1 = 2(1-w2x)(-w2)

    ∂/∂w2 = 2(1-w1x)(-w1)

    而我們事先知道當w1 = 0, w2 = 0時為critical point

    (可以各自將w1 = 0的值帶入上述微分後的函數,結果也為0)

    我們計算出Hession: 將每個向量都做2次微分

    向量H = ∂*L/∂w12 ∂*L/∂w1∂w2 0 -2

    ∂*L/∂w2∂w1 ∂*L/∂w22 => -2 0

    => 由此得知 eigen value λ = 2,-2 => saddle point ​

    如何更新參數?

    eigen value λ = -2, eigen vector u = [1 1]T

    => 更新的 θ = θ' - u

    *但要注意,這樣的計算量極大,通常不採用這樣的方法


    那local minima與saddle point哪一個更常見?

    -> 事實上Saddle point 更多。

    我們可以透過檢查eigen value的正負決定

    -> minimum ratio 代表還有多少路可以走 = 正的eigen value數目 / 所有的eigen value數目


    以上是關於如何區分local minima與saddle point的辦法~





    分享至
    成為作者繼續創作的動力吧!
    © 2024 vocus All rights reserved.