examples/notebook/notebook_neural_field_2d/neural_field_2d.ipynb
from __future__ import annotations
import io
import itertools
import numpy as np
import PIL.Image
import requests
import torch
from tqdm import tqdm
import rerun as rr # pip install rerun-sdk
import rerun.blueprint as rrb
First, we define the neural field class which we can be used to represent any continuous ND signal. I.e., it maps an ND point to another ND point. In this notebook we fit fields to map from 2D image coordinates to RGB colors. This way the network weights can be interpreted as encoding a continuous image.
class NeuralField(torch.nn.Module): # type: ignore[misc]
"""Simple neural field composed of positional encoding, MLP, and activation function."""
def __init__(
self,
num_layers: int,
dim_hidden: int,
dim_in: int = 2,
dim_out: int = 3,
activation: str = "sigmoid",
pe_sigma: float | None = None,
) -> None:
super().__init__()
self.num_layers = num_layers
self.dim_hidden = dim_hidden
self.dim_in = dim_in
self.dim_out = dim_out
self.activation = activation
self.pe_sigma = pe_sigma
sizes = [dim_in] + [dim_hidden for _ in range(num_layers - 1)] + [dim_out]
self.linears = torch.nn.ModuleList()
for in_size, out_size in itertools.pairwise(sizes):
self.linears.append(torch.nn.Linear(in_size, out_size))
if self.pe_sigma is not None:
assert isinstance(self.linears[0].weight, torch.Tensor)
torch.nn.init.normal_(self.linears[0].weight, 0.0, self.pe_sigma)
def __str__(self) -> str:
return f"{self.num_layers} lay., {self.dim_hidden} neu., pe σ: {self.pe_sigma}"
def forward(self, input_points: torch.Tensor) -> torch.Tensor:
"""
Compute output for given input points.
Args:
input_points: input points
"""
if self.pe_sigma is None:
out = torch.relu(self.linears[0](input_points))
else:
out = torch.sin(self.linears[0](input_points))
for linear in self.linears[1:-1]:
out = torch.relu(linear(out))
out = self.linears[-1](out)
if self.activation == "sigmoid":
out = torch.sigmoid(out)
return out
Now we create a few neural fields with different parameters and visualize their output as images. We assume that images are fit in a 0 to 1 unit square, so we query in a dense grid (with some additional margin to observe out-of-training behavior) to retrieve the image from the network. Note that the positional encoding encodes how quickly the neural field varies out-of-the-box. This corresponds to the amount of detail that the field can easily represent, but also determines how the field extrapolates outside of the training region.
fields = [
NeuralField(num_layers=5, dim_hidden=128, pe_sigma=5),
NeuralField(num_layers=5, dim_hidden=128, pe_sigma=15),
NeuralField(num_layers=5, dim_hidden=128, pe_sigma=30),
NeuralField(num_layers=5, dim_hidden=128, pe_sigma=100),
]
total_iterations = [0 for _ in fields]
rr.init("rerun_example_cube")
blueprint = rrb.Blueprint(
rrb.Vertical(
rrb.Grid(
rrb.Spatial2DView(name="Target", origin="target"),
*[rrb.Spatial2DView(name=str(field), origin=f"field_{i}") for i, field in enumerate(fields)],
),
rrb.TimeSeriesView(
name="Losses",
origin="/",
plot_legend=rrb.Corner2D.LeftTop,
),
row_shares=[0.7, 0.3],
),
collapse_panels=True,
)
for i, field in enumerate(fields):
rr.log(
f"loss/field_{i}",
rr.SeriesLines(names=str(field), aggregation_policy=rr.components.AggregationPolicy.auto("Average")),
static=True,
)
rr.notebook_show(blueprint=blueprint, width=1050, height=600)
@torch.no_grad() # type: ignore[misc]
def log_field_as_image(
entity_path: str,
field: NeuralField,
min_uv: tuple[float, float],
max_uv: tuple[float, float],
uv_resolution: tuple[int, int],
) -> None:
u_values = torch.linspace(min_uv[0], max_uv[0], uv_resolution[0])
v_values = torch.linspace(min_uv[1], max_uv[1], uv_resolution[1])
uv_points = torch.cartesian_prod(u_values, v_values) + 0.5 / torch.tensor(
uv_resolution,
) # 0.5 is the center of a pixel
predictions = field(uv_points)
image_prediction = torch.clamp(predictions.reshape(uv_resolution[0], uv_resolution[1], 3), 0, 1)
image_prediction = image_prediction.permute(1, 0, 2)
rr.log(entity_path, rr.Image(image_prediction.numpy(force=True)))
rr.set_time("iteration", sequence=0)
for i, field in enumerate(fields):
log_field_as_image(f"field_{i}", field, (-0.1, -0.1), (1.1, 1.1), (100, 100))
Now we train the neural fields for a fixed number of iterations. If you run the cell twice, we continue training where we left off. To reset the fields, run the previous cell again.
field_ids = [0, 1, 2, 3] # if you only want to train one of the fields
num_iterations = 3000 # Run to 10_000 for better fit
batch_size = 1000
learning_rate = 1e-3
log_image_period = 10
response = requests.get("https://storage.googleapis.com/rerun-example-datasets/example_images/tiger.jpg")
# response = requests.get("https://storage.googleapis.com/rerun-example-datasets/example_images/bird.jpg")
src_array = np.asarray(PIL.Image.open(io.BytesIO(response.content)))
target_image = torch.tensor(src_array).float() / 255
rr.log("target", rr.Image(target_image))
try:
parameters = itertools.chain(*[fields[field_id].parameters() for field_id in field_ids])
optimizer = torch.optim.Adam(parameters, lr=learning_rate)
for iteration in tqdm(range(num_iterations)):
optimizer.zero_grad()
target_uvs = torch.rand(batch_size, 2)
target_jis = (target_uvs * torch.tensor([target_image.shape[1], target_image.shape[0]])).int()
target_rgbs = target_image[target_jis[:, 1], target_jis[:, 0]]
for field_id in field_ids:
field = fields[field_id]
total_iterations[field_id] += 1
predicted_rgbs = field(target_uvs)
loss = torch.nn.functional.mse_loss(target_rgbs, predicted_rgbs)
rr.set_time("iteration", sequence=total_iterations[field_id])
rr.log(f"loss/field_{field_id}", rr.Scalars(loss.item()))
loss.backward()
optimizer.step()
if iteration % log_image_period == 0:
for field_id in field_ids:
log_field_as_image(f"field_{field_id}", fields[field_id], (-0.1, -0.1), (1.1, 1.1), (100, 100))
except KeyboardInterrupt:
print("Training stopped.")