模仿學習

模仿學習

更新於 發佈於 閱讀時間約 6 分鐘

本文介紹了一種名為Diffusion Model-Augmented Behavioral Cloning,(簡稱 DBC)的新型模仿學習框架。這個框架旨在結合建模條件機率和聯合機率的優勢,以改善模仿學習的效能。


模仿學習的目標是從專家示範中學習政策,而無需與環境互動。現有的不需要環境互動的模仿學習方法通常將專家分佈建模為條件機率 p(a|s)(如行為克隆,BC)或聯合機率 p(s,a)。雖然使用 BC 建模條件機率較為簡單,但通常難以通用化。而建模聯合機率雖然可以改善通用化效能,但推理過程往往耗時,且模型可能遭受流形過擬合問題。DBC 框架採用擴散模型來建模專家行為,並學習一個政策來同時最佳化 BC 損失(條件機率)和作者提出的擴散模型損失(聯合機率)。具體來說,DBC 包含以下步驟:

  1. 使用擴散模型對專家的狀態-動作對進行建模。
  2. 訓練一個政策網路,同時最佳化 BC 損失和擴散模型損失。


BC 損失定義為:

L_BC = E[(s,a)~D, â~π(s)][d(a, â)]其中 d(·,·) 表示動作對之間的距離度量。


擴散模型損失包括代理擴散損失和專家擴散損失:

L_diff^agent = E[s~D, â~π(s)][||φ(s, â, n) - ε||²]

L_diff^expert = E[(s,a)~D][||φ(s, a, n) - ε||²]


最終的擴散模型損失為:

L_DM = E[(s,a)~D, â~π(s)][max(L_diff^agent - L_diff^expert, 0)]


總損失函數為:

L_total = L_BC + λL_DM其中 λ 是一個係數,用於平衡兩個損失項的重要性。


作者在各種連續控制任務中評估了 DBC 的效能,包括導航、機器人手臂操作、靈巧操作和運動控制。實驗結果表明,DBC 在所有任務中都優於或達到與基本方法相當的效能。主要的實驗結果包括:

  1. 在 MAZE 環境中,DBC 達到了 95.4% 的成功率,與擴散政策(95.5%)相當,優於 BC(92.1%)和隱式 BC(78.3%)。
  2. 在 FETCHPICK 任務中,DBC 的成功率為 97.5%,明顯優於其他方法(BC:91.6%,隱式 BC:69.4%,擴散政策:83.9%)。
  3. 在 HANDROTATE 環境中,DBC(60.1%)與擴散政策(61.7%)表現相當,優於 BC(57.5%)和隱式 BC(13.8%)。
  4. 在 CHEETAH 和 WALKER 環境中,DBC 分別達到了 4909.5 和 7034.6 的回報,優於或與 BC 相當。
  5. 在 ANTREACH 任務中,DBC 的成功率為 70.1%,優於所有基本方法。


此外,作者還進行了一系列消融實驗和分析,以驗證 DBC 的設計選擇和效能:

  1. 比較不同生成模型:作者將擴散模型與能量基礎模型(EBM)、變分自動編碼器(VAE)和生成對抗網路(GAN)進行了比較。結果顯示,擴散模型在大多數情況下都能達到最佳效能。
  2. 擴散模型損失係數 λ 的影響:實驗表明,適當選擇 λ 值可以平衡 BC 損失和擴散模型損失,從而獲得最佳效能。
  3. 歸一化項的效果:作者驗證了使用專家擴散損失進行歸一化的有效性,結果顯示歸一化可以提高模型的效能。
  4. 流形過擬合實驗:作者設計了一個實驗來驗證 DBC 在處理低維流形上的高維數據時的效能,結果表明 DBC 能夠有效地克服流形過擬合問題。
  5. 泛化實驗:在 FETCHPICK 環境中,作者通過向初始狀態和目標位置注入不同程度的雜訊來評估模型的泛化能力。結果顯示,DBC 在不同雜訊水平下都能保持較好的效能,優於其他基本方法。


