Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@

processed_data/
gaussctrl/__pycache__
test.sh
test.py

cmd.sh

Expand Down
45 changes: 27 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<p align="center">

<h1 align="center"><strong>[ECCV 2024] GaussCtrl: Multi-View Consistent Text-Driven 3D Gaussian Splatting Editing</strong></h3>
<h1 align="center"><strong>🎥 [ECCV 2024] GaussCtrl: Multi-View Consistent Text-Driven 3D Gaussian Splatting Editing</strong></h3>

<p align="center">
<a href="https://jingwu2121.github.io/" class="name-link" target="_blank">Jing Wu<sup>*1</sup> </a>,
Expand Down Expand Up @@ -29,7 +29,10 @@

![teaser](./assets/teaser.png)

## Installation
## ✨ News
- [9.4.2024] Our original results utilise stable-diffusion-v1-5 from runwayml for editing, which is now unavailable. Please change the diffusion checkpoint to other available models, e.g. `CompVis/stable-diffusion-v1-4`, by using `--pipeline.diffusion_ckpt "CompVis/stable-diffusion-v1-4"`. Reproduce our original results by using the checkpoint `--pipeline.diffusion_ckpt "jinggogogo/gaussctrl-sd15"`

## ⚙️ Installation

- Tested on CUDA11.8 + Ubuntu22.04 + NeRFStudio1.0.0 (NVIDIA RTX A5000 24G)

