深度學(xué)習(xí)如何自動微分
在深度學(xué)習(xí)中,求導(dǎo)幾乎是所有優(yōu)化算法的關(guān)鍵步驟,但是對于單個值的求導(dǎo)比較簡單,但是如果對于模型訓(xùn)練過程中每次都需要手動求導(dǎo)就很麻煩,因此深度學(xué)習(xí)框架都提供了自動導(dǎo)數(shù)(自動微分)。
1、PyTorch自動微分
對函數(shù) y = 2X^X 求導(dǎo)(其中X為列向量,這里表示兩段列向量做矩陣乘法),其中PyTorch自動微分的代碼如下:
import torch
x = torch.arange(4.0)
print("x: ", x)
x.requires_grad_(True) # 等價于x=torch.arange(4.0,requires_grad=True)
print("x.grad: ", x.grad)
y = 2 * torch.dot(x, x)
print("y: ", y)
y.backward()
print("x.grad: ", x.grad)
1、先給x賦值 tensor([0., 1., 2., 3.]) ;2、將x設(shè)置為自動微分 ;3、賦值y表達(dá)式,計算y的值,結(jié)果:tensor(28., grad_fn=<MulBackward0>);4、y.backward() 調(diào)用反向傳播函數(shù)來自動計算y關(guān)于x每個分量的梯度,并打印這些梯度;
輸出結(jié)果:tensor([ 0., 4., 8., 12.]) 和 y = 4X 的導(dǎo)數(shù)是一樣的。
2、如何自動微分
自動微分開源實(shí)現(xiàn)很多,其中類似 PyTorch 的 API 包括 karpathy 開源的 https://github.com/karpathy/micrograd 和 https://github.com/tinygrad/tinygrad,這里為了簡單借鑒 micrograd,重寫部分代碼實(shí)現(xiàn)自動微分。
2.1 前向傳播
微分需要支持多個基礎(chǔ)運(yùn)算,如+,-,*,/,power等,代碼如下:
class Value:
def __init__(self, data, _children=(), _op=''):
self.data = data
self.grad = 0
self._prev = set(_children)
self._op = _op
def __add__(self, other):
other = other if isinstance(other, Value) else Value(other)
out = Value(self.data + other.data, (self, other), '+')
return out
def __mul__(self, other):
other = other if isinstance(other, Value) else Value(other)
out = Value(self.data * other.data, (self, other), '*')
return out
def __pow__(self, other):
assert isinstance(other, (int, float)), "only supporting int/float powers for now"
out = Value(self.data**other, (self,), f'**{other}')
return out
def relu(self):
out = Value(0 if self.data < 0 else self.data, (self,), 'ReLU')
return out
def __neg__(self): # -self
return self * -1
def __radd__(self, other): # other + self
return self + other
def __sub__(self, other): # self - other
return self + (-other)
def __rsub__(self, other): # other - self
return other + (-self)
def __rmul__(self, other): # other * self
return self * other
def __truediv__(self, other): # self / other
return self * other**-1
def __rtruediv__(self, other): # other / self
return other * self**-1
def __repr__(self):
return f"Value(data={self.data}, grad={self.grad})"
那么表達(dá)式 a * b + c + d**2,按照賦值變量的運(yùn)行:
a = Value(2.0)
b = Value(-3.0)
c = Value(10.0)
d = Value(2.0)
d = a * b + c
z = d + f**2
print("z: ", z)
結(jié)果:Value(data=8.0, grad=0),同時按照前向傳播路徑畫圖如下:
前向傳播
2.2 反向傳播
在反向傳播的過程,本質(zhì)是求網(wǎng)絡(luò)的每個參數(shù)關(guān)于最終損失函數(shù)的梯度,而該梯度可以成是回傳的全局梯度和局部梯度之乘。
其中梯度代表了當(dāng)前層參數(shù)的變化,對最終預(yù)測損失的影響(變化率),而該變化率實(shí)際取決于當(dāng)前層參數(shù)對下一層輸入的影響,以及下一層輸入對最終預(yù)測損失的影響,兩個變化一乘,就是當(dāng)前層參數(shù)對最終預(yù)測損失的影響。
那么反向傳播的代碼實(shí)現(xiàn)就是要將每個變量與表達(dá)式的結(jié)果關(guān)聯(lián),根據(jù)微積分的鏈?zhǔn)椒▌t(https://zh.wikipedia.org/wiki/%E9%93%BE%E5%BC%8F%E6%B3%95%E5%88%99),如果變量 z 依賴于變量 y,而變量 y 又依賴于變量 x(即 y 和 z 是因變量),那么 z 也通過中間變量 y 來依賴于 x,其中 c = 2 * a; d = c + b; 推倒如下:
c = 2 * a
d = c + b
求導(dǎo)數(shù) dd/da 可以根據(jù)微分傳遞性轉(zhuǎn)換 dd/dc * dc/da
那么 dd/dc = 1(加法的導(dǎo)數(shù)是常數(shù)),dc/da = 2(乘法的導(dǎo)數(shù)對應(yīng)是2),所以 dd/da = 2
根據(jù)前向傳播的代碼添加反向傳播:
class Value:
def __init__(self, data, _children=(), _op=''):
self.data = data
self.grad = 0
self._backward = lambda: None
self._prev = set(_children)
self._op = _op
def __add__(self, other):
other = other if isinstance(other, Value) else Value(other)
out = Value(self.data + other.data, (self, other), '+')
def _backward():
self.grad += out.grad
other.grad += out.grad
out._backward = _backward
return out
def __mul__(self, other):
other = other if isinstance(other, Value) else Value(other)
out = Value(self.data * other.data, (self, other), '*')
def _backward():
self.grad += other.data * out.grad
other.grad += self.data * out.grad
out._backward = _backward
return out
def __pow__(self, other):
assert isinstance(other, (int, float)), "only supporting int/float powers for now"
out = Value(self.data**other, (self,), f'**{other}')
def _backward():
self.grad += (other * self.data**(other-1)) * out.grad
out._backward = _backward
return out
def relu(self):
out = Value(0 if self.data < 0 else self.data, (self,), 'ReLU')
def _backward():
self.grad += (out.data > 0) * out.grad
out._backward = _backward
return out
def backward(self):
topo = []
visited = set()
def build_topo(v):
if v not in visited:
visited.add(v)
for child in v._prev:
build_topo(child)
topo.append(v)
build_topo(self)
self.grad = 1
for v in reversed(topo):
v._backward()
其中 backward 遍歷所有孩子節(jié)點(diǎn),然后 reversed 計算每個_backward,那么表達(dá)式 a * b + c + d**2,按照賦值變量的運(yùn)行:
a = Value(2.0)
b = Value(-3.0)
c = Value(10.0)
f = Value(2.0)
d = a * b + c
z = d + f**2
z.backward()
print("a: ", a.grad)
結(jié)果:a:-3.0,同時按照反向傳播路徑畫圖如下:
反向傳播
3、總結(jié)
以上就是構(gòu)造自動微分的代碼,功能比較簡單,主要是理解梯度的計算方法,并計算各個計算變量在圖節(jié)點(diǎn)上的關(guān)系。