Skip to content

Commit d44473b

Browse files
committed
add fine tune stage
1 parent 3ce0b41 commit d44473b

File tree

3 files changed

+235
-57
lines changed

3 files changed

+235
-57
lines changed

arguments/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(self, parser):
6969

7070
class OptimizationParams(ParamGroup):
7171
def __init__(self, parser):
72-
self.iterations = 30_000
72+
self.iterations = 10_000
7373
self.position_lr_init = 0.00016
7474
self.position_lr_final = 0.0000016
7575
self.position_lr_delay_mult = 0.01
@@ -85,6 +85,7 @@ def __init__(self, parser):
8585
self.densify_from_iter = 500
8686
self.densify_until_iter = 15_000
8787
self.densify_grad_threshold = 0.0002
88+
self.fine_tune = False
8889
super().__init__(parser, "Optimization Parameters")
8990

9091
def get_combined_args(parser : ArgumentParser):

fine_tune.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
#
2+
# Copyright (C) 2023, Inria
3+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
4+
# All rights reserved.
5+
#
6+
# This software is free for non-commercial, research and evaluation use
7+
# under the terms of the LICENSE.md file.
8+
#
9+
# For inquiries contact george.drettakis@inria.fr
10+
#
11+
12+
import torch
13+
from scene import Scene
14+
import os
15+
from tqdm import tqdm
16+
from os import makedirs
17+
from gaussian_renderer import render
18+
import torchvision
19+
from utils.general_utils import safe_state
20+
from argparse import ArgumentParser
21+
from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams
22+
from gaussian_renderer import GaussianModel
23+
from random import randint
24+
from utils.loss_utils import l1_loss, ssim
25+
from utils.image_utils import psnr
26+
27+
28+
def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene: Scene, renderFunc,
29+
renderArgs):
30+
if tb_writer:
31+
tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
32+
tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
33+
tb_writer.add_scalar('iter_time', elapsed, iteration)
34+
35+
# Report test and samples of training set
36+
if iteration in testing_iterations:
37+
torch.cuda.empty_cache()
38+
validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras()},
39+
{'name': 'train',
40+
'cameras': [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in
41+
range(5, 30, 5)]})
42+
43+
for config in validation_configs:
44+
if config['cameras'] and len(config['cameras']) > 0:
45+
images = torch.tensor([], device="cuda")
46+
gts = torch.tensor([], device="cuda")
47+
for idx, viewpoint in enumerate(config['cameras']):
48+
image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0)
49+
gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
50+
images = torch.cat((images, image.unsqueeze(0)), dim=0)
51+
gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0)
52+
if tb_writer and (idx < 5):
53+
tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name),
54+
image[None], global_step=iteration)
55+
if iteration == testing_iterations[0]:
56+
tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name),
57+
gt_image[None], global_step=iteration)
58+
59+
l1_test = l1_loss(images, gts)
60+
psnr_test = psnr(images, gts).mean()
61+
print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test))
62+
if tb_writer:
63+
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
64+
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
65+
66+
if tb_writer:
67+
tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
68+
tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
69+
torch.cuda.empty_cache()
70+
71+
72+
def fine_tune_sets(dataset: ModelParams, opt: OptimizationParams, pipe: PipelineParams, iteration: int,
73+
testing_iterations: int, saving_iterations: int):
74+
gaussians = GaussianModel(dataset.sh_degree)
75+
76+
scene = Scene(dataset, gaussians, load_iteration=iteration)
77+
gaussians.training_setup(opt)
78+
79+
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
80+
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
81+
82+
iter_start = torch.cuda.Event(enable_timing=True)
83+
iter_end = torch.cuda.Event(enable_timing=True)
84+
85+
viewpoint_stack = None
86+
ema_loss_for_log = 0.0
87+
progress_bar = tqdm(range(opt.iterations), desc="Fine Tune progress")
88+
89+
loaded_iter = scene.loaded_iter + 1
90+
final_iter = opt.iterations + loaded_iter
91+
for iteration in range(loaded_iter, final_iter):
92+
iter_start.record()
93+
94+
# Pick a random Camera
95+
if not viewpoint_stack:
96+
viewpoint_stack = scene.getTrainCameras().copy()
97+
98+
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))
99+
# Render
100+
render_pkg = render(viewpoint_cam, gaussians, pipe, background)
101+
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], \
102+
render_pkg["visibility_filter"], render_pkg["radii"]
103+
104+
# Loss
105+
gt_image = viewpoint_cam.original_image.cuda()
106+
Ll1 = l1_loss(image, gt_image)
107+
loss = 1.0 - ssim(image, gt_image)
108+
loss.backward()
109+
110+
iter_end.record()
111+
112+
with torch.no_grad():
113+
# Progress bar
114+
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
115+
if iteration % 10 == 0:
116+
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
117+
progress_bar.update(10)
118+
if iteration == final_iter:
119+
progress_bar.close()
120+
121+
# Log and save
122+
training_report(None, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end),
123+
testing_iterations, scene, render, (pipe, background))
124+
125+
if (iteration in saving_iterations):
126+
print("\n[ITER {}] Saving Gaussians".format(iteration))
127+
scene.save(iteration)
128+
129+
# Optimizer step
130+
if iteration < final_iter:
131+
gaussians.optimizer.step()
132+
gaussians.optimizer.zero_grad(set_to_none=True)
133+
gaussians.update_learning_rate(iteration)
134+
135+
136+
if __name__ == "__main__":
137+
# Set up command line argument parser
138+
parser = ArgumentParser(description="Testing script parameters") # add argument into parser
139+
model = ModelParams(parser, sentinel=True)
140+
op = OptimizationParams(parser)
141+
pipeline = PipelineParams(parser)
142+
parser.add_argument("--iteration", default=-1, type=int)
143+
parser.add_argument("--test_iterations", nargs="+", type=int, default=[35_000, 40_000])
144+
parser.add_argument("--save_iterations", nargs="+", type=int, default=[35_000, 40_000])
145+
parser.add_argument("--quiet", action="store_true")
146+
args = get_combined_args(parser)
147+
print("Rendering " + args.model_path)
148+
149+
# Initialize system state (RNG)
150+
safe_state(args.quiet)
151+
152+
fine_tune_sets(model.extract(args), op.extract(args), pipeline.extract(args), args.iteration, args.test_iterations,
153+
args.save_iterations)

0 commit comments

Comments
 (0)