用JAX訓練神經網絡

更新 發佈閱讀 17 分鐘
raw-image

Google JAX是一種用於轉換數值函數的機器學習框架。

它被描述為匯集了autograd(通過函數微分自動獲得梯度函數)和TensorFlowXLA(加速線性代數)的修改版本。

它旨在盡可能地遵循NumPy的結構和工作流程,並與各種現有框架(如TensorFlowPyTorch)一起工作

JAX 的主要功能是:

  1. grad: 自動微分/求導數
  2. jit:編譯/加速
  3. vmap:自動矢量化/批次處理(batch)
  4. pmap:SPMD編程

首先導入必要的庫

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

再來定義一個函數random_layer_params

輸入為(m,n,key,scale)分別對應輸入神經元數量,輸出神經元數量

隨機key,和一個scale控制數值大小,主要功能是返回一個隨機初始化的層

下面那個函數init_network_params則是給定layer_sizes和隨機key

返回整個神經網路架構,這裡要提的key有點像其他框架的random_seed

目的是讓程式有可再現性.

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
w_key, b_key = random.split(key)
return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
keys = random.split(key, len(sizes))
return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))

自動批處理預測

讓我們首先定義我們的預測函數。

請注意,我們正在為單個輸入範例定義這函數。

我們將使用 JAX 的 vmap 函數來自動處理batch(批量),而不會降低性能。

from jax.scipy.special import logsumexp

def relu(x):
return jnp.maximum(0, x)

def predict(params, image):
# per-example predictions
activations = image
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)

final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)

讓我們檢查一下我們的預測函數是否僅適用於單個輸入。

# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)
(10,)
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:
preds = predict(params, random_flattened_images)
except TypeError:
print('Invalid shapes!')
Invalid shapes!
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)
(10, 10)

至此,我們擁有了定義神經網絡並對其進行訓練所需的所有要素。我們已經構建了一個自動批處理版本的預測,我們應該能夠在損失函數中使用它

我們應該能夠使用 grad 對神經網絡參數求損失的導數。最後,我們應該能夠使用 jit 來加速一切。

def one_hot(x, k, dtype=jnp.float32):
"""Create a one-hot encoding of x of size k."""
return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, images, targets):
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
preds = batched_predict(params, images)
return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
grads = grad(loss)(params, x, y)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]

使用tensorflow/datasets讀取訓練資料

讓我們使用看看 tensorflow/datasets的dataloader

import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)
Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)

訓練迴圈

import time

def get_train_batches():
# as_supervised=True gives us the (image, label) as a tuple instead of a dict
ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
# You can build up an arbitrary tf.data input pipeline
ds = ds.batch(batch_size).prefetch(1)
# tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
return tfds.as_numpy(ds)

for epoch in range(num_epochs):
start_time = time.time()
for x, y in get_train_batches():
x = jnp.reshape(x, (len(x), num_pixels))
y = one_hot(y, num_labels)
params = update(params, x, y)
epoch_time = time.time() - start_time

train_acc = accuracy(params, train_images, train_labels)
test_acc = accuracy(params, test_images, test_labels)
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
Epoch 0 in 28.30 sec
Training set accuracy 0.8400499820709229
Test set accuracy 0.8469000458717346
Epoch 1 in 14.74 sec
Training set accuracy 0.8743667006492615
Test set accuracy 0.8803000450134277
Epoch 2 in 14.57 sec
Training set accuracy 0.8901500105857849
Test set accuracy 0.8957000374794006
Epoch 3 in 14.36 sec
Training set accuracy 0.8991333246231079
Test set accuracy 0.903700053691864
Epoch 4 in 14.20 sec
Training set accuracy 0.9061833620071411
Test set accuracy 0.9087000489234924
Epoch 5 in 14.89 sec
Training set accuracy 0.9113333225250244
Test set accuracy 0.912600040435791
Epoch 6 in 13.95 sec
Training set accuracy 0.9156833291053772
Test set accuracy 0.9176000356674194
Epoch 7 in 13.32 sec
Training set accuracy 0.9192000031471252
Test set accuracy 0.9214000701904297
Epoch 8 in 13.55 sec
Training set accuracy 0.9222500324249268
Test set accuracy 0.9241000413894653
Epoch 9 in 13.40 sec
Training set accuracy 0.9253666996955872
Test set accuracy 0.9269000291824341

我們現在已經使用了三個 JAX API:

  1. grad 用於求導數(gradient)
  2. jit 用於加速
  3. vmap 用於自動批量化(batch)

