Skip to content

Commit aa2a7f7

Browse files
Merge remote-tracking branch 'upstream/main' into akaratza_ci
2 parents 2a4c027 + 711241c commit aa2a7f7

File tree

83 files changed

+1076
-599
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+1076
-599
lines changed

.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,22 @@ function cpu_tests() {
2525

2626
# offline inference
2727
podman exec -it "$container_id" bash -c "
28+
export TORCH_COMPILE_DISABLE=1
2829
set -xve
2930
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" >> $HOME/test_basic.log
3031

3132
# Run basic model test
3233
podman exec -it "$container_id" bash -c "
34+
export TORCH_COMPILE_DISABLE=1
3335
set -evx
3436
pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib
35-
pip install sentence-transformers datamodel_code_generator
37+
pip install sentence-transformers datamodel_code_generator tblib
3638
3739
# Note: disable Bart until supports V1
3840
# pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model
39-
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-openai-community/gpt2]
40-
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-facebook/opt-125m]
41-
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-google/gemma-1.1-2b-it]
41+
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-False-5-32-openai-community/gpt2]
42+
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-False-5-32-facebook/opt-125m]
43+
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-False-5-32-google/gemma-1.1-2b-it]
4244
pytest -v -s tests/models/language/pooling/test_classification.py::test_models[float-jason9693/Qwen2.5-1.5B-apeach]
4345
# TODO: Below test case tests/models/language/pooling/test_embedding.py::test_models[True-ssmits/Qwen2-7B-Instruct-embed-base] fails on ppc64le. Disabling it for time being.
4446
# pytest -v -s tests/models/language/pooling/test_embedding.py -m cpu_model" >> $HOME/test_rest.log

csrc/cpu/utils.cpp

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,31 +45,54 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
4545
// Memory node binding
4646
if (numa_available() != -1) {
4747
int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front());
48-
// Verify all CPUs are on the same NUMA node
49-
for (size_t i = 1; i < omp_cpu_ids.size(); ++i) {
50-
int node_id = numa_node_of_cpu(omp_cpu_ids[i]);
51-
TORCH_CHECK(node_id == mem_node_id, "CPU ", omp_cpu_ids[i],
52-
" is on NUMA node ", node_id, ", but CPU ",
53-
omp_cpu_ids.front(), " is on NUMA node ", mem_node_id,
54-
". All CPUs should be on the same NUMA node for optimal "
55-
"performance. Memory will be bound to NUMA node ",
56-
mem_node_id, ".");
48+
std::set<int> node_ids;
49+
for (const auto& cpu_id : omp_cpu_ids) {
50+
int node_id = numa_node_of_cpu(cpu_id);
51+
if (node_id != -1) {
52+
node_ids.insert(node_id);
53+
}
54+
TORCH_WARN(node_id == mem_node_id, "CPU ", cpu_id, " is on NUMA node ",
55+
node_id, ", but CPU ", omp_cpu_ids.front(),
56+
" is on NUMA node ", mem_node_id,
57+
". All CPUs should be on the same NUMA node for optimal "
58+
"performance. Memory will be bound to NUMA node ",
59+
mem_node_id, ".");
5760
}
58-
bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str());
59-
bitmask* src_mask = numa_get_membind();
60-
61-
int pid = getpid();
61+
// Concatenate all node_ids into a single comma-separated string
62+
if (!node_ids.empty()) {
63+
std::string node_ids_str;
64+
for (const int node_id : node_ids) {
65+
if (!node_ids_str.empty()) {
66+
node_ids_str += ",";
67+
}
68+
node_ids_str += std::to_string(node_id);
69+
}
6270

63-
// move all existing pages to the specified numa node.
64-
*(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
65-
int page_num = numa_migrate_pages(pid, src_mask, mask);
66-
if (page_num == -1) {
67-
TORCH_WARN("numa_migrate_pages failed. errno: " + std::to_string(errno));
71+
bitmask* mask = numa_parse_nodestring(node_ids_str.c_str());
72+
bitmask* src_mask = numa_get_membind();
73+
74+
int pid = getpid();
75+
76+
if (mask && src_mask) {
77+
// move all existing pages to the specified numa node.
78+
*(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
79+
int page_num = numa_migrate_pages(pid, src_mask, mask);
80+
if (page_num == -1) {
81+
TORCH_WARN("numa_migrate_pages failed. errno: " +
82+
std::to_string(errno));
83+
}
84+
85+
// restrict memory allocation node.
86+
numa_set_membind(mask);
87+
numa_set_strict(1);
88+
89+
numa_free_nodemask(mask);
90+
numa_free_nodemask(src_mask);
91+
} else {
92+
TORCH_WARN("numa_parse_nodestring or numa_get_membind failed. errno: " +
93+
std::to_string(errno));
94+
}
6895
}
69-
70-
// restrict memory allocation node.
71-
numa_set_membind(mask);
72-
numa_set_strict(1);
7396
}
7497

7598
// OMP threads binding

csrc/dispatch_utils.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,24 @@
117117
break; \
118118
} \
119119
}
120+
121+
#define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \
122+
switch (NUM_DIMS) { \
123+
case 2: { \
124+
constexpr int tensor_rank = 2; \
125+
__VA_ARGS__(); \
126+
break; \
127+
} \
128+
case 3: { \
129+
constexpr int tensor_rank = 3; \
130+
__VA_ARGS__(); \
131+
break; \
132+
} \
133+
case 4: { \
134+
constexpr int tensor_rank = 4; \
135+
__VA_ARGS__(); \
136+
break; \
137+
} \
138+
default: \
139+
TORCH_CHECK(false, "Expects rank 2, 3 or 4 tensors but got ", NUM_DIMS); \
140+
}

