Soft Diffusion:谷歌新框架從通用擴(kuò)散過(guò)程中正確調(diào)度、學(xué)習(xí)和采樣
近來(lái),擴(kuò)散模型成為 AI 領(lǐng)域的研究熱點(diǎn)。谷歌研究院和 UT-Austin 的研究者在最新的一項(xiàng)研究中充分考慮了「損壞」過(guò)程,并提出了一個(gè)用于更通用損壞過(guò)程的擴(kuò)散模型設(shè)計(jì)框架。
我們知道,基于分?jǐn)?shù)的模型和去噪擴(kuò)散概率模型(DDPM)是兩類強(qiáng)大的生成模型,它們通過(guò)反轉(zhuǎn)擴(kuò)散過(guò)程來(lái)產(chǎn)生樣本。這兩類模型已經(jīng)在 Yang Song 等研究者的論文《Score-based generative modeling through stochastic differential equations》中統(tǒng)一到了單一的框架下,并被廣泛地稱為擴(kuò)散模型。
目前,擴(kuò)散模型在包括圖像、音頻、視頻生成以及解決逆問(wèn)題等一系列應(yīng)用中取得了巨大的成功。Tero Karras 等研究者在論文《Elucidating the design space of diffusionbased generative models》中對(duì)擴(kuò)散模型的設(shè)計(jì)空間進(jìn)行了分析,并確定了 3 個(gè)階段,分別為 i) 選擇噪聲水平的調(diào)度,ii) 選擇網(wǎng)絡(luò)參數(shù)化(每個(gè)參數(shù)化生成一個(gè)不同的損失函數(shù)),iii) 設(shè)計(jì)采樣算法。
近日,在谷歌研究院和 UT-Austin 合作的一篇 arXiv 論文《Soft Diffusion: Score Matching for General Corruptions》中,幾位研究者認(rèn)為擴(kuò)散模型仍有一個(gè)重要的步驟:損壞(corrupt)。一般來(lái)說(shuō),損壞是一個(gè)添加不同幅度噪聲的過(guò)程,對(duì)于 DDMP 還需要重縮放。雖然有人嘗試使用不同的分布來(lái)進(jìn)行擴(kuò)散,但仍缺乏一個(gè)通用的框架。因此,研究者提出了一個(gè)用于更通用損壞過(guò)程的擴(kuò)散模型設(shè)計(jì)框架。
具體地,他們提出了一個(gè)名為 Soft Score Matching 的新訓(xùn)練目標(biāo)和一種新穎的采樣方法 Momentum Sampler。理論結(jié)果表明,對(duì)于滿足正則條件的損壞過(guò)程,Soft Score MatchIng 能夠?qū)W習(xí)它們的分?jǐn)?shù)(即似然梯度),擴(kuò)散必須將任何圖像轉(zhuǎn)換為具有非零似然的任何圖像。
在實(shí)驗(yàn)部分,研究者在 CelebA 以及 CIFAR-10 上訓(xùn)練模型,其中在 CelebA 上訓(xùn)練的模型實(shí)現(xiàn)了線性擴(kuò)散模型的 SOTA FID 分?jǐn)?shù)——1.85。同時(shí)與使用原版高斯去噪擴(kuò)散訓(xùn)練的模型相比,研究者訓(xùn)練的模型速度顯著更快。
論文地址:https://arxiv.org/pdf/2209.05442.pdf
方法概覽
通常來(lái)說(shuō),擴(kuò)散模型通過(guò)反轉(zhuǎn)逐漸增加噪聲的損壞過(guò)程來(lái)生成圖像。研究者展示了如何學(xué)習(xí)對(duì)涉及線性確定性退化和隨機(jī)加性噪聲的擴(kuò)散進(jìn)行反轉(zhuǎn)。
具體地,研究者展示了使用更通用損壞模型訓(xùn)練擴(kuò)散模型的框架,包含有三個(gè)部分,分別為新的訓(xùn)練目標(biāo) Soft Score Matching、新穎采樣方法 Momentum Sampler 和損壞機(jī)制的調(diào)度。
首先來(lái)看訓(xùn)練目標(biāo) Soft Score Matching,這個(gè)名字的靈感來(lái)自于軟過(guò)濾,是一種攝影術(shù)語(yǔ),指的是去除精細(xì)細(xì)節(jié)的過(guò)濾器。它以一種可證明的方式學(xué)習(xí)常規(guī)線性損壞過(guò)程的分?jǐn)?shù),還在網(wǎng)絡(luò)中合并入了過(guò)濾過(guò)程,并訓(xùn)練模型來(lái)預(yù)測(cè)損壞后與擴(kuò)散觀察相匹配的圖像。
只要擴(kuò)散將非零概率指定為任何干凈、損壞的圖像對(duì),則該訓(xùn)練目標(biāo)可以證明學(xué)習(xí)到了分?jǐn)?shù)。另外,當(dāng)損壞中存在加性噪聲時(shí),這一條件總是可以得到滿足。
具體地,研究者探究了如下形式的損壞過(guò)程。
在過(guò)程中,研究者發(fā)現(xiàn)噪聲在實(shí)證(即更好的結(jié)果)和理論(即為了學(xué)習(xí)分?jǐn)?shù))這兩方面都很重要。這也成為了其與反轉(zhuǎn)確定性損壞的并發(fā)工作 Cold Diffusion 的關(guān)鍵區(qū)別。
其次是采樣方法 Momentum Sampling。研究者證明,采樣器的選擇對(duì)生成樣本質(zhì)量具有顯著影響。他們提出了 Momentum Sampler,用于反轉(zhuǎn)通用線性損壞過(guò)程。該采樣器使用了不同擴(kuò)散水平的損壞的凸組合,并受到了優(yōu)化中動(dòng)量方法的啟發(fā)。
這一采樣方法受到了上文 Yang Song 等人論文提出的擴(kuò)散模型連續(xù)公式化的啟發(fā)。Momentum Sampler 的算法如下所示。
下圖直觀展示了不同采樣方法對(duì)生成樣本質(zhì)量的影響。圖左使用 Naive Sampler 采樣的圖像似乎有重復(fù)且缺少細(xì)節(jié),而圖右 Momentum Sampler 顯著提升了采樣質(zhì)量和 FID 分?jǐn)?shù)。
最后是調(diào)度。即使退化的類型是預(yù)定義的(如模糊),決定在每個(gè)擴(kuò)散步驟中損壞多少并非易事。研究者提出一個(gè)原則性工具來(lái)指導(dǎo)損壞過(guò)程的設(shè)計(jì)。為了找到調(diào)度,他們將沿路徑分布之間的 Wasserstein 距離最小化。直觀地講,研究者希望從完全損壞的分布平穩(wěn)過(guò)渡到干凈的分布。
實(shí)驗(yàn)結(jié)果
研究者在 CelebA-64 和 CIFAR-10 上評(píng)估了提出的方法,這兩個(gè)數(shù)據(jù)集都是圖像生成的標(biāo)準(zhǔn)基線。實(shí)驗(yàn)的主要目的是了解損壞類型的作用。
研究者首先嘗試使用模糊和低幅噪聲進(jìn)行損壞。結(jié)果表明,他們提出的模型在 CelebA 上實(shí)現(xiàn)了 SOTA 結(jié)果,即 FID 分?jǐn)?shù)為 1.85,超越了所有其他僅添加噪聲以及可能重縮放圖像的方法。此外在 CIFAR-10 上獲得的 FID 分?jǐn)?shù)為 4.64,雖未達(dá)到 SOTA 但也具有競(jìng)爭(zhēng)力。
此外,在 CIFAR-10 和 CelebA 數(shù)據(jù)集上,研究者的方法在另一項(xiàng)指標(biāo)采樣時(shí)間上也表現(xiàn)更好。另一個(gè)額外的好處是具有顯著的計(jì)算優(yōu)勢(shì)。與圖像生成去噪方法相比,去模糊(幾乎沒(méi)有噪聲)似乎是一種更有效的操縱。
下圖展示了 FID 分?jǐn)?shù)如何隨著函數(shù)評(píng)估數(shù)量(Number of Function Evaluations, NFE)而變。從結(jié)果可以看到,在 CIFAR-10 和 CelebA 數(shù)據(jù)集上,研究者的模型可以使用明顯更少的步驟來(lái)獲得與標(biāo)準(zhǔn)高斯去噪擴(kuò)散模型相同或更好的質(zhì)量。
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。