How to Build Your Own NeRF Model in PyTorch – Step‑by‑Step Guide

This tutorial walks through the theory and implementation of Neural Radiance Fields (NeRF) in PyTorch, covering positional encoding, the MLP architecture, differentiable volume rendering, hierarchical sampling, training tricks, and references to the original research.

Code DAO
Code DAO
Code DAO
How to Build Your Own NeRF Model in PyTorch – Step‑by‑Step Guide

Introduction

Neural Radiance Fields (NeRF) is a recent paradigm for representing 3D scenes as continuous functions, introduced in the ECCV 2020 paper “NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis” (≈800 citations). This tutorial explains each component needed to build a NeRF model in PyTorch.

What is NeRF?

NeRF is a generative model that, given a set of calibrated images and camera poses, predicts color and density for any 3D point and viewing direction, enabling novel‑view synthesis and operations such as mesh extraction. It does not rely on convolutional or transformer layers and can be compact (5‑10 MB).

NeRF Architecture

The model samples points along camera rays, feeds the (x, d) coordinates into a multi‑layer perceptron (MLP) to obtain RGB and density (σ), and then uses classic volume rendering to composite an image.

Positional encoding

Radiance‑field MLP

Differentiable volume renderer

Stratified (hierarchical) sampling

Hierarchical volume sampling

Positional Encoder

Following the 2017 transformer paper, NeRF uses a sinusoidal positional encoder to map input coordinates to a higher‑dimensional space, allowing the network to represent high‑frequency details. The implementation is shown below.

class PositionalEncoder(nn.Module):
    """Sine‑cosine positional encoder for input points."""
    def __init__(self, d_input: int, n_freqs: int, log_space: bool = False):
        super().__init__()
        self.d_input = d_input
        self.n_freqs = n_freqs
        self.log_space = log_space
        self.d_output = d_input * (1 + 2 * self.n_freqs)
        self.embed_fns = [lambda x: x]
        if self.log_space:
            freq_bands = 2.**torch.linspace(0., self.n_freqs - 1, self.n_freqs)
        else:
            freq_bands = torch.linspace(2.**0., 2.**(self.n_freqs - 1), self.n_freqs)
        for freq in freq_bands:
            self.embed_fns.append(lambda x, freq=freq: torch.sin(x * freq))
            self.embed_fns.append(lambda x, freq=freq: torch.cos(x * freq))

    def forward(self, x) -> torch.Tensor:
        """Apply positional encoding to input."""
        return torch.concat([fn(x) for fn in self.embed_fns], dim=-1)

Radiance‑Field Function

The original NeRF paper models the radiance field with an 8‑layer MLP (256‑dimensional hidden units, a skip connection at layer 4). The network outputs RGB and density; optional view‑direction handling adds a second branch for color.

