diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index f0f76fb5c09..aa8be93720c 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -942,6 +942,7 @@ def _export_cuda(model, config, args): ) from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.passes import MemoryPlanningPass + from executorch.exir.passes.propagate_device_pass import PropagateDeviceConfig from torch.export import Dim, export # Coordinate descent recompiles each kernel trying config perturbations, @@ -1038,7 +1039,10 @@ def _export_cuda(model, config, args): extract_delegate_segments=True, do_quant_fusion_and_const_prop=True, memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), - emit_mutable_buffer_names=True, + propagate_device_config=PropagateDeviceConfig( + skip_h2d_for_method_inputs=True, + skip_d2h_for_method_outputs=True, + ), ), ) diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index 2cd79f0eabe..02aa803dc45 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -13,14 +13,18 @@ #include #include #include +#include #include #include +#include +#include #include #include #include #include #include +#include #include #include @@ -51,6 +55,9 @@ using ::executorch::extension::Module; using ::executorch::extension::TensorPtr; using ::executorch::runtime::Error; using ::executorch::runtime::EValue; +#ifdef EXECUTORCH_BUILD_CUDA +using ::executorch::extension::clone_tensor_ptr_to; +#endif using SizesType = executorch::aten::SizesType; @@ -58,7 +65,12 @@ using SizesType = executorch::aten::SizesType; // // On the CUDA build, the model fuses the sampler in (see sampler.py / // Qwen35MoE.forward) and returns a single sampled token id as a [B, 1] -// float tensor; we just copy that scalar back from device. +// int64 tensor that lives in CUDA device memory (skip_d2h keeps method +// outputs on-device). We copy just that 8-byte scalar back to host — this +// is the only device->host transfer per decode step, needed for EOS +// detection and streaming detokenization. The token is fed to the next +// step device->device (see the decode loop), so no host round-trip occurs +// for the model input. // // On non-CUDA builds (Metal / MLX / CPU), the model returns raw logits // of shape [B, T, V] in the model dtype (typically bf16). We sample on @@ -72,10 +84,10 @@ static uint64_t read_token(const executorch::aten::Tensor& output) { bool on_device = cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess && attrs.type == cudaMemoryTypeDevice; - float val; + int64_t val; if (on_device) { cudaError_t err = - cudaMemcpy(&val, ptr, sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(&val, ptr, sizeof(int64_t), cudaMemcpyDeviceToHost); if (err != cudaSuccess) { ET_LOG( Error, @@ -84,7 +96,7 @@ static uint64_t read_token(const executorch::aten::Tensor& output) { return 0; } } else { - memcpy(&val, ptr, sizeof(float)); + memcpy(&val, ptr, sizeof(int64_t)); } return static_cast(val); #else @@ -272,10 +284,20 @@ int main(int argc, char** argv) { // a third input. Use a very small temperature for greedy to avoid // division by zero while keeping the Gumbel noise negligible relative // to logit differences. + // + // The export lowered this program with skip_h2d_for_method_inputs=True, + // so the CUDA backend requires every method input to already live in + // CUDA device memory (no host->device copy is inserted in the graph). + // We therefore stage all inputs on-device via clone_tensor_ptr_to. The + // temperature is constant, so it is cloned to the device exactly once + // and reused for prefill and every decode step. + auto cuda_device = + executorch::aten::Device(executorch::aten::DeviceType::CUDA, 0); float temp_val = FLAGS_temperature <= 0.0 ? 1e-6f : static_cast(FLAGS_temperature); - auto temp_tensor = - from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float); + auto temp_tensor = clone_tensor_ptr_to( + from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float), + cuda_device); #endif stats.inference_start_ms = llm::time_in_ms(); @@ -298,14 +320,22 @@ int main(int argc, char** argv) { pos_data[i] = i; } std::vector token_data(prompt_tokens.begin(), prompt_tokens.end()); - auto tokens_tensor = from_blob( + auto tokens_cpu = from_blob( token_data.data(), {1, S(num_prompt_tokens)}, executorch::aten::ScalarType::Long); - auto pos_tensor = from_blob( + auto pos_cpu = from_blob( pos_data.data(), {S(num_prompt_tokens)}, executorch::aten::ScalarType::Long); +#ifdef EXECUTORCH_BUILD_CUDA + // Stage prefill inputs in CUDA device memory (see temperature note above). + auto tokens_tensor = clone_tensor_ptr_to(tokens_cpu, cuda_device); + auto pos_tensor = clone_tensor_ptr_to(pos_cpu, cuda_device); +#else + auto tokens_tensor = tokens_cpu; + auto pos_tensor = pos_cpu; +#endif std::vector prefill_inputs; prefill_inputs.push_back(tokens_tensor); @@ -348,14 +378,57 @@ int main(int argc, char** argv) { std::vector decode_token_data = {static_cast(cur_token)}; std::vector decode_pos_data = {pos}; - auto decode_tokens = from_blob( + auto decode_tokens_cpu = from_blob( decode_token_data.data(), {1, 1}, executorch::aten::ScalarType::Long); - auto decode_pos = from_blob( + auto decode_pos_cpu = from_blob( decode_pos_data.data(), {1}, executorch::aten::ScalarType::Long); +#ifdef EXECUTORCH_BUILD_CUDA + // Device-resident decode loop. The decode method's token input and its + // fused sampled-token output are both int64 [1,1] living in CUDA memory + // (skip_h2d on inputs, skip_d2h on outputs). We keep fixed device buffers + // (CUDA graph requires stable input addresses) and feed each step's output + // straight into the next step's token input with a device->device copy — + // no host round-trip for the model I/O. The initial clone seeds + // decode_tokens with the prefill-sampled token (one-time H2D at setup). + auto decode_tokens = clone_tensor_ptr_to(decode_tokens_cpu, cuda_device); + auto decode_pos = clone_tensor_ptr_to(decode_pos_cpu, cuda_device); + + // Precompute every decode position on-device with a SINGLE H2D up front, so + // the per-step position update becomes a device->device copy (no per-step + // H2D). positions[k] = num_prompt_tokens + k. + std::vector all_pos_data(FLAGS_max_new_tokens); + std::iota(all_pos_data.begin(), all_pos_data.end(), pos); + auto all_pos = clone_tensor_ptr_to( + from_blob( + all_pos_data.data(), + {S(FLAGS_max_new_tokens)}, + executorch::aten::ScalarType::Long), + cuda_device); + const auto* all_pos_dev = + static_cast(all_pos->const_data_ptr()); +#else + auto decode_tokens = decode_tokens_cpu; + auto decode_pos = decode_pos_cpu; +#endif for (int32_t step = 0; step < FLAGS_max_new_tokens; step++) { +#ifdef EXECUTORCH_BUILD_CUDA + // Set this step's position via device->device copy from the precomputed + // on-device array (no per-step H2D). The token input (decode_tokens) + // already holds the token to feed: the prefill-sampled token on step 0, + // and the previous step's output (copied in device->device at the end of + // the prior iteration) on every later step. + ET_CHECK_MSG( + cudaMemcpy( + decode_pos->mutable_data_ptr(), + all_pos_dev + step, + sizeof(int64_t), + cudaMemcpyDeviceToDevice) == cudaSuccess, + "Failed to set decode position device-to-device"); +#else decode_token_data[0] = static_cast(cur_token); decode_pos_data[0] = pos; +#endif std::vector decode_inputs; decode_inputs.push_back(EValue(decode_tokens)); @@ -370,9 +443,27 @@ int main(int argc, char** argv) { return 1; } auto& decode_outputs = decode_result.get(); + const auto& out_tensor = decode_outputs[0].toTensor(); prev_token = cur_token; - cur_token = read_token(decode_outputs[0].toTensor()); + // Single per-step device->host copy: the 8-byte sampled token id, needed + // for EOS detection and streaming detokenization below. + cur_token = read_token(out_tensor); + +#ifdef EXECUTORCH_BUILD_CUDA + // Feed this step's sampled token straight into the next step's token input + // on-device (device->device). This replaces the old host re-upload (H2D) + // and, together with read_token's D2H above, leaves exactly one 8-byte + // D2H and zero H2D per decode step. read_token's synchronous D2H has + // already forced the output to be ready, so the copy below is well-ordered. + ET_CHECK_MSG( + cudaMemcpy( + decode_tokens->mutable_data_ptr(), + out_tensor.const_data_ptr(), + sizeof(int64_t), + cudaMemcpyDeviceToDevice) == cudaSuccess, + "Failed to feed decode token device-to-device"); +#endif if (step == 0) { stats.first_token_ms = llm::time_in_ms(); diff --git a/examples/models/qwen3_5_moe/sampler.py b/examples/models/qwen3_5_moe/sampler.py index dd1cec5bcd6..3f8a60a33c1 100644 --- a/examples/models/qwen3_5_moe/sampler.py +++ b/examples/models/qwen3_5_moe/sampler.py @@ -42,7 +42,7 @@ def sample( sampler returns the unmodified ``logits`` tensor. Returns: - ``[B, 1]`` float32 tensor of sampled token IDs, or the unmodified + ``[B, 1]`` int64 tensor of sampled token IDs, or the unmodified ``logits`` tensor when ``temperature`` is ``None``. """ # No sampling configured — return raw logits. @@ -57,4 +57,4 @@ def sample( # float32 note in the docstring. noise = torch.rand_like(logits) gumbel = -torch.log(-torch.log(noise + 1e-20) + 1e-20) - return (logits + gumbel).argmax(dim=-1, keepdim=True).float() + return (logits + gumbel).argmax(dim=-1, keepdim=True)