NeRF是什么?基于NeRF的三維重建是基于體素嗎?
本文經(jīng)自動(dòng)駕駛之心公眾號(hào)授權(quán)轉(zhuǎn)載,轉(zhuǎn)載請(qǐng)聯(lián)系出處。
1介紹
神經(jīng)輻射場(chǎng)(NeRF)是深度學(xué)習(xí)和計(jì)算機(jī)視覺領(lǐng)域的一個(gè)相當(dāng)新的范式。ECCV 2020論文《NeRF:將場(chǎng)景表示為視圖合成的神經(jīng)輻射場(chǎng)》(該論文獲得了最佳論文獎(jiǎng))中介紹了這項(xiàng)技術(shù),該技術(shù)自此大受歡迎,迄今已獲得近800次引用[1]。該方法標(biāo)志著機(jī)器學(xué)習(xí)處理3D數(shù)據(jù)的傳統(tǒng)方式發(fā)生了巨大變化。
神經(jīng)輻射場(chǎng)場(chǎng)景表示和可微分渲染過程:
通過沿著相機(jī)射線采樣5D坐標(biāo)(位置和觀看方向)來合成圖像; 將這些位置輸入MLP以產(chǎn)生顏色和體積密度; 并使用體積渲染技術(shù)將這些值合成圖像; 該渲染函數(shù)是可微分的,因此可以通過最小化合成圖像和真實(shí)觀測(cè)圖像之間的殘差來優(yōu)化場(chǎng)景表示。
2 What is a NeRF?
NeRF是一種生成模型,以圖像和精確姿勢(shì)為條件,生成給定圖像的3D場(chǎng)景的新視圖,這一過程通常被稱為“新視圖合成”。不僅如此,它還將場(chǎng)景的3D形狀和外觀明確定義為連續(xù)函數(shù),可以通過marching cubes生成3D網(wǎng)格。盡管它們直接從圖像數(shù)據(jù)中學(xué)習(xí),但它們既不使用convolutional層,也不使用transformer層。
多年來,機(jī)器學(xué)習(xí)應(yīng)用中表示3D數(shù)據(jù)的方法很多,從3D體素到點(diǎn)云,再到符號(hào)距離(signed distance )函數(shù)。他們最大的共同缺點(diǎn)是需要預(yù)先假設(shè)一個(gè)3D模型,要么使用攝影測(cè)量或激光雷達(dá)等工具來生成3D數(shù)據(jù),要么手工制作3D模型。然而,許多類型的物體,如高反射物體、“網(wǎng)格狀”物體或透明物體,都無法按比例掃描。3D重建方法通常也具有重建誤差,這可能導(dǎo)致影響模型精度的階梯效應(yīng)或漂移。
相比之下,NeRF基于射線光場(chǎng)的概念。光場(chǎng)是描述光傳輸如何在整個(gè)3D體積中發(fā)生的函數(shù)。它描述了光線在空間中的每個(gè)x=(x,y,z)坐標(biāo)和每個(gè)方向d上移動(dòng)的方向,描述為θ和ξ角或單位向量。它們共同形成了描述3D場(chǎng)景中的光傳輸?shù)?D特征空間。受此表示的啟發(fā),NeRF試圖近似一個(gè)函數(shù),該函數(shù)從該空間映射到由顏色c=(R,G,B)和濃度(density)σ組成的4D空間,可以將其視為該5D坐標(biāo)空間處的光線終止的可能性(例如通過遮擋)。因此,標(biāo)準(zhǔn)NeRF是形式F:(x,d)->(c,σ)的函數(shù)。
原始的NeRF論文使用多層感知器將該函數(shù)參數(shù)化,該感知器基于一組姿勢(shì)已知的圖像上訓(xùn)練得到。這是一類稱為generalized scene reconstruction的技術(shù)中的一種方法,旨在直接從圖像集合中描述3D場(chǎng)景。這種方法具備一些非常好的特性:
- 直接從數(shù)據(jù)中學(xué)習(xí)
- 場(chǎng)景的連續(xù)表示允許非常薄和復(fù)雜的結(jié)構(gòu),例如樹葉或網(wǎng)格
- 隱含物理特性,如鏡面性和粗糙度
- 隱式呈現(xiàn)場(chǎng)景中的照明
此后,一系列的改進(jìn)論文隨之涌現(xiàn),例如,少鏡頭和單鏡頭學(xué)習(xí)[2,3]、對(duì)動(dòng)態(tài)場(chǎng)景的支持[4,5]、將光場(chǎng)推廣到特征場(chǎng)[6]、從網(wǎng)絡(luò)上的未校準(zhǔn)圖像集合中學(xué)習(xí)[7]、結(jié)合激光雷達(dá)數(shù)據(jù)[8]、大規(guī)模場(chǎng)景表示[9]、在沒有神經(jīng)網(wǎng)絡(luò)的情況下學(xué)習(xí)[10],諸如此類。
3 NeRF Architecture
總體而言,給定一個(gè)經(jīng)過訓(xùn)練的NeRF模型和一個(gè)具有已知姿勢(shì)和圖像維度的相機(jī),我們通過以下過程構(gòu)建場(chǎng)景:
- 對(duì)于每個(gè)像素,從相機(jī)光心射出光線穿過場(chǎng)景,以在(x,d)位置收集一組樣本
- 使用每個(gè)樣本的點(diǎn)和視線方向(x,d)作為輸入,以產(chǎn)生輸出(c,σ)值(rgbσ)
- 使用經(jīng)典的體積渲染技術(shù)構(gòu)建圖像
光射場(chǎng)(很多文獻(xiàn)翻譯為"輻射場(chǎng)",但譯者認(rèn)為"光射場(chǎng)"更直觀)函數(shù)只是幾個(gè)組件中的一個(gè),一旦組合起來,就可以創(chuàng)建之前看到的視頻中的視覺效果??傮w而言,本文包括以下幾個(gè)部分:
- 位置編碼(Positional encoding)
- 光射場(chǎng)函數(shù)近似器(MLP)
- 可微分體渲染器(Differentiable volume renderer)
- 分層(Stratified)取樣 層次(Hierarchical)體積采樣
為了最大限度地清晰講述,本文將每個(gè)組件的關(guān)鍵元素以盡可能簡(jiǎn)潔的代碼展示。參考了bmild的原始實(shí)現(xiàn)和yenchenlin和krrish94的PyTorch實(shí)現(xiàn)。
3.1 Positional Encoder
就像2017年推出的transformer模型[11]一樣,NeRF也受益于位置編碼器作為其輸入。它使用高頻函數(shù)將其連續(xù)輸入映射到更高維的空間,以幫助模型學(xué)習(xí)數(shù)據(jù)中的高頻變化,從而產(chǎn)生更清晰的模型。這種方法避開(circumvent)了神經(jīng)網(wǎng)絡(luò)對(duì)低頻函數(shù)偏置(bias),使NeRF能夠表示更清晰的細(xì)節(jié)。作者參考了ICML 2019上的一篇論文[12]。
如果熟悉transformerd的位置編碼,NeRF的相關(guān)實(shí)現(xiàn)是很標(biāo)準(zhǔn)的,它具有相同的交替正弦和余弦表達(dá)式。位置編碼器實(shí)現(xiàn):
# py
class PositionalEncoder(nn.Module):
# sine-cosine positional encoder for input points.
def __init__( self,
d_input: int,
n_freqs: int,
log_space: bool = False ):
super().__init__()
self.d_input = d_input
self.n_freqs = n_freqs # 是不是視線上的采樣頻率?
self.log_space = log_space
self.d_output = d_input * (1 + 2 * self.n_freqs)
self.embed_fns = [lambda x: x] # 冒號(hào)前面的x表示函數(shù)參數(shù),后面的表示匿名函數(shù)運(yùn)算
# Define frequencies in either linear or log scale
if self.log_space:
freq_bands = 2.**torch.linspace(0., self.n_freqs - 1, self.n_freqs)
else:
freq_bands = torch.linspace(2.**0., 2.**(self.n_freqs - 1), self.n_freqs)
# Alternate sin and cos
for freq in freq_bands:
self.embed_fns.append(lambda x, freq=freq: torch.sin(x * freq))
self.embed_fns.append(lambda x, freq=freq: torch.cos(x * freq))
def forward(self, x) -> torch.Tensor:
# Apply positional encoding to input.
return torch.concat([fn(x) for fn in self.embed_fns], dim=-1)
思考:這個(gè)位置編碼,是對(duì)輸入點(diǎn)(input points)進(jìn)行編碼,這個(gè)輸入點(diǎn)是視線上的采樣點(diǎn)?還是不同的視角位置點(diǎn)? self.n_freqs是不是視線上的采樣頻率?由此理解,應(yīng)該就是視線上的采樣位置,因?yàn)槿绻粚?duì)視線上的采樣位置進(jìn)行編碼,就無法有效表示這些位置,也就無法對(duì)它們的RGBA進(jìn)行訓(xùn)練。
3.2 Radiance Field Function
在原文中,光射場(chǎng)函數(shù)由NeRF模型表示,NeRF模型是一種典型的多層感知器,以編碼的3D點(diǎn)和視角方向作為輸入,并返回RGBA值作為輸出。雖然本文使用的是神經(jīng)網(wǎng)絡(luò),但這里可以使用任何函數(shù)逼近器(function approximator)。例如,Yu等人的后續(xù)論文Plenoxels使用球面諧波(spherical harmonics)實(shí)現(xiàn)了數(shù)量級(jí)的更快訓(xùn)練,同時(shí)獲得有競(jìng)爭(zhēng)力的結(jié)果[10]。
圖片
NeRF模型有8層深,大多數(shù)層的特征維度為256。剩余的連接被放置在層4處。在這些層之后,產(chǎn)生RGB和σ值。RGB值用線性層進(jìn)一步處理,然后與視線方向連接,然后通過另一個(gè)線性層,最后在輸出處與σ重新組合。NeRF模型的PyTorch模塊實(shí)現(xiàn):
class NeRF(nn.Module):
# Neural radiance fields module.
def __init__( self,
d_input: int = 3,
n_layers: int = 8,
d_filter: int = 256,
skip: Tuple[int] = (4,), # (4,)只有一個(gè)元素4的元組
d_viewdirs: Optional[int] = None):
super().__init__()
self.d_input = d_input # 這里是3D XYZ,?
self.skip = skip # 是要跳過什么?為啥要跳過?被遮擋?
self.act = nn.functional.relu
self.d_viewdirs = d_viewdirs # d_viewdirs 是2D方向?
# Create model layers
# [if_true 就執(zhí)行的指令] if [if_true條件] else [if_false]
# 是否skip的區(qū)別是,訓(xùn)練輸入維度是否多3維,
# if i in skip = if i in (4,),似乎是判斷i是否等于4
# self.d_input=3 :如果層id=4,網(wǎng)絡(luò)輸入要加3維,這是為什么?第4層有何特殊的?
self.layers = nn.ModuleList(
[nn.Linear(self.d_input, d_filter)] +
[nn.Linear(d_filter + self.d_input, d_filter) if i in skip else \
nn.Linear(d_filter , d_filter) for i in range(n_layers - 1)]
)
# Bottleneck layers
if self.d_viewdirs is not None:
# If using viewdirs, split alpha and RGB
self.alpha_out = nn.Linear(d_filter, 1)
self.rgb_filters = nn.Linear(d_filter, d_filter)
self.branch = nn.Linear(d_filter + self.d_viewdirs, d_filter // 2)
self.output = nn.Linear(d_filter // 2, 3) # 為啥要取一半?
else:
# If no viewdirs, use simpler output
self.output = nn.Linear(d_filter, 4) # d_filter=256,輸出是4維RGBA
def forward(self,
x: torch.Tensor, # ?
viewdirs: Optional[torch.Tensor] = None) -> torch.Tensor:
# Forward pass with optional view direction.
if self.d_viewdirs is None and viewdirs is not None:
raise ValueError('Cannot input x_direction')
# Apply forward pass up to bottleneck
x_input = x # 這里的x是幾維?從下面的分離RGB和A看,應(yīng)該是4D
# 下面通過8層MLP訓(xùn)練RGBA
for i, layer in enumerate(self.layers): # 8層,每一層進(jìn)行運(yùn)算
x = self.act(layer(x))
if i in self.skip:
x = torch.cat([x, x_input], dim=-1)
# Apply bottleneck bottleneck 瓶頸是啥?是不是最費(fèi)算力的模塊?
if self.d_viewdirs is not None:
# 從網(wǎng)絡(luò)輸出分離A,RGB還需要經(jīng)過更多訓(xùn)練
alpha = self.alpha_out(x)
# Pass through bottleneck to get RGB
x = self.rgb_filters(x)
x = torch.concat([x, viewdirs], dim=-1)
x = self.act(self.branch(x)) # self.branch shape: (d_filter // 2)
x = self.output(x) # self.output shape: (3)
# Concatenate alphas to output
x = torch.concat([x, alpha], dim=-1)
else:
# Simple output
x = self.output(x)
return x
思考:這個(gè)NERF類的輸入輸出是什么?通過這個(gè)類發(fā)生了啥?從__init__函數(shù)參數(shù)看出,主要是對(duì)神經(jīng)網(wǎng)絡(luò)的輸入、層次和維度等進(jìn)行設(shè)置,輸入了5D數(shù)據(jù),也就是視點(diǎn)位置和視線方向,輸出的是RGBA。問題,這個(gè)輸出的RGBA是一個(gè)點(diǎn)的?還是視線上一串的?如果是一串的,沒有看到位置編碼如何確定每個(gè)采樣點(diǎn)的RGBA?
也沒看到采樣間隔之類的說明;如果是一個(gè)點(diǎn),那這個(gè)RGBA是視線上哪個(gè)點(diǎn)的?是不是眼睛看到的視線采樣點(diǎn)集合成后的點(diǎn)RGBA?從NERF類代碼可以看出,主要是根據(jù)視點(diǎn)位置和視線方向進(jìn)行了多層前饋訓(xùn)練,輸入5D的視點(diǎn)位置和視線方向,輸出4D的RGBA。
3.3 可微分體渲染器(Differentiable Volume Renderer)
RGBA輸出點(diǎn)位于3D空間中,因此要將它們合成圖像,需要應(yīng)用論文第4節(jié)中方程1-3中描述的體積積分。本質(zhì)上,沿著每個(gè)像素的視線對(duì)所有樣本進(jìn)行加權(quán)求和,以獲得該像素的估計(jì)顏色值。每個(gè)RGB采樣都按其透明度alpha值進(jìn)行加權(quán):α值越高,表示采樣區(qū)域不透明的可能性越高,因此沿射線更遠(yuǎn)的點(diǎn)更有可能被遮擋。累積乘積運(yùn)算確保了這些進(jìn)一步的點(diǎn)被抑制。
原始NeRF模型輸出的體繪制:
def raw2outputs(raw: torch.Tensor,
z_vals: torch.Tensor,
rays_d: torch.Tensor,
raw_noise_std: float = 0.0,
white_bkgd: bool = False)
-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# 將原始的NeRF輸出轉(zhuǎn)為RGB和其他映射
# Difference between consecutive elements of `z_vals`. [n_rays, n_samples]
dists = z_vals[..., 1:] - z_vals[..., :-1]# ?這里減法的意義是啥?
dists = torch.cat([dists, 1e10 * torch.ones_like(dists[..., :1])], dim=-1)
# 將每個(gè)距離乘以其對(duì)應(yīng)方向光線的范數(shù),以轉(zhuǎn)換為真實(shí)世界的距離(考慮非單位方向)
dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
# 將噪聲添加到模型對(duì)密度的預(yù)測(cè)中,用于在訓(xùn)練期間規(guī)范網(wǎng)絡(luò)(防止漂浮物偽影)
noise = 0.
if raw_noise_std > 0.:
noise = torch.randn(raw[..., 3].shape) * raw_noise_std
# Predict density of each sample along each ray. Higher values imply
# higher likelihood of being absorbed at this point. [n_rays, n_samples]
alpha = 1.0 - torch.exp(-nn.functional.relu(raw[..., 3] + noise) * dists)
# Compute weight for RGB of each sample along each ray. [n_rays, n_samples]
# The higher the alpha, the lower subsequent weights are driven.
weights = alpha * cumprod_exclusive(1. - alpha + 1e-10)
# Compute weighted RGB map.
rgb = torch.sigmoid(raw[..., :3]) # [n_rays, n_samples, 3]
rgb_map = torch.sum(weights[..., None] * rgb, dim=-2) # [n_rays, 3]
# Estimated depth map is predicted distance.
depth_map = torch.sum(weights * z_vals, dim=-1)
# Disparity map is inverse depth.
disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map),
depth_map / torch.sum(weights, -1))
# Sum of weights along each ray. In [0, 1] up to numerical error.
acc_map = torch.sum(weights, dim=-1)
# To composite onto a white background, use the accumulated alpha map.
if white_bkgd:
rgb_map = rgb_map + (1. - acc_map[..., None])
return rgb_map, depth_map, acc_map, weights
def cumprod_exclusive(tensor: torch.Tensor) -> torch.Tensor:
# (Courtesy of https://github.com/krrish94/nerf-pytorch)
# Compute regular cumprod first.
cumprod = torch.cumprod(tensor, -1)
# "Roll" the elements along dimension 'dim' by 1 element.
cumprod = torch.roll(cumprod, 1, -1)
# Replace the first element by "1" as this is what tf.cumprod(..., exclusive=True) does.
cumprod[..., 0] = 1.
return cumprod
問題:這里的主要功能是啥?輸入了什么?輸出了什么?
3.4 Stratified Sampling
相機(jī)最終拾取到的RGB值是沿著穿過該像素視線的光樣本的累積,經(jīng)典的體積渲染方法是沿著該視線累積點(diǎn),然后對(duì)點(diǎn)進(jìn)行積分,在每個(gè)點(diǎn)估計(jì)光線在不撞擊任何粒子的情況下射行的概率。因此,每個(gè)像素都需要沿著穿過它的光線對(duì)點(diǎn)進(jìn)行采樣。為了最好地近似積分,他們的分層采樣方法是將空間均勻地劃分為N個(gè)倉(bins),并從每個(gè)倉中均勻地抽取一個(gè)樣本。stratified sampling方法不是簡(jiǎn)單地以相等的間隔繪制樣本,而是允許模型在連續(xù)空間中采樣,從而調(diào)節(jié)網(wǎng)絡(luò)在連續(xù)空間上學(xué)習(xí)。
圖片
分層采樣PyTorch中實(shí)現(xiàn):
def sample_stratified(rays_o: torch.Tensor,
rays_d: torch.Tensor,
near: float,
far: float,
n_samples: int,
perturb: Optional[bool] = True,
inverse_depth: bool = False)
-> Tuple[torch.Tensor, torch.Tensor]:
# Sample along ray from regularly-spaced bins.
# Grab samples for space integration along ray
t_vals = torch.linspace(0., 1., n_samples, device=rays_o.device)
if not inverse_depth:
# Sample linearly between `near` and `far`
z_vals = near * (1.-t_vals) + far * (t_vals)
else:
# Sample linearly in inverse depth (disparity)
z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
# Draw uniform samples from bins along ray
if perturb:
mids = .5 * (z_vals[1:] + z_vals[:-1])
upper = torch.concat([mids, z_vals[-1:]], dim=-1)
lower = torch.concat([z_vals[:1], mids], dim=-1)
t_rand = torch.rand([n_samples], device=z_vals.device)
z_vals = lower + (upper - lower) * t_rand
z_vals = z_vals.expand(list(rays_o.shape[:-1]) + [n_samples])
# Apply scale from `rays_d` and offset from `rays_o` to samples
# pts: (width, height, n_samples, 3)
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
return pts, z_vals
3.5 層次體積采樣(Hierarchical Volume Sampling)
輻射場(chǎng)由兩個(gè)多層感知器表示:一個(gè)是在粗略級(jí)別上操作,對(duì)場(chǎng)景的廣泛結(jié)構(gòu)屬性進(jìn)行編碼;另一個(gè)是在精細(xì)的層面上細(xì)化細(xì)節(jié),從而實(shí)現(xiàn)網(wǎng)格和分支等薄而復(fù)雜的結(jié)構(gòu)。此外,他們接收的樣本是不同的,粗模型在整個(gè)射線中處理寬的、大多是規(guī)則間隔的樣本,而精細(xì)模型在具有強(qiáng)先驗(yàn)的區(qū)域中珩磨(honing in)以獲得顯著信息。
這種“珩磨”過程是通過層次體積采樣流程完成的。3D空間實(shí)際上非常稀疏,存在遮擋,因此大多數(shù)點(diǎn)對(duì)渲染圖像的貢獻(xiàn)不大。因此,對(duì)具有對(duì)積分貢獻(xiàn)可能性高的區(qū)域進(jìn)行過采樣(oversample)更有好處。他們將學(xué)習(xí)到的歸一化權(quán)重應(yīng)用于第一組樣本,以在光線上創(chuàng)建PDF,然后再將inverse transform sampling應(yīng)用于該P(yáng)DF以收集第二組樣本。該集合與第一集合相結(jié)合,并被饋送到精細(xì)網(wǎng)絡(luò)以產(chǎn)生最終輸出。
分層采樣PyTorch實(shí)現(xiàn):
def sample_hierarchical(rays_o: torch.Tensor,
rays_d: torch.Tensor,
z_vals: torch.Tensor,
weights: torch.Tensor,
n_samples: int,
perturb: bool = False)
-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Apply hierarchical sampling to the rays.
# Draw samples from PDF using z_vals as bins and weights as probabilities.
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
new_z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], n_samples, perturb=perturb)
new_z_samples = new_z_samples.detach()
# Resample points from ray based on PDF.
z_vals_combined, _ = torch.sort(torch.cat([z_vals, new_z_samples], dim=-1), dim=-1)
# [N_rays, N_samples + n_samples, 3]
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals_combined[..., :, None]
return pts, z_vals_combined, new_z_samples
def sample_pdf(bins: torch.Tensor,
weights: torch.Tensor,
n_samples: int,
perturb: bool = False) -> torch.Tensor:
# Apply inverse transform sampling to a weighted set of points.
# Normalize weights to get PDF.
# [n_rays, weights.shape[-1]]
pdf = (weights + 1e-5) / torch.sum(weights + 1e-5, -1, keepdims=True)
# Convert PDF to CDF.
cdf = torch.cumsum(pdf, dim=-1) # [n_rays, weights.shape[-1]]
# [n_rays, weights.shape[-1] + 1]
cdf = torch.concat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1)
# Take sample positions to grab from CDF. Linear when perturb == 0.
if not perturb:
u = torch.linspace(0., 1., n_samples, device=cdf.device)
u = u.expand(list(cdf.shape[:-1]) + [n_samples]) # [n_rays, n_samples]
else:
# [n_rays, n_samples]
u = torch.rand(list(cdf.shape[:-1]) + [n_samples], device=cdf.device)
# Find indices along CDF where values in u would be placed.
u = u.contiguous() # Returns contiguous tensor with same values.
inds = torch.searchsorted(cdf, u, right=True) # [n_rays, n_samples]
# Clamp indices that are out of bounds.
below = torch.clamp(inds - 1, min=0)
above = torch.clamp(inds, max=cdf.shape[-1] - 1)
inds_g = torch.stack([below, above], dim=-1) # [n_rays, n_samples, 2]
# Sample from cdf and the corresponding bin centers.
matched_shape = list(inds_g.shape[:-1]) + [cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(-2).expand(matched_shape), dim=-1,index=inds_g)
bins_g = torch.gather(bins.unsqueeze(-2).expand(matched_shape), dim=-1, index=inds_g)
# Convert samples to ray length.
denom = (cdf_g[..., 1] - cdf_g[..., 0])
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
return samples # [n_rays, n_samples]
4 Training
論文中訓(xùn)練NeRF推薦的每網(wǎng)絡(luò)8層、每層256維的架構(gòu)在訓(xùn)練過程中會(huì)消耗大量?jī)?nèi)存。緩解這種情況的方法是將前傳(forward pass)分成更小的部分,然后在這些部分上積累梯度。注意與minibatching的區(qū)別:梯度是在采樣光線的單個(gè)小批次上累積的,這些光線可能已經(jīng)被收集成塊。如果沒有論文中使用的NVIDIA V100類似性能的GPU,可能必須相應(yīng)地調(diào)整塊大小以避免OOM錯(cuò)誤。Colab筆記本采用了更小的架構(gòu)和更適中的分塊尺寸。
我個(gè)人發(fā)現(xiàn),由于局部極小值,即使選擇了許多默認(rèn)值,NeRF的訓(xùn)練也有些棘手。一些有幫助的技術(shù)包括早期訓(xùn)練迭代和早期重新啟動(dòng)期間的中心裁剪(center cropping)。隨意嘗試不同的超參數(shù)和技術(shù),以進(jìn)一步提高訓(xùn)練收斂性。
初始化
def init_models():
# Initialize models, encoders, and optimizer for NeRF training.
encoder = PositionalEncoder(d_input, n_freqs, log_space=log_space)
encode = lambda x: encoder(x)
# View direction encoders
if use_viewdirs:
encoder_viewdirs = PositionalEncoder(d_input, n_freqs_views,log_space=log_space)
encode_viewdirs = lambda x: encoder_viewdirs(x)
d_viewdirs = encoder_viewdirs.d_output
else:
encode_viewdirs = None
d_viewdirs = None
model = NeRF(encoder.d_output,
n_layers=n_layers,
d_filter=d_filter, skip=skip,d_viewdirs=d_viewdirs)
model.to(device)
model_params = list(model.parameters())
if use_fine_model:
fine_model = NeRF(encoder.d_output,
n_layers=n_layers,
d_filter=d_filter, skip=skip,d_viewdirs=d_viewdirs)
fine_model.to(device)
model_params = model_params + list(fine_model.parameters())
else:
fine_model = None
optimizer = torch.optim.Adam(model_params, lr=lr)
warmup_stopper = EarlyStopping(patience=50)
return model, fine_model, encode, encode_viewdirs, optimizer, warmup_stopper
訓(xùn)練
def train():
# Launch training session for NeRF.
# Shuffle rays across all images.
if not one_image_per_step:
height, width = images.shape[1:3]
all_rays = torch.stack([torch.stack(get_rays(height, width, focal, p), 0)
for p in poses[:n_training]], 0)
rays_rgb = torch.cat([all_rays, images[:, None]], 1)
rays_rgb = torch.permute(rays_rgb, [0, 2, 3, 1, 4])
rays_rgb = rays_rgb.reshape([-1, 3, 3])
rays_rgb = rays_rgb.type(torch.float32)
rays_rgb = rays_rgb[torch.randperm(rays_rgb.shape[0])]
i_batch = 0
train_psnrs = []
val_psnrs = []
iternums = []
for i in trange(n_iters):
model.train()
if one_image_per_step:
# Randomly pick an image as the target.
target_img_idx = np.random.randint(images.shape[0])
target_img = images[target_img_idx].to(device)
if center_crop and i < center_crop_iters:
target_img = crop_center(target_img)
height, width = target_img.shape[:2]
target_pose = poses[target_img_idx].to(device)
rays_o, rays_d = get_rays(height, width, focal, target_pose)
rays_o = rays_o.reshape([-1, 3])
rays_d = rays_d.reshape([-1, 3])
else:
# Random over all images.
batch = rays_rgb[i_batch:i_batch + batch_size]
batch = torch.transpose(batch, 0, 1)
rays_o, rays_d, target_img = batch
height, width = target_img.shape[:2]
i_batch += batch_size
# Shuffle after one epoch
if i_batch >= rays_rgb.shape[0]:
rays_rgb = rays_rgb[torch.randperm(rays_rgb.shape[0])]
i_batch = 0
target_img = target_img.reshape([-1, 3])
# Run one iteration of TinyNeRF and get the rendered RGB image.
outputs = nerf_forward(rays_o, rays_d,
near, far, encode, model,
kwargs_sample_stratified=kwargs_sample_stratified,
n_samples_hierarchical=n_samples_hierarchical,
kwargs_sample_hierarchical=kwargs_sample_hierarchical,
fine_model=fine_model,
viewdirs_encoding_fn=encode_viewdirs,
chunksize=chunksize)
# Backprop!
rgb_predicted = outputs['rgb_map']
loss = torch.nn.functional.mse_loss(rgb_predicted, target_img)
loss.backward()
optimizer.step()
optimizer.zero_grad()
psnr = -10. * torch.log10(loss)
train_psnrs.append(psnr.item())
# Evaluate testimg at given display rate.
if i % display_rate == 0:
model.eval()
height, width = testimg.shape[:2]
rays_o, rays_d = get_rays(height, width, focal, testpose)
rays_o = rays_o.reshape([-1, 3])
rays_d = rays_d.reshape([-1, 3])
outputs = nerf_forward(rays_o, rays_d,
near, far, encode, model,
kwargs_sample_stratified=kwargs_sample_stratified,
n_samples_hierarchical=n_samples_hierarchical,
kwargs_sample_hierarchical=kwargs_sample_hierarchical,
fine_model=fine_model,
viewdirs_encoding_fn=encode_viewdirs,
chunksize=chunksize)
rgb_predicted = outputs['rgb_map']
loss = torch.nn.functional.mse_loss(rgb_predicted, testimg.reshape(-1, 3))
val_psnr = -10. * torch.log10(loss)
val_psnrs.append(val_psnr.item())
iternums.append(i)
# Check PSNR for issues and stop if any are found.
if i == warmup_iters - 1:
if val_psnr < warmup_min_fitness:
return False, train_psnrs, val_psnrs
elif i < warmup_iters:
if warmup_stopper is not None and warmup_stopper(i, psnr):
return False, train_psnrs, val_psnrs
return True, train_psnrs, val_psnrs
訓(xùn)練
# Run training session(s)
for _ in range(n_restarts):
model, fine_model, encode, encode_viewdirs, optimizer, warmup_stopper = init_models()
success, train_psnrs, val_psnrs = train()
if success and val_psnrs[-1] >= warmup_min_fitness:
print('Training successful!')
break
print(f'Done!')
5 Conclusion
輻射場(chǎng)標(biāo)志著處理3D數(shù)據(jù)的方式發(fā)生了巨大變化。NeRF模型和更廣泛的可微分渲染正在迅速彌合圖像創(chuàng)建和體積場(chǎng)景創(chuàng)建之間的差距。雖然我們的組件可能看起來非常復(fù)雜,但受vanilla NeRF啟發(fā)的無數(shù)其他方法證明,基本概念(連續(xù)函數(shù)逼近器+可微分渲染器)是構(gòu)建各種解決方案的堅(jiān)實(shí)基礎(chǔ),這些解決方案可用于幾乎無限的情況。
原文:NeRF From Nothing: A Tutorial with PyTorch | Towards Data Science
原文鏈接:https://mp.weixin.qq.com/s/zxJAIpAmLgsIuTsPqQqOVg