自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

NeRF是什么?基于NeRF的三維重建是基于體素嗎?

人工智能 新聞
NeRF是一種生成模型,以圖像和精確姿勢(shì)為條件,生成給定圖像的3D場(chǎng)景的新視圖,這一過程通常被稱為“新視圖合成”。

本文經(jīng)自動(dòng)駕駛之心公眾號(hào)授權(quán)轉(zhuǎn)載,轉(zhuǎn)載請(qǐng)聯(lián)系出處。

1介紹

神經(jīng)輻射場(chǎng)(NeRF)是深度學(xué)習(xí)和計(jì)算機(jī)視覺領(lǐng)域的一個(gè)相當(dāng)新的范式。ECCV 2020論文《NeRF:將場(chǎng)景表示為視圖合成的神經(jīng)輻射場(chǎng)》(該論文獲得了最佳論文獎(jiǎng))中介紹了這項(xiàng)技術(shù),該技術(shù)自此大受歡迎,迄今已獲得近800次引用[1]。該方法標(biāo)志著機(jī)器學(xué)習(xí)處理3D數(shù)據(jù)的傳統(tǒng)方式發(fā)生了巨大變化。

神經(jīng)輻射場(chǎng)場(chǎng)景表示和可微分渲染過程:

通過沿著相機(jī)射線采樣5D坐標(biāo)(位置和觀看方向)來合成圖像; 將這些位置輸入MLP以產(chǎn)生顏色和體積密度; 并使用體積渲染技術(shù)將這些值合成圖像; 該渲染函數(shù)是可微分的,因此可以通過最小化合成圖像和真實(shí)觀測(cè)圖像之間的殘差來優(yōu)化場(chǎng)景表示。

2 What is a NeRF?

NeRF是一種生成模型,以圖像和精確姿勢(shì)為條件,生成給定圖像的3D場(chǎng)景的新視圖,這一過程通常被稱為“新視圖合成”。不僅如此,它還將場(chǎng)景的3D形狀和外觀明確定義為連續(xù)函數(shù),可以通過marching cubes生成3D網(wǎng)格。盡管它們直接從圖像數(shù)據(jù)中學(xué)習(xí),但它們既不使用convolutional層,也不使用transformer層。

多年來,機(jī)器學(xué)習(xí)應(yīng)用中表示3D數(shù)據(jù)的方法很多,從3D體素到點(diǎn)云,再到符號(hào)距離(signed distance )函數(shù)。他們最大的共同缺點(diǎn)是需要預(yù)先假設(shè)一個(gè)3D模型,要么使用攝影測(cè)量或激光雷達(dá)等工具來生成3D數(shù)據(jù),要么手工制作3D模型。然而,許多類型的物體,如高反射物體、“網(wǎng)格狀”物體或透明物體,都無法按比例掃描。3D重建方法通常也具有重建誤差,這可能導(dǎo)致影響模型精度的階梯效應(yīng)或漂移。

相比之下,NeRF基于射線光場(chǎng)的概念。光場(chǎng)是描述光傳輸如何在整個(gè)3D體積中發(fā)生的函數(shù)。它描述了光線在空間中的每個(gè)x=(x,y,z)坐標(biāo)和每個(gè)方向d上移動(dòng)的方向,描述為θ和ξ角或單位向量。它們共同形成了描述3D場(chǎng)景中的光傳輸?shù)?D特征空間。受此表示的啟發(fā),NeRF試圖近似一個(gè)函數(shù),該函數(shù)從該空間映射到由顏色c=(R,G,B)和濃度(density)σ組成的4D空間,可以將其視為該5D坐標(biāo)空間處的光線終止的可能性(例如通過遮擋)。因此,標(biāo)準(zhǔn)NeRF是形式F:(x,d)->(c,σ)的函數(shù)。

原始的NeRF論文使用多層感知器將該函數(shù)參數(shù)化,該感知器基于一組姿勢(shì)已知的圖像上訓(xùn)練得到。這是一類稱為generalized scene reconstruction的技術(shù)中的一種方法,旨在直接從圖像集合中描述3D場(chǎng)景。這種方法具備一些非常好的特性:

  • 直接從數(shù)據(jù)中學(xué)習(xí)
  • 場(chǎng)景的連續(xù)表示允許非常薄和復(fù)雜的結(jié)構(gòu),例如樹葉或網(wǎng)格
  • 隱含物理特性,如鏡面性和粗糙度
  • 隱式呈現(xiàn)場(chǎng)景中的照明

