-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Update inductor_cpp_wrapper_tutorial.rst #3614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| Inductor C++ Wrapper Tutorial | ||
| TorchInductor C++ Wrapper Tutorial | ||
| ============================================================== | ||
|
|
||
| **Author**: `Chunyuan Wu <https://github.com/chunyuan-w>`_, `Bin Bao <https://github.com/desertfire>`__, `Jiong Gong <https://github.com/jgong5>`__ | ||
|
|
@@ -10,85 +10,119 @@ Prerequisites: | |
| Introduction | ||
| ------------ | ||
|
|
||
| Python, as the primary interface of PyTorch, is easy to use and efficient for development and debugging. | ||
| The Inductor's default wrapper generates Python code to invoke generated kernels and external kernels. | ||
| However, in deployments requiring high performance, Python, as an interpreted language, runs relatively slower compared to compiled languages. | ||
| In ``torch.compile``, the default backend **TorchInductor** emits Python wrapper | ||
| code that manages memory allocation and kernel invocation. This design provides | ||
| flexibility and ease of debugging, but the interpreted nature of Python | ||
| introduces runtime overhead in performance-sensitive environments. | ||
|
|
||
| We implemented an Inductor C++ wrapper by leveraging the PyTorch C++ APIs | ||
| to generate pure C++ code that combines the generated and external kernels. | ||
| This allows for the execution of each captured Dynamo graph in pure C++, | ||
| thereby reducing the Python overhead within the graph. | ||
| To address this limitation, TorchInductor includes a specialized mode that | ||
| generates **C++ wrapper code** in place of the Python wrapper, enabling faster | ||
| execution with minimal Python involvement. | ||
|
|
||
|
|
||
| Enabling the API | ||
| Enabling the C++ wrapper mode | ||
| ---------------- | ||
| This feature is still in prototype stage. To activate this feature, add the following to your code: | ||
| To enable this C++ wrapper mode for TorchInductor, add the following config to your code: | ||
|
|
||
| .. code:: python | ||
|
|
||
| import torch._inductor.config as config | ||
| config.cpp_wrapper = True | ||
|
|
||
| This will speed up your models by reducing the Python overhead of the Inductor wrapper. | ||
|
|
||
|
|
||
| Example code | ||
| ------------ | ||
|
|
||
| We will use the below frontend code as an example: | ||
| We will use the following model code as an example: | ||
|
|
||
| .. code:: python | ||
|
|
||
| import torch | ||
| import torch._inductor.config as config | ||
|
|
||
| def fn(x): | ||
| return torch.tensor(list(range(2, 40, 2)), device=x.device) + x | ||
| config.cpp_wrapper = True | ||
|
|
||
| x = torch.randn(1) | ||
| opt_fn = torch.compile()(fn) | ||
| y = opt_fn(x) | ||
| def fn(x, y): | ||
| return (x + y).sum() | ||
|
|
||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| x = torch.randn(128, 128, device=device) | ||
| y = torch.randn(128, 128, device=device) | ||
|
|
||
| opt_fn = torch.compile(fn) | ||
| result = opt_fn(x, y) | ||
|
|
||
|
|
||
| **For CPU** | ||
|
|
||
| The main part of Inductor-generated code with the default Python wrapper will look like this: | ||
| The main part of TorchInductor-generated code with the default Python wrapper will look like this: | ||
|
|
||
| .. code:: python | ||
|
|
||
| def call(args): | ||
| arg0_1, = args | ||
| args.clear() | ||
| assert_size_stride(arg0_1, (1, ), (1, )) | ||
| buf0 = empty_strided((19, ), (1, ), device='cpu', dtype=torch.float32) | ||
| cpp_fused_add_lift_fresh_0(c_void_p(constant0.data_ptr()), c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr())) | ||
| del arg0_1 | ||
| return (buf0, ) | ||
| class Runner: | ||
| def __init__(self, partitions): | ||
| self.partitions = partitions | ||
|
|
||
| def call(self, args): | ||
| arg0_1, arg1_1 = args | ||
| args.clear() | ||
| assert_size_stride(arg0_1, (128, 128), (128, 1)) | ||
| assert_size_stride(arg1_1, (128, 128), (128, 1)) | ||
| buf0 = empty_strided_cpu((), (), torch.float32) | ||
| cpp_fused_add_sum_0(arg0_1, arg1_1, buf0) | ||
| del arg0_1 | ||
| del arg1_1 | ||
| return (buf0, ) | ||
|
|
||
| By turning on the C++ wrapper, the generated code for the ``call`` function becomes a C++ function | ||
| ``inductor_entry_cpp`` of the C++ extension ``module``: | ||
| ``inductor_entry_impl``: | ||
|
|
||
| .. code:: python | ||
|
|
||
| std::vector<at::Tensor> inductor_entry_cpp(const std::vector<at::Tensor>& args) { | ||
| at::Tensor arg0_1 = args[0]; | ||
| at::Tensor constant0 = args[1]; | ||
| auto buf0 = at::empty_strided({19L, }, {1L, }, at::device(at::kCPU).dtype(at::kFloat)); | ||
| cpp_fused_add_lift_fresh_0((long*)(constant0.data_ptr()), (float*)(arg0_1.data_ptr()), (float*)(buf0.data_ptr())); | ||
| cpp_wrapper_src = ( | ||
| r''' | ||
| #include <torch/csrc/inductor/cpp_wrapper/cpu.h> | ||
| extern "C" void cpp_fused_add_sum_0(const float* in_ptr0, | ||
| const float* in_ptr1, | ||
| float* out_ptr0); | ||
| CACHE_TORCH_DTYPE(float32); | ||
| CACHE_TORCH_DEVICE(cpu); | ||
|
|
||
| void inductor_entry_impl( | ||
| AtenTensorHandle* | ||
| input_handles, // array of input AtenTensorHandle; handles | ||
| // are stolen; the array itself is borrowed | ||
| AtenTensorHandle* | ||
| output_handles // array for writing output AtenTensorHandle; handles | ||
| // will be stolen by the caller; the array itself is | ||
| // borrowed) | ||
| ) { | ||
| py::gil_scoped_release_simple release; | ||
|
|
||
| auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, 2); | ||
| auto arg0_1 = std::move(inputs[0]); | ||
| auto arg1_1 = std::move(inputs[1]); | ||
| static constexpr int64_t *int_array_0=nullptr; | ||
| AtenTensorHandle buf0_handle; | ||
| AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(0, int_array_0, int_array_0, cached_torch_dtype_float32, cached_torch_device_type_cpu, 0, &buf0_handle)); | ||
| RAIIAtenTensorHandle buf0(buf0_handle); | ||
| cpp_fused_add_sum_0((const float*)(arg0_1.data_ptr()), (const float*)(arg1_1.data_ptr()), (float*)(buf0.data_ptr())); | ||
| arg0_1.reset(); | ||
| return {buf0}; | ||
| } | ||
|
|
||
| module = CppWrapperCodeCache.load(cpp_wrapper_src, 'inductor_entry_cpp', 'c2buojsvlqbywxe3itb43hldieh4jqulk72iswa2awalwev7hjn2', False) | ||
|
|
||
| def _wrap_func(f): | ||
| def g(args): | ||
| args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args] | ||
| constants_tensor = [constant0] | ||
| args_tensor.extend(constants_tensor) | ||
|
|
||
| return f(args_tensor) | ||
| return g | ||
| call = _wrap_func(module.inductor_entry_cpp) | ||
| arg1_1.reset(); | ||
| output_handles[0] = buf0.release(); | ||
| } // inductor_entry_impl | ||
| ... | ||
| ''' | ||
| ) | ||
|
|
||
| inductor_entry = CppWrapperCodeCache.load_pybinding( | ||
| argtypes=["std::vector<AtenTensorHandle>"], | ||
| main_code=cpp_wrapper_src, | ||
| device_type="cpu", | ||
| num_outputs=1, | ||
| kernel_code=None, | ||
| ) | ||
|
|
||
| call = _wrap_func(inductor_entry) | ||
|
|
||
| **For GPU** | ||
|
|
||
|
|
@@ -113,47 +147,41 @@ Based on the same example code, the generated code for GPU will look like this: | |
| With the C++ wrapper turned on, the below equivalent C++ code will be generated: | ||
|
|
||
| .. code:: python | ||
|
|
||
| std::vector<at::Tensor> inductor_entry_cpp(const std::vector<at::Tensor>& args) { | ||
| at::Tensor arg0_1 = args[0]; | ||
| at::Tensor constant0 = args[1]; | ||
|
|
||
| at::cuda::CUDAGuard device_guard(0); | ||
| auto buf0 = at::empty_strided({19L, }, {1L, }, at::TensorOptions(c10::Device(at::kCUDA, 0)).dtype(at::kFloat)); | ||
| // Source Nodes: [add, tensor], Original ATen: [aten.add, aten.lift_fresh] | ||
| if (triton_poi_fused_add_lift_fresh_0 == nullptr) { | ||
| triton_poi_fused_add_lift_fresh_0 = loadKernel("/tmp/torchinductor_user/mm/cmm6xjgijjffxjku4akv55eyzibirvw6bti6uqmfnruujm5cvvmw.cubin", "triton_poi_fused_add_lift_fresh_0_0d1d2d3"); | ||
| } | ||
| CUdeviceptr var_0 = reinterpret_cast<CUdeviceptr>(constant0.data_ptr()); | ||
| CUdeviceptr var_1 = reinterpret_cast<CUdeviceptr>(arg0_1.data_ptr()); | ||
| CUdeviceptr var_2 = reinterpret_cast<CUdeviceptr>(buf0.data_ptr()); | ||
| auto var_3 = 19; | ||
| void* kernel_args_var_0[] = {&var_0, &var_1, &var_2, &var_3}; | ||
| cudaStream_t stream0 = at::cuda::getCurrentCUDAStream(0); | ||
| launchKernel(triton_poi_fused_add_lift_fresh_0, 1, 1, 1, 1, 0, kernel_args_var_0, stream0); | ||
| arg0_1.reset(); | ||
| return {buf0}; | ||
| } | ||
|
|
||
| module = CppWrapperCodeCache.load(cpp_wrapper_src, 'inductor_entry_cpp', 'czbpeilh4qqmbyejdgsbpdfuk2ss5jigl2qjb7xs4gearrjvuwem', True) | ||
| inductor_entry = CppWrapperCodeCache.load_pybinding( | ||
| argtypes=["std::vector<AtenTensorHandle>"], | ||
| main_code=cpp_wrapper_src, | ||
| device_type="cuda", | ||
| num_outputs=1, | ||
| kernel_code=None, | ||
| ) | ||
|
|
||
| def _wrap_func(f): | ||
| def g(args): | ||
| args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args] | ||
| constants_tensor = [constant0] | ||
| args_tensor.extend(constants_tensor) | ||
| input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg, device='cpu') for arg in args] | ||
| input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors) | ||
|
|
||
| args.clear() | ||
| del input_tensors | ||
|
|
||
| output_handles = f(input_handles) | ||
| output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles) | ||
| return output_tensors | ||
|
|
||
| return f(args_tensor) | ||
| return g | ||
| call = _wrap_func(module.inductor_entry_cpp) | ||
|
|
||
| call = _wrap_func(inductor_entry) | ||
|
|
||
|
|
||
| Conclusion | ||
| ------------ | ||
|
|
||
| In this tutorial, we introduced a new C++ wrapper in TorchInductor to speed up your models with just two lines of code changes. | ||
| We explained the motivation of this new feature and walked through the easy-to-use API to activate this experimental feature. | ||
| Furthermore, we demonstrated the Inductor-generated code using the default Python wrapper and the new C++ wrapper on both CPU and GPU | ||
| to visually showcase the difference between these two wrappers. | ||
| This tutorial introduced the **C++ wrapper** feature in TorchInductor, designed | ||
| to improve model performance with minimal code modification. We described the | ||
| motivation for this feature, detailed the experimental API used to enable it, | ||
| and compared the generated outputs of the default Python wrapper and the new | ||
| C++ wrapper on both CPU and GPU backends to illustrate their distinctions. | ||
|
|
||
| This feature is still in prototype stage. If you have any feature requests or run into any issues, please file a bug report at `GitHub issues <https://github.com/pytorch/pytorch/issues>`_. | ||
| # For more information on torch.compile, see | ||
| # | ||
| # .. _torch.compile tutorial: https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html | ||
| # .. TORCH_LOGS tutorial: https://docs.pytorch.org/tutorials/recipes/torch_logs.html | ||
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.