我們使用 NumPy 來指定我們所有的計算,並從 tensorflow/datasets 借用了強大的數據加載器,並在 GPU 上運行了整個過程。

留言
avatar-img
留言分享你的想法!
avatar-img
于正龍(Ricky)的沙龍
49會員
77內容數
人工智能工作經驗跟研究
2025/03/05
你做錯了。你剛剛發給 ChatGPT 的「寫一個函式來……」的提示?刪掉它吧。這些通用提示就是為什麼你的編碼速度還跟其他人一樣的原因。 在與 AI 進行超過 3,000 小時的結對編程後,我發現了真正有效的方法——而這並不是你想的那樣。 真相是:85% 的開發者陷入了 AI 驅動的複製粘貼循環。
2025/03/05
你做錯了。你剛剛發給 ChatGPT 的「寫一個函式來……」的提示?刪掉它吧。這些通用提示就是為什麼你的編碼速度還跟其他人一樣的原因。 在與 AI 進行超過 3,000 小時的結對編程後,我發現了真正有效的方法——而這並不是你想的那樣。 真相是:85% 的開發者陷入了 AI 驅動的複製粘貼循環。
2025/03/05
簡介 — 我如何停止浪費時間的故事 幾年前,我意識到我花在“做事”上的時間比實際在專案上取得進展的時間要多。我醒來時有無休止的待辦事項清單、回復電子郵件、參加會議、審查檔,但到一天結束時,我覺得我實際上沒有在任何重要的事情上取得進展。 有一天,一個朋友告訴我: 忙碌並不等同於有效。 這讓
2025/03/05
簡介 — 我如何停止浪費時間的故事 幾年前,我意識到我花在“做事”上的時間比實際在專案上取得進展的時間要多。我醒來時有無休止的待辦事項清單、回復電子郵件、參加會議、審查檔,但到一天結束時,我覺得我實際上沒有在任何重要的事情上取得進展。 有一天,一個朋友告訴我: 忙碌並不等同於有效。 這讓
2023/09/30
看到滿多年輕工程師提問:工作時經常查 ChatGPT,感覺不太踏實,沒關係嗎? 讓我簡單談論一下這件事 --- 首先,讓我們把時間倒回 2000 年代 google 剛出來的時候 當時一定也是這樣, 年輕工程師遇到問題狂查 google 資深工程師則覺得 google 可有可無,
2023/09/30
看到滿多年輕工程師提問:工作時經常查 ChatGPT,感覺不太踏實,沒關係嗎? 讓我簡單談論一下這件事 --- 首先,讓我們把時間倒回 2000 年代 google 剛出來的時候 當時一定也是這樣, 年輕工程師遇到問題狂查 google 資深工程師則覺得 google 可有可無,
看更多
你可能也想看
Thumbnail
還在煩惱平凡日常該如何增添一點小驚喜嗎?全家便利商店這次聯手超萌的馬來貘,推出黑白配色的馬來貘雪糕,不僅外觀吸睛,層次豐富的雙層口味更是讓人一口接一口!本文將帶你探索馬來貘雪糕的多種創意吃法,從簡單的豆漿燕麥碗、藍莓果昔,到大人系的奇亞籽布丁下午茶,讓可愛的馬來貘陪你度過每一餐,增添生活中的小確幸!
Thumbnail
還在煩惱平凡日常該如何增添一點小驚喜嗎?全家便利商店這次聯手超萌的馬來貘,推出黑白配色的馬來貘雪糕,不僅外觀吸睛,層次豐富的雙層口味更是讓人一口接一口!本文將帶你探索馬來貘雪糕的多種創意吃法,從簡單的豆漿燕麥碗、藍莓果昔,到大人系的奇亞籽布丁下午茶,讓可愛的馬來貘陪你度過每一餐,增添生活中的小確幸!
Thumbnail
本篇文章介紹如何使用PyTorch構建和訓練圖神經網絡(GNN),並使用Cora資料集進行節點分類任務。通過模型架構的逐步優化,包括引入批量標準化和獨立的消息傳遞層,調整Dropout和聚合函數,顯著提高了模型的分類準確率。實驗結果表明,經過優化的GNN模型在處理圖結構數據具有強大的性能和應用潛力。
Thumbnail
本篇文章介紹如何使用PyTorch構建和訓練圖神經網絡(GNN),並使用Cora資料集進行節點分類任務。通過模型架構的逐步優化,包括引入批量標準化和獨立的消息傳遞層,調整Dropout和聚合函數,顯著提高了模型的分類準確率。實驗結果表明,經過優化的GNN模型在處理圖結構數據具有強大的性能和應用潛力。
Thumbnail
本文參考TensorFlow官網Deep Convolutional Generative Adversarial Network的程式碼來加以實作說明。 示範如何使用深度卷積生成對抗網路(DCGAN) 生成手寫數位影像。
Thumbnail
本文參考TensorFlow官網Deep Convolutional Generative Adversarial Network的程式碼來加以實作說明。 示範如何使用深度卷積生成對抗網路(DCGAN) 生成手寫數位影像。
Thumbnail
延續上一篇訓練GAM模型,這次我們讓神經網路更多層更複雜一點,來看訓練生成的圖片是否效果會更好。 [深度學習][Python]訓練MLP的GAN模型來生成圖片_訓練篇 資料集分割處理的部分在延續上篇文章,從第五點開始後修改即可,前面都一樣 訓練過程,比較圖 是不是CNN的效果比MLP還要好,
Thumbnail
延續上一篇訓練GAM模型,這次我們讓神經網路更多層更複雜一點,來看訓練生成的圖片是否效果會更好。 [深度學習][Python]訓練MLP的GAN模型來生成圖片_訓練篇 資料集分割處理的部分在延續上篇文章,從第五點開始後修改即可,前面都一樣 訓練過程,比較圖 是不是CNN的效果比MLP還要好,
Thumbnail
本文主要介紹,如何利用GAN生成對抗網路來訓練生成圖片。 利用tensorflow,中的keras來建立生成器及鑑別器互相競爭訓練,最後利用訓練好的生成器來生成圖片。 GAN生成對抗網路的介紹 它由生成網路(Generator Network)和鑑別網路(Discriminator Netwo
Thumbnail
本文主要介紹,如何利用GAN生成對抗網路來訓練生成圖片。 利用tensorflow,中的keras來建立生成器及鑑別器互相競爭訓練,最後利用訓練好的生成器來生成圖片。 GAN生成對抗網路的介紹 它由生成網路(Generator Network)和鑑別網路(Discriminator Netwo
Thumbnail
透過這篇文章,我們將瞭解如何使用PyTorch實作圖神經網絡中的訊息傳遞機制,從定義消息傳遞的類別到實作消息傳遞過程。我們也探討了各種不同的消息傳遞機制,並通過對單次和多次傳遞過程的結果,可以看到節點特徵如何逐步傳遞與更新。
Thumbnail
透過這篇文章,我們將瞭解如何使用PyTorch實作圖神經網絡中的訊息傳遞機制,從定義消息傳遞的類別到實作消息傳遞過程。我們也探討了各種不同的消息傳遞機制,並通過對單次和多次傳遞過程的結果,可以看到節點特徵如何逐步傳遞與更新。
Thumbnail
本文主要筆記使用pytorch建立graph的幾個概念與實作。在傳統的神經網路模型中,數據點之間往往是互相連接和影響的,使用GNN,我們不僅處理單獨的數據點或Xb,而是處理一個包含多個數據點和它們之間連結的特徵。GNN的優勢在於其能夠將這些連結關係納入模型中,將關係本身作為特徵進行學習。
Thumbnail
本文主要筆記使用pytorch建立graph的幾個概念與實作。在傳統的神經網路模型中,數據點之間往往是互相連接和影響的,使用GNN,我們不僅處理單獨的數據點或Xb,而是處理一個包含多個數據點和它們之間連結的特徵。GNN的優勢在於其能夠將這些連結關係納入模型中,將關係本身作為特徵進行學習。
Thumbnail
這篇文章探討了生成式對抗網路中機率分佈的使用與相關的訓練方式,包括Generator不同的點、Distriminator的訓練過程、生成圖片的條件設定等。此外,也提到了GAN訓練的困難與解決方式以及不同的learning方式。文章內容豐富且詳細,涵蓋了GAN的各個相關面向。
Thumbnail
這篇文章探討了生成式對抗網路中機率分佈的使用與相關的訓練方式,包括Generator不同的點、Distriminator的訓練過程、生成圖片的條件設定等。此外,也提到了GAN訓練的困難與解決方式以及不同的learning方式。文章內容豐富且詳細,涵蓋了GAN的各個相關面向。
追蹤感興趣的內容從 Google News 追蹤更多 vocus 的最新精選內容追蹤 Google News