此后,一系列的改進(jìn)論文隨之涌現(xiàn),例如,少鏡頭和單鏡頭學(xué)習(xí)[2,3]、對(duì)動(dòng)態(tài)場(chǎng)景的支持[4,5]、將光場(chǎng)推廣到特征場(chǎng)[6]、從網(wǎng)絡(luò)上的未校準(zhǔn)圖像集合中學(xué)習(xí)[7]、結(jié)合激光雷達(dá)數(shù)據(jù)[8]、大規(guī)模場(chǎng)景表示[9]、在沒有神經(jīng)網(wǎng)絡(luò)的情況下學(xué)習(xí)[10],諸如此類。

3 NeRF Architecture

總體而言,給定一個(gè)經(jīng)過訓(xùn)練的NeRF模型和一個(gè)具有已知姿勢(shì)和圖像維度的相機(jī),我們通過以下過程構(gòu)建場(chǎng)景:

  • 對(duì)于每個(gè)像素,從相機(jī)光心射出光線穿過場(chǎng)景,以在(x,d)位置收集一組樣本
  • 使用每個(gè)樣本的點(diǎn)和視線方向(x,d)作為輸入,以產(chǎn)生輸出(c,σ)值(rgbσ)
  • 使用經(jīng)典的體積渲染技術(shù)構(gòu)建圖像

光射場(chǎng)(很多文獻(xiàn)翻譯為"輻射場(chǎng)",但譯者認(rèn)為"光射場(chǎng)"更直觀)函數(shù)只是幾個(gè)組件中的一個(gè),一旦組合起來,就可以創(chuàng)建之前看到的視頻中的視覺效果??傮w而言,本文包括以下幾個(gè)部分:

  • 位置編碼(Positional encoding)
  • 光射場(chǎng)函數(shù)近似器(MLP)
  • 可微分體渲染器(Differentiable volume renderer)
  • 分層(Stratified)取樣 層次(Hierarchical)體積采樣

為了最大限度地清晰講述,本文將每個(gè)組件的關(guān)鍵元素以盡可能簡(jiǎn)潔的代碼展示。參考了bmild的原始實(shí)現(xiàn)和yenchenlin和krrish94的PyTorch實(shí)現(xiàn)。

3.1 Positional Encoder

就像2017年推出的transformer模型[11]一樣,NeRF也受益于位置編碼器作為其輸入。它使用高頻函數(shù)將其連續(xù)輸入映射到更高維的空間,以幫助模型學(xué)習(xí)數(shù)據(jù)中的高頻變化,從而產(chǎn)生更清晰的模型。這種方法避開(circumvent)了神經(jīng)網(wǎng)絡(luò)對(duì)低頻函數(shù)偏置(bias),使NeRF能夠表示更清晰的細(xì)節(jié)。作者參考了ICML 2019上的一篇論文[12]。

如果熟悉transformerd的位置編碼,NeRF的相關(guān)實(shí)現(xiàn)是很標(biāo)準(zhǔn)的,它具有相同的交替正弦和余弦表達(dá)式。位置編碼器實(shí)現(xiàn):

# py
class PositionalEncoder(nn.Module):
  # sine-cosine positional encoder for input points.
  def __init__( self,
                d_input: int,
                n_freqs: int,
                log_space: bool = False ):
    super().__init__()
    self.d_input = d_input
    self.n_freqs = n_freqs         # 是不是視線上的采樣頻率?
    self.log_space = log_space
    self.d_output = d_input * (1 + 2 * self.n_freqs)
    self.embed_fns = [lambda x: x] # 冒號(hào)前面的x表示函數(shù)參數(shù),后面的表示匿名函數(shù)運(yùn)算

    # Define frequencies in either linear or log scale
    if self.log_space:
      freq_bands = 2.**torch.linspace(0., self.n_freqs - 1, self.n_freqs)
    else:
      freq_bands = torch.linspace(2.**0., 2.**(self.n_freqs - 1), self.n_freqs)

    # Alternate sin and cos
    for freq in freq_bands:
      self.embed_fns.append(lambda x, freq=freq: torch.sin(x * freq))
      self.embed_fns.append(lambda x, freq=freq: torch.cos(x * freq))
  
  def forward(self, x) -> torch.Tensor:
    # Apply positional encoding to input.
    return torch.concat([fn(x) for fn in self.embed_fns], dim=-1)

