圖片來源
我們已經介紹過關於Transformer模型的平台「【Hugging Face】Ep.1 平凡人也能玩的起的AI平台」,而操作的過程中相信也會有不少玩家會遇到這樣的狀況,因此將遇到的問題整理並分享解決方法,讓需要的朋友可以參考一下。
Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
故事是這樣的, 小明是個軟體工程師, 專門研究語音辨識領域的技術, 有一天進行wav2vec2的語音辨識模型進行辨識時, 竟然在關鍵時刻發生了錯誤, 而這個錯誤可能也會是其他人遇到的狀況, 因此決定將過程好好整理一番, 以幫助在相同技術道路上的夥伴一起突破難關。
首先小明使用了wav2vec2的語辨模型, 並載入中文模型「wav2vec2-large-xlsr-53-chinese-zh-cn-gpt」, 並期望使用GPU加速辨識速度, 因此將DEVICE設定為cuda。
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
SRC_MODEL = 'ydshieh/wav2vec2-large-xlsr-53-chinese-zh-cn-gpt'
DEVICE = 'cuda'
processor = Wav2Vec2Processor.from_pretrained(SRC_MODEL)
model = Wav2Vec2ForCTC.from_pretrained(SRC_MODEL).to(DEVICE)
接著就直接對音檔進行辨識。
audio_buffer, _ = sf.read('test.wav')
input_values = processor(audio_buffer, sampling_rate=16000, return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])
transcription
結果竟然出錯了…, 這應該怎麼辦呢?
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) ...
根據錯誤訊息推測大致是輸入類型是CPU(torch.FloatTensor), 但模型類型是GPU(torch.cuda.FloatTensor), 因此需要將數據來源轉成GPU的類型才能相符。
我們試著將音訊數據轉成「torch.cuda.FloatTensor」類型的資料。
input_values = input_values.to(DEVICE)
如此一來模型與數據的資料型態就會都一致了, 畢竟GPU與CPU並不相容,因此進行運算時也要非常的小心…。