torch.nn.Embedding到底在幹嘛

2023/04/28閱讀時間約 6 分鐘

其實跟word2vec, skipgram都沒什麼關係

如果你跟我一樣是先看了transformers或者是至少word embeddings相關的papers才回去設法用pytorch來實作 一開始一定會非常非常困惑
會不知道這個 torch.nn.Embedding在搞什麼鬼
查了官方文件你可能會跟我一樣更困惑...
A simple lookup table that stores embeddings of a fixed dictionary and size.
This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding word embeddings.
Parameters:
  • num_embeddings (int) – size of the dictionary of embeddings
  • embedding_dim (int) – the size of each embedding vector
... 以下略
什麼 simple lookup table?沒有訓練過這個東西 到底哪來的word embedding? 底下到底是word2vec, gloVe, 還是什麼pretrained的東西?

答案

其實答案很簡單: 都不是
其實就是"隨機"

我們再看一次這個document

其實
num_embeddings, 第一個參數的意思就是, 隨便給定一個vocabulary size, 比方說 3, 那麼
nn.Embedding 就會幫你準備 3個空位
第二個參數embedding_dim 會直接幫決定他幫你準備的隨機的representation要有幾個dimensions, 你幫說5

背後在做什麼

當你設了 nn.Embedding(3, 5)
那麼你可以想成其實就是這樣
{
0: [.123123, .123123, .123123, .12312, .123123], # 五個隨機的floats來代表0 這個token
1: [.456456,.456456,.456456,.456546,.456456,.42342],# 五個隨機的floats來代表1 這個token
2: [.789789, .987987, .98798, .5789, .7896, .794] #五個隨機的floats來代表2 這個token
}
為什麼是5個數字呢? 因為你embedding_dim設成5, 如果你設成384就會有384個隨機數字對應到每一個id

可是我想處理文字 又不是數字 - Tokenizer在幹嘛

你可能接下來會感到困惑的點是... 可是我想處理文字 又不是數字...
所以... 其實tokenizer就是在做這件事
假設你想要把 "你好嗎" 這句話拿去配合什麼東西訓練
那麼你就可能會有個tokenizer做這件事:
{你: 0, 好:1, 嗎:2}
你的文字input經過tokenizer之後就會變成一串數字
比方說"你好好嗎嗎"就會變成[0, 1, 1, 2, 2]
"你嗎嗎好"就會變成[0,2,2,1]

所以

然後經過nn.Embedding的時候他就把剛剛的隨機數字塞進去
所以"你好好嗎嗎" 會被轉成這樣 (就只是去查[0, 1, 1, 2, 2])
[[.123123, .123123, .123123, .12312, .123123],
[.456456,.456456,.456456,.456546,.456456,.42342],
[.456456,.456456,.456456,.456546,.456456,.42342],
[.789789, .987987, .98798, .5789, .7896, .794],
[.789789, .987987, .98798, .5789, .7896, .794]]
"你嗎嗎好"就會變成(就只是去查[0,2,2,1]) (我們這邊先不管padding)
[[.123123, .123123, .123123, .12312, .123123],
[.789789, .987987, .98798, .5789, .7896, .794],
[.789789, .987987, .98798, .5789, .7896, .794],
[.456456,.456456,.456456,.456546,.456456,.42342]
]

更新參數

接下來你會有一個task可能是要訓練model來分類什麼東西
比方說聽起來像不像髒話
那麼
[[.123123, .123123, .123123, .12312, .123123],
[.456456,.456456,.456456,.456546,.456456,.42342],
[.456456,.456456,.456456,.456546,.456456,.42342],
[.789789, .987987, .98798, .5789, .7896, .794],
[.789789, .987987, .98798, .5789, .7896, .794]]
可能會對應到 0 (不像)
[[.123123, .123123, .123123, .12312, .123123],
[.789789, .987987, .98798, .5789, .7896, .794],
[.789789, .987987, .98798, .5789, .7896, .794],
[.456456,.456456,.456456,.456546,.456456,.42342]
]
可能會對應到 1 (有點像)等等
然後只要你不鎖住nn.embedding的參數
那麼這些隨機的數字就會被更新, 已讓你的classification更準

Vocabulary Size的影響

你的第一個參數會影響到你有幾個place holders可以用
剛剛我們設3
所以只有三個不同的tokens可以用
所以一但傳進去的index超過2, 就會出錯(list out of range)
所以大部分的語言模型都會設一個很大的數字像是80000
再搭配tokenizer
為什麼會看到廣告
8會員
15內容數
對工程師友善的(目前免費)英文教材 #工程師 #Coding #Python #Django #English #英文 #文法 #語言學習 #程式
留言0
查看全部
發表第一個留言支持創作者!