思考:這個(gè)位置編碼,是對(duì)輸入點(diǎn)(input points)進(jìn)行編碼,這個(gè)輸入點(diǎn)是視線上的采樣點(diǎn)?還是不同的視角位置點(diǎn)? self.n_freqs是不是視線上的采樣頻率?由此理解,應(yīng)該就是視線上的采樣位置,因?yàn)槿绻粚?duì)視線上的采樣位置進(jìn)行編碼,就無法有效表示這些位置,也就無法對(duì)它們的RGBA進(jìn)行訓(xùn)練。

3.2 Radiance Field Function

在原文中,光射場(chǎng)函數(shù)由NeRF模型表示,NeRF模型是一種典型的多層感知器,以編碼的3D點(diǎn)和視角方向作為輸入,并返回RGBA值作為輸出。雖然本文使用的是神經(jīng)網(wǎng)絡(luò),但這里可以使用任何函數(shù)逼近器(function approximator)。例如,Yu等人的后續(xù)論文Plenoxels使用球面諧波(spherical harmonics)實(shí)現(xiàn)了數(shù)量級(jí)的更快訓(xùn)練,同時(shí)獲得有競(jìng)爭(zhēng)力的結(jié)果[10]。

圖片圖片

NeRF模型有8層深,大多數(shù)層的特征維度為256。剩余的連接被放置在層4處。在這些層之后,產(chǎn)生RGB和σ值。RGB值用線性層進(jìn)一步處理,然后與視線方向連接,然后通過另一個(gè)線性層,最后在輸出處與σ重新組合。NeRF模型的PyTorch模塊實(shí)現(xiàn):