作者還討論了 BC 損失和擴散模型損失之間的關係。從訓練過程來看,同時最佳化這兩個目標可以使學習到的政策更接近最佳政策。從理論角度來看,BC 損失可以近似為最小化前向 KL 散度,而擴散模型損失可以近似為最小化反向 KL 散度。這兩種散度的結合可以在模式覆蓋和樣本質量之間取得平衡。總的來說,DBC 框架通過結合條件機率和聯合機率建模的優勢,在各種連續控制任務中展現出優秀的效能。它不僅能夠有效地預測給定狀態下的動作,還能更好地泛化到未見過的狀態,同時減輕了流形過擬合問題。然而,DBC 也存在一些限制。首先,它是為了從專家軌跡中學習而設計的,無法從代理軌跡中學習。其次,DBC 的效能可能受到專家示範質量的影響。最後,雖然 DBC 在連續控制任務中表現出色,但在離散動作空間或更複雜的任務中的效能還有待進一步研究。未來的研究方向可能包括:

  1. 擴展 DBC 以納入代理數據,這可能允許在可以與環境互動時進行改進。
  2. 探索 DBC 在更複雜任務和不同類型動作空間中的應用。
  3. 研究如何進一步提高 DBC 的樣本效率和計算效率。
  4. 調查 DBC 在處理具有多模態行為的任務時的效能。
  5. 探索將 DBC 與其他模仿學習和強化學習方法結合的可能性。


總結來說,DBC為模仿學習領域提供了一個新的研究方向,通過結合條件機率和聯合機率建模的優勢,在多個具有挑戰性的連續控制任務中取得了優秀的效能。這種方法不僅提高了模型的一般化能力,還緩解了流形過擬合問題,為未來的研究和應用開闢了新的可能性。


Reference

  1. https://arxiv.org/abs/2302.13335
avatar-img
Kiki的沙龍
1會員
39內容數
心繫正體中文的科學家,立志使用正體中文撰寫文章。 此沙龍預計涵蓋各項資訊科技知識分享與學習心得
留言
avatar-img
留言分享你的想法!
Kiki的沙龍 的其他內容
本文簡介 3GPP 在 Release 18 與 Release 19中引入人工智慧/機器學習(AI/ML)功能到無線電介面、無線電接取網路和核心網路的標準化工作。
MLIR是什麼以及使用MLIR的優點
tcpdump -i <網路介面> 捕捉流經網路介面的通訊。
Raspberry Pi 5 不再支援 raspi-gpio 指令,因此在Raspberry Pi 5 上執行 GPIO 操作指令 raspi-gpio 時,會顯示以下訊息指示「使用 pinctrl」
最近各組織正急於整合大型語言模型(LLMs)以改善其線上用戶體驗。這使它們面臨網路LLM攻擊的風險,這些攻擊嘗試取得不允許存取的資料、API或阻擋使用者。
協調型同時定位與建構地圖(C-SLAM)是在室內、地下、水中等無外部定位系統的環境中,多機器人協同運作的必須要素。傳統的C-SLAM系統可分為集中型和分散型兩類。集中型系統將所有機器人的地圖資料集中到遠端基地站,計算全域SLAM估計。
本文簡介 3GPP 在 Release 18 與 Release 19中引入人工智慧/機器學習(AI/ML)功能到無線電介面、無線電接取網路和核心網路的標準化工作。
MLIR是什麼以及使用MLIR的優點
tcpdump -i <網路介面> 捕捉流經網路介面的通訊。
Raspberry Pi 5 不再支援 raspi-gpio 指令,因此在Raspberry Pi 5 上執行 GPIO 操作指令 raspi-gpio 時,會顯示以下訊息指示「使用 pinctrl」
最近各組織正急於整合大型語言模型(LLMs)以改善其線上用戶體驗。這使它們面臨網路LLM攻擊的風險,這些攻擊嘗試取得不允許存取的資料、API或阻擋使用者。
協調型同時定位與建構地圖(C-SLAM)是在室內、地下、水中等無外部定位系統的環境中,多機器人協同運作的必須要素。傳統的C-SLAM系統可分為集中型和分散型兩類。集中型系統將所有機器人的地圖資料集中到遠端基地站,計算全域SLAM估計。