csrc/layernorm_kernels.cu

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,38 @@
1010
namespace vllm {
1111

1212
// TODO(woosuk): Further optimize this kernel.
13-
template <typename scalar_t, int VEC_SIZE>
13+
template <typename scalar_t, int VEC_SIZE, int NUM_DIMS>
1414
__global__ void rms_norm_kernel(
15-
scalar_t* __restrict__ out, // [..., hidden_size]
16-
const scalar_t* __restrict__ input, // [..., hidden_size]
17-
const int64_t input_stride,
15+
scalar_t* __restrict__ out, // [..., hidden_size]
16+
const scalar_t* __restrict__ input, // [..., hidden_size]
17+
const int64_t input_stride_d2, // input.stride(-2)
18+
const int64_t input_stride_d3, // input.stride(-3)
19+
const int64_t input_stride_d4, // input.stride(-4)
20+
const int64_t input_shape_d2, // input.size(-2)
21+
const int64_t input_shape_d3, // input.size(-3)
1822
const scalar_t* __restrict__ weight, // [hidden_size]
1923
const float epsilon, const int num_tokens, const int hidden_size) {
2024
__shared__ float s_variance;
2125
float variance = 0.0f;
22-
const scalar_t* input_row = input + blockIdx.x * input_stride;
26+
const scalar_t* input_row;
27+
if constexpr (NUM_DIMS == 2) {
28+
// 2D for layernorm normal case [batch_size, hidden]
29+
input_row = input + blockIdx.x * input_stride_d2;
30+
} else if constexpr (NUM_DIMS == 3) {
31+
// 3D for q/k norm [batch_size, num_heads, head_size]
32+
int batch_idx = blockIdx.x / input_shape_d2;
33+
int head_idx = blockIdx.x % input_shape_d2;
34+
input_row =
35+
input + batch_idx * input_stride_d3 + head_idx * input_stride_d2;
36+
} else if constexpr (NUM_DIMS == 4) {
37+
// 4D for transformers model_impl qk norm [batch, seq, head, head_dim]
38+
int batch_idx = blockIdx.x / (input_shape_d3 * input_shape_d2);
39+
int remaining = blockIdx.x % (input_shape_d3 * input_shape_d2);
40+
int seq_idx = remaining / input_shape_d2;
41+
int head_idx = remaining % input_shape_d2;
42+
input_row = input + batch_idx * input_stride_d4 +
43+
seq_idx * input_stride_d3 + head_idx * input_stride_d2;
44+
}
2345

2446
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
2547
#pragma unroll
@@ -164,38 +186,44 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
164186
torch::Tensor& weight, // [hidden_size]
165187
double epsilon) {
166188
TORCH_CHECK(out.is_contiguous());
189+
if (input.stride(-1) != 1) {
190+
input = input.contiguous();
191+
}
167192
TORCH_CHECK(input.stride(-1) == 1);
168193
TORCH_CHECK(weight.is_contiguous());
169194

170195
int hidden_size = input.size(-1);
171196

172-
// We cannot just use `input.stride(-2)` if the tensor is not row-major.
173-
// Instead, we use a 2d view to get the second-innermost stride.
174-
// That way the dimensions (except the last one) can be arbitrarily permuted.
175-
torch::Tensor input_view = input.view({-1, hidden_size});
176-
177-
int num_tokens = input_view.numel() / hidden_size;
178-
int64_t input_stride = input_view.stride(-2);
197+
int num_tokens = input.numel() / hidden_size;
198+
int num_dims = input.dim();
199+
int64_t input_stride_d2 = input.stride(-2);
200+
int64_t input_stride_d3 = (num_dims >= 3) ? input.stride(-3) : 0;
201+
int64_t input_stride_d4 = (num_dims >= 4) ? input.stride(-4) : 0;
202+
int64_t input_shape_d2 = (num_dims >= 3) ? input.size(-2) : 0;
203+
int64_t input_shape_d3 = (num_dims >= 4) ? input.size(-3) : 0;
179204

180205
// For large num_tokens, use smaller blocks to increase SM concurrency.
181206
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
182207
dim3 grid(num_tokens);
183-
const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view));
208+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
184209
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
185-
VLLM_DISPATCH_FLOATING_TYPES(
186-
input_view.scalar_type(), "rms_norm_kernel", [&] {
187-
const int calculated_vec_size =
188-
std::gcd(16 / sizeof(scalar_t), hidden_size);
189-
const int block_size =
190-
std::min(hidden_size / calculated_vec_size, max_block_size);
191-
dim3 block(block_size);
192-
VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
193-
vllm::rms_norm_kernel<scalar_t, vec_size><<<grid, block, 0, stream>>>(
194-
out.data_ptr<scalar_t>(), input_view.data_ptr<scalar_t>(),
195-
input_stride, weight.data_ptr<scalar_t>(), epsilon, num_tokens,
196-
hidden_size);
197-
});
210+
VLLM_DISPATCH_RANK234(num_dims, [&] {
211+
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
212+
const int calculated_vec_size =
213+
std::gcd(16 / sizeof(scalar_t), hidden_size);
214+
const int block_size =
215+
std::min(hidden_size / calculated_vec_size, max_block_size);
216+
dim3 block(block_size);
217+
VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
218+
vllm::rms_norm_kernel<scalar_t, vec_size, tensor_rank>
219+
<<<grid, block, 0, stream>>>(
220+
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
221+
input_stride_d2, input_stride_d3, input_stride_d4,
222+
input_shape_d2, input_shape_d3, weight.data_ptr<scalar_t>(),
223+
epsilon, num_tokens, hidden_size);
198224
});
225+
});
226+
});
199227
}
200228

