在AI已經全民運動的年代,Google還是希望有一個更小巧精美的深度學習套件讓大家都能快速上手──JAX就這麼誕生了。
在台灣,本書可說是領先群雄的第一本JAX手冊。不管你是Tensorflow或PyTorch的使用者,都可以試著從MNIST開始。當你發現JAX的程式碼行數是Tensorflow的1/10、PyTorch的1/3,不僅速度更快,程式還更容易理解、更加「Pythnoic」。現在,你真的可以放心的進入JAX的世界,當你上手之後,不論是CNN、RNN、NLP或是GAN,全部可以又快又好又清楚地做出來。
【JAX是什麼?】
「工欲善其事,必先利其器」。人工智慧或其核心理論深度學習也一樣。任何一個好的成果實作並在將來發揮其巨大作用,都需要一個能夠將其實作並應用的基本框架工具。JAX 是機器學習框架領域的新生力量,它具有更快的高階微分計算方法,可以採用先編譯後執行的模式,突破了已有深度學習框架的局限性,同時具有更好的硬體支援,甚至將來可能會成為Google 的主要科學計算深度學習函數庫。
JAX 官方文件的解釋是:「JAX 是CPU、GPU 和TPU 上的NumPy,具有出色的自動差異化功能,可用於高性能機器學習研究。」就像JAX 官方文件解釋的那樣,最簡單的JAX 是加速器支持的NumPy,它具有一些便利的功能,具有一定靈活性,可用於常見的機器學習操作。
【JAX與XLA】
在全面講解JAX 之前先介紹一下XLA。簡單來說,XLA 是將JAX轉化為加速器支持操作的中堅力量。XLA 的全稱是Accelerated Linear Algebra,即加速線性代數。身為深度學習編譯器,其長期以來作為Google 在深度學習領域的重要特性被開發,歷時至今已經超過兩年,特別是作為TensorFlow 2.0 背後支持力量之一,XLA 也終於從試驗特性變成了預設打開的特性。
JAX 可以自動微分本機Python 和NumPy 程式。它可以透過Python的大部分功能(包括迴圈、if、遞迴和閉包)進行微分,甚至可以採用衍生類別的衍生類別。它支援反向模式和正向模式微分,並且兩者可以以任意順序組成。
JAX 的新功能是使用XLA 在諸如GPU 和TPU 的加速器上編譯和執行NumPy 程式。預設情況下,編譯是在後台進行的,而函數庫呼叫將得到即時的編譯和執行。但是,JAX 允許使用單功能API 將Python 函數編譯為XLA 最佳化的核心。編譯和自動微分可以任意組合,因此我們無需離開Python 即可表達複雜的演算法並獲得最佳性能。
【JAX 與NumPy】
JAX 在應用上是想取代NumPy 成為下一代標準運算函數庫。眾所皆知,NumPy 提供了一個功能強大的數字處理API。JAX 吸取NumPy 的優點並使之成為自己框架的部分,同時這也能在不改變使用者使用習慣的基礎上方便使用者快速掌握JAX。
在一定程度上,NumPy 的API 可以無縫平移到JAX 中使用,可以說JAX API 緊接NumPy 的API。然而還是有一些重要的區別的。最重要的區別就是JAX 是被設計為函數式的,就像函數式程式設計一樣(例如Scala 語言)。這背後的原因是JAX 支援的程式轉換類型在函數式程式中更可行。
【用JAX實作GAN生成對抗網路】
GAN 是一種生成式的對抗網路。具體來說,就是透過對抗的方式去學習資料分佈的生成式模型。所謂的對抗,指的是生成網路和判別網路的互相對抗。生成網路盡可能生成逼真樣本,判別網路則盡可能去判別該樣本是真實樣本還是生成的假樣本。
- 判別器:學習不同類別和標籤之間的區分界限。
- 生成器:學習標籤中某一類的機率分佈進行建模。
判別器中的判別演算法能夠判別這幅畫是不是由真正的畫家完成的。
生成器的做法恰恰相反,它不關心向量是什麼形式和內容,只關心給定標籤資訊,嘗試由給定的標籤內容去生成特徵,這也和人類思考的過程相類似。
正如其他一些具有非常大研究價值和潛力的學科一樣,GAN 的發展也越來越受到關注,對其的研究也越深入。GAN 採用簡單的生成與判別關係,在大量重複學習運算之後,可能為行業發展帶來十分巨大的想像力。從基本原理上看,GAN 可以透過不斷地自我判別來推導出更真實、更符合訓練目的的生成樣本。這就給圖片、視訊等領域帶來了極大的想像空間。
--
本書深度解說最新人工智慧套件JAX的使用。從基本概念開始談起,在Windows環境下架設WSL以方便使用GPU,而不需要全新從Linux安裝。
人工智慧時代的來臨造就了Keras的大流行,你可以開始使用JAX連貫所有技能,習得更多元的機器學習技能。
--
本文取自深智數位出版之〈Tensorflow 接班王者:Google JAX 深度學習又快又強大〉