我想要一天分享一點「LLM從底層堆疊的技術」,並且每篇文章長度控制在三分鐘以內,讓大家不會壓力太大,但是又能夠每天成長一點。
class TextEncoder(keras.Model):
def __init__(self, max_length, vocab_size = 49408, name = None, download_weights = True):
tokens = keras.layers.Input(shape = (max_length,), dtype = "int32", name = "tokens")
positions = keras.layers.Input(shape = (max_length,), dtype = "int32", name = "positions")
x = CLIPEmbedding(vocab_size, 768, max_length)([tokens, positions])
for _ in range(12):
x = CLIPEncoderLayer(768, 12, activation = quick_gelu)(x)
embedded = keras.layers.LayerNormalization(epsilon = 1e-5)(x)
super().__init__([tokens, positions], embedded, name = name)
if download_weights:
text_encoder_weights_fpath = keras.utils.get_file(origin = "https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5",
file_hash = "4789e63e07c0e54d6a34a29b45ce81ece27060c499a709d556c7755b42bb0dc4")
self.load_weights(text_encoder_weights_fpath)
當中程式說明為:
- keras.layers.Input(shape = (max_length,), dtype = "int32", name ="tokens")
- shape = (max_length,) 表示輸入的形狀是長度為 max_length 的一維向量,這通常代表 Token 序列的長度
- dtype = "int32" 定義輸入數據的資料型別為整數 (int32),通常用於標記 Token IDs 編碼
- name = "tokens" 為此輸入層指定一個名稱 "tokens",方便在模型中引用此層
- CLIPEmbedding(vocab_size, 768, max_length)([tokens, positions])
- CLIPEmbedding(vocab_size, 768, max_length) 是 CLIP 模型中的嵌入層,它接收詞彙表大小 (vocab_size)、嵌入維度 (768) 和最大序列長度 (max_length) 作為參數,將標記序列和位置轉換成 768 維的嵌入表示
- CLIPEncoderLayer(768, 12, activation = quick_gelu)(x)
- 12 表示多頭自注意力的頭數,即模型會使用 12 個不同的注意力頭來捕捉輸入中不同部分的相關性
漂亮的原文說明為:

最後是原始元件的程式碼:
def quick_gelu(x):
return x * ops.sigmoid(x * 1.702)
class CLIPEmbedding(keras.layers.Layer):
def __init__(self, input_dim = 49408, output_dim = 768, max_length = 77, **kwargs)
super().__init__(**kwargs)
self.token_embedding = keras.layers.Embedding(input_dim, output_dim)
self.position_embedding = keras.layers.Embedding(max_length, output_dim)
def call(self, inputs):
tokens, positions = inputs
tokens = self.token_embedding(tokens)
positions = self.position_embedding(positions)
return tokens + positions
class CLIPEncoderLayer(keras.layers.Layer):
def __init__(self, embed_dim, num_heads, activation = None, **kwargs):
super().__init__(**kwargs)
self.layer_norm1 = keras.layers.LayerNormalization(epsilon = 1e-5)
self.clip_attn = CLIPAttention(embed_dim, num_heads, causal = True)
self.layer_norm2 = keras.layers.LayerNormalization(epsilon = 1e-5)
self.fc1 = keras.layers.Dense(embed_dim * 4) # Output Dimension is embed_dim * 4
self.fc2 = keras.layers.Dense(embed_dim)
self.activation = activation
def call(self, inputs):
residual = inputs
x = self.layer_norm1(inputs)
x = self.clip_attn(x)
x = residual + x
residual = x
x = self.layer_norm2(x)
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x + residual
class CLIPAttention(keras.layers.Layer):
def __init__(self, embed_dim = 768, num_heads = 12, causal = True, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.causal = causal
self.head_dim = self.embed_dim // self.num_heads
self.scale = self.head_dim**-0.5
self.q_proj = keras.layers.Dense(self.embed_dim)
self.k_proj = keras.layers.Dense(self.embed_dim)
self.v_proj = keras.layers.Dense(self.embed_dim)
self.out_proj = keras.layers.Dense(self.embed_dim)
def reshape_states(self, x, sequence_length, batch_size):
x = ops.reshape(x, (batch_size, sequence_length, self.num_heads, self.head_dim))
return ops.transpose(x, (0, 2, 1, 3)) # bs, heads, sequence_length, head_dim
def call(self, inputs, attention_mask = None):
if attention_mask is None and self.causal:
length = ops.shape(inputs)[1]
attention_mask = ops.triu(ops.ones((1, 1, length, length), dtype = self.compute_dtype)* -float("inf"), k = 1)
_, tgt_len, embed_dim = inputs.shape
query_states = self.q_proj(inputs) * self.scale
key_states = self.reshape_states(self.k_proj(inputs), tgt_len, -1)
value_states = self.reshape_states(self.v_proj(inputs), tgt_len, -1)
proj_shape = (-1, tgt_len, self.head_dim)
query_states = self.reshape_states(query_states, tgt_len, -1)
query_states = ops.reshape(query_states, proj_shape)
key_states = ops.reshape(key_states, proj_shape)
src_len = tgt_len
value_states = ops.reshape(value_states, proj_shape)
attn_weights = query_states @ ops.transpose(key_states, (0, 2, 1))
attn_weights = ops.reshape(attn_weights, (-1, self.num_heads, tgt_len, src_len))
attn_weights = attn_weights + attention_mask
attn_weights = ops.reshape(attn_weights, (-1, tgt_len, src_len))
attn_weights = ops.softmax(attn_weights, axis = -1)
attn_output = attn_weights @ value_states
attn_output = ops.reshape(attn_output, (-1, self.num_heads, tgt_len, self.head_dim))
attn_output = ops.transpose(attn_output, (0, 2, 1, 3))
attn_output = ops.reshape(attn_output, (-1, tgt_len, embed_dim))
return self.out_proj(attn_output)
當中程式說明為:
- ops.triu(ops.ones((1, 1, length, length), dtype = self.compute_dtype)* -float("inf"), k = 1)
- 取得 inputs 的第二維度 (即序列長度),並將其儲存在變數 length 中
- 生成一個大小為 (1, 1, length, length) 的矩陣,內容為 1,並轉換為模型的計算精度 self.compute_dtype
- 使用 ops.triu 函數,這個矩陣會被轉換為上三角形矩陣,並乘以 -float("inf"),因此除對角線下方為 0 外,其餘位置都會變成 -inf
- k = 0 會將主對角線及以上的元素設為 -inf,k = 1 則從主對角線上方一條對角線開始設為 -inf,主對角線元素不變
- 這樣的矩陣讓模型在計算注意力分數時,忽略掉未來的序列資訊,防止模型在自回歸的情境中窺探未來的資訊