2023-07-20|閱讀時間 ‧ 約 4 分鐘

【AI訓練故障篇】FloatTensor and cuda.FloatTensor should be the same

raw-image
圖片來源

我們已經介紹過關於Transformer模型的平台「【Hugging Face】Ep.1 平凡人也能玩的起的AI平台」,而操作的過程中相信也會有不少玩家會遇到這樣的狀況,因此將遇到的問題整理並分享解決方法,讓需要的朋友可以參考一下。

問題

Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

前景提要

故事是這樣的, 小明是個軟體工程師, 專門研究語音辨識領域的技術, 有一天進行wav2vec2的語音辨識模型進行辨識時, 竟然在關鍵時刻發生了錯誤, 而這個錯誤可能也會是其他人遇到的狀況, 因此決定將過程好好整理一番, 以幫助在相同技術道路上的夥伴一起突破難關。

圖片來源...

首先小明使用了wav2vec2的語辨模型, 並載入中文模型「wav2vec2-large-xlsr-53-chinese-zh-cn-gpt」, 並期望使用GPU加速辨識速度, 因此將DEVICE設定為cuda。

from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
SRC_MODEL = 'ydshieh/wav2vec2-large-xlsr-53-chinese-zh-cn-gpt'
DEVICE = 'cuda'
processor = Wav2Vec2Processor.from_pretrained(SRC_MODEL)
model = Wav2Vec2ForCTC.from_pretrained(SRC_MODEL).to(DEVICE)

接著就直接對音檔進行辨識。

audio_buffer, _ = sf.read('test.wav')

input_values = processor(audio_buffer, sampling_rate=16000, return_tensors="pt").input_values

logits = model(input_values).logits

predicted_ids = torch.argmax(logits, dim=-1)

transcription = processor.decode(predicted_ids[0])

transcription

結果竟然出錯了…, 這應該怎麼辦呢?

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) ...

原因

根據錯誤訊息推測大致是輸入類型是CPU(torch.FloatTensor), 但模型類型是GPU(torch.cuda.FloatTensor), 因此需要將數據來源轉成GPU的類型才能相符。

如何解決

我們試著將音訊數據轉成「torch.cuda.FloatTensor」類型的資料。

input_values = input_values.to(DEVICE)

如此一來模型與數據的資料型態就會都一致了, 畢竟GPU與CPU並不相容,因此進行運算時也要非常的小心…。

分享至
成為作者繼續創作的動力吧!
© 2024 vocus All rights reserved.