Expand All @@ -51,6 +54,10 @@ GaussCtrl is built upon NeRFStudio, follow [this link](https://docs.nerf.studio/

```bash
pip install nerfstudio==1.0.0

# Try either of these two if one is not working
pip install gsplat==0.1.2
pip install gsplat==0.1.3
```

Install Lang-SAM for mask extraction.
Expand All @@ -71,7 +78,7 @@ pip install -e .
ns-train -h
```

## Data
## 🗄️ Data

### Use Our Preprocessed Data

Expand All @@ -87,11 +94,11 @@ We thank these authors for their great work!

We recommend to pre-process your data to 512x512, and follow [this page](https://docs.nerf.studio/quickstart/custom_dataset.html) to process your data.

## Get Started
## :arrow_forward: Get Started
![Method](./assets/method.png)

### 1. Train a 3DGS
To get started, you firstly need to train your 3DGS model. We use `splatfacto` from NeRFStudio.
To get started, you first need to train your 3DGS model. We use `splatfacto` from NeRFStudio.

```bash
ns-train splatfacto --output-dir {output/folder} --experiment-name EXPEIMENT_NAME nerfstudio-data --data {path/to/your/data}
Expand All @@ -103,42 +110,44 @@ Once you finish training the `splatfacto` model, the checkpoints will be saved t
Start editing your model by running:

```bash
ns-train gaussctrl --load-checkpoint {output/folder/.../nerfstudio_models/step-000029999.ckpt} --experiment-name EXPEIMENT_NAME --output-dir {output/folder} --pipeline.datamanager.data {path/to/your/data} --pipeline.prompt "YOUR PROMPT" --pipeline.guidance_scale 5 --pipeline.chunk_size {batch size of images during editing} --pipeline.langsam_obj 'OBJECT TO BE EDITED'
ns-train gaussctrl --load-checkpoint {output/folder/.../nerfstudio_models/step-000029999.ckpt} --experiment-name EXPEIMENT_NAME --output-dir {output/folder} --pipeline.datamanager.data {path/to/your/data} --pipeline.edit_prompt "YOUR PROMPT" --pipeline.reverse_prompt "PROMPT TO DESCRIBE THE UNEDITED SCENE" --pipeline.guidance_scale 5 --pipeline.chunk_size {batch size of images during editing} --pipeline.langsam_obj 'OBJECT TO BE EDITED'
```

Please note that the Lang-SAM is optional here. If you are editing the environment, please remove this argument.

```bash
ns-train gaussctrl --load-checkpoint {output/folder/.../nerfstudio_models/step-000029999.ckpt} --experiment-name EXPEIMENT_NAME --output-dir {output/folder} --pipeline.datamanager.data {path/to/your/data} --pipeline.prompt "YOUR PROMPT" --pipeline.guidance_scale 5 --pipeline.chunk_size {batch size of images during editing}
ns-train gaussctrl --load-checkpoint {output/folder/.../nerfstudio_models/step-000029999.ckpt} --experiment-name EXPEIMENT_NAME --output-dir {output/folder} --pipeline.datamanager.data {path/to/your/data} --pipeline.edit_prompt "YOUR PROMPT" --pipeline.reverse_prompt "PROMPT TO DESCRIBE THE UNEDITED SCENE" --pipeline.guidance_scale 5 --pipeline.chunk_size {batch size of images during editing}
```

Here, `--pipeline.guidance_scale` denotes the classifier free guidance used when editing the images. `--pipeline.chunk_size` denotes the number of images edited together during 1 batch. We are using **NVIDIA RTX A5000** GPU (24G), and the maximum chunk size is 3. (~22G)
Here, `--pipeline.guidance_scale` denotes the classifier-free guidance used when editing the images. `--pipeline.chunk_size` denotes the number of images edited together during 1 batch. We are using **NVIDIA RTX A5000** GPU (24G), and the maximum chunk size is 3. (~22G)

Control the number of reference views using `--pipeline.ref_view_num`, by default, it is set to 4.

### Small Tips
- If your find your editings are not as expected, please check the images edited by ControlNet.
- Normally, conditioning your editing on the good ControlNet editing views is very helpful, which means it is better to choose those good ControlNet editing views as reference views.
- If your editings are not as expected, please check the images edited by ControlNet.
- Empirically, conditioning your editing on the good ControlNet editing views is very helpful, which means choosing those good ControlNet editing views as reference views is better.

## Reproduce Our Results
## :wrench: Reproduce Our Results

Experiments in the main paper are inclued in `scripts` folder. To reproduce the results, first train the `splatfacto` model. We take the `bear` case as an example here.
Experiments in the main paper are included in the `scripts` folder. To reproduce the results, first train the `splatfacto` model. We take the `bear` case as an example here.
```bash
ns-train splatfacto --output-dir unedited_models --experiment-name bear nerfstudio-data --data data/bear
```

Then edit the 3DGS by running:
```bash
ns-train gaussctrl --load-checkpoint {unedited_models/bear/splatfacto/.../nerfstudio_models/step-000029999.ckpt} --experiment-name bear --output-dir outputs --pipeline.datamanager.data data/bear --pipeline.prompt "a photo of a polar bear in the forest" --pipeline.guidance_scale 5 --pipeline.chunk_size 3 --pipeline.langsam_obj 'bear'
ns-train gaussctrl --load-checkpoint unedited_models/bear/splatfacto/2024-07-10_170906/nerfstudio_models/step-000029999.ckpt --experiment-name bear --output-dir outputs --pipeline.datamanager.data data/bear --pipeline.edit_prompt "a photo of a polar bear in the forest" --pipeline.reverse_prompt "a photo of a bear statue in the forest" --pipeline.guidance_scale 5 --pipeline.chunk_size 3 --pipeline.langsam_obj 'bear' --viewer.quit-on-train-completion True
```

In our experiments, We sampled 40 views randomly from the entire dataset to accelerate the method, which is set in `gc_datamanager.py` by default. We split the entire set into 4 subsets, and randomly sampled 10 images in each subset split. Feel free to decrease/increase the number to see the difference by modifying `--pipeline.datamanager.subset-num` and `--pipeline.datamanager.sampled-views-every-subset`. Set `--pipeline.datamanager.load-all` to `True`, if you want to edit all the images in the dataset.

## View Results Using NeRFStudio Viewer
## :camera: View Results Using NeRFStudio Viewer
```bash
ns-viewer --load-config {outputs/.../config.yml}
```

## Render Your Results
- Render the all the dataset views.
## :movie_camera: Render Your Results
- Render all the dataset views.
```bash
ns-gaussctrl-render dataset --load-config {outputs/.../config.yml} --output_path {render/EXPEIMENT_NAME}
```
Expand All @@ -157,7 +166,7 @@ If you find this code or find the paper useful for your research, please conside
@article{gaussctrl2024,
author = {Wu, Jing and Bian, Jia-Wang and Li, Xinghui and Wang, Guangrun and Reid, Ian and Torr, Philip and Prisacariu, Victor},
title = {{GaussCtrl: Multi-View Consistent Text-Driven 3D Gaussian Splatting Editing}},
booktitle = {ECCV},
journal = {ECCV},
year = {2024},
}
```
```
3 changes: 1 addition & 2 deletions gaussctrl/gc_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def __init__(self,
self.step_every = 1
self.edited_image_dict = {}

breakpoint()
# Sample data
if len(self.train_dataset._dataparser_outputs.image_filenames) <= self.config.subset_num * self.config.sampled_views_every_subset or self.config.load_all:
self.cameras = self.train_dataset.cameras
Expand Down Expand Up @@ -226,7 +225,7 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]:
data["image"] = data["image"].to(self.device)

assert len(self.train_dataset.cameras.shape) == 1, "Assumes single batch dimension"
if len(self.train_dataset._dataparser_outputs.image_filenames) <= self.config.subset_num * self.config.sampled_views_every_subset:
if len(self.train_dataset._dataparser_outputs.image_filenames) <= self.config.subset_num * self.config.sampled_views_every_subset or self.config.load_all:
camera = self.cameras[image_idx : image_idx + 1].to(self.device)
else:
camera = self.cameras[image_idx : image_idx + 1][0].to(self.device)
Expand Down
40 changes: 25 additions & 15 deletions gaussctrl/gc_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ class GaussCtrlPipelineConfig(VanillaPipelineConfig):
"""specifies the datamanager config"""
render_rate: int = 500
"""how many gauss steps for gauss training"""
prompt: str = ""
edit_prompt: str = ""
"""Positive Prompt"""
reverse_prompt: str = ""
"""DDIM Inversion Prompt"""
langsam_obj: str = ""
"""The object to be edited"""
guidance_scale: float = 5
Expand All @@ -65,6 +67,10 @@ class GaussCtrlPipelineConfig(VanillaPipelineConfig):
"""Inference steps"""
chunk_size: int = 5
"""Batch size for image editing, feel free to reduce to fit your GPU"""
ref_view_num: int = 4
"""Number of reference frames"""
diffusion_ckpt: str = 'CompVis/stable-diffusion-v1-4'
"""Diffusion checkpoints"""


class GaussCtrlPipeline(VanillaPipeline):
Expand All @@ -85,21 +91,23 @@ def __init__(
self.test_mode = test_mode
self.langsam = LangSAM()

self.prompt = self.config.prompt
self.edit_prompt = self.config.edit_prompt
self.reverse_prompt = self.config.reverse_prompt
self.pipe_device = 'cuda:0'
self.ddim_scheduler = DDIMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
self.ddim_inverser = DDIMInverseScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
self.ddim_scheduler = DDIMScheduler.from_pretrained(self.config.diffusion_ckpt, subfolder="scheduler")
self.ddim_inverser = DDIMInverseScheduler.from_pretrained(self.config.diffusion_ckpt, subfolder="scheduler")

controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth")
self.pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet=controlnet).to(self.device).to(torch.float16)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.config.diffusion_ckpt, controlnet=controlnet).to(self.device).to(torch.float16)
self.pipe.to(self.pipe_device)

