博客專欄

EEPW首頁(yè) > 博客 > 混合密度網(wǎng)絡(luò)(MDN)進(jìn)行多元回歸詳解和代碼示例(1)

混合密度網(wǎng)絡(luò)(MDN)進(jìn)行多元回歸詳解和代碼示例(1)

發(fā)布人:數(shù)據(jù)派THU 時(shí)間:2022-03-13 來源:工程師 發(fā)布文章
來源:Deephub Imba


回歸


“回歸預(yù)測(cè)建模是逼近從輸入變量 (X) 到連續(xù)輸出變量 (y) 的映射函數(shù) (f) [...] 回歸問題需要預(yù)測(cè)具體的數(shù)值。具有多個(gè)輸入變量的問題通常被稱為多元回歸問題 例如,預(yù)測(cè)房屋價(jià)值,可能在 100,000 美元到 200,000 美元之間
這是另一個(gè)區(qū)分分類問題和回歸問題的視覺解釋如下:
圖片
另外一個(gè)例子

圖片

密度


DENSITY “密度” 是什么意思?這是一個(gè)快速的通俗示例:
假設(shè)正在為必勝客運(yùn)送比薩。現(xiàn)在記錄剛剛進(jìn)行的每次交付的時(shí)間(以分鐘為單位)。交付 1000 次后,將數(shù)據(jù)可視化以查看工作表現(xiàn)如何。這是結(jié)果:圖片
這是披薩交付時(shí)間數(shù)據(jù)分布的“密度”。平均而言,每次交付需要 30 分鐘(圖中的峰值)。它還表示,在 95% 的情況下(2 個(gè)標(biāo)準(zhǔn)差2sd ),交付需要 20 到 40 分鐘才能完成。密度種類代表時(shí)間結(jié)果的“頻率”?!邦l率”和“密度”的區(qū)別在于:

  • 頻率:如果你在這條曲線下繪制一個(gè)直方圖并對(duì)所有的 bin 進(jìn)行計(jì)數(shù),它將求和為任何整數(shù)(取決于數(shù)據(jù)集中捕獲的觀察總數(shù))。

  • 密度:如果你在這條曲線下繪制一個(gè)直方圖并計(jì)算所有的 bin,它總和為 1。我們也可以將此曲線稱為概率密度函數(shù) (pdf)。

  • 用統(tǒng)計(jì)術(shù)語(yǔ)來說,這是一個(gè)漂亮的正態(tài)/高斯分布。這個(gè)正態(tài)分布有兩個(gè)參數(shù):


均值


  • 標(biāo)準(zhǔn)差:“標(biāo)準(zhǔn)差是一個(gè)數(shù)字,用于說明一組測(cè)量值如何從平均值(平均值)或預(yù)期值中展開。低標(biāo)準(zhǔn)偏差意味著大多數(shù)數(shù)字接近平均值。高標(biāo)準(zhǔn)差意味著數(shù)字更加分散?!?/span>


均值和標(biāo)準(zhǔn)差的變化會(huì)影響分布的形狀。例如:
圖片
有許多具有不同類型參數(shù)的各種不同分布類型。例如:

圖片

混合密度


現(xiàn)在讓我們看看這 3 個(gè)分布:
圖片
如果我們采用這種雙峰分布(也稱為一般分布):

圖片
混合密度網(wǎng)絡(luò)使用這樣的假設(shè),即任何像這種雙峰分布的一般分布都可以分解為正態(tài)分布的混合(該混合也可以與其他類型的分布一起定制 例如拉普拉斯):
圖片

網(wǎng)絡(luò)架構(gòu)


混合密度網(wǎng)絡(luò)也是一種人工神經(jīng)網(wǎng)絡(luò)。這是神經(jīng)網(wǎng)絡(luò)的經(jīng)典示例:
圖片
輸入層(黃色)、隱藏層(綠色)和輸出層(紅色)。
如果我們將神經(jīng)網(wǎng)絡(luò)的目標(biāo)定義為學(xué)習(xí)在給定一些輸入特征的情況下輸出連續(xù)值。在上面的例子中,給定年齡、性別、教育程度和其他特征,那么神經(jīng)網(wǎng)絡(luò)就可以進(jìn)行回歸的運(yùn)算。
圖片


