用JAX訓練神經網絡

閱讀時間約 16 分鐘
Google JAX是一種用於轉換數值函數的機器學習框架。
它被描述為匯集了autograd(通過函數微分自動獲得梯度函數)和TensorFlowXLA(加速線性代數)的修改版本。
它旨在盡可能地遵循NumPy的結構和工作流程,並與各種現有框架(如TensorFlowPyTorch)一起工作
JAX 的主要功能是:
  1. grad: 自動微分/求導數
  2. jit:編譯/加速
  3. vmap:自動矢量化/批次處理(batch)
  4. pmap:SPMD編程
首先導入必要的庫
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
再來定義一個函數random_layer_params
輸入為(m,n,key,scale)分別對應輸入神經元數量,輸出神經元數量
隨機key,和一個scale控制數值大小,主要功能是返回一個隨機初始化的層
下面那個函數init_network_params則是給定layer_sizes和隨機key
返回整個神經網路架構,這裡要提的key有點像其他框架的random_seed
目的是讓程式有可再現性.
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
w_key, b_key = random.split(key)
return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
keys = random.split(key, len(sizes))
return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))
自動批處理預測
讓我們首先定義我們的預測函數。
請注意,我們正在為單個輸入範例定義這函數。
我們將使用 JAX 的 vmap 函數來自動處理batch(批量),而不會降低性能。
from jax.scipy.special import logsumexp

def relu(x):
return jnp.maximum(0, x)

def predict(params, image):
# per-example predictions
activations = image
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)

final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
讓我們檢查一下我們的預測函數是否僅適用於單個輸入。
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)
(10,)
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:
preds = predict(params, random_flattened_images)
except TypeError:
print('Invalid shapes!')
Invalid shapes!
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)
(10, 10)
至此,我們擁有了定義神經網絡並對其進行訓練所需的所有要素。我們已經構建了一個自動批處理版本的預測,我們應該能夠在損失函數中使用它
我們應該能夠使用 grad 對神經網絡參數求損失的導數。最後,我們應該能夠使用 jit 來加速一切。
def one_hot(x, k, dtype=jnp.float32):
"""Create a one-hot encoding of x of size k."""
return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, images, targets):
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
preds = batched_predict(params, images)
return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
grads = grad(loss)(params, x, y)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
使用tensorflow/datasets讀取訓練資料
讓我們使用看看 tensorflow/datasets的dataloader
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)
Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)
訓練迴圈
import time

def get_train_batches():
# as_supervised=True gives us the (image, label) as a tuple instead of a dict
ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
# You can build up an arbitrary tf.data input pipeline
ds = ds.batch(batch_size).prefetch(1)
# tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
return tfds.as_numpy(ds)

for epoch in range(num_epochs):
start_time = time.time()
for x, y in get_train_batches():
x = jnp.reshape(x, (len(x), num_pixels))
y = one_hot(y, num_labels)
params = update(params, x, y)
epoch_time = time.time() - start_time

train_acc = accuracy(params, train_images, train_labels)
test_acc = accuracy(params, test_images, test_labels)
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
Epoch 0 in 28.30 sec
Training set accuracy 0.8400499820709229
Test set accuracy 0.8469000458717346
Epoch 1 in 14.74 sec
Training set accuracy 0.8743667006492615
Test set accuracy 0.8803000450134277
Epoch 2 in 14.57 sec
Training set accuracy 0.8901500105857849
Test set accuracy 0.8957000374794006
Epoch 3 in 14.36 sec
Training set accuracy 0.8991333246231079
Test set accuracy 0.903700053691864
Epoch 4 in 14.20 sec
Training set accuracy 0.9061833620071411
Test set accuracy 0.9087000489234924
Epoch 5 in 14.89 sec
Training set accuracy 0.9113333225250244
Test set accuracy 0.912600040435791
Epoch 6 in 13.95 sec
Training set accuracy 0.9156833291053772
Test set accuracy 0.9176000356674194
Epoch 7 in 13.32 sec
Training set accuracy 0.9192000031471252
Test set accuracy 0.9214000701904297
Epoch 8 in 13.55 sec
Training set accuracy 0.9222500324249268
Test set accuracy 0.9241000413894653
Epoch 9 in 13.40 sec
Training set accuracy 0.9253666996955872
Test set accuracy 0.9269000291824341
我們現在已經使用了三個 JAX API:
  1. grad 用於求導數(gradient)
  2. jit 用於加速
  3. vmap 用於自動批量化(batch)