added_prompt = 'best quality, extremely detailed'
self.positive_prompt = self.prompt + ', ' + added_prompt
self.positive_prompt = self.edit_prompt + ', ' + added_prompt
self.positive_reverse_prompt = self.reverse_prompt + ', ' + added_prompt
self.negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'

view_num = len(self.datamanager.cameras)
anchors = list(range(0, view_num, view_num // 4)) + [view_num]
anchors = [(view_num * i) // self.config.ref_view_num for i in range(self.config.ref_view_num)] + [view_num]

random.seed(13789)
self.ref_indices = [random.randint(anchor, anchors[idx+1]) for idx, anchor in enumerate(anchors[:-1])]
Expand Down Expand Up @@ -131,7 +139,7 @@ def render_reverse(self):
disparity = self.depth2disparity_torch(rendered_depth[:,:,0][None])

self.pipe.scheduler = self.ddim_inverser
latent, _ = self.pipe(prompt=self.positive_prompt, # placeholder here, since cfg=0
latent, _ = self.pipe(prompt=self.positive_reverse_prompt, # placeholder here, since cfg=0
num_inference_steps=self.num_inference_steps,
latents=init_latent,
image=disparity, return_dict=False, guidance_scale=0, output_type='latent')
Expand All @@ -140,8 +148,11 @@ def render_reverse(self):
if self.config.langsam_obj != "":
langsam_obj = self.config.langsam_obj
langsam_rgb_pil = Image.fromarray((rendered_rgb.cpu().numpy() * 255).astype(np.uint8))
masks, _, _, _ = self.langsam.predict(langsam_rgb_pil, langsam_obj)
mask_npy = masks.clone().cpu().numpy()[0] * 1
# The new LangSAM API expects lists; passing a bare string causes it to
# iterate over characters (e.g. "bear" -> ["b","e","a","r"]), breaking batching.
results = self.langsam.predict([langsam_rgb_pil], [langsam_obj])
result_masks = results[0]["masks"] # new API returns list[dict]
mask_npy = result_masks[0] * 1 if len(result_masks) > 0 else None

if self.config.langsam_obj != "":
self.update_datasets(cam_idx, rendered_rgb.cpu(), rendered_depth, latent, mask_npy)
Expand All @@ -150,20 +161,19 @@ def render_reverse(self):

def edit_images(self):
'''Edit images with ControlNet and AttnAlign'''
# if self.test_mode == "val":
# Set up ControlNet and AttnAlign
self.pipe.scheduler = self.ddim_scheduler
self.pipe.unet.set_attn_processor(
processor=utils.CrossFrameAttnProcessor(self_attn_coeff=0.6,
processor=utils.CrossViewAttnProcessor(self_attn_coeff=0.6,
unet_chunk_size=2))
self.pipe.controlnet.set_attn_processor(
processor=utils.CrossFrameAttnProcessor(self_attn_coeff=0,
processor=utils.CrossViewAttnProcessor(self_attn_coeff=0,
unet_chunk_size=2))
CONSOLE.print("Done Reset Attention Processor", style="bold blue")
CONSOLE.print("Done Resetting Attention Processor", style="bold blue")

print("#############################")
CONSOLE.print("Start Editing: ", style="bold yellow")
CONSOLE.print(f"Reference views are {[j+1 for j in self.ref_indices]}, counting from 1", style="bold yellow")
CONSOLE.print(f"Reference views are {[j+1 for j in self.ref_indices]}", style="bold yellow")
print("#############################")
ref_disparity_list = []
ref_z0_list = []
Expand Down
2 changes: 1 addition & 1 deletion gaussctrl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def compute_attn(attn, query, key, value, video_length, ref_frame_index, attenti
hidden_states_ref_cross = torch.bmm(attention_probs, value_ref_cross)
return hidden_states_ref_cross

class CrossFrameAttnProcessor:
class CrossViewAttnProcessor:
def __init__(self, self_attn_coeff, unet_chunk_size=2):
self.unet_chunk_size = unet_chunk_size
self.self_attn_coeff = self_attn_coeff
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
transformers>=4.38.0
diffusers==0.26.0
transformers==4.34.1
pip install huggingface-hub==0.20.3
6 changes: 3 additions & 3 deletions scripts/bear.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ns-train splatfacto --output-dir unedited_models --experiment-name bear --viewer.quit-on-train-completion True nerfstudio-data --data data/bear

ns-train gaussctrl --load-checkpoint unedited_models/bear/splatfacto/2024-07-10_170906/nerfstudio_models/step-000029999.ckpt --experiment-name bear --output-dir outputs --pipeline.datamanager.data data/bear --pipeline.prompt "a photo of a polar bear in the forest" --pipeline.guidance_scale 5 --pipeline.chunk_size 3 --pipeline.langsam_obj 'bear' --viewer.quit-on-train-completion True
ns-train gaussctrl --load-checkpoint unedited_models/bear/splatfacto/2024-07-10_170906/nerfstudio_models/step-000029999.ckpt --experiment-name bear --output-dir outputs --pipeline.datamanager.data data/bear --pipeline.edit_prompt "a photo of a polar bear in the forest" --pipeline.reverse_prompt "a photo of a bear statue in the forest" --pipeline.guidance_scale 5 --pipeline.chunk_size 3 --pipeline.langsam_obj 'bear' --viewer.quit-on-train-completion True

ns-train gaussctrl --load-checkpoint unedited_models/bear/splatfacto/2024-07-10_170906/nerfstudio_models/step-000029999.ckpt --experiment-name bear --output-dir outputs --pipeline.datamanager.data data/bear --pipeline.prompt "a photo of a grizzly bear in the forest" --pipeline.guidance_scale 5 --pipeline.chunk_size 3 --pipeline.langsam_obj 'bear' --viewer.quit-on-train-completion True
ns-train gaussctrl --load-checkpoint unedited_models/bear/splatfacto/2024-07-10_170906/nerfstudio_models/step-000029999.ckpt --experiment-name bear --output-dir outputs --pipeline.datamanager.data data/bear --pipeline.edit_prompt "a photo of a grizzly bear in the forest" --pipeline.reverse_prompt "a photo of a bear statue in the forest" --pipeline.guidance_scale 5 --pipeline.chunk_size 3 --pipeline.langsam_obj 'bear' --viewer.quit-on-train-completion True

ns-train gaussctrl --load-checkpoint unedited_models/bear/splatfacto/2024-07-10_170906/nerfstudio_models/step-000029999.ckpt --experiment-name bear --output-dir outputs --pipeline.datamanager.data data/bear --pipeline.prompt "a photo of a golden bear statue in the forest" --pipeline.guidance_scale 5 --pipeline.chunk_size 3 --pipeline.langsam_obj 'bear' --viewer.quit-on-train-completion True
ns-train gaussctrl --load-checkpoint unedited_models/bear/splatfacto/2024-07-10_170906/nerfstudio_models/step-000029999.ckpt --experiment-name bear --output-dir outputs --pipeline.datamanager.data data/bear --pipeline.edit_prompt "a photo of a golden bear statue in the forest" --pipeline.reverse_prompt "a photo of a bear statue in the forest" --pipeline.guidance_scale 5 --pipeline.chunk_size 3 --pipeline.langsam_obj 'bear' --viewer.quit-on-train-completion True
Loading