Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/models/qwen3_5_moe/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,7 @@ def _export_cuda(model, config, args):
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.propagate_device_pass import PropagateDeviceConfig
from torch.export import Dim, export

# Coordinate descent recompiles each kernel trying config perturbations,
Expand Down Expand Up @@ -1038,7 +1039,10 @@ def _export_cuda(model, config, args):
extract_delegate_segments=True,
do_quant_fusion_and_const_prop=True,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
emit_mutable_buffer_names=True,
propagate_device_config=PropagateDeviceConfig(
skip_h2d_for_method_inputs=True,
skip_d2h_for_method_outputs=True,
),
),
)

Expand Down
113 changes: 102 additions & 11 deletions examples/models/qwen3_5_moe/main.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand All @@ -13,14 +13,18 @@
#include <executorch/extension/llm/runner/util.h>
#include <executorch/extension/module/module.h>
#include <executorch/extension/tensor/tensor.h>
#include <executorch/extension/tensor/tensor_ptr.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/portable_type/device.h>
#include <executorch/runtime/platform/assert.h>
#include <executorch/runtime/backend/options.h>
#include <executorch/runtime/platform/log.h>
#include <pytorch/tokenizers/hf_tokenizer.h>

#include <algorithm>
#include <cinttypes>
#include <fstream>
#include <numeric>
#include <string>
#include <vector>

Expand Down Expand Up @@ -51,14 +55,22 @@
using ::executorch::extension::TensorPtr;
using ::executorch::runtime::Error;
using ::executorch::runtime::EValue;
#ifdef EXECUTORCH_BUILD_CUDA
using ::executorch::extension::clone_tensor_ptr_to;
#endif

using SizesType = executorch::aten::SizesType;

// Convert a model output tensor to the next sampled token id.
//
// On the CUDA build, the model fuses the sampler in (see sampler.py /
// Qwen35MoE.forward) and returns a single sampled token id as a [B, 1]
// float tensor; we just copy that scalar back from device.
// int64 tensor that lives in CUDA device memory (skip_d2h keeps method
// outputs on-device). We copy just that 8-byte scalar back to host — this
// is the only device->host transfer per decode step, needed for EOS
// detection and streaming detokenization. The token is fed to the next
// step device->device (see the decode loop), so no host round-trip occurs
// for the model input.
//
// On non-CUDA builds (Metal / MLX / CPU), the model returns raw logits
// of shape [B, T, V] in the model dtype (typically bf16). We sample on
Expand All @@ -72,10 +84,10 @@
bool on_device = cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess &&
attrs.type == cudaMemoryTypeDevice;

float val;
int64_t val;
if (on_device) {
cudaError_t err =
cudaMemcpy(&val, ptr, sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(&val, ptr, sizeof(int64_t), cudaMemcpyDeviceToHost);
if (err != cudaSuccess) {
ET_LOG(
Error,
Expand All @@ -84,7 +96,7 @@
return 0;
}
} else {
memcpy(&val, ptr, sizeof(float));
memcpy(&val, ptr, sizeof(int64_t));
}
return static_cast<uint64_t>(val);
#else
Expand Down Expand Up @@ -272,10 +284,20 @@
// a third input. Use a very small temperature for greedy to avoid
// division by zero while keeping the Gumbel noise negligible relative
// to logit differences.
//
// The export lowered this program with skip_h2d_for_method_inputs=True,
// so the CUDA backend requires every method input to already live in
// CUDA device memory (no host->device copy is inserted in the graph).
// We therefore stage all inputs on-device via clone_tensor_ptr_to. The
// temperature is constant, so it is cloned to the device exactly once
// and reused for prefill and every decode step.
auto cuda_device =
executorch::aten::Device(executorch::aten::DeviceType::CUDA, 0);
float temp_val =
FLAGS_temperature <= 0.0 ? 1e-6f : static_cast<float>(FLAGS_temperature);
auto temp_tensor =
from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float);
auto temp_tensor = clone_tensor_ptr_to(
from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float),
cuda_device);
#endif