class NeRF(nn.Module):
  # Neural radiance fields module.
  def __init__( self,
                d_input: int = 3,                  
                n_layers: int = 8,
                d_filter: int = 256,
                skip: Tuple[int] = (4,), # (4,)只有一個(gè)元素4的元組       
                d_viewdirs: Optional[int] = None): 

    super().__init__()
    self.d_input = d_input        # 這里是3D XYZ,?
    self.skip = skip              # 是要跳過什么?為啥要跳過?被遮擋?
    self.act = nn.functional.relu
    self.d_viewdirs = d_viewdirs  # d_viewdirs 是2D方向?

    # Create model layers
    # [if_true 就執(zhí)行的指令] if [if_true條件] else [if_false]
    # 是否skip的區(qū)別是,訓(xùn)練輸入維度是否多3維,
    # if i in skip =  if i in (4,),似乎是判斷i是否等于4
    # self.d_input=3 :如果層id=4,網(wǎng)絡(luò)輸入要加3維,這是為什么?第4層有何特殊的?
    self.layers = nn.ModuleList(
      [nn.Linear(self.d_input, d_filter)] +
      [nn.Linear(d_filter + self.d_input, d_filter) if i in skip else \
       nn.Linear(d_filter  , d_filter) for i in range(n_layers - 1)]
    )

    # Bottleneck layers
    if self.d_viewdirs is not None:
      # If using viewdirs, split alpha and RGB
      self.alpha_out = nn.Linear(d_filter, 1)
      self.rgb_filters = nn.Linear(d_filter, d_filter)
      self.branch = nn.Linear(d_filter + self.d_viewdirs, d_filter // 2)
      self.output = nn.Linear(d_filter // 2, 3) # 為啥要取一半?
    else:
      # If no viewdirs, use simpler output
      self.output = nn.Linear(d_filter, 4) # d_filter=256,輸出是4維RGBA
  
  def forward(self,
              x: torch.Tensor, # ?
              viewdirs: Optional[torch.Tensor] = None) -> torch.Tensor:
   
    # Forward pass with optional view direction.
    if self.d_viewdirs is None and viewdirs is not None:
      raise ValueError('Cannot input x_direction')

    # Apply forward pass up to bottleneck
    x_input = x    # 這里的x是幾維?從下面的分離RGB和A看,應(yīng)該是4D
    # 下面通過8層MLP訓(xùn)練RGBA
    for i, layer in enumerate(self.layers):  # 8層,每一層進(jìn)行運(yùn)算
      x = self.act(layer(x)) 
      if i in self.skip:
        x = torch.cat([x, x_input], dim=-1)

    # Apply bottleneck  bottleneck 瓶頸是啥?是不是最費(fèi)算力的模塊?
    if self.d_viewdirs is not None:
      # 從網(wǎng)絡(luò)輸出分離A,RGB還需要經(jīng)過更多訓(xùn)練
      alpha = self.alpha_out(x)  
      
      # Pass through bottleneck to get RGB
      x = self.rgb_filters(x) 
      x = torch.concat([x, viewdirs], dim=-1)
      x = self.act(self.branch(x)) # self.branch shape: (d_filter // 2)
      x = self.output(x)           # self.output shape: (3)

      # Concatenate alphas to output
      x = torch.concat([x, alpha], dim=-1)
    else:
      # Simple output
      x = self.output(x)
    return x

思考:這個(gè)NERF類的輸入輸出是什么?通過這個(gè)類發(fā)生了啥?從__init__函數(shù)參數(shù)看出,主要是對(duì)神經(jīng)網(wǎng)絡(luò)的輸入、層次和維度等進(jìn)行設(shè)置,輸入了5D數(shù)據(jù),也就是視點(diǎn)位置和視線方向,輸出的是RGBA。問題,這個(gè)輸出的RGBA是一個(gè)點(diǎn)的?還是視線上一串的?如果是一串的,沒有看到位置編碼如何確定每個(gè)采樣點(diǎn)的RGBA?

也沒看到采樣間隔之類的說明;如果是一個(gè)點(diǎn),那這個(gè)RGBA是視線上哪個(gè)點(diǎn)的?是不是眼睛看到的視線采樣點(diǎn)集合成后的點(diǎn)RGBA?從NERF類代碼可以看出,主要是根據(jù)視點(diǎn)位置和視線方向進(jìn)行了多層前饋訓(xùn)練,輸入5D的視點(diǎn)位置和視線方向,輸出4D的RGBA。

3.3 可微分體渲染器(Differentiable Volume Renderer)

RGBA輸出點(diǎn)位于3D空間中,因此要將它們合成圖像,需要應(yīng)用論文第4節(jié)中方程1-3中描述的體積積分。本質(zhì)上,沿著每個(gè)像素的視線對(duì)所有樣本進(jìn)行加權(quán)求和,以獲得該像素的估計(jì)顏色值。每個(gè)RGB采樣都按其透明度alpha值進(jìn)行加權(quán):α值越高,表示采樣區(qū)域不透明的可能性越高,因此沿射線更遠(yuǎn)的點(diǎn)更有可能被遮擋。累積乘積運(yùn)算確保了這些進(jìn)一步的點(diǎn)被抑制。

原始NeRF模型輸出的體繪制:

def raw2outputs(raw: torch.Tensor,
                z_vals: torch.Tensor,
                rays_d: torch.Tensor,
                raw_noise_std: float = 0.0,
                white_bkgd: bool = False) 
  -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

  # 將原始的NeRF輸出轉(zhuǎn)為RGB和其他映射
  # Difference between consecutive elements of `z_vals`. [n_rays, n_samples]
  dists = z_vals[..., 1:] - z_vals[..., :-1]# ?這里減法的意義是啥?
  dists = torch.cat([dists, 1e10 * torch.ones_like(dists[..., :1])], dim=-1)

  # 將每個(gè)距離乘以其對(duì)應(yīng)方向光線的范數(shù),以轉(zhuǎn)換為真實(shí)世界的距離(考慮非單位方向)
  dists = dists * torch.norm(rays_d[..., None, :], dim=-1)

  # 將噪聲添加到模型對(duì)密度的預(yù)測(cè)中,用于在訓(xùn)練期間規(guī)范網(wǎng)絡(luò)(防止漂浮物偽影)
  noise = 0.
  if raw_noise_std > 0.:
    noise = torch.randn(raw[..., 3].shape) * raw_noise_std

  # Predict density of each sample along each ray. Higher values imply
  # higher likelihood of being absorbed at this point. [n_rays, n_samples]
  alpha = 1.0 - torch.exp(-nn.functional.relu(raw[..., 3] + noise) * dists)

  # Compute weight for RGB of each sample along each ray. [n_rays, n_samples]
  # The higher the alpha, the lower subsequent weights are driven.
  weights = alpha * cumprod_exclusive(1. - alpha + 1e-10)

  # Compute weighted RGB map.
  rgb = torch.sigmoid(raw[..., :3])  # [n_rays, n_samples, 3]
  rgb_map = torch.sum(weights[..., None] * rgb, dim=-2)  # [n_rays, 3]

  # Estimated depth map is predicted distance.
  depth_map = torch.sum(weights * z_vals, dim=-1)

  # Disparity map is inverse depth.
  disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map),
                            depth_map / torch.sum(weights, -1))

  # Sum of weights along each ray. In [0, 1] up to numerical error.
  acc_map = torch.sum(weights, dim=-1)

  # To composite onto a white background, use the accumulated alpha map.
  if white_bkgd:
    rgb_map = rgb_map + (1. - acc_map[..., None])

  return rgb_map, depth_map, acc_map, weights


