如何在Java中構(gòu)建神經(jīng)網(wǎng)絡(luò)
譯文譯者 | 李睿
審校 | 重樓
人工神經(jīng)網(wǎng)絡(luò)是深度學(xué)習(xí)的一種形式,也是現(xiàn)代人工智能的支柱之一。用戶真正掌握其工作原理的最佳方法是自己構(gòu)建一個(gè)人工神經(jīng)網(wǎng)絡(luò)。本文將介紹如何用Java構(gòu)建和訓(xùn)練神經(jīng)網(wǎng)絡(luò)。
感興趣的用戶可以查閱軟件架構(gòu)師Matthew Tyson以前撰寫的名為《機(jī)器學(xué)習(xí)的風(fēng)格:神經(jīng)網(wǎng)絡(luò)簡(jiǎn)介》文章,以了解人工神經(jīng)網(wǎng)絡(luò)如何運(yùn)行的概述。本文中的示例將不是一個(gè)生產(chǎn)等級(jí)的系統(tǒng),與其相反,它在一個(gè)易于理解的演示例子中展示了所有的主要組件。
一個(gè)基本的神經(jīng)網(wǎng)絡(luò)
神經(jīng)網(wǎng)絡(luò)是一種稱為神經(jīng)元(Neuron)的節(jié)點(diǎn)圖。神經(jīng)元是計(jì)算的基本單位。它接收輸入并使用每個(gè)輸入的權(quán)重、每個(gè)節(jié)點(diǎn)的偏差和最終函數(shù)處理器(其名稱為激活函數(shù))算法處理它們。例如圖1所示的雙輸入神經(jīng)元。
圖1 神經(jīng)網(wǎng)絡(luò)中的雙輸入神經(jīng)元
這個(gè)模型具有廣泛的可變性,將在下面演示的例子中使用這個(gè)精確的配置。
第一步是建立一個(gè)神經(jīng)元類模型,該類將保持這些值??梢栽谇鍐?中看到神經(jīng)元類。需要注意的是,這是該類的第一個(gè)版本。它將隨著添加的功能而改變。
清單1.簡(jiǎn)單的神經(jīng)元類
class Neuron {
Random random = new Random();
private Double bias = random.nextDouble(-1, 1);
public Double weight1 = random.nextDouble(-1, 1);
private Double weight2 = random.nextDouble(-1, 1);
public double compute(double input1, double input2){
double preActivation = (this.weight1 * input1) + (this.weight2 * input2) + this.bias;
double output = Util.sigmoid(preActivation);
return output;
}
}
可以看到神經(jīng)元(Neuron)類非常簡(jiǎn)單,有三個(gè)成員:bias、weight1和weight2。每個(gè)成員被初始化為-1到1之間的隨機(jī)雙精度。
當(dāng)計(jì)算神經(jīng)元的輸出時(shí),遵循圖1所示的算法:將每個(gè)輸入乘以其權(quán)重,再加上偏差:input1 * weight1 + input2 * weight2 + biass。這提供了通過(guò)激活函數(shù)運(yùn)行的未處理計(jì)算(即預(yù)激活)。在本例中,使用Sigmoid激活函數(shù),它將值壓縮到-1到1的范圍內(nèi)。清單2顯示了Util.sigmoid()靜態(tài)方法。
清單2.Sigmoid激活函數(shù)
public class Util {
public static double sigmoid(double in){
return 1 / (1 + Math.exp(-in));
}
}
現(xiàn)在已經(jīng)了解了神經(jīng)元是如何工作的,可以把一些神經(jīng)元放到一個(gè)網(wǎng)絡(luò)中。然后將使用帶有神經(jīng)元列表的Network類,如清單3所示。
清單3.神經(jīng)網(wǎng)絡(luò)類
class Network {
List<Neuron> neurons = Arrays.asList(
new Neuron(), new Neuron(), new Neuron(), /* input nodes */
new Neuron(), new Neuron(), /* hidden nodes */
new Neuron()); /* output node */
}
}
雖然神經(jīng)元的列表是一維的,但將在使用過(guò)程中將它們連接起來(lái),使它們形成一個(gè)網(wǎng)絡(luò)。前三個(gè)神經(jīng)元是輸入,第二個(gè)和第三個(gè)是隱藏的,最后一個(gè)是輸出節(jié)點(diǎn)。
進(jìn)行預(yù)測(cè)
現(xiàn)在,使用這個(gè)網(wǎng)絡(luò)來(lái)做一個(gè)預(yù)測(cè)。將使用兩個(gè)輸入整數(shù)的簡(jiǎn)單數(shù)據(jù)集和0到1的答案格式。這個(gè)例子使用體重-身高組合來(lái)猜測(cè)某人的性別,這是基于這樣的假設(shè),即體重和身高越高,則表明某人是男性??梢詫?duì)任何兩個(gè)因素使用相同的公式,即單輸出概率。可以將輸入視為一個(gè)向量,因此神經(jīng)元的整體功能將向量轉(zhuǎn)換為標(biāo)量值。
網(wǎng)絡(luò)的預(yù)測(cè)階段如清單4所示。
清單4.網(wǎng)絡(luò)預(yù)測(cè)
public Double predict(Integer input1, Integer input2){
return neurons.get(5).compute(
neurons.get(4).compute(
neurons.get(2).compute(input1, input2),
neurons.get(1).compute(input1, input2)
),
neurons.get(3).compute(
neurons.get(1).compute(input1, input2),
neurons.get(0).compute(input1, input2)
)
);
}
清單4顯示了將兩個(gè)輸入饋入到前三個(gè)神經(jīng)元,然后將前三個(gè)神經(jīng)元的輸出饋入到神經(jīng)元4和5,神經(jīng)元4和5又饋入到輸出神經(jīng)元。這個(gè)過(guò)程被稱為前饋。
現(xiàn)在,可以要求網(wǎng)絡(luò)進(jìn)行預(yù)測(cè),如清單5所示。
清單5.獲取預(yù)測(cè)
Network network = new Network();
Double prediction = network.predict(Arrays.asList(115, 66));
System.out.println(“prediction: “ + prediction);
在這里肯定會(huì)得到一些結(jié)果,但這是隨機(jī)權(quán)重和偏差的結(jié)果。為了進(jìn)行真正的預(yù)測(cè),首先需要訓(xùn)練網(wǎng)絡(luò)。
訓(xùn)練網(wǎng)絡(luò)
訓(xùn)練神經(jīng)網(wǎng)絡(luò)遵循一個(gè)稱為反向傳播的過(guò)程。反向傳播基本上是通過(guò)網(wǎng)絡(luò)向后推動(dòng)更改,使輸出向期望的目標(biāo)移動(dòng)。
可以使用函數(shù)微分進(jìn)行反向傳播,但在這個(gè)例子中,需要做一些不同的事情,將賦予每個(gè)神經(jīng)元“變異”的能力。在每一輪訓(xùn)練(稱為epoch)中,選擇一個(gè)不同的神經(jīng)元對(duì)其屬性之一(weight1,weight2或bias)進(jìn)行小的隨機(jī)調(diào)整,然后檢查結(jié)果是否有所改善。如果結(jié)果有所改善,將使用remember()方法保留該更改。如果結(jié)果惡化,將使用forget()方法放棄更改。
添加類成員(舊版本的權(quán)重和偏差)來(lái)跟蹤變化??梢栽谇鍐?中看到mutate()、remember()和forget()方法。
清單6.Mutate(),remember(),forget()
public class Neuron() {
private Double oldBias = random.nextDouble(-1, 1), bias = random.nextDouble(-1, 1);
public Double oldWeight1 = random.nextDouble(-1, 1), weight1 = random.nextDouble(-1, 1);
private Double oldWeight2 = random.nextDouble(-1, 1), weight2 = random.nextDouble(-1, 1);
public void mutate(){
int propertyToChange = random.nextInt(0, 3);
Double changeFactor = random.nextDouble(-1, 1);
if (propertyToChange == 0){
this.bias += changeFactor;
} else if (propertyToChange == 1){
this.weight1 += changeFactor;
} else {
this.weight2 += changeFactor;
};
}
public void forget(){
bias = oldBias;
weight1 = oldWeight1;
weight2 = oldWeight2;
}
public void remember(){
oldBias = bias;
oldWeight1 = weight1;
oldWeight2 = weight2;
}
}
非常簡(jiǎn)單:mutate()方法隨機(jī)選擇一個(gè)屬性,隨機(jī)選擇-1到1之間的值,然后更改該屬性。forget()方法將更改滾回舊值。remember()方法將新值復(fù)制到緩沖區(qū)。
現(xiàn)在,為了利用神經(jīng)元的新功能,我們向Network添加了一個(gè)train()方法,如清單7所示。
清單7.Network.train()方法
public void train(List<List<Integer>> data, List<Double> answers){
Double bestEpochLoss = null;
for (int epoch = 0; epoch < 1000; epoch++){
// adapt neuron
Neuron epochNeuron = neurons.get(epoch % 6);
epochNeuron.mutate(this.learnFactor);
List<Double> predictions = new ArrayList<Double>();
for (int i = 0; i < data.size(); i++){
predictions.add(i, this.predict(data.get(i).get(0), data.get(i).get(1)));
}
Double thisEpochLoss = Util.meanSquareLoss(answers, predictions);
if (bestEpochLoss == null){
bestEpochLoss = thisEpochLoss;
epochNeuron.remember();
} else {
if (thisEpochLoss < bestEpochLoss){
bestEpochLoss = thisEpochLoss;
epochNeuron.remember();
} else {
epochNeuron.forget();
}
}
}
train()方法對(duì)數(shù)據(jù)重復(fù)1000次,并在參數(shù)中保留回答列表。這些是同樣大小的訓(xùn)練集;數(shù)據(jù)保存輸入值,答案保存已知的良好答案。然后,該方法遍歷這些答案,并得到一個(gè)值,表明網(wǎng)絡(luò)猜測(cè)的結(jié)果與已知的正確答案相比的正確率。然后,它會(huì)讓一個(gè)隨機(jī)的神經(jīng)元發(fā)生突變,如果新的測(cè)試表明這是一個(gè)更好的預(yù)測(cè),它就會(huì)保持這種變化。
檢查結(jié)果
可以使用均方誤差(MSE)公式來(lái)檢查結(jié)果,這是一種在神經(jīng)網(wǎng)絡(luò)中測(cè)試一組結(jié)果的常用方法??梢栽谇鍐?中看到MSE函數(shù)。
清單8.均方誤差函數(shù)
public static Double meanSquareLoss(List<Double> correctAnswers, List<Double> predictedAnswers){
double sumSquare = 0;
for (int i = 0; i < correctAnswers.size(); i++){
double error = correctAnswers.get(i) - predictedAnswers.get(i);
sumSquare += (error * error);
}
return sumSquare / (correctAnswers.size());
}
微調(diào)系統(tǒng)
現(xiàn)在剩下的就是把一些訓(xùn)練數(shù)據(jù)輸入網(wǎng)絡(luò),并用更多的預(yù)測(cè)來(lái)嘗試。清單9顯示了如何提供訓(xùn)練數(shù)據(jù)。
清單9.訓(xùn)練數(shù)據(jù)
List<List<Integer>> data = new ArrayList<List<Integer>>();
data.add(Arrays.asList(115, 66));
data.add(Arrays.asList(175, 78));
data.add(Arrays.asList(205, 72));
data.add(Arrays.asList(120, 67));
List<Double> answers = Arrays.asList(1.0,0.0,0.0,1.0);
Network network = new Network();
network.train(data, answers);
在清單9中,訓(xùn)練數(shù)據(jù)是一個(gè)二維整數(shù)集列表(可以把它們看作體重和身高),然后是一個(gè)答案列表(1.0表示女性,0.0表示男性)。
如果在訓(xùn)練算法中添加一些日志記錄,運(yùn)行它將得到類似清單10的輸出。
清單10.記錄訓(xùn)練器
// Logging:
if (epoch % 10 == 0) System.out.println(String.format("Epoch: %s | bestEpochLoss: %.15f | thisEpochLoss: %.15f", epoch, bestEpochLoss, thisEpochLoss));
// output:
Epoch: 910 | bestEpochLoss: 0.034404863820424 | thisEpochLoss: 0.034437939546120
Epoch: 920 | bestEpochLoss: 0.033875954196897 | thisEpochLoss: 0.431451026477016
Epoch: 930 | bestEpochLoss: 0.032509260025490 | thisEpochLoss: 0.032509260025490
Epoch: 940 | bestEpochLoss: 0.003092720117159 | thisEpochLoss: 0.003098025397281
Epoch: 950 | bestEpochLoss: 0.002990128276146 | thisEpochLoss: 0.431062364628853
Epoch: 960 | bestEpochLoss: 0.001651762688346 | thisEpochLoss: 0.001651762688346
Epoch: 970 | bestEpochLoss: 0.001637709485751 | thisEpochLoss: 0.001636810460399
Epoch: 980 | bestEpochLoss: 0.001083365453009 | thisEpochLoss: 0.391527869500699
Epoch: 990 | bestEpochLoss: 0.001078338540452 | thisEpochLoss: 0.001078338540452
清單10顯示了損失(誤差偏離正右側(cè))緩慢下降;也就是說(shuō),它越來(lái)越接近做出準(zhǔn)確的預(yù)測(cè)。剩下的就是看看模型對(duì)真實(shí)數(shù)據(jù)的預(yù)測(cè)效果如何,如清單11所示。
清單11.預(yù)測(cè)
System.out.println("");
System.out.println(String.format(" male, 167, 73: %.10f", network.predict(167, 73)));
System.out.println(String.format("female, 105, 67: %.10", network.predict(105, 67)));
System.out.println(String.format("female, 120, 72: %.10f | network1000: %.10f", network.predict(120, 72)));
System.out.println(String.format(" male, 143, 67: %.10f | network1000: %.10f", network.predict(143, 67)));
System.out.println(String.format(" male', 130, 66: %.10f | network: %.10f", network.predict(130, 66)));
在清單11中,將訓(xùn)練好的網(wǎng)絡(luò)輸入一些數(shù)據(jù),輸出預(yù)測(cè)結(jié)果。結(jié)果如清單12所示。
清單12.訓(xùn)練有素的預(yù)測(cè)
male, 167, 73: 0.0279697143
female, 105, 67: 0.9075809407
female, 120, 72: 0.9075808235
male, 143, 67: 0.0305401413
male, 130, 66: network: 0.9009811922
在清單12中,看到網(wǎng)絡(luò)對(duì)大多數(shù)值對(duì)(又名向量)都做得很好。它給女性數(shù)據(jù)集的估計(jì)值約為0.907,非常接近1。兩名男性顯示0.027和0.030接近0。離群的男性數(shù)據(jù)集(130,67)被認(rèn)為可能是女性,但可信度較低,為0.900。
結(jié)論
有多種方法可以調(diào)整這一系統(tǒng)上的參數(shù)。首先,訓(xùn)練運(yùn)行中的epoch數(shù)是一個(gè)主要因素。epoch越多,其模型就越適合數(shù)據(jù)。運(yùn)行更多的epoch可以提高符合訓(xùn)練集的實(shí)時(shí)數(shù)據(jù)的準(zhǔn)確性,但也會(huì)導(dǎo)致過(guò)度訓(xùn)練。也就是說(shuō),這是一個(gè)在邊緣情況下自信地預(yù)測(cè)錯(cuò)誤結(jié)果的模型。
文章標(biāo)題:How to build a neural network in Java,作者:Matthew Tyson