博客專欄

EEPW首頁 > 博客 > Soft Diffusion:谷歌新框架從通用擴散過程中正確調(diào)度、學(xué)習(xí)和采樣

Soft Diffusion:谷歌新框架從通用擴散過程中正確調(diào)度、學(xué)習(xí)和采樣

發(fā)布人:機器之心 時間:2022-10-19 來源:工程師 發(fā)布文章

近來,擴散模型成為 AI 領(lǐng)域的研究熱點。谷歌研究院和 UT-Austin 的研究者在最新的一項研究中充分考慮了「損壞」過程,并提出了一個用于更通用損壞過程的擴散模型設(shè)計框架。


我們知道,基于分?jǐn)?shù)的模型和去噪擴散概率模型(DDPM)是兩類強大的生成模型,它們通過反轉(zhuǎn)擴散過程來產(chǎn)生樣本。這兩類模型已經(jīng)在 Yang Song 等研究者的論文《Score-based generative modeling through stochastic differential equations》中統(tǒng)一到了單一的框架下,并被廣泛地稱為擴散模型。


目前,擴散模型在包括圖像、音頻、視頻生成以及解決逆問題等一系列應(yīng)用中取得了巨大的成功。Tero Karras 等研究者在論文《Elucidating the design space of diffusionbased generative models》中對擴散模型的設(shè)計空間進行了分析,并確定了 3 個階段,分別為 i) 選擇噪聲水平的調(diào)度,ii) 選擇網(wǎng)絡(luò)參數(shù)化(每個參數(shù)化生成一個不同的損失函數(shù)),iii) 設(shè)計采樣算法。


近日,在谷歌研究院和 UT-Austin 合作的一篇 arXiv 論文《Soft Diffusion: Score Matching for General Corruptions》中,幾位研究者認(rèn)為擴散模型仍有一個重要的步驟:損壞(corrupt)。一般來說,損壞是一個添加不同幅度噪聲的過程,對于 DDMP 還需要重縮放。雖然有人嘗試使用不同的分布來進行擴散,但仍缺乏一個通用的框架。因此,研究者提出了一個用于更通用損壞過程的擴散模型設(shè)計框架。


具體地,他們提出了一個名為 Soft Score Matching 的新訓(xùn)練目標(biāo)和一種新穎的采樣方法 Momentum Sampler。理論結(jié)果表明,對于滿足正則條件的損壞過程,Soft Score MatchIng 能夠?qū)W習(xí)它們的分?jǐn)?shù)(即似然梯度),擴散必須將任何圖像轉(zhuǎn)換為具有非零似然的任何圖像。


在實驗部分,研究者在 CelebA 以及 CIFAR-10 上訓(xùn)練模型,其中在 CelebA 上訓(xùn)練的模型實現(xiàn)了線性擴散模型的 SOTA FID 分?jǐn)?shù)——1.85。同時與使用原版高斯去噪擴散訓(xùn)練的模型相比,研究者訓(xùn)練的模型速度顯著更快。


圖片


論文地址:https://arxiv.org/pdf/2209.05442.pdf


方法概覽


通常來說,擴散模型通過反轉(zhuǎn)逐漸增加噪聲的損壞過程來生成圖像。研究者展示了如何學(xué)習(xí)對涉及線性確定性退化和隨機加性噪聲的擴散進行反轉(zhuǎn)。


圖片


具體地,研究者展示了使用更通用損壞模型訓(xùn)練擴散模型的框架,包含有三個部分,分別為新的訓(xùn)練目標(biāo) Soft Score Matching、新穎采樣方法 Momentum Sampler 和損壞機制的調(diào)度。


首先來看訓(xùn)練目標(biāo) Soft Score Matching,這個名字的靈感來自于軟過濾,是一種攝影術(shù)語,指的是去除精細細節(jié)的過濾器。它以一種可證明的方式學(xué)習(xí)常規(guī)線性損壞過程的分?jǐn)?shù),還在網(wǎng)絡(luò)中合并入了過濾過程,并訓(xùn)練模型來預(yù)測損壞后與擴散觀察相匹配的圖像。