def cumprod_exclusive(tensor: torch.Tensor) -> torch.Tensor:
  # (Courtesy of https://github.com/krrish94/nerf-pytorch)
  # Compute regular cumprod first.
  cumprod = torch.cumprod(tensor, -1)
  # "Roll" the elements along dimension 'dim' by 1 element.
  cumprod = torch.roll(cumprod, 1, -1)
  # Replace the first element by "1" as this is what tf.cumprod(..., exclusive=True) does.
  cumprod[..., 0] = 1.
  return cumprod

問題:這里的主要功能是啥?輸入了什么?輸出了什么?

3.4 Stratified Sampling

相機(jī)最終拾取到的RGB值是沿著穿過該像素視線的光樣本的累積,經(jīng)典的體積渲染方法是沿著該視線累積點(diǎn),然后對(duì)點(diǎn)進(jìn)行積分,在每個(gè)點(diǎn)估計(jì)光線在不撞擊任何粒子的情況下射行的概率。因此,每個(gè)像素都需要沿著穿過它的光線對(duì)點(diǎn)進(jìn)行采樣。為了最好地近似積分,他們的分層采樣方法是將空間均勻地劃分為N個(gè)倉(bins),并從每個(gè)倉中均勻地抽取一個(gè)樣本。stratified sampling方法不是簡(jiǎn)單地以相等的間隔繪制樣本,而是允許模型在連續(xù)空間中采樣,從而調(diào)節(jié)網(wǎng)絡(luò)在連續(xù)空間上學(xué)習(xí)。

圖片圖片

分層采樣PyTorch中實(shí)現(xiàn):

def sample_stratified(rays_o: torch.Tensor,
                      rays_d: torch.Tensor,
                      near: float,
                      far: float,
                      n_samples: int,
                      perturb: Optional[bool] = True,
                      inverse_depth: bool = False)
   -> Tuple[torch.Tensor, torch.Tensor]:
  # Sample along ray from regularly-spaced bins.
  # Grab samples for space integration along ray
  t_vals = torch.linspace(0., 1., n_samples, device=rays_o.device)
  if not inverse_depth:
    # Sample linearly between `near` and `far`
    z_vals = near * (1.-t_vals) + far * (t_vals)
  else:
    # Sample linearly in inverse depth (disparity)
    z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))

  # Draw uniform samples from bins along ray
  if perturb:
    mids = .5 * (z_vals[1:] + z_vals[:-1])
    upper = torch.concat([mids, z_vals[-1:]], dim=-1)
    lower = torch.concat([z_vals[:1], mids], dim=-1)
    t_rand = torch.rand([n_samples], device=z_vals.device)
    z_vals = lower + (upper - lower) * t_rand
  z_vals = z_vals.expand(list(rays_o.shape[:-1]) + [n_samples])

  # Apply scale from `rays_d` and offset from `rays_o` to samples
  # pts: (width, height, n_samples, 3)
  pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
  return pts, z_vals

