大一統(tǒng)視角理解擴散模型Understanding Diffusion Models: A Unified Perspective(2)
在擴散模型里,有幾個重要的假設(shè)。其中一個就是每一步擴散過程的變換,都是對前一步結(jié)果的高斯變換(上一節(jié)MHVAE的限制條件2):
與MHVAE不同,編碼器側(cè)的潛在向量分布并不經(jīng)過學(xué)習(xí)得到,而是固定為線性高斯模型
這一點和VAE有很大不同。VAE里編碼器側(cè)的潛在向量的分布是通過模型訓(xùn)練得到的。而擴散模型里,前向加噪過程里的每一步都是基于上一步結(jié)果的高斯變換。其中 alpha_t 一般當作超參設(shè)置得到。這點對于我們計算擴散模型的證據(jù)下界有很大幫助。因為我們可以基于輸入x0確切地知道前向過程里的某一步的具體狀態(tài),從而監(jiān)督我們的預(yù)測。
基于式31,我們可以遞歸式地對x0不斷加噪變換,得到最終xt的表達式:
xt可以寫為關(guān)于x0的一個高斯分布的采樣結(jié)果
所以對于式58里噪音匹配項里的監(jiān)督信號,我們可以重寫成以下形式,其中根據(jù)式70,我們可以得到q(xt|x0)和q(xt-1|x0)的表達式,而q(xt|xt-1, x0)因為是前向擴散過程,可以應(yīng)用馬爾可夫性質(zhì)看做q(xt|xt-1)使用式31得到具體表達式。
式58里的監(jiān)督信號可以通過x0計算具體的值
代入每一項q所代表的高斯函數(shù)表達式后,我們最后可以得到一個新的高斯分布表達式,其中每一項都是具體可求的:
q(xt-1|xt,x0)的解析形式
參考已經(jīng)證明了前向加噪過程可以寫為一個高斯分布了。在擴散模型的初始論文[2]里提到,對于一個連續(xù)的高斯擴散過程,其逆過程與前向過程的方程形式(functional form)一致。所以我們將對去噪匹配項里的p_theta(xt-1|xt)也采用高斯分布的形式(更加具體的一些推導(dǎo)放在了末尾的補充里)。注意式58里,對兩個高斯分布求KL散度,其解析解的形式如下:
兩個高斯分布的KL散度解析解
我們現(xiàn)在已知其中一個高斯分布(左側(cè))的參數(shù),現(xiàn)在如果我們令右側(cè)的高斯分布和左側(cè)高斯分布的方差保持一致。那么優(yōu)化該KL散度的解析式將簡化為以下形式:
式58的噪音匹配項簡化為最小化前后向均值的預(yù)測誤差
如此一來式58的噪音匹配項就被簡化為最小化前后向均值的預(yù)測誤差(式92)。讀者請注意,以下的大一統(tǒng)的三個角度來看待Diffusion model,實質(zhì)上都是對式92里mu_q的不同變形所推論出來的。 其中mu_q是關(guān)于xt, x0的函數(shù),而mu_theta是關(guān)于xt和t的函數(shù)。其中通過式84,我們有mu_q的準確計算結(jié)果,而因為mu_theta是關(guān)于xt的函數(shù)。我們可以將其寫為類似式84的形式(注意,有關(guān)為什么可以忽略方差并且讓均值選取這個形式放在了最末尾的補充討論里。但關(guān)于這個形式的選擇的深層原因?qū)嵸|(zhì)上開辟了一個全新的領(lǐng)域來研究,并且關(guān)于該領(lǐng)域的研究直接導(dǎo)向了擴散模型之后的一系列加速采樣技術(shù)的出現(xiàn))
將后向預(yù)測的均值寫為類似前向加噪的形式
比較式84與94可知,x_hat是我們通過噪音數(shù)據(jù)xt來預(yù)測原始數(shù)據(jù)x0的神經(jīng)網(wǎng)絡(luò)。那么我們可以將式58里證據(jù)下界的噪音匹配項,最終寫為
噪聲匹配項的最終形式
那么,我們最后得到擴散模型的優(yōu)化,最終表現(xiàn)為訓(xùn)練一個神經(jīng)網(wǎng)絡(luò),以任意時間步的噪音圖像為輸入,來預(yù)測最初的原始圖像!此時優(yōu)化目標轉(zhuǎn)化為了最小化預(yù)測誤差。同時式58上的對所有時間步的噪音匹配項求和的優(yōu)化,可以近似為對每一時間步上的預(yù)測誤差的期望的最小值,而該優(yōu)化目標可以通過隨機采樣近似:
該優(yōu)化目標可以通過隨機采樣實現(xiàn)
為什么Calvin Luo的這篇論文叫做大一統(tǒng)視角來看待擴散模型?以上我們花了不菲的篇幅論證了擴散模型的優(yōu)化目標可以最終轉(zhuǎn)化為訓(xùn)練一個神經(jīng)網(wǎng)絡(luò)在任意時間步從xt預(yù)測原始輸入x0。以下我們將論述如何通過對mu_q不同的推導(dǎo)得到類似的角度看待擴散模型。
首先,我們已經(jīng)知道給定每個時間步的噪聲系數(shù)alpha_t之后,我們可以由初始輸入x0遞歸得到xt。同理,給定xt我們也可以求得x0。那么對式69重置后,我們可以得到式115.
將式69里的xt和x0關(guān)系重置后可得式115
重新將式115代入式84里,我們所得的關(guān)于時間步t的真實均值表達式mu_q后,我們可以得到以下推導(dǎo):
在推導(dǎo)真實均值時替換x0
注意在上一次推導(dǎo)的過程中,mu_q里的xt在計算kl散度的解析式時被抵消掉了,而x0我們采取的是用神經(jīng)網(wǎng)絡(luò)直接擬合的策略。而在這一次的推導(dǎo)過程中,x0被替換成了關(guān)于xt的表達式(關(guān)于alpha_bar和epsilon_0)后,我們可以得到mu_q的新的表達式,依舊關(guān)于xt,只是不再與x0相關(guān),而是與epsilon_0相關(guān)(式124)。其中,和式94一樣,我們忽略方差(將其設(shè)為與前向一致)并將希望擬合的mu_theta寫成與真實均值mu_q一樣的形式,只是將epsilon_0替換為神經(jīng)網(wǎng)絡(luò)的擬合項后我們可以得到式125。
與上次推導(dǎo)時替換x0為神經(jīng)網(wǎng)絡(luò)所擬合項一樣,這次換為擬合初始噪聲項
將我們新得到的兩個均值表達式重新代入KL散度的表達式里,xt再次被抵消掉(因為mu_theta和mu_q選取的形式一致)最終只剩下epsilon_0和epsilon_theta的差值。注意式130和式99的相似性!
最終對證據(jù)下界里的去噪匹配項的優(yōu)化可以寫成關(guān)于初始噪聲和其擬合項的差的最小化
至此,我們得到了對擴散模型的第二種直觀理解。對于一個變分擴散模型VDM,我們優(yōu)化該模型的證據(jù)下界既等價于優(yōu)化其在所有時間步上對初始圖像的預(yù)測誤差的期望,也等價于優(yōu)化在所有時間步上對噪聲的預(yù)測誤差的期望! 事實上DDPM采取的做法就是式130的做法(注意DDPM里的表達式實際上用的是epsilon_t,關(guān)于這點在文末也會討論)。
下面筆者將概括第三種看待VDM的推導(dǎo)方式。這種方式主要來自于SongYang博士的系列論文,非常直觀。并且該系列論文將擴散模型這種離散的多步去噪過程統(tǒng)一成了一個連續(xù)的隨機微分方程(SDE)的特殊形式。SongYang博士因此獲得了ICLR2021的最佳論文獎!后續(xù)來自清華大學(xué)的基于將該SDE轉(zhuǎn)化為常微分方程ODE后的采樣提速論文,也獲得了2022ICLR的最佳論文獎!關(guān)于該論文的一些細節(jié)和直觀理解,SongYang博士在他自己的博客里給出了非常精彩和直觀的講解。有興趣的讀者可以點開本文初始的第二個鏈接查看。以下只對大一統(tǒng)視角下的第三種視角做簡短的概括。
第三種推導(dǎo)方式主要基于Tweedie's formula.該公式主要闡述了對于一個指數(shù)家族的分布的真實均值,在給定了采樣樣本后,可以通過采樣樣本的最大似然概率(即經(jīng)驗均值)加上一個關(guān)于分數(shù)(score)預(yù)估的校正項來預(yù)估。注意score在這里的定義是真實數(shù)據(jù)分布的對數(shù)似然關(guān)于輸入xt的梯度。即
score的定義
根據(jù)Tweedie's formula,對于一個高斯變量z~N(mu_z, sigma_z)來說, 該高斯變量的真實均值的預(yù)估是:
Tweedie's formula對高斯變量的應(yīng)用
我們知道在訓(xùn)練時,模型的輸入xt關(guān)于x0的表達式如下
上文里的式70
我們也知道根據(jù)Tweedie's formula的高斯變量的真實均值預(yù)估我們可以得到下式
將式70的方差代入Tweedie's formula
那么聯(lián)立兩式的關(guān)于均值的表達式后,我們可以得到x0關(guān)于score的表達式133
將x0寫為關(guān)于score的表達式
如上一種推導(dǎo)方式所做的一樣,再一次重新將x0的表達式代入式84對真實均值mu_q的表達式里:(注意式135到136的變形主要在分子里最右邊的alpha_bar_t到alpha_t, 約去了根號下alpha_bar_t-1)
將x0的關(guān)于score表達式代入式84
同樣,將mu_theta采取和mu_q一樣的形式,并用神經(jīng)網(wǎng)絡(luò)s_theta來近似score后, 我們得到了新的mu_theta的表達式143。
關(guān)于score的mu_theta的表達式
再再再同樣,和上種推導(dǎo)里的做法一樣,我們再將新的mu_theta, mu_q代入證據(jù)下界里KL散度的損失項我們可以得到一個最終的優(yōu)化目標
將新的mu的表達式代入證據(jù)下界的優(yōu)化目標里
事實上,比較式148和式130的形式,可以說是非常的接近了。那么我們的score function delta_p(xt)和初始噪聲epsilon_0是否有關(guān)聯(lián)呢?聯(lián)立關(guān)于x0的兩個表達式133和115我們可以得到
score function和初始噪聲間的關(guān)系
讀者如果將式151代入148會發(fā)現(xiàn)和式130等價!直觀上來講,score function描述的是如何在數(shù)據(jù)空間里最大化似然概率的更新向量。而又因為初始噪聲是在原輸入的基礎(chǔ)上加入的,那么往噪聲的反方向(也是最佳方向)更新實質(zhì)上等價于去噪的過程。而數(shù)學(xué)上講,對score function的建模也等價于對初始噪聲乘上負系數(shù)的建模!
至此我們終于將擴散模型的三個形式的所有推導(dǎo)整理完畢!即對變分擴散模型VDM的訓(xùn)練等價于訓(xùn)練一個神經(jīng)網(wǎng)絡(luò)來預(yù)測原輸入x0,也等價于預(yù)測噪聲epsilon, 也等價于預(yù)測初始輸入在特定時間步的score delta_logp(xt)。
讀到這里,相比讀者也已經(jīng)發(fā)現(xiàn),不同的推導(dǎo)所得出的不同結(jié)果,都來自于對證據(jù)下界里去噪匹配項的不同推導(dǎo)過程。而不同的變形,基本上都是利用了MHVAE里最開始提到的三點基本假設(shè)所得。
Drawbacks to Consider盡管擴散模型在最近兩年成功出圈,引爆了業(yè)界,學(xué)術(shù)界甚至普通人對文本生成圖像的AI模型的關(guān)注,但擴散模型這個體系本身依舊存在著一些缺陷:
- 擴散模型本身盡管理論框架已經(jīng)比較完善,公式推導(dǎo)也十分優(yōu)美。但仍然非常不直觀。最起碼從一個完全噪聲的輸入不斷優(yōu)化的這個過程和人類的思維過程相去甚遠。
- 擴散模型和GAN或者VAE相比,所學(xué)的潛在向量不具備任何語義和結(jié)構(gòu)的可解釋性。上文提到了擴散模型可以看做是特殊的MHVAE,但里面每一層的潛在向量間都是線性高斯的形式,變化有限。
- 而擴散模型的潛在向量要求維度與輸入一致這一點,則更加死地限制住了潛在向量的表征能力。
- 擴散模型的多步迭代導(dǎo)致了擴散模型的生成往往耗時良久。
不過學(xué)術(shù)界對以上的一些難題其實也提出了不少解決方案。比如擴散模型的可解釋性問題。筆者最近就發(fā)現(xiàn)了一些工作將score-matching直接應(yīng)用在了普通VAE的潛在向量的采樣上。這是一個非常自然的創(chuàng)新點,就和數(shù)年前的flow-based-vae一樣。而耗時良久的問題,今年ICLR的最佳論文也將采樣這個問題加速和壓縮到了幾十步內(nèi)就可以生成非常高質(zhì)量的結(jié)果。
但是對于擴散模型在文本生成領(lǐng)域的應(yīng)用最近似乎還不多,除了prefix-tuning的作者xiang-lisa-li的一篇論文[3]
之外筆者暫未關(guān)注到任何工作。而具體來講,如果將擴散模型直接用在文本生成上,仍有諸多不便。比如輸入的尺寸在整個擴散過程必須保持一致就決定了使用者必須事先決定好想生成的文本的長度。而且做有引導(dǎo)的條件生成還好,要用擴散模型訓(xùn)練出一個開放域的文本生成模型恐怕難度不低。
本篇筆記著重的是在探討大一統(tǒng)角度下的擴散模型推斷。但具體對score matching如何訓(xùn)練,如何引導(dǎo)擴散模型生成我們想要的條件分布還沒有寫出來。筆者打算在下一篇探討最近一些將擴散模型應(yīng)用在受控文本生成領(lǐng)域的方法調(diào)研里詳細記錄和比較一下
補充- 關(guān)于為什么擴散核是高斯變換的擴散過程的逆過程也是高斯變換的問題,來自清華大神的一篇知乎回答里[4] 給出了比較直觀的解釋。其中第二行是將p_t-1和p_t近似。第三行是對logpt(x_t-1)使用一階泰勒展開消去了logpt(xt)。第四行是直接代入了q(xt|xt-1)的表達式。于是我們得到了一個高斯分布的表達式。
擴散的逆過程也是高斯分布
- 在式94和式125,我們都將對真實高斯分布q的均值mu_q的近似mu_theta建模成了與我們所推導(dǎo)出的mu_q一致的形式,并且將方差設(shè)置為了與q的方差一致的形式。直觀上來講,這樣建模的好處很多,一方面是根據(jù)KL散度對兩個高斯分布的解析式來說,這樣我們可以約掉和抵消掉絕大部分的項,簡化了建模。另一方面真實分布和近似分布都依賴于xt。在訓(xùn)練時我們的輸入就是xt,采取和真實分布形式一樣的表達式?jīng)]有泄漏任何信息。并且在工程上DDPM也驗證了類似的簡化是事實上可行的。但實際上可以這樣做的原因背后是從2021年以來的一系列論文里復(fù)雜的數(shù)理證明所在解釋的目標。 同樣引用清華大佬[4]的回答:
DDPM里簡化去噪的高斯分布的做法其實蘊含著深刻的道理
- 在DDPM里,其最終的優(yōu)化目標是epsilon_t而不是epsilon_0。即預(yù)測的誤差到底是初始誤差還是某個時間步上的初始誤差。誰對誰錯?實際上這個誤解來源于我們對xt關(guān)于x0的表達式的求解中的誤解。從式63開始的連續(xù)幾步推導(dǎo),都應(yīng)用到了一個高斯性質(zhì),即兩個獨立高斯分布的和的均值與方差等于原分布的均值和與方差和。而實質(zhì)上我們在應(yīng)用重參數(shù)化技巧求xt的過程中,是遞歸式的不斷引入了新的epsilon來替換遞歸中的x_n里的epsilon。那么到最后,我們所得到的epsilon無非是一個囊括了所有擴散過程中的epsilon。這個噪聲即可以說是t,也可以說是0,甚至最準確來說應(yīng)該不等于任何一個時間步,就叫做噪聲就好!
DDPM的優(yōu)化目標
- 關(guān)于對證據(jù)下界的不同簡化形式。其中我們提到第二種對噪聲的近似是DDPM所采用的建模方式。但是對初始輸入的近似其實也有論文采用。也就是上文提及的將擴散模型應(yīng)用在可控文本生成的論文里[3]所采用的形式。該論文每輪直接預(yù)測初始Word-embedding。而第三種score-matching的角度可以參照SongYang博士的系列論文[5]來看。里面的優(yōu)化函數(shù)的形式用的是第三種。
- 本篇筆記著重于講述擴散模型的變分下界的公式推導(dǎo),關(guān)于擴散模型與能量模型,朗之萬動力學(xué),隨機微分方程等一系列名詞的關(guān)系本篇筆記并無涉及。 筆者將在另外一篇筆記里梳理相關(guān)的理解。
參考
- ^Improving Variational Inference with Inverse Autoregressive Flow https://arxiv.org/abs/1606.04934
- ^Deep Unsupervised Learning using Nonequilibrium Thermodynamics https://arxiv.org/abs/1503.03585
- ^abDiffusion-LM Improves Controllable Text Generation https://arxiv.org/abs/2205.14217
- ^abdiffusion model最近在圖像生成領(lǐng)域大紅大紫,如何看待它的風(fēng)頭開始超過GAN?- 我想唱high C的回答 - 知乎 https://www.zhihu.com/question/536012286/answer/2533146567
- ^SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS https://arxiv.org/abs/2011.13456
*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權(quán)請聯(lián)系工作人員刪除。