stats.inference_start_ms = llm::time_in_ms();
Expand All @@ -298,14 +320,22 @@
pos_data[i] = i;
}
std::vector<int64_t> token_data(prompt_tokens.begin(), prompt_tokens.end());
auto tokens_tensor = from_blob(
auto tokens_cpu = from_blob(
token_data.data(),
{1, S(num_prompt_tokens)},
executorch::aten::ScalarType::Long);
auto pos_tensor = from_blob(
auto pos_cpu = from_blob(
pos_data.data(),
{S(num_prompt_tokens)},
executorch::aten::ScalarType::Long);
#ifdef EXECUTORCH_BUILD_CUDA
// Stage prefill inputs in CUDA device memory (see temperature note above).
auto tokens_tensor = clone_tensor_ptr_to(tokens_cpu, cuda_device);
auto pos_tensor = clone_tensor_ptr_to(pos_cpu, cuda_device);
#else
auto tokens_tensor = tokens_cpu;
auto pos_tensor = pos_cpu;
#endif

std::vector<EValue> prefill_inputs;
prefill_inputs.push_back(tokens_tensor);
Expand Down Expand Up @@ -348,14 +378,57 @@

std::vector<int64_t> decode_token_data = {static_cast<int64_t>(cur_token)};
std::vector<int64_t> decode_pos_data = {pos};
auto decode_tokens = from_blob(
auto decode_tokens_cpu = from_blob(
decode_token_data.data(), {1, 1}, executorch::aten::ScalarType::Long);
auto decode_pos = from_blob(
auto decode_pos_cpu = from_blob(
decode_pos_data.data(), {1}, executorch::aten::ScalarType::Long);
#ifdef EXECUTORCH_BUILD_CUDA
// Device-resident decode loop. The decode method's token input and its
// fused sampled-token output are both int64 [1,1] living in CUDA memory
// (skip_h2d on inputs, skip_d2h on outputs). We keep fixed device buffers
// (CUDA graph requires stable input addresses) and feed each step's output
// straight into the next step's token input with a device->device copy —
// no host round-trip for the model I/O. The initial clone seeds
// decode_tokens with the prefill-sampled token (one-time H2D at setup).
auto decode_tokens = clone_tensor_ptr_to(decode_tokens_cpu, cuda_device);
auto decode_pos = clone_tensor_ptr_to(decode_pos_cpu, cuda_device);

// Precompute every decode position on-device with a SINGLE H2D up front, so
// the per-step position update becomes a device->device copy (no per-step
// H2D). positions[k] = num_prompt_tokens + k.
std::vector<int64_t> all_pos_data(FLAGS_max_new_tokens);
std::iota(all_pos_data.begin(), all_pos_data.end(), pos);
auto all_pos = clone_tensor_ptr_to(
from_blob(
all_pos_data.data(),
{S(FLAGS_max_new_tokens)},
executorch::aten::ScalarType::Long),
cuda_device);
const auto* all_pos_dev =
static_cast<const int64_t*>(all_pos->const_data_ptr());
#else
auto decode_tokens = decode_tokens_cpu;
auto decode_pos = decode_pos_cpu;
#endif

for (int32_t step = 0; step < FLAGS_max_new_tokens; step++) {
#ifdef EXECUTORCH_BUILD_CUDA
// Set this step's position via device->device copy from the precomputed
// on-device array (no per-step H2D). The token input (decode_tokens)
// already holds the token to feed: the prefill-sampled token on step 0,
// and the previous step's output (copied in device->device at the end of
// the prior iteration) on every later step.
ET_CHECK_MSG(
cudaMemcpy(
decode_pos->mutable_data_ptr(),
all_pos_dev + step,
sizeof(int64_t),
cudaMemcpyDeviceToDevice) == cudaSuccess,
"Failed to set decode position device-to-device");
#else
decode_token_data[0] = static_cast<int64_t>(cur_token);
decode_pos_data[0] = pos;
#endif

std::vector<EValue> decode_inputs;
decode_inputs.push_back(EValue(decode_tokens));
Expand All @@ -370,9 +443,27 @@
return 1;
}
auto& decode_outputs = decode_result.get();
const auto& out_tensor = decode_outputs[0].toTensor();

prev_token = cur_token;
cur_token = read_token(decode_outputs[0].toTensor());
// Single per-step device->host copy: the 8-byte sampled token id, needed
// for EOS detection and streaming detokenization below.
cur_token = read_token(out_tensor);

#ifdef EXECUTORCH_BUILD_CUDA
// Feed this step's sampled token straight into the next step's token input
// on-device (device->device). This replaces the old host re-upload (H2D)
// and, together with read_token's D2H above, leaves exactly one 8-byte
// D2H and zero H2D per decode step. read_token's synchronous D2H has
// already forced the output to be ready, so the copy below is well-ordered.
ET_CHECK_MSG(
cudaMemcpy(
decode_tokens->mutable_data_ptr(),
out_tensor.const_data_ptr(),
sizeof(int64_t),
cudaMemcpyDeviceToDevice) == cudaSuccess,
"Failed to feed decode token device-to-device");
#endif

if (step == 0) {
stats.first_token_ms = llm::time_in_ms();
Expand Down
4 changes: 2 additions & 2 deletions examples/models/qwen3_5_moe/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def sample(
sampler returns the unmodified ``logits`` tensor.

Returns:
``[B, 1]`` float32 tensor of sampled token IDs, or the unmodified
``[B, 1]`` int64 tensor of sampled token IDs, or the unmodified
``logits`` tensor when ``temperature`` is ``None``.
"""
# No sampling configured — return raw logits.
Expand All @@ -57,4 +57,4 @@ def sample(
# float32 note in the docstring.
noise = torch.rand_like(logits)
gumbel = -torch.log(-torch.log(noise + 1e-20) + 1e-20)
return (logits + gumbel).argmax(dim=-1, keepdim=True).float()
return (logits + gumbel).argmax(dim=-1, keepdim=True)
Loading