3.5 層次體積采樣(Hierarchical Volume Sampling)

輻射場(chǎng)由兩個(gè)多層感知器表示:一個(gè)是在粗略級(jí)別上操作,對(duì)場(chǎng)景的廣泛結(jié)構(gòu)屬性進(jìn)行編碼;另一個(gè)是在精細(xì)的層面上細(xì)化細(xì)節(jié),從而實(shí)現(xiàn)網(wǎng)格和分支等薄而復(fù)雜的結(jié)構(gòu)。此外,他們接收的樣本是不同的,粗模型在整個(gè)射線中處理寬的、大多是規(guī)則間隔的樣本,而精細(xì)模型在具有強(qiáng)先驗(yàn)的區(qū)域中珩磨(honing in)以獲得顯著信息。

這種“珩磨”過程是通過層次體積采樣流程完成的。3D空間實(shí)際上非常稀疏,存在遮擋,因此大多數(shù)點(diǎn)對(duì)渲染圖像的貢獻(xiàn)不大。因此,對(duì)具有對(duì)積分貢獻(xiàn)可能性高的區(qū)域進(jìn)行過采樣(oversample)更有好處。他們將學(xué)習(xí)到的歸一化權(quán)重應(yīng)用于第一組樣本,以在光線上創(chuàng)建PDF,然后再將inverse transform sampling應(yīng)用于該P(yáng)DF以收集第二組樣本。該集合與第一集合相結(jié)合,并被饋送到精細(xì)網(wǎng)絡(luò)以產(chǎn)生最終輸出。

分層采樣PyTorch實(shí)現(xiàn):

def sample_hierarchical(rays_o: torch.Tensor,
                        rays_d: torch.Tensor,
                        z_vals: torch.Tensor,
                        weights: torch.Tensor,
                        n_samples: int,
                        perturb: bool = False) 
  -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  # Apply hierarchical sampling to the rays.
  # Draw samples from PDF using z_vals as bins and weights as probabilities.
  z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
  new_z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], n_samples, perturb=perturb)
  new_z_samples = new_z_samples.detach()

  # Resample points from ray based on PDF.
  z_vals_combined, _ = torch.sort(torch.cat([z_vals, new_z_samples], dim=-1), dim=-1)
  # [N_rays, N_samples + n_samples, 3]
  pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals_combined[..., :, None]  
  return pts, z_vals_combined, new_z_samples

def sample_pdf(bins: torch.Tensor,
               weights: torch.Tensor,
               n_samples: int,
               perturb: bool = False) -> torch.Tensor:
  # Apply inverse transform sampling to a weighted set of points.
  # Normalize weights to get PDF.
  # [n_rays, weights.shape[-1]]
  pdf = (weights + 1e-5) / torch.sum(weights + 1e-5, -1, keepdims=True) 

  # Convert PDF to CDF.
  cdf = torch.cumsum(pdf, dim=-1) # [n_rays, weights.shape[-1]]
  # [n_rays, weights.shape[-1] + 1]
  cdf = torch.concat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1) 

  # Take sample positions to grab from CDF. Linear when perturb == 0.
  if not perturb:
    u = torch.linspace(0., 1., n_samples, device=cdf.device)
    u = u.expand(list(cdf.shape[:-1]) + [n_samples]) # [n_rays, n_samples]
  else:
    # [n_rays, n_samples]
    u = torch.rand(list(cdf.shape[:-1]) + [n_samples], device=cdf.device) 

  # Find indices along CDF where values in u would be placed.
  u = u.contiguous() # Returns contiguous tensor with same values.
  inds = torch.searchsorted(cdf, u, right=True) # [n_rays, n_samples]

  # Clamp indices that are out of bounds.
  below = torch.clamp(inds - 1, min=0)
  above = torch.clamp(inds, max=cdf.shape[-1] - 1)
  inds_g = torch.stack([below, above], dim=-1) # [n_rays, n_samples, 2]

  # Sample from cdf and the corresponding bin centers.
  matched_shape = list(inds_g.shape[:-1]) + [cdf.shape[-1]]
  cdf_g = torch.gather(cdf.unsqueeze(-2).expand(matched_shape), dim=-1,index=inds_g)
  bins_g = torch.gather(bins.unsqueeze(-2).expand(matched_shape), dim=-1, index=inds_g)

  # Convert samples to ray length.
  denom = (cdf_g[..., 1] - cdf_g[..., 0])
  denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
  t = (u - cdf_g[..., 0]) / denom
  samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

  return samples # [n_rays, n_samples]