只要擴散將非零概率指定為任何干凈、損壞的圖像對,則該訓(xùn)練目標(biāo)可以證明學(xué)習(xí)到了分?jǐn)?shù)。另外,當(dāng)損壞中存在加性噪聲時,這一條件總是可以得到滿足。


具體地,研究者探究了如下形式的損壞過程。


圖片


在過程中,研究者發(fā)現(xiàn)噪聲在實證(即更好的結(jié)果)和理論(即為了學(xué)習(xí)分?jǐn)?shù))這兩方面都很重要。這也成為了其與反轉(zhuǎn)確定性損壞的并發(fā)工作 Cold Diffusion 的關(guān)鍵區(qū)別。


其次是采樣方法 Momentum Sampling。研究者證明,采樣器的選擇對生成樣本質(zhì)量具有顯著影響。他們提出了 Momentum Sampler,用于反轉(zhuǎn)通用線性損壞過程。該采樣器使用了不同擴散水平的損壞的凸組合,并受到了優(yōu)化中動量方法的啟發(fā)。


這一采樣方法受到了上文 Yang Song 等人論文提出的擴散模型連續(xù)公式化的啟發(fā)。Momentum Sampler 的算法如下所示。


圖片


下圖直觀展示了不同采樣方法對生成樣本質(zhì)量的影響。圖左使用 Naive Sampler 采樣的圖像似乎有重復(fù)且缺少細節(jié),而圖右 Momentum Sampler 顯著提升了采樣質(zhì)量和 FID 分?jǐn)?shù)。


圖片


最后是調(diào)度。即使退化的類型是預(yù)定義的(如模糊),決定在每個擴散步驟中損壞多少并非易事。研究者提出一個原則性工具來指導(dǎo)損壞過程的設(shè)計。為了找到調(diào)度,他們將沿路徑分布之間的 Wasserstein 距離最小化。直觀地講,研究者希望從完全損壞的分布平穩(wěn)過渡到干凈的分布。


實驗結(jié)果


研究者在 CelebA-64 和 CIFAR-10 上評估了提出的方法,這兩個數(shù)據(jù)集都是圖像生成的標(biāo)準(zhǔn)基線。實驗的主要目的是了解損壞類型的作用。


研究者首先嘗試使用模糊和低幅噪聲進行損壞。結(jié)果表明,他們提出的模型在 CelebA 上實現(xiàn)了 SOTA 結(jié)果,即 FID 分?jǐn)?shù)為 1.85,超越了所有其他僅添加噪聲以及可能重縮放圖像的方法。此外在 CIFAR-10 上獲得的 FID 分?jǐn)?shù)為 4.64,雖未達到 SOTA 但也具有競爭力。


圖片


此外,在 CIFAR-10 和 CelebA 數(shù)據(jù)集上,研究者的方法在另一項指標(biāo)采樣時間上也表現(xiàn)更好。另一個額外的好處是具有顯著的計算優(yōu)勢。與圖像生成去噪方法相比,去模糊(幾乎沒有噪聲)似乎是一種更有效的操縱。


下圖展示了 FID 分?jǐn)?shù)如何隨著函數(shù)評估數(shù)量(Number of Function Evaluations, NFE)而變。從結(jié)果可以看到,在 CIFAR-10 和 CelebA 數(shù)據(jù)集上,研究者的模型可以使用明顯更少的步驟來獲得與標(biāo)準(zhǔn)高斯去噪擴散模型相同或更好的質(zhì)量。


圖片


*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權(quán)請聯(lián)系工作人員刪除。



關(guān)鍵詞: AI

相關(guān)推薦

技術(shù)專區(qū)

關(guān)閉