@@ -164,14 +164,13 @@ extern "C" __global__ void kernel(
164164 b = inputs[1 ];
165165 auto n = params.n ;
166166 void *args[]{&c, &a, &b, &n};
167- CUDA_ASSERT (cuLaunchKernel (
168- h->kernel (),
169- params.gridSize , 1 , 1 ,
170- params.blockSize , 1 , 1 ,
171- 0 , nullptr , args, nullptr ));
167+ h->launch (params.gridSize , 1 , 1 ,
168+ params.blockSize , 1 , 1 ,
169+ 0 , args);
172170 };
171+
173172 } else if (auto rank = broadcaster.strides .size () / (broadcaster.inputsCount + 1 ); rank == 1 ) {
174- static std::vector<dim_t > S0{0 , 1 , 1 }, S1{1 , 0 , 1 };
173+ static const std::vector<dim_t > S0{0 , 1 , 1 }, S1{1 , 0 , 1 };
175174 auto name = fmt::format (" binaryScalar{}" , postfix);
176175 auto code = fmt::format (SCALAR, dt_, op_);
177176 return [params, h = nvrtc::Handler::compile (name.c_str (), code.c_str (), " kernel" ),
@@ -185,12 +184,11 @@ extern "C" __global__ void kernel(
185184 v = inputs[1 - scalar];
186185 auto n = params.n ;
187186 void *args[]{&c, &v, &s, &n};
188- CUDA_ASSERT (cuLaunchKernel (
189- h->kernel (),
190- params.gridSize , 1 , 1 ,
191- params.blockSize , 1 , 1 ,
192- 0 , nullptr , args, nullptr ));
187+ h->launch (params.gridSize , 1 , 1 ,
188+ params.blockSize , 1 , 1 ,
189+ 0 , args);
193190 };
191+
194192 } else {
195193 auto name = fmt::format (" binary{}{}" , rank, postfix);
196194 auto code = fmt::format (BROADCAST, dt_, op_, rank);
@@ -202,11 +200,9 @@ extern "C" __global__ void kernel(
202200 b = inputs[1 ];
203201 auto n = params.n ;
204202 void *args[]{&c, &a, &b, const_cast <dim_t *>(strides.data ()), &n};
205- CUDA_ASSERT (cuLaunchKernel (
206- h->kernel (),
207- params.gridSize , 1 , 1 ,
208- params.blockSize , 1 , 1 ,
209- 0 , nullptr , args, nullptr ));
203+ h->launch (params.gridSize , 1 , 1 ,
204+ params.blockSize , 1 , 1 ,
205+ 0 , args);
210206 };
211207 }
212208 }
0 commit comments