4 Training

論文中訓(xùn)練NeRF推薦的每網(wǎng)絡(luò)8層、每層256維的架構(gòu)在訓(xùn)練過程中會(huì)消耗大量?jī)?nèi)存。緩解這種情況的方法是將前傳(forward pass)分成更小的部分,然后在這些部分上積累梯度。注意與minibatching的區(qū)別:梯度是在采樣光線的單個(gè)小批次上累積的,這些光線可能已經(jīng)被收集成塊。如果沒有論文中使用的NVIDIA V100類似性能的GPU,可能必須相應(yīng)地調(diào)整塊大小以避免OOM錯(cuò)誤。Colab筆記本采用了更小的架構(gòu)和更適中的分塊尺寸。

我個(gè)人發(fā)現(xiàn),由于局部極小值,即使選擇了許多默認(rèn)值,NeRF的訓(xùn)練也有些棘手。一些有幫助的技術(shù)包括早期訓(xùn)練迭代和早期重新啟動(dòng)期間的中心裁剪(center cropping)。隨意嘗試不同的超參數(shù)和技術(shù),以進(jìn)一步提高訓(xùn)練收斂性。

初始化

def init_models():
  # Initialize models, encoders, and optimizer for NeRF training.
  encoder = PositionalEncoder(d_input, n_freqs, log_space=log_space)
  encode = lambda x: encoder(x)
  # View direction encoders
  if use_viewdirs:
    encoder_viewdirs = PositionalEncoder(d_input, n_freqs_views,log_space=log_space)
    encode_viewdirs  = lambda x: encoder_viewdirs(x)
    d_viewdirs       = encoder_viewdirs.d_output
  else:
    encode_viewdirs = None
    d_viewdirs = None

  model = NeRF(encoder.d_output, 
               n_layers=n_layers, 
               d_filter=d_filter, skip=skip,d_viewdirs=d_viewdirs)
  model.to(device)
  model_params = list(model.parameters())
  if use_fine_model:
    fine_model = NeRF(encoder.d_output, 
                      n_layers=n_layers, 
                      d_filter=d_filter, skip=skip,d_viewdirs=d_viewdirs)
    fine_model.to(device)
    model_params = model_params + list(fine_model.parameters())
  else:
    fine_model = None

  optimizer      = torch.optim.Adam(model_params, lr=lr)
  warmup_stopper = EarlyStopping(patience=50)
  return model, fine_model, encode, encode_viewdirs, optimizer, warmup_stopper

訓(xùn)練

