最高加速9倍!字節(jié)跳動開源8比特混合精度Transformer引擎(1)
SC22 接收論文:https://sc22.supercomputing.org/presentation/?id=pap211&sess=sess154
代碼地址:https://github.com/bytedance/lightseq
如何繼續(xù)提升速度?降低計(jì)算精度是比較直接的方法。2017 年以來,fp16 混合精度技術(shù) [2] 獲得了廣泛應(yīng)用。在對模型效果無損的前提下,將模型訓(xùn)練和推理的速度提升了 50% 以上。而為了維持模型效果,更低精度的方法(例如 int8)通常需要使用如下傳統(tǒng)方案:
首先使用 fp16 混合精度將模型訓(xùn)練至收斂;
然后在模型計(jì)算密集型算子的權(quán)重、輸入和輸出位置處,插入偽量化結(jié)點(diǎn),進(jìn)行量化感知訓(xùn)練;
最后將帶有偽量化結(jié)點(diǎn)的模型計(jì)算圖轉(zhuǎn)換到專用的 int8 推理引擎中,進(jìn)行服務(wù)部署和模型推理。
雖然在多數(shù)任務(wù)上,上述方案可以實(shí)現(xiàn)模型效果無損,但還是存在以下問題:
使用方法復(fù)雜。例如要多一次量化感知訓(xùn)練 [4] 的過程,并且?guī)в袀瘟炕?jié)點(diǎn)的計(jì)算圖轉(zhuǎn)換復(fù)雜。
訓(xùn)練速度慢。由于目前流行的深度學(xué)習(xí)框架不支持 int8 精度,所以量化感知訓(xùn)練需要插入 fp16 的偽量化結(jié)點(diǎn)來模擬 int8 量化,導(dǎo)致量化感知訓(xùn)練反而比 fp16 混合精度訓(xùn)練慢 2-3 倍。
推理部署難且加速比低。對比 fp32、fp16 等類型,int8 硬件和底層軟件庫優(yōu)化相對滯后。例如在 NVIDIA GPU 上,int8 矩陣乘法加速受限于硬件架構(gòu)和特定 shape,實(shí)際加速比遠(yuǎn)遠(yuǎn)低于理論值。
在下文中,如無特殊說明,量化都是指的 int8 精度的量化。
針對這些問題,字節(jié)跳動推出了全新版本的 LightSeq GPU 量化訓(xùn)練與推理引擎。支持 Transformer 系列模型的量化訓(xùn)練與推理,并做到了開箱即用,用戶友好。LightSeq 快準(zhǔn)狠地實(shí)現(xiàn)了 int8 精度的量化訓(xùn)練和推理:
快:A100 多卡訓(xùn)練最高加速 5.2 倍,T4 單卡推理最高加速 8.9 倍。
準(zhǔn):訓(xùn)練和推理效果基本無損。
狠:相同數(shù)據(jù)量下,顯存占用最高減少 68%,模型存儲空間減少 75%。
總體來說,LightSeq 新版量化訓(xùn)練與推理引擎具有如下幾個優(yōu)點(diǎn):
1. 豐富的支持
支持完整的 Transformer 模塊和多種解碼算法,支持 Transformer、BERT、GPT、BART、ViT 等多種模型結(jié)構(gòu),支持 Fairseq、Hugging Face、NeurST 等多種訓(xùn)練框架接入量化訓(xùn)練、導(dǎo)出模型以及量化推理,提供了豐富的樣例供用戶參考。
2. 卓越的性能
相比于 fp16 精度的 LightSeq 推理引擎,int8 量化還可以進(jìn)一步加速最高 70%,相比于 PyTorch 推理更是達(dá)到了最高 8.9 倍的加速比。同時顯存占用相比 fp16 推理引擎降低了 30% 左右,模型存儲空間只需要原來的四分之一。最后經(jīng)過多個任務(wù)的驗(yàn)證,推理效果幾乎無損。
3. 便捷的使用
LightSeq 已經(jīng)針對多個訓(xùn)練庫進(jìn)行了量化支持,可以一鍵開啟量化訓(xùn)練,然后輕松導(dǎo)出為 LightSeq 支持的模型格式,最后實(shí)現(xiàn)量化推理。除此之外,LightSeq 還支持訓(xùn)練后量化,無需額外訓(xùn)練即可體驗(yàn)量化推理。
使用方法
如上圖所示,為了最大程度減小量化帶來的損失,首先需要用 fp16 精度訓(xùn)練一個浮點(diǎn)數(shù)模型,將模型效果訓(xùn)到最好。然后開啟量化進(jìn)行 finetune,得到微調(diào)過的量化模型,此時模型效果已經(jīng)基本恢復(fù)到浮點(diǎn)數(shù)模型的水平。接著將量化模型轉(zhuǎn)換為 LightSeq 支持的 PB 或者 HDF5 模型格式,最后用 LightSeq 進(jìn)行量化推理。
安裝方法
LightSeq 安裝非常簡單,只需要一行命令即可:
pip install lightseq
量化訓(xùn)練
LightSeq 支持 Fairseq、Hugging Face、NeurST 等訓(xùn)練框架的量化接入,同時也可以自定義模型并開啟量化訓(xùn)練。以 encoder 層為例,只需要先定義浮點(diǎn)數(shù)模型,然后開啟量化即可:
from lightseq.training import LSTransformerEncoderLayerfrom lightseq.training.ops.pytorch.quantization import enable_quant
config = LSTransformerEncoderLayer.get_config( model="bert-base", max_batch_tokens=4096, max_seq_len=512, fp16=True, local_rank=0,)layer = LSTransformerEncoderLayer(config)# 開啟量化layer.apply(enable_quant)
量化推理
LightSeq 提供了便捷的 python 推理接口,只需要三行代碼即可實(shí)現(xiàn)快速的量化推理:
import lightseq.inference as lsi
model = lsi.QuantTransformer(pb_path, batch_size)result = model.infer(input)
此外 LightSeq 還提供了 BERT、GPT、ViT 等模型的 python 接口,分別調(diào)用 QuantBert、QuantGpt 和 QuanVit 即可體驗(yàn)。
梯度通信量化
LightSeq 支持 Transformer 模型的梯度通信量化[5],使用 Fairseq 或者 Hugging Face 即可輕松開啟分布式量化訓(xùn)練,并同時支持浮點(diǎn)數(shù)模型和量化模型。在構(gòu)建模型后,只需要為模型注冊一個 communication hook 即可開啟梯度通信量化,再開始訓(xùn)練過程。
from lightseq.training.gradient_comm_quantization import encode_and_decode, GCQStatefrom torch.nn.parallel import DistributedDataParallel
# model could be from Fairseq or Hugging Face, wrapped by DDPmodel = DistributedDataParallel(model)state = GCQState(process_group)# register hookmodel.register_comm_hook(state=state, hook=encode_and_decode)
性能測試
LightSeq 在多個任務(wù)上測試了量化訓(xùn)練、量化推理和梯度通信量化的速度,并且分析了顯存占用情況和量化模型的效果。
量化訓(xùn)練速度
LightSeq 在 8 張 A100 顯卡上進(jìn)行了訓(xùn)練實(shí)驗(yàn),主要對比對象是 Fairseq 的 Transformer、Hugging Face 的 BERT、GPT2 和 ViT。
可以看出,四種模型結(jié)構(gòu)加速趨勢都是類似的,加速比都會隨著數(shù)據(jù)量的增大而減小,原因有三點(diǎn):
隨著數(shù)據(jù)量的增大,矩陣乘法 GEMM 的占比會明顯增加,因此 PyTorch QAT 增加的額外的偽量化結(jié)點(diǎn)時間占比會逐漸減小,最后速度會和 PyTorch fp16 無限接近。
與此同時,隨著 GEMM 占比升高,LightSeq fp16 自定義算子的提速效果也逐漸減小,因此時間上也會和 PyTorch fp16 無限接近。
由于 Ampere 架構(gòu)顯卡上 int8 GEMM 在 shape 較小時甚至不如 fp16 GEMM 快,在大 shape 下才能稍快一點(diǎn),因此隨著數(shù)據(jù)量增大,LightSeq int8 也會無限接近 LightSeq fp16 的速度。
量化推理速度
LightSeq 在單張 T4 顯卡上進(jìn)行了推理實(shí)驗(yàn),主要對比對象是 Hugging Face 的 Transformer、BERT、GPT2 和 ViT。
可以看出,隨著輸入數(shù)據(jù)量的增大,LightSeq 與 PyTorch 的差距會逐漸減小,這也是 GEMM 占比升高造成的。比較 LightSeq fp16 和 LightSeq int8,可以看出隨著數(shù)據(jù)量的增大,LightSeq int8 越來越快。這是因?yàn)樵?T4 顯卡上,int8 GEMM 的加速會隨著 shape 的增大而有明顯增加。因此在 T4 顯卡上進(jìn)行量化推理時,輸入數(shù)據(jù)量越大,加速效果越好。
LightSeq 還針對機(jī)器翻譯多個語向和多個測試集,測試了不同 batch size 下,LightSeq int8 推理相對于 LightSeq fp16 推理的加速比,實(shí)驗(yàn)同樣是在單張 T4 顯卡上進(jìn)行的,采用的模型都是標(biāo)準(zhǔn)的 Transformer-Big。
可以得到和上文中相同的結(jié)論,隨著 batch size 的增大,量化推理的加速比會逐漸升高。相比于 LightSeq fp16,最高還可以再加速近 70%,這極大地縮短了線上翻譯模型的推理延時。
最后如上圖所示,為了展示自動 GEMM 調(diào)優(yōu)技術(shù)的效果,LightSeq 測試對比了 A100 顯卡上 Transformer 和 BERT 模型 fp16、int8 調(diào)優(yōu)前和 int8 調(diào)優(yōu)后的延時。可以看出調(diào)優(yōu)前某些 shape 的 int8 GEMM 速度甚至比 fp16 還要慢,而調(diào)優(yōu)后全面超越了 fp16。
顯存占用
LightSeq 分析了不同 batch size 下,量化模型相對于浮點(diǎn)數(shù)模型顯存占用的加速比??梢钥闯鲭S著 batch size 的增大,量化模型的顯存占用優(yōu)勢更明顯,最高可以減少 30% 左右。而 LightSeq fp16 引擎相對于 PyTorch 模型也極大程度減少了顯存占用,因此 LightSeq int8 引擎最終能夠減少最多 68% 左右的顯存。
量化模型效果
針對機(jī)器翻譯多個語向和多個測試集,LightSeq 測試了量化模型推理相對于浮點(diǎn)數(shù)模型 BLEU 的損失,采用的模型都是標(biāo)準(zhǔn)的 Transformer-Big。
在數(shù)據(jù)量較大的語向 en2zh 上,LightSeq int8 相對 BLEU 損失較大些,最大達(dá)到了 - 0.4。而在數(shù)據(jù)量較小的語向 en2es 上,LightSeq int8 不僅沒有任何效果損失,反而比浮點(diǎn)數(shù)模型更好。總體而言,int8 量化模型的平均 BLEU 相比浮點(diǎn)數(shù)模型基本無損。在 GLUE 和 SQuAD 等多個任務(wù)上,LightSeq 也驗(yàn)證了量化模型的效果。
梯度通信量化
由于在多機(jī)多卡場景下通信瓶頸更加明顯,所以梯度通信量化主要應(yīng)用在分布式訓(xùn)練場景。因此 LightSeq 在 2 機(jī) 8 卡的 A100 上進(jìn)行了分布式訓(xùn)練的速度測試。
可以看出,梯度通信量化的訓(xùn)練加速效果整體上隨著輸入數(shù)據(jù)的增大而減弱。這主要是因?yàn)殡S著輸入數(shù)據(jù)的增大,計(jì)算時間占比升高,梯度通信時間占比減少,梯度量化的收益也隨之減小。
LightSeq 還額外增加了不同數(shù)量網(wǎng)卡(NIC)下的訓(xùn)練速度測試??梢钥吹绞褂锰荻韧ㄐ帕炕姆植际接?xùn)練速度相比原始的 LightSeq fp16 有大幅度提升。
*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點(diǎn),如有侵權(quán)請聯(lián)系工作人員刪除。