Skip to content

Commit c910de3

Browse files
committed
Fix compute_cap parsing and formatting in sagemaker-entrypoint-cuda-all.sh
1 parent 5a6e220 commit c910de3

File tree

1 file changed

+22
-59
lines changed

1 file changed

+22
-59
lines changed

sagemaker-entrypoint-cuda-all.sh

Lines changed: 22 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
#!/bin/bash
22

3+
if ! command -v nvidia-smi &>/dev/null; then
4+
echo "Error: 'nvidia-smi' command not found."
5+
exit 1
6+
fi
7+
8+
# Function to compare version numbers
39
verlte() {
410
[ "$1" = "$2" ] && return 1 || [ "$2" = "$(echo -e "$1\n$2" | sort -V | head -n1)" ]
511
}
612

13+
# CUDA compat libs logic
714
if [ -f /usr/local/cuda/compat/libcuda.so.1 ]; then
815
CUDA_COMPAT_MAX_DRIVER_VERSION=$(readlink /usr/local/cuda/compat/libcuda.so.1 | cut -d"." -f 3-)
916
echo "CUDA compat package requires Nvidia driver ≤${CUDA_COMPAT_MAX_DRIVER_VERSION}"
1017
cat /proc/driver/nvidia/version
11-
NVIDIA_DRIVER_VERSION=$(sed -n 's/^NVRM.*Kernel Module *\([0-9.]*\).*$/\1/p' /proc/driver/nvidia/version 2>/dev/null || true)
18+
NVIDIA_DRIVER_VERSION=$(sed -n 's/^NVRM.*Kernel Module \([0-9.]*\).*$/\1/p' /proc/driver/nvidia/version 2>/dev/null || true)
1219
echo "Current installed Nvidia driver version is ${NVIDIA_DRIVER_VERSION}"
1320
if [ $(verlte "$CUDA_COMPAT_MAX_DRIVER_VERSION" "$NVIDIA_DRIVER_VERSION") ]; then
1421
echo "Setup CUDA compatibility libs path to LD_LIBRARY_PATH"
@@ -21,71 +28,27 @@ else
2128
echo "Skip CUDA compat libs setup as package not found"
2229
fi
2330

31+
# Model variables check
2432
if [[ -z "${HF_MODEL_ID}" ]]; then
25-
echo "HF_MODEL_ID must be set"
26-
exit 1
33+
echo "HF_MODEL_ID must be set"
34+
exit 1
2735
fi
2836
export MODEL_ID="${HF_MODEL_ID}"
2937

3038
if [[ -n "${HF_MODEL_REVISION}" ]]; then
31-
export REVISION="${HF_MODEL_REVISION}"
32-
fi
33-
34-
if ! command -v nvidia-smi &> /dev/null; then
35-
echo "Error: 'nvidia-smi' command not found."
36-
exit 1
37-
fi
38-
39-
# Query GPU name using nvidia-smi
40-
gpu_name=$(nvidia-smi --query-gpu=gpu_name --format=csv | awk 'NR==2')
41-
if [ $? -ne 0 ]; then
42-
echo "Error: $gpu_name"
43-
echo "Query gpu_name failed"
44-
else
45-
echo "Query gpu_name succeeded. Printing output: $gpu_name"
39+
export REVISION="${HF_MODEL_REVISION}"
4640
fi
4741

48-
# Function to get compute capability based on GPU name
49-
get_compute_cap() {
50-
gpu_name="$1"
51-
52-
# Check if the GPU name contains "A10G"
53-
if [[ "$gpu_name" == *"A10G"* ]]; then
54-
echo "86"
55-
# Check if the GPU name contains "A100"
56-
elif [[ "$gpu_name" == *"A100"* ]]; then
57-
echo "80"
58-
# Check if the GPU name contains "H100"
59-
elif [[ "$gpu_name" == *"H100"* ]]; then
60-
echo "90"
61-
# Cover Nvidia T4
62-
elif [[ "$gpu_name" == *"T4"* ]]; then
63-
echo "75"
64-
# Cover Nvidia L4
65-
elif [[ "$gpu_name" == *"L4"* ]]; then
66-
echo "89"
67-
else
68-
echo "80" # Default compute capability
69-
fi
70-
}
71-
72-
if [[ -z "${CUDA_COMPUTE_CAP}" ]]
73-
then
74-
compute_cap=$(get_compute_cap "$gpu_name")
75-
echo "the compute_cap is $compute_cap"
76-
else
77-
compute_cap=$CUDA_COMPUTE_CAP
78-
fi
42+
compute_cap=$(nvidia-smi --query-gpu=compute_cap --format=csv | sed -n '2p' | sed 's/\.//g')
7943

80-
if [[ ${compute_cap} -eq 75 ]]
81-
then
82-
text-embeddings-router-75 --port 8080 --json-output
83-
elif [[ ${compute_cap} -ge 80 && ${compute_cap} -lt 90 ]]
84-
then
85-
text-embeddings-router-80 --port 8080 --json-output
86-
elif [[ ${compute_cap} -eq 90 ]]
87-
then
88-
text-embeddings-router-90 --port 8080 --json-output
44+
# Router selection logic
45+
if [ ${compute_cap} -eq 75 ]; then
46+
exec text-embeddings-router-75 --port 8080 --json-output
47+
elif [ ${compute_cap} -ge 80 -a ${compute_cap} -lt 90 ]; then
48+
exec text-embeddings-router-80 --port 8080 --json-output
49+
elif [ ${compute_cap} -eq 90 ]; then
50+
exec text-embeddings-router-90 --port 8080 --json-output
8951
else
90-
echo "cuda compute cap ${compute_cap} is not supported"; exit 1
52+
echo "cuda compute cap ${compute_cap} is not supported"
53+
exit 1
9154
fi

0 commit comments

Comments
 (0)