201229
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \

docker/Dockerfile

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ ARG UV_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
5656

5757
# PyTorch provides its own indexes for standard and nightly builds
5858
ARG PYTORCH_CUDA_INDEX_BASE_URL=https://download.pytorch.org/whl
59-
ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL=https://download.pytorch.org/whl/nightly
6059

6160
# PIP supports multiple authentication schemes, including keyring
6261
# By parameterizing the PIP_KEYRING_PROVIDER variable and setting it to
@@ -98,7 +97,6 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
9897
ARG PIP_INDEX_URL UV_INDEX_URL
9998
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
10099
ARG PYTORCH_CUDA_INDEX_BASE_URL
101-
ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL
102100
ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
103101

104102
# Activate virtual environment and add uv to PATH
@@ -317,7 +315,6 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
317315
ARG PIP_INDEX_URL UV_INDEX_URL
318316
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
319317
ARG PYTORCH_CUDA_INDEX_BASE_URL
320-
ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL
321318
ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
322319

323320
# Install uv for faster pip installs
@@ -337,20 +334,6 @@ ENV UV_LINK_MODE=copy
337334
# or future versions of triton.
338335
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
339336

340-
# arm64 (GH200) build follows the practice of "use existing pytorch" build,
341-
# we need to install torch and torchvision from the nightly builds first,
342-
# pytorch will not appear as a vLLM dependency in all of the following steps
343-
# after this step
344-
RUN --mount=type=cache,target=/root/.cache/uv \
345-
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
346-
uv pip install --system \
347-
--index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
348-
"torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319" ; \
349-
uv pip install --system \
350-
--index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
351-
--pre pytorch_triton==3.3.0+gitab727c40 ; \
352-
fi
353-
354337
# Install vllm wheel first, so that torch etc will be installed.
355338
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
356339
--mount=type=cache,target=/root/.cache/uv \

0 commit comments

Comments
 (0)