diff --git a/CMakeLists.txt b/CMakeLists.txt index f8cf822e..12e54de2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,26 +11,56 @@ cmake_minimum_required(VERSION 3.20) -project(DiffRast LANGUAGES CUDA CXX) +project(DiffRast LANGUAGES CXX) +find_package(HIP) +if(HIP_FOUND) + set(HIP_ENABLED ON) + project(DiffRast LANGUAGES HIP) +else() + project(DiffRast LANGUAGES CUDA) +endif() + set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_EXTENSIONS OFF) -set(CMAKE_CUDA_STANDARD 17) - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") -add_library(CudaRasterizer - cuda_rasterizer/backward.h - cuda_rasterizer/backward.cu - cuda_rasterizer/forward.h - cuda_rasterizer/forward.cu - cuda_rasterizer/auxiliary.h - cuda_rasterizer/rasterizer_impl.cu - cuda_rasterizer/rasterizer_impl.h - cuda_rasterizer/rasterizer.h -) +if(HIP_ENABLED) + set(CMAKE_HIP_STANDARD 17) + message(STATUS "Building with HIP support") +else() + set(CMAKE_CUDA_STANDARD 17) + message(STATUS "Building with CUDA support") +endif() -set_target_properties(CudaRasterizer PROPERTIES CUDA_ARCHITECTURES "70;75;86") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") -target_include_directories(CudaRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rasterizer) -target_include_directories(CudaRasterizer PRIVATE third_party/glm ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +if(HIP_ENABLED) + add_library(HipRasterizer + hip_rasterizer/backward.h + hip_rasterizer/backward.hip + hip_rasterizer/forward.h + hip_rasterizer/forward.hip + hip_rasterizer/auxiliary.h + hip_rasterizer/rasterizer_impl.hip + hip_rasterizer/rasterizer_impl.h + hip_rasterizer/rasterizer.h + ) + set_target_properties(HipRasterizer PROPERTIES HIP_ARCHITECTURES "gfx908;gfx90a;gfx942") + target_include_directories(HipRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/hip_rasterizer) + target_include_directories(HipRasterizer PRIVATE ${CMAKE_HIP_TOOLKIT_INCLUDE_DIRECTORIES}) + target_include_directories(HipRasterizer PRIVATE third_party/glm) +else() + add_library(CudaRasterizer + cuda_rasterizer/backward.h + cuda_rasterizer/backward.cu + cuda_rasterizer/forward.h + cuda_rasterizer/forward.cu + cuda_rasterizer/auxiliary.h + cuda_rasterizer/rasterizer_impl.cu + cuda_rasterizer/rasterizer_impl.h + cuda_rasterizer/rasterizer.h + ) + set_target_properties(CudaRasterizer PROPERTIES CUDA_ARCHITECTURES "70;75;86") + target_include_directories(CudaRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rasterizer) + target_include_directories(CudaRasterizer PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +endif() \ No newline at end of file diff --git a/README.md b/README.md index 6e165b0b..e2f961ac 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,23 @@ Used as the rasterization engine for the paper "3D Gaussian Splatting for Real-Time Rendering of Radiance Fields". If you can make use of it in your own research, please be so kind to cite us. +## Install: +``` +git clone https://github.com/graphdeco-inria/diff-gaussian-rasterization.git +git submodule update --init --recursive +cd diff-gaussian-rasterization +pip insatll -e . + +# or +git clone --recursive https://github.com/graphdeco-inria/diff-gaussian-rasterization.git +cd diff-gaussian-rasterization +pip insatll -e . +``` +## Test: +``` +python tests/test_forward.py +``` +

BibTeX

@@ -16,4 +33,4 @@ Used as the rasterization engine for the paper "3D Gaussian Splatting for Real-T url = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/} }
-
\ No newline at end of file + diff --git a/cuda_rasterizer/auxiliary.h b/cuda_rasterizer/auxiliary.h index 4d4b9b78..8ae25b55 100644 --- a/cuda_rasterizer/auxiliary.h +++ b/cuda_rasterizer/auxiliary.h @@ -18,6 +18,13 @@ #define BLOCK_SIZE (BLOCK_X * BLOCK_Y) #define NUM_WARPS (BLOCK_SIZE/32) +#if defined(USE_ROCM) || defined(__HIP_PLATFORM_AMD__) + #include + #define DEVICE_TRAP() assert(false) +#else + __device__ __forceinline__ void DEVICE_TRAP() { __trap(); } +#endif + // Spherical harmonics coefficients __device__ const float SH_C0 = 0.28209479177387814f; __device__ const float SH_C1 = 0.4886025119029199f; @@ -156,7 +163,7 @@ __forceinline__ __device__ bool in_frustum(int idx, if (prefiltered) { printf("Point is filtered although prefiltered is set. This shouldn't happen!"); - __trap(); + DEVICE_TRAP(); } return false; } diff --git a/cuda_rasterizer/backward.cu b/cuda_rasterizer/backward.cu index 4aa41e1c..f6ba0e24 100644 --- a/cuda_rasterizer/backward.cu +++ b/cuda_rasterizer/backward.cu @@ -12,7 +12,6 @@ #include "backward.h" #include "auxiliary.h" #include -#include namespace cg = cooperative_groups; // Backward pass for conversion of spherical harmonics to RGB for @@ -584,7 +583,7 @@ void BACKWARD::preprocess( // Somewhat long, thus it is its own kernel rather than being part of // "preprocess". When done, loss gradient w.r.t. 3D means has been // modified and gradient w.r.t. 3D covariance matrix has been computed. - computeCov2DCUDA << <(P + 255) / 256, 256 >> > ( + computeCov2DCUDA <<<(P + 255) / 256, 256 >>> ( P, means3D, radii, @@ -601,7 +600,7 @@ void BACKWARD::preprocess( // Propagate gradients for remaining steps: finish 3D mean gradients, // propagate color gradients to SH (if desireD), propagate 3D covariance // matrix gradients to scale and rotation. - preprocessCUDA << < (P + 255) / 256, 256 >> > ( + preprocessCUDA <<< (P + 255) / 256, 256 >>> ( P, D, M, (float3*)means3D, radii, @@ -638,7 +637,7 @@ void BACKWARD::render( float* dL_dopacity, float* dL_dcolors) { - renderCUDA << > >( + renderCUDA <<>>( ranges, point_list, W, H, diff --git a/cuda_rasterizer/backward.h b/cuda_rasterizer/backward.h index 93dd2e4b..5812f8a4 100644 --- a/cuda_rasterizer/backward.h +++ b/cuda_rasterizer/backward.h @@ -14,7 +14,6 @@ #include #include "cuda_runtime.h" -#include "device_launch_parameters.h" #define GLM_FORCE_CUDA #include diff --git a/cuda_rasterizer/forward.cu b/cuda_rasterizer/forward.cu index c419a328..a973252c 100644 --- a/cuda_rasterizer/forward.cu +++ b/cuda_rasterizer/forward.cu @@ -12,7 +12,6 @@ #include "forward.h" #include "auxiliary.h" #include -#include namespace cg = cooperative_groups; // Forward method for converting the input spherical harmonics @@ -259,118 +258,94 @@ __global__ void preprocessCUDA(int P, int D, int M, // block, each thread treats one pixel. Alternates between fetching // and rasterizing data. template -__global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) +__global__ void __launch_bounds__(BLOCK_SIZE) renderCUDA( - const uint2* __restrict__ ranges, - const uint32_t* __restrict__ point_list, - int W, int H, - const float2* __restrict__ points_xy_image, - const float* __restrict__ features, - const float4* __restrict__ conic_opacity, - float* __restrict__ final_T, - uint32_t* __restrict__ n_contrib, - const float* __restrict__ bg_color, - float* __restrict__ out_color) + const uint2* __restrict__ ranges, + const uint32_t* __restrict__ point_list, + int W, int H, + const float2* __restrict__ points_xy_image, + const float* __restrict__ features, + const float4* __restrict__ conic_opacity, + float* __restrict__ final_T, + uint32_t* __restrict__ n_contrib, + const float* __restrict__ bg_color, + float* __restrict__ out_color) { - // Identify current tile and associated min/max pixel range. - auto block = cg::this_thread_block(); - uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X; - uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y }; - uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) }; - uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y }; - uint32_t pix_id = W * pix.y + pix.x; - float2 pixf = { (float)pix.x, (float)pix.y }; - - // Check if this thread is associated with a valid pixel or outside. - bool inside = pix.x < W&& pix.y < H; - // Done threads can help with fetching, but don't rasterize - bool done = !inside; - - // Load start/end range of IDs to process in bit sorted list. - uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x]; - const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE); - int toDo = range.y - range.x; - - // Allocate storage for batches of collectively fetched data. - __shared__ int collected_id[BLOCK_SIZE]; - __shared__ float2 collected_xy[BLOCK_SIZE]; - __shared__ float4 collected_conic_opacity[BLOCK_SIZE]; - - // Initialize helper variables - float T = 1.0f; - uint32_t contributor = 0; - uint32_t last_contributor = 0; - float C[CHANNELS] = { 0 }; - - // Iterate over batches until all done or range is complete - for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE) - { - // End if entire block votes that it is done rasterizing - int num_done = __syncthreads_count(done); - if (num_done == BLOCK_SIZE) - break; - - // Collectively fetch per-Gaussian data from global to shared - int progress = i * BLOCK_SIZE + block.thread_rank(); - if (range.x + progress < range.y) - { - int coll_id = point_list[range.x + progress]; - collected_id[block.thread_rank()] = coll_id; - collected_xy[block.thread_rank()] = points_xy_image[coll_id]; - collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id]; - } - block.sync(); - - // Iterate over current batch - for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++) - { - // Keep track of current position in range - contributor++; - - // Resample using conic matrix (cf. "Surface - // Splatting" by Zwicker et al., 2001) - float2 xy = collected_xy[j]; - float2 d = { xy.x - pixf.x, xy.y - pixf.y }; - float4 con_o = collected_conic_opacity[j]; - float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y; - if (power > 0.0f) - continue; - - // Eq. (2) from 3D Gaussian splatting paper. - // Obtain alpha by multiplying with Gaussian opacity - // and its exponential falloff from mean. - // Avoid numerical instabilities (see paper appendix). - float alpha = min(0.99f, con_o.w * exp(power)); - if (alpha < 1.0f / 255.0f) - continue; - float test_T = T * (1 - alpha); - if (test_T < 0.0001f) - { - done = true; - continue; - } - - // Eq. (3) from 3D Gaussian splatting paper. - for (int ch = 0; ch < CHANNELS; ch++) - C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T; - - T = test_T; - - // Keep track of last range entry to update this - // pixel. - last_contributor = contributor; - } - } - - // All threads that treat valid pixel write out their final - // rendering data to the frame and auxiliary buffers. - if (inside) - { - final_T[pix_id] = T; - n_contrib[pix_id] = last_contributor; - for (int ch = 0; ch < CHANNELS; ch++) - out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch]; - } + int bx = blockIdx.x, by = blockIdx.y; + int tx = threadIdx.x, ty = threadIdx.y; + int tflat = ty * BLOCK_X + tx; + int pix_x = bx * BLOCK_X + tx, pix_y = by * BLOCK_Y + ty; + int pix_id = pix_y * W + pix_x; + bool inside = (pix_x < W && pix_y < H); + + int hblocks = (W + BLOCK_X - 1) / BLOCK_X; + int range_idx = by * hblocks + bx; + uint2 range = ranges[range_idx]; + int n_ids = range.y - range.x; + int rounds = (n_ids + BLOCK_SIZE - 1) / BLOCK_SIZE; + + float2 pixf = {static_cast(pix_x), static_cast(pix_y)}; + float T = 1.0f; + uint32_t contributor = 0, last_contributor = 0; + float C[CHANNELS] = {0.0f}; + bool done = !inside; + + // Cache background color per thread + float bg_val[CHANNELS]; + #pragma unroll + for (int ch = 0; ch < CHANNELS; ++ch) + bg_val[ch] = bg_color[ch]; + + // Shared memory tiling + __shared__ uint32_t s_id[BLOCK_SIZE]; + __shared__ float2 s_xy[BLOCK_SIZE]; + __shared__ float4 s_conic[BLOCK_SIZE]; + __shared__ float s_feat[BLOCK_SIZE * NUM_CHANNELS]; + + for (int chunk = 0, offset = 0; chunk < rounds; ++chunk, offset += BLOCK_SIZE) { + if (__syncthreads_count(done) == BLOCK_SIZE) break; + int gi = offset + tflat; + bool valid = (gi < n_ids); + + int gauss_id = valid ? point_list[range.x + gi] : 0; + s_id[tflat] = gauss_id; + s_xy[tflat] = valid ? points_xy_image[gauss_id] : make_float2(0.0f, 0.0f); + s_conic[tflat] = valid ? conic_opacity[gauss_id] : make_float4(0.0f, 0.0f, 0.0f, 0.0f); + #pragma unroll + for (int ch = 0; ch < CHANNELS; ++ch) + s_feat[tflat*CHANNELS + ch] = valid ? features[gauss_id*CHANNELS+ch] : 0.0f; + __syncthreads(); + + int end = min(BLOCK_SIZE, n_ids-offset); + #pragma unroll 8 + for (int j = 0; !done && j < end; ++j) { + contributor++; + float2 xy = s_xy[j]; + float4 co = s_conic[j]; + float dx = xy.x - pixf.x, dy = xy.y - pixf.y; + float power = -0.5f * (co.x*dx*dx + co.z*dy*dy) - co.y*dx*dy; + float alpha = co.w * __expf(power); + + if (power > 0.0f || alpha < (1.0f/255.0f)) continue; + alpha = (alpha > 0.99f) ? 0.99f : alpha; + float new_T = T * (1.0f-alpha); + if (new_T < 0.0001f) { done = true; continue; } + #pragma unroll + for (int ch = 0; ch < CHANNELS; ++ch) + C[ch] = __fmaf_rn(s_feat[j*CHANNELS+ch], alpha*T, C[ch]); + T = new_T; + last_contributor = contributor; + } + __syncthreads(); + } + if (inside) { + final_T[pix_id] = T; + n_contrib[pix_id] = last_contributor; + // Non-temporal store + #pragma unroll + for (int ch = 0; ch < CHANNELS; ++ch) + ((volatile float*)out_color)[ch*H*W + pix_id] = C[ch] + T*bg_val[ch]; + } } void FORWARD::render( @@ -386,7 +361,7 @@ void FORWARD::render( const float* bg_color, float* out_color) { - renderCUDA << > > ( + renderCUDA <<>> ( ranges, point_list, W, H, @@ -425,7 +400,7 @@ void FORWARD::preprocess(int P, int D, int M, uint32_t* tiles_touched, bool prefiltered) { - preprocessCUDA << <(P + 255) / 256, 256 >> > ( + preprocessCUDA <<<(P + 255) / 256, 256 >>> ( P, D, M, means3D, scales, diff --git a/cuda_rasterizer/forward.h b/cuda_rasterizer/forward.h index 3c11cb91..c33c531c 100644 --- a/cuda_rasterizer/forward.h +++ b/cuda_rasterizer/forward.h @@ -14,7 +14,6 @@ #include #include "cuda_runtime.h" -#include "device_launch_parameters.h" #define GLM_FORCE_CUDA #include diff --git a/cuda_rasterizer/rasterizer_impl.cu b/cuda_rasterizer/rasterizer_impl.cu index f8782ac4..60138123 100644 --- a/cuda_rasterizer/rasterizer_impl.cu +++ b/cuda_rasterizer/rasterizer_impl.cu @@ -16,14 +16,12 @@ #include #include #include "cuda_runtime.h" -#include "device_launch_parameters.h" #include #include #define GLM_FORCE_CUDA #include #include -#include namespace cg = cooperative_groups; #include "auxiliary.h" @@ -145,7 +143,7 @@ void CudaRasterizer::Rasterizer::markVisible( float* projmatrix, bool* present) { - checkFrustum << <(P + 255) / 256, 256 >> > ( + checkFrustum <<<(P + 255) / 256, 256 >>> ( P, means3D, viewmatrix, projmatrix, @@ -286,7 +284,7 @@ int CudaRasterizer::Rasterizer::forward( // For each instance to be rendered, produce adequate [ tile | depth ] key // and corresponding dublicated Gaussian indices to be sorted - duplicateWithKeys << <(P + 255) / 256, 256 >> > ( + duplicateWithKeys <<<(P + 255) / 256, 256 >>> ( P, geomState.means2D, geomState.depths, @@ -311,7 +309,7 @@ int CudaRasterizer::Rasterizer::forward( // Identify start and end of per-tile workloads in sorted list if (num_rendered > 0) - identifyTileRanges << <(num_rendered + 255) / 256, 256 >> > ( + identifyTileRanges <<<(num_rendered + 255) / 256, 256 >>> ( num_rendered, binningState.point_list_keys, imgState.ranges); diff --git a/cuda_rasterizer/rasterizer_impl.h b/cuda_rasterizer/rasterizer_impl.h index bc3f0ece..f658163c 100644 --- a/cuda_rasterizer/rasterizer_impl.h +++ b/cuda_rasterizer/rasterizer_impl.h @@ -15,16 +15,17 @@ #include #include "rasterizer.h" #include +#include namespace CudaRasterizer { template static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment) { - std::size_t offset = (reinterpret_cast(chunk) + alignment - 1) & ~(alignment - 1); + std::size_t offset = (reinterpret_cast(chunk) + alignment - 1) & ~(alignment - 1); ptr = reinterpret_cast(offset); chunk = reinterpret_cast(ptr + count); - } + } struct GeometryState { diff --git a/setup.py b/setup.py index bb7220d2..68627ee4 100644 --- a/setup.py +++ b/setup.py @@ -8,27 +8,52 @@ # # For inquiries contact george.drettakis@inria.fr # - from setuptools import setup -from torch.utils.cpp_extension import CUDAExtension, BuildExtension +from torch.utils.cpp_extension import CUDAExtension, BuildExtension, ROCM_HOME +from torch.utils.hipify import hipify_python import os -os.path.dirname(os.path.abspath(__file__)) +import torch + +# Include this line immediately after the import statements +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) +is_rocm = False +if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5): + is_rocm = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False + +current_dir = os.path.dirname(os.path.abspath(__file__)) +hipify_source_dir = os.path.join(current_dir, "cuda_rasterizer") +hipify_source_file = os.path.join(current_dir, "rasterize_points.cu") +hipify_dst_dir = os.path.join(current_dir, "hip_rasterizer") +hipify_in_files = [hipify_source_dir, hipify_source_file] +hipify_out_dirs = [hipify_dst_dir, current_dir] +if is_rocm: + print("[INFO] ROCm detected: running hipify...") + for i in range(len(hipify_in_files)): + hipify_python.hipify( + project_directory=hipify_in_files[i], + output_directory=hipify_out_dirs[i], + show_detailed=True, + is_pytorch_extension=True + ) + source_files = ["hip_rasterizer/rasterizer_impl.hip", "hip_rasterizer/forward.hip", "hip_rasterizer/backward.hip", "rasterize_points.hip", "ext.cpp"] +else: + print("[INFO] CUDA detected: using CUDA source directly.") + source_files = ["cuda_rasterizer/rasterizer_impl.cu", "cuda_rasterizer/forward.cu", "cuda_rasterizer/backward.cu", "rasterize_points.cu", "ext.cpp"] + +glm_include_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/") setup( name="diff_gaussian_rasterization", packages=['diff_gaussian_rasterization'], ext_modules=[ CUDAExtension( name="diff_gaussian_rasterization._C", - sources=[ - "cuda_rasterizer/rasterizer_impl.cu", - "cuda_rasterizer/forward.cu", - "cuda_rasterizer/backward.cu", - "rasterize_points.cu", - "ext.cpp"], - extra_compile_args={"nvcc": ["-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")]}) + sources=source_files, + extra_compile_args={"nvcc":["-I"+glm_include_dir], "cxx":["-I"+glm_include_dir]}) ], cmdclass={ 'build_ext': BuildExtension - } + }, + version='1.0.0' ) diff --git a/tests/data/forward_golden.pt b/tests/data/forward_golden.pt new file mode 100644 index 00000000..0ff8d823 Binary files /dev/null and b/tests/data/forward_golden.pt differ diff --git a/tests/test_forward.py b/tests/test_forward.py new file mode 100644 index 00000000..31546cfd --- /dev/null +++ b/tests/test_forward.py @@ -0,0 +1,71 @@ +import os +import math +import torch +import unittest +from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings + +class TestGaussianRasterization(unittest.TestCase): + def test_gaussian_rasterization(self): + # init parameters + device = "cuda" + num_points = 1000 + + # generate random gaussian parameters + torch.manual_seed(42) + means3D = torch.randn((num_points, 3), dtype=torch.float, device=device) # [N, 3] + colors = torch.rand((num_points, 3), dtype=torch.float, device=device) # [N, 3] RGB in [0,1] + opacities = torch.sigmoid(torch.randn((num_points, 1), dtype=torch.float, device=device)) # [N, 1] + scales = torch.rand((num_points, 3), dtype=torch.float, device=device) # [N, 3] + rotations = torch.randn((num_points, 4), dtype=torch.float, device=device) + rotations = rotations / rotations.norm(dim=1, keepdim=True) # normalization + + # set rasterization parameters + image_height, image_width = 512, 512 + fov_x, fov_y = 60.0, 45.0 # Field Angle (degree) + + raster_settings = GaussianRasterizationSettings( + image_height=image_height, + image_width=image_width, + tanfovx=float(math.tan(fov_x * 0.5 * math.pi / 180)), # Convert to radians and get tangent + tanfovy=float(math.tan(fov_y * 0.5 * math.pi / 180)), + bg=torch.tensor([0, 0, 0], dtype=torch.float, device=device), # background color(black) + scale_modifier=float(1.0), # scale factor + viewmatrix=torch.eye(4, dtype=torch.float, device=device), # assume the view matrix is the identity matrix + projmatrix=torch.eye(4, dtype=torch.float, device=device), # assume the view matrix is the identity matrix + sh_degree=0, # 0 for not use sh + campos=torch.zeros(3, dtype=torch.float, device=device), # camera position + prefiltered=False, + debug=False + ) + + # create rasterizer + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + color, radii = rasterizer( + means3D=means3D, + means2D=None, # not used + shs=None, # as raster_settings.sh_degree = 0, no Spherical Harmonics will affect the final render result + opacities=opacities, # opacity for render if cover or not cover + colors_precomp=colors, # set the color to avoid SH calculate + scales=scales, # scale factor to control Conv3D gaussian + rotations=rotations, # rotation for 3D=>2D + cov3D_precomp=None + ) + + current_dir = os.path.dirname(os.path.abspath(__file__)) + out_file_path = os.path.join(current_dir, 'data/forward_golden.pt') + if not os.path.exists(out_file_path): + if not torch.version.hip: + print("[INFO] saving golden result with cuda.") + torch.save({'color': color.cpu(), 'radii': radii.cpu()}, out_file_path) + else: + print("[ERROR] File not exist!") + self.assertTrue(0) + else: + data = torch.load(out_file_path) + reference_color = data['color'] + reference_radii = data['radii'] + self.assertTrue(torch.allclose(color.cpu(), reference_color)) + self.assertTrue(torch.allclose(radii.cpu(), reference_radii)) + +if __name__ == "__main__": + unittest.main()