Skip to content

Commit 3a07ac2

Browse files
author
bkerbl
committed
No more activation inside, guardbands removed
1 parent feecabd commit 3a07ac2

File tree

7 files changed

+53
-44
lines changed

7 files changed

+53
-44
lines changed

cuda_rasterizer/auxiliary.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ __forceinline__ __device__ bool in_frustum(int idx,
140140
float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
141141
p_view = transformPoint4x3(p_orig, viewmatrix);
142142

143-
if (p_view.z <= 0.2f || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3)))
143+
if (p_view.z <= 0.2f)// || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3)))
144144
{
145145
if (prefiltered)
146146
{

cuda_rasterizer/backward.cu

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ __global__ void computeCov2DCUDA(int P,
134134
const float3* means,
135135
const int* radii,
136136
const float* cov3Ds,
137-
float h_x,
138-
float h_y,
137+
const float h_x, float h_y,
138+
const float tan_fovx, float tan_fovy,
139139
const float* view_matrix,
140140
const float* dL_dconics,
141141
float3* dL_dmeans,
@@ -153,11 +153,20 @@ __global__ void computeCov2DCUDA(int P,
153153
float3 mean = means[idx];
154154
float3 dL_dconic = { dL_dconics[4 * idx], dL_dconics[4 * idx + 1], dL_dconics[4 * idx + 3] };
155155
float3 t = transformPoint4x3(mean, view_matrix);
156-
float t_inv_norm = 1.f / sqrt(t.x * t.x + t.y * t.y + t.z * t.z);
156+
157+
const float limx = 1.3f * tan_fovx;
158+
const float limy = 1.3f * tan_fovy;
159+
const float txtz = t.x / t.z;
160+
const float tytz = t.y / t.z;
161+
t.x = min(limx, max(-limx, txtz)) * t.z;
162+
t.y = min(limy, max(-limy, tytz)) * t.z;
163+
164+
const float x_grad_mul = txtz < -limx || txtz > limx ? 0 : 1;
165+
const float y_grad_mul = tytz < -limy || tytz > limy ? 0 : 1;
157166

158167
glm::mat3 J = glm::mat3(h_x / t.z, 0.0f, -(h_x * t.x) / (t.z * t.z),
159168
0.0f, h_y / t.z, -(h_y * t.y) / (t.z * t.z),
160-
t.x * t_inv_norm, t.y * t_inv_norm, t.z * t_inv_norm);
169+
0, 0, 0);
161170

162171
glm::mat3 W = glm::mat3(
163172
view_matrix[0], view_matrix[4], view_matrix[8],
@@ -239,8 +248,8 @@ __global__ void computeCov2DCUDA(int P,
239248
float tz3 = tz2 * tz;
240249

241250
// Gradients of loss w.r.t. transformed Gaussian mean t
242-
float dL_dtx = -h_x * tz2 * dL_dJ02;
243-
float dL_dty = -h_y * tz2 * dL_dJ12;
251+
float dL_dtx = x_grad_mul * -h_x * tz2 * dL_dJ02;
252+
float dL_dty = y_grad_mul * -h_y * tz2 * dL_dJ12;
244253
float dL_dtz = -h_x * tz2 * dL_dJ00 - h_y * tz2 * dL_dJ11 + (2 * h_x * t.x) * tz3 * dL_dJ02 + (2 * h_y * t.y) * tz3 * dL_dJ12;
245254

246255
// Account for transformation of mean to t
@@ -258,7 +267,7 @@ __global__ void computeCov2DCUDA(int P,
258267
__device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const glm::vec4 rot, const float* dL_dcov3Ds, glm::vec3* dL_dscales, glm::vec4* dL_drots)
259268
{
260269
// Recompute (intermediate) results for the 3D covariance computation.
261-
glm::vec4 q = rot / glm::length(rot);
270+
glm::vec4 q = rot;// / glm::length(rot);
262271
float r = q.x;
263272
float x = q.y;
264273
float y = q.z;
@@ -272,7 +281,7 @@ __device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const gl
272281

273282
glm::mat3 S = glm::mat3(1.0f);
274283

275-
glm::vec3 s = mod * exp(scale);
284+
glm::vec3 s = mod * scale;
276285
S[0][0] = s.x;
277286
S[1][1] = s.y;
278287
S[2][2] = s.z;
@@ -298,16 +307,16 @@ __device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const gl
298307
glm::mat3 Rt = glm::transpose(R);
299308
glm::mat3 dL_dMt = glm::transpose(dL_dM);
300309

301-
dL_dMt[0] *= s.x;
302-
dL_dMt[1] *= s.y;
303-
dL_dMt[2] *= s.z;
304-
305310
// Gradients of loss w.r.t. scale
306311
glm::vec3* dL_dscale = dL_dscales + idx;
307312
dL_dscale->x = glm::dot(Rt[0], dL_dMt[0]);
308313
dL_dscale->y = glm::dot(Rt[1], dL_dMt[1]);
309314
dL_dscale->z = glm::dot(Rt[2], dL_dMt[2]);
310315

316+
dL_dMt[0] *= s.x;
317+
dL_dMt[1] *= s.y;
318+
dL_dMt[2] *= s.z;
319+
311320
// Gradients of loss w.r.t. normalized quaternion
312321
glm::vec4 dL_dq;
313322
dL_dq.x = 2 * z * (dL_dMt[0][1] - dL_dMt[1][0]) + 2 * y * (dL_dMt[2][0] - dL_dMt[0][2]) + 2 * x * (dL_dMt[1][2] - dL_dMt[2][1]);
@@ -317,7 +326,7 @@ __device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const gl
317326

318327
// Gradients of loss w.r.t. unnormalized quaternion
319328
float4* dL_drot = (float4*)(dL_drots + idx);
320-
*dL_drot = dnormvdv(float4{ rot.x, rot.y, rot.z, rot.w }, float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w });
329+
*dL_drot = float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w };//dnormvdv(float4{ rot.x, rot.y, rot.z, rot.w }, float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w });
321330
}
322331

323332
// Backward pass of the preprocessing steps, except
@@ -377,7 +386,8 @@ __global__ void preprocessCUDA(
377386

378387
// Backward version of the rendering procedure.
379388
template <uint32_t C>
380-
__global__ void renderCUDA(
389+
__global__ void __launch_bounds__(BLOCK_X * BLOCK_Y)
390+
renderCUDA(
381391
const uint2* __restrict__ ranges,
382392
const uint32_t* __restrict__ point_list,
383393
int W, int H,
@@ -548,6 +558,7 @@ void BACKWARD::preprocess(
548558
const float* viewmatrix,
549559
const float* projmatrix,
550560
const float focal_x, float focal_y,
561+
const float tan_fovx, float tan_fovy,
551562
const glm::vec3* campos,
552563
const float3* dL_dmean2D,
553564
const float* dL_dconic,
@@ -569,6 +580,8 @@ void BACKWARD::preprocess(
569580
cov3Ds,
570581
focal_x,
571582
focal_y,
583+
tan_fovx,
584+
tan_fovy,
572585
viewmatrix,
573586
dL_dconic,
574587
(float3*)dL_dmean3D,

cuda_rasterizer/backward.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ namespace BACKWARD
3939
const float* view,
4040
const float* proj,
4141
const float focal_x, float focal_y,
42+
const float tan_fovx, float tan_fovy,
4243
const glm::vec3* campos,
4344
const float3* dL_dmean2D,
4445
const float* dL_dconics,

cuda_rasterizer/forward.cu

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,25 @@ __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const
6060
}
6161

6262
// Forward version of 2D covariance matrix computation
63-
__device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, const float* cov3D, const float* viewmatrix)
63+
__device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, float tan_fovx, float tan_fovy, const float* cov3D, const float* viewmatrix)
6464
{
6565
// The following models the steps outlined by equations 29
6666
// and 31 in "EWA Splatting" (Zwicker et al., 2002).
6767
// Additionally considers aspect / scaling of viewport.
6868
// Transposes used to account for row-/column-major conventions.
6969
float3 t = transformPoint4x3(mean, viewmatrix);
7070

71-
float t_inv_norm = 1.f / sqrt(t.x * t.x + t.y * t.y + t.z * t.z);
71+
const float limx = 1.3f * tan_fovx;
72+
const float limy = 1.3f * tan_fovy;
73+
const float txtz = t.x / t.z;
74+
const float tytz = t.y / t.z;
75+
t.x = min(limx, max(-limx, txtz)) * t.z;
76+
t.y = min(limy, max(-limy, tytz)) * t.z;
7277

7378
glm::mat3 J = glm::mat3(
7479
focal_x / t.z, 0.0f, -(focal_x * t.x) / (t.z * t.z),
7580
0.0f, focal_y / t.z, -(focal_y * t.y) / (t.z * t.z),
76-
t.x * t_inv_norm, t.y * t_inv_norm, t.z * t_inv_norm);
81+
0, 0, 0);
7782

7883
glm::mat3 W = glm::mat3(
7984
viewmatrix[0], viewmatrix[4], viewmatrix[8],
@@ -98,17 +103,17 @@ __device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y,
98103

99104
// Forward method for converting scale and rotation properties of each
100105
// Gaussian to a 3D covariance matrix in world space. Also takes care
101-
// of quaternion normalization and scale activation via exp.
106+
// of quaternion normalization.
102107
__device__ void computeCov3D(const glm::vec3 scale, float mod, const glm::vec4 rot, float* cov3D)
103108
{
104109
// Create scaling matrix
105110
glm::mat3 S = glm::mat3(1.0f);
106-
S[0][0] = mod * exp(scale.x);
107-
S[1][1] = mod * exp(scale.y);
108-
S[2][2] = mod * exp(scale.z);
111+
S[0][0] = mod * scale.x;
112+
S[1][1] = mod * scale.y;
113+
S[2][2] = mod * scale.z;
109114

110115
// Normalize quaternion to get valid rotation
111-
glm::vec4 q = rot / glm::length(rot);
116+
glm::vec4 q = rot;// / glm::length(rot);
112117
float r = q.x;
113118
float x = q.y;
114119
float y = q.z;
@@ -172,7 +177,7 @@ __global__ void preprocessCUDA(int P, int D, int M,
172177
radii[idx] = 0;
173178
tiles_touched[idx] = 0;
174179

175-
// Perform near and frustum culling with guardband, quit if outside.
180+
// Perform near culling, quit if outside.
176181
float3 p_view;
177182
if (!in_frustum(idx, orig_points, viewmatrix, projmatrix, prefiltered, p_view))
178183
return;
@@ -196,11 +201,8 @@ __global__ void preprocessCUDA(int P, int D, int M,
196201
cov3D = cov3Ds + idx * 6;
197202
}
198203

199-
// Compute max extent of Gaussian for fine-grained fustum culling
200-
float max_dist2 = 9.f * max(cov3D[0], max(cov3D[3], cov3D[5]));
201-
202204
// Compute 2D screen-space covariance matrix
203-
float3 cov = computeCov2D(p_orig, focal_x, focal_y, cov3D, viewmatrix);
205+
float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix);
204206

205207
// Invert covariance (EWA algorithm)
206208
float det = (cov.x * cov.z - cov.y * cov.y);
@@ -209,14 +211,6 @@ __global__ void preprocessCUDA(int P, int D, int M,
209211
float det_inv = 1.f / det;
210212
float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv };
211213

212-
// Fine-grained frustum culling against ellipsoid
213-
float z_at_point = p_view.z + sqrt(max_dist2);
214-
float x_to_border = z_at_point * tan_fovx;
215-
float y_to_border = z_at_point * tan_fovy;
216-
float D2_point = p_view.x * p_view.x + p_view.y * p_view.y;
217-
if (D2_point - (x_to_border * x_to_border + y_to_border * y_to_border) > max_dist2)
218-
return;
219-
220214
// Compute extent in screen space (by finding eigenvalues of
221215
// 2D covariance matrix). Use extent to compute a bounding rectangle
222216
// of screen-space tiles that this Gaussian overlaps with. Quit if
@@ -254,7 +248,8 @@ __global__ void preprocessCUDA(int P, int D, int M,
254248
// block, each thread treats one pixel. Alternates between fetching
255249
// and rasterizing data.
256250
template <uint32_t CHANNELS>
257-
__global__ void renderCUDA(
251+
__global__ void __launch_bounds__(BLOCK_X * BLOCK_Y)
252+
renderCUDA(
258253
const uint2* __restrict__ ranges,
259254
const uint32_t* __restrict__ point_list,
260255
int W, int H,
@@ -407,8 +402,8 @@ void FORWARD::preprocess(int P, int D, int M,
407402
const float* projmatrix,
408403
const glm::vec3* cam_pos,
409404
const int W, int H,
410-
const float tan_fovx, float tan_fovy,
411405
const float focal_x, float focal_y,
406+
const float tan_fovx, float tan_fovy,
412407
int* radii,
413408
float2* means2D,
414409
float* depths,

cuda_rasterizer/forward.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ namespace FORWARD
2424
const float* projmatrix,
2525
const glm::vec3* cam_pos,
2626
const int W, int H,
27-
const float tan_fovx, float tan_fovy,
2827
const float focal_x, float focal_y,
28+
const float tan_fovx, float tan_fovy,
2929
int* radii,
3030
float2* points_xy_image,
3131
float* depths,

cuda_rasterizer/rasterizer_impl.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,8 @@ int CudaRasterizer::Rasterizer::forward(
247247
viewmatrix, projmatrix,
248248
(glm::vec3*)cam_pos,
249249
width, height,
250-
tan_fovx, tan_fovy,
251250
focal_x, focal_y,
251+
tan_fovx, tan_fovy,
252252
radii,
253253
geomState.means2D,
254254
geomState.depths,
@@ -408,6 +408,7 @@ void CudaRasterizer::Rasterizer::backward(
408408
viewmatrix,
409409
projmatrix,
410410
focal_x, focal_y,
411+
tan_fovx, tan_fovy,
411412
(glm::vec3*)campos,
412413
(float3*)dL_dmean2D,
413414
dL_dconic,

rasterize_points.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,13 @@ RasterizeGaussiansCUDA(
4747
}
4848

4949
const int P = means3D.size(0);
50-
const int N = 1; // batch size hard-coded
5150
const int H = image_height;
5251
const int W = image_width;
5352

5453
auto int_opts = means3D.options().dtype(torch::kInt32);
5554
auto float_opts = means3D.options().dtype(torch::kFloat32);
5655

57-
torch::Tensor out_color = torch::full({N, NUM_CHANNELS, H, W}, 0.0, float_opts);
56+
torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts);
5857
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
5958

6059
torch::Device device(torch::kCUDA);
@@ -126,8 +125,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
126125
const torch::Tensor& imageBuffer)
127126
{
128127
const int P = means3D.size(0);
129-
const int H = dL_dout_color.size(2);
130-
const int W = dL_dout_color.size(3);
128+
const int H = dL_dout_color.size(1);
129+
const int W = dL_dout_color.size(2);
131130

132131
int M = 0;
133132
if(sh.size(0) != 0)

0 commit comments

Comments
 (0)