時(shí)空?qǐng)D神經(jīng)網(wǎng)絡(luò)原理及Pytorch實(shí)現(xiàn)
在我們生活的這個(gè)充滿聯(lián)系的世界中,從微觀的分子結(jié)構(gòu)到宏觀的社交網(wǎng)絡(luò),再到復(fù)雜的城市設(shè)計(jì)結(jié)構(gòu),都隱藏著一張張相互關(guān)聯(lián)的圖數(shù)據(jù)。這些圖數(shù)據(jù)仿佛一張張神秘的網(wǎng),將世界萬(wàn)物緊密相連。而圖神經(jīng)網(wǎng)絡(luò)(GNN)作為一種革命性的技術(shù),正以其強(qiáng)大的能力,逐漸揭開(kāi)這些圖數(shù)據(jù)的面紗,讓我們能夠更深入地理解和利用它們。
圖神經(jīng)網(wǎng)絡(luò)的出現(xiàn),為我們提供了一種全新的建模和學(xué)習(xí)方式。它不僅能夠捕捉數(shù)據(jù)的空間結(jié)構(gòu),還能夠揭示圖結(jié)構(gòu)中的復(fù)雜關(guān)系。無(wú)論是在生物學(xué)領(lǐng)域,如蛋白質(zhì)結(jié)構(gòu)分析和藥物發(fā)現(xiàn),還是在社會(huì)學(xué)領(lǐng)域,如社交網(wǎng)絡(luò)模擬和輿情分析,圖神經(jīng)網(wǎng)絡(luò)都展現(xiàn)出了驚人的應(yīng)用潛力。
更令人興奮的是,圖神經(jīng)網(wǎng)絡(luò)還可以與其他機(jī)器學(xué)習(xí)模型進(jìn)行融合,形成更加強(qiáng)大的模型。例如,將圖神經(jīng)網(wǎng)絡(luò)與序列模型結(jié)合,形成時(shí)空?qǐng)D神經(jīng)網(wǎng)絡(luò)(Spatail-Temporal Graph),不僅能夠捕捉數(shù)據(jù)的時(shí)間和空間依賴性,還能夠更全面地揭示數(shù)據(jù)的內(nèi)在規(guī)律和趨勢(shì)。這種融合模型的出現(xiàn),為各個(gè)領(lǐng)域的研究和應(yīng)用帶來(lái)了更多的可能性。
在時(shí)空?qǐng)D神經(jīng)網(wǎng)絡(luò)中,時(shí)間維度被巧妙地引入到了圖結(jié)構(gòu)中。這意味著,原本靜止的節(jié)點(diǎn)特征現(xiàn)在會(huì)隨著時(shí)間的推移而發(fā)生變化。這種變化不僅反映了節(jié)點(diǎn)之間的動(dòng)態(tài)關(guān)系,還為我們提供了更豐富的信息,使我們能夠更準(zhǔn)確地預(yù)測(cè)和分析各種復(fù)雜現(xiàn)象。
不過(guò),GNN模型和序列模型(如簡(jiǎn)單RNN、LSTM或GRU)本身就復(fù)雜。結(jié)合這些模型以處理空間和時(shí)間依賴性是強(qiáng)大的,但也很復(fù)雜:難以理解,也難以實(shí)現(xiàn)。
所以在這篇文章中,我們將深入探討這些模型的原理,并實(shí)現(xiàn)一個(gè)相對(duì)簡(jiǎn)單的示例,以更深入地理解它們的能力和應(yīng)用。
圖神經(jīng)網(wǎng)絡(luò)(GNN)
我們先介紹一些入門(mén)的知識(shí)簡(jiǎn)要討論GNN。
圖G可以定義為G = (V, E),其中V是節(jié)點(diǎn)集,E是它們之間的邊。
一個(gè)包含n個(gè)節(jié)點(diǎn)的圖的特征矩陣,每個(gè)節(jié)點(diǎn)具有f個(gè)特征,是所有特征的連接:
GNN的關(guān)鍵問(wèn)題是所有連接節(jié)點(diǎn)之間的消息傳遞,這種鄰居特征轉(zhuǎn)換和聚合可以寫(xiě)成:
A是圖的鄰接矩陣,I是允許自連接的單位矩陣。雖然這不是完整的方程,但這已經(jīng)可以說(shuō)明可以學(xué)習(xí)不同節(jié)點(diǎn)之間空間依賴性的圖卷積網(wǎng)絡(luò)的基礎(chǔ)。一個(gè)經(jīng)典的圖神經(jīng)網(wǎng)絡(luò)如下圖所示:
時(shí)空?qǐng)D神經(jīng)網(wǎng)絡(luò) (ST-GNN)
ST-GNN中每個(gè)時(shí)間步都是一個(gè)圖,并通過(guò)GCN/GAT網(wǎng)絡(luò)傳遞,以獲得嵌入數(shù)據(jù)空間相互依賴性的結(jié)果編碼圖。然后這些編碼圖可以像時(shí)間序列數(shù)據(jù)一樣進(jìn)行建模,只要保留每個(gè)時(shí)間步驟的數(shù)據(jù)的圖結(jié)構(gòu)的完整性。下圖演示了這兩個(gè)步驟,時(shí)間模型可以是從ARIMA或簡(jiǎn)單的循環(huán)神經(jīng)網(wǎng)絡(luò)或者是transformers的任何序列模型。
我們下面使用簡(jiǎn)單的循環(huán)神經(jīng)網(wǎng)絡(luò)來(lái)繪制ST-GNN的組件
上面就是ST-GNN的基本原理,將GNN和序列模型(如RNN、LSTM、GRU、Transformers 等)結(jié)合。如果你已經(jīng)熟悉這些序列和GNN模型,那么理論來(lái)說(shuō)是非常簡(jiǎn)單的,但是實(shí)際操作的時(shí)候就會(huì)有一些復(fù)雜,所以我們下面將直接使用Pytorch實(shí)現(xiàn)一個(gè)簡(jiǎn)單的ST-GNN。
ST-GNN的Pytorch實(shí)現(xiàn)
首先要說(shuō)明:為了用于演示我將使用大型科技公司的股市數(shù)據(jù)。雖然這些數(shù)據(jù)本質(zhì)上不是圖數(shù)據(jù),但這種網(wǎng)絡(luò)可能會(huì)捕捉到這些公司之間的相互依賴性,例如一個(gè)公司的表現(xiàn)(好或壞)可能反過(guò)來(lái)影響市場(chǎng)中其他公司的價(jià)值。但這只是一個(gè)演示,我們并不建議在股市預(yù)測(cè)中使用ST-GNN。
加載數(shù)據(jù),直接使用yfinance里面什么都有
import yfinance as yf
import datetime as dt
import pandas as pd
from sklearn.preprocessing import StandardScaler
import plotly.graph_objs as go
from plotly.offline import iplot
import matplotlib.pyplot as plt
############ Dataset download #################
start_date = dt.datetime(2013,1,1)
end_date = dt.datetime(2024,3,7)
#loading from yahoo finance
google = yf.download("GOOGL",start_date, end_date)
apple = yf.download("AAPL",start_date, end_date)
Microsoft = yf.download("MSFT", start_date, end_date)
Amazon = yf.download("AMZN", start_date, end_date)
meta = yf.download("META", start_date, end_date)
Nvidia = yf.download("NVDA", start_date, end_date)
data = pd.DataFrame({'google': google['Open'],'microsoft': Microsoft['Open'],'amazon': Amazon['Open'],
'Nvidia': Nvidia['Open'],'meta': meta['Open'], 'apple': apple['Open']})
############## Scaling data ######################
scaler = StandardScaler()
data_scaled = pd.DataFrame(scaler.fit_transform(data), columns=data.columns)
為了適應(yīng)ST-GNN,所以我們要將數(shù)據(jù)進(jìn)行轉(zhuǎn)換以適應(yīng)模型的要求
將標(biāo)量時(shí)間序列數(shù)據(jù)集轉(zhuǎn)換為圖形數(shù)據(jù)結(jié)構(gòu)是一個(gè)將傳統(tǒng)數(shù)據(jù)轉(zhuǎn)換為圖神經(jīng)網(wǎng)絡(luò)可以處理的形式的關(guān)鍵步驟。這里描述的功能和類如下:
- 鄰接矩陣的定義:AdjacencyMatrix 函數(shù)定義了圖的鄰接矩陣(連通性),這通常是基于手頭物理系統(tǒng)的結(jié)構(gòu)來(lái)完成的。然而,在這里,作者僅使用了一個(gè)全1矩陣,即所有節(jié)點(diǎn)都與所有其他節(jié)點(diǎn)相連。
- 股市數(shù)據(jù)集類:StockMarketDataset 類旨在為訓(xùn)練時(shí)空?qǐng)D神經(jīng)網(wǎng)絡(luò)(ST-GNNs)創(chuàng)建數(shù)據(jù)集。這個(gè)類中包含的方法有:
- 數(shù)據(jù)序列生成:DatasetCreate 方法生成數(shù)據(jù)序列。
- 構(gòu)造圖邊:_create_edges 方法使用鄰接矩陣構(gòu)造圖的邊。
- 生成數(shù)據(jù)序列:_create_sequences 方法通過(guò)在輸入的股市數(shù)據(jù)上滑動(dòng)窗口來(lái)生成數(shù)據(jù)序列。
這種數(shù)據(jù)準(zhǔn)備代碼可以很容易地適應(yīng)其他問(wèn)題。這包括定義每個(gè)時(shí)間步的節(jié)點(diǎn)間的連接方式,并利用滑動(dòng)窗口方法提取可以供模型學(xué)習(xí)的序列特征。通過(guò)這種方法,原本簡(jiǎn)單的時(shí)間序列數(shù)據(jù)被轉(zhuǎn)化為具有復(fù)雜關(guān)系和時(shí)間依賴性的圖形數(shù)據(jù)結(jié)構(gòu),從而可以使用圖神經(jīng)網(wǎng)絡(luò)來(lái)進(jìn)行更深入的分析和預(yù)測(cè)。
def AdjacencyMatrix(L):
AdjM = np.ones((L,L))
return AdjM
class StockMarketDataset:
def __init__(self, W,N_hist, N_pred):
self.W = W
self.N_hist = N_hist
self.N_pred = N_pred
def DatasetCreate(self):
num_days, self.n_node = data_scaled.shape
n_window = self.N_hist + self.N_pred
edge_index, edge_attr = self._create_edges(self.n_node)
sequences = self._create_sequences(data_scaled, self.n_node, n_window, edge_index, edge_attr)
return sequences
def _create_edges(self, n_node):
edge_index = torch.zeros((2, n_node**2), dtype=torch.long)
edge_attr = torch.zeros((n_node**2, 1))
num_edges = 0
for i in range(n_node):
for j in range(n_node):
if self.W[i, j] != 0:
edge_index[:, num_edges] = torch.tensor([i, j], dtype=torch.long)
edge_attr[num_edges, 0] = self.W[i, j]
num_edges += 1
edge_index = edge_index[:, :num_edges]
edge_attr = edge_attr[:num_edges]
return edge_index, edge_attr
def _create_sequences(self, data, n_node, n_window, edge_index, edge_attr):
sequences = []
num_days, _ = data.shape
for i in range(num_days):
sta = i
end = i+n_window
full_window = np.swapaxes(data[sta:end, :], 0, 1)
g = Data(x=torch.FloatTensor(full_window[:, :self.N_hist]),
y=torch.FloatTensor(full_window[:, self.N_hist:]),
edge_index=edge_index,
num_nodes=n_node)
sequences.append(g)
return sequences
訓(xùn)練-驗(yàn)證-測(cè)試分割。
from torch_geometric.loader import DataLoader
def train_val_test_splits(sequences, splits):
total = len(sequences)
split_train, split_val, split_test = splits
# Calculate split indices
idx_train = int(total * split_train)
idx_val = int(total * (split_train + split_val))
indices = [i for i in range(len(sequences)-100)]
random.shuffle(indices)
train = [sequences[index] for index in indices[:idx_train]]
val = [sequences[index] for index in indices[idx_train:idx_val]]
test = [sequences[index] for index in indices[idx_val:]]
return train, val, test
'''Setting up the hyper paramaters'''
n_nodes = 6
n_hist = 50
n_pred = 10
batch_size = 32
# Adjacency matrix
W = AdjacencyMatrix(n_nodes)
# transorm data into graphical time series
dataset = StockMarketDataset(W, n_hist, n_pred)
sequences = dataset.DatasetCreate()
# train, validation, test split
splits = (0.9, 0.05, 0.05)
train, val, test = train_val_test_splits(sequences, splits)
train_dataloader = DataLoader(train, batch_size=batch_size, shuffle=True, drop_last = True)
val_dataloader = DataLoader(val, batch_size=batch_size, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test, batch_size=batch_size, shuffle=True, drop_last = True)
我們的模型包括一個(gè)GATConv和2個(gè)GRU層作為編碼器,1個(gè)GRU層+全連接層作為解碼器。GATconv是GNN部分,可以捕獲空間依賴性,GRU層可以捕獲數(shù)據(jù)的時(shí)間動(dòng)態(tài)。代碼包括大量的數(shù)據(jù)重塑,這樣可以保證每一層的輸入維度相同。這也是我們所說(shuō)的ST-GNN實(shí)現(xiàn)中最復(fù)雜的部分,所以如果向具體了解輸各層輸入的維度,可以在向前傳遞的不同階段打印x的形狀,并將其與GRU和Linear層的預(yù)期輸入尺寸的文檔進(jìn)行比較。
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv
class ST_GNN_Model(torch.nn.Module):
def __init__(self, in_channels, out_channels, n_nodes,gru_hs_l1, gru_hs_l2, heads=1, dropout=0.01):
super(ST_GAT, self).__init__()
self.n_pred = out_channels
self.heads = heads
self.dropout = dropout
self.n_nodes = n_nodes
self.gru_hidden_size_l1 = gru_hs_l1
self.gru_hidden_size_l2 = gru_hs_l2
self.decoder_hidden_size = self.gru_hidden_size_l2
# enconder GRU layers
self.gat = GATConv(in_channels=in_channels, out_channels=in_channels,
heads=heads, dropout=dropout, cnotallow=False)
self.encoder_gru_l1 = torch.nn.GRU(input_size=self.n_nodes,
hidden_size=self.gru_hidden_size_l1, num_layers=1,
bias = True)
self.encoder_gru_l2 = torch.nn.GRU(input_size=self.gru_hidden_size_l1,
hidden_size=self.gru_hidden_size_l2, num_layers = 1,
bias = True)
self.GRU_decoder = torch.nn.GRU(input_size = self.gru_hidden_size_l2, hidden_size = self.decoder_hidden_size,
num_layers =1, bias = True, dropout= self.dropout)
self.prediction_layer = torch.nn.Linear(self.decoder_hidden_size, self.n_nodes*self.n_pred, bias= True)
def forward(self, data, device):
x, edge_index = data.x, data.edge_index
if device == 'cpu':
x = torch.FloatTensor(x)
else:
x = torch.cuda.FloatTensor(x)
x = self.gat(x, edge_index)
x = F.dropout(x, self.dropout, training=self.training)
batch_size = data.num_graphs
n_node = int(data.num_nodes / batch_size)
x = torch.reshape(x, (batch_size, n_node, data.num_features))
x = torch.movedim(x, 2, 0)
encoderl1_outputs, _ = self.encoder_gru_l1(x)
x = F.relu(encoderl1_outputs)
encoderl2_outputs, h2 = self.encoder_gru_l2(x)
x = F.relu(encoderl2_outputs)
x, _ = self.GRU_decoder(x,h2)
x = torch.squeeze(x[-1,:,:])
x = self.prediction_layer(x)
x = torch.reshape(x, (batch_size, self.n_nodes, self.n_pred))
x = torch.reshape(x, (batch_size*self.n_nodes, self.n_pred))
return x
訓(xùn)練過(guò)程與pytorch中的任何網(wǎng)絡(luò)訓(xùn)練過(guò)程幾乎相同。
import torch
import torch.optim as optim
# Hyperparameters
gru_hs_l1 = 16
gru_hs_l2 = 16
learning_rate = 1e-3
Epochs = 50
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ST_GNN_Model(in_channels=n_hist, out_channels=n_pred, n_nodes=n_nodes, gru_hs_l1=gru_hs_l1, gru_hs_l2 = gru_hs_l2)
pretrained = False
model_path = "ST_GNN_Model.pth"
if pretrained:
model.load_state_dict(torch.load(model_path))
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-7)
criterion = torch.nn.MSELoss()
model.to(device)
for epoch in range(Epochs):
model.train()
for _, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch}")):
batch = batch.to(device)
optimizer.zero_grad()
y_pred = torch.squeeze(model(batch, device))
loss= criterion(y_pred.float(), torch.squeeze(batch.y).float())
loss.backward()
optimizer.step()
print(f"Loss: {loss:.7f}")
模型訓(xùn)練完成了,下面就可視化模型的預(yù)測(cè)能力。對(duì)于每個(gè)數(shù)據(jù)輸入,下面的代碼預(yù)測(cè)模型輸出,并隨后繪制模型輸出與基礎(chǔ)真值的關(guān)系。
@torch.no_grad()
def Extract_results(model, device, dataloader, type=''):
model.eval()
model.to(device)
n = 0
# Evaluate model on all data
for i, batch in enumerate(dataloader):
batch = batch.to(device)
if batch.x.shape[0] == 1:
pass
else:
with torch.no_grad():
pred = model(batch, device)
truth = batch.y.view(pred.shape)
if i == 0:
y_pred = torch.zeros(len(dataloader), pred.shape[0], pred.shape[1])
y_truth = torch.zeros(len(dataloader), pred.shape[0], pred.shape[1])
y_pred[i, :pred.shape[0], :] = pred
y_truth[i, :pred.shape[0], :] = truth
n += 1
y_pred_flat = torch.reshape(y_pred, (len(dataloader),batch_size,n_nodes,n_pred))
y_truth_flat = torch.reshape(y_truth,(len(dataloader),batch_size,n_nodes,n_pred))
return y_pred_flat, y_truth_flat
def plot_results(predictions,actual, step, node):
predictions = torch.tensor(predictions[:,:,node,step]).squeeze()
actual = torch.tensor(actual[:,:,node,step]).squeeze()
pred_values_float = torch.reshape(predictions,(-1,))
actual_values_float = torch.reshape(actual, (-1,))
scatter_trace = go.Scatter(
x=actual_values_float,
y=pred_values_float,
mode='markers',
marker=dict(
size=10,
opacity=0.5,
color='rgba(255,255,255,0)',
line=dict(
width=2,
color='rgba(152, 0, 0, .8)',
)
),
name='Actual vs Predicted'
)
line_trace = go.Scatter(
x=[min(actual_values_float), max(actual_values_float)],
y=[min(actual_values_float), max(actual_values_float)],
mode='lines',
marker=dict(color='blue'),
name='Perfect Prediction'
)
data = [scatter_trace, line_trace]
layout = dict(
title='Actual vs Predicted Values',
xaxis=dict(title='Actual Values'),
yaxis=dict(title='Predicted Values'),
autosize=False,
width=800,
height=600
)
fig = dict(data=data, layout=layout)
iplot(fig)
y_pred, y_truth = Extract_results(model, device, test_dataloader, 'Test')
plot_results(y_pred, y_truth,9,0) # timestep, node
對(duì)于6個(gè)節(jié)點(diǎn)(公司),給出過(guò)去50個(gè)值,做出10個(gè)預(yù)測(cè)。下面是第一個(gè)節(jié)點(diǎn)的第10步預(yù)測(cè)與真值的圖。看起來(lái)看不錯(cuò),但并不一定意味著就很好。因?yàn)閷?duì)于時(shí)間序列數(shù)據(jù),下一個(gè)值的最佳估計(jì)量總是前一個(gè)值。如果沒(méi)有得到很好的訓(xùn)練,這些模型可以輸出與輸入數(shù)據(jù)的最后一個(gè)值相似的值,而不是捕獲時(shí)間動(dòng)態(tài)。
對(duì)于給定的節(jié)點(diǎn),我們可以繪制歷史輸入、預(yù)測(cè)和真值進(jìn)行比較,查看預(yù)測(cè)是否捕獲了模式。
@torch.no_grad()
def forecastModel(model, device, dataloader, node):
model.eval()
model.to(device)
for i, batch in enumerate(dataloader):
batch = batch.to(device)
with torch.no_grad():
pred = model(batch, device)
truth = batch.y.view(pred.shape)
# the shape should [batch_size, nodes, number of predictions]
truth = torch.reshape(truth, [batch_size, n_nodes,n_pred])
pred = torch.reshape(pred, [batch_size, n_nodes,n_pred])
x = batch.x
x = torch.reshape(x, [batch_size, n_nodes,n_hist])
y_pred = torch.squeeze(pred[0, node, :])
y_truth = torch.squeeze(truth[0,node,:])
y_past = torch.squeeze(x[0, node, :])
t_range = [t for t in range(len(y_past))]
break
t_shifted = [t_range[-1]+1+t for t in range(len(y_pred))]
trace1 = go.Scatter(x =t_range, y= y_past, mode = "markers", name = "Historical data")
trace2 = go.Scatter(x=t_shifted, y=y_pred, mode = "markers", name = "pred")
trace3 = go.Scatter(x=t_shifted, y=y_truth, mode = "markers", name = "truth")
layout = go.Layout(title = "forecasting", xaxis=dict(title = 'time'),
yaxis=dict(title = 'y-value'), width = 1000, height = 600)
figure = go.Figure(data = [trace1, trace2, trace3], layout = layout)
iplot(figure)
forecastModel(model, device, test_dataloader, 0)
第一個(gè)節(jié)點(diǎn)(Google)在測(cè)試數(shù)據(jù)集的4個(gè)不同點(diǎn)上的預(yù)測(cè)實(shí)際上比我想象的要好,其他的看來(lái)不怎么樣。
總結(jié)
我的理解是未來(lái)的股票價(jià)格不能通過(guò)單純的歷史價(jià)值自回歸來(lái)預(yù)測(cè),因?yàn)楣善笔怯涩F(xiàn)實(shí)世界的事件決定的,這并沒(méi)有體現(xiàn)在歷史價(jià)值中。這也就是我們?cè)谇懊嬲f(shuō)的不建議在股市預(yù)測(cè)中使用ST-GNN,我們使用這個(gè)數(shù)據(jù)集只是因?yàn)樗菀撰@取。最后不要忘集我們本篇文章的目的,學(xué)習(xí)ST-GNN的基本概念,以及通過(guò)Pytorch代碼實(shí)現(xiàn)來(lái)了解ST-GNN的工作原理。