def train():
  # Launch training session for NeRF.
  # Shuffle rays across all images.
  if not one_image_per_step:
    height, width = images.shape[1:3]
    all_rays = torch.stack([torch.stack(get_rays(height, width, focal, p), 0)
                           for p in poses[:n_training]], 0)
    rays_rgb = torch.cat([all_rays, images[:, None]], 1)
    rays_rgb = torch.permute(rays_rgb, [0, 2, 3, 1, 4])
    rays_rgb = rays_rgb.reshape([-1, 3, 3])
    rays_rgb = rays_rgb.type(torch.float32)
    rays_rgb = rays_rgb[torch.randperm(rays_rgb.shape[0])]
    i_batch = 0

  train_psnrs = []
  val_psnrs = []
  iternums = []
  for i in trange(n_iters):
    model.train()
    if one_image_per_step:
      # Randomly pick an image as the target.
      target_img_idx = np.random.randint(images.shape[0])
      target_img     = images[target_img_idx].to(device)
      if center_crop and i < center_crop_iters:
        target_img = crop_center(target_img)
      height, width = target_img.shape[:2]
      target_pose = poses[target_img_idx].to(device)
      rays_o, rays_d = get_rays(height, width, focal, target_pose)
      rays_o = rays_o.reshape([-1, 3])
      rays_d = rays_d.reshape([-1, 3])
    else:
      # Random over all images.
      batch = rays_rgb[i_batch:i_batch + batch_size]
      batch = torch.transpose(batch, 0, 1)
      rays_o, rays_d, target_img = batch
      height, width = target_img.shape[:2]
      i_batch += batch_size
      # Shuffle after one epoch
      if i_batch >= rays_rgb.shape[0]:
          rays_rgb = rays_rgb[torch.randperm(rays_rgb.shape[0])]
          i_batch = 0
    target_img = target_img.reshape([-1, 3])

    # Run one iteration of TinyNeRF and get the rendered RGB image.
    outputs = nerf_forward(rays_o, rays_d,
                           near, far, encode, model,
                           kwargs_sample_stratified=kwargs_sample_stratified,
                           n_samples_hierarchical=n_samples_hierarchical,
                           kwargs_sample_hierarchical=kwargs_sample_hierarchical,
                           fine_model=fine_model,
                           viewdirs_encoding_fn=encode_viewdirs,
                           chunksize=chunksize)
    # Backprop!
    rgb_predicted = outputs['rgb_map']
    loss = torch.nn.functional.mse_loss(rgb_predicted, target_img)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    psnr = -10. * torch.log10(loss)
    train_psnrs.append(psnr.item())

    # Evaluate testimg at given display rate.
    if i % display_rate == 0:
      model.eval()
      height, width = testimg.shape[:2]
      rays_o, rays_d = get_rays(height, width, focal, testpose)
      rays_o = rays_o.reshape([-1, 3])
      rays_d = rays_d.reshape([-1, 3])
      outputs = nerf_forward(rays_o, rays_d,
                             near, far, encode, model,
                             kwargs_sample_stratified=kwargs_sample_stratified,
                             n_samples_hierarchical=n_samples_hierarchical,
                             kwargs_sample_hierarchical=kwargs_sample_hierarchical,
                             fine_model=fine_model,
                             viewdirs_encoding_fn=encode_viewdirs,
                             chunksize=chunksize)
      rgb_predicted = outputs['rgb_map']
      loss = torch.nn.functional.mse_loss(rgb_predicted, testimg.reshape(-1, 3))
      val_psnr = -10. * torch.log10(loss)
      val_psnrs.append(val_psnr.item())
      iternums.append(i)

    # Check PSNR for issues and stop if any are found.
    if i == warmup_iters - 1:
      if val_psnr < warmup_min_fitness:
        return False, train_psnrs, val_psnrs
    elif i < warmup_iters:
      if warmup_stopper is not None and warmup_stopper(i, psnr):
        return False, train_psnrs, val_psnrs

  return True, train_psnrs, val_psnrs

訓(xùn)練

# Run training session(s)
for _ in range(n_restarts):
  model, fine_model, encode, encode_viewdirs, optimizer, warmup_stopper = init_models()
  success, train_psnrs, val_psnrs = train()
  if success and val_psnrs[-1] >= warmup_min_fitness:
    print('Training successful!')
    break
print(f'Done!')

5 Conclusion

輻射場(chǎng)標(biāo)志著處理3D數(shù)據(jù)的方式發(fā)生了巨大變化。NeRF模型和更廣泛的可微分渲染正在迅速彌合圖像創(chuàng)建和體積場(chǎng)景創(chuàng)建之間的差距。雖然我們的組件可能看起來非常復(fù)雜,但受vanilla NeRF啟發(fā)的無數(shù)其他方法證明,基本概念(連續(xù)函數(shù)逼近器+可微分渲染器)是構(gòu)建各種解決方案的堅(jiān)實(shí)基礎(chǔ),這些解決方案可用于幾乎無限的情況。

原文:NeRF From Nothing: A Tutorial with PyTorch | Towards Data Science

原文鏈接:https://mp.weixin.qq.com/s/zxJAIpAmLgsIuTsPqQqOVg

責(zé)任編輯:張燕妮 來源: 自動(dòng)駕駛之心
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)