我們使用 NumPy 來指定我們所有的計算,並從 tensorflow/datasets 借用了強大的數據加載器,並在 GPU 上運行了整個過程。
為什麼會看到廣告
人工智能工作經驗跟研究
留言0
查看全部
avatar-img
發表第一個留言支持創作者!
世界目前正處於人工智能 (AI) 革命之中。 人工智能有可能改變和徹底改變許多行業和我們生活的方方面面, 而且越來越明顯的是,未來世界將嚴重依賴人工智能。 人工智能將產生重大影響的關鍵領域之一是自動化領域。 自動化是指在沒有人工干預的情況下使用技術來執行任務,已經存在了幾十年。 然而,人工智能的
在您的數據上免費使用 GPT3 這是GPT3根據Reddit的一些笑話微調後生成的笑話之一。如需更多 AI 生成的笑話,請滾動至文章末尾,我會在其中寫一些我最喜歡的由 GPT3 生成的笑話。
故事開始於2010年7月28日,「未來道具研究所」社團的兩人,岡部倫太郎和椎名真由理去秋葉原廣播會館參加中鉢博士的時間旅行理論發表會,見到了年僅18歲就在《科學》雜誌上發表學術論文的天才少女牧瀨紅莉栖。發表會結束不久後,在會館8樓深處,岡部發現了身上滿是鮮血的紅莉栖。驚慌失措的他帶著真由理立刻離開會
世界目前正處於人工智能 (AI) 革命之中。 人工智能有可能改變和徹底改變許多行業和我們生活的方方面面, 而且越來越明顯的是,未來世界將嚴重依賴人工智能。 人工智能將產生重大影響的關鍵領域之一是自動化領域。 自動化是指在沒有人工干預的情況下使用技術來執行任務,已經存在了幾十年。 然而,人工智能的
在您的數據上免費使用 GPT3 這是GPT3根據Reddit的一些笑話微調後生成的笑話之一。如需更多 AI 生成的笑話,請滾動至文章末尾,我會在其中寫一些我最喜歡的由 GPT3 生成的笑話。
故事開始於2010年7月28日,「未來道具研究所」社團的兩人,岡部倫太郎和椎名真由理去秋葉原廣播會館參加中鉢博士的時間旅行理論發表會,見到了年僅18歲就在《科學》雜誌上發表學術論文的天才少女牧瀨紅莉栖。發表會結束不久後,在會館8樓深處,岡部發現了身上滿是鮮血的紅莉栖。驚慌失措的他帶著真由理立刻離開會
你可能也想看
Google News 追蹤
Thumbnail
這個秋,Chill 嗨嗨!穿搭美美去賞楓,裝備款款去露營⋯⋯你的秋天怎麼過?秋日 To Do List 等你分享! 秋季全站徵文,我們準備了五個創作主題,參賽還有機會獲得「火烤兩用鍋」,一起來看看如何參加吧~
Thumbnail
11/20日NVDA即將公布最新一期的財報, 今天Sell Side的分析師, 開始調高目標價, 市場的股價也開始反應, 未來一週NVDA將重新回到美股市場的焦點, 今天我們要分析NVDA Sell Side怎麼看待這次NVDA的財報預測, 以及實際上Buy Side的倉位及操作, 從
Thumbnail
Hi 大家好,我是Ethan😊 相近大家都知道保濕是皮膚保養中最基本,也是最重要的一步。無論是在畫室裡長時間對著畫布,還是在旅途中面對各種氣候變化,保持皮膚的水分平衡對我來說至關重要。保濕化妝水不僅能迅速為皮膚補水,還能提升後續保養品的吸收效率。 曾經,我的保養程序簡單到只包括清潔和隨意上乳液
Thumbnail
矢掛町位於岡山縣西南部,氣候溫暖​​,春天,您可以在小田川沿岸和嵐山欣賞櫻花和油菜花盛開的日本鄉村美景;夏日的夜晚,星田川沿岸、美山川沿岸的宇內螢火蟲公園裡,散發著微弱光芒的螢火蟲瘋狂地舞動;秋天,被譽為名勝的大通寺的池塘噴泉觀賞庭園,銀杏樹等紅葉點綴其間,景色十分美麗;冬天,從町內或附近的展望台可
Thumbnail
千萬不要覺得經濟獨立是一切解案,父母的無能正在藉由你的「懂事」來開脫。
Thumbnail
題目敘述 Triangle 題目會給我們一個三角形的二維陣列triangle ,每個元素分別代表每個格子的成本,請問我們從最頂端到底部的下墜路徑的最小成本總和是多少? 每次下墜到下一排的時候,可以有兩種選擇: 1.往左下方的格子點移動。 2.往右下方的格子點移動。 測試範例 Examp
Thumbnail
本來以為那樣的情緒會掃蕩一空 看了好笑的影片 特調幾杯會大醉的酒 跟最親近的人講電話 爆睡了好幾天 那樣子的情緒還是跟著我 委靡不振用在這裡最適切了 但我想去旅行 想要有個人抱著我 我知道我沒事 只是如果有就好了
Thumbnail
I need a man who loves me like my father loves my mom. 我想要有個人能像我爸愛我媽一樣愛我 最近聽到哭的一首歌😭,分享給大家 Jax 的 "Like My Father" 世界上有什麼比得上一個美滿的家?
Thumbnail
接下來要講前後端怎麼溝通,最常見應該都是用 axios, ajax 或 fetch 來 call api。先講結論,個人推薦使用 axios,那他們又有什麼優缺點呢?讓我簡單講解一下 ajax ajax 用法有點麻煩,要先引入 jQuery,用法如下 ajax 比起 axios 較為笨重也較不安全
如有需要可以不用浪費效能重新將整個頁面重新載入,可以使用非同步的JS,使用動態載入資料
Thumbnail
當然,若是執意想買個股的話,我另外的建議是: 1選擇大型權值股,比如台積電、鴻海、中華電…等,相較於其他個股更安全些。 2選擇被很多ETF重覆納入的個股去挑選,如兆豐金就同時被0050、0056、00878納入。 以三大類型的ETF舉例來說:
Thumbnail
自從2020年3月份股市從8千多點反彈至今的1萬7千多點,處在這樣地市場氛圍,確實會有股市短期一直往上的錯覺,反而忽略了下跌的風險,甚至有些人還有想辭掉工作,準備當個全職交易者的念頭。 【關於個股需要做的功課】 【投資新手還有更好的選擇】 以三大類型的ETF舉例來說:
Thumbnail
疫情打亂許多行業的生產秩序,使人們被迫轉換生活方式,這還不打緊在2022年2月底俄國突然天外飛來一樁「烏俄衝突」,油價不斷飆漲通膨也誓不兩立一同標高,對於時下的投資人來說,見招拆招、挑戰突發狀況和反應速度,是目前投資必備的一項勘家本領。 ✅貨幣 ✅黃金貴金屬 ✅債券 _________
Thumbnail
這個秋,Chill 嗨嗨!穿搭美美去賞楓,裝備款款去露營⋯⋯你的秋天怎麼過?秋日 To Do List 等你分享! 秋季全站徵文,我們準備了五個創作主題,參賽還有機會獲得「火烤兩用鍋」,一起來看看如何參加吧~
Thumbnail
11/20日NVDA即將公布最新一期的財報, 今天Sell Side的分析師, 開始調高目標價, 市場的股價也開始反應, 未來一週NVDA將重新回到美股市場的焦點, 今天我們要分析NVDA Sell Side怎麼看待這次NVDA的財報預測, 以及實際上Buy Side的倉位及操作, 從
Thumbnail
Hi 大家好,我是Ethan😊 相近大家都知道保濕是皮膚保養中最基本,也是最重要的一步。無論是在畫室裡長時間對著畫布,還是在旅途中面對各種氣候變化,保持皮膚的水分平衡對我來說至關重要。保濕化妝水不僅能迅速為皮膚補水,還能提升後續保養品的吸收效率。 曾經,我的保養程序簡單到只包括清潔和隨意上乳液
Thumbnail
矢掛町位於岡山縣西南部,氣候溫暖​​,春天,您可以在小田川沿岸和嵐山欣賞櫻花和油菜花盛開的日本鄉村美景;夏日的夜晚,星田川沿岸、美山川沿岸的宇內螢火蟲公園裡,散發著微弱光芒的螢火蟲瘋狂地舞動;秋天,被譽為名勝的大通寺的池塘噴泉觀賞庭園,銀杏樹等紅葉點綴其間,景色十分美麗;冬天,從町內或附近的展望台可
Thumbnail
千萬不要覺得經濟獨立是一切解案,父母的無能正在藉由你的「懂事」來開脫。
Thumbnail
題目敘述 Triangle 題目會給我們一個三角形的二維陣列triangle ,每個元素分別代表每個格子的成本,請問我們從最頂端到底部的下墜路徑的最小成本總和是多少? 每次下墜到下一排的時候,可以有兩種選擇: 1.往左下方的格子點移動。 2.往右下方的格子點移動。 測試範例 Examp
Thumbnail
本來以為那樣的情緒會掃蕩一空 看了好笑的影片 特調幾杯會大醉的酒 跟最親近的人講電話 爆睡了好幾天 那樣子的情緒還是跟著我 委靡不振用在這裡最適切了 但我想去旅行 想要有個人抱著我 我知道我沒事 只是如果有就好了
Thumbnail
I need a man who loves me like my father loves my mom. 我想要有個人能像我爸愛我媽一樣愛我 最近聽到哭的一首歌😭,分享給大家 Jax 的 "Like My Father" 世界上有什麼比得上一個美滿的家?
Thumbnail
接下來要講前後端怎麼溝通,最常見應該都是用 axios, ajax 或 fetch 來 call api。先講結論,個人推薦使用 axios,那他們又有什麼優缺點呢?讓我簡單講解一下 ajax ajax 用法有點麻煩,要先引入 jQuery,用法如下 ajax 比起 axios 較為笨重也較不安全
如有需要可以不用浪費效能重新將整個頁面重新載入,可以使用非同步的JS,使用動態載入資料
Thumbnail
當然,若是執意想買個股的話,我另外的建議是: 1選擇大型權值股,比如台積電、鴻海、中華電…等,相較於其他個股更安全些。 2選擇被很多ETF重覆納入的個股去挑選,如兆豐金就同時被0050、0056、00878納入。 以三大類型的ETF舉例來說:
Thumbnail
自從2020年3月份股市從8千多點反彈至今的1萬7千多點,處在這樣地市場氛圍,確實會有股市短期一直往上的錯覺,反而忽略了下跌的風險,甚至有些人還有想辭掉工作,準備當個全職交易者的念頭。 【關於個股需要做的功課】 【投資新手還有更好的選擇】 以三大類型的ETF舉例來說:
Thumbnail
疫情打亂許多行業的生產秩序,使人們被迫轉換生活方式,這還不打緊在2022年2月底俄國突然天外飛來一樁「烏俄衝突」,油價不斷飆漲通膨也誓不兩立一同標高,對於時下的投資人來說,見招拆招、挑戰突發狀況和反應速度,是目前投資必備的一項勘家本領。 ✅貨幣 ✅黃金貴金屬 ✅債券 _________