混合密度網(wǎng)絡(luò)(MDN)進(jìn)行多元回歸詳解和代碼示例(1)
來源:Deephub Imba
“回歸預(yù)測(cè)建模是逼近從輸入變量 (X) 到連續(xù)輸出變量 (y) 的映射函數(shù) (f) [...] 回歸問題需要預(yù)測(cè)具體的數(shù)值。具有多個(gè)輸入變量的問題通常被稱為多元回歸問題 例如,預(yù)測(cè)房屋價(jià)值,可能在 100,000 美元到 200,000 美元之間
這是另一個(gè)區(qū)分分類問題和回歸問題的視覺解釋如下:
另外一個(gè)例子
DENSITY “密度” 是什么意思?這是一個(gè)快速的通俗示例:
假設(shè)正在為必勝客運(yùn)送比薩。現(xiàn)在記錄剛剛進(jìn)行的每次交付的時(shí)間(以分鐘為單位)。交付 1000 次后,將數(shù)據(jù)可視化以查看工作表現(xiàn)如何。這是結(jié)果:
這是披薩交付時(shí)間數(shù)據(jù)分布的“密度”。平均而言,每次交付需要 30 分鐘(圖中的峰值)。它還表示,在 95% 的情況下(2 個(gè)標(biāo)準(zhǔn)差2sd ),交付需要 20 到 40 分鐘才能完成。密度種類代表時(shí)間結(jié)果的“頻率”?!邦l率”和“密度”的區(qū)別在于:
頻率:如果你在這條曲線下繪制一個(gè)直方圖并對(duì)所有的 bin 進(jìn)行計(jì)數(shù),它將求和為任何整數(shù)(取決于數(shù)據(jù)集中捕獲的觀察總數(shù))。
密度:如果你在這條曲線下繪制一個(gè)直方圖并計(jì)算所有的 bin,它總和為 1。我們也可以將此曲線稱為概率密度函數(shù) (pdf)。
用統(tǒng)計(jì)術(shù)語(yǔ)來說,這是一個(gè)漂亮的正態(tài)/高斯分布。這個(gè)正態(tài)分布有兩個(gè)參數(shù):
均值
標(biāo)準(zhǔn)差:“標(biāo)準(zhǔn)差是一個(gè)數(shù)字,用于說明一組測(cè)量值如何從平均值(平均值)或預(yù)期值中展開。低標(biāo)準(zhǔn)偏差意味著大多數(shù)數(shù)字接近平均值。高標(biāo)準(zhǔn)差意味著數(shù)字更加分散?!?/span>
均值和標(biāo)準(zhǔn)差的變化會(huì)影響分布的形狀。例如:
有許多具有不同類型參數(shù)的各種不同分布類型。例如:
現(xiàn)在讓我們看看這 3 個(gè)分布:
如果我們采用這種雙峰分布(也稱為一般分布):
混合密度網(wǎng)絡(luò)使用這樣的假設(shè),即任何像這種雙峰分布的一般分布都可以分解為正態(tài)分布的混合(該混合也可以與其他類型的分布一起定制 例如拉普拉斯):
混合密度網(wǎng)絡(luò)也是一種人工神經(jīng)網(wǎng)絡(luò)。這是神經(jīng)網(wǎng)絡(luò)的經(jīng)典示例:
輸入層(黃色)、隱藏層(綠色)和輸出層(紅色)。
如果我們將神經(jīng)網(wǎng)絡(luò)的目標(biāo)定義為學(xué)習(xí)在給定一些輸入特征的情況下輸出連續(xù)值。在上面的例子中,給定年齡、性別、教育程度和其他特征,那么神經(jīng)網(wǎng)絡(luò)就可以進(jìn)行回歸的運(yùn)算。
密度網(wǎng)絡(luò)
密度網(wǎng)絡(luò)也是神經(jīng)網(wǎng)絡(luò),其目標(biāo)不是簡(jiǎn)單地學(xué)習(xí)輸出單個(gè)連續(xù)值,而是學(xué)習(xí)在給定一些輸入特征的情況下輸出分布參數(shù)(此處為均值和標(biāo)準(zhǔn)差)。在上面的例子中,給定年齡、性別、教育程度等特征,神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)預(yù)測(cè)期望工資分布的均值和標(biāo)準(zhǔn)差。預(yù)測(cè)分布比預(yù)測(cè)單個(gè)值具有很多的優(yōu)勢(shì),例如能夠給出預(yù)測(cè)的不確定性邊界。這是解決回歸問題的“貝葉斯”方法。下面是預(yù)測(cè)每個(gè)預(yù)期連續(xù)值的分布的一個(gè)很好的例子:
下面的圖片向我們展示了每個(gè)預(yù)測(cè)實(shí)例的預(yù)期值分布:
混合密度網(wǎng)絡(luò)
最后回到正題,混合密度網(wǎng)絡(luò)的目標(biāo)是在給定特定輸入特征的情況下,學(xué)習(xí)輸出混合在一般分布中的所有分布的參數(shù)(此處為均值、標(biāo)準(zhǔn)差和 Pi)。新參數(shù)“Pi”是混合參數(shù),它給出最終混合中給定分布的權(quán)重/概率。
最終結(jié)果如下:
上面的定義和理論基礎(chǔ)已經(jīng)介紹完畢,下面我們開始代碼的演示:
import numpy as np
import pandas as pd
from mdn_model import MDN
from sklearn.datasets import make_moons
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LinearRegression
from sklearn.kernel_ridge import KernelRidge
plt.style.use('ggplot')
生成著名的“半月”型的數(shù)據(jù)集:
X, y = make_moons(n_samples=2500, noise=0.03)
y = X[:, 1].reshape(-1,1)
X = X[:, 0].reshape(-1,1)
x_scaler = StandardScaler()
y_scaler = StandardScaler()
X = x_scaler.fit_transform(X)
y = y_scaler.fit_transform(y)
plt.scatter(X, y, alpha = 0.3)
繪制目標(biāo)值 (y) 的密度分布:
sns.kdeplot(y.ravel(), shade=True)
通過查看數(shù)據(jù),我們可以看到有兩個(gè)重疊的簇:
這時(shí)一個(gè)很好的多模態(tài)分布(一般分布)。如果我們?cè)谶@個(gè)數(shù)據(jù)集上嘗試一個(gè)標(biāo)準(zhǔn)的線性回歸來用 X 預(yù)測(cè) y:
model = LinearRegression()
model.fit(X.reshape(-1,1), y.reshape(-1,1))
y_pred = model.predict(X.reshape(-1,1))
plt.scatter(X, y, alpha = 0.3)
plt.scatter(X,y_pred)
plt.title('Linear Regression')
sns.kdeplot(y_pred.ravel(), shade=True, alpha = 0.15, label = 'Linear Pred dist')
sns.kdeplot(y.ravel(), shade=True, label = 'True dist')
效果必須不好!現(xiàn)在讓嘗試一個(gè)非線性模型(徑向基函數(shù)核嶺回歸):
model = KernelRidge(kernel = 'rbf')
model.fit(X, y)
y_pred = model.predict(X)
plt.scatter(X, y, alpha = 0.3)
plt.scatter(X,y_pred)
plt.title('Non Linear Regression')
sns.kdeplot(y_pred.ravel(), shade=True, alpha = 0.15, label = 'NonLinear Pred dist')
sns.kdeplot(y.ravel(), shade=True, label = 'True dist')
雖然結(jié)果也不盡如人意,但是比上面的線性回歸要好很多了。
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。