密度網(wǎng)絡(luò)


圖片
密度網(wǎng)絡(luò)也是神經(jīng)網(wǎng)絡(luò),其目標(biāo)不是簡(jiǎn)單地學(xué)習(xí)輸出單個(gè)連續(xù)值,而是學(xué)習(xí)在給定一些輸入特征的情況下輸出分布參數(shù)(此處為均值和標(biāo)準(zhǔn)差)。在上面的例子中,給定年齡、性別、教育程度等特征,神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)預(yù)測(cè)期望工資分布的均值和標(biāo)準(zhǔn)差。預(yù)測(cè)分布比預(yù)測(cè)單個(gè)值具有很多的優(yōu)勢(shì),例如能夠給出預(yù)測(cè)的不確定性邊界。這是解決回歸問題的“貝葉斯”方法。下面是預(yù)測(cè)每個(gè)預(yù)期連續(xù)值的分布的一個(gè)很好的例子:

圖片
下面的圖片向我們展示了每個(gè)預(yù)測(cè)實(shí)例的預(yù)期值分布:

圖片


混合密度網(wǎng)絡(luò)


最后回到正題,混合密度網(wǎng)絡(luò)的目標(biāo)是在給定特定輸入特征的情況下,學(xué)習(xí)輸出混合在一般分布中的所有分布的參數(shù)(此處為均值、標(biāo)準(zhǔn)差和 Pi)。新參數(shù)“Pi”是混合參數(shù),它給出最終混合中給定分布的權(quán)重/概率。
圖片
最終結(jié)果如下:

圖片

示例1:?jiǎn)巫兞繑?shù)據(jù)的 MDN 類


上面的定義和理論基礎(chǔ)已經(jīng)介紹完畢,下面我們開始代碼的演示:

import numpy as np
import pandas as pd

from mdn_model import MDN

from sklearn.datasets import make_moons
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.linear_model import LinearRegression
from sklearn.kernel_ridge import KernelRidge

plt.style.use('ggplot')


生成著名的“半月”型的數(shù)據(jù)集:

X, y = make_moons(n_samples=2500, noise=0.03)
y = X[:, 1].reshape(-1,1)
X = X[:, 0].reshape(-1,1)

x_scaler = StandardScaler()
y_scaler = StandardScaler()

X = x_scaler.fit_transform(X)
y = y_scaler.fit_transform(y)

plt.scatter(X, y, alpha = 0.3)

圖片


繪制目標(biāo)值 (y) 的密度分布:

sns.kdeplot(y.ravel(), shade=True)

通過查看數(shù)據(jù),我們可以看到有兩個(gè)重疊的簇:


圖片


這時(shí)一個(gè)很好的多模態(tài)分布(一般分布)。如果我們?cè)谶@個(gè)數(shù)據(jù)集上嘗試一個(gè)標(biāo)準(zhǔn)的線性回歸來用 X 預(yù)測(cè) y:

model = LinearRegression()
model.fit(X.reshape(-1,1), y.reshape(-1,1))
y_pred = model.predict(X.reshape(-1,1))

plt.scatter(X, y, alpha = 0.3)
plt.scatter(X,y_pred)
plt.title('Linear Regression')

圖片

sns.kdeplot(y_pred.ravel(), shade=True, alpha = 0.15, label = 'Linear Pred dist')      
sns.kdeplot(y.ravel(), shade=True, label = 'True dist')

圖片


效果必須不好!現(xiàn)在讓嘗試一個(gè)非線性模型(徑向基函數(shù)核嶺回歸):


model = KernelRidge(kernel = 'rbf')
model.fit(X, y)
y_pred = model.predict(X)


plt.scatter(X, y, alpha = 0.3)
plt.scatter(X,y_pred)
plt.title('Non Linear Regression')

圖片

sns.kdeplot(y_pred.ravel(), shade=True, alpha = 0.15, label = 'NonLinear Pred dist')      
sns.kdeplot(y.ravel(), shade=True, label = 'True dist')

圖片
雖然結(jié)果也不盡如人意,但是比上面的線性回歸要好很多了。


*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。



關(guān)鍵詞: AI

相關(guān)推薦

技術(shù)專區(qū)

關(guān)閉