我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。
# Upsampling Flow
for _ in range(3):
x = keras.layers.Concatenate()([x, outputs.pop()])
x = ResBlock(1280)([x, t_emb])
x = Upsample(1280)(x)
for _ in range(3):
x = keras.layers.Concatenate()([x, outputs.pop()])
x = ResBlock(1280)([x, t_emb])
x = SpatialTransformer(20, 64, fully_connected = True)([x, context])
x = Upsample(1280)(x)
for _ in range(3):
x = keras.layers.Concatenate()([x, outputs.pop()])
x = ResBlock(640)([x, t_emb])
x = SpatialTransformer(10, 64, fully_connected = True)([x, context])
x = Upsample(640)(x)
for _ in range(3):
x = keras.layers.Concatenate()([x, outputs.pop()])
x = ResBlock(320)([x, t_emb])
x = SpatialTransformer(5, 64, fully_connected = True)([x, context])
當中關鍵區塊程式的源碼為:
class Upsample(keras.layers.Layer):
def __init__(self, channels, **kwargs):
super().__init__(**kwargs)
self.ups = keras.layers.UpSampling2D(2)
self.conv = PaddedConv2D(channels, 3, padding = 1)
def call(self, inputs):
return self.conv(self.ups(inputs))
解析如下:
- keras.layers.Concatenate()([x, outputs.pop()])
- Keras 中的 Concatenate 是一個層,用來將多個張量沿指定的軸(默認為最後一個軸)進行拼接
- pop() 是 Python 列表的內建方法,用於移除列表中最後一個元素並返回該元素
- outputs 這個 List 的 Channel 維度依序為:320、320、320、320、640、640、640、1280、1280、1280、1280、1280
- 假設 x 和 outputs.pop() 的形狀分別為 (batch_size, height, width, channel_x) 和 (batch_size, height, width, channel_y),拼接後的形狀將為 (batch_size, height, width, channel_x + channel_y)
- 注意 ResBlock 和 SpatialTransformer 兩函數,一進入均會執行 Channel 維度對準,所以 ResBlock 和 SpatialTransformer 兩函數,各自兩個 Argument 的 Channel 維度不一樣,並不會影響程式運作
- keras.layers.UpSampling2D(2)
- UpSampling2D 是一個上採樣層,透過增加空間維度(高度和寬度)來放大圖像,具體放大倍率由參數 size 控制
- interpolation 指定插值方式,默認為 'nearest',使用最近鄰插值方法,其他選擇包括 'bilinear',提供更平滑的結果