時(shí)序預(yù)測(cè)的三種方式:統(tǒng)計(jì)學(xué)模型、機(jī)器學(xué)習(xí)、循環(huán)神經(jīng)網(wǎng)絡(luò)
作者 | luanhz來源 | 小數(shù)志
導(dǎo)讀
時(shí)序預(yù)測(cè)是一類經(jīng)典的問題,在學(xué)術(shù)界和工業(yè)界都有著廣泛的研究和應(yīng)用。甚至說,世間萬物加上時(shí)間維度后都可抽象為時(shí)間序列問題,例如股****價(jià)格、天氣變化等等。關(guān)于時(shí)序預(yù)測(cè)問題的相關(guān)理論也極為廣泛,除了經(jīng)典的各種統(tǒng)計(jì)學(xué)模型外,當(dāng)下火熱的機(jī)器學(xué)習(xí)以及深度學(xué)習(xí)中的循環(huán)神經(jīng)網(wǎng)絡(luò)也都可以用于時(shí)序預(yù)測(cè)問題的建模。今天,本文就來介紹三種方式的簡(jiǎn)單應(yīng)用,并在一個(gè)真實(shí)的時(shí)序數(shù)據(jù)集上加以驗(yàn)證。
時(shí)間序列預(yù)測(cè),其主要任務(wù)是基于某一指標(biāo)的歷史數(shù)據(jù)來預(yù)測(cè)其在未來的取值,例如上圖中的曲線記錄了1949年至1960年共12年144個(gè)月份的每月航班乘客數(shù)(具體單位未經(jīng)考證),那么時(shí)序預(yù)測(cè)要解決的問題就是:給定前9年的歷史數(shù)據(jù),例如1949-1957,那么能否預(yù)測(cè)出1958-1960兩年間的乘客數(shù)量的問題。
為了解決這一問題,大概當(dāng)前主流的解決方式有4種:
統(tǒng)計(jì)學(xué)模型,較為經(jīng)典的AR系列,包括AR、MA、ARMA以及ARIMA等,另外Facebook(準(zhǔn)確的講,現(xiàn)在應(yīng)該叫Meta了)推出的Prophet模型,其實(shí)本質(zhì)上也是一種統(tǒng)計(jì)學(xué)模型,只不過是傳統(tǒng)的趨勢(shì)、周期性成分的基礎(chǔ)上,進(jìn)一步細(xì)化考慮了節(jié)假日、時(shí)序拐點(diǎn)等因素的影響,以期帶來更為精準(zhǔn)的時(shí)序規(guī)律刻畫;
機(jī)器學(xué)習(xí)模型,在有監(jiān)督機(jī)器學(xué)習(xí)中,回歸問題主要解決的是基于一系列Feature來預(yù)測(cè)某一Label的可能取值的問題,那么當(dāng)以歷史數(shù)據(jù)作為Feature時(shí)其實(shí)自然也就可以將時(shí)序預(yù)測(cè)問題抽象為回歸問題,從這一角度講,所有回歸模型都可用于解決時(shí)序預(yù)測(cè)。關(guān)于用機(jī)器學(xué)習(xí)抽象時(shí)序預(yù)測(cè),推薦查看這篇論文《Machine Learning Strategies for Time Series Forecasting》;
深度學(xué)習(xí)模型,深度學(xué)習(xí)主流的應(yīng)用場(chǎng)景當(dāng)屬CV和NLP兩大領(lǐng)域,其中后者就是專門用于解決序列問題建模的問題,而時(shí)間序列當(dāng)然屬于序列數(shù)據(jù)的一種特殊形式,所以自然可以運(yùn)用循環(huán)神經(jīng)網(wǎng)絡(luò)來建模時(shí)序預(yù)測(cè);
隱馬爾科夫模型,馬爾科夫模型是用于刻畫相鄰狀態(tài)轉(zhuǎn)換間的經(jīng)典抽象,而隱馬爾科夫模型則在其基礎(chǔ)上進(jìn)一步增加了隱藏狀態(tài),來以此豐富模型的表達(dá)能力。但其一大假設(shè)條件是未來狀態(tài)僅與當(dāng)前狀態(tài)有關(guān),而不利于利用多個(gè)歷史狀態(tài)來共同參與預(yù)測(cè),較為常用的可能就是天氣預(yù)報(bào)的例子了。
本文主要考慮前三種時(shí)序預(yù)測(cè)建模方法,并分別選?。?)Prophet模型,2)RandomForest回歸模型,3)LSTM三種方案加以測(cè)試。
首先在這個(gè)航班乘客真實(shí)數(shù)據(jù)集上進(jìn)行測(cè)試,依次對(duì)比三個(gè)所選模型的預(yù)測(cè)精度。該數(shù)據(jù)集共有12年間每個(gè)月的乘客數(shù)量,以1958年1月作為切分界面劃分訓(xùn)練集和測(cè)試集,即前9年的數(shù)據(jù)作為訓(xùn)練集,后3年的數(shù)據(jù)作為測(cè)試集驗(yàn)證模型效果。數(shù)據(jù)集切分后的示意圖如下:
df = pd.read_csv("AirPassengers.csv", parse_dates=["date"]).rename(columns={"date":"ds", "value":"y"})X_train = df[df.ds<"19580101"]X_test = df[df.ds>="19580101"]
plt.plot(X_train['ds'], X_train['y'])plt.plot(X_test['ds'], X_test['y'])
1.Prophet模型預(yù)測(cè)。Prophet是一個(gè)高度封裝好的時(shí)序預(yù)測(cè)模型,接受一個(gè)DataFrame作為訓(xùn)練集(要求有ds和y兩個(gè)字段列),在預(yù)測(cè)時(shí)也接受一個(gè)DataFrame,但此時(shí)只需有ds列即可,關(guān)于模型的詳細(xì)介紹可參考其官方文檔:https://facebook.github.io/prophet/。模型訓(xùn)練及預(yù)測(cè)部分核心代碼如下:
from prophet import Prophetpro = Prophet()pro.fit(X_train)pred = pro.predict(X_test)
pro.plot(pred)
訓(xùn)練后的結(jié)果示意圖如下:
當(dāng)然,這是通過Prophet內(nèi)置的可視化函數(shù)給出的結(jié)果,也可通過手動(dòng)繪制測(cè)試集真實(shí)標(biāo)簽與預(yù)測(cè)結(jié)果間的對(duì)比:
易見,雖然序列的整體****上具有良好的擬合結(jié)果,但在具體取值上其實(shí)差距還是比較大的。
2.機(jī)器學(xué)習(xí)模型,這里選用常常用作各種baseline的RandomForest模型。在使用機(jī)器學(xué)習(xí)實(shí)現(xiàn)時(shí)序預(yù)測(cè)時(shí),通常需要通過滑動(dòng)窗口的方式來提取特征和標(biāo)簽,而后在實(shí)現(xiàn)預(yù)測(cè)時(shí)實(shí)際上也需滑動(dòng)的截取測(cè)試集特征實(shí)現(xiàn)單步預(yù)測(cè),參考論文《Machine Learning Strategies for Time Series Forecasting》中的做法,該問題可大致描述如下:
據(jù)此,設(shè)置特征提取窗口長(zhǎng)度為12,構(gòu)建訓(xùn)練集和測(cè)試集的方式如下:
data = df.copy()n = 12for i in range(1, n+1): data['ypre_'+str(i)] = data['y'].shift(i)data = data[['ds']+['ypre_'+str(i) for i in range(n, 0, -1)]+['y']]
# 提取訓(xùn)練集和測(cè)試集X_train = data[data['ds']<"19580101"].dropna()[['ypre_'+str(i) for i in range(n, 0, -1)]]y_train = data[data['ds']<"19580101"].dropna()[['y']]X_test = data[data['ds']>="19580101"].dropna()[['ypre_'+str(i) for i in range(n, 0, -1)]]y_test = data[data['ds']>="19580101"].dropna()[['y']]
# 模型訓(xùn)練和預(yù)測(cè)rf = RandomForestRegressor(n_estimators=10, max_depth=5)rf.fit(X_train, y_train)y_pred = rf.predict(X_test)
# 結(jié)果對(duì)比繪圖y_test.assign(yhat=y_pred).plot()
可見,預(yù)測(cè)效果也較為一般,尤其是對(duì)于最后兩年的預(yù)測(cè)結(jié)果,與真實(shí)值差距還是比較大的。用機(jī)器學(xué)習(xí)模型的思維很容易解釋這一現(xiàn)象:隨機(jī)森林模型實(shí)際上是在根據(jù)訓(xùn)練數(shù)據(jù)集來學(xué)習(xí)曲線之間的規(guī)律,由于該時(shí)序整體呈現(xiàn)隨時(shí)間增長(zhǎng)的趨勢(shì),所以歷史數(shù)據(jù)中的最高點(diǎn)也不足以cover住未來的較大值,因而在測(cè)試集中超過歷史數(shù)據(jù)的所有標(biāo)簽其實(shí)都是無法擬合的。
3.深度學(xué)習(xí)中的循環(huán)神經(jīng)網(wǎng)絡(luò),其實(shí)深度學(xué)習(xí)一般要求數(shù)據(jù)集較大時(shí)才能發(fā)揮其優(yōu)勢(shì),而這里的數(shù)據(jù)集顯然是非常小的,所以僅設(shè)計(jì)一個(gè)最為簡(jiǎn)單的模型:1層LSTM+1層Linear。模型搭建如下:
class Model(nn.Module): def __init__(self): super().__init__() self.rnn = nn.LSTM(input_size=1, hidden_size=10, batch_first=True) self.linear = nn.Linear(10, 1)
def forward(self, x): x, _ = self.rnn(x) x = x[:, -1, :] x = self.linear(x) return x
數(shù)據(jù)集構(gòu)建思路整體同前述的機(jī)器學(xué)習(xí)部分,而后,按照進(jìn)行模型訓(xùn)練煉丹,部分結(jié)果如下:
# 數(shù)據(jù)集轉(zhuǎn)化為3DX_train_3d = torch.Tensor(X_train.values).reshape(*X_train.shape, 1)y_train_2d = torch.Tensor(y_train.values).reshape(*y_train.shape, 1)X_test_3d = torch.Tensor(X_test.values).reshape(*X_test.shape, 1)y_test_2d = torch.Tensor(y_test.values).reshape(*y_test.shape, 1)
# 模型、優(yōu)化器、評(píng)估準(zhǔn)則model = Model()creterion = nn.MSELoss()optimizer = optim.Adam(model.parameters())
# 訓(xùn)練過程for i in range(1000): out = model(X_train_3d) loss = creterion(out, y_train_2d) optimizer.zero_grad() loss.backward() optimizer.step()
if (i+1)%100 == 0: y_pred = model(X_test_3d) loss_test = creterion(y_pred, y_test_2d) print(i, loss.item(), loss_test.item())
# 訓(xùn)練結(jié)果99 65492.08984375 188633.796875199 64814.4375 187436.4375299 64462.09765625 186815.5399 64142.70703125 186251.125499 63835.5 185707.46875599 63535.15234375 185175.1875699 63239.39453125 184650.46875799 62947.08203125 184131.21875899 62657.484375 183616.203125999 62370.171875 183104.671875
通過上述1000個(gè)epoch,大體可以推斷該模型不會(huì)很好的擬合了,所以果斷放棄吧!
當(dāng)然必須指出的是,上述測(cè)試效果只能說明3種方案在該數(shù)據(jù)集上的表現(xiàn),而不能代表這一類模型在用于時(shí)序預(yù)測(cè)問題時(shí)的性能。實(shí)際上,時(shí)序預(yù)測(cè)問題本身就是一個(gè)需要具體問題具體分析的場(chǎng)景,沒有放之四海而皆準(zhǔn)的好模型,就像“No Free Lunch”一樣!
本文僅是作為時(shí)序預(yù)測(cè)系列推文的一個(gè)牛刀小試,后續(xù)將不定期更新其他相關(guān)心得和總結(jié)。
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。