class NeRF(nn.Module):
    """Neural radiance fields module."""
    def __init__(self,
                 d_input: int = 3,
                 n_layers: int = 8,
                 d_filter: int = 256,
                 skip: Tuple[int] = (4,),
                 d_viewdirs: Optional[int] = None):
        super().__init__()
        self.d_input = d_input
        self.skip = skip
        self.act = nn.functional.relu
        self.d_viewdirs = d_viewdirs
        self.layers = nn.ModuleList(
            [nn.Linear(self.d_input, d_filter)] +
            [nn.Linear(d_filter + self.d_input, d_filter) if i in skip
             else nn.Linear(d_filter, d_filter) for i in range(n_layers - 1)]
        )
        if self.d_viewdirs is not None:
            self.alpha_out = nn.Linear(d_filter, 1)
            self.rgb_filters = nn.Linear(d_filter, d_filter)
            self.branch = nn.Linear(d_filter + self.d_viewdirs, d_filter // 2)
            self.output = nn.Linear(d_filter // 2, 3)
        else:
            self.output = nn.Linear(d_filter, 4)

    def forward(self, x: torch.Tensor, viewdirs: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Forward pass with optional view direction."""
        if self.d_viewdirs is None and viewdirs is not None:
            raise ValueError('Cannot input viewdirs if d_viewdirs was not given.')
        x_input = x
        for i, layer in enumerate(self.layers):
            x = self.act(layer(x))
            if i in self.skip:
                x = torch.cat([x, x_input], dim=-1)
        if self.d_viewdirs is not None:
            alpha = self.alpha_out(x)
            x = self.rgb_filters(x)
            x = torch.concat([x, viewdirs], dim=-1)
            x = self.act(self.branch(x))
            x = self.output(x)
            x = torch.concat([x, alpha], dim=-1)
        else:
            x = self.output(x)
        return x

Differentiable Volume Renderer

Given raw network outputs (rgbσ) and sampled depths, the renderer converts them into RGB, depth, accumulation, and weight maps using volume‑rendering equations. The implementation follows the original paper and includes optional white‑background compositing.

def raw2outputs(raw: torch.Tensor,
                 z_vals: torch.Tensor,
                 rays_d: torch.Tensor,
                 raw_noise_std: float = 0.0,
                 white_bkgd: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Convert the raw NeRF output into RGB and other maps."""
    dists = z_vals[..., 1:] - z_vals[..., :-1]
    dists = torch.cat([dists, 1e10 * torch.ones_like(dists[..., :1])], dim=-1)
    dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
    noise = 0.
    if raw_noise_std > 0.:
        noise = torch.randn(raw[..., 3].shape) * raw_noise_std
    alpha = 1.0 - torch.exp(-nn.functional.relu(raw[..., 3] + noise) * dists)
    weights = alpha * cumprod_exclusive(1. - alpha + 1e-10)
    rgb = torch.sigmoid(raw[..., :3])
    rgb_map = torch.sum(weights[..., None] * rgb, dim=-2)
    depth_map = torch.sum(weights * z_vals, dim=-1)
    disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
    acc_map = torch.sum(weights, dim=-1)
    if white_bkgd:
        rgb_map = rgb_map + (1. - acc_map[..., None])
    return rgb_map, depth_map, acc_map, weights

Stratified (Hierarchical) Sampling

To approximate the integral along each ray, the space is divided into N bins and a uniform sample is drawn from each bin. This yields a set of coarse samples that are later refined.

def sample_stratified(rays_o: torch.Tensor,
                      rays_d: torch.Tensor,
                      near: float,
                      far: float,
                      n_samples: int,
                      perturb: Optional[bool] = True,
                      inverse_depth: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
    """Sample along ray from regularly‑spaced bins."""
    t_vals = torch.linspace(0., 1., n_samples, device=rays_o.device)
    if not inverse_depth:
        z_vals = near * (1. - t_vals) + far * t_vals
    else:
        z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * t_vals)
    if perturb:
        mids = .5 * (z_vals[1:] + z_vals[:-1])
        upper = torch.concat([mids, z_vals[-1:]], dim=-1)
        lower = torch.concat([z_vals[:1], mids], dim=-1)
        t_rand = torch.rand([n_samples], device=z_vals.device)
        z_vals = lower + (upper - lower) * t_rand
    z_vals = z_vals.expand(list(rays_o.shape[:-1]) + [n_samples])
    pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
    return pts, z_vals

Hierarchical Volume Sampling

After the coarse pass, a probability density function (PDF) is built from the coarse weights. Inverse‑transform sampling draws additional points in regions with high expected contribution, which are then processed by a second “fine” MLP.

def sample_hierarchical(rays_o: torch.Tensor,
                        rays_d: torch.Tensor,
                        z_vals: torch.Tensor,
                        weights: torch.Tensor,
                        n_samples: int,
                        perturb: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Apply hierarchical sampling to the rays."""
    z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
    new_z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], n_samples, perturb=perturb)
    new_z_samples = new_z_samples.detach()
    z_vals_combined, _ = torch.sort(torch.cat([z_vals, new_z_samples], dim=-1), dim=-1)
    pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals_combined[..., :, None]
    return pts, z_vals_combined, new_z_samples

Training

The standard training regime follows the original paper: an 8‑layer MLP with 256 hidden units, gradient accumulation over ray‑batch blocks to fit into GPU memory, and optional minibatch size adjustments to avoid OOM on hardware weaker than an NVIDIA V100.

Conclusion

NeRF has reshaped how machine‑learning practitioners handle 3D data. The components described—continuous function approximation and differentiable rendering—form a solid foundation for many subsequent methods.

References

[1] Ben Mildenhall et al., “NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis”, ECCV 2020.

[2] Julian Chibane et al., “Stereo Radiance Fields (SRF)”, CVPR 2021.

[3] Alex Yu et al., “pixelNeRF: Neural Radiance Fields from One or Few Images”, CVPR 2021.

[4] Zhengqi Li et al., “Neural Scene Flow Fields for Space‑Time View Synthesis of Dynamic Scenes”, CVPR 2021.

[5] Albert Pumarola et al., “D‑NeRF: Neural Radiance Fields for Dynamic Scenes”, CVPR 2021.

[6] Michael Niemeyer & Andreas Geiger, “GIRAFFE: Representing Scenes as Compositional Generative Neural Feature Fields”, CVPR 2021.

[7] Zhengfei Kuang et al., “NeROIC: Neural Object Capture and Rendering from Online Image Collections”, arXiv 2022.

[8] Konstantinos Rematas et al., “Urban Radiance Fields”, CVPR 2022.

[9] Matthew Tancik et al., “Block‑NeRF: Scalable Large Scene Neural View Synthesis”, arXiv 2022.

[10] Alex Yu et al., “Plenoxels: Radiance Fields without Neural Networks”, CVPR 2022 (Oral).

[11] Ashish Vaswani et al., “Attention Is All You Need”, NeurIPS 2017.

[12] Nasim Rahaman et al., “On the Spectral Bias of Neural Networks”, PMLR 2019.

Stratified sampling example
Stratified sampling example
Hierarchical vs stratified sampling
Hierarchical vs stratified sampling
Original Source

Signed-in readers can open the original source through BestHub's protected redirect.

Sign in to view source
Republication Notice

This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactadmin@besthub.devand we will review it promptly.

PyTorchPositional EncodingNeRFVolume RenderingNeural Radiance FieldsHierarchical Sampling
Code DAO
Written by

Code DAO

We deliver AI algorithm tutorials and the latest news, curated by a team of researchers from Peking University, Shanghai Jiao Tong University, Central South University, and leading AI companies such as Huawei, Kuaishou, and SenseTime. Join us in the AI alchemy—making life better!

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.